[
  {
    "path": ".clang-format",
    "content": "# Start with a built-in style and modify it\nBasedOnStyle: Google\n\n# Overrides\nColumnLimit: 120\n"
  },
  {
    "path": ".git-blame-ignore-revs",
    "content": "# Commits to ignore in git-blame\n# These commits are bulk formatting or refactoring changes that should be skipped when viewing blame history\n\n# Add pre-commit and GitHub Actions workflow for it (#1949)\n1f20398756f0eeba37d6887a2d3f65e0687ec94f\n# Remove github actions config of pre-commit in favor of pre-commit ci (#1958)\n27e0e8951352d9d58c88b2895cd8f2c752bda963\n# Enable Ruff pre-commit hooks (#1957)\n16fadfe71c0d57312351c2d8b056251a0c8ce1ef\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Create a report to help us improve apex\ntitle: ''\nlabels: bug\nassignees: ''\n\n---\n\n**Describe the Bug**\n\n**Minimal Steps/Code to Reproduce the Bug**\n<!--\nPlease list the *minimal* steps or provide a code snippet for us to be able to reproduce the bug.\n\nA helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports.\n--> \n\n**Expected Behavior**\n<!-- A clear and concise description of what you expected to happen. -->\n\n**Environment**\n<!-- OS, version of Python, CUDA, PyTorch; collect these via `python -m torch.utils.collect_env` -->\n"
  },
  {
    "path": ".gitignore",
    "content": "apex.egg-info\ndist\nbuild\ndocs/build\n*~\n__pycache__\n.vscode\n\n# Copied from https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"apex/contrib/csrc/multihead_attn/cutlass\"]\n\tpath = apex/contrib/csrc/multihead_attn/cutlass\n\turl = https://github.com/NVIDIA/cutlass.git\n\tbranch = v1.2.0\n[submodule \"apex/contrib/csrc/cudnn-frontend\"]\n\tpath = apex/contrib/csrc/cudnn-frontend\n\turl = https://github.com/NVIDIA/cudnn-frontend.git\n"
  },
  {
    "path": ".nojekyll",
    "content": ""
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n- repo: https://github.com/pre-commit/mirrors-clang-format\n  rev: v22.1.1 # Or pin to your preferred clang-format version\n  hooks:\n  - id: clang-format\n    files: \\.(c|h|cpp|hpp|proto|cu|cuh)$\n    exclude: ^(apex/contrib/csrc/multihead_attn/cutlass|apex/contrib/csrc/cudnn-frontend)/\n\n- repo: https://github.com/astral-sh/ruff-pre-commit\n  rev: v0.15.6\n  hooks:\n  - id: ruff-check\n    args: [\"--fix\"]\n  - id: ruff-format\n    types_or: [python]\n    exclude: \"examples\"\n"
  },
  {
    "path": "LICENSE",
    "content": "All rights reserved.\n\nRedistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n\n1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n\n2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n\n3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."
  },
  {
    "path": "README.md",
    "content": "# Introduction\n\nThis repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch.\nSome of the code here will be included in upstream Pytorch eventually.\nThe intent of Apex is to make up-to-date utilities available to users as quickly as possible.\n\n# Installation\nEach [`apex.contrib`](./apex/contrib) module requires one or more install options other than `--cpp_ext` and `--cuda_ext`.\nNote that contrib modules do not necessarily support stable PyTorch releases, some of them might only be compatible with nightlies.\n\n## Containers\nNVIDIA PyTorch Containers are available on NGC: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch.\nThe containers come with all the custom extensions available at the moment. \n\nSee [the NGC documentation](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html) for details such as:\n- how to pull a container\n- how to run a pulled container\n- release notes\n\n## From Source\n\nTo install Apex from source, we recommend using the nightly Pytorch obtainable from https://github.com/pytorch/pytorch.\n\nThe latest stable release obtainable from https://pytorch.org should also work.\n\nWe recommend installing [`Ninja`](https://ninja-build.org/) to make compilation faster.\n\n### Linux\n\nFor performance and full functionality, we recommend installing Apex with CUDA and C++ extensions using environment variables:\n\n#### Using Environment Variables (Recommended)\n\n```bash\ngit clone https://github.com/NVIDIA/apex\ncd apex\n# Build with core extensions (cpp and cuda)\nAPEX_CPP_EXT=1 APEX_CUDA_EXT=1 pip install -v --no-build-isolation .\n\n# To build with additional extensions, specify them with environment variables\nAPEX_CPP_EXT=1 APEX_CUDA_EXT=1 APEX_FAST_MULTIHEAD_ATTN=1 APEX_FUSED_CONV_BIAS_RELU=1 pip install -v --no-build-isolation .\n\n# To build all contrib extensions at once\nAPEX_CPP_EXT=1 APEX_CUDA_EXT=1 APEX_ALL_CONTRIB_EXT=1 pip install -v --no-build-isolation .\n```\n\nTo reduce the build time, parallel building can be enabled:\n\n```bash\nNVCC_APPEND_FLAGS=\"--threads 4\" APEX_PARALLEL_BUILD=8 APEX_CPP_EXT=1 APEX_CUDA_EXT=1 pip install -v --no-build-isolation .\n```\n\nWhen CPU cores or memory are limited, the `--parallel` option is generally preferred over `--threads`. See [pull#1882](https://github.com/NVIDIA/apex/pull/1882) for more details.\n\n#### Using Command-Line Flags (Legacy Method)\n\nThe traditional command-line flags are still supported:\n\n```bash\n# Using pip config-settings (pip >= 23.1)\npip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" ./\n\n# For older pip versions\npip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./\n\n# To build with additional extensions\npip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" --global-option=\"--fast_multihead_attn\" ./\n```\n\n#### Python-Only Build\n\nAPEX also supports a Python-only build via:\n```bash\npip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./\n```\nA Python-only build omits:\n- Fused kernels required to use `apex.optimizers.FusedAdam`.\n- Fused kernels required to use `apex.normalization.FusedLayerNorm` and `apex.normalization.FusedRMSNorm`.\n- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.\n- Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`.\n`DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.\n\n\n### [Experimental] Windows\n`pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" .` may work if you were able to build Pytorch from source\non your system. A Python-only build via `pip install -v --no-cache-dir .` is more likely to work.  \nIf you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.\n\n\n## Custom C++/CUDA Extensions and Install Options\n\nIf a requirement of a module is not met, then it will not be built.\n\n|  Module Name  |  Environment Variable  |  Install Option  |  Misc  |\n|---------------|------------------------|------------------|--------|\n|  `apex_C`     |  `APEX_CPP_EXT=1`      |  `--cpp_ext`     | |\n|  `amp_C`      |  `APEX_CUDA_EXT=1`     |  `--cuda_ext`    | |\n|  `syncbn`     |  `APEX_CUDA_EXT=1`     |  `--cuda_ext`    | |\n|  `fused_layer_norm_cuda`  |  `APEX_CUDA_EXT=1`  |  `--cuda_ext`  | [`apex.normalization`](./apex/normalization) |\n|  `mlp_cuda`   |  `APEX_CUDA_EXT=1`     |  `--cuda_ext`    | |\n|  `scaled_upper_triang_masked_softmax_cuda`  |  `APEX_CUDA_EXT=1`  |  `--cuda_ext`  | |\n|  `generic_scaled_masked_softmax_cuda`  |  `APEX_CUDA_EXT=1`  |  `--cuda_ext`  | |\n|  `scaled_masked_softmax_cuda`  |  `APEX_CUDA_EXT=1`  |  `--cuda_ext`  | |\n|  `fused_weight_gradient_mlp_cuda`  |  `APEX_CUDA_EXT=1`  |  `--cuda_ext`  | Requires CUDA>=11 |\n|  `permutation_search_cuda`  |  `APEX_PERMUTATION_SEARCH=1`  |  `--permutation_search`  | [`apex.contrib.sparsity`](./apex/contrib/sparsity)  |\n|  `bnp`        |  `APEX_BNP=1`          |  `--bnp`         |  [`apex.contrib.groupbn`](./apex/contrib/groupbn) |\n|  `xentropy`   |  `APEX_XENTROPY=1`     |  `--xentropy`    |  [`apex.contrib.xentropy`](./apex/contrib/xentropy)  |\n|  `focal_loss_cuda`  |  `APEX_FOCAL_LOSS=1`  |  `--focal_loss`  |  [`apex.contrib.focal_loss`](./apex/contrib/focal_loss)  |\n|  `fused_index_mul_2d`  |  `APEX_INDEX_MUL_2D=1`  |  `--index_mul_2d`  |  [`apex.contrib.index_mul_2d`](./apex/contrib/index_mul_2d)  |\n|  `fused_adam_cuda`  |  `APEX_DEPRECATED_FUSED_ADAM=1`  |  `--deprecated_fused_adam`  |  [`apex.contrib.optimizers`](./apex/contrib/optimizers)  |\n|  `fused_lamb_cuda`  |  `APEX_DEPRECATED_FUSED_LAMB=1`  |  `--deprecated_fused_lamb`  |  [`apex.contrib.optimizers`](./apex/contrib/optimizers)  |\n|  `fast_layer_norm`  |  `APEX_FAST_LAYER_NORM=1`  |  `--fast_layer_norm`  |  [`apex.contrib.layer_norm`](./apex/contrib/layer_norm). different from `fused_layer_norm` |\n|  `fmhalib`    |  `APEX_FMHA=1`         |  `--fmha`        |  [`apex.contrib.fmha`](./apex/contrib/fmha)  |\n|  `fast_multihead_attn`  |  `APEX_FAST_MULTIHEAD_ATTN=1`  |  `--fast_multihead_attn`  |  [`apex.contrib.multihead_attn`](./apex/contrib/multihead_attn)  |\n|  `transducer_joint_cuda`  |  `APEX_TRANSDUCER=1`  |  `--transducer`  |  [`apex.contrib.transducer`](./apex/contrib/transducer)  |\n|  `transducer_loss_cuda`   |  `APEX_TRANSDUCER=1`  |  `--transducer`  |  [`apex.contrib.transducer`](./apex/contrib/transducer)  |\n|  `cudnn_gbn_lib`  |  `APEX_CUDNN_GBN=1`  |  `--cudnn_gbn`  | Requires cuDNN>=8.5, [`apex.contrib.cudnn_gbn`](./apex/contrib/cudnn_gbn) |\n|  `peer_memory_cuda`  |  `APEX_PEER_MEMORY=1`  |  `--peer_memory`  |  [`apex.contrib.peer_memory`](./apex/contrib/peer_memory)  |\n|  `nccl_p2p_cuda`  |  `APEX_NCCL_P2P=1`  |  `--nccl_p2p`  | Requires NCCL >= 2.10, [`apex.contrib.nccl_p2p`](./apex/contrib/nccl_p2p)  |\n|  `fast_bottleneck`  |  `APEX_FAST_BOTTLENECK=1`  |  `--fast_bottleneck`  |  Requires `peer_memory_cuda` and `nccl_p2p_cuda`, [`apex.contrib.bottleneck`](./apex/contrib/bottleneck) |\n|  `fused_conv_bias_relu`  |  `APEX_FUSED_CONV_BIAS_RELU=1`  |  `--fused_conv_bias_relu`  | Requires cuDNN>=8.4, [`apex.contrib.conv_bias_relu`](./apex/contrib/conv_bias_relu) |\n|  `distributed_adam_cuda`  |  `APEX_DISTRIBUTED_ADAM=1`  |  `--distributed_adam`  |  [`apex.contrib.optimizers`](./apex/contrib/optimizers)  |\n|  `distributed_lamb_cuda`  |  `APEX_DISTRIBUTED_LAMB=1`  |  `--distributed_lamb`  |  [`apex.contrib.optimizers`](./apex/contrib/optimizers)  |\n|  `_apex_nccl_allocator`  |  `APEX_NCCL_ALLOCATOR=1`  |  `--nccl_allocator`  | Requires NCCL >= 2.19, [`apex.contrib.nccl_allocator`](./apex/contrib/nccl_allocator)  |\n|  `_apex_gpu_direct_storage`  |  `APEX_GPU_DIRECT_STORAGE=1`  |  `--gpu_direct_storage`  |  [`apex.contrib.gpu_direct_storage`](./apex/contrib/gpu_direct_storage)  |\n\nYou can also build all contrib extensions at once by setting `APEX_ALL_CONTRIB_EXT=1`.\n"
  },
  {
    "path": "apex/__init__.py",
    "content": "import logging\nimport warnings\n\n# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten\nimport torch\n\n# For optimizers and normalization there is no Python fallback.\n# Absence of cuda backend is a hard error.\n# I would like the errors from importing fused_adam_cuda or fused_layer_norm_cuda\n# to be triggered lazily, because if someone has installed with --cpp_ext and --cuda_ext\n# so they expect those backends to be available, but for some reason they actually aren't\n# available (for example because they built improperly in a way that isn't revealed until\n# load time) the error message is timely and visible.\nfrom . import optimizers\nfrom . import normalization\n\n\n__all__ = [\"optimizers\", \"normalization\"]\n\n\ndef check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:\n    cudnn_available = torch.backends.cudnn.is_available()\n    cudnn_version = torch.backends.cudnn.version() if cudnn_available else None\n    if not (cudnn_available and (cudnn_version >= required_cudnn_version)):\n        warnings.warn(\n            f\"`{global_option}` depends on cuDNN {required_cudnn_version} or later, \"\n            f\"but {'cuDNN is not available' if not cudnn_available else cudnn_version}\"\n        )\n        return False\n    return True\n\n\nclass DeprecatedFeatureWarning(FutureWarning):\n    pass\n\n\ndef deprecated_warning(msg: str) -> None:\n    if (\n        not torch.distributed.is_available\n        or not torch.distributed.is_initialized()\n        or (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0)\n    ):\n        warnings.warn(msg, DeprecatedFeatureWarning)\n"
  },
  {
    "path": "apex/_autocast_utils.py",
    "content": "from typing import Optional, Sequence\n\nimport torch\n\n\n__all__ = [\"_cast_if_autocast_enabled\"]\n\n\ndef _get_autocast_dtypes() -> Sequence[torch.dtype]:\n    if torch.cuda.is_bf16_supported():\n        return [torch.half, torch.bfloat16]\n    return [torch.half]\n\n\ndef _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype:\n    if not torch.is_autocast_enabled():\n        return torch.float or dtype\n    else:\n        return torch.get_autocast_gpu_dtype()\n\n\ndef _cast_if_autocast_enabled(*args):\n    if not torch.is_autocast_enabled():\n        return args\n    else:\n        return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype())\n"
  },
  {
    "path": "apex/contrib/__init__.py",
    "content": ""
  },
  {
    "path": "apex/contrib/bottleneck/__init__.py",
    "content": "from .bottleneck import Bottleneck, SpatialBottleneck\nfrom .halo_exchangers import (\n    HaloExchangerNoComm,\n    HaloExchangerAllGather,\n    HaloExchangerSendRecv,\n    HaloExchangerPeer,\n)\n"
  },
  {
    "path": "apex/contrib/bottleneck/bottleneck.py",
    "content": "import functools as func\n\nimport torch\nfrom torch import nn\n\nfrom apex import check_cudnn_version_and_warn\nimport fast_bottleneck\nimport nccl_p2p_cuda as inc\n\n\nassert check_cudnn_version_and_warn(__name__, 8400)\n\n\ndef kaiming_uniform_(tensor, a=0, mode=\"fan_in\", nonlinearity=\"leaky_relu\"):\n    weight_tensor_nchw = tensor\n    nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity)\n\n\ndef compute_scale_bias_one(nhwc, weight, bias, running_mean, running_var, w_scale, w_bias):\n    scale = weight * running_var.rsqrt()\n    bias = bias - running_mean * scale\n    w_scale.copy_(scale)\n    w_bias.copy_(bias)\n\n\ndef compute_scale_bias_method(nhwc, args):\n    for arg in args:\n        # arg is tuple of (weight, bias, running_mean, running_var, w_scale, w_bias)\n        compute_scale_bias_one(nhwc, *arg)\n\n\nclass FrozenBatchNorm2d(torch.jit.ScriptModule):\n    \"\"\"\n    BatchNorm2d where the batch statistics and the affine parameters are fixed\n    \"\"\"\n\n    def __init__(self, n):\n        super(FrozenBatchNorm2d, self).__init__()\n        self.register_buffer(\"weight\", torch.ones(n))\n        self.register_buffer(\"bias\", torch.zeros(n))\n        self.register_buffer(\"running_mean\", torch.zeros(n))\n        self.register_buffer(\"running_var\", torch.ones(n))\n\n    @torch.jit.script_method\n    def get_scale_bias(self, nhwc):\n        # type: (bool) -> List[torch.Tensor]\n        scale = self.weight * self.running_var.rsqrt()\n        bias = self.bias - self.running_mean * scale\n        if nhwc:\n            scale = scale.reshape(1, 1, 1, -1)\n            bias = bias.reshape(1, 1, 1, -1)\n        else:\n            scale = scale.reshape(1, -1, 1, 1)\n            bias = bias.reshape(1, -1, 1, 1)\n        return scale, bias\n\n    @torch.jit.script_method\n    def forward(self, x):\n        scale, bias = self.get_scale_bias(False)\n        return x * scale + bias\n\n\n@torch.jit.script\ndef drelu_dscale1(grad_o, output, scale1):\n    relu_mask = output > 0\n    dx_relu = relu_mask * grad_o\n    g1 = dx_relu * scale1\n    return g1, dx_relu\n\n\n@torch.jit.script\ndef drelu_dscale2(grad_o, output, scale1, scale2):\n    relu_mask = output > 0\n    dx_relu = relu_mask * grad_o\n    g1 = dx_relu * scale1\n    g2 = dx_relu * scale2\n    return g1, g2\n\n\nclass BottleneckFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, nhwc, stride_1x1, scale, bias, x, *conv):\n        # TODO: clean up order of tensors\n        args = [x, *conv[0:3], *scale[0:3], *bias[0:3]]\n        ctx.downsample = len(conv) > 3\n        if ctx.downsample:\n            args.append(conv[3])\n            args.append(scale[3])\n            args.append(bias[3])\n\n        # weight buffers are always in nhwc while shape can be nhwc or channels_last\n        # here we pass in flag and let c++ handle it\n        # alternatively, we can put all sizes into a fixed format and pass it in\n        outputs = fast_bottleneck.forward(nhwc, stride_1x1, args)\n        ctx.save_for_backward(*(args + outputs))\n        # save relu outputs for drelu\n        ctx.nhwc = nhwc\n        ctx.stride_1x1 = stride_1x1\n        return outputs[2]\n\n    # backward relu is not exposed, MUL with mask used now\n    # only support dgrad\n    @staticmethod\n    def backward(ctx, grad_o):\n        outputs = ctx.saved_tensors[-3:]\n\n        if ctx.downsample:\n            grad_conv3, grad_conv4 = drelu_dscale2(\n                grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11]\n            )\n        else:\n            grad_conv3, grad_conv4 = drelu_dscale1(grad_o, outputs[2], ctx.saved_tensors[6])\n\n        # create input vector for backward\n        t_list = [*ctx.saved_tensors[0:10]]\n        t_list.append(grad_conv3)\n        t_list.append(grad_conv4)\n\n        # outputs used for wgrad and generating drelu mask\n        t_list.append(outputs[0])\n        t_list.append(outputs[1])\n\n        # in case there is downsample\n        if ctx.downsample:\n            t_list.append(ctx.saved_tensors[10])\n\n        grads = fast_bottleneck.backward(ctx.nhwc, ctx.stride_1x1, t_list)\n\n        return (None, None, None, None, *grads)\n\n\nbottleneck_function = BottleneckFunction.apply\n\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(\n        in_planes,\n        out_planes,\n        kernel_size=3,\n        stride=stride,\n        padding=dilation,\n        groups=groups,\n        bias=False,\n        dilation=dilation,\n    )\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\n\nclass Bottleneck(torch.nn.Module):\n    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)\n    # while original implementation places the stride at the first 1x1 convolution(self.conv1)\n    # according to \"Deep residual learning for image recognition\"https://arxiv.org/abs/1512.03385.\n    # This variant is also known as ResNet V1.5 and improves accuracy according to\n    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.\n    # here we put it at 1x1\n\n    def __init__(\n        self,\n        in_channels,\n        bottleneck_channels,\n        out_channels,\n        stride=1,\n        groups=1,\n        dilation=1,\n        norm_func=None,\n        use_cudnn=False,\n        explicit_nhwc=False,\n    ):\n        super(Bottleneck, self).__init__()\n        if groups != 1:\n            raise RuntimeError(\"Only support groups == 1\")\n        if dilation != 1:\n            raise RuntimeError(\"Only support dilation == 1\")\n        if norm_func == None:\n            norm_func = FrozenBatchNorm2d\n        else:\n            raise RuntimeError(\"Only support frozen BN now.\")\n\n        if stride != 1 or in_channels != out_channels:\n            self.downsample = nn.Sequential(\n                conv1x1(in_channels, out_channels, stride),\n                norm_func(out_channels),\n            )\n        else:\n            self.downsample = None\n\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv1x1(in_channels, bottleneck_channels, stride)\n        self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels)\n        self.conv3 = conv1x1(bottleneck_channels, out_channels)\n        self.relu = nn.ReLU(inplace=True)\n        self.stride = stride\n\n        self.bn1 = norm_func(bottleneck_channels)\n        self.bn2 = norm_func(bottleneck_channels)\n        self.bn3 = norm_func(out_channels)\n        self.w_scale = None\n\n        self.use_cudnn = use_cudnn\n\n        # setup conv weights\n        self.w_conv = [self.conv1.weight, self.conv2.weight, self.conv3.weight]\n        if self.downsample is not None:\n            self.w_conv.append(self.downsample[0].weight)\n\n        # init weight in nchw format before possible transpose\n        for w in self.w_conv:\n            kaiming_uniform_(w, a=1)\n\n        # TODO: prevent unsupported case usage\n        # support cases\n        #                 native      cudnn\n        # normal             yes         no\n        # channel_last       yes        yes\n        # explicit_nhwc       no        yes\n        self.explicit_nhwc = explicit_nhwc\n        if self.explicit_nhwc:\n            for p in self.parameters():\n                with torch.no_grad():\n                    p.data = p.data.permute(0, 2, 3, 1).contiguous()\n\n        return\n\n    # Returns single callable that recomputes scale and bias for all frozen batch-norms.\n    # This method must be called before cuda graphing.\n    # The callable it returns can be called anytime.\n    # Calling this method will prevent these from being computed every forward call.\n    def get_scale_bias_callable(self):\n        self.w_scale, self.w_bias, args = [], [], []\n        batch_norms = [self.bn1, self.bn2, self.bn3]\n        if self.downsample is not None:\n            batch_norms.append(self.downsample[1])\n        for bn in batch_norms:\n            s = torch.empty_like(bn.weight)\n            b = torch.empty_like(s)\n            args.append((bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b))\n            if self.explicit_nhwc:\n                self.w_scale.append(s.reshape(1, 1, 1, -1))\n                self.w_bias.append(b.reshape(1, 1, 1, -1))\n            else:\n                self.w_scale.append(s.reshape(1, -1, 1, 1))\n                self.w_bias.append(b.reshape(1, -1, 1, 1))\n        return func.partial(compute_scale_bias_method, self.explicit_nhwc, args)\n\n    def forward(self, x):\n        if self.use_cudnn:\n            if self.w_scale is None:\n                # calculate scale/bias from registered buffers\n                # TODO: make this better\n                s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)\n                s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)\n                s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)\n                w_scale = [s1, s2, s3]\n                w_bias = [b1, b2, b3]\n                if self.downsample is not None:\n                    s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)\n                    w_scale.append(s4)\n                    w_bias.append(b4)\n                out = bottleneck_function(\n                    self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv\n                )\n            else:\n                out = bottleneck_function(\n                    self.explicit_nhwc,\n                    self.stride,\n                    self.w_scale,\n                    self.w_bias,\n                    x,\n                    *self.w_conv,\n                )\n            return out\n\n        if self.explicit_nhwc:\n            raise RuntimeError(\"explicit nhwc with native ops is not supported.\")\n\n        # fallback to native ops\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass SpatialBottleneckFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        spatial_group_size,\n        spatial_group_rank,\n        spatial_communicator,\n        spatial_halo_exchanger,\n        spatial_method,\n        use_delay_kernel,\n        explicit_nhwc,\n        stride_1x1,\n        scale,\n        bias,\n        thresholdTop,\n        thresholdBottom,\n        x,\n        *conv,\n    ):\n        if spatial_group_size > 1:\n            stream1 = spatial_halo_exchanger.stream1\n            stream2 = spatial_halo_exchanger.stream2\n            stream3 = spatial_halo_exchanger.stream3\n\n        # TODO: clean up order of tensors\n        args = [x, *conv[0:3], *scale[0:3], *bias[0:3]]\n        ctx.downsample = len(conv) > 3\n        if ctx.downsample:\n            args.append(conv[3])\n            args.append(scale[3])\n            args.append(bias[3])\n\n        # weight buffers are always in explicit_nhwc while shape can be explicit_nhwc or channels_last\n        # here we pass in flag and let c++ handle it\n        # alternatively, we can put all sizes into a fixed format and pass it in\n        outputs = fast_bottleneck.forward_init(explicit_nhwc, stride_1x1, args)\n        fast_bottleneck.forward_out1(explicit_nhwc, stride_1x1, args, outputs)\n\n        if spatial_group_size > 1:\n            out1 = outputs[0]\n            if explicit_nhwc:\n                N, Hs, W, C = list(out1.shape)\n                memory_format = torch.contiguous_format\n                out1_pad = torch.empty([N, Hs + 2, W, C], dtype=out1.dtype, device=\"cuda\")\n            else:\n                N, C, Hs, W = list(out1.shape)\n                memory_format = (\n                    torch.channels_last\n                    if out1.is_contiguous(memory_format=torch.channels_last)\n                    else torch.contiguous_format\n                )\n                out1_pad = torch.empty(\n                    [N, C, Hs + 2, W],\n                    dtype=out1.dtype,\n                    device=\"cuda\",\n                    memory_format=memory_format,\n                )\n            stream1.wait_stream(torch.cuda.current_stream())\n            if spatial_method != 2:\n                stream3.wait_stream(torch.cuda.current_stream())\n            with torch.cuda.stream(stream1):\n                if explicit_nhwc:\n                    top_out1_halo = out1_pad[:, :1, :, :]\n                    btm_out1_halo = out1_pad[:, Hs + 1 : Hs + 2, :, :]\n                    spatial_halo_exchanger.left_right_halo_exchange(\n                        out1[:, :1, :, :],\n                        out1[:, Hs - 1 :, :, :],\n                        top_out1_halo,\n                        btm_out1_halo,\n                    )\n                else:\n                    top_out1_halo = out1_pad[:, :, :1, :]\n                    btm_out1_halo = out1_pad[:, :, Hs + 1 : Hs + 2, :]\n                    spatial_halo_exchanger.left_right_halo_exchange(\n                        out1[:, :, :1, :],\n                        out1[:, :, Hs - 1 :, :],\n                        top_out1_halo,\n                        btm_out1_halo,\n                    )\n            if spatial_method == 1:\n                # overlap mid convolution with halo transfer\n                if spatial_group_rank < spatial_group_size - 1:\n                    stream2.wait_stream(stream1)\n                    with torch.cuda.stream(stream2):\n                        if explicit_nhwc:\n                            btm_fat_halo = torch.empty(\n                                (N, 3, W, C), dtype=out1.dtype, device=out1.device\n                            )\n                            btm_fat_halo[:, 0:2, :, :].copy_(out1[:, Hs - 2 :, :, :])\n                            btm_fat_halo[:, 2:, :, :].copy_(btm_out1_halo)\n                        else:\n                            btm_fat_halo = torch.empty(\n                                (N, C, 3, W), dtype=out1.dtype, device=out1.device\n                            )\n                            btm_fat_halo[:, :, 0:2, :].copy_(out1[:, :, Hs - 2 :, :])\n                            btm_fat_halo[:, :, 2:, :].copy_(btm_out1_halo)\n                        btm_out2 = fast_bottleneck.forward_out2_halo(\n                            explicit_nhwc, btm_fat_halo, args\n                        )\n                if spatial_group_rank > 0:\n                    with torch.cuda.stream(stream1):\n                        if explicit_nhwc:\n                            top_fat_halo = torch.empty(\n                                (N, 3, W, C), dtype=out1.dtype, device=out1.device\n                            )\n                            top_fat_halo[:, :1, :, :].copy_(top_out1_halo)\n                            top_fat_halo[:, 1:3, :, :].copy_(out1[:, :2, :, :])\n                        else:\n                            top_fat_halo = torch.empty(\n                                (N, C, 3, W), dtype=out1.dtype, device=out1.device\n                            )\n                            top_fat_halo[:, :, :1, :].copy_(top_out1_halo)\n                            top_fat_halo[:, :, 1:3, :].copy_(out1[:, :, :2, :])\n                        top_out2 = fast_bottleneck.forward_out2_halo(\n                            explicit_nhwc, top_fat_halo, args\n                        )\n                if use_delay_kernel:\n                    inc.add_delay(10)\n            elif spatial_method != 2 and spatial_method != 3:\n                assert False, \"spatial_method must be 1, 2 or 3\"\n\n        if spatial_group_size <= 1:\n            fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)\n        elif spatial_method == 1:\n            fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)\n            with torch.cuda.stream(stream3):\n                if explicit_nhwc:\n                    out1_pad[:, 1 : Hs + 1, :, :].copy_(out1)\n                else:\n                    out1_pad[:, :, 1 : Hs + 1, :].copy_(out1)\n        elif spatial_method == 2:\n            # wait for halo transfer to finish before doing a full convolution of padded x\n            if explicit_nhwc:\n                out1_pad[:, 1 : Hs + 1, :, :].copy_(out1)\n            else:\n                out1_pad[:, :, 1 : Hs + 1, :].copy_(out1)\n            torch.cuda.current_stream().wait_stream(stream1)\n            fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad)\n        elif spatial_method == 3:\n            fast_bottleneck.forward_out2_mask(\n                explicit_nhwc, stride_1x1, args, outputs, thresholdTop, thresholdBottom\n            )\n            with torch.cuda.stream(stream3):\n                if explicit_nhwc:\n                    out1_pad[:, 1 : Hs + 1, :, :].copy_(out1)\n                else:\n                    out1_pad[:, :, 1 : Hs + 1, :].copy_(out1)\n\n        # compute halo cells for outputs[1] (out2)\n        if spatial_group_size > 1:\n            out2 = outputs[1]\n            if explicit_nhwc:\n                top_out2_halo = out2[:, :1, :, :]\n                btm_out2_halo = out2[:, Hs - 1 :, :, :]\n            else:\n                top_out2_halo = out2[:, :, :1, :]\n                btm_out2_halo = out2[:, :, Hs - 1 :, :]\n            if spatial_method == 1:\n                if spatial_group_rank > 0:\n                    torch.cuda.current_stream().wait_stream(stream1)\n                    top_out2_halo.copy_(top_out2)\n                if spatial_group_rank < spatial_group_size - 1:\n                    torch.cuda.current_stream().wait_stream(stream2)\n                    btm_out2_halo.copy_(btm_out2)\n            elif spatial_method == 3:\n                # Note\n                # out2 halo correction cannot overlap with anything since it has\n                # to wait for out2_mask to finish, but itself has to finish before\n                # the first kernel of _forward_rest can launch.\n                # At least we can overlap the two halo correction kernels.\n                if spatial_group_rank < spatial_group_size - 1:\n                    stream2.wait_stream(stream1)  # wait for halo transfers to finish\n                    stream2.wait_stream(\n                        torch.cuda.current_stream()\n                    )  # wait for *_out2_mask to finish\n                    with torch.cuda.stream(stream2):\n                        w1by3 = args[2][:, 2:3, :, :].clone()\n                        btm_out1_halo = btm_out1_halo.clone()\n                        btm_out2 = fast_bottleneck.forward_out2_halo_corr(\n                            explicit_nhwc,\n                            btm_out1_halo,\n                            args,\n                            w1by3,\n                            btm_out2_halo.clone(),\n                        )\n                        btm_out2_halo.copy_(btm_out2)\n                if spatial_group_rank > 0:\n                    stream1.wait_stream(\n                        torch.cuda.current_stream()\n                    )  # wait for *_out2_mask to finish\n                    with torch.cuda.stream(stream1):\n                        w1by3 = args[2][:, :1, :, :].clone()\n                        top_out1_halo = top_out1_halo.clone()\n                        top_out2 = fast_bottleneck.forward_out2_halo_corr(\n                            explicit_nhwc,\n                            top_out1_halo,\n                            args,\n                            w1by3,\n                            top_out2_halo.clone(),\n                        )\n                        top_out2_halo.copy_(top_out2)\n                if spatial_group_rank < spatial_group_size - 1:\n                    torch.cuda.current_stream().wait_stream(stream2)\n                if spatial_group_rank > 0:\n                    torch.cuda.current_stream().wait_stream(stream1)\n\n        fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs)\n        # save halos for backward pass\n        if spatial_group_size > 1:\n            if spatial_method != 2:\n                # make sure copy of mid-section of out1 into out1_pad is done before exiting\n                torch.cuda.current_stream().wait_stream(stream3)\n            ctx.save_for_backward(\n                *(\n                    args\n                    + outputs\n                    + [\n                        out1_pad,\n                    ]\n                )\n            )\n        else:\n            ctx.save_for_backward(*(args + outputs))\n        # save relu outputs for drelu\n        ctx.explicit_nhwc = explicit_nhwc\n        ctx.stride_1x1 = stride_1x1\n        ctx.spatial_group_size = spatial_group_size\n        if spatial_group_size > 1:\n            ctx.spatial_group_rank = spatial_group_rank\n            ctx.spatial_halo_exchanger = spatial_halo_exchanger\n            ctx.spatial_method = spatial_method\n            ctx.use_delay_kernel = use_delay_kernel\n            ctx.thresholdTop = thresholdTop\n            ctx.thresholdBottom = thresholdBottom\n            ctx.stream1 = stream1\n            ctx.stream2 = stream2\n            ctx.stream3 = stream3\n        return outputs[2]\n\n    # backward relu is not exposed, MUL with mask used now\n    # only support dgrad\n    @staticmethod\n    def backward(ctx, grad_o):\n        if ctx.spatial_group_size > 1:\n            out1_pad = ctx.saved_tensors[-1]\n            outputs = ctx.saved_tensors[-4:-1]\n        else:\n            outputs = ctx.saved_tensors[-3:]\n\n        if ctx.downsample:\n            grad_conv3, grad_conv4 = drelu_dscale2(\n                grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11]\n            )\n        else:\n            grad_conv3, grad_conv4 = drelu_dscale1(grad_o, outputs[2], ctx.saved_tensors[6])\n\n        # create input vector for backward\n        t_list = [*ctx.saved_tensors[0:10]]\n        t_list.append(grad_conv3)\n        t_list.append(grad_conv4)\n\n        # outputs used for wgrad and generating drelu mask\n        t_list.append(outputs[0])\n        t_list.append(outputs[1])\n\n        # in case there is downsample\n        if ctx.downsample:\n            t_list.append(ctx.saved_tensors[10])\n\n        grads = fast_bottleneck.backward_init(ctx.explicit_nhwc, ctx.stride_1x1, t_list)\n        wgrad3_stream = torch.cuda.Stream()\n        wgrad3_stream.wait_stream(torch.cuda.current_stream())\n        grad_out2 = fast_bottleneck.backward_grad_out2(\n            ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads\n        )\n        wgrad2_stream = torch.cuda.Stream()\n        wgrad2_stream.wait_stream(torch.cuda.current_stream())\n        # do halo exchange of grad_out2 here\n        # compute halo cells for grad_out1\n        if ctx.spatial_group_size > 1:\n            if ctx.explicit_nhwc:\n                N, Hs, W, C = list(grad_out2.shape)\n            else:\n                N, C, Hs, W = list(grad_out2.shape)\n            relu1 = t_list[12]\n            ctx.stream1.wait_stream(torch.cuda.current_stream())\n            with torch.cuda.stream(ctx.stream1):\n                top_halo, btm_halo = ctx.spatial_halo_exchanger.left_right_halo_exchange(\n                    grad_out2[:, :1, :, :], grad_out2[:, Hs - 1 :, :, :]\n                )\n                # copy halos to send buffer\n            if ctx.spatial_method == 1 or ctx.spatial_method == 2:\n                # 1 -> halo recompute approach\n                # 2 -> wait for concatenated halos, then do single conv on full input (not implemented yet for bprop)\n                if ctx.spatial_group_rank < ctx.spatial_group_size - 1:\n                    ctx.stream2.wait_stream(ctx.stream1)\n                    with torch.cuda.stream(ctx.stream2):\n                        if ctx.explicit_nhwc:\n                            btm_fat_halo = torch.empty(\n                                (N, 3, W, C),\n                                dtype=grad_out2.dtype,\n                                device=grad_out2.device,\n                            )\n                            btm_fat_halo[:, :2, :, :].copy_(grad_out2[:, Hs - 2 :, :, :])\n                            btm_fat_halo[:, 2:, :, :].copy_(btm_halo)\n                            btm_fat_relu_halo = torch.empty(\n                                (N, 3, W, C),\n                                dtype=grad_out2.dtype,\n                                device=grad_out2.device,\n                            )\n                            btm_fat_relu_halo[:, :2, :, :].copy_(relu1[:, Hs - 2 :, :, :])\n                            btm_fat_relu_halo[:, 2:, :, :].zero_()\n                        else:\n                            btm_fat_halo = torch.empty(\n                                (N, C, 3, W),\n                                dtype=grad_out2.dtype,\n                                device=grad_out2.device,\n                            )\n                            btm_fat_halo[:, :, :2, :].copy_(grad_out2[:, :, Hs - 2 :, :])\n                            btm_fat_halo[:, :, 2:, :].copy_(btm_halo)\n                            btm_fat_relu_halo = torch.empty(\n                                (N, C, 3, W),\n                                dtype=grad_out2.dtype,\n                                device=grad_out2.device,\n                            )\n                            btm_fat_relu_halo[:, :, :2, :].copy_(relu1[:, :, Hs - 2 :, :])\n                            btm_fat_relu_halo[:, :, 2:, :].zero_()\n                        btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(\n                            ctx.explicit_nhwc,\n                            ctx.stride_1x1,\n                            t_list,\n                            grads,\n                            btm_fat_halo,\n                            btm_fat_relu_halo,\n                        )\n                        if ctx.explicit_nhwc:\n                            btm_grad_out1_halo = btm_grad_out1_halo[:, 1:2, :, :]\n                        else:\n                            btm_grad_out1_halo = btm_grad_out1_halo[:, :, 1:2, :]\n                if ctx.spatial_group_rank > 0:\n                    with torch.cuda.stream(ctx.stream1):\n                        if ctx.explicit_nhwc:\n                            top_fat_halo = torch.empty(\n                                (N, 3, W, C),\n                                dtype=grad_out2.dtype,\n                                device=grad_out2.device,\n                            )\n                            top_fat_halo[:, :1, :, :].copy_(top_halo)\n                            top_fat_halo[:, 1:, :, :].copy_(grad_out2[:, :2, :, :])\n                            top_fat_relu_halo = torch.empty(\n                                (N, 3, W, C),\n                                dtype=grad_out2.dtype,\n                                device=grad_out2.device,\n                            )\n                            top_fat_relu_halo[:, :1, :, :].zero_()\n                            top_fat_relu_halo[:, 1:, :, :].copy_(relu1[:, :2, :, :])\n                        else:\n                            top_fat_halo = torch.empty(\n                                (N, C, 3, W),\n                                dtype=grad_out2.dtype,\n                                device=grad_out2.device,\n                            )\n                            top_fat_halo[:, :, :1, :].copy_(top_halo)\n                            top_fat_halo[:, :, 1:, :].copy_(grad_out2[:, :, :2, :])\n                            top_fat_relu_halo = torch.empty(\n                                (N, C, 3, W),\n                                dtype=grad_out2.dtype,\n                                device=grad_out2.device,\n                            )\n                            top_fat_relu_halo[:, :, :1, :].zero_()\n                            top_fat_relu_halo[:, :, 1:, :].copy_(relu1[:, :, :2, :])\n                        top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(\n                            ctx.explicit_nhwc,\n                            ctx.stride_1x1,\n                            t_list,\n                            grads,\n                            top_fat_halo,\n                            top_fat_relu_halo,\n                        )\n                        if ctx.explicit_nhwc:\n                            top_grad_out1_halo = top_grad_out1_halo[:, 1:2, :, :]\n                        else:\n                            top_grad_out1_halo = top_grad_out1_halo[:, :, 1:2, :]\n                if ctx.use_delay_kernel:\n                    inc.add_delay(10)\n            elif ctx.spatial_method != 3:\n                assert False, \"spatial_method must be 1, 2 or 3\"\n\n        # compute grad_out1 for internal cells\n        if ctx.spatial_group_size <= 1 or ctx.spatial_method == 1 or ctx.spatial_method == 2:\n            grad_out1 = fast_bottleneck.backward_grad_out1(\n                ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2\n            )\n        elif ctx.spatial_group_size > 1 and ctx.spatial_method == 3:\n            grad_out1 = fast_bottleneck.backward_grad_out1_mask(\n                ctx.explicit_nhwc,\n                ctx.stride_1x1,\n                t_list,\n                grads,\n                grad_out2,\n                ctx.thresholdTop,\n                ctx.thresholdBottom,\n            )\n\n        # apply halo cells to grad_out1\n        if ctx.spatial_group_size > 1:\n            w = t_list[2]\n            z = t_list[4]\n            relu1 = t_list[12]\n            # print(\"w.shape = %s, z.shape = %s, relu1.shape = %s\" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))\n            if ctx.spatial_method == 1 or ctx.spatial_method == 2:\n                if ctx.spatial_group_rank < ctx.spatial_group_size - 1:\n                    torch.cuda.current_stream().wait_stream(ctx.stream2)\n                    if ctx.explicit_nhwc:\n                        grad_out1[:, Hs - 1 :, :, :].copy_(btm_grad_out1_halo)\n                    else:\n                        grad_out1[:, :, Hs - 1 :, :].copy_(btm_grad_out1_halo)\n                    # print(\"ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)\" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))\n                if ctx.spatial_group_rank > 0:\n                    torch.cuda.current_stream().wait_stream(ctx.stream1)\n                    if ctx.explicit_nhwc:\n                        grad_out1[:, :1, :, :].copy_(top_grad_out1_halo)\n                    else:\n                        grad_out1[:, :, :1, :].copy_(top_grad_out1_halo)\n                    # print(\"ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)\" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))\n            elif ctx.spatial_method == 3:\n                if ctx.spatial_group_rank < ctx.spatial_group_size - 1:\n                    if ctx.explicit_nhwc:\n                        btm_relu_halo = relu1[:, Hs - 1 :, :, :].clone()\n                        btm_grad_out1 = grad_out1[:, Hs - 1 :, :, :]\n                    else:\n                        btm_relu_halo = relu1[:, :, Hs - 1 :, :].clone()\n                        btm_grad_out1 = grad_out1[:, :, Hs - 1 :, :]\n                    w1by3 = w[:, :1, :, :].clone()\n                    ctx.stream2.wait_stream(ctx.stream1)  # wait for halo transfers to finish\n                    ctx.stream2.wait_stream(\n                        torch.cuda.current_stream()\n                    )  # wait for backward_grad_out1_mask to finish before launching halo correction kernel\n                    with torch.cuda.stream(ctx.stream2):\n                        btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(\n                            ctx.explicit_nhwc,\n                            ctx.stride_1x1,\n                            t_list,\n                            w1by3,\n                            grads,\n                            btm_halo,\n                            btm_relu_halo,\n                            btm_grad_out1.clone(),\n                        )\n                        btm_grad_out1.copy_(btm_grad_out1_halo)\n                if ctx.spatial_group_rank > 0:\n                    if ctx.explicit_nhwc:\n                        top_relu_halo = relu1[:, :1, :, :].clone()\n                        top_grad_out1 = grad_out1[:, :1, :, :]\n                    else:\n                        top_relu_halo = relu1[:, :, :1, :].clone()\n                        top_grad_out1 = grad_out1[:, :, :1, :]\n                    w1by3 = w[:, 2:, :, :].clone()\n                    ctx.stream1.wait_stream(\n                        torch.cuda.current_stream()\n                    )  # wait for backward_grad_out1_mask to finish before launching halo correction kernel\n                    with torch.cuda.stream(ctx.stream1):\n                        top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(\n                            ctx.explicit_nhwc,\n                            ctx.stride_1x1,\n                            t_list,\n                            w1by3,\n                            grads,\n                            top_halo,\n                            top_relu_halo,\n                            top_grad_out1.clone(),\n                        )\n                        top_grad_out1.copy_(top_grad_out1_halo)\n                if ctx.spatial_group_rank < ctx.spatial_group_size - 1:\n                    torch.cuda.current_stream().wait_stream(\n                        ctx.stream2\n                    )  # wait for halo correction to finish\n                if ctx.spatial_group_rank > 0:\n                    torch.cuda.current_stream().wait_stream(ctx.stream1)\n\n        wgrad1_stream = torch.cuda.Stream()\n        wgrad1_stream.wait_stream(torch.cuda.current_stream())\n        fast_bottleneck.backward_rest(\n            ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1\n        )\n        with torch.cuda.stream(wgrad3_stream):\n            fast_bottleneck.backward_wgrad3(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads)\n        with torch.cuda.stream(wgrad2_stream):\n            if ctx.spatial_group_size > 1:\n                fast_bottleneck.backward_wgrad2_pad(\n                    ctx.explicit_nhwc,\n                    ctx.stride_1x1,\n                    t_list,\n                    grads,\n                    out1_pad,\n                    grad_out2,\n                )\n            else:\n                fast_bottleneck.backward_wgrad2(\n                    ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2\n                )\n        with torch.cuda.stream(wgrad1_stream):\n            fast_bottleneck.backward_wgrad1(\n                ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out1\n            )\n        torch.cuda.current_stream().wait_stream(wgrad3_stream)\n        torch.cuda.current_stream().wait_stream(wgrad2_stream)\n        torch.cuda.current_stream().wait_stream(wgrad1_stream)\n\n        return (\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            *grads,\n        )\n\n\nspatial_bottleneck_function = SpatialBottleneckFunction.apply\n\n\nclass SpatialBottleneck(torch.nn.Module):\n    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)\n    # while original implementation places the stride at the first 1x1 convolution(self.conv1)\n    # according to \"Deep residual learning for image recognition\"https://arxiv.org/abs/1512.03385.\n    # This variant is also known as ResNet V1.5 and improves accuracy according to\n    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.\n    # here we put it at 1x1\n\n    def __init__(\n        self,\n        in_channels,\n        bottleneck_channels,\n        out_channels,\n        stride=1,\n        groups=1,\n        dilation=1,\n        norm_func=None,\n        use_cudnn=False,\n        explicit_nhwc=False,\n        spatial_parallel_args=None,\n    ):\n        super(SpatialBottleneck, self).__init__()\n        if groups != 1:\n            raise RuntimeError(\"Only support groups == 1\")\n        if dilation != 1:\n            raise RuntimeError(\"Only support dilation == 1\")\n        if norm_func == None:\n            norm_func = FrozenBatchNorm2d\n        else:\n            raise RuntimeError(\"Only support frozen BN now.\")\n\n        if stride != 1 or in_channels != out_channels:\n            self.downsample = nn.Sequential(\n                conv1x1(in_channels, out_channels, stride),\n                norm_func(out_channels),\n            )\n        else:\n            self.downsample = None\n\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv1x1(in_channels, bottleneck_channels, stride)\n        self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels)\n        self.conv3 = conv1x1(bottleneck_channels, out_channels)\n        self.relu = nn.ReLU(inplace=True)\n        self.stride = stride\n\n        self.bn1 = norm_func(bottleneck_channels)\n        self.bn2 = norm_func(bottleneck_channels)\n        self.bn3 = norm_func(out_channels)\n        self.w_scale = None\n\n        self.use_cudnn = use_cudnn\n\n        # setup conv weights\n        self.w_conv = [self.conv1.weight, self.conv2.weight, self.conv3.weight]\n        if self.downsample is not None:\n            self.w_conv.append(self.downsample[0].weight)\n\n        # init weight in nchw format before possible transpose\n        for w in self.w_conv:\n            kaiming_uniform_(w, a=1)\n\n        self.thresholdTop, self.thresholdBottom = None, None\n\n        # TODO: prevent unsupported case usage\n        # support cases\n        #                 native      cudnn\n        # normal             yes         no\n        # channel_last       yes        yes\n        # explicit_nhwc       no        yes\n        self.explicit_nhwc = explicit_nhwc\n        if self.explicit_nhwc:\n            for p in self.parameters():\n                with torch.no_grad():\n                    p.data = p.data.permute(0, 2, 3, 1).contiguous()\n\n        # spatial communicator\n        if spatial_parallel_args is None:\n            self.spatial_parallel_args = (1, 0, None, None, 0, False)\n        else:\n            self.spatial_parallel_args = spatial_parallel_args\n        return\n\n    # Returns single callable that recomputes scale and bias for all frozen batch-norms.\n    # This method must be called before cuda graphing.\n    # The callable it returns can be called anytime.\n    # Calling this method will prevent these from being computed every forward call.\n    def get_scale_bias_callable(self):\n        self.w_scale, self.w_bias, args = [], [], []\n        batch_norms = [self.bn1, self.bn2, self.bn3]\n        if self.downsample is not None:\n            batch_norms.append(self.downsample[1])\n        for bn in batch_norms:\n            s = torch.empty_like(bn.weight)\n            b = torch.empty_like(s)\n            args.append((bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b))\n            if self.explicit_nhwc:\n                self.w_scale.append(s.reshape(1, 1, 1, -1))\n                self.w_bias.append(b.reshape(1, 1, 1, -1))\n            else:\n                self.w_scale.append(s.reshape(1, -1, 1, 1))\n                self.w_bias.append(b.reshape(1, -1, 1, 1))\n        return func.partial(compute_scale_bias_method, self.explicit_nhwc, args)\n\n    def forward(self, x):\n        if self.use_cudnn:\n            if self.thresholdTop is None:\n                spatial_group_size, spatial_group_rank, _, _, _, _ = self.spatial_parallel_args\n                if self.explicit_nhwc:\n                    N, H, W, C = list(x.shape)\n                else:\n                    N, C, H, W = list(x.shape)\n                self.thresholdTop = torch.tensor(\n                    [1 if spatial_group_rank > 0 else 0],\n                    dtype=torch.int32,\n                    device=\"cuda\",\n                )\n                self.thresholdBottom = torch.tensor(\n                    [H - 2 if spatial_group_rank < spatial_group_size - 1 else H - 1],\n                    dtype=torch.int32,\n                    device=\"cuda\",\n                )\n\n            if self.w_scale is None:\n                # calculate scale/bias from registered buffers\n                # TODO: make this better\n                s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)\n                s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)\n                s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)\n                w_scale = [s1, s2, s3]\n                w_bias = [b1, b2, b3]\n                if self.downsample is not None:\n                    s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)\n                    w_scale.append(s4)\n                    w_bias.append(b4)\n                out = spatial_bottleneck_function(\n                    *self.spatial_parallel_args,\n                    self.explicit_nhwc,\n                    self.stride,\n                    w_scale,\n                    w_bias,\n                    self.thresholdTop,\n                    self.thresholdBottom,\n                    x,\n                    *self.w_conv,\n                )\n            else:\n                out = spatial_bottleneck_function(\n                    *self.spatial_parallel_args,\n                    self.explicit_nhwc,\n                    self.stride,\n                    self.w_scale,\n                    self.w_bias,\n                    self.thresholdTop,\n                    self.thresholdBottom,\n                    x,\n                    *self.w_conv,\n                )\n            return out\n\n        if self.explicit_nhwc:\n            raise RuntimeError(\"explicit nhwc with native ops is not supported.\")\n\n        # fallback to native ops\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n"
  },
  {
    "path": "apex/contrib/bottleneck/halo_exchangers.py",
    "content": "import torch\nimport nccl_p2p_cuda as inc\nimport peer_memory_cuda as pm\n\n\n# Communication free halo exchanger.\n# NB! This halo exchanger does not exchange halos with neighbors as it should, it merely swaps the inputs\n# NB! This is only useful for performance testing.\n# NB! Do not use for actual production runs\nclass HaloExchanger(object):\n    def __init__(self, ranks, rank_in_group):\n        self.stream1 = torch.cuda.Stream()\n        self.stream2 = torch.cuda.Stream()\n        self.stream3 = torch.cuda.Stream()\n        self.group_size = len(ranks)\n        self.ranks = ranks\n        self.rank_in_group = rank_in_group\n        self.wrap_around_left_rank_in_group = (\n            rank_in_group + self.group_size - 1\n        ) % self.group_size\n        self.wrap_around_right_rank_in_group = (rank_in_group + 1) % self.group_size\n        self.left_rank = ranks[rank_in_group - 1] if rank_in_group > 0 else -1\n        self.left_zero = True if rank_in_group == 0 else False\n        self.right_rank = ranks[rank_in_group + 1] if rank_in_group < self.group_size - 1 else -1\n        self.right_zero = True if rank_in_group == self.group_size - 1 else False\n\n\nclass HaloExchangerNoComm(HaloExchanger):\n    def __init__(self, ranks, rank_in_group):\n        super(HaloExchangerNoComm, self).__init__(ranks, rank_in_group)\n\n    def left_right_halo_exchange(\n        self,\n        left_output_halo,\n        right_output_halo,\n        left_input_halo=None,\n        right_input_halo=None,\n    ):\n        if left_input_halo is None:\n            return right_output_halo, left_output_halo\n        else:\n            left_input_halo.copy_(right_output_halo)\n            right_input_halo.copy_(left_output_halo)\n\n\nclass HaloExchangerAllGather(HaloExchanger):\n    def __init__(self, ranks, rank_in_group, comm):\n        super(HaloExchangerAllGather, self).__init__(ranks, rank_in_group)\n        # self.comm must be NCCL process_group created with torch.distributed.new_group(ranks=ranks)\n        self.comm = comm\n\n    def left_right_halo_exchange(\n        self,\n        left_output_halo,\n        right_output_halo,\n        left_input_halo=None,\n        right_input_halo=None,\n    ):\n        N, Hh, W, C = list(left_output_halo.shape)\n        send_halos = torch.empty(\n            (N, 2 * Hh, W, C),\n            dtype=left_output_halo.dtype,\n            device=left_output_halo.device,\n        )\n        send_halos[:, :Hh, :, :].copy_(left_output_halo)\n        send_halos[:, Hh:, :, :].copy_(right_output_halo)\n        all_halos = torch.empty(\n            (N, 2 * Hh * self.group_size, W, C),\n            dtype=left_output_halo.dtype,\n            device=left_output_halo.device,\n        )\n        all_halos = [\n            all_halos[:, i * 2 * Hh : (i + 1) * 2 * Hh, :, :] for i in range(self.group_size)\n        ]\n        torch.distributed.all_gather(all_halos, send_halos, group=self.comm, no_copy=True)\n        ag_left_input_halo = all_halos[self.wrap_around_left_rank_in_group][:, Hh:, :, :]\n        ag_right_input_halo = all_halos[self.wrap_around_right_rank_in_group][:, :Hh, :, :]\n        if left_input_halo is None:\n            if self.left_zero:\n                ag_left_input_halo.zero_()\n            if self.right_zero:\n                ag_right_input_halo.zero_()\n            return ag_left_input_halo, ag_right_input_halo\n        else:\n            if self.left_zero:\n                left_input_halo.zero_()\n            else:\n                left_input_halo.copy_(ag_left_input_halo)\n            if self.right_zero:\n                right_input_halo.zero_()\n            else:\n                right_input_halo.copy_(ag_right_input_halo)\n\n\nclass HaloExchangerSendRecv(HaloExchanger):\n    def __init__(self, ranks, rank_in_group):\n        super(HaloExchangerSendRecv, self).__init__(ranks, rank_in_group)\n        nccl_id = inc.get_unique_nccl_id(1).cuda()\n        torch.distributed.broadcast(nccl_id, 0)\n        nccl_id = nccl_id.cpu()\n        print(\"%d :: nccl_id = %s\" % (torch.distributed.get_rank(), str(nccl_id)))\n        # Create another global nccl communicator in addition to the one created by torch.distributed.init_process_group(\"nccl\")\n        # This is unavoidable because the underlying NCCL communicator torch.distributed creates is a protected variable, hence\n        # it cannot be accessed from another class.\n        # TODO: Figure out a way to avoid creating a second global communicator\n        assert torch.distributed.get_rank() == self.ranks[self.rank_in_group], (\n            \"ranks[%d](%d) != torch.distributed.get_rank()(%d)\"\n            % (\n                self.rank_in_group,\n                self.ranks[self.rank_in_group],\n                torch.distributed.get_rank(),\n            )\n        )\n        self.handle = inc.init_nccl_comm(\n            nccl_id, torch.distributed.get_rank(), torch.distributed.get_world_size()\n        )\n\n    def left_right_halo_exchange(\n        self,\n        left_output_halo,\n        right_output_halo,\n        left_input_halo=None,\n        right_input_halo=None,\n    ):\n        if left_input_halo is None:\n            left_input_halo, right_input_halo = inc.left_right_halo_exchange(\n                self.handle,\n                self.left_rank,\n                self.right_rank,\n                left_output_halo,\n                right_output_halo,\n            )\n            return left_input_halo, right_input_halo\n        else:\n            inc.left_right_halo_exchange_inplace(\n                self.handle,\n                self.left_rank,\n                self.right_rank,\n                left_output_halo,\n                right_output_halo,\n                left_input_halo,\n                right_input_halo,\n            )\n\n\nclass HaloExchangerPeer(HaloExchanger):\n    def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=0):\n        super(HaloExchangerPeer, self).__init__(ranks, rank_in_group)\n        self.diagnostics = False\n        self.explicit_nhwc = explicit_nhwc\n        self.numSM = numSM\n        self.peer_pool = peer_pool\n\n    def _allocate_peer_tensor(self, halo):\n        # Compute size in bytes\n        # Note: Pad buffer so each CUDA block gets required buffer size\n        size = 4 * halo.numel() * halo.element_size()\n        size_per_block = 128 * 2 * 16  # 128 threads each require two 128b buffers\n        size = (size + size_per_block - 1) // size_per_block * size_per_block\n\n        # Construct dtype peer buffer with desired size\n        shape = [1, 1, 1, size // halo.element_size()]\n        return self.peer_pool.allocate_peer_tensors(shape, halo.dtype, False, True)\n\n    def left_right_halo_exchange(\n        self,\n        left_output_halo,\n        right_output_halo,\n        left_input_halo=None,\n        right_input_halo=None,\n    ):\n        inplace = False if left_input_halo is None and right_input_halo is None else True\n        if not inplace:\n            left_input_halo = torch.empty_like(right_output_halo)\n            right_input_halo = torch.empty_like(left_output_halo)\n        channels_last = (\n            left_output_halo.is_contiguous(memory_format=torch.channels_last)\n            and not self.explicit_nhwc\n        )\n        left_tx = self._allocate_peer_tensor(left_input_halo)\n        right_tx = self._allocate_peer_tensor(right_input_halo)\n        pm.push_pull_halos_1d(\n            self.diagnostics,\n            self.explicit_nhwc,\n            self.numSM,\n            self.rank_in_group,\n            self.left_zero,\n            left_output_halo,\n            left_tx[self.rank_in_group],\n            right_tx[self.wrap_around_left_rank_in_group],\n            left_input_halo,\n            self.right_zero,\n            right_output_halo,\n            right_tx[self.rank_in_group],\n            left_tx[self.wrap_around_right_rank_in_group],\n            right_input_halo,\n        )\n        if not inplace:\n            return left_input_halo, right_input_halo\n\n\n# Class that combines input volume with halos from neighbors (1d).\nclass HaloPadder:\n    def __init__(self, halo_ex):\n        self.halo_ex = halo_ex\n        self.stream1 = torch.cuda.Stream()\n        self.stream2 = torch.cuda.Stream()\n\n    def __call__(self, y, half_halo, explicit_nhwc, H_split):\n        channels_last = not explicit_nhwc and y.is_contiguous(memory_format=torch.channels_last)\n        if explicit_nhwc:\n            N, H, W, C = list(y.shape)\n            if H_split:\n                padded_shape = [N, H + 2 * half_halo, W, C]\n                ypad = torch.empty(\n                    shape=padded_shape,\n                    dtype=y.dtype,\n                    device=y.device,\n                    memory_format=torch.contiguous_format,\n                )\n                yleft = ypad[:, :half_halo, :, :]\n                ymid = ypad[:, half_halo : H + half_halo, :, :]\n                yright = ypad[:, H + half_halo : H + 2 * half_halo, :, :]\n                oleft = y[:, :half_halo, :, :]\n                oright = y[:, H - half_halo :, :, :]\n            else:\n                padded_shape = [N, H, W + 2 * half_halo, C]\n                ypad = torch.empty(\n                    shape=padded_shape,\n                    dtype=y.dtype,\n                    device=y.device,\n                    memory_format=torch.contiguous_format,\n                )\n                yleft = ypad[:, :, :half_halo, :]\n                ymid = ypad[:, :, half_halo : W + half_halo, :]\n                yright = ypad[:, :, W + half_halo : W + 2 * half_halo, :]\n                oleft = y[:, :, :half_halo, :]\n                oright = y[:, :, W - half_halo :, :]\n        else:\n            N, C, H, W = list(y.shape)\n            if H_split:\n                padded_shape = [N, C, H + 2 * half_halo, W]\n                ypad = torch.empty(\n                    shape=padded_shape,\n                    dtype=y.dtype,\n                    device=y.device,\n                    memory_format=torch.channels_last,\n                )\n                yleft = ypad[:, :, :half_halo, :]\n                ymid = ypad[:, :, half_halo : H + half_halo, :]\n                yright = ypad[:, :, H + half_halo : H + 2 * half_halo, :]\n                oleft = y[:, :, :half_halo, :]\n                oright = y[:, :, H - half_halo :, :]\n            else:\n                padded_shape = [N, C, H, W + 2 * half_halo]\n                ypad = torch.empty(\n                    shape=padded_shape,\n                    dtype=y.dtype,\n                    device=y.device,\n                    memory_format=torch.channels_last,\n                )\n                yleft = ypad[:, :, :, :half_halo]\n                ymid = ypad[:, :, :, half_halo : W + half_halo]\n                yright = ypad[:, :, :, W + half_halo : W + 2 * half_halo]\n                oleft = y[:, :, :, :half_halo]\n                oright = y[:, :, :, W - half_halo :]\n        with torch.cuda.stream(self.stream1):\n            self.halo_ex(oleft, oright, yleft, yright)\n        with torch.cuda.stream(self.stream2):\n            ymid.copy_(y)\n        return ypad\n\n    def wait(self):\n        current_stream = torch.cuda.current_stream()\n        current_stream.wait_stream(self.stream1)\n        current_stream.wait_stream(self.stream2)\n"
  },
  {
    "path": "apex/contrib/bottleneck/test.py",
    "content": "import torch\nfrom bottleneck import Bottleneck\n\ntorch.manual_seed(23337)\n\n# use True to print layerwise sum for all outputs in reference code path\nDEBUG = False  # True\n\nfor stride, o_channel in [(1, 32), (1, 128), (2, 32)]:\n    print(\"testing stride ==\", stride, \", in_channel == 32 , out_channel ==\", o_channel)\n    a_ = torch.randn(17, 32, 28, 28)\n\n    a = a_.cuda().half().to(memory_format=torch.channels_last).requires_grad_()\n    model = (\n        Bottleneck(32, 8, o_channel, stride=stride)\n        .cuda()\n        .half()\n        .to(memory_format=torch.channels_last)\n    )\n\n    # test model\n    b = model(a)\n    b.mean().backward()\n    d_grad = a.grad.float()\n    a.grad = None\n    torch.cuda.synchronize()\n\n    if DEBUG:\n        print(\"[DEBUG] ref dx :\", d_grad.sum().item())\n        # print wgrad. we don't need to reset since later cpp print before accumulation\n        for i, w in enumerate(model.w_conv):\n            print(\"[DEBUG] ref wgrad{} :\".format(i + 1), w.grad.sum().item())\n\n    wgrads = []\n    for w in model.w_conv:\n        wgrads.append(w.grad.float())\n\n    model.use_cudnn = True\n    model.zero_grad()\n    c = model(a)\n    c.mean().backward()\n\n    torch.cuda.synchronize()\n    print(\"comparing native and channels_last:\")\n    print(\n        \"max error fprop:\",\n        (b - c).abs().max().item(),\n        \"max elem:\",\n        b.abs().max().item(),\n    )\n    print(\n        \"max error dgrad:\",\n        (d_grad - a.grad.float()).abs().max().item(),\n        \"max elem:\",\n        d_grad.abs().max().item(),\n    )\n    for i, (w, wgrad) in enumerate(zip(model.w_conv, wgrads)):\n        print(\n            \"max error wgrad{}:\".format(i + 1),\n            (wgrad - w.grad.float()).abs().max().item(),\n            \"max elem:\",\n            wgrad.abs().max().item(),\n        )\n\n    nhwc_a = a_.permute(0, 2, 3, 1).contiguous().cuda().half().requires_grad_()\n    nhwc_model = (\n        Bottleneck(32, 8, o_channel, stride=stride, explicit_nhwc=True, use_cudnn=True)\n        .cuda()\n        .half()\n    )\n    for p, q in zip(model.parameters(), nhwc_model.parameters()):\n        # model's storage is already in nhwc, we clone and assign to explicit nhwc model\n        q.data.copy_(p.data.permute(0, 2, 3, 1).contiguous())\n    for p, q in zip(model.buffers(), nhwc_model.buffers()):\n        q.data.copy_(p.data)\n\n    d = nhwc_model(nhwc_a)\n    d.mean().backward()\n    torch.cuda.synchronize()\n\n    # reset reference to cudnn channels_last permute\n    # c_s = c.storage().tolist()\n    # d_s = d.storage().tolist()\n    # print(max([x-y for x,y in zip(c_s,d_s)]))\n    c = c.contiguous(memory_format=torch.contiguous_format).permute(0, 2, 3, 1).contiguous()\n    d_grad = a.grad.float().permute(0, 2, 3, 1).contiguous()\n    wgrads = []\n    for w in model.w_conv:\n        wgrads.append(w.grad.float().permute(0, 2, 3, 1).contiguous())\n\n    torch.cuda.synchronize()\n    print(\"comparing nhwc and channels_last:\")\n    print(\n        \"max error fprop:\",\n        (d - c).abs().max().item(),\n        \"max elem:\",\n        c.abs().max().item(),\n    )\n    print(\n        \"max error dgrad:\",\n        (d_grad - nhwc_a.grad.float()).abs().max().item(),\n        \"max elem:\",\n        d_grad.abs().max().item(),\n    )\n    for i, (w, wgrad) in enumerate(zip(nhwc_model.w_conv, wgrads)):\n        print(\n            \"max error wgrad{}:\".format(i + 1),\n            (wgrad - w.grad.float()).abs().max().item(),\n            \"max elem:\",\n            wgrad.abs().max().item(),\n        )\n"
  },
  {
    "path": "apex/contrib/clip_grad/__init__.py",
    "content": "from .clip_grad import clip_grad_norm_\n"
  },
  {
    "path": "apex/contrib/clip_grad/clip_grad.py",
    "content": "from typing import Union, Iterable\n\nimport torch\n\n_kernel_import_succeeded = False\ntry:\n    import amp_C\n    from apex.multi_tensor_apply import multi_tensor_applier\n\n    _kernel_import_succeeded = True\nexcept ImportError:\n    _kernel_import_succeeded = False\n\n_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]\n\n\ndef clip_grad_norm_(\n    parameters: _tensor_or_tensors,\n    max_norm: float,\n    norm_type: float = 2.0,\n    error_if_nonfinite: bool = False,\n) -> torch.Tensor:\n    r\"\"\"Clips gradient norm of an iterable of parameters.\n\n    The norm is computed over all gradients together, as if they were\n    concatenated into a single vector. Gradients are modified in-place.\n\n    This is identical to torch.nn.utils.clip_grad_norm_, except it\n    uses a fused CUDA kernel when computing the 2-norm of GPU tensors\n    in float32 and float16.\n\n    Args:\n        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a\n            single Tensor that will have gradients normalized\n        max_norm (float or int): max norm of the gradients\n        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for\n            infinity norm.\n        error_if_nonfinite (bool): if True, an error is thrown if the total\n            norm of the gradients from :attr:`parameters` is ``nan``,\n            ``inf``, or ``-inf``. Default: False (will switch to True in the future)\n\n    Returns:\n        Total norm of the parameters (viewed as a single vector).\n\n    \"\"\"\n    if isinstance(parameters, torch.Tensor):\n        parameters = [parameters]\n    parameters = [p for p in parameters if p.grad is not None]\n    max_norm = float(max_norm)\n    norm_type = float(norm_type)\n\n    # Trivial case\n    if len(parameters) == 0:\n        return torch.tensor(0.0)\n\n    # Fallback implementation\n    if not (_kernel_import_succeeded and norm_type == 2.0 and any(p.is_cuda for p in parameters)):\n        return torch.nn.utils.clip_grad_norm_(\n            parameters,\n            max_norm,\n            norm_type=norm_type,\n            error_if_nonfinite=error_if_nonfinite,\n        )\n\n    # Find fp32 and fp16 gradients on GPU\n    device = next(p.device for p in parameters if p.is_cuda)\n    grads_fp32, grads_fp16, grads_misc = [], [], []\n    for p in parameters:\n        grad = p.grad.detach()\n        if p.dtype == torch.float32 and p.device == device:\n            grads_fp32.append(grad)\n        elif p.dtype == torch.float16 and p.device == device:\n            grads_fp16.append(grad)\n        else:\n            grads_misc.append(grad)\n\n    # Compute gradient L2 norms\n    norms = []\n    dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=device)\n    if grads_fp32:\n        norms.append(\n            multi_tensor_applier(\n                amp_C.multi_tensor_l2norm,\n                dummy_overflow_buf,\n                [grads_fp32],\n                False,\n            )[0]\n        )\n    if grads_fp16:\n        norms.append(\n            multi_tensor_applier(\n                amp_C.multi_tensor_l2norm,\n                dummy_overflow_buf,\n                [grads_fp16],\n                False,\n            )[0],\n        )\n    for g in grads_misc:\n        norms.append(torch.linalg.norm(g).unsqueeze(0).to(device))\n    total_norm = torch.linalg.norm(torch.cat(norms))\n\n    # Check for non-finite values\n    if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):\n        raise RuntimeError(\n            f\"The total norm of order {norm_type} for gradients from \"\n            \"`parameters` is non-finite, so it cannot be clipped. To disable \"\n            \"this error and scale the gradients by the non-finite norm anyway, \"\n            \"set `error_if_nonfinite=False`\"\n        )\n\n    # Scale gradients\n    clip_coef = max_norm / (total_norm + 1e-6)\n    clip_coef_clamped = torch.clamp(clip_coef, max=1.0)\n    if grads_fp32:\n        multi_tensor_applier(\n            amp_C.multi_tensor_scale,\n            dummy_overflow_buf,\n            [grads_fp32, grads_fp32],\n            clip_coef_clamped,\n        )\n    if grads_fp16:\n        multi_tensor_applier(\n            amp_C.multi_tensor_scale,\n            dummy_overflow_buf,\n            [grads_fp16, grads_fp16],\n            clip_coef_clamped,\n        )\n    for g in grads_misc:\n        g.mul_(clip_coef_clamped.to(g.device))\n\n    return total_norm\n"
  },
  {
    "path": "apex/contrib/conv_bias_relu/__init__.py",
    "content": "from .conv_bias_relu import (\n    ConvBiasReLU,\n    ConvBias,\n    ConvBiasMaskReLU,\n    ConvFrozenScaleBiasReLU,\n)\n"
  },
  {
    "path": "apex/contrib/conv_bias_relu/conv_bias_relu.py",
    "content": "import torch\n\nfrom apex import check_cudnn_version_and_warn\nimport fused_conv_bias_relu\n\ncheck_cudnn_version_and_warn(__name__, 8400)\n\n\nclass ConvBiasReLU_(torch.autograd.Function):\n    @staticmethod\n    @torch.amp.custom_fwd(cast_inputs=torch.half, device_type=\"cuda\")\n    def forward(ctx, x, weight, bias, padding, stride):\n        outputs = fused_conv_bias_relu.forward([x, weight, bias], padding, stride)\n        ctx.save_for_backward(x, weight, outputs[0])\n        ctx.padding = padding\n        ctx.stride = stride\n\n        return outputs[0]\n\n    @staticmethod\n    @torch.amp.custom_bwd(device_type=\"cuda\")\n    def backward(ctx, grad_output):\n        bwd_args = [*ctx.saved_tensors, grad_output]\n        padding = ctx.padding\n        stride = ctx.stride\n        grads = fused_conv_bias_relu.backward(bwd_args, padding, stride)\n\n        return grads[0], grads[1], grads[2], None, None\n\n\nclass ConvBiasMaskReLU_(torch.autograd.Function):\n    @staticmethod\n    @torch.amp.custom_fwd(cast_inputs=torch.half, device_type=\"cuda\")\n    def forward(ctx, x, weight, bias, mask, padding, stride):\n        outputs = fused_conv_bias_relu.forward_mask([x, weight, bias, mask], padding, stride)\n        ctx.save_for_backward(x, weight, outputs[0])\n        ctx.padding = padding\n        ctx.stride = stride\n\n        return outputs[0]\n\n    @staticmethod\n    @torch.amp.custom_bwd(device_type=\"cuda\")\n    def backward(ctx, grad_output):\n        bwd_args = [*ctx.saved_tensors, grad_output]\n        padding = ctx.padding\n        stride = ctx.stride\n        grads = fused_conv_bias_relu.backward(bwd_args, padding, stride)\n\n        return grads[0], grads[1], grads[2], None, None, None\n\n\nclass ConvBias_(torch.autograd.Function):\n    @staticmethod\n    @torch.amp.custom_fwd(cast_inputs=torch.half, device_type=\"cuda\")\n    def forward(ctx, x, weight, bias, padding, stride):\n        outputs = fused_conv_bias_relu.forward_no_relu([x, weight, bias], padding, stride)\n        ctx.save_for_backward(x, weight)\n        ctx.padding = padding\n        ctx.stride = stride\n\n        return outputs[0]\n\n    @staticmethod\n    @torch.amp.custom_bwd(device_type=\"cuda\")\n    def backward(ctx, grad_output):\n        bwd_args = [*ctx.saved_tensors, grad_output]\n        padding = ctx.padding\n        stride = ctx.stride\n        grads = fused_conv_bias_relu.backward_no_relu(bwd_args, padding, stride)\n\n        return grads[0], grads[1], grads[2], None, None\n\n\nclass ConvFrozenScaleBiasReLU_(torch.autograd.Function):\n    @staticmethod\n    @torch.amp.custom_fwd(cast_inputs=torch.half, device_type=\"cuda\")\n    def forward(ctx, x, weight, scale, bias, padding, stride):\n        output = fused_conv_bias_relu.forward_cscale_cbias_relu(\n            [x, weight, scale, bias], padding, stride\n        )\n        ctx.save_for_backward(x, weight, scale, output)\n        ctx.padding = padding\n        ctx.stride = stride\n\n        return output\n\n    @staticmethod\n    @torch.amp.custom_bwd(device_type=\"cuda\")\n    def backward(ctx, grad_output):\n        bwd_args = [*ctx.saved_tensors, grad_output]\n        padding = ctx.padding\n        stride = ctx.stride\n        grads = fused_conv_bias_relu.backward_cscale_cbias_relu(bwd_args, padding, stride)\n\n        return grads[0], grads[1], None, None, None, None\n\n\nConvBiasReLU = ConvBiasReLU_.apply\nConvBiasMaskReLU = ConvBiasMaskReLU_.apply\nConvBias = ConvBias_.apply\nConvFrozenScaleBiasReLU = ConvFrozenScaleBiasReLU_.apply\n"
  },
  {
    "path": "apex/contrib/csrc/bottleneck/bottleneck.cpp",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cudnn/Handle.h>  // for getcudnnhandle\n#include <cudnn_frontend.h>\n#include <torch/extension.h>\n#include <torch/torch.h>\n\n#include <iostream>\n#include <vector>\n\n#ifdef DEBUG\n#define DEBUG_MSG(str)             \\\n  do {                             \\\n    std::cout << str << std::endl; \\\n  } while (false)\n#else\n#define DEBUG_MSG(str) \\\n  do {                 \\\n  } while (false)\n#endif\n\n#ifdef DEBUG_CUDNN\n#define DEBUG_CUDNN_MSG(buf, str) \\\n  do {                            \\\n    buf << str << std::endl;      \\\n  } while (false)\n#else\n#define DEBUG_CUDNN_MSG(buf, str) \\\n  do {                            \\\n  } while (false)\n#endif\n\n#define checkCudnnErr(...)                                                    \\\n  do {                                                                        \\\n    int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \\\n    if (err) {                                                                \\\n      return;                                                                 \\\n    }                                                                         \\\n  } while (0)\n\nint checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) {\n  if (code) {\n    printf(\"CUDNN error at %s:%d, code=%d (%s) in '%s'\\n\", file, line, (int)code, cudnnGetErrorString(code), expr);\n    return 1;\n  }\n  return 0;\n}\n\nvoid checkError(cudaError_t code, char const* func, const char* file, const int line, bool abort = true);\n#define checkCUDAError(val)                      \\\n  {                                              \\\n    checkError((val), #val, __FILE__, __LINE__); \\\n  }  // in-line regular function\n\nvoid checkError(cudaError_t code, char const* func, const char* file, const int line, bool abort) {\n  if (code != cudaSuccess) {\n    const char* errorMessage = cudaGetErrorString(code);\n    fprintf(stderr, \"CUDA error returned from \\\"%s\\\" at %s:%d, Error code: %d (%s)\\n\", func, file, line, code,\n            errorMessage);\n    if (abort) {\n      cudaDeviceReset();\n      exit(code);\n    }\n  }\n}\n\nvoid generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, cudnnTensorFormat_t filterFormat) {\n  // For INT8x4 and INT8x32 we still compute standard strides here to input\n  // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref.\n  if (filterFormat == CUDNN_TENSOR_NCHW) {\n    strideA[nbDims - 1] = 1;\n    for (int64_t d = nbDims - 2; d >= 0; d--) {\n      strideA[d] = strideA[d + 1] * dimA[d + 1];\n    }\n  } else {\n    // Here we assume that the format is CUDNN_TENSOR_NHWC\n    strideA[1] = 1;\n    strideA[nbDims - 1] = strideA[1] * dimA[1];\n    for (int64_t d = nbDims - 2; d >= 2; d--) {\n      strideA[d] = strideA[d + 1] * dimA[d + 1];\n    }\n    strideA[0] = strideA[2] * dimA[2];\n  }\n}\n\nint getFwdConvDilatedFilterDim(int filterDim, int dilation) { return ((filterDim - 1) * dilation) + 1; }\n\nint getFwdConvPaddedImageDim(int tensorDim, int pad) { return tensorDim + (2 * pad); }\n\nint getFwdConvOutputDim(int tensorDim, int pad, int filterDim, int stride, int dilation) {\n  int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1;\n  return (p);\n}\n\nenum {\n  X_TENSOR,\n  Y_TENSOR,\n  W_TENSOR,\n  Z_TENSOR,\n  B_TENSOR,\n  AFTERADD_TENSOR,\n  AFTERBIAS_TENSOR,\n  AFTERCONV_TENSOR,\n  OPTIONAL,\n  AFTEROPT_TENSOR,\n};\n\nusing common_conv_descriptors =\n    std::tuple<cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::ConvDesc>;\n\ncommon_conv_descriptors create_common_descriptors(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA,\n                                                  int64_t* dilationA, int64_t* w_dim_padded, int64_t* y_dim_padded,\n                                                  cudnnDataType_t dataType, cudnnConvolutionMode_t mode) {\n  const int convDim = 2;\n\n  int64_t strideA_padded[4];\n  int64_t outstrideA_padded[4];\n  int64_t filterstrideA_padded[4];\n\n  generateStrides(w_dim_padded, filterstrideA_padded, 4, CUDNN_TENSOR_NHWC);\n  generateStrides(x_dim_padded, strideA_padded, 4, CUDNN_TENSOR_NHWC);\n  generateStrides(y_dim_padded, outstrideA_padded, 4, CUDNN_TENSOR_NHWC);\n\n  return common_conv_descriptors(cudnn_frontend::TensorBuilder()\n                                     .setDim(4, x_dim_padded)\n                                     .setStrides(4, strideA_padded)\n                                     .setId('x')\n                                     .setAlignment(16)\n                                     .setDataType(dataType)\n                                     .build(),\n                                 cudnn_frontend::TensorBuilder()\n                                     .setDim(4, y_dim_padded)\n                                     .setStrides(4, outstrideA_padded)\n                                     .setId('y')\n                                     .setAlignment(16)\n                                     .setDataType(dataType)\n                                     .build(),\n                                 cudnn_frontend::TensorBuilder()\n                                     .setDim(4, w_dim_padded)\n                                     .setStrides(4, filterstrideA_padded)\n                                     .setId('w')\n                                     .setAlignment(16)\n                                     .setDataType(dataType)\n                                     .build(),\n                                 cudnn_frontend::ConvDescBuilder()\n                                     .setDataType(CUDNN_DATA_FLOAT)\n                                     .setMathMode(mode)\n                                     .setNDims(convDim)\n                                     .setStrides(convDim, convstrideA)\n                                     .setPrePadding(convDim, padA)\n                                     .setPostPadding(convDim, padA)\n                                     .setDilation(convDim, dilationA)\n                                     .build());\n}\n\nusing common_convbias_descriptors =\n    std::tuple<cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,\n               cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,\n               cudnn_frontend::Tensor, cudnn_frontend::Tensor>;\n\ncommon_convbias_descriptors create_conv_bias_add_act_descriptors(int64_t* x_dim_padded, int64_t* padA,\n                                                                 int64_t* convstrideA, int64_t* dilationA,\n                                                                 int64_t* w_dim_padded, int64_t* y_dim_padded,\n                                                                 cudnnDataType_t dataType) {\n  const int convDim = 2;\n\n  int64_t b_dim_padded[4];\n  b_dim_padded[0] = 1;\n  b_dim_padded[1] = y_dim_padded[1];\n  b_dim_padded[2] = 1;\n  b_dim_padded[3] = 1;\n\n  int64_t x_stride_padded[4];\n  int64_t y_stride_padded[4];\n  int64_t w_stride_padded[4];\n  int64_t b_stride_padded[4];\n\n  generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);\n  generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);\n  generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);\n  generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);\n\n  return common_convbias_descriptors(cudnn_frontend::TensorBuilder()\n                                         .setDim(4, x_dim_padded)\n                                         .setStrides(4, x_stride_padded)\n                                         .setId('x')\n                                         .setAlignment(16)\n                                         .setDataType(dataType)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, y_dim_padded)\n                                         .setStrides(4, y_stride_padded)\n                                         .setId('y')\n                                         .setAlignment(16)\n                                         .setDataType(dataType)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, w_dim_padded)\n                                         .setStrides(4, w_stride_padded)\n                                         .setId('w')\n                                         .setAlignment(16)\n                                         .setDataType(dataType)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, b_dim_padded)\n                                         .setStrides(4, b_stride_padded)\n                                         .setId('z')\n                                         .setAlignment(16)\n                                         .setDataType(dataType)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, b_dim_padded)\n                                         .setStrides(4, b_stride_padded)\n                                         .setId('b')\n                                         .setAlignment(16)\n                                         .setDataType(dataType)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, y_dim_padded)\n                                         .setStrides(4, y_stride_padded)\n                                         .setVirtual()\n                                         .setId('A')  // after add\n                                         .setAlignment(16)\n                                         .setDataType(CUDNN_DATA_FLOAT)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, y_dim_padded)\n                                         .setStrides(4, y_stride_padded)\n                                         .setVirtual()\n                                         .setId('B')  // after bias\n                                         .setAlignment(16)\n                                         .setDataType(CUDNN_DATA_FLOAT)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, y_dim_padded)\n                                         .setStrides(4, y_stride_padded)\n                                         .setId('C')  // after conv\n                                         .setAlignment(16)\n                                         .setVirtual()\n                                         .setDataType(CUDNN_DATA_FLOAT)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, y_dim_padded)\n                                         .setStrides(4, y_stride_padded)\n                                         .setId('i')\n                                         .setAlignment(16)\n                                         .setDataType(dataType)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, y_dim_padded)\n                                         .setStrides(4, y_stride_padded)\n                                         .setId('D')  // after optional add\n                                         .setAlignment(16)\n                                         .setVirtual()\n                                         .setDataType(CUDNN_DATA_FLOAT)\n                                         .build());\n}\n\n// tensor descriptors used for dgrad\nenum {\n  X_OR_DX_TENSOR,\n  DY_TENSOR,\n  W_OR_DW_TENSOR,\n  SCALE_TENSOR,\n  RELU_TENSOR,\n  AFTER_DCONV_TENSOR,\n  AFTER_DRELU_TENSOR,\n};\n\nusing dconv_descriptors =\n    std::tuple<cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,\n               cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor>;\n\ndconv_descriptors create_dconv_descriptors(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA,\n                                           int64_t* dilationA, int64_t* w_dim_padded, int64_t* y_dim_padded,\n                                           cudnnDataType_t dataType) {\n  const int convDim = 2;\n\n  int64_t b_dim_padded[4];\n  b_dim_padded[0] = 1;\n  b_dim_padded[1] = x_dim_padded[1];\n  b_dim_padded[2] = 1;\n  b_dim_padded[3] = 1;\n\n  int64_t x_stride_padded[4];\n  int64_t y_stride_padded[4];\n  int64_t w_stride_padded[4];\n  int64_t b_stride_padded[4];\n\n  generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);\n  generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);\n  generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);\n  generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);\n\n  return dconv_descriptors(cudnn_frontend::TensorBuilder()\n                               .setDim(4, x_dim_padded)\n                               .setStrides(4, x_stride_padded)\n                               .setId('x')\n                               .setAlignment(16)\n                               .setDataType(dataType)\n                               .build(),\n                           cudnn_frontend::TensorBuilder()\n                               .setDim(4, y_dim_padded)\n                               .setStrides(4, y_stride_padded)\n                               .setId('y')\n                               .setAlignment(16)\n                               .setDataType(dataType)\n                               .build(),\n                           cudnn_frontend::TensorBuilder()\n                               .setDim(4, w_dim_padded)\n                               .setStrides(4, w_stride_padded)\n                               .setId('w')\n                               .setAlignment(16)\n                               .setDataType(dataType)\n                               .build(),\n                           cudnn_frontend::TensorBuilder()\n                               .setDim(4, b_dim_padded)\n                               .setStrides(4, b_stride_padded)\n                               .setId('s')\n                               .setAlignment(16)\n                               .setDataType(dataType)\n                               .build(),\n                           cudnn_frontend::TensorBuilder()\n                               .setDim(4, x_dim_padded)\n                               .setStrides(4, x_stride_padded)\n                               .setId('r')\n                               .setAlignment(16)\n                               .setDataType(dataType)\n                               .build(),\n                           cudnn_frontend::TensorBuilder()\n                               .setDim(4, x_dim_padded)\n                               .setStrides(4, x_stride_padded)\n                               .setVirtual()\n                               .setId('A')  // after dconv\n                               .setAlignment(16)\n                               .setDataType(CUDNN_DATA_FLOAT)\n                               .build(),\n                           cudnn_frontend::TensorBuilder()\n                               .setDim(4, x_dim_padded)\n                               .setStrides(4, x_stride_padded)\n                               .setVirtual()\n                               .setId('B')  // after drelu\n                               .setAlignment(16)\n                               .setDataType(CUDNN_DATA_FLOAT)\n                               .build());\n}\n\n// create a cache for plan\nstd::unordered_map<std::string, cudnn_frontend::ExecutionPlan> plan_cache;\n\n// TODO: better name\nstd::string getConvFusionString(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA, int64_t* dilationA,\n                                int64_t* w_dim_padded, cudnnDataType_t dataType, std::string fusion_string) {\n  for (int i = 0; i < 4; i++) {\n    fusion_string += 'X';\n    fusion_string += std::to_string(x_dim_padded[i]);\n  }\n  for (int i = 0; i < 4; i++) {\n    fusion_string += 'W';\n    fusion_string += std::to_string(w_dim_padded[i]);\n  }\n  for (int i = 0; i < 2; i++) {\n    fusion_string += 'P';\n    fusion_string += std::to_string(padA[i]);\n  }\n  for (int i = 0; i < 2; i++) {\n    fusion_string += 'S';\n    fusion_string += std::to_string(convstrideA[i]);\n  }\n  for (int i = 0; i < 2; i++) {\n    fusion_string += 'D';\n    fusion_string += std::to_string(dilationA[i]);\n  }\n  fusion_string += 'T';\n  fusion_string += std::to_string(dataType);\n  return fusion_string;\n}\n\ncudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_, std::stringstream& log_buf,\n                                               cudnn_frontend::OperationGraph& opGraph, std::string cache_string,\n                                               bool use_heuristic = true) {\n  auto it = plan_cache.find(cache_string);\n  if (it != plan_cache.end()) {\n    DEBUG_CUDNN_MSG(log_buf, \"Found plan in cache\");\n    return it->second;\n  } else {\n    if (use_heuristic) {\n      // TODO: confirm which mode to use\n      auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()\n                            .setOperationGraph(opGraph)\n                            .setHeurMode(CUDNN_HEUR_MODE_INSTANT)\n                            .build();\n      // try 3 times for now as WAR for no heuristic training\n      int max_tries = 3, count = 0;\n      auto& engine_configs = heuristics.getEngineConfig(max_tries);\n      while (true) {\n        try {\n          plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder()\n                                                         .setHandle(handle_)\n                                                         .setEngineConfig(engine_configs[count], opGraph.getTag())\n                                                         .build()));\n          break;\n        } catch (cudnn_frontend::cudnnException e) {\n          if (++count == max_tries) throw e;\n        }\n      }\n    } else {\n      DEBUG_CUDNN_MSG(log_buf, \"No plan in cache\");\n      // How many engines support this operation graph ?\n      auto total_engines = opGraph.getEngineCount();\n      DEBUG_CUDNN_MSG(log_buf, opGraph.describe() << \" has \" << total_engines << \" engines.\");\n      // We have to randomly pick one engine from [0, total_engines)\n      // Selecting \"0\" by default\n      auto engine = cudnn_frontend::EngineBuilder().setGlobalEngineIdx(0).setOperationGraph(opGraph).build();\n      DEBUG_CUDNN_MSG(log_buf, engine.describe());\n      auto& knobs = engine.getSupportedKnobs();\n      for (auto it = std::begin(knobs); it != std::end(knobs); ++it) {\n        DEBUG_CUDNN_MSG(log_buf, it->describe());\n      }\n      if (knobs.begin() != knobs.end()) {\n        DEBUG_CUDNN_MSG(log_buf, \"Updated knob choice\");\n        knobs.begin()->setChoice(knobs.begin()->getMinValue() + 1);\n        DEBUG_CUDNN_MSG(log_buf, knobs.begin()->describe());\n      }\n\n      // Createmplacee the requisite engine config\n      auto engine_config = cudnn_frontend::EngineConfigBuilder().setEngine(engine).build();\n      DEBUG_CUDNN_MSG(log_buf, engine_config.describe());\n      plan_cache.emplace(\n          cache_string,\n          std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build()));\n    }\n\n    return plan_cache.find(cache_string)->second;\n  }\n}\n\nvoid run_conv_scale_bias_add_activation(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation,\n                                        int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType,\n                                        at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ,\n                                        at::Half* devPtrB, at::Half* devPtrI) {\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n  std::stringstream log_buf;\n  try {\n    int convDim = 2;\n\n    // Creates the necessary tensor descriptors\n    common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(x_dim_padded, pad, convstride, dilation,\n                                                                               w_dim_padded, y_dim_padded, dataType);\n    DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());\n\n    // Define the add operation\n    auto scaleDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());\n\n    // Define the bias operation\n    auto biasDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());\n\n    // optional add\n    auto addDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, addDesc.describe());\n\n    // Define the activation operation\n    auto actDesc = cudnn_frontend::PointWiseDescBuilder()\n                       .setMode(CUDNN_POINTWISE_RELU_FWD)\n                       .setMathPrecision(CUDNN_DATA_FLOAT)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, actDesc.describe());\n\n    // Define the convolution problem\n    auto convDesc = cudnn_frontend::ConvDescBuilder()\n                        .setDataType(CUDNN_DATA_FLOAT)\n                        .setMathMode(CUDNN_CROSS_CORRELATION)\n                        .setNDims(convDim)\n                        .setStrides(convDim, convstride)\n                        .setPrePadding(convDim, pad)\n                        .setPostPadding(convDim, pad)\n                        .setDilation(convDim, dilation)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n    float alpha = 1.0f;\n    float beta = 0.0f;\n\n    // Create a convolution Node\n    auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)\n                       .setxDesc(std::get<X_TENSOR>(tensors))\n                       .setwDesc(std::get<W_TENSOR>(tensors))\n                       .setyDesc(std::get<AFTERCONV_TENSOR>(tensors))\n                       .setcDesc(convDesc)\n                       .setAlpha(alpha)\n                       .setBeta(beta)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n    // Create a Add Node with scaling parameters.\n    auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                        .setxDesc(conv_op.getOutputTensor())\n                        .setbDesc(std::get<Z_TENSOR>(tensors))\n                        .setyDesc(std::get<AFTERADD_TENSOR>(tensors))\n                        .setpwDesc(scaleDesc)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, scale_op.describe());\n\n    // Create a Bias Node.\n    auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                       .setxDesc(scale_op.getOutputTensor())\n                       .setbDesc(std::get<B_TENSOR>(tensors))\n                       .setyDesc(std::get<AFTERBIAS_TENSOR>(tensors))\n                       .setpwDesc(biasDesc)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, bias_op.describe());\n\n    // Create a optional add Node.\n    auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                      .setxDesc(bias_op.getOutputTensor())\n                      .setbDesc(std::get<OPTIONAL>(tensors))\n                      .setyDesc(std::get<AFTEROPT_TENSOR>(tensors))\n                      .setpwDesc(addDesc)\n                      .build();\n    DEBUG_CUDNN_MSG(log_buf, add_op.describe());\n\n    // Create an Activation Node.\n    auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                      .setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor())\n                      .setyDesc(std::get<Y_TENSOR>(tensors))\n                      .setpwDesc(actDesc)\n                      .build();\n    DEBUG_CUDNN_MSG(log_buf, act_op.describe());\n\n    // Create an Operation Graph. In this case it is convolution add bias activation\n    std::array<cudnn_frontend::Operation const*, 5> ops = {&conv_op, &scale_op, &bias_op, devPtrI ? &add_op : &act_op,\n                                                           &act_op};\n\n    auto opGraph = cudnn_frontend::OperationGraphBuilder()\n                       .setHandle(handle_)\n                       .setOperationGraph(devPtrI ? ops.size() : 4, ops.data())\n                       .build();\n\n    // Create string encoding for plan caching\n    auto cache_string =\n        getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n    DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n    auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n    DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n    auto workspace_size = plan.getWorkspaceSize();\n    DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n    void* workspace_ptr = nullptr;\n    auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n    if (workspace_size > 0) {\n      workspace_ptr = workspace_tensor.data_ptr<float>();\n    }\n    void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI};\n    int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'};\n    auto variantPack = cudnn_frontend::VariantPackBuilder()\n                           .setWorkspacePointer(workspace_ptr)\n                           .setDataPointers(devPtrI ? 6 : 5, data_ptrs)\n                           .setUids(devPtrI ? 6 : 5, uids)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n    cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n    checkCudnnErr(status);\n    cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n  } catch (cudnn_frontend::cudnnException e) {\n    std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n  }\n}\n\nvoid run_conv_scale_bias(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation,\n                         int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX,\n                         at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrB) {\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n  std::stringstream log_buf;\n  try {\n    int convDim = 2;\n\n    // Creates the necessary tensor descriptors\n    common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(x_dim_padded, pad, convstride, dilation,\n                                                                               w_dim_padded, y_dim_padded, dataType);\n    DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());\n\n    // Define the add operation\n    auto scaleDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());\n\n    // Define the bias operation\n    auto addDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, addDesc.describe());\n\n    // Define the convolution problem\n    auto convDesc = cudnn_frontend::ConvDescBuilder()\n                        .setDataType(CUDNN_DATA_FLOAT)\n                        .setMathMode(CUDNN_CROSS_CORRELATION)\n                        .setNDims(convDim)\n                        .setStrides(convDim, convstride)\n                        .setPrePadding(convDim, pad)\n                        .setPostPadding(convDim, pad)\n                        .setDilation(convDim, dilation)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n    float alpha = 1.0f;\n    float beta = 0.0f;\n\n    // Create a convolution Node\n    auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)\n                       .setxDesc(std::get<X_TENSOR>(tensors))\n                       .setwDesc(std::get<W_TENSOR>(tensors))\n                       .setyDesc(std::get<AFTERCONV_TENSOR>(tensors))\n                       .setcDesc(convDesc)\n                       .setAlpha(alpha)\n                       .setBeta(beta)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n    // Create a Add Node with scaling parameters.\n    auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                        .setxDesc(conv_op.getOutputTensor())\n                        .setbDesc(std::get<Z_TENSOR>(tensors))\n                        .setyDesc(std::get<AFTERADD_TENSOR>(tensors))  // TODO: change enum to aftermul\n                        .setpwDesc(scaleDesc)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, scale_op.describe());\n\n    // Create a Bias Node.\n    auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                      .setxDesc(scale_op.getOutputTensor())\n                      .setbDesc(std::get<B_TENSOR>(tensors))\n                      .setyDesc(std::get<Y_TENSOR>(tensors))\n                      .setpwDesc(addDesc)\n                      .build();\n    DEBUG_CUDNN_MSG(log_buf, add_op.describe());\n\n    // Create an Operation Graph. In this case it is convolution add bias activation\n    std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &scale_op, &add_op};\n\n    auto opGraph =\n        cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();\n\n    // Create string encoding for plan caching\n    auto cache_string =\n        getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n    DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n    auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n    DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n    auto workspace_size = plan.getWorkspaceSize();\n    DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n    void* workspace_ptr = nullptr;\n    auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n    if (workspace_size > 0) {\n      workspace_ptr = workspace_tensor.data_ptr<float>();\n    }\n    void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB};\n    int64_t uids[] = {'x', 'y', 'w', 'z', 'b'};\n    auto variantPack = cudnn_frontend::VariantPackBuilder()\n                           .setWorkspacePointer(workspace_ptr)\n                           .setDataPointers(5, data_ptrs)\n                           .setUids(5, uids)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n    cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n    checkCudnnErr(status);\n    cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n  } catch (cudnn_frontend::cudnnException e) {\n    std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n  }\n}\n\nvoid run_dconv_drelu_dscale(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation,\n                            int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX,\n                            at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrR) {\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n  std::stringstream log_buf;\n  try {\n    int convDim = 2;\n\n    // Creates the necessary tensor descriptors\n    dconv_descriptors tensors =\n        create_dconv_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);\n    DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());\n\n    // Define the convolution problem\n    auto convDesc = cudnn_frontend::ConvDescBuilder()\n                        .setDataType(CUDNN_DATA_FLOAT)\n                        .setMathMode(CUDNN_CROSS_CORRELATION)\n                        .setNDims(convDim)\n                        .setStrides(convDim, convstride)\n                        .setPrePadding(convDim, pad)\n                        .setPostPadding(convDim, pad)\n                        .setDilation(convDim, dilation)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n    // Define the activation backward operation\n    auto actDesc = cudnn_frontend::PointWiseDescBuilder()\n                       .setMode(CUDNN_POINTWISE_RELU_BWD)\n                       .setMathPrecision(CUDNN_DATA_FLOAT)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, actDesc.describe());\n\n    // Define the scale backward operation\n    auto scaleDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());\n\n    float alpha = 1.0f;\n    float beta = 0.0f;\n\n    // Create a convolution Node\n    auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)\n                       .setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))\n                       .setwDesc(std::get<W_OR_DW_TENSOR>(tensors))\n                       .setdyDesc(std::get<DY_TENSOR>(tensors))\n                       .setcDesc(convDesc)\n                       .setAlpha(alpha)\n                       .setBeta(beta)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n    // TODO: do we need getOutputTensor(), and what it returns in backward case?\n    // Create an relu backward Node.\n    auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                      .setdyDesc(std::get<AFTER_DCONV_TENSOR>(tensors))\n                      .setxDesc(std::get<RELU_TENSOR>(tensors))\n                      .setdxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))\n                      .setpwDesc(actDesc)\n                      .build();\n    DEBUG_CUDNN_MSG(log_buf, act_op.describe());\n\n    // Create a Scale Node.\n    auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                        .setxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))\n                        .setbDesc(std::get<SCALE_TENSOR>(tensors))\n                        .setyDesc(std::get<X_OR_DX_TENSOR>(tensors))\n                        .setpwDesc(scaleDesc)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, scale_op.describe());\n\n    // Create an Operation Graph. In this case it is convolution add bias activation\n    std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &act_op, &scale_op};\n\n    auto opGraph =\n        cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();\n\n    // Create string encoding for plan caching\n    auto cache_string =\n        getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n    DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n    auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n    DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n    auto workspace_size = plan.getWorkspaceSize();\n    DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n    void* workspace_ptr = nullptr;\n    auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n    if (workspace_size > 0) {\n      workspace_ptr = workspace_tensor.data_ptr<float>();\n    }\n    void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR};\n    int64_t uids[] = {'x', 'y', 'w', 's', 'r'};\n    auto variantPack = cudnn_frontend::VariantPackBuilder()\n                           .setWorkspacePointer(workspace_ptr)\n                           .setDataPointers(5, data_ptrs)\n                           .setUids(5, uids)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n    cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n    checkCudnnErr(status);\n    cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n  } catch (cudnn_frontend::cudnnException e) {\n    std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n  }\n}\n\nvoid run_dconv(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded,\n               int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY,\n               cudnnBackendDescriptorType_t mode) {\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n  std::stringstream log_buf;\n  try {\n    int convDim = 2;\n\n    // Creates the necessary tensor descriptors\n    dconv_descriptors tensors =\n        create_dconv_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);\n    DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());\n\n    // Define the convolution problem\n    auto convDesc = cudnn_frontend::ConvDescBuilder()\n                        .setDataType(CUDNN_DATA_FLOAT)\n                        .setMathMode(CUDNN_CROSS_CORRELATION)\n                        .setNDims(convDim)\n                        .setStrides(convDim, convstride)\n                        .setPrePadding(convDim, pad)\n                        .setPostPadding(convDim, pad)\n                        .setDilation(convDim, dilation)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n    float alpha = 1.0f;\n    float beta = 0.0f;\n\n    // Create a convolution Node\n    // mode should be one of following\n    // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR\n    // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR\n    auto conv_op_builder = cudnn_frontend::OperationBuilder(mode);\n    if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {\n      conv_op_builder.setdxDesc(std::get<X_OR_DX_TENSOR>(tensors))\n          .setwDesc(std::get<W_OR_DW_TENSOR>(tensors))\n          .setdyDesc(std::get<DY_TENSOR>(tensors))\n          .setcDesc(convDesc)\n          .setAlpha(alpha)\n          .setBeta(beta);\n    } else {\n      conv_op_builder.setxDesc(std::get<X_OR_DX_TENSOR>(tensors))\n          .setdwDesc(std::get<W_OR_DW_TENSOR>(tensors))\n          .setdyDesc(std::get<DY_TENSOR>(tensors))\n          .setcDesc(convDesc)\n          .setAlpha(alpha)\n          .setBeta(beta);\n    }\n    auto conv_op = conv_op_builder.build();\n    DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n    // Create an Operation Graph. In this case it is convolution add bias activation\n    std::array<cudnn_frontend::Operation const*, 1> ops = {&conv_op};\n\n    auto opGraph =\n        cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();\n\n    // Create string encoding for plan caching\n    auto cache_string =\n        getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n    DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n    auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n    DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n    auto workspace_size = plan.getWorkspaceSize();\n    DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n    void* workspace_ptr = nullptr;\n    auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n    if (workspace_size > 0) {\n      workspace_ptr = workspace_tensor.data_ptr<float>();\n    }\n    void* data_ptrs[] = {devPtrX, devPtrY, devPtrW};\n    int64_t uids[] = {'x', 'y', 'w'};\n    auto variantPack = cudnn_frontend::VariantPackBuilder()\n                           .setWorkspacePointer(workspace_ptr)\n                           .setDataPointers(3, data_ptrs)\n                           .setUids(3, uids)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n    cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n    checkCudnnErr(status);\n    cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n  } catch (cudnn_frontend::cudnnException e) {\n    std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n  }\n}\n\nvoid run_dconv_add(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded,\n                   int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW,\n                   at::Half* devPtrY, at::Half* devPtrR) {\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n  std::stringstream log_buf;\n  try {\n    int convDim = 2;\n\n    // Creates the necessary tensor descriptors\n    dconv_descriptors tensors =\n        create_dconv_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);\n    DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());\n\n    // Define the convolution problem\n    auto convDesc = cudnn_frontend::ConvDescBuilder()\n                        .setDataType(CUDNN_DATA_FLOAT)\n                        .setMathMode(CUDNN_CROSS_CORRELATION)\n                        .setNDims(convDim)\n                        .setStrides(convDim, convstride)\n                        .setPrePadding(convDim, pad)\n                        .setPostPadding(convDim, pad)\n                        .setDilation(convDim, dilation)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n    // Define the add backward operation\n    auto addDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, addDesc.describe());\n\n    float alpha = 1.0f;\n    float beta = 0.0f;\n\n    // Create a convolution Node\n    auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)\n                       .setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))\n                       .setwDesc(std::get<W_OR_DW_TENSOR>(tensors))\n                       .setdyDesc(std::get<DY_TENSOR>(tensors))\n                       .setcDesc(convDesc)\n                       .setAlpha(alpha)\n                       .setBeta(beta)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n    // TODO: do we need getOutputTensor(), and what it returns in backward case?\n    // Create add Node.\n    auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                      .setxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))\n                      .setbDesc(std::get<RELU_TENSOR>(tensors))\n                      .setyDesc(std::get<X_OR_DX_TENSOR>(tensors))\n                      .setpwDesc(addDesc)\n                      .build();\n    DEBUG_CUDNN_MSG(log_buf, add_op.describe());\n\n    // Create an Operation Graph. In this case it is convolution add bias activation\n    std::array<cudnn_frontend::Operation const*, 2> ops = {&conv_op, &add_op};\n\n    auto opGraph =\n        cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();\n\n    // Create string encoding for plan caching\n    auto cache_string =\n        getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n    DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n    auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n    DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n    auto workspace_size = plan.getWorkspaceSize();\n    DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n    void* workspace_ptr = nullptr;\n    auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n    if (workspace_size > 0) {\n      workspace_ptr = workspace_tensor.data_ptr<float>();\n    }\n    void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrR};\n    int64_t uids[] = {'x', 'y', 'w', 'r'};\n    auto variantPack = cudnn_frontend::VariantPackBuilder()\n                           .setWorkspacePointer(workspace_ptr)\n                           .setDataPointers(4, data_ptrs)\n                           .setUids(4, uids)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n    cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n    checkCudnnErr(status);\n    cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n  } catch (cudnn_frontend::cudnnException e) {\n    std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n  }\n}\n\n// inputs contains x,w,z,b,(i)\nstd::vector<at::Tensor> bottleneck_forward(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {\n  std::cout << std::fixed;\n  // create output vector\n  std::vector<at::Tensor> outputs;\n  auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;\n\n  // setup dimensions\n  int64_t dimA[] = {0, 0, 0, 0};\n  int64_t filterdimA1[] = {0, 0, 0, 0};\n  int64_t filterdimA2[] = {0, 0, 0, 0};\n  int64_t filterdimA3[] = {0, 0, 0, 0};\n  int64_t filterdimA4[] = {0, 0, 0, 0};\n\n  // All dim calculation after this order of n,c,h,w\n  int axis[]{0, 1, 2, 3};\n  if (explicit_nhwc) {\n    axis[0] = 0;\n    axis[1] = 3;\n    axis[2] = 1;\n    axis[3] = 2;\n  }\n  for (int dim = 0; dim < 4; dim++) {\n    dimA[dim] = inputs[0].size(axis[dim]);\n    filterdimA1[dim] = inputs[1].size(axis[dim]);\n    filterdimA2[dim] = inputs[2].size(axis[dim]);\n    filterdimA3[dim] = inputs[3].size(axis[dim]);\n  }\n  if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {\n    for (int dim = 0; dim < 4; dim++) {\n      filterdimA4[dim] = inputs[10].size(axis[dim]);\n    }\n  }\n\n  // output dim in n,c,h,w used by backend\n  int64_t outdimA1[] = {0, 0, 0, 0};  // Computed Below\n  int64_t outdimA2[] = {0, 0, 0, 0};  // Computed Below\n  int64_t outdimA3[] = {0, 0, 0, 0};  // Computed Below\n\n  // use these fixed value for test run\n  int64_t padA[] = {0, 0};\n  int64_t padA1[] = {1, 1};\n  int64_t dilationA[] = {1, 1};\n  int64_t convstrideA[] = {1, 1};\n  int64_t convstride1X1[] = {stride_1X1, stride_1X1};\n\n  // compute output from pad/stride/dilation\n  outdimA1[0] = dimA[0];\n  outdimA1[1] = filterdimA1[0];\n  for (int dim = 0; dim < 2; dim++) {\n    outdimA1[dim + 2] =\n        getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);\n  }\n\n  outdimA2[0] = outdimA1[0];\n  outdimA2[1] = filterdimA2[0];\n  for (int dim = 0; dim < 2; dim++) {\n    outdimA2[dim + 2] =\n        getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]);\n  }\n\n  outdimA3[0] = outdimA2[0];\n  outdimA3[1] = filterdimA3[0];\n  for (int dim = 0; dim < 2; dim++) {\n    outdimA3[dim + 2] =\n        getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]);\n  }\n\n  // Create output tensor in the correct shape in pytorch's view\n  int64_t outdim1[] = {0, 0, 0, 0};\n  int64_t outdim2[] = {0, 0, 0, 0};\n  int64_t outdim3[] = {0, 0, 0, 0};\n  if (explicit_nhwc) {\n    axis[0] = 0;\n    axis[1] = 2;\n    axis[2] = 3;\n    axis[3] = 1;\n  }\n  for (int dim = 0; dim < 4; dim++) {\n    outdim1[dim] = outdimA1[axis[dim]];\n    outdim2[dim] = outdimA2[axis[dim]];\n    outdim3[dim] = outdimA3[axis[dim]];\n  }\n\n  // run\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n  at::Half* w = inputs[1].data_ptr<at::Half>();\n  at::Half* z = inputs[4].data_ptr<at::Half>();\n  at::Half* b = inputs[7].data_ptr<at::Half>();\n  auto out1 = at::empty(outdim1, inputs[0].type(), output_format);\n  at::Half* y1 = out1.data_ptr<at::Half>();\n\n  run_conv_scale_bias_add_activation(dimA, padA, convstride1X1, dilationA, filterdimA1, outdimA1, CUDNN_DATA_HALF, x, w,\n                                     y1, z, b, nullptr);\n\n  DEBUG_MSG(\"[DEBUG] new relu1 : \" << out1.to(at::kFloat).sum().item<float>());\n\n  w = inputs[2].data_ptr<at::Half>();\n  z = inputs[5].data_ptr<at::Half>();\n  b = inputs[8].data_ptr<at::Half>();\n  auto out2 = at::empty(outdim2, inputs[0].type(), output_format);\n  at::Half* y2 = out2.data_ptr<at::Half>();\n\n  run_conv_scale_bias_add_activation(outdimA1, padA1, convstrideA, dilationA, filterdimA2, outdimA2, CUDNN_DATA_HALF,\n                                     y1, w, y2, z, b, nullptr);\n  DEBUG_MSG(\"[DEBUG] new relu2 : \" << out2.to(at::kFloat).sum().item<float>());\n\n  // create output of conv3\n  auto out3 = at::empty(outdim3, inputs[0].type(), output_format);\n  at::Half* y3 = out3.data_ptr<at::Half>();\n\n  // create output of conv4 that may exist\n  auto identity = at::empty_like(out3);\n  at::Half* yi = identity.data_ptr<at::Half>();\n\n  if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {\n    w = inputs[10].data_ptr<at::Half>();\n    z = inputs[11].data_ptr<at::Half>();\n    b = inputs[12].data_ptr<at::Half>();\n    run_conv_scale_bias(dimA, padA, convstride1X1, dilationA, filterdimA4, outdimA3, CUDNN_DATA_HALF, x, w, yi, z, b);\n    DEBUG_MSG(\"[DEBUG] new downsample : \" << identity.to(at::kFloat).sum().item<float>());\n  } else {\n    yi = x;\n  }\n\n  w = inputs[3].data_ptr<at::Half>();\n  z = inputs[6].data_ptr<at::Half>();\n  b = inputs[9].data_ptr<at::Half>();\n\n  run_conv_scale_bias_add_activation(outdimA2, padA, convstrideA, dilationA, filterdimA3, outdimA3, CUDNN_DATA_HALF, y2,\n                                     w, y3, z, b, yi);\n  DEBUG_MSG(\"[DEBUG] new relu3 : \" << out3.to(at::kFloat).sum().item<float>());\n\n  outputs.push_back(out1);\n  outputs.push_back(out2);\n  outputs.push_back(out3);\n\n  return outputs;\n}\n\nstd::vector<at::Tensor> bottleneck_backward(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {\n  bool requires_grad = inputs[0].requires_grad();\n\n  std::cout << std::fixed;\n  // create output vector\n  std::vector<at::Tensor> outputs;\n  auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;\n\n  // setup dimensions\n  int64_t dimA[] = {0, 0, 0, 0};\n  int64_t filterdimA1[] = {0, 0, 0, 0};\n  int64_t filterdimA2[] = {0, 0, 0, 0};\n  int64_t filterdimA3[] = {0, 0, 0, 0};\n  int64_t filterdimA4[] = {0, 0, 0, 0};\n\n  // All dim calculation after this order of n,c,h,w\n  int axis[]{0, 1, 2, 3};\n  if (explicit_nhwc) {\n    axis[0] = 0;\n    axis[1] = 3;\n    axis[2] = 1;\n    axis[3] = 2;\n  }\n  for (int dim = 0; dim < 4; dim++) {\n    dimA[dim] = inputs[0].size(axis[dim]);\n    filterdimA1[dim] = inputs[1].size(axis[dim]);\n    filterdimA2[dim] = inputs[2].size(axis[dim]);\n    filterdimA3[dim] = inputs[3].size(axis[dim]);\n  }\n  if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {\n    for (int dim = 0; dim < 4; dim++) {\n      filterdimA4[dim] = inputs[14].size(axis[dim]);\n    }\n  }\n\n  // output dim in n,c,h,w used by backend\n  int64_t outdimA1[] = {0, 0, 0, 0};  // Computed Below\n  int64_t outdimA2[] = {0, 0, 0, 0};  // Computed Below\n  int64_t outdimA3[] = {0, 0, 0, 0};  // Computed Below\n\n  // use these fixed value for test run\n  int64_t padA[] = {0, 0};\n  int64_t padA1[] = {1, 1};\n  int64_t dilationA[] = {1, 1};\n  int64_t convstrideA[] = {1, 1};\n  int64_t convstride1X1[] = {stride_1X1, stride_1X1};\n\n  // compute output from pad/stride/dilation\n  outdimA1[0] = dimA[0];\n  outdimA1[1] = filterdimA1[0];\n  for (int dim = 0; dim < 2; dim++) {\n    outdimA1[dim + 2] =\n        getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);\n  }\n\n  outdimA2[0] = outdimA1[0];\n  outdimA2[1] = filterdimA2[0];\n  for (int dim = 0; dim < 2; dim++) {\n    outdimA2[dim + 2] =\n        getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]);\n  }\n\n  outdimA3[0] = outdimA2[0];\n  outdimA3[1] = filterdimA3[0];\n  for (int dim = 0; dim < 2; dim++) {\n    outdimA3[dim + 2] =\n        getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]);\n  }\n\n  // Create output tensor in the correct shape in pytorch's view\n  int64_t outdim1[] = {0, 0, 0, 0};\n  int64_t outdim2[] = {0, 0, 0, 0};\n  int64_t outdim3[] = {0, 0, 0, 0};\n  if (explicit_nhwc) {\n    axis[0] = 0;\n    axis[1] = 2;\n    axis[2] = 3;\n    axis[3] = 1;\n  }\n  for (int dim = 0; dim < 4; dim++) {\n    outdim1[dim] = outdimA1[axis[dim]];\n    outdim2[dim] = outdimA2[axis[dim]];\n    outdim3[dim] = outdimA3[axis[dim]];\n  }\n\n  // dconv3+drelu2+dscale2\n  at::Half* conv_in = inputs[13].data_ptr<at::Half>();\n  at::Half* dy3 = inputs[10].data_ptr<at::Half>();\n\n  DEBUG_MSG(\"[DEBUG] new dconv3 : \" << inputs[10].to(at::kFloat).sum().item<float>());\n\n  // wgrad\n  auto wgrad3 = at::empty_like(inputs[3]);\n  at::Half* dw3 = wgrad3.data_ptr<at::Half>();\n  run_dconv(outdimA2, padA, convstrideA, dilationA, filterdimA3, outdimA3, CUDNN_DATA_HALF, conv_in, dw3, dy3,\n            CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n\n  // dgrad\n  auto grad_out2 = at::empty(outdim2, inputs[0].type(), output_format);\n  at::Half* dy2 = grad_out2.data_ptr<at::Half>();\n  at::Half* w = inputs[3].data_ptr<at::Half>();\n  at::Half* z = inputs[5].data_ptr<at::Half>();\n\n  at::Half* relu2 = inputs[13].data_ptr<at::Half>();\n\n  run_dconv_drelu_dscale(outdimA2, padA, convstrideA, dilationA, filterdimA3, outdimA3, CUDNN_DATA_HALF, dy2, w, dy3, z,\n                         relu2);\n\n  DEBUG_MSG(\"[DEBUG] new dconv2 : \" << grad_out2.to(at::kFloat).sum().item<float>());\n\n  // dconv2+drelu1+dscale1\n  conv_in = inputs[12].data_ptr<at::Half>();\n\n  // wgrad\n  auto wgrad2 = at::empty_like(inputs[2]);\n  at::Half* dw2 = wgrad2.data_ptr<at::Half>();\n  run_dconv(outdimA1, padA1, convstrideA, dilationA, filterdimA2, outdimA2, CUDNN_DATA_HALF, conv_in, dw2, dy2,\n            CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n\n  // dgrad\n  auto grad_out1 = at::empty(outdim1, inputs[0].type(), output_format);\n  at::Half* dy1 = grad_out1.data_ptr<at::Half>();\n  w = inputs[2].data_ptr<at::Half>();\n  z = inputs[4].data_ptr<at::Half>();\n\n  at::Half* relu1 = inputs[12].data_ptr<at::Half>();\n  // fused dgrad\n  run_dconv_drelu_dscale(outdimA1, padA1, convstrideA, dilationA, filterdimA2, outdimA2, CUDNN_DATA_HALF, dy1, w, dy2,\n                         z, relu1);\n\n  /*\n    // backward strided conv cannot be fused\n    // if stride == 1 but channel changes, we can fuse here\n    if (stride_1X1 != 1){\n      // dgrad\n      run_dconv(outdimA1,\n                padA1,\n                convstride1X1,\n                dilationA,\n                filterdimA2,\n                outdimA2,\n                CUDNN_DATA_HALF,\n                dy1,\n                w,\n                dy2,\n                CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);\n\n      // mul fused mask\n      grad_out1.mul_(inputs[15]);\n    }\n    else {\n      at::Half* relu1 = inputs[12].data_ptr<at::Half>();\n      // fused dgrad\n      run_dconv_drelu_dscale(outdimA1,\n                             padA1,\n                             convstride1X1,\n                             dilationA,\n                             filterdimA2,\n                             outdimA2,\n                             CUDNN_DATA_HALF,\n                             dy1,\n                             w,\n                             dy2,\n                             z,\n                             relu1);\n    }\n  */\n  DEBUG_MSG(\"[DEBUG] new dconv1 : \" << grad_out1.to(at::kFloat).sum().item<float>());\n\n  // create grads of conv4 that may exist\n  auto grad_x_conv4 = at::empty_like(inputs[0]);\n  at::Half* dx_conv4 = grad_x_conv4.data_ptr<at::Half>();\n  at::Tensor wgrad4;\n\n  // x used for dconv1 and dconv4 wgrad\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n\n  if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {\n    w = inputs[14].data_ptr<at::Half>();\n    at::Half* dy_conv4 = inputs[11].data_ptr<at::Half>();\n    if (requires_grad) {\n      run_dconv(dimA, padA, convstride1X1, dilationA, filterdimA4, outdimA3, CUDNN_DATA_HALF, dx_conv4, w, dy_conv4,\n                CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);\n      // we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx\n      // DEBUG_MSG(\"[DEBUG] new dx_identity : \" << grad_x_conv4.to(at::kFloat).sum().item<float>());\n    }\n    // wgrad\n    wgrad4 = at::empty_like(inputs[14]);\n    at::Half* dw4 = wgrad4.data_ptr<at::Half>();\n    run_dconv(dimA, padA, convstride1X1, dilationA, filterdimA4, outdimA3, CUDNN_DATA_HALF, x, dw4, dy_conv4,\n              CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n  } else {\n    // if there is no downsample, dx_conv4 is fork of drelu3\n    dx_conv4 = inputs[11].data_ptr<at::Half>();\n  }\n\n  // dconv1+add\n  // wgrad\n  auto wgrad1 = at::empty_like(inputs[1]);\n  at::Half* dw1 = wgrad1.data_ptr<at::Half>();\n  run_dconv(dimA, padA, convstride1X1, dilationA, filterdimA1, outdimA1, CUDNN_DATA_HALF, x, dw1, dy1,\n            CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n\n  // dgrad\n  w = inputs[1].data_ptr<at::Half>();\n  auto grad_x = at::empty_like(inputs[0]);\n  at::Half* dx = grad_x.data_ptr<at::Half>();\n\n  // backward strided conv cannot be fused\n  // if stride == 1 but channel changes, we can fuse here\n  if (requires_grad) {\n    if (stride_1X1 != 1) {\n      run_dconv(dimA, padA, convstride1X1, dilationA, filterdimA1, outdimA1, CUDNN_DATA_HALF, dx, w, dy1,\n                CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);\n      // add 2 together\n      grad_x.add_(grad_x_conv4);\n    } else {\n      run_dconv_add(dimA, padA, convstride1X1, dilationA, filterdimA1, outdimA1, CUDNN_DATA_HALF, dx, w, dy1, dx_conv4);\n    }\n  }\n\n  DEBUG_MSG(\"[DEBUG] new dx : \" << grad_x.to(at::kFloat).sum().item<float>());\n  DEBUG_MSG(\"[DEBUG] new wgrad1 : \" << wgrad1.to(at::kFloat).sum().item<float>());\n  DEBUG_MSG(\"[DEBUG] new wgrad2 : \" << wgrad2.to(at::kFloat).sum().item<float>());\n  DEBUG_MSG(\"[DEBUG] new wgrad3 : \" << wgrad3.to(at::kFloat).sum().item<float>());\n  outputs.push_back(grad_x);\n  outputs.push_back(wgrad1);\n  outputs.push_back(wgrad2);\n  outputs.push_back(wgrad3);\n\n  if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {\n    DEBUG_MSG(\"[DEBUG] new wgrad4 : \" << wgrad4.to(at::kFloat).sum().item<float>());\n    outputs.push_back(wgrad4);\n  }\n\n  return outputs;\n}\n\nnamespace {\n\nenum {\n  X_TENSOR,\n  Y_TENSOR,\n  W_TENSOR,\n  Z_TENSOR,\n  B_TENSOR,\n  AFTERADD_TENSOR,\n  AFTERBIAS_TENSOR,\n  AFTERCONV_TENSOR,\n  OPTIONAL,\n  AFTEROPT_TENSOR,\n  AFTERACT_TENSOR,\n  GEN_INDEX_TENSOR,\n  MASK_TOP_TENSOR,\n  MASK_BOTTOM_TENSOR,\n  MASK_TENSOR,\n  THRESHOLD_TOP_TENSOR,\n  THRESHOLD_BOTTOM_TENSOR,\n};\n\nusing masked_convbias_descriptors =\n    std::tuple<cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,\n               cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,\n               cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,\n               cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,\n               cudnn_frontend::Tensor>;\n\nmasked_convbias_descriptors create_conv_bias_add_act_mask_descriptors(int64_t* x_dim_padded, int64_t* padA,\n                                                                      int64_t* convstrideA, int64_t* dilationA,\n                                                                      int64_t* w_dim_padded, int64_t* y_dim_padded,\n                                                                      int64_t* threshold_dim,\n                                                                      cudnnDataType_t dataType) {\n  const int convDim = 2;\n\n  int64_t b_dim_padded[4];\n  b_dim_padded[0] = 1;\n  b_dim_padded[1] = y_dim_padded[1];\n  b_dim_padded[2] = 1;\n  b_dim_padded[3] = 1;\n\n  int64_t x_stride_padded[4];\n  int64_t y_stride_padded[4];\n  int64_t w_stride_padded[4];\n  int64_t b_stride_padded[4];\n  int64_t threshold_stride[4];\n\n  generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);\n  generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);\n  generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);\n  generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);\n  generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC);\n\n  return masked_convbias_descriptors(cudnn_frontend::TensorBuilder()\n                                         .setDim(4, x_dim_padded)\n                                         .setStrides(4, x_stride_padded)\n                                         .setId('x')\n                                         .setAlignment(16)\n                                         .setDataType(dataType)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, y_dim_padded)\n                                         .setStrides(4, y_stride_padded)\n                                         .setId('y')\n                                         .setAlignment(16)\n                                         .setDataType(dataType)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, w_dim_padded)\n                                         .setStrides(4, w_stride_padded)\n                                         .setId('w')\n                                         .setAlignment(16)\n                                         .setDataType(dataType)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, b_dim_padded)\n                                         .setStrides(4, b_stride_padded)\n                                         .setId('z')\n                                         .setAlignment(16)\n                                         .setDataType(dataType)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, b_dim_padded)\n                                         .setStrides(4, b_stride_padded)\n                                         .setId('b')\n                                         .setAlignment(16)\n                                         .setDataType(dataType)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, y_dim_padded)\n                                         .setStrides(4, y_stride_padded)\n                                         .setVirtual()\n                                         .setId('A')  // after add\n                                         .setAlignment(16)\n                                         .setDataType(CUDNN_DATA_FLOAT)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, y_dim_padded)\n                                         .setStrides(4, y_stride_padded)\n                                         .setVirtual()\n                                         .setId('B')  // after bias\n                                         .setAlignment(16)\n                                         .setDataType(CUDNN_DATA_FLOAT)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, y_dim_padded)\n                                         .setStrides(4, y_stride_padded)\n                                         .setId('C')  // after conv\n                                         .setAlignment(16)\n                                         .setVirtual()\n                                         .setDataType(CUDNN_DATA_FLOAT)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, y_dim_padded)\n                                         .setStrides(4, y_stride_padded)\n                                         .setId('i')\n                                         .setAlignment(16)\n                                         .setDataType(dataType)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, y_dim_padded)\n                                         .setStrides(4, y_stride_padded)\n                                         .setId('D')  // after optional add\n                                         .setAlignment(16)\n                                         .setVirtual()\n                                         .setDataType(CUDNN_DATA_FLOAT)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, y_dim_padded)\n                                         .setStrides(4, y_stride_padded)\n                                         .setId('E')  // after act for masked\n                                         .setAlignment(16)\n                                         .setVirtual()\n                                         .setDataType(CUDNN_DATA_FLOAT)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, y_dim_padded)\n                                         .setStrides(4, y_stride_padded)\n                                         .setId('I')  // output of the gen index operation\n                                         .setAlignment(16)\n                                         .setVirtual()\n                                         .setDataType(CUDNN_DATA_INT32)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, y_dim_padded)\n                                         .setStrides(4, y_stride_padded)\n                                         .setId('m')  // top half of the mask created after the less than\n                                         .setAlignment(16)\n                                         .setVirtual()\n                                         .setDataType(CUDNN_DATA_BOOLEAN)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, y_dim_padded)\n                                         .setStrides(4, y_stride_padded)\n                                         .setId('n')  // bottom half of the mask\n                                         .setAlignment(16)\n                                         .setVirtual()\n                                         .setDataType(CUDNN_DATA_BOOLEAN)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, y_dim_padded)\n                                         .setStrides(4, y_stride_padded)\n                                         .setId('M')  // OR of the top and bottom masks\n                                         .setAlignment(16)\n                                         .setVirtual()\n                                         .setDataType(CUDNN_DATA_BOOLEAN)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, threshold_dim)\n                                         .setStrides(4, threshold_stride)\n                                         .setId('t')  // threshold for creating the top mask\n                                         .setAlignment(16)\n                                         .setDataType(CUDNN_DATA_INT32)\n                                         .build(),\n                                     cudnn_frontend::TensorBuilder()\n                                         .setDim(4, threshold_dim)\n                                         .setStrides(4, threshold_stride)\n                                         .setId('u')  // threshold for creating the bottom mask\n                                         .setAlignment(16)\n                                         .setDataType(CUDNN_DATA_INT32)\n                                         .build());\n}\n\n// tensor descriptors used for dgrad\nenum {\n  X_OR_DX_TENSOR,\n  DY_TENSOR,\n  W_OR_DW_TENSOR,\n  SCALE_TENSOR,\n  RELU_TENSOR,\n  AFTER_DCONV_TENSOR,\n  AFTER_DRELU_TENSOR,\n  DGRAD_INPUT_TENSOR,\n  DGRAD_OPTIONAL_TENSOR,\n  DGRAD_GEN_INDEX_TENSOR,\n  DGRAD_MASK_TOP_TENSOR,\n  DGRAD_MASK_BOTTOM_TENSOR,\n  DGRAD_MASK_TENSOR,\n  DGRAD_THRESHOLD_TOP_TENSOR,\n  DGRAD_THRESHOLD_BOTTOM_TENSOR,\n};\n\nusing dconv_add_descriptors = std::tuple<cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,\n                                         cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,\n                                         cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor>;\n\ndconv_add_descriptors create_dconv_add_descriptors(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA,\n                                                   int64_t* dilationA, int64_t* w_dim_padded, int64_t* y_dim_padded,\n                                                   cudnnDataType_t dataType) {\n  const int convDim = 2;\n\n  int64_t b_dim_padded[4];\n  b_dim_padded[0] = 1;\n  b_dim_padded[1] = x_dim_padded[1];\n  b_dim_padded[2] = 1;\n  b_dim_padded[3] = 1;\n\n  int64_t x_stride_padded[4];\n  int64_t y_stride_padded[4];\n  int64_t w_stride_padded[4];\n  int64_t b_stride_padded[4];\n\n  generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);\n  generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);\n  generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);\n  generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);\n\n  return dconv_add_descriptors(cudnn_frontend::TensorBuilder()\n                                   .setDim(4, x_dim_padded)\n                                   .setStrides(4, x_stride_padded)\n                                   .setId('x')\n                                   .setAlignment(16)\n                                   .setDataType(dataType)\n                                   .build(),\n                               cudnn_frontend::TensorBuilder()\n                                   .setDim(4, y_dim_padded)\n                                   .setStrides(4, y_stride_padded)\n                                   .setId('y')\n                                   .setAlignment(16)\n                                   .setDataType(dataType)\n                                   .build(),\n                               cudnn_frontend::TensorBuilder()\n                                   .setDim(4, w_dim_padded)\n                                   .setStrides(4, w_stride_padded)\n                                   .setId('w')\n                                   .setAlignment(16)\n                                   .setDataType(dataType)\n                                   .build(),\n                               cudnn_frontend::TensorBuilder()\n                                   .setDim(4, b_dim_padded)\n                                   .setStrides(4, b_stride_padded)\n                                   .setId('s')\n                                   .setAlignment(16)\n                                   .setDataType(dataType)\n                                   .build(),\n                               cudnn_frontend::TensorBuilder()\n                                   .setDim(4, x_dim_padded)\n                                   .setStrides(4, x_stride_padded)\n                                   .setId('r')\n                                   .setAlignment(16)\n                                   .setDataType(dataType)\n                                   .build(),\n                               cudnn_frontend::TensorBuilder()\n                                   .setDim(4, x_dim_padded)\n                                   .setStrides(4, x_stride_padded)\n                                   .setVirtual()\n                                   .setId('A')  // after dconv\n                                   .setAlignment(16)\n                                   .setDataType(CUDNN_DATA_FLOAT)\n                                   .build(),\n                               cudnn_frontend::TensorBuilder()\n                                   .setDim(4, x_dim_padded)\n                                   .setStrides(4, x_stride_padded)\n                                   .setVirtual()\n                                   .setId('B')  // after drelu\n                                   .setAlignment(16)\n                                   .setDataType(CUDNN_DATA_FLOAT)\n                                   .build(),\n                               cudnn_frontend::TensorBuilder()\n                                   .setDim(4, y_dim_padded)\n                                   .setStrides(4, y_stride_padded)\n                                   .setId('i')\n                                   .setAlignment(16)\n                                   .setDataType(dataType)\n                                   .build(),\n                               cudnn_frontend::TensorBuilder()\n                                   .setDim(4, y_dim_padded)\n                                   .setStrides(4, y_stride_padded)\n                                   .setId('D')  // after optional add\n                                   .setAlignment(16)\n                                   .setVirtual()\n                                   .setDataType(CUDNN_DATA_FLOAT)\n                                   .build());\n}\n\nusing dconv_mask_descriptors =\n    std::tuple<cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,\n               cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,\n               cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor,\n               cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor>;\n\ndconv_mask_descriptors create_dconv_mask_descriptors(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA,\n                                                     int64_t* dilationA, int64_t* w_dim_padded, int64_t* y_dim_padded,\n                                                     int64_t* threshold_dim, cudnnDataType_t dataType) {\n  const int convDim = 2;\n\n  int64_t b_dim_padded[4];\n  b_dim_padded[0] = 1;\n  b_dim_padded[1] = x_dim_padded[1];\n  b_dim_padded[2] = 1;\n  b_dim_padded[3] = 1;\n\n  int64_t x_stride_padded[4];\n  int64_t y_stride_padded[4];\n  int64_t w_stride_padded[4];\n  int64_t b_stride_padded[4];\n  int64_t threshold_stride[4];\n\n  generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);\n  generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);\n  generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);\n  generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);\n  generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC);\n\n  return dconv_mask_descriptors(cudnn_frontend::TensorBuilder()\n                                    .setDim(4, x_dim_padded)\n                                    .setStrides(4, x_stride_padded)\n                                    .setId('x')\n                                    .setAlignment(16)\n                                    .setDataType(dataType)\n                                    .build(),\n                                cudnn_frontend::TensorBuilder()\n                                    .setDim(4, y_dim_padded)\n                                    .setStrides(4, y_stride_padded)\n                                    .setId('y')\n                                    .setAlignment(16)\n                                    .setDataType(dataType)\n                                    .build(),\n                                cudnn_frontend::TensorBuilder()\n                                    .setDim(4, w_dim_padded)\n                                    .setStrides(4, w_stride_padded)\n                                    .setId('w')\n                                    .setAlignment(16)\n                                    .setDataType(dataType)\n                                    .build(),\n                                cudnn_frontend::TensorBuilder()\n                                    .setDim(4, b_dim_padded)\n                                    .setStrides(4, b_stride_padded)\n                                    .setId('s')\n                                    .setAlignment(16)\n                                    .setDataType(dataType)\n                                    .build(),\n                                cudnn_frontend::TensorBuilder()\n                                    .setDim(4, x_dim_padded)\n                                    .setStrides(4, x_stride_padded)\n                                    .setId('r')\n                                    .setAlignment(16)\n                                    .setDataType(dataType)\n                                    .build(),\n                                cudnn_frontend::TensorBuilder()\n                                    .setDim(4, x_dim_padded)\n                                    .setStrides(4, x_stride_padded)\n                                    .setVirtual()\n                                    .setId('A')  // after dconv\n                                    .setAlignment(16)\n                                    .setDataType(CUDNN_DATA_FLOAT)\n                                    .build(),\n                                cudnn_frontend::TensorBuilder()\n                                    .setDim(4, x_dim_padded)\n                                    .setStrides(4, x_stride_padded)\n                                    .setVirtual()\n                                    .setId('B')  // after drelu\n                                    .setAlignment(16)\n                                    .setDataType(CUDNN_DATA_FLOAT)\n                                    .build(),\n                                cudnn_frontend::TensorBuilder()\n                                    .setDim(4, y_dim_padded)\n                                    .setStrides(4, y_stride_padded)\n                                    .setId('i')\n                                    .setAlignment(16)\n                                    .setDataType(dataType)\n                                    .build(),\n                                cudnn_frontend::TensorBuilder()\n                                    .setDim(4, y_dim_padded)\n                                    .setStrides(4, y_stride_padded)\n                                    .setId('D')  // after optional add\n                                    .setAlignment(16)\n                                    .setVirtual()\n                                    .setDataType(CUDNN_DATA_FLOAT)\n                                    .build(),\n                                cudnn_frontend::TensorBuilder()\n                                    .setDim(4, y_dim_padded)\n                                    .setStrides(4, y_stride_padded)\n                                    .setId('I')  // output of the gen index operation\n                                    .setAlignment(16)\n                                    .setVirtual()\n                                    .setDataType(CUDNN_DATA_INT32)\n                                    .build(),\n                                cudnn_frontend::TensorBuilder()\n                                    .setDim(4, y_dim_padded)\n                                    .setStrides(4, y_stride_padded)\n                                    .setId('m')  // top half of the mask created after the less than\n                                    .setAlignment(16)\n                                    .setVirtual()\n                                    .setDataType(CUDNN_DATA_BOOLEAN)\n                                    .build(),\n                                cudnn_frontend::TensorBuilder()\n                                    .setDim(4, y_dim_padded)\n                                    .setStrides(4, y_stride_padded)\n                                    .setId('n')  // bottom half of the mask\n                                    .setAlignment(16)\n                                    .setVirtual()\n                                    .setDataType(CUDNN_DATA_BOOLEAN)\n                                    .build(),\n                                cudnn_frontend::TensorBuilder()\n                                    .setDim(4, y_dim_padded)\n                                    .setStrides(4, y_stride_padded)\n                                    .setId('M')  // OR of the top and bottom masks\n                                    .setAlignment(16)\n                                    .setVirtual()\n                                    .setDataType(CUDNN_DATA_BOOLEAN)\n                                    .build(),\n                                cudnn_frontend::TensorBuilder()\n                                    .setDim(4, threshold_dim)\n                                    .setStrides(4, threshold_stride)\n                                    .setId('t')  // threshold for creating the top mask\n                                    .setAlignment(16)\n                                    .setDataType(CUDNN_DATA_INT32)\n                                    .build(),\n                                cudnn_frontend::TensorBuilder()\n                                    .setDim(4, threshold_dim)\n                                    .setStrides(4, threshold_stride)\n                                    .setId('u')  // threshold for creating the bottom mask\n                                    .setAlignment(16)\n                                    .setDataType(CUDNN_DATA_INT32)\n                                    .build());\n}\n\nvoid run_conv_add_scale_bias_activation(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation,\n                                        int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType,\n                                        at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ,\n                                        at::Half* devPtrB, at::Half* devPtrI) {\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n  std::stringstream log_buf;\n  try {\n    int convDim = 2;\n\n    // Creates the necessary tensor descriptors\n    common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(x_dim_padded, pad, convstride, dilation,\n                                                                               w_dim_padded, y_dim_padded, dataType);\n    DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTEROPT_TENSOR>(tensors).describe());\n\n    // Define the add operation\n    auto scaleDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());\n\n    // Define the bias operation\n    auto biasDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());\n\n    // optional add\n    auto addDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, addDesc.describe());\n\n    // Define the activation operation\n    auto actDesc = cudnn_frontend::PointWiseDescBuilder()\n                       .setMode(CUDNN_POINTWISE_RELU_FWD)\n                       .setMathPrecision(CUDNN_DATA_FLOAT)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, actDesc.describe());\n\n    // Define the convolution problem\n    auto convDesc = cudnn_frontend::ConvDescBuilder()\n                        .setDataType(CUDNN_DATA_FLOAT)\n                        .setMathMode(CUDNN_CROSS_CORRELATION)\n                        .setNDims(convDim)\n                        .setStrides(convDim, convstride)\n                        .setPrePadding(convDim, pad)\n                        .setPostPadding(convDim, pad)\n                        .setDilation(convDim, dilation)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n    float alpha = 1.0f;\n    float beta = 0.0f;\n\n    // Create a convolution Node\n    auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)\n                       .setxDesc(std::get<X_TENSOR>(tensors))\n                       .setwDesc(std::get<W_TENSOR>(tensors))\n                       .setyDesc(std::get<AFTERCONV_TENSOR>(tensors))\n                       .setcDesc(convDesc)\n                       .setAlpha(alpha)\n                       .setBeta(beta)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n    // create an add node.\n    auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                      .setxDesc(conv_op.getOutputTensor())\n                      .setbDesc(std::get<OPTIONAL>(tensors))\n                      .setyDesc(std::get<AFTEROPT_TENSOR>(tensors))\n                      .setpwDesc(addDesc)\n                      .build();\n    DEBUG_CUDNN_MSG(log_buf, add_op.describe());\n\n    // Create a Add Node with scaling parameters.\n    auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                        .setxDesc(add_op.getOutputTensor())\n                        .setbDesc(std::get<Z_TENSOR>(tensors))\n                        .setyDesc(std::get<AFTERADD_TENSOR>(tensors))\n                        .setpwDesc(scaleDesc)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, scale_op.describe());\n\n    // Create a Bias Node.\n    auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                       .setxDesc(scale_op.getOutputTensor())\n                       .setbDesc(std::get<B_TENSOR>(tensors))\n                       .setyDesc(std::get<AFTERBIAS_TENSOR>(tensors))\n                       .setpwDesc(biasDesc)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, bias_op.describe());\n\n    // Create an Activation Node.\n    auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                      .setxDesc(bias_op.getOutputTensor())\n                      .setyDesc(std::get<Y_TENSOR>(tensors))\n                      .setpwDesc(actDesc)\n                      .build();\n    DEBUG_CUDNN_MSG(log_buf, act_op.describe());\n\n    // Create an Operation Graph. In this case it is convolution add bias activation\n    std::array<cudnn_frontend::Operation const*, 5> ops = {&conv_op, &add_op, &scale_op, &bias_op, &act_op};\n\n    auto opGraph =\n        cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();\n\n    // Create string encoding for plan caching\n    auto cache_string =\n        getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n    DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n    auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n    DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n    auto workspace_size = plan.getWorkspaceSize();\n    DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n    void* workspace_ptr = nullptr;\n    auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n    if (workspace_size > 0) {\n      workspace_ptr = workspace_tensor.data_ptr<float>();\n    }\n    void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI};\n    int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'};\n    auto variantPack = cudnn_frontend::VariantPackBuilder()\n                           .setWorkspacePointer(workspace_ptr)\n                           .setDataPointers(6, data_ptrs)\n                           .setUids(6, uids)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n    cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n    checkCudnnErr(status);\n    cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n  } catch (cudnn_frontend::cudnnException e) {\n    std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n  }\n}\n\nvoid run_conv_scale_bias_add_activation_mask(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride,\n                                             int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded,\n                                             int64_t* threshold_dim, cudnnDataType_t dataType, at::Half* devPtrX,\n                                             at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrB,\n                                             at::Half* devPtrI, int* devPtrT, int* devPtrU, int axis) {\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n  std::stringstream log_buf;\n  try {\n    int convDim = 2;\n\n    // Creates the necessary tensor descriptors\n    masked_convbias_descriptors tensors = create_conv_bias_add_act_mask_descriptors(\n        x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType);\n    DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTERACT_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<GEN_INDEX_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<MASK_TOP_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<MASK_BOTTOM_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<MASK_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<THRESHOLD_TOP_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<THRESHOLD_BOTTOM_TENSOR>(tensors).describe());\n\n    // Define the add operation\n    auto scaleDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());\n\n    // Define the bias operation\n    auto biasDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());\n\n    // optional add\n    auto addDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, addDesc.describe());\n\n    // Define the activation operation\n    auto actDesc = cudnn_frontend::PointWiseDescBuilder()\n                       .setMode(CUDNN_POINTWISE_RELU_FWD)\n                       .setMathPrecision(CUDNN_DATA_FLOAT)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, actDesc.describe());\n\n    // Define the convolution problem\n    auto convDesc = cudnn_frontend::ConvDescBuilder()\n                        .setDataType(CUDNN_DATA_FLOAT)\n                        .setMathMode(CUDNN_CROSS_CORRELATION)\n                        .setNDims(convDim)\n                        .setStrides(convDim, convstride)\n                        .setPrePadding(convDim, pad)\n                        .setPostPadding(convDim, pad)\n                        .setDilation(convDim, dilation)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n    // Define the genIndex descriptor\n    auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder()\n                            .setMode(CUDNN_POINTWISE_GEN_INDEX)\n                            .setMathPrecision(CUDNN_DATA_FLOAT)\n                            .setAxis(axis)\n                            .build();\n    DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe());\n\n    // Define the lessThan descriptor\n    auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder()\n                            .setMode(CUDNN_POINTWISE_CMP_LT)\n                            .setMathPrecision(CUDNN_DATA_FLOAT)\n                            .build();\n    DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe());\n\n    // Define the greaterThan descriptor\n    auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder()\n                               .setMode(CUDNN_POINTWISE_CMP_GT)\n                               .setMathPrecision(CUDNN_DATA_FLOAT)\n                               .build();\n    DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe());\n\n    // Define the logical_or descriptor\n    auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder()\n                             .setMode(CUDNN_POINTWISE_LOGICAL_OR)\n                             .setMathPrecision(CUDNN_DATA_BOOLEAN)\n                             .build();\n    DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe());\n\n    // Define the binary_selection descriptor\n    auto selectionDesc = cudnn_frontend::PointWiseDescBuilder()\n                             .setMode(CUDNN_POINTWISE_BINARY_SELECT)\n                             .setMathPrecision(CUDNN_DATA_FLOAT)\n                             .build();\n    DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe());\n\n    float alpha = 1.0f;\n    float beta = 0.0f;\n\n    // Create a convolution Node\n    auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)\n                       .setxDesc(std::get<X_TENSOR>(tensors))\n                       .setwDesc(std::get<W_TENSOR>(tensors))\n                       .setyDesc(std::get<AFTERCONV_TENSOR>(tensors))\n                       .setcDesc(convDesc)\n                       .setAlpha(alpha)\n                       .setBeta(beta)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n    // Create a Add Node with scaling parameters.\n    auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                        .setxDesc(conv_op.getOutputTensor())\n                        .setbDesc(std::get<Z_TENSOR>(tensors))\n                        .setyDesc(std::get<AFTERADD_TENSOR>(tensors))\n                        .setpwDesc(scaleDesc)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, scale_op.describe());\n\n    // Create a Bias Node.\n    auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                       .setxDesc(scale_op.getOutputTensor())\n                       .setbDesc(std::get<B_TENSOR>(tensors))\n                       .setyDesc(std::get<AFTERBIAS_TENSOR>(tensors))\n                       .setpwDesc(biasDesc)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, bias_op.describe());\n\n    // Create a optional add Node.\n    auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                      .setxDesc(bias_op.getOutputTensor())\n                      .setbDesc(std::get<OPTIONAL>(tensors))\n                      .setyDesc(std::get<AFTEROPT_TENSOR>(tensors))\n                      .setpwDesc(addDesc)\n                      .build();\n    DEBUG_CUDNN_MSG(log_buf, add_op.describe());\n\n    // Create an Activation Node.\n    auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                      .setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor())\n                      .setyDesc(std::get<AFTERACT_TENSOR>(tensors))\n                      .setpwDesc(actDesc)\n                      .build();\n    DEBUG_CUDNN_MSG(log_buf, act_op.describe());\n\n    // Create a Gen_Index Node.\n    auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                           .setxDesc(std::get<AFTERACT_TENSOR>(tensors))\n                           .setyDesc(std::get<GEN_INDEX_TENSOR>(tensors))\n                           .setpwDesc(genIndexDesc)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe());\n\n    // Create a LessThan Node.\n    auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                           .setxDesc(std::get<GEN_INDEX_TENSOR>(tensors))\n                           .setbDesc(std::get<THRESHOLD_TOP_TENSOR>(tensors))\n                           .setyDesc(std::get<MASK_TOP_TENSOR>(tensors))\n                           .setpwDesc(lessThanDesc)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe());\n\n    // Create a GreaterThan Node.\n    auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                              .setxDesc(std::get<GEN_INDEX_TENSOR>(tensors))\n                              .setbDesc(std::get<THRESHOLD_BOTTOM_TENSOR>(tensors))\n                              .setyDesc(std::get<MASK_BOTTOM_TENSOR>(tensors))\n                              .setpwDesc(greaterThanDesc)\n                              .build();\n    DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe());\n\n    // Create a LogicalOr Node.\n    auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                            .setxDesc(std::get<MASK_TOP_TENSOR>(tensors))\n                            .setbDesc(std::get<MASK_BOTTOM_TENSOR>(tensors))\n                            .setyDesc(std::get<MASK_TENSOR>(tensors))\n                            .setpwDesc(logicalOrDesc)\n                            .build();\n    DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe());\n\n    // Create a Binary_Selection Node.\n    auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                            .setxDesc(std::get<AFTERCONV_TENSOR>(tensors))\n                            .setbDesc(std::get<AFTERACT_TENSOR>(tensors))\n                            .settDesc(std::get<MASK_TENSOR>(tensors))\n                            .setyDesc(std::get<Y_TENSOR>(tensors))\n                            .setpwDesc(selectionDesc)\n                            .build();\n    DEBUG_CUDNN_MSG(log_buf, selection_op.describe());\n\n    // Create an Operation Graph. In this case it is convolution add bias activation\n    if (devPtrI) {\n      std::array<cudnn_frontend::Operation const*, 10> ops = {\n          &conv_op,     &scale_op,    &bias_op,        &add_op,       &act_op,\n          &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op};\n\n      auto opGraph =\n          cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();\n\n      // Create string encoding for plan caching\n      auto cache_string =\n          getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n      DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n      auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n      DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n      auto workspace_size = plan.getWorkspaceSize();\n      DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n      void* workspace_ptr = nullptr;\n      auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n      if (workspace_size > 0) {\n        workspace_ptr = workspace_tensor.data_ptr<float>();\n      }\n      void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI, devPtrT, devPtrU};\n      int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i', 't', 'u'};\n      auto variantPack = cudnn_frontend::VariantPackBuilder()\n                             .setWorkspacePointer(workspace_ptr)\n                             .setDataPointers(8, data_ptrs)\n                             .setUids(8, uids)\n                             .build();\n      DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n      cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n      checkCudnnErr(status);\n      cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n    } else {\n      std::array<cudnn_frontend::Operation const*, 9> ops = {&conv_op,        &scale_op,     &bias_op,\n                                                             &act_op,         &genIndex_op,  &lessThan_op,\n                                                             &greaterThan_op, &logicalOr_op, &selection_op};\n\n      auto opGraph =\n          cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();\n\n      // Create string encoding for plan caching\n      auto cache_string =\n          getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n      DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n      auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n      DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n      auto workspace_size = plan.getWorkspaceSize();\n      DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n      void* workspace_ptr = nullptr;\n      auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n      if (workspace_size > 0) {\n        workspace_ptr = workspace_tensor.data_ptr<float>();\n      }\n      void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrT, devPtrU};\n      int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 't', 'u'};\n      auto variantPack = cudnn_frontend::VariantPackBuilder()\n                             .setWorkspacePointer(workspace_ptr)\n                             .setDataPointers(7, data_ptrs)\n                             .setUids(7, uids)\n                             .build();\n      DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n      cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n      checkCudnnErr(status);\n      cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n    }\n  } catch (cudnn_frontend::cudnnException e) {\n    std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n  }\n}\n\nvoid run_dconv_add_drelu_dscale(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation,\n                                int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType,\n                                at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ,\n                                at::Half* devPtrR, at::Half* devPtrI) {\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n  std::stringstream log_buf;\n  try {\n    int convDim = 2;\n\n    // Creates the necessary tensor descriptors\n    dconv_add_descriptors tensors =\n        create_dconv_add_descriptors(x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);\n    DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_INPUT_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_OPTIONAL_TENSOR>(tensors).describe());\n\n    // Define the convolution problem\n    auto convDesc = cudnn_frontend::ConvDescBuilder()\n                        .setDataType(CUDNN_DATA_FLOAT)\n                        .setMathMode(CUDNN_CROSS_CORRELATION)\n                        .setNDims(convDim)\n                        .setStrides(convDim, convstride)\n                        .setPrePadding(convDim, pad)\n                        .setPostPadding(convDim, pad)\n                        .setDilation(convDim, dilation)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n    // optional add\n    auto addDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, addDesc.describe());\n\n    // Define the activation backward operation\n    auto actDesc = cudnn_frontend::PointWiseDescBuilder()\n                       .setMode(CUDNN_POINTWISE_RELU_BWD)\n                       .setMathPrecision(CUDNN_DATA_FLOAT)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, actDesc.describe());\n\n    // Define the scale backward operation\n    auto scaleDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());\n\n    float alpha = 1.0f;\n    float beta = 0.0f;\n\n    // Create a convolution Node\n    auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)\n                       .setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))\n                       .setwDesc(std::get<W_OR_DW_TENSOR>(tensors))\n                       .setdyDesc(std::get<DY_TENSOR>(tensors))\n                       .setcDesc(convDesc)\n                       .setAlpha(alpha)\n                       .setBeta(beta)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n    // Create add Node.\n    auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                      .setxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))\n                      .setbDesc(std::get<DGRAD_INPUT_TENSOR>(tensors))\n                      .setyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))\n                      .setpwDesc(addDesc)\n                      .build();\n    DEBUG_CUDNN_MSG(log_buf, add_op.describe());\n\n    // TODO: do we need getOutputTensor(), and what it returns in backward case?\n    // Create an relu backward Node.\n    auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                      .setdyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))\n                      .setxDesc(std::get<RELU_TENSOR>(tensors))\n                      .setdxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))\n                      .setpwDesc(actDesc)\n                      .build();\n    DEBUG_CUDNN_MSG(log_buf, act_op.describe());\n\n    // Create a Scale Node.\n    auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                        .setxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))\n                        .setbDesc(std::get<SCALE_TENSOR>(tensors))\n                        .setyDesc(std::get<X_OR_DX_TENSOR>(tensors))\n                        .setpwDesc(scaleDesc)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, scale_op.describe());\n\n    // Create an Operation Graph. In this case it is convolution add bias activation\n    std::array<cudnn_frontend::Operation const*, 4> ops = {&conv_op, &add_op, &act_op, &scale_op};\n\n    auto opGraph =\n        cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();\n\n    // Create string encoding for plan caching\n    auto cache_string =\n        getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n    DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n    auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n    DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n    auto workspace_size = plan.getWorkspaceSize();\n    DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n    void* workspace_ptr = nullptr;\n    auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n    if (workspace_size > 0) {\n      workspace_ptr = workspace_tensor.data_ptr<float>();\n    }\n    void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrI};\n    int64_t uids[] = {'x', 'y', 'w', 's', 'r', 'i'};\n    auto variantPack = cudnn_frontend::VariantPackBuilder()\n                           .setWorkspacePointer(workspace_ptr)\n                           .setDataPointers(6, data_ptrs)\n                           .setUids(6, uids)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n    cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n    checkCudnnErr(status);\n    cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n  } catch (cudnn_frontend::cudnnException e) {\n    std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n  }\n}\n\nvoid run_dconv_drelu_dscale_mask(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation,\n                                 int64_t* w_dim_padded, int64_t* y_dim_padded, int64_t* threshold_dim,\n                                 cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY,\n                                 at::Half* devPtrZ, at::Half* devPtrR, int* devPtrT, int* devPtrU, int axis) {\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n  std::stringstream log_buf;\n  try {\n    int convDim = 2;\n\n    // Creates the necessary tensor descriptors\n    dconv_mask_descriptors tensors = create_dconv_mask_descriptors(x_dim_padded, pad, convstride, dilation,\n                                                                   w_dim_padded, y_dim_padded, threshold_dim, dataType);\n    DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_OPTIONAL_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_GEN_INDEX_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_MASK_TOP_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_MASK_BOTTOM_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_MASK_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_THRESHOLD_TOP_TENSOR>(tensors).describe());\n    DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_THRESHOLD_BOTTOM_TENSOR>(tensors).describe());\n\n    // Define the convolution problem\n    auto convDesc = cudnn_frontend::ConvDescBuilder()\n                        .setDataType(CUDNN_DATA_FLOAT)\n                        .setMathMode(CUDNN_CROSS_CORRELATION)\n                        .setNDims(convDim)\n                        .setStrides(convDim, convstride)\n                        .setPrePadding(convDim, pad)\n                        .setPostPadding(convDim, pad)\n                        .setDilation(convDim, dilation)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n    // Define the activation backward operation\n    auto actDesc = cudnn_frontend::PointWiseDescBuilder()\n                       .setMode(CUDNN_POINTWISE_RELU_BWD)\n                       .setMathPrecision(CUDNN_DATA_FLOAT)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, actDesc.describe());\n\n    // Define the scale backward operation\n    auto scaleDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());\n\n    // Define the genIndex descriptor\n    auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder()\n                            .setMode(CUDNN_POINTWISE_GEN_INDEX)\n                            .setMathPrecision(CUDNN_DATA_FLOAT)\n                            .setAxis(axis)\n                            .build();\n    DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe());\n\n    // Define the lessThan descriptor\n    auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder()\n                            .setMode(CUDNN_POINTWISE_CMP_LT)\n                            .setMathPrecision(CUDNN_DATA_FLOAT)\n                            .build();\n    DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe());\n\n    // Define the greaterThan descriptor\n    auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder()\n                               .setMode(CUDNN_POINTWISE_CMP_GT)\n                               .setMathPrecision(CUDNN_DATA_FLOAT)\n                               .build();\n    DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe());\n\n    // Define the logical_or descriptor\n    auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder()\n                             .setMode(CUDNN_POINTWISE_LOGICAL_OR)\n                             .setMathPrecision(CUDNN_DATA_BOOLEAN)\n                             .build();\n    DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe());\n\n    // Define the binary_selection descriptor\n    auto selectionDesc = cudnn_frontend::PointWiseDescBuilder()\n                             .setMode(CUDNN_POINTWISE_BINARY_SELECT)\n                             .setMathPrecision(CUDNN_DATA_FLOAT)\n                             .build();\n    DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe());\n\n    float alpha = 1.0f;\n    float beta = 0.0f;\n\n    // Create a convolution Node\n    auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)\n                       .setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))\n                       .setwDesc(std::get<W_OR_DW_TENSOR>(tensors))\n                       .setdyDesc(std::get<DY_TENSOR>(tensors))\n                       .setcDesc(convDesc)\n                       .setAlpha(alpha)\n                       .setBeta(beta)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n    // TODO: do we need getOutputTensor(), and what it returns in backward case?\n    // Create an relu backward Node.\n    auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                      .setdyDesc(std::get<AFTER_DCONV_TENSOR>(tensors))\n                      .setxDesc(std::get<RELU_TENSOR>(tensors))\n                      .setdxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))\n                      .setpwDesc(actDesc)\n                      .build();\n    DEBUG_CUDNN_MSG(log_buf, act_op.describe());\n\n    // Create a Scale Node.\n    auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                        .setxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))\n                        .setbDesc(std::get<SCALE_TENSOR>(tensors))\n                        .setyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))\n                        .setpwDesc(scaleDesc)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, scale_op.describe());\n\n    // Create a Gen_Index Node.\n    auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                           .setxDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))\n                           .setyDesc(std::get<DGRAD_GEN_INDEX_TENSOR>(tensors))\n                           .setpwDesc(genIndexDesc)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe());\n\n    // Create a LessThan Node.\n    auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                           .setxDesc(std::get<DGRAD_GEN_INDEX_TENSOR>(tensors))\n                           .setbDesc(std::get<DGRAD_THRESHOLD_TOP_TENSOR>(tensors))\n                           .setyDesc(std::get<DGRAD_MASK_TOP_TENSOR>(tensors))\n                           .setpwDesc(lessThanDesc)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe());\n\n    // Create a GreaterThan Node.\n    auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                              .setxDesc(std::get<DGRAD_GEN_INDEX_TENSOR>(tensors))\n                              .setbDesc(std::get<DGRAD_THRESHOLD_BOTTOM_TENSOR>(tensors))\n                              .setyDesc(std::get<DGRAD_MASK_BOTTOM_TENSOR>(tensors))\n                              .setpwDesc(greaterThanDesc)\n                              .build();\n    DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe());\n\n    // Create a LogicalOr Node.\n    auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                            .setxDesc(std::get<DGRAD_MASK_TOP_TENSOR>(tensors))\n                            .setbDesc(std::get<DGRAD_MASK_BOTTOM_TENSOR>(tensors))\n                            .setyDesc(std::get<DGRAD_MASK_TENSOR>(tensors))\n                            .setpwDesc(logicalOrDesc)\n                            .build();\n    DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe());\n\n    // Create a Binary_Selection Node.\n    auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                            .setxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))\n                            .setbDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))\n                            .settDesc(std::get<DGRAD_MASK_TENSOR>(tensors))\n                            .setyDesc(std::get<X_OR_DX_TENSOR>(tensors))\n                            .setpwDesc(selectionDesc)\n                            .build();\n    DEBUG_CUDNN_MSG(log_buf, selection_op.describe());\n\n    // Create an Operation Graph. In this case it is convolution add bias activation\n    std::array<cudnn_frontend::Operation const*, 8> ops = {&conv_op,     &act_op,         &scale_op,     &genIndex_op,\n                                                           &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op};\n\n    auto opGraph =\n        cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();\n\n    // Create string encoding for plan caching\n    auto cache_string =\n        getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());\n    DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n    auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n    DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n    auto workspace_size = plan.getWorkspaceSize();\n    DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n    void* workspace_ptr = nullptr;\n    auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n    if (workspace_size > 0) {\n      workspace_ptr = workspace_tensor.data_ptr<float>();\n    }\n    void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrT, devPtrU};\n    int64_t uids[] = {'x', 'y', 'w', 's', 'r', 't', 'u'};\n    auto variantPack = cudnn_frontend::VariantPackBuilder()\n                           .setWorkspacePointer(workspace_ptr)\n                           .setDataPointers(7, data_ptrs)\n                           .setUids(7, uids)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n    cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n    checkCudnnErr(status);\n    cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n  } catch (cudnn_frontend::cudnnException e) {\n    std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n  }\n}\n\nstruct bottleneck_forward_status {\n  int64_t dimA[4];\n  int64_t filterdimA1[4];\n  int64_t filterdimA2[4];\n  int64_t filterdimA2hh[4];\n  int64_t filterdimA3[4];\n  int64_t filterdimA4[4];\n\n  int64_t threshdim[4];\n\n  int axis[4];\n\n  int64_t outdimA0[4];\n  int64_t outdimA1[4];\n  int64_t outdimA1b[4];  // out1_pad\n  int64_t outdimA2[4];\n  int64_t outdimA3[4];\n  int64_t outdimA4[4];\n\n  int64_t padA[2];\n  int64_t padA1[2];\n  int64_t padA2[2];  // halo padding\n  int64_t dilationA[2];\n  int64_t convstrideA[2];\n  int64_t convstride1X1[2];\n\n  int64_t outdim0[4];  // halo input shape\n  int64_t outdim1[4];\n  int64_t outdim1b[4];\n  int64_t outdim2[4];\n  int64_t outdim3[4];\n  int64_t outdim4[4];  // halo output shape\n\n  void init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {\n    dimA[0] = dimA[1] = dimA[2] = dimA[3] = 0;\n    filterdimA1[0] = filterdimA1[1] = filterdimA1[2] = filterdimA1[3] = 0;\n    filterdimA2[0] = filterdimA2[1] = filterdimA2[2] = filterdimA2[3] = 0;\n    filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0;\n    filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0;\n    filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0;\n    threshdim[0] = threshdim[1] = threshdim[2] = threshdim[3] = 1;\n\n    // All dim calculation after this order of n,c,h,w\n    if (explicit_nhwc) {\n      axis[0] = 0;\n      axis[1] = 3;\n      axis[2] = 1;\n      axis[3] = 2;\n    } else {\n      axis[0] = 0;\n      axis[1] = 1;\n      axis[2] = 2;\n      axis[3] = 3;\n    }\n\n    for (int dim = 0; dim < 4; dim++) {\n      dimA[dim] = inputs[0].size(axis[dim]);\n      filterdimA1[dim] = inputs[1].size(axis[dim]);\n      filterdimA2[dim] = inputs[2].size(axis[dim]);\n      filterdimA3[dim] = inputs[3].size(axis[dim]);\n    }\n    if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {\n      for (int dim = 0; dim < 4; dim++) {\n        filterdimA4[dim] = inputs[10].size(axis[dim]);\n      }\n    }\n    for (int dim = 0; dim < 4; dim++) {\n      if (dim == 2) {\n        filterdimA2hh[dim] = 1;\n      } else {\n        filterdimA2hh[dim] = filterdimA2[dim];\n      }\n    }\n\n    // output dim in n,c,h,w used by backend\n    outdimA0[0] = outdimA0[1] = outdimA0[2] = outdimA0[3] = 0;\n    outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0;\n    outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0;\n    outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0;\n    outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0;\n    outdimA4[0] = outdimA4[1] = outdimA4[2] = outdimA4[3] = 0;\n\n    // use these fixed value for test run\n    padA[0] = 0;\n    padA[1] = 0;\n    padA1[0] = 1;\n    padA1[1] = 1;\n    padA2[0] = 0;\n    padA2[1] = 1;\n    dilationA[0] = 1;\n    dilationA[1] = 1;\n    convstrideA[0] = 1;\n    convstrideA[1] = 1;\n    convstride1X1[0] = stride_1X1;\n    convstride1X1[1] = stride_1X1;\n\n    // compute output from pad/stride/dilation\n    outdimA1[0] = dimA[0];\n    outdimA1[1] = filterdimA1[0];\n    for (int dim = 0; dim < 2; dim++) {\n      outdimA1[dim + 2] =\n          getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);\n    }\n    for (int dim = 0; dim < 4; dim++) {\n      if (dim == 2) {\n        outdimA1b[dim] = outdimA1[dim] + 2;\n      } else {\n        outdimA1b[dim] = outdimA1[dim];\n      }\n    }\n\n    outdimA2[0] = outdimA1[0];\n    outdimA2[1] = filterdimA2[0];\n    for (int dim = 0; dim < 2; dim++) {\n      outdimA2[dim + 2] =\n          getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]);\n    }\n\n    for (int dim = 0; dim < 4; dim++) {\n      if (dim == 2) {\n        outdimA0[dim] = 3;\n        outdimA4[dim] = 1;\n      } else {\n        outdimA0[dim] = outdimA1[dim];\n        outdimA4[dim] = outdimA2[dim];\n      }\n    }\n\n    outdimA3[0] = outdimA2[0];\n    outdimA3[1] = filterdimA3[0];\n    for (int dim = 0; dim < 2; dim++) {\n      outdimA3[dim + 2] =\n          getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]);\n    }\n\n    // Create output tensor in the correct shape in pytorch's view\n    outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0;\n    outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0;\n    outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0;\n    outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0;\n    if (explicit_nhwc) {\n      axis[0] = 0;\n      axis[1] = 2;\n      axis[2] = 3;\n      axis[3] = 1;\n    }\n    for (int dim = 0; dim < 4; dim++) {\n      outdim0[dim] = outdimA0[axis[dim]];\n      outdim1[dim] = outdimA1[axis[dim]];\n      outdim1b[dim] = outdimA1b[axis[dim]];\n      outdim2[dim] = outdimA2[axis[dim]];\n      outdim3[dim] = outdimA3[axis[dim]];\n      outdim4[dim] = outdimA4[axis[dim]];\n    }\n  }\n};\n\nbottleneck_forward_status forward_state;\n\n}  // end of anonymous namespace\n\nstd::vector<at::Tensor> bottleneck_forward_init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {\n  // NB! Bottleneck_forward and bottleneck_backward are NOT thread safe method.\n  // NB! We use a global object to store state.\n  forward_state.init(explicit_nhwc, stride_1X1, inputs);\n\n  // create output vector\n  std::vector<at::Tensor> outputs;\n  auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;\n\n  // printf(\"outdim1 =\n  // (%d,%d,%d,%d)\\n\",forward_state.outdim1[0],forward_state.outdim1[1],forward_state.outdim1[2],forward_state.outdim1[3]);\n  auto out1 = at::empty(forward_state.outdim1, inputs[0].type(), output_format);\n  auto out2 = at::empty(forward_state.outdim2, inputs[0].type(), output_format);\n  auto out3 = at::empty(forward_state.outdim3, inputs[0].type(), output_format);\n\n  outputs.push_back(out1);\n  outputs.push_back(out2);\n  outputs.push_back(out3);\n\n  return outputs;\n}\n\n// inputs contains x,w,z,b,(i)\nvoid bottleneck_forward_out1(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs,\n                             std::vector<at::Tensor> outputs) {\n  std::cout << std::fixed;\n\n  // run\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n  at::Half* w = inputs[1].data_ptr<at::Half>();\n  at::Half* z = inputs[4].data_ptr<at::Half>();\n  at::Half* b = inputs[7].data_ptr<at::Half>();\n  auto out1 = outputs[0];\n  at::Half* y1 = out1.data_ptr<at::Half>();\n\n  run_conv_scale_bias_add_activation(forward_state.dimA, forward_state.padA, forward_state.convstride1X1,\n                                     forward_state.dilationA, forward_state.filterdimA1, forward_state.outdimA1,\n                                     CUDNN_DATA_HALF, x, w, y1, z, b, nullptr);\n\n  DEBUG_MSG(\"[DEBUG] new relu1 : \" << out1.to(at::kFloat).sum().item<float>());\n}\n\n// computes halo (top or bottom) from fat halo input.\n// fat halo input is 3 pixels wide in H.\nat::Tensor bottleneck_forward_out2_halo(bool explicit_nhwc, at::Tensor fat_halo_y1, std::vector<at::Tensor> inputs) {\n  auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;\n\n  // run\n  at::Half* w = inputs[2].data_ptr<at::Half>();\n  at::Half* z = inputs[5].data_ptr<at::Half>();\n  at::Half* b = inputs[8].data_ptr<at::Half>();\n\n  at::Half* y1 = fat_halo_y1.data_ptr<at::Half>();\n\n  auto halo_y2 = at::empty(forward_state.outdim4, inputs[0].type(), output_format);\n  at::Half* y2 = halo_y2.data_ptr<at::Half>();\n\n  run_conv_scale_bias_add_activation(forward_state.outdimA0, forward_state.padA2, forward_state.convstrideA,\n                                     forward_state.dilationA, forward_state.filterdimA2, forward_state.outdimA4,\n                                     CUDNN_DATA_HALF, y1, w, y2, z, b, nullptr);\n\n  return halo_y2;\n}\n\n// compute halo correction term (top or bottom) from slim halo input (N,C,1,W).\n// slim halo input is 1 pixel wide in H.\nat::Tensor bottleneck_forward_out2_halo_corr(bool explicit_nhwc, at::Tensor slim_halo_y1,\n                                             std::vector<at::Tensor> inputs, at::Tensor w1by3,\n                                             at::Tensor out2_part_halo) {\n  auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;\n\n  // run\n  at::Half* w = w1by3.data_ptr<at::Half>();  // C,C,1,3\n  at::Half* z = inputs[5].data_ptr<at::Half>();\n  at::Half* b = inputs[8].data_ptr<at::Half>();\n\n  at::Half* y1 = slim_halo_y1.data_ptr<at::Half>();\n\n  at::Half* prev_out2 = out2_part_halo.data_ptr<at::Half>();\n\n  auto halo_y2 = at::empty(forward_state.outdim4, inputs[0].type(), output_format);\n  at::Half* y2 = halo_y2.data_ptr<at::Half>();\n\n  run_conv_add_scale_bias_activation(forward_state.outdimA4, forward_state.padA2, forward_state.convstrideA,\n                                     forward_state.dilationA, forward_state.filterdimA2hh, forward_state.outdimA4,\n                                     CUDNN_DATA_HALF, y1, w, y2, z, b, prev_out2);\n\n  return halo_y2;\n}\n\nvoid bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs,\n                             std::vector<at::Tensor> outputs) {\n  std::cout << std::fixed;\n\n  // from _out1 method\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n  auto out1 = outputs[0];\n  at::Half* y1 = out1.data_ptr<at::Half>();\n\n  // run\n  at::Half* w = inputs[2].data_ptr<at::Half>();\n  at::Half* z = inputs[5].data_ptr<at::Half>();\n  at::Half* b = inputs[8].data_ptr<at::Half>();\n  auto out2 = outputs[1];\n  at::Half* y2 = out2.data_ptr<at::Half>();\n\n  // printf(\"forward_state.outdimA1 =\n  // {%d,%d,%d,%d}\\n\",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);\n  // printf(\"forward_state.padA1 = {%d,%d}\\n\",forward_state.padA1[0],forward_state.padA1[1]);\n  // printf(\"forward_state.convstrideA = {%d,%d}\\n\",forward_state.convstrideA[0],forward_state.convstrideA[1]);\n  // printf(\"forward_state.dilationA = {%d,%d}\\n\",forward_state.dilationA[0],forward_state.dilationA[1]);\n  // printf(\"forward_state.filterdimA2 =\n  // {%d,%d,%d,%d}\\n\",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);\n  // printf(\"forward_state.outdimA2 =\n  // {%d,%d,%d,%d}\\n\",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);\n  run_conv_scale_bias_add_activation(forward_state.outdimA1, forward_state.padA1, forward_state.convstrideA,\n                                     forward_state.dilationA, forward_state.filterdimA2, forward_state.outdimA2,\n                                     CUDNN_DATA_HALF, y1, w, y2, z, b, nullptr);\n  DEBUG_MSG(\"[DEBUG] new relu2 : \" << out2.to(at::kFloat).sum().item<float>());\n}\n\nvoid bottleneck_forward_out2_mask(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs,\n                                  std::vector<at::Tensor> outputs, at::Tensor thresholdTop,\n                                  at::Tensor thresholdBottom) {\n  std::cout << std::fixed;\n\n  // from _out1 method\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n  auto out1 = outputs[0];\n  at::Half* y1 = out1.data_ptr<at::Half>();\n\n  // run\n  at::Half* w = inputs[2].data_ptr<at::Half>();\n  at::Half* z = inputs[5].data_ptr<at::Half>();\n  at::Half* b = inputs[8].data_ptr<at::Half>();\n  auto out2 = outputs[1];\n  at::Half* y2 = out2.data_ptr<at::Half>();\n\n  // printf(\"forward_state.outdimA1 =\n  // {%d,%d,%d,%d}\\n\",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);\n  // printf(\"forward_state.padA1 = {%d,%d}\\n\",forward_state.padA1[0],forward_state.padA1[1]);\n  // printf(\"forward_state.convstrideA = {%d,%d}\\n\",forward_state.convstrideA[0],forward_state.convstrideA[1]);\n  // printf(\"forward_state.dilationA = {%d,%d}\\n\",forward_state.dilationA[0],forward_state.dilationA[1]);\n  // printf(\"forward_state.filterdimA2 =\n  // {%d,%d,%d,%d}\\n\",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);\n  // printf(\"forward_state.outdimA2 =\n  // {%d,%d,%d,%d}\\n\",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);\n  run_conv_scale_bias_add_activation_mask(forward_state.outdimA1, forward_state.padA1, forward_state.convstrideA,\n                                          forward_state.dilationA, forward_state.filterdimA2, forward_state.outdimA2,\n                                          forward_state.threshdim, CUDNN_DATA_HALF, y1, w, y2, z, b, nullptr,\n                                          thresholdTop.data_ptr<int>(), thresholdBottom.data_ptr<int>(),\n                                          2);  // axis == 1 -> Does this assume explicit NHWC?\n  DEBUG_MSG(\"[DEBUG] new relu2 : \" << out2.to(at::kFloat).sum().item<float>());\n}\n\nvoid bottleneck_forward_out2_pad(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs,\n                                 std::vector<at::Tensor> outputs, at::Tensor out1_pad) {\n  std::cout << std::fixed;\n\n  // from _out1 method\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n  auto out1 = outputs[0];\n  at::Half* y1 = out1_pad.data_ptr<at::Half>();\n\n  // run\n  at::Half* w = inputs[2].data_ptr<at::Half>();\n  at::Half* z = inputs[5].data_ptr<at::Half>();\n  at::Half* b = inputs[8].data_ptr<at::Half>();\n  auto out2 = outputs[1];\n  at::Half* y2 = out2.data_ptr<at::Half>();\n\n  // printf(\"forward_state.outdimA1 =\n  // {%d,%d,%d,%d}\\n\",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);\n  // printf(\"forward_state.padA1 = {%d,%d}\\n\",forward_state.padA1[0],forward_state.padA1[1]);\n  // printf(\"forward_state.convstrideA = {%d,%d}\\n\",forward_state.convstrideA[0],forward_state.convstrideA[1]);\n  // printf(\"forward_state.dilationA = {%d,%d}\\n\",forward_state.dilationA[0],forward_state.dilationA[1]);\n  // printf(\"forward_state.filterdimA2 =\n  // {%d,%d,%d,%d}\\n\",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);\n  // printf(\"forward_state.outdimA2 =\n  // {%d,%d,%d,%d}\\n\",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);\n  run_conv_scale_bias_add_activation(forward_state.outdimA1b, forward_state.padA2, forward_state.convstrideA,\n                                     forward_state.dilationA, forward_state.filterdimA2, forward_state.outdimA2,\n                                     CUDNN_DATA_HALF, y1, w, y2, z, b, nullptr);\n  DEBUG_MSG(\"[DEBUG] new relu2 : \" << out2.to(at::kFloat).sum().item<float>());\n}\n\nvoid bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs,\n                             std::vector<at::Tensor> outputs) {\n  std::cout << std::fixed;\n\n  // from _out1 method\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n\n  // create output of conv3\n  auto out3 = outputs[2];\n  at::Half* y3 = out3.data_ptr<at::Half>();\n\n  // create output of conv4 that may exist\n  auto identity = at::empty_like(out3);\n  at::Half* yi = identity.data_ptr<at::Half>();\n\n  at::Half *w, *z, *b;\n\n  if (stride_1X1 != 1 || forward_state.filterdimA3[0] != forward_state.dimA[1]) {\n    w = inputs[10].data_ptr<at::Half>();\n    z = inputs[11].data_ptr<at::Half>();\n    b = inputs[12].data_ptr<at::Half>();\n    run_conv_scale_bias(forward_state.dimA, forward_state.padA, forward_state.convstride1X1, forward_state.dilationA,\n                        forward_state.filterdimA4, forward_state.outdimA3, CUDNN_DATA_HALF, x, w, yi, z, b);\n    DEBUG_MSG(\"[DEBUG] new downsample : \" << identity.to(at::kFloat).sum().item<float>());\n  } else {\n    yi = x;\n  }\n\n  auto out2 = outputs[1];\n  at::Half* y2 = out2.data_ptr<at::Half>();\n\n  w = inputs[3].data_ptr<at::Half>();\n  z = inputs[6].data_ptr<at::Half>();\n  b = inputs[9].data_ptr<at::Half>();\n\n  run_conv_scale_bias_add_activation(forward_state.outdimA2, forward_state.padA, forward_state.convstrideA,\n                                     forward_state.dilationA, forward_state.filterdimA3, forward_state.outdimA3,\n                                     CUDNN_DATA_HALF, y2, w, y3, z, b, yi);\n  DEBUG_MSG(\"[DEBUG] new relu3 : \" << out3.to(at::kFloat).sum().item<float>());\n}\n\nnamespace {\n\nstruct bottleneck_backward_state {\n  int64_t dimA[4];\n  int64_t filterdimA1[4];\n  int64_t filterdimA2[4];\n  int64_t filterdimA3[4];\n  int64_t filterdimA4[4];\n  int64_t filterdimA2hh[4];  // Cin,Cout,1,3\n  int64_t threshdim[4];\n\n  int axis[4];\n\n  int64_t outdimA1[4];   // grad_out1\n  int64_t outdimA1b[4];  // out1_pad\n  int64_t outdimA2[4];   // grad_out2\n  int64_t outdimA3[4];\n  int64_t outdimA1h[4];   // output: grad_out1 halo (H=3)\n  int64_t outdimA2h[4];   // input : grad_out2 halo cells (H=3)\n  int64_t outdimA1hh[4];  // input: grad_out2 halo (H=1)\n  int64_t outdimA2hh[4];  // input: out1 halo (H=1)\n\n  int64_t padA[2];\n  int64_t padA1[2];\n  int64_t padA2[2];\n  int64_t dilationA[2];\n  int64_t convstrideA[2];\n  int64_t convstride1X1[2];\n\n  int64_t filterdim2hh[4];  // Cin,1,3,Cout\n\n  int64_t outdim1[4];\n  int64_t outdim1b[4];\n  int64_t outdim2[4];\n  int64_t outdim3[4];\n  int64_t outdim1h[4];\n  int64_t outdim1hh[4];\n\n  void init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {\n    // setup dimensions\n    dimA[0] = dimA[1] = dimA[2] = dimA[3] = 0;\n    filterdimA1[0] = filterdimA1[1] = filterdimA1[2] = filterdimA1[3] = 0;\n    filterdimA2[0] = filterdimA2[1] = filterdimA2[2] = filterdimA2[3] = 0;\n    filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0;\n    filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0;\n    filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0;\n    threshdim[0] = threshdim[1] = threshdim[2] = threshdim[3] = 1;\n\n    // All dim calculation after this order of n,c,h,w\n    if (explicit_nhwc) {\n      axis[0] = 0;\n      axis[1] = 3;\n      axis[2] = 1;\n      axis[3] = 2;\n    } else {\n      axis[0] = 0;\n      axis[1] = 1;\n      axis[2] = 2;\n      axis[3] = 3;\n    }\n\n    for (int dim = 0; dim < 4; dim++) {\n      dimA[dim] = inputs[0].size(axis[dim]);\n      filterdimA1[dim] = inputs[1].size(axis[dim]);\n      filterdimA2[dim] = inputs[2].size(axis[dim]);\n      filterdimA3[dim] = inputs[3].size(axis[dim]);\n    }\n    if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {\n      for (int dim = 0; dim < 4; dim++) {\n        filterdimA4[dim] = inputs[14].size(axis[dim]);\n      }\n    }\n\n    for (int dim = 0; dim < 4; dim++) {\n      if (dim == 2) {\n        filterdimA2hh[dim] = 1;\n      } else {\n        filterdimA2hh[dim] = filterdimA2[dim];\n      }\n    }\n\n    // output dim in n,c,h,w used by backend\n    outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0;\n    outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0;\n    outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0;\n    outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0;\n    outdimA1h[0] = outdimA1h[1] = outdimA1h[2] = outdimA1h[3] = 0;\n    outdimA2h[0] = outdimA2h[1] = outdimA2h[2] = outdimA2h[3] = 0;\n    outdimA1hh[0] = outdimA1hh[1] = outdimA1hh[2] = outdimA1hh[3] = 0;\n    outdimA2hh[0] = outdimA2hh[1] = outdimA2hh[2] = outdimA2hh[3] = 0;\n\n    // use these fixed value for test run\n    padA[0] = 0;\n    padA[1] = 0;\n    padA1[0] = 1;\n    padA1[1] = 1;\n    padA2[0] = 0;\n    padA2[1] = 1;\n    dilationA[0] = 1;\n    dilationA[1] = 1;\n    convstrideA[0] = 1;\n    convstrideA[1] = 1;\n    convstride1X1[0] = stride_1X1;\n    convstride1X1[1] = stride_1X1;\n\n    // compute output from pad/stride/dilation\n    outdimA1[0] = dimA[0];\n    outdimA1[1] = filterdimA1[0];\n    for (int dim = 0; dim < 2; dim++) {\n      outdimA1[dim + 2] =\n          getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);\n    }\n    for (int dim = 0; dim < 4; dim++) {\n      if (dim == 2) {\n        outdimA1b[dim] = outdimA1[dim] + 2;\n      } else {\n        outdimA1b[dim] = outdimA1[dim];\n      }\n    }\n\n    outdimA2[0] = outdimA1[0];\n    outdimA2[1] = filterdimA2[0];\n    for (int dim = 0; dim < 2; dim++) {\n      outdimA2[dim + 2] =\n          getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]);\n    }\n\n    outdimA3[0] = outdimA2[0];\n    outdimA3[1] = filterdimA3[0];\n    for (int dim = 0; dim < 2; dim++) {\n      outdimA3[dim + 2] =\n          getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]);\n    }\n\n    for (int dim = 0; dim < 4; dim++) {\n      if (dim == 2) {\n        outdimA1h[dim] = 3;\n        outdimA2h[dim] = 3;\n        outdimA1hh[dim] = 1;\n        outdimA2hh[dim] = 1;\n      } else {\n        outdimA1h[dim] = outdimA1[dim];\n        outdimA2h[dim] = outdimA2[dim];\n        outdimA1hh[dim] = outdimA1[dim];\n        outdimA2hh[dim] = outdimA2[dim];\n      }\n    }\n\n    // Create output tensor in the correct shape in pytorch's view\n    outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0;\n    outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0;\n    outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0;\n    outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0;\n    outdim1h[0] = outdim1h[1] = outdim1h[2] = outdim1h[3] = 0;\n    outdim1hh[0] = outdim1hh[1] = outdim1hh[2] = outdim1hh[3] = 0;\n    filterdim2hh[0] = filterdim2hh[1] = filterdim2hh[2] = filterdim2hh[3] = 0;\n    if (explicit_nhwc) {\n      axis[0] = 0;\n      axis[1] = 2;\n      axis[2] = 3;\n      axis[3] = 1;\n    }\n    for (int dim = 0; dim < 4; dim++) {\n      outdim1[dim] = outdimA1[axis[dim]];\n      outdim1b[dim] = outdimA1b[axis[dim]];\n      outdim2[dim] = outdimA2[axis[dim]];\n      outdim3[dim] = outdimA3[axis[dim]];\n      outdim1h[dim] = outdimA1h[axis[dim]];\n      outdim1hh[dim] = outdimA1hh[axis[dim]];\n      filterdim2hh[dim] = filterdimA2hh[axis[dim]];\n    }\n  }\n};\n\nbottleneck_backward_state backward_state;\n\n}  // namespace\n\nstd::vector<at::Tensor> bottleneck_backward_init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {\n  std::cout << std::fixed;\n\n  backward_state.init(explicit_nhwc, stride_1X1, inputs);\n\n  // create output vector\n  std::vector<at::Tensor> outputs;\n  auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;\n\n  auto grad_x = at::empty_like(inputs[0]);\n  auto wgrad1 = at::empty_like(inputs[1]);\n  auto wgrad2 = at::empty_like(inputs[2]);\n  auto wgrad3 = at::empty_like(inputs[3]);\n\n  outputs.push_back(grad_x);\n  outputs.push_back(wgrad1);\n  outputs.push_back(wgrad2);\n  outputs.push_back(wgrad3);\n  if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]) {\n    auto wgrad4 = at::empty_like(inputs[14]);\n    outputs.push_back(wgrad4);\n  }\n\n  return outputs;\n}\n\nvoid bottleneck_backward_wgrad3(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs,\n                                std::vector<at::Tensor> outputs) {\n  // dconv3+drelu2+dscale2\n  at::Half* conv_in = inputs[13].data_ptr<at::Half>();\n  at::Half* dy3 = inputs[10].data_ptr<at::Half>();\n\n  // wgrad\n  auto wgrad3 = outputs[3];\n  at::Half* dw3 = wgrad3.data_ptr<at::Half>();\n  run_dconv(backward_state.outdimA2, backward_state.padA, backward_state.convstrideA, backward_state.dilationA,\n            backward_state.filterdimA3, backward_state.outdimA3, CUDNN_DATA_HALF, conv_in, dw3, dy3,\n            CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n  DEBUG_MSG(\"[DEBUG] new wgrad3 : \" << wgrad3.to(at::kFloat).sum().item<float>());\n}\n\nat::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs,\n                                         std::vector<at::Tensor> outputs) {\n  bool requires_grad = inputs[0].requires_grad();\n\n  std::cout << std::fixed;\n  auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;\n\n  // dconv3+drelu2+dscale2\n  at::Half* conv_in = inputs[13].data_ptr<at::Half>();\n  at::Half* dy3 = inputs[10].data_ptr<at::Half>();\n\n  DEBUG_MSG(\"[DEBUG] new dconv3 : \" << inputs[10].to(at::kFloat).sum().item<float>());\n\n  // dgrad\n  auto grad_out2 = at::empty(backward_state.outdim2, inputs[0].type(), output_format);\n  at::Half* dy2 = grad_out2.data_ptr<at::Half>();\n  at::Half* w = inputs[3].data_ptr<at::Half>();\n  at::Half* z = inputs[5].data_ptr<at::Half>();\n\n  at::Half* relu2 = inputs[13].data_ptr<at::Half>();\n\n  run_dconv_drelu_dscale(backward_state.outdimA2, backward_state.padA, backward_state.convstrideA,\n                         backward_state.dilationA, backward_state.filterdimA3, backward_state.outdimA3, CUDNN_DATA_HALF,\n                         dy2, w, dy3, z, relu2);\n\n  // do halo exchange of dy2 here\n\n  DEBUG_MSG(\"[DEBUG] new dconv2 : \" << grad_out2.to(at::kFloat).sum().item<float>());\n\n  return grad_out2;\n}\n\nat::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs,\n                                         std::vector<at::Tensor> outputs, at::Tensor grad_out2) {\n  bool requires_grad = inputs[0].requires_grad();\n\n  std::cout << std::fixed;\n  auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;\n\n  // dgrad\n  at::Half* dy2 = grad_out2.data_ptr<at::Half>();\n\n  // dgrad\n  auto grad_out1 = at::empty(backward_state.outdim1, inputs[0].type(), output_format);\n  at::Half* dy1 = grad_out1.data_ptr<at::Half>();\n  at::Half* w = inputs[2].data_ptr<at::Half>();\n  at::Half* z = inputs[4].data_ptr<at::Half>();\n\n  at::Half* relu1 = inputs[12].data_ptr<at::Half>();\n  // printf(\"relu.shape = [%d,%d,%d,%d]\\n\",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3));\n\n  // fused dgrad\n  // printf(\"backward_state.outdim1 =\n  // {%d,%d,%d,%d}\\n\",backward_state.outdim1[0],backward_state.outdim1[1],backward_state.outdim1[2],backward_state.outdim1[3]);\n  run_dconv_drelu_dscale(backward_state.outdimA1, backward_state.padA1, backward_state.convstrideA,\n                         backward_state.dilationA, backward_state.filterdimA2, backward_state.outdimA2, CUDNN_DATA_HALF,\n                         dy1, w, dy2, z, relu1);\n\n  return grad_out1;\n}\n\nat::Tensor bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs,\n                                              std::vector<at::Tensor> outputs, at::Tensor grad_out2,\n                                              at::Tensor thresholdTop, at::Tensor thresholdBottom) {\n  bool requires_grad = inputs[0].requires_grad();\n\n  std::cout << std::fixed;\n  auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;\n\n  // dgrad\n  at::Half* dy2 = grad_out2.data_ptr<at::Half>();\n\n  // dgrad\n  auto grad_out1 = at::empty(backward_state.outdim1, inputs[0].type(), output_format);\n  at::Half* dy1 = grad_out1.data_ptr<at::Half>();\n  at::Half* w = inputs[2].data_ptr<at::Half>();\n  at::Half* z = inputs[4].data_ptr<at::Half>();\n\n  at::Half* relu1 = inputs[12].data_ptr<at::Half>();\n  // printf(\"relu.shape = [%d,%d,%d,%d]\\n\",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3));\n\n  // fused dgrad\n  run_dconv_drelu_dscale_mask(backward_state.outdimA1, backward_state.padA1, backward_state.convstrideA,\n                              backward_state.dilationA, backward_state.filterdimA2, backward_state.outdimA2,\n                              backward_state.threshdim, CUDNN_DATA_HALF, dy1, w, dy2, z, relu1,\n                              thresholdTop.data_ptr<int>(), thresholdBottom.data_ptr<int>(), 2);\n\n  return grad_out1;\n}\n\n// perform backward data 1x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,1,W,C] with padding=(0,1)\n// to produce output of shape [N,1,W,C]\nat::Tensor bottleneck_backward_grad_out1_halo_corr(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs,\n                                                   at::Tensor w1by3, std::vector<at::Tensor> outputs,\n                                                   at::Tensor grad_out2_halo, at::Tensor relu1_halo,\n                                                   at::Tensor part_grad_out1) {\n  bool requires_grad = inputs[0].requires_grad();\n\n  std::cout << std::fixed;\n  auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;\n\n  // dgrad\n  at::Half* dy2h = grad_out2_halo.data_ptr<at::Half>();\n\n  // dgrad\n  auto grad_out1_halo = at::empty(backward_state.outdim1hh, inputs[0].type(), output_format);\n  at::Half* dy1h = grad_out1_halo.data_ptr<at::Half>();\n  // at::Half* w = inputs[2].data_ptr<at::Half>();  // use w1by3 instead, which is a sliced version of inputs[2]\n  at::Half* w = w1by3.data_ptr<at::Half>();\n  at::Half* z = inputs[4].data_ptr<at::Half>();\n  at::Half* relu1h = relu1_halo.data_ptr<at::Half>();\n  at::Half* pdy1h = part_grad_out1.data_ptr<at::Half>();\n\n  // printf(\"relu.shape = [%d,%d,%d,%d]\\n\",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3));\n  //  fused dgrad\n  // printf(\"backward_state.outdimA1h =\n  // {%d,%d,%d,%d}\\n\",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]);\n  // printf(\"backward_state.outdimA2h =\n  // {%d,%d,%d,%d}\\n\",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]);\n  // printf(\"backward_state.filterdimA2 =\n  // {%d,%d,%d,%d}\\n\",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]);\n  run_dconv_add_drelu_dscale(backward_state.outdimA1hh,\n                             backward_state.padA2,  // 0,1\n                             backward_state.convstrideA, backward_state.dilationA,\n                             backward_state.filterdimA2hh,  // C,1,3,C\n                             backward_state.outdimA2hh, CUDNN_DATA_HALF, dy1h, w, dy2h, z, relu1h, pdy1h);\n\n  return grad_out1_halo;\n}\n\n// perform backward data 3x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,3,W,C] with padding=(1,1)\n// to produce output of shape [N,3,W,C]\nat::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs,\n                                              std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo,\n                                              at::Tensor relu1_halo) {\n  bool requires_grad = inputs[0].requires_grad();\n\n  std::cout << std::fixed;\n  auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;\n\n  // dgrad\n  at::Half* dy2h = grad_out2_halo.data_ptr<at::Half>();\n\n  // dgrad\n  auto grad_out1_halo = at::empty(backward_state.outdim1h, inputs[0].type(), output_format);\n  at::Half* dy1h = grad_out1_halo.data_ptr<at::Half>();\n  at::Half* w = inputs[2].data_ptr<at::Half>();\n  at::Half* z = inputs[4].data_ptr<at::Half>();\n\n  at::Half* relu1h = relu1_halo.data_ptr<at::Half>();\n  // printf(\"relu.shape = [%d,%d,%d,%d]\\n\",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3));\n  //  fused dgrad\n  // printf(\"backward_state.outdimA1h =\n  // {%d,%d,%d,%d}\\n\",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]);\n  // printf(\"backward_state.outdimA2h =\n  // {%d,%d,%d,%d}\\n\",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]);\n  // printf(\"backward_state.filterdimA2 =\n  // {%d,%d,%d,%d}\\n\",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]);\n  run_dconv_drelu_dscale(backward_state.outdimA1h, backward_state.padA1, backward_state.convstrideA,\n                         backward_state.dilationA, backward_state.filterdimA2, backward_state.outdimA2h,\n                         CUDNN_DATA_HALF, dy1h, w, dy2h, z, relu1h);\n\n  return grad_out1_halo;\n}\n\nvoid bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs,\n                                    std::vector<at::Tensor> outputs, at::Tensor input, at::Tensor grad_out2) {\n  std::cout << std::fixed;\n  auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;\n\n  // dgrad\n  at::Half* dy2 = grad_out2.data_ptr<at::Half>();\n\n  // dconv2+drelu1+dscale1\n  at::Half* conv_in = input.data_ptr<at::Half>();\n\n  // wgrad\n  auto wgrad2 = outputs[2];\n  at::Half* dw2 = wgrad2.data_ptr<at::Half>();\n\n  // printf(\"outdimA1b =\n  // (%d,%d,%d,%d)\\n\",backward_state.outdimA1b[0],backward_state.outdimA1b[1],backward_state.outdimA1b[2],backward_state.outdimA1b[3]);\n  // printf(\"backward_state.padA2 = {%d,%d}\\n\",backward_state.padA2[0],backward_state.padA2[1]);\n  run_dconv(backward_state.outdimA1b,  // conv_in.shape (including H halos)\n            backward_state.padA2,      // 0, 1\n            backward_state.convstrideA, backward_state.dilationA,\n            backward_state.filterdimA2,  // dw2.shape\n            backward_state.outdimA2,     // dy2.shape\n            CUDNN_DATA_HALF, conv_in, dw2, dy2, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n  DEBUG_MSG(\"[DEBUG] new wgrad2 : \" << wgrad2.to(at::kFloat).sum().item<float>());\n}\n\nvoid bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs,\n                                std::vector<at::Tensor> outputs, at::Tensor grad_out2) {\n  bool requires_grad = inputs[0].requires_grad();\n\n  std::cout << std::fixed;\n  auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;\n\n  // dgrad\n  at::Half* dy2 = grad_out2.data_ptr<at::Half>();\n\n  // dconv2+drelu1+dscale1\n  at::Half* conv_in = inputs[12].data_ptr<at::Half>();\n\n  // wgrad\n  auto wgrad2 = outputs[2];\n  at::Half* dw2 = wgrad2.data_ptr<at::Half>();\n\n  // printf(\"outdimA1 =\n  // (%d,%d,%d,%d)\\n\",backward_state.outdimA1[0],backward_state.outdimA1[1],backward_state.outdimA1[2],backward_state.outdimA1[3]);\n  run_dconv(backward_state.outdimA1, backward_state.padA1, backward_state.convstrideA, backward_state.dilationA,\n            backward_state.filterdimA2, backward_state.outdimA2, CUDNN_DATA_HALF, conv_in, dw2, dy2,\n            CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n  DEBUG_MSG(\"[DEBUG] new wgrad2 : \" << wgrad2.to(at::kFloat).sum().item<float>());\n}\n\n// compute halo cells for input volume of dimension [N,1,W,C] with padding=(0,1) to produce output volume of dimension\n// [N,1,W,C] input and grad_out2_halo tensors are all of same shape output tensor is of shape [Cin,1,3,Cout] (regular\n// filter dims are [Cin,3,3,Cout]\nat::Tensor bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs,\n                                           std::vector<at::Tensor> outputs, at::Tensor input,\n                                           at::Tensor grad_out2_halo) {\n  bool requires_grad = inputs[0].requires_grad();\n\n  std::cout << std::fixed;\n  auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;\n\n  // dgrad\n  at::Half* dy2 = grad_out2_halo.data_ptr<at::Half>();\n\n  // dconv2+drelu1+dscale1\n  at::Half* conv_in = input.data_ptr<at::Half>();\n\n  // wgrad\n  auto wgrad2_halo = at::empty(backward_state.filterdim2hh, input.type(), output_format);\n  at::Half* dw2 = wgrad2_halo.data_ptr<at::Half>();\n\n  // printf(\"backward_state.outdimA1hh =\n  // {%d,%d,%d,%d}\\n\",backward_state.outdimA1hh[0],backward_state.outdimA1hh[1],backward_state.outdimA1hh[2],backward_state.outdimA1hh[3]);\n  // printf(\"backward_state.outdimA2hh =\n  // {%d,%d,%d,%d}\\n\",backward_state.outdimA2hh[0],backward_state.outdimA2hh[1],backward_state.outdimA2hh[2],backward_state.outdimA2hh[3]);\n  // printf(\"backward_state.filterdim2hh =\n  // {%d,%d,%d,%d}\\n\",backward_state.filterdim2hh[0],backward_state.filterdim2hh[1],backward_state.filterdim2hh[2],backward_state.filterdim2hh[3]);\n  // printf(\"backward_state.filterdimA2hh =\n  // {%d,%d,%d,%d}\\n\",backward_state.filterdimA2hh[0],backward_state.filterdimA2hh[1],backward_state.filterdimA2hh[2],backward_state.filterdimA2hh[3]);\n  // printf(\"backward_state.padA2 = {%d,%d}\\n\",backward_state.padA2[0],backward_state.padA2[1]);\n  run_dconv(backward_state.outdimA1hh,  // N,C,1,W\n            backward_state.padA2,       // 0, 1\n            backward_state.convstrideA, backward_state.dilationA,\n            backward_state.filterdimA2hh,  // Cin,Cout,1,3\n            backward_state.outdimA2hh,     // N,C,1,W\n            CUDNN_DATA_HALF, conv_in, dw2, dy2, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n\n  return wgrad2_halo;\n}\n\nvoid bottleneck_backward_wgrad1(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs,\n                                std::vector<at::Tensor> outputs, at::Tensor grad_out1) {\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n  at::Half* dy1 = grad_out1.data_ptr<at::Half>();\n\n  // dconv1+add\n  // wgrad\n  auto wgrad1 = outputs[1];\n  at::Half* dw1 = wgrad1.data_ptr<at::Half>();\n  run_dconv(backward_state.dimA, backward_state.padA, backward_state.convstride1X1, backward_state.dilationA,\n            backward_state.filterdimA1, backward_state.outdimA1, CUDNN_DATA_HALF, x, dw1, dy1,\n            CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n}\n\nvoid bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs,\n                              std::vector<at::Tensor> outputs, at::Tensor grad_out2, at::Tensor grad_out1) {\n  bool requires_grad = inputs[0].requires_grad();\n\n  std::cout << std::fixed;\n  auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;\n\n  // dgrad\n  at::Half* dy2 = grad_out2.data_ptr<at::Half>();\n  at::Half* dy1 = grad_out1.data_ptr<at::Half>();\n\n  /*\n    // backward strided conv cannot be fused\n    // if stride == 1 but channel changes, we can fuse here\n    if (stride_1X1 != 1){\n      // dgrad\n      run_dconv(outdimA1,\n                padA1,\n                convstride1X1,\n                dilationA,\n                filterdimA2,\n                outdimA2,\n                CUDNN_DATA_HALF,\n                dy1,\n                w,\n                dy2,\n                CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);\n\n      // mul fused mask\n      grad_out1.mul_(inputs[15]);\n    }\n    else {\n      at::Half* relu1 = inputs[12].data_ptr<at::Half>();\n      // fused dgrad\n      run_dconv_drelu_dscale(outdimA1,\n                             padA1,\n                             convstride1X1,\n                             dilationA,\n                             filterdimA2,\n                             outdimA2,\n                             CUDNN_DATA_HALF,\n                             dy1,\n                             w,\n                             dy2,\n                             z,\n                             relu1);\n    }\n  */\n  DEBUG_MSG(\"[DEBUG] new dconv1 : \" << grad_out1.to(at::kFloat).sum().item<float>());\n\n  // create grads of conv4 that may exist\n  auto grad_x_conv4 = at::empty_like(inputs[0]);\n  at::Half* dx_conv4 = grad_x_conv4.data_ptr<at::Half>();\n  at::Tensor wgrad4;\n\n  // x used for dconv1 and dconv4 wgrad\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n\n  at::Half* w = NULL;\n\n  if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]) {\n    w = inputs[14].data_ptr<at::Half>();\n    at::Half* dy_conv4 = inputs[11].data_ptr<at::Half>();\n    if (requires_grad) {\n      run_dconv(backward_state.dimA, backward_state.padA, backward_state.convstride1X1, backward_state.dilationA,\n                backward_state.filterdimA4, backward_state.outdimA3, CUDNN_DATA_HALF, dx_conv4, w, dy_conv4,\n                CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);\n      // we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx\n      // DEBUG_MSG(\"[DEBUG] new dx_identity : \" << grad_x_conv4.to(at::kFloat).sum().item<float>());\n    }\n    // wgrad\n    wgrad4 = outputs[4];\n    at::Half* dw4 = wgrad4.data_ptr<at::Half>();\n    run_dconv(backward_state.dimA, backward_state.padA, backward_state.convstride1X1, backward_state.dilationA,\n              backward_state.filterdimA4, backward_state.outdimA3, CUDNN_DATA_HALF, x, dw4, dy_conv4,\n              CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n  } else {\n    // if there is no downsample, dx_conv4 is fork of drelu3\n    dx_conv4 = inputs[11].data_ptr<at::Half>();\n  }\n\n  // dgrad\n  w = inputs[1].data_ptr<at::Half>();\n  auto grad_x = outputs[0];\n  at::Half* dx = grad_x.data_ptr<at::Half>();\n\n  // backward strided conv cannot be fused\n  // if stride == 1 but channel changes, we can fuse here\n  if (requires_grad) {\n    if (stride_1X1 != 1) {\n      run_dconv(backward_state.dimA, backward_state.padA, backward_state.convstride1X1, backward_state.dilationA,\n                backward_state.filterdimA1, backward_state.outdimA1, CUDNN_DATA_HALF, dx, w, dy1,\n                CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);\n      // add 2 together\n      grad_x.add_(grad_x_conv4);\n    } else {\n      run_dconv_add(backward_state.dimA, backward_state.padA, backward_state.convstride1X1, backward_state.dilationA,\n                    backward_state.filterdimA1, backward_state.outdimA1, CUDNN_DATA_HALF, dx, w, dy1, dx_conv4);\n    }\n  }\n\n  DEBUG_MSG(\"[DEBUG] new dx : \" << grad_x.to(at::kFloat).sum().item<float>());\n  DEBUG_MSG(\"[DEBUG] new wgrad1 : \" << wgrad1.to(at::kFloat).sum().item<float>());\n\n  if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]) {\n    DEBUG_MSG(\"[DEBUG] new wgrad4 : \" << wgrad4.to(at::kFloat).sum().item<float>());\n  }\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &bottleneck_forward, \"Bottleneck block forward\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward\", &bottleneck_backward, \"Bottleneck block backward\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"forward_init\", &bottleneck_forward_init, \"Bottleneck block init\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"forward_out1\", &bottleneck_forward_out1, \"Bottleneck block forward\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"forward_out2\", &bottleneck_forward_out2, \"Bottleneck block forward\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"forward_out2_mask\", &bottleneck_forward_out2_mask, \"Bottleneck block forward\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"forward_out2_halo\", &bottleneck_forward_out2_halo, \"Bottleneck block forward\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"forward_out2_halo_corr\", &bottleneck_forward_out2_halo_corr, \"Bottleneck block forward\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"forward_out2_pad\", &bottleneck_forward_out2_pad, \"Bottleneck block forward\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"forward_rest\", &bottleneck_forward_rest, \"Bottleneck block forward\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward_init\", &bottleneck_backward_init, \"Bottleneck block backward init\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward_grad_out2\", &bottleneck_backward_grad_out2, \"Bottleneck block backward\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward_grad_out1\", &bottleneck_backward_grad_out1, \"Bottleneck block backward\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward_grad_out1_mask\", &bottleneck_backward_grad_out1_mask, \"Bottleneck block backward\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward_grad_out1_halo\", &bottleneck_backward_grad_out1_halo, \"Bottleneck block backward\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward_grad_out1_halo_corr\", &bottleneck_backward_grad_out1_halo_corr, \"Bottleneck block backward\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward_wgrad2_pad\", &bottleneck_backward_wgrad2_pad, \"Bottleneck block backward\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward_wgrad2\", &bottleneck_backward_wgrad2, \"Bottleneck block backward\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward_wgrad2_halo\", &bottleneck_backward_wgrad2_halo, \"Bottleneck block backward\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward_wgrad3\", &bottleneck_backward_wgrad3, \"Bottleneck block backward\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward_wgrad1\", &bottleneck_backward_wgrad1, \"Bottleneck block backward\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward_rest\", &bottleneck_backward_rest, \"Bottleneck block backward\",\n        py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cudnn/Handle.h>  // for getcudnnhandle\n#include <cudnn_frontend.h>\n#include <torch/extension.h>\n#include <torch/torch.h>\n\n#include <iostream>\n#include <vector>\n\n#ifdef DEBUG\n#define DEBUG_MSG(str)             \\\n  do {                             \\\n    std::cout << str << std::endl; \\\n  } while (false)\n#else\n#define DEBUG_MSG(str) \\\n  do {                 \\\n  } while (false)\n#endif\n\n#ifdef DEBUG_CUDNN\n#define DEBUG_CUDNN_MSG(buf, str) \\\n  do {                            \\\n    buf << str << std::endl;      \\\n  } while (false)\n#else\n#define DEBUG_CUDNN_MSG(buf, str) \\\n  do {                            \\\n  } while (false)\n#endif\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(at::MemoryFormat::ChannelsLast), #x \" must be contiguous\")\n#define CHECK_INPUT(x) \\\n  CHECK_CUDA(x);       \\\n  CHECK_CONTIGUOUS(x)\n\n#define checkCudnnErr(...)                                                    \\\n  do {                                                                        \\\n    int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \\\n    if (err) {                                                                \\\n      return;                                                                 \\\n    }                                                                         \\\n  } while (0)\n\nint checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) {\n  if (code) {\n    printf(\"CUDNN error at %s:%d, code=%d (%s) in '%s'\\n\", file, line, (int)code, cudnnGetErrorString(code), expr);\n    return 1;\n  }\n  return 0;\n}\n\nvoid checkError(cudaError_t code, char const* func, const char* file, const int line, bool abort = true);\n#define checkCUDAError(val)                      \\\n  {                                              \\\n    checkError((val), #val, __FILE__, __LINE__); \\\n  }  // in-line regular function\n\nvoid checkError(cudaError_t code, char const* func, const char* file, const int line, bool abort) {\n  if (code != cudaSuccess) {\n    const char* errorMessage = cudaGetErrorString(code);\n    fprintf(stderr, \"CUDA error returned from \\\"%s\\\" at %s:%d, Error code: %d (%s)\\n\", func, file, line, code,\n            errorMessage);\n    if (abort) {\n      cudaDeviceReset();\n      exit(code);\n    }\n  }\n}\n\nvoid generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, cudnnTensorFormat_t filterFormat) {\n  // For INT8x4 and INT8x32 we still compute standard strides here to input\n  // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref.\n  if (filterFormat == CUDNN_TENSOR_NCHW) {\n    strideA[nbDims - 1] = 1;\n    for (int64_t d = nbDims - 2; d >= 0; d--) {\n      strideA[d] = strideA[d + 1] * dimA[d + 1];\n    }\n  } else {\n    // Here we assume that the format is CUDNN_TENSOR_NHWC\n    strideA[1] = 1;\n    strideA[nbDims - 1] = strideA[1] * dimA[1];\n    for (int64_t d = nbDims - 2; d >= 2; d--) {\n      strideA[d] = strideA[d + 1] * dimA[d + 1];\n    }\n    strideA[0] = strideA[2] * dimA[2];\n  }\n}\n\nint getFwdConvDilatedFilterDim(int filterDim, int dilation) { return ((filterDim - 1) * dilation) + 1; }\n\nint getFwdConvPaddedImageDim(int tensorDim, int pad) { return tensorDim + (2 * pad); }\n\nint getFwdConvOutputDim(int tensorDim, int pad, int filterDim, int stride, int dilation) {\n  int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1;\n  return (p);\n}\n\n// create a cache for plan\nstd::unordered_map<std::string, cudnn_frontend::ExecutionPlan> plan_cache;\n\nstd::string getConvFusionString(int64_t* x_dim_padded, int64_t* padA, int64_t* convstrideA, int64_t* dilationA,\n                                int64_t* w_dim_padded, cudnnDataType_t dataType, std::string fusion_string) {\n  for (int i = 0; i < 4; i++) {\n    fusion_string += 'X';\n    fusion_string += std::to_string(x_dim_padded[i]);\n  }\n  for (int i = 0; i < 4; i++) {\n    fusion_string += 'W';\n    fusion_string += std::to_string(w_dim_padded[i]);\n  }\n  for (int i = 0; i < 2; i++) {\n    fusion_string += 'P';\n    fusion_string += std::to_string(padA[i]);\n  }\n  for (int i = 0; i < 2; i++) {\n    fusion_string += 'S';\n    fusion_string += std::to_string(convstrideA[i]);\n  }\n  for (int i = 0; i < 2; i++) {\n    fusion_string += 'D';\n    fusion_string += std::to_string(dilationA[i]);\n  }\n  fusion_string += 'T';\n  fusion_string += std::to_string(dataType);\n  return fusion_string;\n}\n\ncudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_, std::stringstream& log_buf,\n                                               cudnn_frontend::OperationGraph& opGraph, std::string cache_string,\n                                               bool use_heuristic = true) {\n  auto it = plan_cache.find(cache_string);\n  if (it != plan_cache.end()) {\n    DEBUG_CUDNN_MSG(log_buf, \"Found plan in cache\");\n    return it->second;\n  } else {\n    DEBUG_CUDNN_MSG(log_buf, \"No plan in cache\");\n    if (use_heuristic) {\n      // TODO: confirm which mode to use\n      auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()\n                            .setOperationGraph(opGraph)\n                            .setHeurMode(CUDNN_HEUR_MODE_INSTANT)\n                            .build();\n      auto engine_config_count = heuristics.getEngineConfigCount();\n      auto& engine_configs = heuristics.getEngineConfig(engine_config_count);\n      for (int64_t count = 0; count < engine_config_count; count++) {\n        try {\n          plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder()\n                                                         .setHandle(handle_)\n                                                         .setEngineConfig(engine_configs[count], opGraph.getTag())\n                                                         .build()));\n          break;\n        } catch (cudnn_frontend::cudnnException e) {\n          // Throw exception if all engines failed\n          if (count == (engine_config_count - 1)) {\n            throw e;\n          } else {\n            continue;\n          }\n        }\n      }\n    } else {\n      // How many engines support this operation graph ?\n      auto total_engines = opGraph.getEngineCount();\n      DEBUG_CUDNN_MSG(log_buf, opGraph.describe() << \" has \" << total_engines << \" engines.\");\n      // We have to randomly pick one engine from [0, total_engines)\n      // Selecting \"0\" by default\n      auto engine = cudnn_frontend::EngineBuilder().setGlobalEngineIdx(0).setOperationGraph(opGraph).build();\n      DEBUG_CUDNN_MSG(log_buf, engine.describe());\n      auto& knobs = engine.getSupportedKnobs();\n      for (auto it = std::begin(knobs); it != std::end(knobs); ++it) {\n        DEBUG_CUDNN_MSG(log_buf, it->describe());\n      }\n      if (knobs.begin() != knobs.end()) {\n        DEBUG_CUDNN_MSG(log_buf, \"Updated knob choice\");\n        knobs.begin()->setChoice(knobs.begin()->getMinValue() + 1);\n        DEBUG_CUDNN_MSG(log_buf, knobs.begin()->describe());\n      }\n\n      // Createmplacee the requisite engine config\n      auto engine_config = cudnn_frontend::EngineConfigBuilder().setEngine(engine).build();\n      DEBUG_CUDNN_MSG(log_buf, engine_config.describe());\n      plan_cache.emplace(\n          cache_string,\n          std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build()));\n    }\n\n    return plan_cache.find(cache_string)->second;\n  }\n}\n\nvoid run_conv_bias(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* convstride,\n                   int64_t* dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrB,\n                   at::Half* devPtrY) {\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n  std::stringstream log_buf;\n\n  try {\n    int convDim = 2;\n    float alpha = 1.0f;\n    float beta = 0.0f;\n    int64_t b_dim[] = {1, y_dim[1], 1, 1};\n\n    // Creates the necessary tensor descriptors\n    int64_t stride[4];\n    generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto xTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, x_dim)\n                       .setStrides(4, stride)\n                       .setId('x')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, xTensor.describe());\n\n    generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto wTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, w_dim)\n                       .setStrides(4, stride)\n                       .setId('w')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, wTensor.describe());\n\n    generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto afterConvTensor = cudnn_frontend::TensorBuilder()\n                               .setDim(4, y_dim)\n                               .setStrides(4, stride)\n                               .setId('c')\n                               .setAlignment(16)\n                               .setDataType(CUDNN_DATA_FLOAT)\n                               .setVirtual()\n                               .build();\n    DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe());\n\n    generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto bTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, b_dim)\n                       .setStrides(4, stride)\n                       .setId('b')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, bTensor.describe());\n\n    generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto afterBiasTensor = cudnn_frontend::TensorBuilder()\n                               .setDim(4, y_dim)\n                               .setStrides(4, stride)\n                               .setId('y')\n                               .setAlignment(16)\n                               .setDataType(dataType)\n                               .build();\n    DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());\n\n    // Define the bias operation\n    auto biasDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());\n\n    // Define the convolution problem\n    auto convDesc = cudnn_frontend::ConvDescBuilder()\n                        .setDataType(CUDNN_DATA_FLOAT)\n                        .setMathMode(CUDNN_CROSS_CORRELATION)\n                        .setNDims(convDim)\n                        .setStrides(convDim, convstride)\n                        .setPrePadding(convDim, conv_pad)\n                        .setPostPadding(convDim, conv_pad)\n                        .setDilation(convDim, dilation)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n    // Create a convolution Node\n    auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)\n                       .setxDesc(xTensor)\n                       .setwDesc(wTensor)\n                       .setyDesc(afterConvTensor)\n                       .setcDesc(convDesc)\n                       .setAlpha(alpha)\n                       .setBeta(beta)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n    // Create a Bias Node.\n    auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                       .setxDesc(conv_op.getOutputTensor())\n                       .setbDesc(bTensor)\n                       .setyDesc(afterBiasTensor)\n                       .setpwDesc(biasDesc)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, bias_op.describe());\n\n    // Create an Operation Graph. In this case it is convolution bias activation\n    std::array<cudnn_frontend::Operation const*, 2> ops = {&conv_op, &bias_op};\n\n    auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(2, ops.data()).build();\n\n    // Create string encoding for plan caching\n    auto cache_string = getConvFusionString(x_dim, conv_pad, convstride, dilation, w_dim, dataType, opGraph.getTag());\n    DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n    auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n    DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n    auto workspace_size = plan.getWorkspaceSize();\n    DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n    void* workspace_ptr = nullptr;\n    auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n    if (workspace_size > 0) {\n      workspace_ptr = workspace_tensor.data_ptr<float>();\n    }\n    void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrY};\n    int64_t uids[] = {'x', 'w', 'b', 'y'};\n    auto variantPack = cudnn_frontend::VariantPackBuilder()\n                           .setWorkspacePointer(workspace_ptr)\n                           .setDataPointers(4, data_ptrs)\n                           .setUids(4, uids)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n    cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n    checkCudnnErr(status);\n    cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n  } catch (cudnn_frontend::cudnnException e) {\n    std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n  }\n}\n\nvoid run_conv_bias_mask_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride,\n                             int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW,\n                             at::Half* devPtrB, int8_t* devPtrM, at::Half* devPtrY) {\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n  std::stringstream log_buf;\n\n  try {\n    int conv_dim = 2;\n    float alpha = 1.0f;\n    float beta = 0.0f;\n    int64_t b_dim[] = {1, y_dim[1], 1, 1};\n\n    // Creates the necessary tensor descriptors\n    int64_t stride[4];\n    generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto xTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, x_dim)\n                       .setStrides(4, stride)\n                       .setId('x')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, xTensor.describe());\n\n    generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto wTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, w_dim)\n                       .setStrides(4, stride)\n                       .setId('w')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, wTensor.describe());\n\n    generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto mTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, y_dim)\n                       .setStrides(4, stride)\n                       .setId('m')\n                       .setAlignment(16)\n                       .setDataType(CUDNN_DATA_INT8)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, wTensor.describe());\n\n    generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto afterConvTensor = cudnn_frontend::TensorBuilder()\n                               .setDim(4, y_dim)\n                               .setStrides(4, stride)\n                               .setId('c')\n                               .setAlignment(16)\n                               .setDataType(CUDNN_DATA_FLOAT)\n                               .setVirtual()\n                               .build();\n    DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe());\n\n    generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto bTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, b_dim)\n                       .setStrides(4, stride)\n                       .setId('b')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, bTensor.describe());\n\n    generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto afterBiasTensor = cudnn_frontend::TensorBuilder()\n                               .setDim(4, y_dim)\n                               .setStrides(4, stride)\n                               .setId('B')\n                               .setAlignment(16)\n                               .setDataType(CUDNN_DATA_FLOAT)\n                               .setVirtual()\n                               .build();\n    DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());\n\n    generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto afterMaskTensor = cudnn_frontend::TensorBuilder()\n                               .setDim(4, y_dim)\n                               .setStrides(4, stride)\n                               .setId('M')\n                               .setAlignment(16)\n                               .setDataType(CUDNN_DATA_FLOAT)\n                               .setVirtual()\n                               .build();\n    DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());\n\n    generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto afterReLUTensor = cudnn_frontend::TensorBuilder()\n                               .setDim(4, y_dim)\n                               .setStrides(4, stride)\n                               .setId('y')\n                               .setAlignment(16)\n                               .setDataType(dataType)\n                               .build();\n    DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe());\n\n    // Define the convolution problem\n    auto convDesc = cudnn_frontend::ConvDescBuilder()\n                        .setDataType(CUDNN_DATA_FLOAT)\n                        .setMathMode(CUDNN_CROSS_CORRELATION)\n                        .setNDims(conv_dim)\n                        .setStrides(conv_dim, conv_stride)\n                        .setPrePadding(conv_dim, conv_pad)\n                        .setPostPadding(conv_dim, conv_pad)\n                        .setDilation(conv_dim, conv_dilation)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n    // Define the bias operation\n    auto biasDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());\n\n    // Define the mask operation\n    auto maskDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build();\n\n    // Define the activation operation\n    auto actDesc = cudnn_frontend::PointWiseDescBuilder()\n                       .setMode(CUDNN_POINTWISE_RELU_FWD)\n                       .setMathPrecision(CUDNN_DATA_FLOAT)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, actDesc.describe());\n\n    // Create a convolution Node\n    auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)\n                       .setxDesc(xTensor)\n                       .setwDesc(wTensor)\n                       .setyDesc(afterConvTensor)\n                       .setcDesc(convDesc)\n                       .setAlpha(alpha)\n                       .setBeta(beta)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n    // Create a Bias Node\n    auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                       .setxDesc(conv_op.getOutputTensor())\n                       .setbDesc(bTensor)\n                       .setyDesc(afterBiasTensor)\n                       .setpwDesc(biasDesc)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, bias_op.describe());\n\n    // create a Mask Node\n    auto mask_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                       .setxDesc(bias_op.getOutputTensor())\n                       .setbDesc(mTensor)\n                       .setyDesc(afterMaskTensor)\n                       .setpwDesc(maskDesc)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, mask_op.describe());\n\n    // Create an Activation Node\n    auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                      .setxDesc(mask_op.getOutputTensor())\n                      .setyDesc(afterReLUTensor)\n                      .setpwDesc(actDesc)\n                      .build();\n    DEBUG_CUDNN_MSG(log_buf, act_op.describe());\n\n    // Create an Operation Graph. In this case it is convolution bias activation\n    std::array<cudnn_frontend::Operation const*, 4> ops = {&conv_op, &bias_op, &mask_op, &act_op};\n\n    auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(4, ops.data()).build();\n\n    // Create string encoding for plan caching\n    auto cache_string =\n        getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag());\n    DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n    auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n    DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n    auto workspace_size = plan.getWorkspaceSize();\n    DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n    void* workspace_ptr = nullptr;\n    auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n    if (workspace_size > 0) {\n      workspace_ptr = workspace_tensor.data_ptr<float>();\n    }\n    void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrM, devPtrY};\n    int64_t uids[] = {'x', 'w', 'b', 'm', 'y'};\n    auto variantPack = cudnn_frontend::VariantPackBuilder()\n                           .setWorkspacePointer(workspace_ptr)\n                           .setDataPointers(5, data_ptrs)\n                           .setUids(5, uids)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n    cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n    checkCudnnErr(status);\n    cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n  } catch (cudnn_frontend::cudnnException e) {\n    std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n  }\n}\n\nvoid run_conv_cscale_cbias_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride,\n                                int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW,\n                                at::Half* devPtrS, at::Half* devPtrB, at::Half* devPtrY) {\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n  std::stringstream log_buf;\n\n  try {\n    int conv_dim = 2;\n    float alpha = 1.0f;\n    float beta = 0.0f;\n    int64_t s_dim[] = {1, y_dim[1], 1, 1};\n    int64_t b_dim[] = {1, y_dim[1], 1, 1};\n\n    // Creates the necessary tensor descriptors\n    int64_t stride[4];\n    generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto xTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, x_dim)\n                       .setStrides(4, stride)\n                       .setId('x')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, xTensor.describe());\n\n    generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto wTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, w_dim)\n                       .setStrides(4, stride)\n                       .setId('w')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, wTensor.describe());\n\n    generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto afterConvTensor = cudnn_frontend::TensorBuilder()\n                               .setDim(4, y_dim)\n                               .setStrides(4, stride)\n                               .setId('c')\n                               .setAlignment(16)\n                               .setDataType(CUDNN_DATA_FLOAT)\n                               .setVirtual()\n                               .build();\n    DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe());\n\n    generateStrides(s_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto sTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, s_dim)\n                       .setStrides(4, stride)\n                       .setId('s')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, sTensor.describe());\n\n    generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto afterScaleTensor = cudnn_frontend::TensorBuilder()\n                                .setDim(4, y_dim)\n                                .setStrides(4, stride)\n                                .setId('S')\n                                .setAlignment(16)\n                                .setDataType(CUDNN_DATA_FLOAT)\n                                .setVirtual()\n                                .build();\n    DEBUG_CUDNN_MSG(log_buf, afterScaleTensor.describe());\n\n    generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto bTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, b_dim)\n                       .setStrides(4, stride)\n                       .setId('b')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, bTensor.describe());\n\n    generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto afterBiasTensor = cudnn_frontend::TensorBuilder()\n                               .setDim(4, y_dim)\n                               .setStrides(4, stride)\n                               .setId('B')\n                               .setAlignment(16)\n                               .setDataType(CUDNN_DATA_FLOAT)\n                               .setVirtual()\n                               .build();\n    DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());\n\n    generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto afterReLUTensor = cudnn_frontend::TensorBuilder()\n                               .setDim(4, y_dim)\n                               .setStrides(4, stride)\n                               .setId('y')\n                               .setAlignment(16)\n                               .setDataType(dataType)\n                               .build();\n    DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe());\n\n    // Define the convolution problem\n    auto convDesc = cudnn_frontend::ConvDescBuilder()\n                        .setDataType(CUDNN_DATA_FLOAT)\n                        .setMathMode(CUDNN_CROSS_CORRELATION)\n                        .setNDims(conv_dim)\n                        .setStrides(conv_dim, conv_stride)\n                        .setPrePadding(conv_dim, conv_pad)\n                        .setPostPadding(conv_dim, conv_pad)\n                        .setDilation(conv_dim, conv_dilation)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n    // Define the scale operation\n    auto scaleDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());\n\n    // Define the bias operation\n    auto biasDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());\n\n    // Define the activation operation\n    auto actDesc = cudnn_frontend::PointWiseDescBuilder()\n                       .setMode(CUDNN_POINTWISE_RELU_FWD)\n                       .setMathPrecision(CUDNN_DATA_FLOAT)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, actDesc.describe());\n\n    // Create a convolution Node\n    auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)\n                       .setxDesc(xTensor)\n                       .setwDesc(wTensor)\n                       .setyDesc(afterConvTensor)\n                       .setcDesc(convDesc)\n                       .setAlpha(alpha)\n                       .setBeta(beta)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n    // Create a scale Node.\n    auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                        .setxDesc(conv_op.getOutputTensor())\n                        .setbDesc(sTensor)\n                        .setyDesc(afterScaleTensor)\n                        .setpwDesc(scaleDesc)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, scale_op.describe());\n\n    // Create a Bias Node.\n    auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                       .setxDesc(scale_op.getOutputTensor())\n                       .setbDesc(bTensor)\n                       .setyDesc(afterBiasTensor)\n                       .setpwDesc(biasDesc)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, bias_op.describe());\n\n    // Create an Activation Node.\n    auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                      .setxDesc(bias_op.getOutputTensor())\n                      .setyDesc(afterReLUTensor)\n                      .setpwDesc(actDesc)\n                      .build();\n    DEBUG_CUDNN_MSG(log_buf, act_op.describe());\n\n    // Create an Operation Graph. In this case it is convolution bias activation\n    std::array<cudnn_frontend::Operation const*, 4> ops = {&conv_op, &scale_op, &bias_op, &act_op};\n\n    auto opGraph =\n        cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();\n\n    // Create string encoding for plan caching\n    auto cache_string =\n        getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag());\n    DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n    auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n    DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n    auto workspace_size = plan.getWorkspaceSize();\n    DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n    void* workspace_ptr = nullptr;\n    auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n    if (workspace_size > 0) {\n      workspace_ptr = workspace_tensor.data_ptr<float>();\n    }\n    void* data_ptrs[] = {devPtrX, devPtrW, devPtrS, devPtrB, devPtrY};\n    int64_t uids[] = {'x', 'w', 's', 'b', 'y'};\n    auto variantPack = cudnn_frontend::VariantPackBuilder()\n                           .setWorkspacePointer(workspace_ptr)\n                           .setDataPointers(5, data_ptrs)\n                           .setUids(5, uids)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n    cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n    checkCudnnErr(status);\n    cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n  } catch (cudnn_frontend::cudnnException e) {\n    std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n  }\n}\n\nvoid run_conv_bias_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride,\n                        int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW,\n                        at::Half* devPtrB, at::Half* devPtrY) {\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n  std::stringstream log_buf;\n\n  try {\n    int conv_dim = 2;\n    float alpha = 1.0f;\n    float beta = 0.0f;\n    int64_t b_dim[] = {1, y_dim[1], 1, 1};\n\n    // Creates the necessary tensor descriptors\n    int64_t stride[4];\n    generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto xTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, x_dim)\n                       .setStrides(4, stride)\n                       .setId('x')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, xTensor.describe());\n\n    generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto wTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, w_dim)\n                       .setStrides(4, stride)\n                       .setId('w')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, wTensor.describe());\n\n    generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto afterConvTensor = cudnn_frontend::TensorBuilder()\n                               .setDim(4, y_dim)\n                               .setStrides(4, stride)\n                               .setId('c')\n                               .setAlignment(16)\n                               .setDataType(CUDNN_DATA_FLOAT)\n                               .setVirtual()\n                               .build();\n    DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe());\n\n    generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto bTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, b_dim)\n                       .setStrides(4, stride)\n                       .setId('b')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, bTensor.describe());\n\n    generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto afterBiasTensor = cudnn_frontend::TensorBuilder()\n                               .setDim(4, y_dim)\n                               .setStrides(4, stride)\n                               .setId('B')\n                               .setAlignment(16)\n                               .setDataType(CUDNN_DATA_FLOAT)\n                               .setVirtual()\n                               .build();\n    DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());\n\n    generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto afterReLUTensor = cudnn_frontend::TensorBuilder()\n                               .setDim(4, y_dim)\n                               .setStrides(4, stride)\n                               .setId('y')\n                               .setAlignment(16)\n                               .setDataType(dataType)\n                               .build();\n    DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe());\n\n    // Define the convolution problem\n    auto convDesc = cudnn_frontend::ConvDescBuilder()\n                        .setDataType(CUDNN_DATA_FLOAT)\n                        .setMathMode(CUDNN_CROSS_CORRELATION)\n                        .setNDims(conv_dim)\n                        .setStrides(conv_dim, conv_stride)\n                        .setPrePadding(conv_dim, conv_pad)\n                        .setPostPadding(conv_dim, conv_pad)\n                        .setDilation(conv_dim, conv_dilation)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n    // Define the bias operation\n    auto biasDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_ADD).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());\n\n    // Define the activation operation\n    auto actDesc = cudnn_frontend::PointWiseDescBuilder()\n                       .setMode(CUDNN_POINTWISE_RELU_FWD)\n                       .setMathPrecision(CUDNN_DATA_FLOAT)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, actDesc.describe());\n\n    // Create a convolution Node\n    auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)\n                       .setxDesc(xTensor)\n                       .setwDesc(wTensor)\n                       .setyDesc(afterConvTensor)\n                       .setcDesc(convDesc)\n                       .setAlpha(alpha)\n                       .setBeta(beta)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n    // Create a Bias Node.\n    auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                       .setxDesc(conv_op.getOutputTensor())\n                       .setbDesc(bTensor)\n                       .setyDesc(afterBiasTensor)\n                       .setpwDesc(biasDesc)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, bias_op.describe());\n\n    // Create an Activation Node.\n    auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                      .setxDesc(bias_op.getOutputTensor())\n                      .setyDesc(afterReLUTensor)\n                      .setpwDesc(actDesc)\n                      .build();\n    DEBUG_CUDNN_MSG(log_buf, act_op.describe());\n\n    // Create an Operation Graph. In this case it is convolution bias activation\n    std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &bias_op, &act_op};\n\n    auto opGraph = cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(3, ops.data()).build();\n\n    // Create string encoding for plan caching\n    auto cache_string =\n        getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag());\n    DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n    auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n    DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n    auto workspace_size = plan.getWorkspaceSize();\n    DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n    void* workspace_ptr = nullptr;\n    auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n    if (workspace_size > 0) {\n      workspace_ptr = workspace_tensor.data_ptr<float>();\n    }\n    void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrY};\n    int64_t uids[] = {'x', 'w', 'b', 'y'};\n    auto variantPack = cudnn_frontend::VariantPackBuilder()\n                           .setWorkspacePointer(workspace_ptr)\n                           .setDataPointers(4, data_ptrs)\n                           .setUids(4, uids)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n    cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n    checkCudnnErr(status);\n    cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n  } catch (cudnn_frontend::cudnnException e) {\n    std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n  }\n}\n\nvoid run_drelu_dscale(int64_t* dy_dim, cudnnDataType_t dataType, at::Half* devPtrDY, at::Half* devPtrR,\n                      at::Half* devPtrS, at::Half* devPtrDX) {\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n  std::stringstream log_buf;\n\n  try {\n    int convDim = 2;\n    float alpha = 1.0f;\n    float beta = 0.0f;\n    int64_t s_dim[] = {1, dy_dim[1], 1, 1};\n\n    // Creates the necessary tensor descriptors\n    int64_t stride[4];\n    generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto dyTensor = cudnn_frontend::TensorBuilder()\n                        .setDim(4, dy_dim)\n                        .setStrides(4, stride)\n                        .setId('y')\n                        .setAlignment(16)\n                        .setDataType(dataType)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, dyTensor.describe());\n\n    generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto rTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, dy_dim)\n                       .setStrides(4, stride)\n                       .setId('r')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, rTensor.describe());\n\n    generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto inActGradTensor = cudnn_frontend::TensorBuilder()\n                               .setDim(4, dy_dim)\n                               .setStrides(4, stride)\n                               .setId('R')\n                               .setAlignment(16)\n                               .setDataType(CUDNN_DATA_FLOAT)\n                               .setVirtual()\n                               .build();\n    DEBUG_CUDNN_MSG(log_buf, inActGradTensor.describe());\n\n    generateStrides(s_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto scaleTensor = cudnn_frontend::TensorBuilder()\n                           .setDim(4, s_dim)\n                           .setStrides(4, stride)\n                           .setId('s')\n                           .setAlignment(16)\n                           .setDataType(dataType)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, scaleTensor.describe());\n\n    generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto dxTensor = cudnn_frontend::TensorBuilder()\n                        .setDim(4, dy_dim)\n                        .setStrides(4, stride)\n                        .setId('x')\n                        .setAlignment(16)\n                        .setDataType(dataType)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, dxTensor.describe());\n\n    // Define the activation backward operation\n    auto actDesc = cudnn_frontend::PointWiseDescBuilder()\n                       .setMode(CUDNN_POINTWISE_RELU_BWD)\n                       .setMathPrecision(CUDNN_DATA_FLOAT)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, actDesc.describe());\n\n    // Define the bias backward operation\n    auto scaleDesc =\n        cudnn_frontend::PointWiseDescBuilder().setMode(CUDNN_POINTWISE_MUL).setMathPrecision(CUDNN_DATA_FLOAT).build();\n    DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());\n\n    // Create an relu backward Node\n    auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                      .setdyDesc(dyTensor)\n                      .setxDesc(rTensor)\n                      .setdxDesc(inActGradTensor)\n                      .setpwDesc(actDesc)\n                      .build();\n    DEBUG_CUDNN_MSG(log_buf, act_op.describe());\n\n    // Create bias node\n    auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                        .setxDesc(inActGradTensor)\n                        .setbDesc(scaleTensor)\n                        .setyDesc(dxTensor)\n                        .setpwDesc(scaleDesc)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, scale_op.describe());\n\n    // Create an Operation Graph. In this case it is bias only\n    std::array<cudnn_frontend::Operation const*, 2> ops = {&act_op, &scale_op};\n\n    auto opGraph =\n        cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();\n\n    // Create string encoding for plan caching\n    // creating unique dummy values\n    int64_t pad_dummy[] = {40, 40};\n    int64_t stride_dummy[] = {40, 40};\n    int64_t dilation_dummy[] = {40, 40};\n    auto cache_string =\n        getConvFusionString(dy_dim, pad_dummy, stride_dummy, dilation_dummy, s_dim, dataType, opGraph.getTag());\n    DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n    auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n    DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n    auto workspace_size = plan.getWorkspaceSize();\n    DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n    void* workspace_ptr = nullptr;\n    auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n    if (workspace_size > 0) {\n      workspace_ptr = workspace_tensor.data_ptr<float>();\n    }\n    void* data_ptrs[] = {devPtrDY, devPtrR, devPtrS, devPtrDX};\n    int64_t uids[] = {'y', 'r', 's', 'x'};\n    auto variantPack = cudnn_frontend::VariantPackBuilder()\n                           .setWorkspacePointer(workspace_ptr)\n                           .setDataPointers(4, data_ptrs)\n                           .setUids(4, uids)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n    cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n    checkCudnnErr(status);\n    cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n  } catch (cudnn_frontend::cudnnException e) {\n    std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n  }\n}\n\nvoid run_drelu_dbias(int64_t* dy_dim, cudnnDataType_t dataType, at::Half* devPtrDY, at::Half* devPtrR,\n                     at::Half* devPtrDR, float* devPtrDB) {\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n  std::stringstream log_buf;\n\n  try {\n    int convDim = 2;\n    float alpha = 1.0f;\n    float beta = 0.0f;\n    int64_t b_dim[] = {1, dy_dim[1], 1, 1};\n\n    // Creates the necessary tensor descriptors\n    int64_t stride[4];\n    generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto dyTensor = cudnn_frontend::TensorBuilder()\n                        .setDim(4, dy_dim)\n                        .setStrides(4, stride)\n                        .setId('x')\n                        .setAlignment(16)\n                        .setDataType(dataType)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, dyTensor.describe());\n\n    generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto rTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, dy_dim)\n                       .setStrides(4, stride)\n                       .setId('r')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, rTensor.describe());\n\n    generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto inActGradTensor = cudnn_frontend::TensorBuilder()\n                               .setDim(4, dy_dim)\n                               .setStrides(4, stride)\n                               .setId('R')\n                               .setAlignment(16)\n                               .setDataType(dataType)\n                               .build();\n    DEBUG_CUDNN_MSG(log_buf, inActGradTensor.describe());\n\n    generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto biasGradTensor = cudnn_frontend::TensorBuilder()\n                              .setDim(4, b_dim)\n                              .setStrides(4, stride)\n                              .setId('y')\n                              .setAlignment(16)\n                              .setDataType(CUDNN_DATA_FLOAT)\n                              .build();\n    DEBUG_CUDNN_MSG(log_buf, biasGradTensor.describe());\n\n    // Define the activation backward operation\n    auto actDesc = cudnn_frontend::PointWiseDescBuilder()\n                       .setMode(CUDNN_POINTWISE_RELU_BWD)\n                       .setMathPrecision(CUDNN_DATA_FLOAT)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, actDesc.describe());\n\n    // Define the bias backward operation\n    auto biasDesc = cudnn_frontend::ReductionDescBuilder()\n                        .setMathPrecision(CUDNN_DATA_FLOAT)\n                        .setReductionOp(CUDNN_REDUCE_TENSOR_ADD)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());\n\n    // Create an relu backward Node\n    auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                      .setdyDesc(dyTensor)\n                      .setxDesc(rTensor)\n                      .setdxDesc(inActGradTensor)\n                      .setpwDesc(actDesc)\n                      .build();\n    DEBUG_CUDNN_MSG(log_buf, act_op.describe());\n\n    // Create bias node\n    auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)\n                       .setxDesc(inActGradTensor)\n                       .setyDesc(biasGradTensor)\n                       .setreductionDesc(biasDesc)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, bias_op.describe());\n\n    // Create an Operation Graph. In this case it is bias only\n    std::array<cudnn_frontend::Operation const*, 2> ops = {&act_op, &bias_op};\n\n    auto opGraph =\n        cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();\n\n    // Create string encoding for plan caching\n    // creating unique dummy values\n    int64_t pad_dummy[] = {20, 20};\n    int64_t stride_dummy[] = {20, 20};\n    int64_t dilation_dummy[] = {20, 20};\n    auto cache_string =\n        getConvFusionString(dy_dim, pad_dummy, stride_dummy, dilation_dummy, b_dim, dataType, opGraph.getTag());\n    DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n    auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n    DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n    auto workspace_size = plan.getWorkspaceSize();\n    DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n    void* workspace_ptr = nullptr;\n    auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n    if (workspace_size > 0) {\n      workspace_ptr = workspace_tensor.data_ptr<float>();\n    }\n    void* data_ptrs[] = {devPtrDY, devPtrR, devPtrDR, devPtrDB};\n    int64_t uids[] = {'x', 'r', 'R', 'y'};\n    auto variantPack = cudnn_frontend::VariantPackBuilder()\n                           .setWorkspacePointer(workspace_ptr)\n                           .setDataPointers(4, data_ptrs)\n                           .setUids(4, uids)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n    cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n    checkCudnnErr(status);\n    cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n  } catch (cudnn_frontend::cudnnException e) {\n    std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n  }\n}\n\nvoid run_dconv_drelu_dbias(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* pad, int64_t* convstride,\n                           int64_t* dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW,\n                           at::Half* devPtrR, at::Half* devPtrRg, float* devPtrY) {\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n  std::stringstream log_buf;\n  try {\n    int convDim = 2;\n    float alpha = 1.0f;\n    float beta = 0.0f;\n    int64_t b_dim[] = {1, x_dim[1], 1, 1};\n\n    int64_t stride[4];\n    generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto outConvGradTensor = cudnn_frontend::TensorBuilder()\n                                 .setDim(4, y_dim)\n                                 .setStrides(4, stride)\n                                 .setId('x')\n                                 .setAlignment(16)\n                                 .setDataType(dataType)\n                                 .build();\n    DEBUG_CUDNN_MSG(log_buf, outConvGradTensor.describe());\n\n    generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto wTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, w_dim)\n                       .setStrides(4, stride)\n                       .setId('w')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, wTensor.describe());\n\n    generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto inConvGradTensor = cudnn_frontend::TensorBuilder()\n                                .setDim(4, x_dim)\n                                .setStrides(4, stride)\n                                .setId('A')\n                                .setAlignment(16)\n                                .setDataType(CUDNN_DATA_FLOAT)\n                                .setVirtual()\n                                .build();\n    DEBUG_CUDNN_MSG(log_buf, inConvGradTensor.describe());\n\n    generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto rTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, x_dim)\n                       .setStrides(4, stride)\n                       .setId('r')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, rTensor.describe());\n\n    generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto inReLUGradTensor = cudnn_frontend::TensorBuilder()\n                                .setDim(4, x_dim)\n                                .setStrides(4, stride)\n                                .setId('R')\n                                .setAlignment(16)\n                                .setDataType(dataType)\n                                .build();\n    DEBUG_CUDNN_MSG(log_buf, inReLUGradTensor.describe());\n\n    generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto inBiasGradTensor = cudnn_frontend::TensorBuilder()\n                                .setDim(4, b_dim)\n                                .setStrides(4, stride)\n                                .setId('y')\n                                .setAlignment(16)\n                                .setDataType(CUDNN_DATA_FLOAT)\n                                .build();\n    DEBUG_CUDNN_MSG(log_buf, inBiasGradTensor.describe());\n\n    // Define the convolution problem\n    auto convDesc = cudnn_frontend::ConvDescBuilder()\n                        .setDataType(CUDNN_DATA_FLOAT)\n                        .setMathMode(CUDNN_CROSS_CORRELATION)\n                        .setNDims(convDim)\n                        .setStrides(convDim, convstride)\n                        .setPrePadding(convDim, pad)\n                        .setPostPadding(convDim, pad)\n                        .setDilation(convDim, dilation)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n    // Define the activation backward operation\n    auto actDesc = cudnn_frontend::PointWiseDescBuilder()\n                       .setMode(CUDNN_POINTWISE_RELU_BWD)\n                       .setMathPrecision(CUDNN_DATA_FLOAT)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, actDesc.describe());\n\n    // Define the bias backward operation\n    auto biasDesc = cudnn_frontend::ReductionDescBuilder()\n                        .setMathPrecision(CUDNN_DATA_FLOAT)\n                        .setReductionOp(CUDNN_REDUCE_TENSOR_ADD)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());\n\n    // Create a convolution Node\n    auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)\n                       .setdyDesc(outConvGradTensor)\n                       .setwDesc(wTensor)\n                       .setdxDesc(inConvGradTensor)\n                       .setcDesc(convDesc)\n                       .setAlpha(alpha)\n                       .setBeta(beta)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n    // Create an relu backward Node\n    auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)\n                      .setdyDesc(inConvGradTensor)\n                      .setxDesc(rTensor)\n                      .setdxDesc(inReLUGradTensor)\n                      .setpwDesc(actDesc)\n                      .build();\n    DEBUG_CUDNN_MSG(log_buf, act_op.describe());\n\n    // Create bias node\n    auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)\n                       .setxDesc(inReLUGradTensor)\n                       .setyDesc(inBiasGradTensor)\n                       .setreductionDesc(biasDesc)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, bias_op.describe());\n\n    // Create an Operation Graph. In this case it is bias only\n    std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &act_op, &bias_op};\n\n    auto opGraph =\n        cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();\n\n    // Create string encoding for plan caching\n    auto cache_string = getConvFusionString(x_dim, pad, convstride, dilation, w_dim, dataType, opGraph.getTag());\n    DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n    auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n    DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n    auto workspace_size = plan.getWorkspaceSize();\n    DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n    void* workspace_ptr = nullptr;\n    auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n    if (workspace_size > 0) {\n      workspace_ptr = workspace_tensor.data_ptr<float>();\n    }\n    void* data_ptrs[] = {devPtrX, devPtrW, devPtrR, devPtrRg, devPtrY};\n    int64_t uids[] = {'x', 'w', 'r', 'R', 'y'};\n    auto variantPack = cudnn_frontend::VariantPackBuilder()\n                           .setWorkspacePointer(workspace_ptr)\n                           .setDataPointers(5, data_ptrs)\n                           .setUids(5, uids)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n    cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n    checkCudnnErr(status);\n    cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n  } catch (cudnn_frontend::cudnnException e) {\n    std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n  }\n}\n\nvoid run_dconv(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride,\n               int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW,\n               at::Half* devPtrY, cudnnBackendDescriptorType_t mode) {\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n  std::stringstream log_buf;\n\n  try {\n    int conv_dim = 2;\n    float alpha = 1.0f;\n    float beta = 0.0f;\n\n    // Define the convolution problem\n    int64_t stride[4];\n    generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto xTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, x_dim)\n                       .setStrides(4, stride)\n                       .setId('x')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, xTensor.describe());\n\n    generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto wTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, w_dim)\n                       .setStrides(4, stride)\n                       .setId('w')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, wTensor.describe());\n\n    generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto yTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, y_dim)\n                       .setStrides(4, stride)\n                       .setId('y')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, yTensor.describe());\n\n    // Define the convolution problem\n    auto convDesc = cudnn_frontend::ConvDescBuilder()\n                        .setDataType(CUDNN_DATA_FLOAT)\n                        .setMathMode(CUDNN_CROSS_CORRELATION)\n                        .setNDims(conv_dim)\n                        .setStrides(conv_dim, conv_stride)\n                        .setPrePadding(conv_dim, conv_pad)\n                        .setPostPadding(conv_dim, conv_pad)\n                        .setDilation(conv_dim, conv_dilation)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, convDesc.describe());\n\n    // Create a convolution node\n    // mode should be one of following\n    // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR\n    // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR\n    auto conv_op_builder = cudnn_frontend::OperationBuilder(mode);\n    if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {\n      conv_op_builder.setdxDesc(xTensor).setwDesc(wTensor).setdyDesc(yTensor).setcDesc(convDesc);\n    } else {\n      conv_op_builder.setxDesc(xTensor).setdwDesc(wTensor).setdyDesc(yTensor).setcDesc(convDesc);\n    }\n    auto conv_op = conv_op_builder.setAlpha(alpha).setBeta(beta).build();\n    DEBUG_CUDNN_MSG(log_buf, conv_op.describe());\n\n    // Create an Operation Graph. In this case it is convolution add bias activation\n    std::array<cudnn_frontend::Operation const*, 1> ops = {&conv_op};\n\n    auto opGraph =\n        cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();\n\n    // Create string encoding for plan caching\n    auto cache_string =\n        getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag());\n    DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n    auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n    DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n    auto workspace_size = plan.getWorkspaceSize();\n    DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n    void* workspace_ptr = nullptr;\n    auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n    if (workspace_size > 0) {\n      workspace_ptr = workspace_tensor.data_ptr<float>();\n    }\n    void* data_ptrs[] = {devPtrX, devPtrW, devPtrY};\n    int64_t uids[] = {'x', 'w', 'y'};\n    auto variantPack = cudnn_frontend::VariantPackBuilder()\n                           .setWorkspacePointer(workspace_ptr)\n                           .setDataPointers(3, data_ptrs)\n                           .setUids(3, uids)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n    cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n    checkCudnnErr(status);\n    cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n  } catch (cudnn_frontend::cudnnException e) {\n    std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n  }\n}\n\nvoid run_dbias(int64_t* x_dim, cudnnDataType_t dataType, at::Half* devPtrX, float* devPtrY) {\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n  std::stringstream log_buf;\n  try {\n    int convDim = 2;\n    int64_t b_dim[] = {1, x_dim[1], 1, 1};\n\n    int64_t stride[4];\n    generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto xTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, x_dim)\n                       .setStrides(4, stride)\n                       .setId('x')\n                       .setAlignment(16)\n                       .setDataType(dataType)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, xTensor.describe());\n\n    generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);\n    auto yTensor = cudnn_frontend::TensorBuilder()\n                       .setDim(4, b_dim)\n                       .setStrides(4, stride)\n                       .setId('y')\n                       .setAlignment(16)\n                       .setDataType(CUDNN_DATA_FLOAT)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, yTensor.describe());\n\n    // Define the bias backward operation\n    auto biasDesc = cudnn_frontend::ReductionDescBuilder()\n                        .setMathPrecision(CUDNN_DATA_FLOAT)\n                        .setReductionOp(CUDNN_REDUCE_TENSOR_ADD)\n                        .build();\n    DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());\n\n    // Create bias node\n    auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)\n                       .setxDesc(xTensor)\n                       .setyDesc(yTensor)\n                       .setreductionDesc(biasDesc)\n                       .build();\n    DEBUG_CUDNN_MSG(log_buf, bias_op.describe());\n\n    // Create an Operation Graph. In this case it is bias only\n    std::array<cudnn_frontend::Operation const*, 1> ops = {&bias_op};\n\n    auto opGraph =\n        cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();\n\n    // Create string encoding for plan caching\n    int64_t pad_dummy[] = {10, 10};\n    int64_t stride_dummy[] = {10, 10};\n    int64_t dilation_dummy[] = {10, 10};\n    auto cache_string =\n        getConvFusionString(x_dim, pad_dummy, stride_dummy, dilation_dummy, b_dim, dataType, opGraph.getTag());\n    DEBUG_CUDNN_MSG(log_buf, \"[convstring] \" << cache_string);\n\n    auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);\n    DEBUG_CUDNN_MSG(log_buf, \"Plan tag: \" << plan.getTag());\n\n    auto workspace_size = plan.getWorkspaceSize();\n    DEBUG_CUDNN_MSG(log_buf, plan.describe() << \" requires workspace \" << workspace_size);\n\n    void* workspace_ptr = nullptr;\n    auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n    if (workspace_size > 0) {\n      workspace_ptr = workspace_tensor.data_ptr<float>();\n    }\n    void* data_ptrs[] = {devPtrX, devPtrY};\n    int64_t uids[] = {'x', 'y'};\n    auto variantPack = cudnn_frontend::VariantPackBuilder()\n                           .setWorkspacePointer(workspace_ptr)\n                           .setDataPointers(2, data_ptrs)\n                           .setUids(2, uids)\n                           .build();\n    DEBUG_CUDNN_MSG(log_buf, \"variantPack \" << variantPack.describe());\n    cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n    checkCudnnErr(status);\n    cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n  } catch (cudnn_frontend::cudnnException e) {\n    std::cout << log_buf.str() << \"[ERROR] Exception \" << e.what() << std::endl;\n  }\n}\n\nstd::vector<at::Tensor> conv_bias_mask_relu_forward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {\n  std::cout << std::fixed;\n\n  // create output vector\n  std::vector<at::Tensor> outputs;\n  auto output_format = at::MemoryFormat::ChannelsLast;\n\n  // setup dimensions\n  int64_t x_dim[] = {0, 0, 0, 0};\n  int64_t w_dim[] = {0, 0, 0, 0};\n\n  // All dim calculation after this order of n,c,h,w\n  int axis[] = {0, 1, 2, 3};\n  for (int dim = 0; dim < 4; dim++) {\n    x_dim[dim] = inputs[0].size(axis[dim]);\n    w_dim[dim] = inputs[1].size(axis[dim]);\n  }\n\n  // output dim in n,c,h,w used by backend\n  int64_t y_dim[] = {0, 0, 0, 0};\n\n  // use these fixed values\n  int64_t conv_pad[] = {padding, padding};\n  int64_t conv_stride[] = {stride, stride};\n  int64_t conv_dilation[] = {1, 1};\n\n  // compute output from pad/stride/dilation\n  y_dim[0] = x_dim[0];\n  y_dim[1] = w_dim[0];\n  for (int dim = 0; dim < 2; dim++) {\n    y_dim[dim + 2] =\n        getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]);\n  }\n\n  // run\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n  at::Half* w = inputs[1].data_ptr<at::Half>();\n  at::Half* b = inputs[2].data_ptr<at::Half>();\n  int8_t* m = inputs[3].data_ptr<int8_t>();\n  auto out = at::empty(y_dim, inputs[0].type(), output_format);\n  at::Half* y = out.data_ptr<at::Half>();\n\n  run_conv_bias_mask_relu(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, w, b, m, y);\n\n  DEBUG_MSG(\"[DEBUG] conv-bias-mask-relu : \" << y.to(at::kFloat).sum().item<float>());\n\n  outputs.push_back(out);\n\n  return outputs;\n}\n\nat::Tensor conv_cscale_cbias_relu_forward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {\n  std::cout << std::fixed;\n\n  // setup dimensions\n  int64_t x_dim[] = {0, 0, 0, 0};\n  int64_t w_dim[] = {0, 0, 0, 0};\n\n  // All dim calculation after this order of n,c,h,w\n  int axis[] = {0, 1, 2, 3};\n  for (int dim = 0; dim < 4; dim++) {\n    x_dim[dim] = inputs[0].size(axis[dim]);\n    w_dim[dim] = inputs[1].size(axis[dim]);\n  }\n\n  // output dim in n,c,h,w used by backend\n  int64_t y_dim[] = {0, 0, 0, 0};\n\n  // use these fixed values\n  int64_t conv_pad[] = {padding, padding};\n  int64_t conv_stride[] = {stride, stride};\n  int64_t conv_dilation[] = {1, 1};\n\n  // compute output from pad/stride/dilation\n  y_dim[0] = x_dim[0];\n  y_dim[1] = w_dim[0];\n  for (int dim = 0; dim < 2; dim++) {\n    y_dim[dim + 2] =\n        getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]);\n  }\n\n  // run\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n  at::Half* w = inputs[1].data_ptr<at::Half>();\n  at::Half* s = inputs[2].data_ptr<at::Half>();\n  at::Half* b = inputs[3].data_ptr<at::Half>();\n  auto out = at::empty(y_dim, inputs[0].type(), at::MemoryFormat::ChannelsLast);\n  at::Half* y = out.data_ptr<at::Half>();\n\n  run_conv_cscale_cbias_relu(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, w, s, b, y);\n\n  DEBUG_MSG(\"[DEBUG] conv-cscale-cbias-relu : \" << y.to(at::kFloat).sum().item<float>());\n\n  return out;\n}\n\nstd::vector<at::Tensor> conv_cscale_cbias_relu_backward(std::vector<at::Tensor> inputs, int64_t padding,\n                                                        int64_t stride) {\n  bool requires_grad = inputs[0].requires_grad();\n\n  for (int i = 0; i <= 4; i++) {\n    CHECK_INPUT(inputs[i]);\n  }\n\n  std::cout << std::fixed;\n\n  // create output vector\n  std::vector<at::Tensor> outputs;\n  auto output_format = at::MemoryFormat::ChannelsLast;\n\n  // setup dimensions\n  int64_t x_dim[] = {0, 0, 0, 0};\n  int64_t w_dim[] = {0, 0, 0, 0};\n  int64_t y_dim[] = {0, 0, 0, 0};\n\n  // All dim calculation after this order of n,c,h,w\n  int axis[] = {0, 1, 2, 3};\n  for (int dim = 0; dim < 4; dim++) {\n    x_dim[dim] = inputs[0].size(axis[dim]);\n    w_dim[dim] = inputs[1].size(axis[dim]);\n    y_dim[dim] = inputs[3].size(axis[dim]);\n  }\n\n  int64_t b_dim[] = {1, y_dim[1], 1, 1};\n\n  int64_t conv_pad[] = {padding, padding};\n  int64_t conv_stride[] = {stride, stride};\n  int64_t conv_dilation[] = {1, 1};\n\n  // run\n  // drelu-dbias\n  at::Half* dy = inputs[4].data_ptr<at::Half>();\n  at::Half* r = inputs[3].data_ptr<at::Half>();\n  auto s = inputs[2].data_ptr<at::Half>();\n  auto dscale = at::empty_like(inputs[4]);\n  at::Half* ds = dscale.data_ptr<at::Half>();\n\n  auto options =\n      at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false);\n  run_drelu_dscale(y_dim, CUDNN_DATA_HALF, dy, r, s, ds);\n\n  // conv wgrad\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n  auto wgrad = at::empty_like(inputs[1]);\n  at::Half* dw = wgrad.data_ptr<at::Half>();\n  run_dconv(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, dw, ds,\n            CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n\n  // conv dgrad\n  at::Half* w = inputs[1].data_ptr<at::Half>();\n  auto dgrad = at::empty_like(inputs[0]);\n  at::Half* dx = dgrad.data_ptr<at::Half>();\n  run_dconv(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, dx, w, ds,\n            CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);\n\n  outputs.push_back(dgrad);\n  outputs.push_back(wgrad);\n\n  return outputs;\n}\n\nstd::vector<at::Tensor> conv_bias_relu_forward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {\n  std::cout << std::fixed;\n\n  // create output vector\n  std::vector<at::Tensor> outputs;\n  auto output_format = at::MemoryFormat::ChannelsLast;\n\n  // setup dimensions\n  int64_t x_dim[] = {0, 0, 0, 0};\n  int64_t w_dim[] = {0, 0, 0, 0};\n\n  // All dim calculation after this order of n,c,h,w\n  int axis[] = {0, 1, 2, 3};\n  for (int dim = 0; dim < 4; dim++) {\n    x_dim[dim] = inputs[0].size(axis[dim]);\n    w_dim[dim] = inputs[1].size(axis[dim]);\n  }\n\n  // output dim in n,c,h,w used by backend\n  int64_t y_dim[] = {0, 0, 0, 0};\n\n  // use these fixed values\n  int64_t conv_pad[] = {padding, padding};\n  int64_t conv_stride[] = {stride, stride};\n  int64_t conv_dilation[] = {1, 1};\n\n  // compute output from pad/stride/dilation\n  y_dim[0] = x_dim[0];\n  y_dim[1] = w_dim[0];\n  for (int dim = 0; dim < 2; dim++) {\n    y_dim[dim + 2] =\n        getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]);\n  }\n\n  // run\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n  at::Half* w = inputs[1].data_ptr<at::Half>();\n  at::Half* b = inputs[2].data_ptr<at::Half>();\n  auto out = at::empty(y_dim, inputs[0].type(), output_format);\n  at::Half* y = out.data_ptr<at::Half>();\n\n  run_conv_bias_relu(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, w, b, y);\n\n  DEBUG_MSG(\"[DEBUG] conv-bias-relu : \" << y.to(at::kFloat).sum().item<float>());\n\n  outputs.push_back(out);\n\n  return outputs;\n}\n\nstd::vector<at::Tensor> conv_bias_relu_backward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {\n  bool requires_grad = inputs[0].requires_grad();\n\n  for (int i = 0; i <= 3; i++) {\n    CHECK_INPUT(inputs[i]);\n  }\n\n  std::cout << std::fixed;\n\n  // create output vector\n  std::vector<at::Tensor> outputs;\n  auto output_format = at::MemoryFormat::ChannelsLast;\n\n  // setup dimensions\n  int64_t x_dim[] = {0, 0, 0, 0};\n  int64_t w_dim[] = {0, 0, 0, 0};\n  int64_t y_dim[] = {0, 0, 0, 0};\n\n  // All dim calculation after this order of n,c,h,w\n  int axis[] = {0, 1, 2, 3};\n  for (int dim = 0; dim < 4; dim++) {\n    x_dim[dim] = inputs[0].size(axis[dim]);\n    w_dim[dim] = inputs[1].size(axis[dim]);\n    y_dim[dim] = inputs[3].size(axis[dim]);\n  }\n\n  int64_t b_dim[] = {1, y_dim[1], 1, 1};\n\n  int64_t conv_pad[] = {padding, padding};\n  int64_t conv_stride[] = {stride, stride};\n  int64_t conv_dilation[] = {1, 1};\n\n  // run\n  // drelu-dbias\n  at::Half* dy = inputs[3].data_ptr<at::Half>();\n  at::Half* r = inputs[2].data_ptr<at::Half>();\n  auto drelu = at::empty_like(inputs[2]);\n  at::Half* dr = drelu.data_ptr<at::Half>();\n  auto options =\n      at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false);\n  auto bgrad = at::empty(b_dim, options, output_format);\n  float* db = bgrad.data_ptr<float>();\n  run_drelu_dbias(y_dim, CUDNN_DATA_HALF, dy, r, dr, db);\n\n  // conv wgrad\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n  auto wgrad = at::empty_like(inputs[1]);\n  at::Half* dw = wgrad.data_ptr<at::Half>();\n  run_dconv(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, dw, dr,\n            CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n\n  // conv dgrad\n  at::Half* w = inputs[1].data_ptr<at::Half>();\n  auto dgrad = at::empty_like(inputs[0]);\n  at::Half* dx = dgrad.data_ptr<at::Half>();\n  run_dconv(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, dx, w, dr,\n            CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);\n\n  outputs.push_back(dgrad);\n  outputs.push_back(wgrad);\n  outputs.push_back(bgrad);\n\n  return outputs;\n}\n\nstd::vector<at::Tensor> conv_bias_forward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {\n  std::cout << std::fixed;\n\n  // create output vector\n  std::vector<at::Tensor> outputs;\n  auto output_format = at::MemoryFormat::ChannelsLast;\n\n  // setup dimensions\n  int64_t x_dim[] = {0, 0, 0, 0};\n  int64_t w_dim[] = {0, 0, 0, 0};\n\n  // All dim calculation after this order of n,c,h,w\n  int axis[] = {0, 1, 2, 3};\n  for (int dim = 0; dim < 4; dim++) {\n    x_dim[dim] = inputs[0].size(axis[dim]);\n    w_dim[dim] = inputs[1].size(axis[dim]);\n  }\n\n  // output dim in n,c,h,w used by backend\n  int64_t y_dim[] = {0, 0, 0, 0};\n\n  // use these fixed values\n  int64_t conv_pad[] = {padding, padding};\n  int64_t conv_stride[] = {stride, stride};\n  int64_t conv_dilation[] = {1, 1};\n\n  // compute output from pad/stride/dilation\n  y_dim[0] = x_dim[0];\n  y_dim[1] = w_dim[0];\n  for (int dim = 0; dim < 2; dim++) {\n    y_dim[dim + 2] =\n        getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]);\n  }\n\n  // run\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n  at::Half* w = inputs[1].data_ptr<at::Half>();\n  at::Half* b = inputs[2].data_ptr<at::Half>();\n  auto out = at::empty(y_dim, inputs[0].type(), output_format);\n  at::Half* y = out.data_ptr<at::Half>();\n\n  run_conv_bias(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, w, b, y);\n\n  DEBUG_MSG(\"[DEBUG] conv-bias : \" << y.to(at::kFloat).sum().item<float>());\n\n  outputs.push_back(out);\n\n  return outputs;\n}\n\nstd::vector<at::Tensor> conv_bias_backward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {\n  bool requires_grad = inputs[0].requires_grad();\n\n  for (int i = 0; i <= 2; i++) {\n    CHECK_INPUT(inputs[i]);\n  }\n\n  std::cout << std::fixed;\n\n  // create output vector\n  std::vector<at::Tensor> outputs;\n  auto output_format = at::MemoryFormat::ChannelsLast;\n\n  // setup dimensions\n  int64_t x_dim[] = {0, 0, 0, 0};\n  int64_t w_dim[] = {0, 0, 0, 0};\n  int64_t y_dim[] = {0, 0, 0, 0};\n\n  // All dim calculation after this order of n,c,h,w\n  int axis[] = {0, 1, 2, 3};\n  for (int dim = 0; dim < 4; dim++) {\n    x_dim[dim] = inputs[0].size(axis[dim]);\n    w_dim[dim] = inputs[1].size(axis[dim]);\n    y_dim[dim] = inputs[2].size(axis[dim]);\n  }\n\n  int64_t b_dim[] = {1, y_dim[1], 1, 1};\n\n  int64_t conv_pad[] = {padding, padding};\n  int64_t conv_stride[] = {stride, stride};\n  int64_t conv_dilation[] = {1, 1};\n\n  // run\n  // dbias\n  at::Half* dy = inputs[2].data_ptr<at::Half>();\n  auto options =\n      at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false);\n  auto bgrad = at::empty(b_dim, options, output_format);\n  float* db = bgrad.data_ptr<float>();\n  run_dbias(y_dim, CUDNN_DATA_HALF, dy, db);\n\n  // conv wgrad\n  at::Half* x = inputs[0].data_ptr<at::Half>();\n  auto wgrad = at::empty_like(inputs[1]);\n  at::Half* dw = wgrad.data_ptr<at::Half>();\n  run_dconv(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, x, dw, dy,\n            CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);\n\n  // conv dgrad\n  at::Half* w = inputs[1].data_ptr<at::Half>();\n  auto dgrad = at::empty_like(inputs[0]);\n  at::Half* dx = dgrad.data_ptr<at::Half>();\n  run_dconv(x_dim, w_dim, y_dim, conv_pad, conv_stride, conv_dilation, CUDNN_DATA_HALF, dx, w, dy,\n            CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);\n\n  outputs.push_back(dgrad);\n  outputs.push_back(wgrad);\n  outputs.push_back(bgrad);\n\n  return outputs;\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &conv_bias_relu_forward, \"Fused Conv-Bias-ReLU forward\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward\", &conv_bias_relu_backward, \"Fused Conv-Bias-ReLU backward\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"forward_no_relu\", &conv_bias_forward, \"Fused Conv-Bias forward\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward_no_relu\", &conv_bias_backward, \"Fused Conv-Bias backward\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"forward_mask\", &conv_bias_mask_relu_forward, \"Fused Conv-Bias-Mask-ReLU forward\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"forward_cscale_cbias_relu\", &conv_cscale_cbias_relu_forward, \"Fused Conv-(const)Scale-(const)Bias-ReLU\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward_cscale_cbias_relu\", &conv_cscale_cbias_relu_backward,\n        \"Fused Conv-(const)Scale-(const)Bias-ReLU backward\", py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp",
    "content": "#include <ATen/ATen.h>\n#include <torch/extension.h>\n#include <torch/torch.h>\n\n#include <iostream>\n#include <vector>\n\n#include \"norm_sample.h\"\n\n// define this enum:\nenum bn_type { BN_FWD, BN_BWD };\n\n// this is a global variable\nstatic std::map<std::vector<int64_t>, cudnn_frontend::ExecutionPlan> gbn_plan_cache;\n\nat::Tensor gbn_forward(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias,\n                       const at::Tensor& running_mean, const at::Tensor& running_var, const at::Tensor& minibatch_mean,\n                       const at::Tensor& minibatch_inv_var, const float momentum, const float epsilon,\n                       const int64_t bn_group, const int rank_id, const std::vector<int64_t>& peer_buffers) {\n  int64_t N = x.size(0);\n  int64_t C = x.size(1);\n  int64_t H = x.size(2);\n  int64_t W = x.size(3);\n\n  int64_t tensorDims[] = {N, C, H, W};\n  int64_t peerDims[] = {bn_group, 4 * C, 1, 1};\n  int64_t perChannelDims[] = {1, C, 1, 1};\n  int64_t epsilonDims[] = {1, 1, 1, 1};\n\n  // Allocate output tensor\n  at::Tensor y = at::empty_like(x);\n\n  std::vector<void*> void_peer_buffers;\n  for (int64_t addr : peer_buffers) {\n    void_peer_buffers.push_back((void*)addr);\n  }\n\n  // we need the peer size for the buffer reset\n  size_t peer_size = 1;\n  for (size_t i = 0; i < 4; ++i) {\n    peer_size *= peerDims[i];\n  }\n\n  // sanity check\n  assert(bn_group == void_peer_buffers.size());\n\n  // check if plan already exists\n  std::vector<int64_t> fv = {(int64_t)BN_FWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF};\n  if (gbn_plan_cache.find(fv) == gbn_plan_cache.end()) {\n    auto plan = run_batch_norm_forward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF);\n    gbn_plan_cache.emplace(fv, std::move(plan));\n  }\n\n  // get plan and handle\n  auto plan = gbn_plan_cache.find(fv)->second;\n\n  // execute\n  execute_batch_norm_forward(plan, x.data_ptr(), y.data_ptr(), scale.data_ptr(), bias.data_ptr(),\n                             running_mean.data_ptr(), running_var.data_ptr(), running_mean.data_ptr(),\n                             running_var.data_ptr(), minibatch_mean.data_ptr(), minibatch_inv_var.data_ptr(),\n                             void_peer_buffers, static_cast<double>(epsilon), static_cast<double>(momentum), peer_size,\n                             rank_id);\n\n  return y;\n}\n\nstd::vector<at::Tensor> gbn_backward(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale,\n                                     const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var,\n                                     const float epsilon, const int64_t bn_group, const int rank_id,\n                                     const std::vector<int64_t>& peer_buffers) {\n  int64_t N = x.size(0);\n  int64_t C = x.size(1);\n  int64_t H = x.size(2);\n  int64_t W = x.size(3);\n\n  int64_t tensorDims[] = {N, C, H, W};\n  int64_t peerDims[] = {bn_group, 4 * C, 1, 1};\n  int64_t perChannelDims[] = {1, C, 1, 1};\n  int64_t epsilonDims[] = {1, 1, 1, 1};\n\n  // Allocate output tensor\n  // outputs\n  at::Tensor x_grad, scale_grad, bias_grad;\n\n  // Allocate outputs\n  x_grad = at::empty_like(x);\n  scale_grad = at::empty_like(scale);\n  bias_grad = at::empty_like(scale);\n\n  std::vector<void*> void_peer_buffers;\n  for (int64_t addr : peer_buffers) {\n    void_peer_buffers.push_back((void*)addr);\n  }\n\n  // we need the peer size for the buffer reset\n  size_t peer_size = 1;\n  for (size_t i = 0; i < 4; ++i) {\n    peer_size *= peerDims[i];\n  }\n\n  assert(bn_group == void_peer_buffers.size());\n\n  std::vector<int64_t> fv = {(int64_t)BN_BWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF};\n  if (gbn_plan_cache.find(fv) == gbn_plan_cache.end()) {\n    auto plan = run_batch_norm_backward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF);\n    gbn_plan_cache.emplace(fv, std::move(plan));\n  }\n\n  // get plan and handle\n  auto plan = gbn_plan_cache.find(fv)->second;\n\n  // execute\n  execute_batch_norm_backward(plan, x.data_ptr(), dy.data_ptr(), scale.data_ptr(), minibatch_mean.data_ptr(),\n                              minibatch_inv_var.data_ptr(), void_peer_buffers, x_grad.data_ptr(), scale_grad.data_ptr(),\n                              bias_grad.data_ptr(), static_cast<double>(epsilon), peer_size, rank_id);\n\n  return std::vector<at::Tensor>{x_grad, scale_grad, bias_grad};\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &gbn_forward, \"Group batch norm forward\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward\", &gbn_backward, \"Group batch backward\", py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/cudnn_gbn/norm_sample.cpp",
    "content": "/*\n * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.\n *\n * Permission is hereby granted, free of charge, to any person obtaining a\n * copy of this software and associated documentation files (the \"Software\"),\n * to deal in the Software without restriction, including without limitation\n * the rights to use, copy, modify, merge, publish, distribute, sublicense,\n * and/or sell copies of the Software, and to permit persons to whom the\n * Software is furnished to do so, subject to the following conditions:\n *\n * The above copyright notice and this permission notice shall be included in\n * all copies or substantial portions of the Software.\n *\n * THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL\n * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n * DEALINGS IN THE SOFTWARE.\n */\n\n#include \"norm_sample.h\"\n\n#include <ATen/cudnn/Handle.h>  // for getcudnnhandle\n#include <cudnn_frontend.h>\n#include <torch/extension.h>\n#include <torch/torch.h>\n\n#include \"cudnn_backend.h\"\n\n// some helpers\nint64_t checkCudaError(cudaError_t code, const char* expr, const char* file, int line) {\n  if (code) {\n    printf(\"CUDA error at %s:%d, code=%d (%s) in '%s'\", file, line, (int)code, cudaGetErrorString(code), expr);\n    return 1;\n  }\n  return 0;\n}\n\nint64_t checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) {\n  if (code) {\n    printf(\"CUDNN error at %s:%d, code=%d (%s) in '%s'\\n\", file, line, (int)code, cudnnGetErrorString(code), expr);\n    return 1;\n  }\n  return 0;\n}\n\nbool AllowAll(cudnnBackendDescriptor_t engine_config) {\n  (void)engine_config;\n  return false;\n}\n\nvoid generateStrides(const int64_t* dimA, int64_t* strideA, int64_t nbDims, cudnnTensorFormat_t filterFormat) {\n  // For INT8x4 and INT8x32 we still compute standard strides here to input\n  // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref.\n  if (filterFormat == CUDNN_TENSOR_NCHW) {\n    strideA[nbDims - 1] = 1;\n    for (int64_t d = nbDims - 2; d >= 0; d--) {\n      strideA[d] = strideA[d + 1] * dimA[d + 1];\n    }\n  } else {\n    // Here we assume that the format is CUDNN_TENSOR_NHWC\n    strideA[1] = 1;\n    strideA[nbDims - 1] = strideA[1] * dimA[1];\n    for (int64_t d = nbDims - 2; d >= 2; d--) {\n      strideA[d] = strideA[d + 1] * dimA[d + 1];\n    }\n    strideA[0] = strideA[2] * dimA[2];\n  }\n}\n\n// runtime\ncudnn_frontend::ExecutionPlan run_batch_norm_forward(int64_t* tensorDims, int64_t* perChannelSum, int64_t* epsilon,\n                                                     int64_t* peerDims, cudnnDataType_t data_type) {\n  // get the cudnn handle\n  cudnnHandle_t handle = torch::native::getCudnnHandle();\n\n  // Creates the necessary tensor descriptors\n  int64_t tensor_stride[4];\n  int64_t stride[4];\n  int64_t peer_stride[4];\n\n  // NHWC format. GenerateStrides() takes care of this. Howeever, tensor dims should still be NCHW\n  generateStrides(tensorDims, tensor_stride, (int64_t)4, CUDNN_TENSOR_NHWC);\n  generateStrides(peerDims, peer_stride, (int64_t)4, CUDNN_TENSOR_NHWC);\n\n  auto tensor_create = [&tensor_stride, &tensorDims](cudnnDataType_t type, int64_t id) {\n    return cudnn_frontend::TensorBuilder()\n        .setDim(4, tensorDims)\n        .setStrides(4, tensor_stride)\n        .setId(id)\n        .setAlignment(16)\n        .setDataType(type)\n        .build();\n  };\n\n  auto peer_tensor_create = [&peer_stride, &tensorDims](cudnnDataType_t type, int64_t id) {\n    return cudnn_frontend::TensorBuilder()\n        .setDim(4, tensorDims)\n        .setStrides(4, peer_stride)\n        .setId(id)\n        .setAlignment(16)\n        .setDataType(type)\n        .build();\n  };\n\n  generateStrides(perChannelSum, stride, (int64_t)4, CUDNN_TENSOR_NHWC);\n\n  auto per_channel_tensor_create = [&stride, &perChannelSum](cudnnDataType_t type, int64_t id) {\n    return cudnn_frontend::TensorBuilder()\n        .setDim(4, perChannelSum)\n        .setStrides(4, stride)\n        .setId(id)\n        .setAlignment(16)\n        .setDataType(type)\n        .build();\n  };\n\n  auto xTensor = tensor_create(data_type, 100);\n  auto yTensor = tensor_create(data_type, 101);\n  auto scaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 102);\n  auto biasTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 103);\n  auto inMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 104);\n  auto inVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 105);\n  auto outMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 106);\n  auto outVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 107);\n  auto savedMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 108);\n  auto savedInvVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 109);\n\n  int64_t epsilon_stride[4];\n  generateStrides(epsilon, epsilon_stride, (int64_t)4, CUDNN_TENSOR_NHWC);\n  auto scalar_tensor_create = [&epsilon_stride, &epsilon](cudnnDataType_t type, int64_t id) {\n    return cudnn_frontend::TensorBuilder()\n        .setDim(4, epsilon)\n        .setStrides(4, epsilon_stride)\n        .setId(id)\n        .setAlignment(16)\n        .setDataType(type)\n        .setByValue(true)\n        .build();\n  };\n\n  auto epsilonTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 110);\n  auto expDecayTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 111);\n\n  // Create the two peer stat tensors. Jump IDs in case we need to add more tensors with UIDs\n  std::vector<cudnn_frontend::Tensor_v8> peerStatTensors;\n  for (size_t i = 112; i < 112 + peerDims[0]; ++i) {\n    peerStatTensors.push_back(peer_tensor_create(CUDNN_DATA_FLOAT, i));\n  }\n\n#if (CUDNN_VERSION >= 8500)\n  // Batch normalization\n  cudnnBackendNormMode_t normalizationMode = CUDNN_BATCH_NORM;\n\n  // Forward training\n  cudnnBackendNormFwdPhase_t phase = CUDNN_NORM_FWD_TRAINING;\n\n  // Create a Finalize node\n  auto batch_norm_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR)\n                           .setNormalizationMode(normalizationMode)\n                           .setNormFwdPhase(phase)\n                           .setxDesc(xTensor)\n                           .setScaleAndBias(scaleTensor, biasTensor)\n                           .setPrevRunningMeanAndVar(inMeanTensor, inVarTensor)\n                           .setNextRunningMeanAndVar(outMeanTensor, outVarTensor)\n                           .setSavedMeanAndInvVar(savedMeanTensor, savedInvVarTensor)\n                           .setEpsilonTensor(epsilonTensor)\n                           .setExpDecayFactorTensor(expDecayTensor)\n                           .setPeerStatTensor(peerStatTensors)\n                           .setyDesc(yTensor)\n                           .build();\n\n  std::array<cudnn_frontend::Operation const*, 1> ops = {&batch_norm_op};\n#else\n  std::array<cudnn_frontend::Operation const*, 0> ops = {};\n#endif\n  auto opGraph =\n      cudnn_frontend::OperationGraphBuilder().setHandle(handle).setOperationGraph(ops.size(), ops.data()).build();\n  // std::cout << opGraph.describe() << std::endl;\n\n  cudnn_frontend::EngineConfigList filtered_configs;\n  auto statuses = cudnn_frontend::get_heuristics_list<2>({\"heuristics_instant\", \"heuristics_fallback\"}, opGraph,\n                                                         ::AllowAll, filtered_configs, true);\n\n  // std::cout << \"get_heuristics_list Statuses: \";\n  // for (auto i = 0u ; i < statuses.size(); i++) {\n  //   std::cout << cudnn_frontend::to_string(statuses[i]) << \" \";\n  // }\n  // std::cout << std::endl;\n  // std::cout << \"Filter config list has \" << filtered_configs.size() << \" configurations \" << std::endl;\n\n  // some verbose printing:\n  // std::cout << \"Tensor shape: (\" << tensorDims[0] << \", \" << tensorDims[1] << \", \" << tensorDims[2] << \", \" <<\n  // tensorDims[3] << \")\" << std::endl;\n\n  auto plan_builder = [&filtered_configs, &opGraph, &handle]() {\n    for (auto i = 0u; i < filtered_configs.size(); i++) {\n      try {\n        auto plan = cudnn_frontend::ExecutionPlanBuilder()\n                        .setHandle(handle)\n                        .setEngineConfig(filtered_configs[i], opGraph.getTag())\n                        .build();\n        return plan;\n      } catch (cudnn_frontend::cudnnException& e) {\n        continue;\n      }\n    }\n    return cudnn_frontend::ExecutionPlanBuilder()\n        .setHandle(handle)\n        .setEngineConfig(filtered_configs[0], opGraph.getTag())\n        .build();\n  };\n\n  assert(filtered_configs.size() > 0);\n  auto plan = plan_builder();\n\n  return plan;\n}\n\nvoid execute_batch_norm_forward(cudnn_frontend::ExecutionPlan plan, void* xDevPtr, void* yDevPtr, void* scaledevPtr,\n                                void* biasdevPtr, void* in_meandevPtr, void* in_vardevPtr, void* out_meandevPtr,\n                                void* out_vardevPtr, void* saved_meandevPtr, void* saved_inv_vardevPtr,\n                                const std::vector<void*>& peer_devPtrs, double epsilon_val,\n                                double exponential_decay_factor, size_t peer_size, int rank_id) {\n  // get handle\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n\n  // get stream\n  cudaStream_t stream;\n  cudnnGetStream(handle_, &stream);\n\n  try {\n    // allocate workspace\n    auto workspace_size = plan.getWorkspaceSize();\n    auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n    void* workPtr = nullptr;\n    if (workspace_size > 0) {\n      workPtr = workspace_tensor.data_ptr<float>();\n    }\n\n    // first the data pointers\n    std::vector<void*> data_ptrs{\n        xDevPtr,        yDevPtr,       scaledevPtr,      biasdevPtr,          in_meandevPtr, in_vardevPtr,\n        out_meandevPtr, out_vardevPtr, saved_meandevPtr, saved_inv_vardevPtr, &epsilon_val,  &exponential_decay_factor};\n    data_ptrs.insert(data_ptrs.end(), peer_devPtrs.begin(), peer_devPtrs.end());\n    // then the uids\n    std::vector<int64_t> uids;\n    for (size_t i = 100; i < 100 + data_ptrs.size(); ++i) {\n      uids.push_back(i);\n    }\n    auto variantPack = cudnn_frontend::VariantPackBuilder()\n                           .setWorkspacePointer(workPtr)\n                           .setDataPointers(data_ptrs.size(), data_ptrs.data())\n                           .setUids(uids.size(), uids.data())\n                           .build();\n    // std::cout << \"variantPack \" << variantPack.describe() << std::endl;\n    cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n    cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n\n    // Reset local communication buffer\n    cudaMemsetAsync(peer_devPtrs[rank_id], 0, peer_size * 4, stream);\n\n  } catch (cudnn_frontend::cudnnException& e) {\n    struct cudaDeviceProp prop;\n    checkCudaErr(cudaGetDeviceProperties(&prop, 0));\n    if (prop.major == 8) {\n      std::cout << \"[ERROR] Exception \" << e.what() << std::endl;\n      assert(false);\n    }\n  }\n}\n\ncudnn_frontend::ExecutionPlan run_batch_norm_backward(int64_t* tensorDims, int64_t* perChannelSum, int64_t* epsilon,\n                                                      int64_t* peerDims, cudnnDataType_t data_type) {\n  // get cudnn handle\n  cudnnHandle_t handle = torch::native::getCudnnHandle();\n\n  // Creates the necessary tensor descriptors\n  int64_t tensor_stride[4];\n  int64_t stride[4];\n  int64_t peer_stride[4];\n\n  // NHWC format. GenerateStrides() takes care of this. Howeever, tensor dims should still be NCHW\n  generateStrides(tensorDims, tensor_stride, (int64_t)4, CUDNN_TENSOR_NHWC);\n  generateStrides(peerDims, peer_stride, (int64_t)4, CUDNN_TENSOR_NHWC);\n\n  auto tensor_create = [&tensor_stride, &tensorDims](cudnnDataType_t type, int64_t id) {\n    return cudnn_frontend::TensorBuilder()\n        .setDim(4, tensorDims)\n        .setStrides(4, tensor_stride)\n        .setId(id)\n        .setAlignment(16)\n        .setDataType(type)\n        .build();\n  };\n\n  auto peer_tensor_create = [&peer_stride, &peerDims](cudnnDataType_t type, int64_t id) {\n    return cudnn_frontend::TensorBuilder()\n        .setDim(4, peerDims)\n        .setStrides(4, peer_stride)\n        .setId(id)\n        .setAlignment(16)\n        .setDataType(type)\n        .build();\n  };\n\n  generateStrides(perChannelSum, stride, (int64_t)4, CUDNN_TENSOR_NHWC);\n\n  auto per_channel_tensor_create = [&stride, &perChannelSum](cudnnDataType_t type, int64_t id) {\n    return cudnn_frontend::TensorBuilder()\n        .setDim(4, perChannelSum)\n        .setStrides(4, stride)\n        .setId(id)\n        .setAlignment(16)\n        .setDataType(type)\n        .build();\n  };\n\n  auto xTensor = tensor_create(data_type, 100);\n  auto dyTensor = tensor_create(data_type, 101);\n  auto scaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 102);\n  auto savedMeanTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 103);\n  auto savedInvVarTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 104);\n  auto dxTensor = tensor_create(data_type, 105);\n  auto dScaleTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 106);\n  auto dBiasTensor = per_channel_tensor_create(CUDNN_DATA_FLOAT, 107);\n\n  int64_t epsilon_stride[4];\n  generateStrides(epsilon, epsilon_stride, (int64_t)4, CUDNN_TENSOR_NHWC);\n  auto scalar_tensor_create = [&epsilon_stride, &epsilon](cudnnDataType_t type, int64_t id) {\n    return cudnn_frontend::TensorBuilder()\n        .setDim(4, epsilon)\n        .setStrides(4, epsilon_stride)\n        .setId(id)\n        .setAlignment(16)\n        .setDataType(type)\n        .setByValue(true)\n        .build();\n  };\n\n  auto epsilonTensor = scalar_tensor_create(CUDNN_DATA_DOUBLE, 108);\n\n  std::vector<cudnn_frontend::Tensor_v8> peerStatTensors;\n  for (size_t i = 109; i < 109 + peerDims[0]; ++i) {\n    peerStatTensors.push_back(peer_tensor_create(CUDNN_DATA_FLOAT, i));\n  }\n\n#if (CUDNN_VERSION >= 8500)\n  // Batch normalization\n  cudnnBackendNormMode_t normalizationMode = CUDNN_BATCH_NORM;\n\n  // Create a Finalize node\n  auto batch_norm_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR)\n                           .setNormalizationMode(normalizationMode)\n                           .setxDesc(xTensor)\n                           .setSavedMeanAndInvVar(savedMeanTensor, savedInvVarTensor)\n                           .setdyDesc(dyTensor)\n                           .setScale(scaleTensor)\n                           .setEpsilonTensor(epsilonTensor)\n                           .setDScaleAndDBias(dScaleTensor, dBiasTensor)\n                           .setdxDesc(dxTensor)\n                           .setPeerStatTensor(peerStatTensors)\n                           .build();\n\n  std::array<cudnn_frontend::Operation const*, 1> ops = {&batch_norm_op};\n#else\n  std::array<cudnn_frontend::Operation const*, 0> ops = {};\n#endif\n\n  auto opGraph =\n      cudnn_frontend::OperationGraphBuilder().setHandle(handle).setOperationGraph(ops.size(), ops.data()).build();\n  // std::cout << opGraph.describe() << std::endl;\n\n  cudnn_frontend::EngineConfigList filtered_configs;\n  auto statuses = cudnn_frontend::get_heuristics_list<2>({\"heuristics_instant\", \"heuristics_fallback\"}, opGraph,\n                                                         ::AllowAll, filtered_configs, true);\n\n  auto plan_builder = [&filtered_configs, &opGraph, &handle]() {\n    for (auto i = 0u; i < filtered_configs.size(); i++) {\n      try {\n        auto plan = cudnn_frontend::ExecutionPlanBuilder()\n                        .setHandle(handle)\n                        .setEngineConfig(filtered_configs[i], opGraph.getTag())\n                        .build();\n        return plan;\n      } catch (cudnn_frontend::cudnnException& e) {\n        continue;\n      }\n    }\n    return cudnn_frontend::ExecutionPlanBuilder()\n        .setHandle(handle)\n        .setEngineConfig(filtered_configs[0], opGraph.getTag())\n        .build();\n  };\n\n  assert(filtered_configs.size() > 0);\n  auto plan = plan_builder();\n\n  return plan;\n}\n\nvoid execute_batch_norm_backward(cudnn_frontend::ExecutionPlan plan, void* xDevPtr, void* dyDevPtr, void* scaledevPtr,\n                                 void* saved_meandevPtr, void* saved_inv_vardevPtr,\n                                 const std::vector<void*>& peer_devPtrs, void* dxDevPtr, void* dscaledevPtr,\n                                 void* dbiasdevPtr, double epsilon_val, size_t peer_size, int rank_id) {\n  // get handle\n  cudnnHandle_t handle_ = torch::native::getCudnnHandle();\n\n  // get stream\n  cudaStream_t stream;\n  cudnnGetStream(handle_, &stream);\n\n  try {\n    // allocate workspace\n    auto workspace_size = plan.getWorkspaceSize();\n    auto workspace_tensor = at::empty({(workspace_size + 3) / 4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));\n    void* workPtr = nullptr;\n    if (workspace_size > 0) {\n      workPtr = workspace_tensor.data_ptr<float>();\n    }\n\n    // create helper arrays\n    std::vector<void*> data_ptrs{xDevPtr,  dyDevPtr,     scaledevPtr, saved_meandevPtr, saved_inv_vardevPtr,\n                                 dxDevPtr, dscaledevPtr, dbiasdevPtr, &epsilon_val};\n    data_ptrs.insert(data_ptrs.end(), peer_devPtrs.begin(), peer_devPtrs.end());\n    std::vector<int64_t> uids;\n    for (size_t i = 100; i < 100 + data_ptrs.size(); ++i) {\n      uids.push_back(i);\n    }\n\n    auto variantPack = cudnn_frontend::VariantPackBuilder()\n                           .setWorkspacePointer(workPtr)\n                           .setDataPointers(data_ptrs.size(), data_ptrs.data())\n                           .setUids(uids.size(), uids.data())\n                           .build();\n    cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());\n\n    cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, \"Plan execute error\", status);\n\n    // Reset local communication buffer\n    cudaMemsetAsync(peer_devPtrs[rank_id], 0, peer_size * 4, stream);\n\n  } catch (cudnn_frontend::cudnnException& e) {\n    struct cudaDeviceProp prop;\n    checkCudaErr(cudaGetDeviceProperties(&prop, 0));\n    if (prop.major == 8) {\n      std::cout << \"[ERROR] Exception \" << e.what() << std::endl;\n      assert(false);\n    }\n  }\n}\n"
  },
  {
    "path": "apex/contrib/csrc/cudnn_gbn/norm_sample.h",
    "content": "#pragma once\n\n/*\n * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.\n *\n * Permission is hereby granted, free of charge, to any person obtaining a\n * copy of this software and associated documentation files (the \"Software\"),\n * to deal in the Software without restriction, including without limitation\n * the rights to use, copy, modify, merge, publish, distribute, sublicense,\n * and/or sell copies of the Software, and to permit persons to whom the\n * Software is furnished to do so, subject to the following conditions:\n *\n * The above copyright notice and this permission notice shall be included in\n * all copies or substantial portions of the Software.\n *\n * THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL\n * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n * DEALINGS IN THE SOFTWARE.\n */\n\n#pragma once\n\n#include <assert.h>\n#include <ctype.h>\n#include <cudnn.h>\n#include <cudnn_frontend.h>\n#include <inttypes.h>\n#include <stdlib.h>\n#include <string.h>\n\n#include <functional>\n#include <iostream>\n#include <tuple>\n\n/* some helpers\n */\nvoid generateStrides(const int64_t* dimA, int64_t* strideA, int64_t nbDims, cudnnTensorFormat_t filterFormat);\n\nint64_t checkCudaError(cudaError_t code, const char* expr, const char* file, int line);\nint64_t checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line);\n\n#define checkCudaErr(...)                                                        \\\n  do {                                                                           \\\n    int64_t err = checkCudaError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \\\n    assert(err == 0);                                                            \\\n  } while (0)\n\n#define checkCudnnErr(...)                                                        \\\n  do {                                                                            \\\n    int64_t err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \\\n    assert(err == 0);                                                             \\\n  } while (0)\n\n/**\n * @brief Run a Group BN forward sample with 2 peer stat tensors.\n *\n * @param tensorDims an array with shape (N, C, H, W) for input tensor dims. Stride in NHWC or NCHW will take care of\n memory format\n * @param perChannelSum an array with shape (1, C, 1, 1) to denote the sum values for each channel in the input tensor\n * @param epsilon a scalar array with shape (1, 1, 1, 1) to represent the epsilon value for the BN\n * @param peerDims an array with shape (num GPUs, 2 * C, 1, 1) to denote the tensor dimensions for peer stat tensor in\n GBN\n\n *\n */\ncudnn_frontend::ExecutionPlan run_batch_norm_forward(int64_t* tensorDims, int64_t* perChannelSum, int64_t* epsilon,\n                                                     int64_t* peerDims, cudnnDataType_t in_out_data_type);\n/**\n * @param xDevPtr input tensor device pointer\n * @param yDevPtr output tensor device pointer\n * @param scaledevPtr input scale device pointer for BN scaling\n * @param biasdevPtr input scale device pointer for BN bias\n * @param in_meandevPtr Input mean device pointer\n * @param in_vardevPtr Input variance device pointer\n * @param out_meandevPtr output mean device pointer\n * @param out_vardevPtr output variance device pointer\n * @param saved_meandevPtr saved mean device pointer for BN backward\n * @param saved_inv_vardevPtr saved inverse variance device pointer for BN backward\n * @param peer_devPtr1 peer stat tensor 1 device pointer\n * @param peer_devPtr2 peer stat tensor 2 device pointer\n * @param epsilon_val episilon value as a double\n * @param exponential_decay_factor exponential_decay_factor as a value\n *\n **/\nvoid execute_batch_norm_forward(cudnn_frontend::ExecutionPlan plan, void* xDevPtr, void* yDevPtr, void* scaledevPtr,\n                                void* biasdevPtr, void* in_meandevPtr, void* in_vardevPtr, void* out_meandevPtr,\n                                void* out_vardevPtr, void* saved_meandevPtr, void* saved_inv_vardevPtr,\n                                const std::vector<void*>& peer_devPtrs, double epsilon_val,\n                                double exponential_decay_factor, size_t peer_size, int rank_id);\n\n/**\n * @brief Run a Group BN backward sample with 2 peer stat tensors.\n *\n * @param tensorDims an array with shape (N, C, H, W) for input tensor dims. Stride in NHWC or NCHW will take care of\n * memory format\n * @param perChannelSum an array with shape (1, C, 1, 1) to denote the sum values for each channel in the input tensor\n * @param epsilon a scalar array with shape (1, 1, 1, 1) to represent the epsilon value for the BN\n * @param peerDims an array with shape (num GPUs, 2 * C, 1, 1) to denote the tensor dimensions for peer stat tensor in\n * GBN\n *\n */\ncudnn_frontend::ExecutionPlan run_batch_norm_backward(int64_t* tensorDims, int64_t* perChannelSum, int64_t* epsilon,\n                                                      int64_t* peerDims, cudnnDataType_t data_type);\n\n/**\n * @brief Run a Group BN backward sample with 2 peer stat tensors.\n *\n * @param xDevPtr input tensor device pointer\n * @param yDevPtr output tensor device pointer\n * @param scaledevPtr input scale device pointer for BN scaling\n * @param biasdevPtr input scale device pointer for BN bias\n * @param in_meandevPtr Input mean device pointer\n * @param in_vardevPtr Input variance device pointer\n * @param out_meandevPtr output mean device pointer\n * @param out_vardevPtr output variance device pointer\n * @param saved_meandevPtr saved mean device pointer for BN backward\n * @param saved_inv_vardevPtr saved inverse variance device pointer for BN backward\n * @param peer_devPtr1 peer stat tensor 1 device pointer\n * @param peer_devPtr2 peer stat tensor 2 device pointer\n * @param epsilon_val episilon value as a double\n *\n */\nvoid execute_batch_norm_backward(cudnn_frontend::ExecutionPlan plan, void* xDevPtr, void* dyDevPtr, void* scaledevPtr,\n                                 void* saved_meandevPtr, void* saved_inv_vardevPtr,\n                                 const std::vector<void*>& peer_devPtrs, void* dxDevPtr, void* dscaledevPtr,\n                                 void* dbiasdevPtr, double epsilon_val, size_t peer_size, int rank_id);\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/fmha_api.cpp",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n\n#include \"fmha.h\"\n\nextern at::Tensor& mha_fill(at::Tensor& self, const at::Tensor& start_index);\nvoid set_params(Fused_multihead_attention_fprop_params& params,\n                // sizes\n                const size_t b, const size_t s, const size_t h, const size_t d,\n                // device pointers\n                void* qkv_packed_d, void* cu_seqlens_d, void* o_packed_d, void* s_d, float p_dropout) {\n  Data_type acc_type = DATA_TYPE_FP32;\n  Data_type data_type = DATA_TYPE_FP16;\n\n  // Reset the parameters\n  memset(&params, 0, sizeof(params));\n\n  // Set the pointers and strides.\n  params.qkv_ptr = qkv_packed_d;\n  params.qkv_stride_in_bytes = get_size_in_bytes(h * 3 * d, data_type);\n  params.o_ptr = o_packed_d;\n  params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type);\n\n  params.cu_seqlens = static_cast<int*>(cu_seqlens_d);\n\n  // S = softmax(P)\n  params.s_ptr = s_d;\n  params.s_stride_in_bytes = get_size_in_bytes(b * h * s, data_type);\n\n  // Set the dimensions.\n  params.b = b;\n  params.h = h;\n  params.s = s;\n  params.d = d;\n\n  // Set the different scale values.\n  const float scale_bmm1 = 1.f / sqrtf(d);\n  constexpr float scale_softmax = 1.f;\n  constexpr float scale_bmm2 = 1.f;\n\n  set_alpha(params.scale_bmm1, scale_bmm1, data_type);\n  set_alpha(params.scale_softmax, scale_softmax, acc_type);\n  set_alpha(params.scale_bmm2, scale_bmm2, data_type);\n\n  // Set this to probability of keeping an element to simplify things.\n  params.p_dropout = 1.f - p_dropout;\n  params.rp_dropout = 1.f / params.p_dropout;\n  TORCH_CHECK(p_dropout < 1.f);\n  set_alpha(params.scale_dropout, params.rp_dropout, data_type);\n}\n\nstd::vector<at::Tensor> mha_fwd(\n    const at::Tensor& qkv,         // total x num_heads x 3 x head_size, total := \\sum_{i=0}^{b} s_i\n    const at::Tensor& cu_seqlens,  // b+1\n    const float p_dropout, const int max_seq_len, const bool is_training, const bool is_nl, const bool zero_tensors,\n    c10::optional<at::Generator> gen_) {\n  using namespace torch::indexing;\n  auto dprops = at::cuda::getCurrentDeviceProperties();\n  TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) || (dprops->major == 9 && dprops->minor == 0) ||\n              (dprops->major == 10 && dprops->minor == 0) || (dprops->major == 12 && dprops->minor == 0));\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_training, is_nl);\n\n  int seq_len = 512;\n  auto launch = &run_fmha_fp16_512_64_sm80;\n  if (max_seq_len <= 128) {\n    seq_len = 128;\n    launch = &run_fmha_fp16_128_64_sm80;\n  } else if (max_seq_len <= 256) {\n    seq_len = 256;\n    launch = &run_fmha_fp16_256_64_sm80;\n  } else if (max_seq_len <= 384) {\n    seq_len = 384;\n    launch = &run_fmha_fp16_384_64_sm80;\n  } else if (max_seq_len <= 512) {\n    seq_len = 512;\n    launch = &run_fmha_fp16_512_64_sm80;\n  } else {\n    TORCH_CHECK(false);\n  }\n\n  TORCH_CHECK(qkv.is_cuda())\n  TORCH_CHECK(cu_seqlens.is_cuda())\n\n  TORCH_CHECK(qkv.is_contiguous())\n  TORCH_CHECK(cu_seqlens.is_contiguous())\n\n  TORCH_CHECK(cu_seqlens.dim() == 1);\n  TORCH_CHECK(qkv.dim() == 4);\n\n  const auto sizes = qkv.sizes();\n\n  TORCH_CHECK(sizes[THREE_DIM] == 3);\n\n  const int batch_size = cu_seqlens.numel() - 1;\n  const int total = sizes[TOTAL_DIM];\n  const int num_heads = sizes[H_DIM];\n  const int head_size = sizes[D_DIM];\n  TORCH_CHECK(batch_size > 0);\n  TORCH_CHECK(head_size == 64);\n  auto opts = qkv.options();\n\n  auto ctx = torch::empty({total, num_heads, head_size}, opts);\n\n  auto s = torch::empty({batch_size, num_heads, seq_len, seq_len}, opts);\n\n  if (zero_tensors) {\n    mha_fill(ctx, cu_seqlens.index({Slice(-1, None)}));\n  }\n\n  auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());\n\n  set_params(launch_params.params, batch_size, seq_len, num_heads, head_size, qkv.data_ptr(), cu_seqlens.data_ptr(),\n             ctx.data_ptr(), s.data_ptr(), p_dropout);\n\n  launch(launch_params, /*configure=*/true);\n  // number of times random will be generated per thread, to offset philox counter in thc random\n  // state\n  int64_t counter_offset = launch_params.elts_per_thread;\n  at::PhiloxCudaState rng_engine_inputs;\n\n  if (is_training) {\n    // See Note [Acquire lock when using random generators]\n    std::lock_guard<std::mutex> lock(gen->mutex_);\n    launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);\n  }\n\n  launch(launch_params, /*configure=*/false);\n\n  return {ctx, s};\n}\n\nstd::vector<at::Tensor> mha_bwd(\n    const at::Tensor& dout,        // total x num_heads, x head_size\n    const at::Tensor& qkv,         // total x num_heads x 3 x head_size, total := \\sum_{i=0}^{b} s_i\n    at::Tensor& softmax,           // b x h x s x s softmax and dmask - will be overwritten with dP\n    const at::Tensor& cu_seqlens,  // b+1\n    const float p_dropout,         // probability to drop\n    const int max_seq_len,         // max sequence length to choose the kernel\n    const bool zero_tensors) {\n  using namespace torch::indexing;\n  auto dprops = at::cuda::getCurrentDeviceProperties();\n  TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) || (dprops->major == 9 && dprops->minor == 0) ||\n              (dprops->major == 10 && dprops->minor == 0) || (dprops->major == 12 && dprops->minor == 0));\n  int seq_len = 512;\n  auto launch = &run_fmha_dgrad_fp16_512_64_sm80;\n  if (max_seq_len <= 128) {\n    seq_len = 128;\n    launch = &run_fmha_dgrad_fp16_128_64_sm80;\n  } else if (max_seq_len <= 256) {\n    seq_len = 256;\n    launch = &run_fmha_dgrad_fp16_256_64_sm80;\n  } else if (max_seq_len <= 384) {\n    seq_len = 384;\n    launch = &run_fmha_dgrad_fp16_384_64_sm80;\n  } else if (max_seq_len <= 512) {\n    seq_len = 512;\n    launch = &run_fmha_dgrad_fp16_512_64_sm80;\n  } else {\n    TORCH_CHECK(false);\n  }\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n  TORCH_CHECK(qkv.dtype() == torch::kFloat16);\n  TORCH_CHECK(dout.dtype() == torch::kFloat16);\n  TORCH_CHECK(softmax.dtype() == torch::kFloat16);\n  TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);\n\n  TORCH_CHECK(qkv.is_cuda());\n  TORCH_CHECK(cu_seqlens.is_cuda());\n\n  TORCH_CHECK(qkv.is_contiguous());\n  TORCH_CHECK(cu_seqlens.is_contiguous());\n\n  TORCH_CHECK(cu_seqlens.dim() == 1);\n  TORCH_CHECK(qkv.dim() == 4);\n\n  const auto sizes = qkv.sizes();\n\n  TORCH_CHECK(sizes[THREE_DIM] == 3);\n\n  const int batch_size = cu_seqlens.numel() - 1;\n  const int num_heads = sizes[H_DIM];\n  const int head_size = sizes[D_DIM];\n  TORCH_CHECK(batch_size > 0);\n  TORCH_CHECK(head_size == 64);\n\n  auto dqkv = torch::empty_like(qkv);\n\n  if (zero_tensors) {\n    mha_fill(dqkv, cu_seqlens.index({Slice(-1, None)}));\n  }\n\n  Fused_multihead_attention_fprop_params params;\n\n  set_params(params, batch_size, seq_len, num_heads, head_size, qkv.data_ptr(), cu_seqlens.data_ptr(),\n             dout.data_ptr(),     // we set o_ptr to dout\n             softmax.data_ptr(),  // softmax gets overwritten by dP!\n             p_dropout);\n\n  // we're re-using these scales\n  Data_type acc_type = DATA_TYPE_FP32;\n  set_alpha(params.scale_bmm1, 1.f, acc_type);\n  set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);\n  set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16);\n  params.dqkv_ptr = dqkv.data_ptr();\n\n  launch(params, stream);\n  return {dqkv, softmax};\n}\n\nstd::vector<at::Tensor> mha_bwd_nl(\n    const at::Tensor& dout,        // total x num_heads, x head_size\n    const at::Tensor& qkv,         // total x num_heads x 3 x head_size, total := \\sum_{i=0}^{b} s_i\n    at::Tensor& softmax,           // b x h x s x s softmax and dmask - will be overwritten with dP\n    const at::Tensor& cu_seqlens,  // b+1\n    const float p_dropout,         // probability to drop\n    const int max_seq_len,         // max sequence length to choose the kernel\n    const bool zero_tensors) {\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n  TORCH_CHECK(qkv.is_cuda())\n  TORCH_CHECK(cu_seqlens.is_cuda())\n\n  TORCH_CHECK(qkv.is_contiguous())\n  TORCH_CHECK(cu_seqlens.is_contiguous())\n\n  TORCH_CHECK(cu_seqlens.dim() == 1);\n\n  TORCH_CHECK(qkv.dim() == 4);\n\n  const auto sizes = qkv.sizes();\n\n  TORCH_CHECK(sizes[THREE_DIM] == 3);\n\n  const int batch_size = cu_seqlens.numel() - 1;\n\n  const int total = sizes[TOTAL_DIM];\n  const int num_heads = sizes[H_DIM];\n  const int head_size = sizes[D_DIM];\n  TORCH_CHECK(batch_size > 0);\n  TORCH_CHECK(head_size == 64);\n\n  int seq_len = 512;\n  auto launch = &run_fmha_dgrad_fp16_512_64_sm80_nl;\n\n  auto opts = qkv.options();\n\n  auto dqkv = torch::empty_like(qkv);\n\n  if (zero_tensors) {\n    dqkv.zero_();\n  }\n\n  int num_chunks = 2;\n  if (batch_size == 1) {\n    num_chunks = 4;\n  } else if (batch_size == 2) {\n    num_chunks = 3;\n  }\n  auto dkv = torch::empty({total, num_chunks, 2, num_heads, head_size}, opts);\n\n  Fused_multihead_attention_fprop_params params;\n\n  set_params(params, batch_size, seq_len, num_heads, head_size, qkv.data_ptr(), cu_seqlens.data_ptr(),\n             dout.data_ptr(),     // o_ptr = dout\n             softmax.data_ptr(),  // softmax gets overwritten by dP!\n             p_dropout);\n\n  params.dkv_ptr = dkv.data_ptr();\n\n  Data_type acc_type = DATA_TYPE_FP32;\n  set_alpha(params.scale_bmm1, 1.f, acc_type);\n  set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);\n  set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16);\n  params.dqkv_ptr = dqkv.data_ptr();\n\n  launch(params, num_chunks, stream);\n\n  // SPLIT-K reduction of num_chunks dK, dV parts\n\n  // The equivalent of the following Pytorch code:\n  // using namespace torch::indexing;\n  // at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)});\n  // torch::sum_out(view_out, dkv, 1);\n\n  const int hidden_size = num_heads * head_size;\n  fmha_run_noloop_reduce(dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr<int>(), hidden_size, batch_size, total,\n                         num_chunks, stream);\n\n  return {dqkv, softmax, dkv};\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.doc() = \"Fused Multi-head Self-attention for BERT\";\n  m.def(\"fwd\", &mha_fwd, \"Forward pass\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"bwd\", &mha_bwd, \"Backward pass\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"bwd_nl\", &mha_bwd_nl, \"Backward pass (small-batch)\", py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha/gemm.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <fmha/utils.h>\n\n#define FMHA_DIV_UP(m, n) (((m) + (n) - 1) / (n))\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Data_type_, int NUM_ELTS_, int BITS_PER_ELT_, int ALIGNMENT_>\nstruct Fragment_base_ {\n  // The data type.\n  using Data_type = Data_type_;\n  // default input type\n  using Input_type_ = Data_type_;\n  // Does it store the array of elements.\n  enum { HAS_ELTS = BITS_PER_ELT_ >= 8 };\n  // The number of elements.\n  enum { NUM_ELTS = NUM_ELTS_ };\n  // The size of element in bits.\n  enum { BITS_PER_ELT = BITS_PER_ELT_ };\n  // The size of byte of a single register.\n  enum { BYTES_PER_REG = 4 };\n  // The size in bits.\n  enum { BITS_PER_REG = BYTES_PER_REG * 8 };\n  // The number of registers needed to store the fragment.\n  enum { NUM_REGS = Div_up<NUM_ELTS * BITS_PER_ELT, BITS_PER_REG>::VALUE };\n  // The size in bytes (as returned by sizeof(Fragment_base<>).\n  enum { SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG };\n  // The alignment.\n  enum { ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : Min<NUM_REGS * BYTES_PER_REG, 16>::VALUE };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    // The type of the elements.\n    typename Data_type_,\n    // The number of elements.\n    int NUM_ELTS_,\n    // The alignment if you want to force a value -- use 0 otherwise.\n    int ALIGNMENT_ = 0,\n    // The base class.\n    typename Base_ = Fragment_base_<Data_type_, NUM_ELTS_, 8 * sizeof(Data_type_), ALIGNMENT_> >\nstruct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {\n  // The size of a load/store.\n  enum { BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t) };\n\n  // Clear the fragment. Using PTX in that code seems to produce better SASS...\n  inline __device__ void clear() {\n#pragma unroll\n    for (int ii = 0; ii < Base_::NUM_REGS; ++ii) {\n      asm volatile(\"mov.u32 %0, 0; \\n\" : \"=r\"(this->reg(ii)) :);\n    }\n  }\n\n  // Immutable access to a register.\n  inline __device__ const uint32_t& reg(int ii) const { return this->regs_[ii]; }\n\n  // Mutable access to a register.\n  inline __device__ uint32_t& reg(int ii) { return this->regs_[ii]; }\n\n  uint32_t regs_[Base_::NUM_REGS];\n\n  // Immutable access to the elements.\n  inline __device__ const Data_type_& elt(int ii) const {\n    return reinterpret_cast<const Data_type_*>(&this->regs_[0])[ii];\n  }\n\n  // Mutable access to the elements.\n  inline __device__ Data_type_& elt(int ii) { return reinterpret_cast<Data_type_*>(&this->regs_[0])[ii]; }\n\n  // Immutable access to the elements with a cast.\n  template <typename Cast_type>\n  inline __device__ const Cast_type& elt_as(int ii) const {\n    return reinterpret_cast<const Cast_type*>(&this->regs_[0])[ii];\n  }\n\n  // Mutable access to the elements.\n  template <typename Cast_type>\n  inline __device__ Cast_type& elt_as(int ii) {\n    return reinterpret_cast<Cast_type*>(&this->regs_[0])[ii];\n  }\n\n  // Add another fragment.\n  inline __device__ void add(const Fragment& other) {\n#pragma unroll\n    for (int ii = 0; ii < NUM_ELTS_; ++ii) {\n      this->elt(ii) += other.elt(ii);\n    }\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Layout>\nstruct Fragment_a : public Fragment<uint16_t, 8> {};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Layout>\nstruct Fragment_b : public Fragment<uint16_t, 8> {};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Fragment_accumulator : public Fragment<float, 8> {\n  // The base class.\n  using Base = Fragment<float, 8>;\n\n  // Add two fragments.\n  template <typename Other_fragment_>\n  inline __device__ void add(const Other_fragment_& other) {\n    for (int ii = 0; ii < Base::NUM_ELTS; ++ii) {\n      this->elt(ii) = this->elt(ii) + other.elt(ii);\n    }\n  }\n\n  // Do the HMMA.\n  template <typename Layout_a, typename Layout_b>\n  inline __device__ void mma(const Fragment_a<Layout_a>& a, const Fragment_b<Layout_b>& b) {\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \\n\"\n        \"    {%0, %1, %2, %3}, \\n\"\n        \"    {%4, %5, %6, %7}, \\n\"\n        \"    {%8, %9}, \\n\"\n        \"    {%0, %1, %2, %3}; \\n\"\n        : \"+f\"(elt(0)), \"+f\"(elt(1)), \"+f\"(elt(2)), \"+f\"(elt(3))\n        : \"r\"(a.reg(0)), \"r\"(a.reg(1)), \"r\"(a.reg(2)), \"r\"(a.reg(3)), \"r\"(b.reg(0)), \"r\"(b.reg(1)));\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \\n\"\n        \"    {%0, %1, %2, %3}, \\n\"\n        \"    {%4, %5, %6, %7}, \\n\"\n        \"    {%8, %9}, \\n\"\n        \"    {%0, %1, %2, %3}; \\n\"\n        : \"+f\"(elt(4)), \"+f\"(elt(5)), \"+f\"(elt(6)), \"+f\"(elt(7))\n        : \"r\"(a.reg(0)), \"r\"(a.reg(1)), \"r\"(a.reg(2)), \"r\"(a.reg(3)), \"r\"(b.reg(2)), \"r\"(b.reg(3)));\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Fragment, int M, int N>\ninline __device__ void clear(Fragment (&frag)[M][N]) {\n#pragma unroll\n  for (int mi = 0; mi < M; ++mi) {\n#pragma unroll\n    for (int ni = 0; ni < N; ++ni) {\n      frag[mi][ni].clear();\n    }\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Accumulator_type, int WARPS_K>\nstruct Clear_accumulator {};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int WARPS_K>\nstruct Clear_accumulator<float, WARPS_K> {\n  template <typename Acc, int M, int N>\n  static inline __device__ void apply(Acc (&acc)[M][N], bool = false) {\n    fmha::clear(acc);\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Acc, typename A, typename B, int M, int N>\ninline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {\n#pragma unroll\n  for (int mi = 0; mi < M; ++mi) {\n#pragma unroll\n    for (int ni = 0; ni < N; ++ni) {\n      acc[mi][ni].mma(a[mi], b[ni]);\n    }\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    // The number of rows in the CTA tile.\n    int M_,\n    // The number of cols in the CTA tile.\n    int N_,\n    // The number of elements in the the K dimension of the GEMM loop.\n    int K_,\n    // The number of rows of warps.\n    int WARPS_M_,\n    // The number of cols of warps.\n    int WARPS_N_,\n    // The number of warps in the K dimension of the GEMM loop.\n    int WARPS_K_>\nstruct Cta_tile_ {\n  enum { M = M_, N = N_, K = K_ };\n  // The number of warps.\n  enum { WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_ };\n  // The number of warps per CTA.\n  enum { WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K };\n  // The number of threads per warp.\n  enum { THREADS_PER_WARP = 32 };\n  // The number of threads per CTA.\n  enum { THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Cta_tile>\nstruct Hmma_tile {\n  // The number of elements computed with a single warp-MMA.\n  enum { M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16 };\n\n  // The number of elements computed with a single CTA-MMA.\n  enum {\n    M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M,\n    N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N,\n    K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K\n  };\n\n  // The number of MMAs needed to compute the GEMM.\n  enum {\n    MMAS_M = Div_up<Cta_tile::M, M_PER_MMA_PER_CTA>::VALUE,\n    MMAS_N = Div_up<Cta_tile::N, N_PER_MMA_PER_CTA>::VALUE,\n    MMAS_K = Div_up<Cta_tile::K, K_PER_MMA_PER_CTA>::VALUE,\n  };\n\n  // The number of elements computed per warp.\n  enum {\n    M_PER_WARP = MMAS_M * M_PER_MMA,\n    N_PER_WARP = MMAS_N * N_PER_MMA,\n    K_PER_WARP = MMAS_K * K_PER_MMA,\n  };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nusing A_type = uint16_t;\nusing B_type = uint16_t;\nusing C_type = uint16_t;\nusing Accumulator_type = float;\nusing Epilogue_type = float;\n\nconstexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8;\nconstexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8;\nconstexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int M, int N, int K, int WARPS_M, int WARPS_N, int WARPS_K>\nusing Cta_tile_extd = Cta_tile_<M, N, K, WARPS_M, WARPS_N, WARPS_K>;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Cta_tile_>\nusing Cta_tile_with_k_with_padding = Cta_tile_extd<Cta_tile_::M, Cta_tile_::N, Next_power_of_two<Cta_tile_::K>::VALUE,\n                                                   Cta_tile_::WARPS_M, Cta_tile_::WARPS_N, Cta_tile_::WARPS_K>;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha/gmem_tile.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The number of bits per element.\n    int BITS_PER_ELEMENT,\n    // The number of rows of Q, K or V loaded by this tile.\n    int ROWS,\n    // The number of columns.\n    int COLS,\n    // The number of matrics.\n    int NUM_MATS = 3>\nstruct Gmem_tile_qkv {\n  // The size of each LDG.\n  enum { BYTES_PER_LDG = 16 };\n  // The size of a row in bytes.\n  enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 };\n\n  // The number of threads to load a \"row\" of the matrix.\n  enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG };\n\n  // The number of \"rows\" loaded per LDG.\n  enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };\n  // The number of LDGs needed to load a chunk of the Q matrix.\n  enum { LDGS = fmha::Div_up<ROWS, ROWS_PER_LDG>::VALUE };\n\n  // Ctor.\n  template <typename Params, typename BInfo>\n  inline __device__ Gmem_tile_qkv(const Params& params, const int qkv_offset, const BInfo& binfo, const int tidx)\n      : params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes),\n        actual_seqlen(binfo.actual_seqlen),\n        qkv_ptr_(reinterpret_cast<char*>(params.qkv_ptr)) {\n    // Compute the position in the sequence (within the CTA for the moment).\n    int row = tidx / THREADS_PER_ROW;\n    // Compute the position of the thread in the row.\n    int col = tidx % THREADS_PER_ROW;\n\n    // Store the row as we need it to disable the loads.\n    row_ = row;\n\n    // The row offset in the batched GEMM. For each seq element, we store QKV in that order.\n    int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;\n    // Add the block index.\n    row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;\n\n    // Assemble the final pointer.\n    qkv_ptr_ += row_offset + col * BYTES_PER_LDG;\n  }\n\n  // Store data to shared memory.\n  template <typename Smem_tile>\n  inline __device__ void commit(Smem_tile& smem_tile) {\n    smem_tile.store(fetch_);\n  }\n\n  // Load data from memory.\n  template <typename Smem_tile>\n  inline __device__ void load(Smem_tile& smem_tile) {\n    const void* ptrs[LDGS];\n    uint32_t preds[LDGS];\n#pragma unroll\n    for (int ii = 0; ii < LDGS; ++ii) {\n      ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;\n      preds[ii] = ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen));\n      fetch_[ii] = make_uint4(0, 0, 0, 0);\n    }\n\n    // not packing predicates removes restrictions (e.g. FP16 384, 4 warps)\n    Ldg_functor<uint4, LDGS> fct(fetch_, ptrs);\n#pragma unroll\n    for (int ii = 0; ii < LDGS; ++ii) {\n      fct.load(ii, preds[ii]);\n    }\n  }\n\n  // Store data to memory.\n  inline __device__ void store(const uint4 (&data)[LDGS]) {\n#pragma unroll\n    for (int ii = 0; ii < LDGS; ++ii) {\n      char* ptr = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;\n      if ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)) {\n        fmha::stg(ptr, data[ii]);\n      }\n    }\n  }\n\n  // Move the pointer to the next location.\n  inline __device__ void move() {\n    qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_;\n    actual_seqlen -= ROWS;\n  }\n\n  inline __device__ void move(int steps) {\n    qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_ * steps;\n    actual_seqlen -= ROWS * steps;\n  }\n\n  // The stride between rows for the QKV matrice.\n  int64_t params_qkv_stride_in_bytes_;\n  // The pointer.\n  char* qkv_ptr_;\n  // The fetch registers.\n  uint4 fetch_[LDGS];\n  // Keep track of the row the thread is processing as we move the tile.\n  int row_;\n  // The length of the sequence loaded by that memory tile.\n  int actual_seqlen;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Cta_tile>\nstruct Gmem_tile_o {\n  // The mma tile.\n  using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n\n  // The size of each element.\n  enum { BYTES_PER_ELEMENT = 2 };\n  // The size of a row in bytes.\n  enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT };\n\n  // The number of threads to store a \"row\" of the matrix.\n  enum { THREADS_PER_ROW = 16 };\n  // The size of each STG.\n  enum { BYTES_PER_STG = BYTES_PER_ROW / THREADS_PER_ROW };\n\n  // The number of \"rows\" stored per iteration of the loop. The output of 1 MMA.\n  enum { ROWS = Cta_tile::M };\n  // The number of \"rows\" stored per iteration of the loop. The output of 1 MMA.\n  enum { ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA };\n  // The number of outter loop for the stores.\n  enum { LOOPS = ROWS / ROWS_PER_LOOP };\n\n  // The number of \"rows\" stored per STG.\n  enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };\n  // Do we have to guard against partial writes/reads.\n  enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 };\n  // The number of STGs needed to store a chunk of the Q matrix.\n  enum { STGS_PER_LOOP = fmha::Div_up<ROWS_PER_LOOP, ROWS_PER_STG>::VALUE };\n  // The number of STGs needed to store a chunk of the Q matrix in total.\n  enum { STGS = STGS_PER_LOOP * LOOPS };\n\n  // Ctor.\n  template <typename Params, typename BInfo>\n  inline __device__ Gmem_tile_o(const Params& params, const BInfo& binfo, int tidx)\n      : params_o_stride_in_bytes_(params.o_stride_in_bytes),\n        actual_seqlen_(binfo.actual_seqlen),\n        o_ptr_(reinterpret_cast<char*>(params.o_ptr)) {\n    // Compute the position in the sequence (within the CTA for the moment).\n    int row = tidx / THREADS_PER_ROW;\n    // Compute the position of the thread in the row.\n    int col = tidx % THREADS_PER_ROW;\n\n    // Store the row as we need it to disable loads.\n    row_ = row;\n\n    // The row offset in the batched GEMM.\n    int64_t row_offset = (int64_t)row * params.o_stride_in_bytes + binfo.bidx * BYTES_PER_ROW;\n    // Assemble the final pointer.\n    o_ptr_ += row_offset + col * BYTES_PER_STG;\n\n    // Is that thread active on the last STG?\n    if (HAS_INCOMPLETE_STG) {\n      is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M;\n    }\n  }\n\n  // Store data to global memory.\n  inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) {\n#pragma unroll\n    for (int ii = 0; ii < STGS_PER_LOOP; ++ii) {\n      int jj = mi * STGS_PER_LOOP + ii;\n      if (this->row_ + jj * ROWS_PER_STG >= this->actual_seqlen_) {\n        break;\n      }\n\n      float x = reinterpret_cast<const float&>(src[ii].x);\n      float y = reinterpret_cast<const float&>(src[ii].y);\n      float z = reinterpret_cast<const float&>(src[ii].z);\n      float w = reinterpret_cast<const float&>(src[ii].w);\n      uint2 out = float4_to_half4(x, y, z, w);\n      if (!HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_)) {\n        fmha::stg(this->o_ptr_ + jj * ROWS_PER_STG * this->params_o_stride_in_bytes_, out);\n      }\n    }\n  }\n\n  // Move the pointer to the next location.\n  inline __device__ void move() {\n    row_ += ROWS;\n    o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_;\n  }\n\n  inline __device__ void move(const int steps) {\n    row_ += ROWS * steps;\n    o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_ * steps;\n  }\n\n  // The stride between rows for the QKV matrice.\n  int64_t params_o_stride_in_bytes_;\n  // The pointer.\n  char* o_ptr_;\n  // Is the thread active for the last STG?\n  int is_active_for_last_stg_;\n  // Keep track of the row to disable loads.\n  int row_;\n  // The length of the sequence loaded by that memory tile.\n  int actual_seqlen_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Cta_tile, int BYTES_PER_ELEMENT>\nstruct Gmem_tile_mma_sd {\n  // The mma tile.\n  using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n\n  // Each STG stores 8 elements.\n  enum { BYTES_PER_STG = BYTES_PER_ELEMENT * 8 };\n  // The number of MMAs in the M dimension.\n  enum { MMAS_M = Mma_tile::MMAS_M };\n  // The number of MMAs in the N dimension.\n  enum { MMAS_N = Mma_tile::MMAS_N };\n  // The number of rows computed per MMA per thread block.\n  enum { M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA };\n  // The number of cols computed per MMA per thread block.\n  enum { N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA };\n  // The number of threads per block.\n  enum { THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA };\n  // The size of each row in bytes. I.e. how many bytes are stored per STG.\n  enum { BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG };\n  // The fixed sequence length.\n  enum { SEQLEN = Cta_tile::N };\n  // The distance between two blocks (in bytes).\n  enum { BLOCK_STRIDE_BYTES = SEQLEN * SEQLEN * BYTES_PER_ELEMENT };\n  // The distance between elements stored per loop (in bytes).\n  enum { LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW };\n\n  // The type of elements stored per STG.\n  using Type = typename fmha::Uint_from_size_in_bytes<BYTES_PER_STG>::Type;\n\n  // Ctor.\n  template <typename Params>\n  inline __device__ Gmem_tile_mma_sd(void* ptr, const Params& params, const int bidb, const int bidh, const int tidx)\n      : ptr_(static_cast<char*>(ptr)) {\n    // The block index.\n    size_t bidx = bidb * params.h + bidh;\n\n    // Set store location for each thread at the beginning of the loop\n    ptr_ += bidx * BLOCK_STRIDE_BYTES + tidx * BYTES_PER_STG;\n  }\n\n  // Store to global memory.\n  inline __device__ void store(const Type& data, const int mi, const int ni) {\n    size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;\n    fmha::stg(ptr_ + offset, data);\n  }\n\n  // Load from global memory.\n  inline __device__ void load(Type& data, const int mi, const int ni) {\n    size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;\n    fmha::ldg(data, ptr_ + offset);\n  }\n\n  // Move to the next tile.\n  inline __device__ void move() { ptr_ += LOOP_STRIDE_BYTES; }\n  inline __device__ void move(const int steps) { ptr_ += LOOP_STRIDE_BYTES * steps; }\n\n  // The pointer in global memory.\n  char* ptr_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Cta_tile, typename Base = Gmem_tile_mma_sd<Cta_tile, sizeof(uint16_t)> >\nstruct Gmem_tile_mma_s : public Base {\n  // The number of mmas in the vertical dimension.\n  enum { M = Base::MMAS_M };\n  // The number of mmas in the horizontal dimension.\n  enum { N = Base::MMAS_N };\n  // The type of the vectors stored by each STG.\n  using Type = typename Base::Type;\n\n  // Ctor.\n  template <typename Params, typename Block_info>\n  inline __device__ Gmem_tile_mma_s(const Params& params, const Block_info& binfo, const int tidx)\n      : Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {}\n\n  // Store to global memory.\n  template <typename Mask>\n  inline __device__ void store(const float (&softmax)[2 * M][4 * N], const Mask& mask) {\n#pragma unroll\n    for (int mi = 0; mi < M; mi++) {\n#pragma unroll\n      for (int ni = 0; ni < N; ni++) {\n        float tmp00 = softmax[2 * mi + 0][4 * ni + 0];\n        float tmp01 = softmax[2 * mi + 0][4 * ni + 1];\n        float tmp02 = softmax[2 * mi + 0][4 * ni + 2];\n        float tmp03 = softmax[2 * mi + 0][4 * ni + 3];\n\n        float tmp10 = softmax[2 * mi + 1][4 * ni + 0];\n        float tmp11 = softmax[2 * mi + 1][4 * ni + 1];\n        float tmp12 = softmax[2 * mi + 1][4 * ni + 2];\n        float tmp13 = softmax[2 * mi + 1][4 * ni + 3];\n\n        uint4 dst;\n        dst.x = fmha::float2_to_half2(tmp00, tmp01);\n        dst.y = fmha::float2_to_half2(tmp02, tmp03);\n        dst.z = fmha::float2_to_half2(tmp10, tmp11);\n        dst.w = fmha::float2_to_half2(tmp12, tmp13);\n        if (mask.is_valid(mi, ni, 0, 0)) {\n          Base::store(dst, mi, ni);\n        }\n      }\n    }\n  }\n\n  // Store to global memory.\n  template <typename Mask, typename Fragment>\n  inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask) {\n#pragma unroll\n    for (int mi = 0; mi < M; mi++) {\n#pragma unroll\n      for (int ni = 0; ni < N; ni++) {\n        uint4 dst;\n        dst.x = frag[ni][mi].reg(0);\n        dst.y = frag[ni][mi].reg(2);\n        dst.z = frag[ni][mi].reg(1);\n        dst.w = frag[ni][mi].reg(3);\n        if (mask.any_valid(mi, ni)) {\n          Base::store(dst, mi, ni);\n        }\n      }\n    }\n  }\n\n  // Load from global memory.\n  template <typename Mask>\n  inline __device__ void load(uint4 (&regs)[M][N], const Mask& mask) {\n#pragma unroll\n    for (int mi = 0; mi < M; mi++) {\n#pragma unroll\n      for (int ni = 0; ni < N; ni++) {\n        regs[mi][ni] = make_uint4(0, 0, 0, 0);\n        if (mask.any_valid(mi, ni)) {\n          Base::load(regs[mi][ni], mi, ni);\n        }\n      }\n    }\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The base class.\n    typename Base = fmha::Gmem_tile_qkv<Cta_tile, fmha::BITS_PER_ELEMENT_A, Cta_tile::M, Cta_tile::K> >\nstruct Gmem_tile_dout : public Base {\n  // Ctor.\n  template <typename Params, typename BInfo>\n  inline __device__ Gmem_tile_dout(const Params& params, const BInfo& binfo, int tidx) : Base(params, 0, binfo, tidx) {\n    this->qkv_ptr_ = reinterpret_cast<char*>(params.o_ptr);\n    this->params_qkv_stride_in_bytes_ = params.o_stride_in_bytes;  // needed for move\n\n    // Compute the position of the thread in the row.\n    int col = tidx % Base::THREADS_PER_ROW;\n\n    // The row offset in the batched GEMM. For each seq element, we store O in that order.\n    int64_t row_offset = (int64_t)this->row_ * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW;\n\n    // Assemble the final pointer.\n    this->qkv_ptr_ += row_offset + col * Base::BYTES_PER_LDG;\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Cta_tile, typename Base = fmha::Gmem_tile_o<Cta_tile> >\nstruct Gmem_tile_dq : public Base {\n  // Ctor.\n  template <typename Params, typename BInfo>\n  inline __device__ Gmem_tile_dq(const Params& params, const BInfo& binfo, int tidx) : Base(params, binfo, tidx) {\n    this->o_ptr_ = reinterpret_cast<char*>(params.dqkv_ptr);\n    this->params_o_stride_in_bytes_ = params.qkv_stride_in_bytes;  // needed for move\n\n    // Compute the position of the thread in the row.\n    int col = tidx % Base::THREADS_PER_ROW;\n\n    // The row offset in the batched GEMM. For each seq element, we store O in that order.\n    int64_t row_offset = (int64_t)this->row_ * params.qkv_stride_in_bytes +\n                         (binfo.sum_s * 3 * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW;\n\n    // Assemble the final pointer.\n    this->o_ptr_ += row_offset + col * Base::BYTES_PER_STG;\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha/kernel_traits.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include \"gmem_tile.h\"\n#include \"smem_tile.h\"\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x08u>\nstruct FMHA_kernel_traits {\n  // The CTA description for the 1st GEMM.\n  using Cta_tile_p = fmha::Cta_tile_extd<STEP, S, D, WARPS_M, WARPS_N, 1>;\n  // The CTA description for the 2nd GEMM.\n  using Cta_tile_o = fmha::Cta_tile_extd<STEP, D, S, WARPS_M, 1, WARPS_N>;\n\n  // Do we use one buffer for K and V.\n  enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x08u) != 0u };\n  // Do we keep K in registers.\n  enum { K_IN_REGS = (FLAGS & 0x10u) == 0u };\n\n  // The global memory tile to load Q.\n  using Gmem_tile_q = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>;\n\n  // The shared memory tile to swizzle Q.\n  using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;\n\n  // The global memory tile to load K.\n  using Gmem_tile_k = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_B, S, D>;\n  // The shared memory tile to swizzle K.\n  using Smem_tile_k = fmha::Smem_tile_b<Cta_tile_p, fmha::Col>;\n\n  // The global memory tile to load V.\n  using Gmem_tile_v = fmha::Gmem_tile_qkv<Cta_tile_o, fmha::BITS_PER_ELEMENT_B, S, D>;\n  // The shared memory tile to swizzle V.\n  using Smem_tile_v = fmha::Smem_tile_v<Cta_tile_o>;\n\n  // The global memory tile to store O.\n  using Gmem_tile_o = fmha::Gmem_tile_o<Cta_tile_o>;\n  // The shared memory tile for O.\n  using Smem_tile_o = fmha::Smem_tile_o<Cta_tile_o>;\n\n  // The global memory tile to load/store S.\n  using Gmem_tile_s = fmha::Gmem_tile_mma_s<Cta_tile_p>;\n\n  // The shared memory tile to transpose S.\n  using Smem_tile_st = fmha::Smem_tile_mma_transposed<Cta_tile_p>;\n\n  using Gmem_tile_do = fmha::Gmem_tile_dout<Cta_tile_p>;\n\n  // Make sure the number of threads match.\n  static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, \"\");\n\n  // The number of threads.\n  enum { THREADS = Cta_tile_p::THREADS_PER_CTA };\n  // Make sure the number of threads matches both CTAs.\n  static_assert((int)THREADS == (int)Cta_tile_o::THREADS_PER_CTA, \"\");\n\n  // The amount of shared memory needed to load Q and K.\n  enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE };\n  // The extra amount of shared memory needed to load V.\n  enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE };\n  // The amount of shared memory needed for Q, K and V..\n  enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V };\n  // The amount of shared memory needed to load Q and store O.\n  enum { BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE };\n\n  // The amount of shared memory needed for Q, K, V and O.\n  enum { BYTES_PER_SMEM = fmha::Max<BYTES_PER_SMEM_QKV, BYTES_PER_SMEM_QO>::VALUE };\n  // Make sure we have enough shared memory.\n  static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, \"\");\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha/mask.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\nnamespace fmha {\n\ntemplate <typename Cta_tile>\nstruct Mask {\n  using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n\n  template <typename Params, typename BInfo>\n  __device__ Mask(const Params& params, const BInfo& blockInfo, int tidx) {\n    actual_seqlen = blockInfo.actual_seqlen;\n\n    const int warp = tidx / Cta_tile::THREADS_PER_WARP;\n    const int lane = tidx % Cta_tile::THREADS_PER_WARP;\n\n    static_assert(Cta_tile::WARPS_K == 1, \"\");\n\n    // find the warp in the Cta tile\n    const int warp_n = (warp / Cta_tile::WARPS_M);\n    const int warp_m = (warp % Cta_tile::WARPS_M);\n    // decompose warp into 8x4 tile\n    const int quad = lane / 4;\n    const int tid = (lane % 4) * 2;\n    row = warp_m * 16 + quad;\n    col = warp_n * 16 + tid;\n  }\n\n  inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const {\n    // ii and jj iterate over the 2x4 fragment\n    const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen;\n    //&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen;\n    return col_valid;\n    // return row_valid && col_valid;\n  }\n\n  // BERT Mask: if upper left is invalid, none are valid\n  inline __device__ bool any_valid(int mi, int ni) const { return is_valid(mi, ni, 0, 0); }\n\n  inline __device__ void load(int it) { row_offset = it * Cta_tile::M + row; }\n  int row_offset;\n\n  int row;\n  int col;\n  int actual_seqlen;\n};\n\n}  // namespace fmha\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha/smem_tile.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <fmha/gemm.h>\n#include <fmha/utils.h>\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    // The description of the tile computed by this CTA.\n    typename Cta_tile,\n    // The number of rows in the 2D shared memory buffer.\n    int M_,\n    // The number of cols.\n    int N_,\n    // The size in bits of each element.\n    int BITS_PER_ELEMENT_,\n    // The number of bytes per STS.\n    int BYTES_PER_STS_ = 16,\n    // The number of buffers. (Used in multistage and double buffer cases.)\n    int BUFFERS_PER_TILE_ = 1,\n    // Do we enable the fast path for LDS.128 and friends.\n    int ENABLE_LDS_FAST_PATH_ = 0,\n    // The number of rows that are used for the XOR swizzling to allow fast STS/LDS.\n    int ROWS_PER_XOR_PATTERN_ = 8,\n    // The number of cols that are used for the XOR swizzling to allow fast STS/LDS.\n    int COLS_PER_XOR_PATTERN_ = 1,\n    // Use or not predicates\n    bool USE_PREDICATES_ = true>\nstruct Smem_tile_without_skews {\n  // The size in bits of each element.\n  enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ };\n  // The size in bytes of a single STS.\n  enum { BYTES_PER_STS = BYTES_PER_STS_ };\n  // The number of elements per STS.\n  enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT };\n  // To support arbitrary N, we pad some values to a power-of-2.\n  enum { N_WITH_PADDING = Next_power_of_two<N_>::VALUE };\n  // The number of bytes per row without packing of rows.\n  enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 };\n  // The number of bytes per row -- we want at least 128B per row.\n  enum { BYTES_PER_ROW = Max<BYTES_PER_ROW_BEFORE_PACKING, 128>::VALUE };\n  // The number of rows in shared memory (two rows may be packed into a single one).\n  enum { ROWS = M_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW };\n\n  // The number of threads per row.\n  enum { THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS };\n  // The number of threads per row.\n  enum { THREADS_PER_ROW = Min<Cta_tile::THREADS_PER_CTA, THREADS_PER_ROW_UNBOUNDED>::VALUE };\n\n  // The number of STS per row.\n  enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS };\n  // It must be at least one.\n  static_assert(STS_PER_ROW >= 1, \"\");\n  // The number of rows written with a single STS.\n  enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };\n  // Make sure we write to at least one row per STS. Thanks Dr. Obvious ;)\n  static_assert(ROWS_PER_STS >= 1, \"\");\n  // The number of STS needed to store all rows.\n  enum { STS_PER_COL = Div_up<ROWS, ROWS_PER_STS>::VALUE };\n  // The number of STS in total.\n  enum { STS = STS_PER_COL * STS_PER_ROW };\n\n  // The size of one buffer in bytes in shared memory.\n  enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA };\n  // The number of buffers.\n  enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ };\n  // The size in bytes of total buffers.\n  enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE };\n  // The boundary for smem_read_offset and smem_write_offset increment.\n  enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER };\n\n  // Do we enable the LDS.128 fast path?\n  enum { ENABLE_LDS_FAST_PATH = ENABLE_LDS_FAST_PATH_ };\n  static_assert(ENABLE_LDS_FAST_PATH == 0);\n  // The number of rows that are used for the XOR swizzling to allow fast STS/LDS.\n  enum { ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_ };\n  // The number of cols that are used for the XOR swizzling to allow fast STS/LDS.\n  enum { COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS };\n  // Use or not predicates\n  enum { USE_PREDICATES = USE_PREDICATES_ };\n\n  // The type of elements that are stored in shared memory by each thread.\n  using Store_type = typename Uint_from_size_in_bytes<BYTES_PER_STS>::Type;\n\n  // Ctor.\n  inline __device__ Smem_tile_without_skews(void* smem, int tidx) : smem_(__nvvm_get_smem_pointer(smem)) {\n    // The row written by a thread. See doc/mma_smem_layout.xlsx.\n    int smem_write_row = tidx / THREADS_PER_ROW;\n\n    // The XOR pattern.\n    int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN;\n    // Compute the column and apply the XOR pattern.\n    int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor;\n\n    // The offset.\n    this->smem_write_offset_ = smem_write_row * BYTES_PER_ROW + smem_write_col * BYTES_PER_STS;\n\n    // TODO: Why not merge it with the read offset?\n    this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0);\n    this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0);\n  }\n\n  // Compute the store pointers.\n  template <int N>\n  inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) {\n#pragma unroll\n    for (int ii = 0; ii < N; ++ii) {\n      // Decompose the STS into row/col.\n      int row = ii / STS_PER_ROW;\n      int col = ii % STS_PER_ROW;\n\n      // Assemble the offset.\n      int offset = smem_write_offset_ + row * ROWS_PER_STS * BYTES_PER_ROW;\n\n      // Take the column into account.\n      if (STS_PER_ROW > 1) {\n        offset += col * THREADS_PER_ROW * BYTES_PER_STS;\n      }\n\n      // Apply the XOR pattern if needed.\n      if (ROWS_PER_STS < ROWS_PER_XOR_PATTERN) {\n        const int m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN;\n        offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS;\n      }\n\n      // Assemble the final pointer :)\n      ptrs[ii] = smem_ + offset + smem_write_buffer_;\n    }\n  }\n\n  inline __device__ void debug_reset() {\n    for (int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) {\n      for (int row = 0; row < ROWS; ++row) {\n        for (int col = 0; col < BYTES_PER_ROW; col += 4) {\n          if (threadIdx.x == 0) {\n            uint32_t val = 0x0;\n            sts(val, smem_ + row * BYTES_PER_ROW + col + buffer);\n          }\n        }\n      }\n    }\n  }\n\n  // Print the content of the tile (only for debug ;)).\n  inline __device__ void debug_print() const {\n    for (int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) {\n      for (int row = 0; row < ROWS; ++row) {\n        for (int col = 0; col < BYTES_PER_ROW; col += 4) {\n          if (threadIdx.x == 0) {\n            uint32_t val;\n            lds(val, smem_ + row * BYTES_PER_ROW + col + buffer);\n            printf(\"block=(x=%2d, y=%2d, z=%2d) (smem_=%2d, buffer=%2d, row=%2d, byte=%4d)=0x%08x\\n\", blockIdx.x,\n                   blockIdx.y, blockIdx.z, smem_, buffer, row, col, val);\n          }\n        }\n      }\n    }\n  }\n\n  // Move the read offset to next buffer.\n  inline __device__ void move_to_next_read_buffer() {\n    if (BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY) {\n      this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;\n    } else if (BUFFERS_PER_TILE > 1) {\n      this->smem_read_buffer_ += BYTES_PER_BUFFER;\n    }\n  }\n\n  // Move the read offset to next buffer. TODO: Remove this member function!!!\n  inline __device__ void move_next_read_buffer() { this->move_to_next_read_buffer(); }\n\n  // Move the read offset to next N buffer (circular-buffer).\n  inline __device__ void move_to_next_read_buffer(int N) {\n    if (BUFFERS_PER_TILE > 1) {\n      this->smem_read_buffer_ += N * BYTES_PER_BUFFER;\n      this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0;\n    }\n  }\n\n  // Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!!\n  inline __device__ void move_next_read_buffer(int N) { this->move_to_next_read_buffer(N); }\n\n  // Move the write offset to next buffer.\n  inline __device__ void move_to_next_write_buffer() {\n    if (BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY) {\n      this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;\n    } else if (BUFFERS_PER_TILE > 1) {\n      this->smem_write_buffer_ += BYTES_PER_BUFFER;\n    }\n  }\n\n  // Move the write offset to next buffer. TODO: Remove that member function!\n  inline __device__ void move_next_write_buffer() { this->move_to_next_write_buffer(); }\n\n  // Move the read offset.\n  inline __device__ void move_read_offset(int delta) { this->smem_read_offset_ += delta; }\n\n  // Move the write offset.\n  inline __device__ void move_write_offset(int delta) { this->smem_write_offset_ += delta; }\n\n  // Store to the tile in shared memory.\n  template <int N>\n  inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) {\n    uint32_t smem_ptrs[N];\n    this->compute_store_pointers(smem_ptrs);\n    sts(smem_ptrs, data);\n  }\n\n  // Store to the tile in shared memory.\n  template <int N, int M>\n  inline __device__ void store(const Store_type (&data)[N], uint32_t (&preds)[M], uint64_t = 0) {\n    uint32_t smem_ptrs[N];\n    this->compute_store_pointers(smem_ptrs);\n    sts(smem_ptrs, data, preds);\n  }\n\n  // Store to the tile in shared memory.\n  template <int N>\n  inline __device__ void store(const Store_type (&data)[N], uint32_t preds, uint64_t = 0) {\n    this->store(data, preds);\n  }\n\n  // Store to the tile in shared memory.\n  template <int N>\n  inline __device__ void store(const void* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0) {\n    uint32_t tmp[1] = {preds};\n    this->store(gmem_ptrs, tmp);\n  }\n\n  // The shared memory pointer.\n  uint32_t smem_;\n  // The read offset. Reserve 4 offsets if needed.\n  int smem_read_offset_;\n  // The write offset.\n  int smem_write_offset_;\n  // The buffer base offset for read.\n  int smem_read_buffer_;\n  // The buffer base offset for write.\n  int smem_write_buffer_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The layout of the tile.\n    typename Layout,\n    // The size of the STS.\n    int BYTES_PER_STS = 16,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE = 1,\n    // Use or not predicates\n    bool USE_PREDICATES = true>\nstruct Smem_tile_a {};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int MMAS_K, int MMAS_K_WITH_PADDING>\nstruct Compute_reset_mask {\n  // The potential mask.\n  enum { HALF = MMAS_K_WITH_PADDING / 2 };\n  // The remainder.\n  enum { MOD = MMAS_K % HALF };\n  // The final value.\n  enum { VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask<MOD, HALF>::VALUE };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int MMAS_K_WITH_PADDING>\nstruct Compute_reset_mask<0, MMAS_K_WITH_PADDING> {\n  enum { VALUE = 0 };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int MMAS_K>\nstruct Compute_reset_mask<MMAS_K, MMAS_K> {\n  enum { VALUE = MMAS_K - 1 };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nstruct Rows_per_xor_pattern_a {\n  // The size in bits.\n  enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_A };\n  // The number of rows.\n  enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nstruct Rows_per_xor_pattern_row_a : public Rows_per_xor_pattern_a<N> {};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The size of the STS.\n    int BYTES_PER_STS,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE,\n    // How many rows to use for the XOR pattern to avoid bank conflicts?\n    int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_a<Cta_tile::K>::VALUE>\nstruct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile, Cta_tile::M, Cta_tile::K, fmha::BITS_PER_ELEMENT_A,\n                                                        BYTES_PER_STS, BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1> {\n  // The MMA tile.\n  using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n  // The base class.\n  using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::M, Cta_tile::K, fmha::BITS_PER_ELEMENT_A, BYTES_PER_STS,\n                                       BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>;\n  // The fragment.\n  using Fragment = Fragment_a<Row>;\n\n  // When we use padding to reach a power of two, special care has to be taken.\n  using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Cta_tile>;\n  // The number of MMAs.\n  using Mma_tile_with_padding = fmha::Hmma_tile<Cta_tile_with_padding>;\n\n  // The size of a single LDS in bytes.\n  enum { BYTES_PER_LDS = 16 };\n\n  // Ctor.\n  inline __device__ Smem_tile_row_a(void* smem, int tidx) : Base(smem, tidx) {\n    // For documentation on the layout, see doc/mma_smem_layout.xlsx.\n\n    // The number of warps.\n    const int WARPS_M = Cta_tile::WARPS_M;\n    const int WARPS_N = Cta_tile::WARPS_N;\n    const int WARPS_K = Cta_tile::WARPS_K;\n\n    static_assert(WARPS_M == 1);\n    static_assert(WARPS_N == 4 || WARPS_N == 8);\n    static_assert(WARPS_K == 1);\n    static_assert(Base::ROWS_PER_XOR_PATTERN == 8);\n\n    // The row and column read by the thread.\n    int smem_read_row = (tidx & 0x0f);\n    int smem_read_col = (tidx & 0x07);\n    smem_read_col ^= (tidx & 0x10) / 16;\n\n    // The shared memory offset.\n    this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS;\n  }\n\n  // Rewind smem_read_offset for last LDS phase in main loop.\n  inline __device__ void reverse_smem_read_offset(int ki = 0) {\n    // Undo the pointer increment for the next ni.\n    // Should match the load function below for ki = 0.\n    if (Mma_tile_with_padding::MMAS_K >= 2) {\n      this->smem_read_offset_ ^= BYTES_PER_LDS * 2;\n    }\n  }\n\n  // Load from shared memory.\n  inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) {\n#pragma unroll\n    for (int mi = 0; mi < Mma_tile::MMAS_M; ++mi) {\n      // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).\n      int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;\n\n      // Load using LDSM.M88.4.\n      uint4 tmp;\n      ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);\n\n      // Store the value into the fragment.\n      a[mi].reg(0) = tmp.x;\n      a[mi].reg(1) = tmp.y;\n      a[mi].reg(2) = tmp.z;\n      a[mi].reg(3) = tmp.w;\n    }\n\n    // Move the offset to the next possition. See doc/mma_smem_layout.xlsx.\n    static_assert(Mma_tile_with_padding::MMAS_K < 64, \"Not implemented\");\n    if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) {\n      this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;\n    } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) {\n      this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;\n    } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) {\n      this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2;\n    } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) {\n      this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2;\n    } else if (Mma_tile_with_padding::MMAS_K >= 2) {\n      this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;\n    }\n  }\n\n  // Reset the read offset.\n  inline __device__ void reset_read_offset() {\n    // The number of MMAs in the K dimension.\n    enum { MMAS_K = Mma_tile::MMAS_K };\n    // The number of MMAs in the K dimension when we include padding.\n    enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K };\n    // Assemble the mask.\n    enum { MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE };\n\n    // Reset the read offset.\n    this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The size of the STS.\n    int BYTES_PER_STS,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE>\nstruct Smem_tile_a<Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>\n    : public Smem_tile_row_a<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE> {\n  // The base class.\n  using Base = Smem_tile_row_a<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;\n\n  // Ctor.\n  inline __device__ Smem_tile_a(void* smem, int tidx) : Base(smem, tidx) {}\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The layout of the tile.\n    typename Layout,\n    // The size of the STS.\n    int BYTES_PER_STS = 16,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE = 1,\n    // Use or not predicates\n    bool USE_PREDICATES = true>\nstruct Smem_tile_b {};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nstruct Rows_per_xor_pattern_b {\n  // The size in bits.\n  enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_B };\n  // The number of rows.\n  enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nstruct Rows_per_xor_pattern_col_b : public Rows_per_xor_pattern_b<N> {};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The size of the STS.\n    int BYTES_PER_STS,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE,\n    // How many rows to use for the XOR pattern to avoid bank conflicts?\n    int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_col_b<Cta_tile::K>::VALUE>\nstruct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile, Cta_tile::N, Cta_tile::K, fmha::BITS_PER_ELEMENT_B,\n                                                        BYTES_PER_STS, BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1> {\n  // The MMA tile.\n  using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n  // The base class.\n  using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::N, Cta_tile::K, fmha::BITS_PER_ELEMENT_B, BYTES_PER_STS,\n                                       BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, 1>;\n  // The fragment.\n  using Fragment = Fragment_b<Col>;\n\n  // When we use padding to reach a power of two, special care has to be taken.\n  using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Cta_tile>;\n  // The number of MMAs.\n  using Mma_tile_with_padding = fmha::Hmma_tile<Cta_tile_with_padding>;\n\n  // The size of a single LDS in bytes.\n  enum { BYTES_PER_LDS = 16 };\n\n  // The number of STS per thread\n  enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA };\n  // The number of STS per thread must be at least 1.\n  enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE };\n\n  // Ctor.\n  inline __device__ Smem_tile_col_b(void* smem, int tidx) : Base(smem, tidx) {\n    // For documentation on the layout, see doc/mma_smem_layout.xlsx.\n\n    // The number of warps.\n    const int WARPS_M = Cta_tile::WARPS_M;\n    const int WARPS_N = Cta_tile::WARPS_N;\n    const int WARPS_K = Cta_tile::WARPS_K;\n    static_assert(Base::ROWS_PER_XOR_PATTERN == 8);\n    static_assert(WARPS_M == 1);\n    static_assert(WARPS_N == 4 || WARPS_N == 8);\n    static_assert(WARPS_K == 1);\n\n    // The masks to select the warps.\n    const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;\n\n    // The divisor for the warps.\n    const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;\n\n    // The row and column read by the thread.\n    int smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA + (tidx & 0x07) + (tidx & 0x10) / 2;\n    int smem_read_col = (tidx & 0x07);\n    smem_read_col ^= (tidx & 0x08) / 8;\n    // The shared memory offset.\n    this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS;\n  }\n\n  // Rewind smem_read_offset for last LDS phase in main loop.\n  inline __device__ void reverse_smem_read_offset(int ki = 0) {\n    // Undo the pointer increment for the next ni.\n    // Should match the load function below for ki = 0.\n    if (Mma_tile_with_padding::MMAS_K >= 2) {\n      this->smem_read_offset_ ^= BYTES_PER_LDS * 2;\n    }\n  }\n\n  // Load from shared memory.\n  inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {\n#pragma unroll\n    for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) {\n      // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).\n      int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;\n\n      // Load using LDSM.M88.4.\n      uint4 tmp;\n      ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);\n\n      // Store the value into the fragment.\n      b[ni].reg(0) = tmp.x;\n      b[ni].reg(1) = tmp.y;\n      b[ni].reg(2) = tmp.z;\n      b[ni].reg(3) = tmp.w;\n    }\n\n    // Move the offset to the next possition. See doc/mma_smem_layout.xlsx.\n    static_assert(Mma_tile_with_padding::MMAS_K < 64, \"Not implemented\");\n    if (Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15) {\n      this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;\n    } else if (Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7) {\n      this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;\n    } else if (Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3) {\n      this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2;\n    } else if (Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1) {\n      this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2;\n    } else if (Mma_tile_with_padding::MMAS_K >= 2) {\n      this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;\n    }\n  }\n\n  // Reset the read offset.\n  inline __device__ void reset_read_offset() {\n    // The number of MMAs in the K dimension.\n    enum { MMAS_K = Mma_tile::MMAS_K };\n    // The number of MMAs in the K dimension when we include padding.\n    enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K };\n    // Assemble the mask.\n    enum { MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE };\n\n    // Reset the read offset.\n    this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The size of the STS.\n    int BYTES_PER_STS,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE>\nstruct Smem_tile_b<Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE>\n    : public Smem_tile_col_b<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE> {\n  // The base class.\n  using Base = Smem_tile_col_b<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;\n\n  // Ctor.\n  inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {}\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nstruct Rows_per_xor_pattern_row_b : public Rows_per_xor_pattern_b<N> {};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The size of the STS.\n    int BYTES_PER_STS,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE,\n    // How many rows to use for the XOR pattern to avoid bank conflicts?\n    int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_b<Cta_tile::N>::VALUE,\n    // How many cols to use for the XOR pattern to avoid bank conflicts?\n    int COLS_PER_XOR_PATTERN_ = 1>\nstruct Smem_tile_row_b\n    : public Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, fmha::BITS_PER_ELEMENT_B, BYTES_PER_STS,\n                                     BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, COLS_PER_XOR_PATTERN_> {\n  // The MMA tile.\n  using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n  // The base class.\n  using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, fmha::BITS_PER_ELEMENT_B, BYTES_PER_STS,\n                                       BUFFERS_PER_TILE, 0, ROWS_PER_XOR_PATTERN_, COLS_PER_XOR_PATTERN_>;\n  // The fragment.\n  using Fragment = Fragment_b<Row>;\n\n  // Can we use LDSM? No if the data type is 32-bit large.\n  enum { USE_LDSMT = fmha::BITS_PER_ELEMENT_B == 16 };\n  // The size of a single LDS in bytes.\n  enum { BYTES_PER_LDS = USE_LDSMT ? 16 : 4 };\n  // The number of elements per LDS.\n  enum { ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / fmha::BITS_PER_ELEMENT_B };\n\n  // The number of STS per thread\n  enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA };\n  // The number of STS per thread must be at least 1.\n  enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE };\n\n  // Ctor.\n  inline __device__ Smem_tile_row_b(void* smem, int tidx) : Base(smem, tidx) {\n    // The number of warps.\n    const int WARPS_M = Cta_tile::WARPS_M;\n    const int WARPS_N = Cta_tile::WARPS_N;\n    const int WARPS_K = Cta_tile::WARPS_K;\n    static_assert(WARPS_K == 1);\n    static_assert(WARPS_M == 4 || WARPS_M == 8);\n    static_assert(WARPS_N == 1);\n\n    // The masks to select the warps.\n    const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;\n    const int WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;\n\n    // The divisor for the warps.\n    const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;\n    const int WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;\n\n    // The row/col read by the thread.\n    int smem_read_row, smem_read_col;\n\n    static_assert(USE_LDSMT);\n    static_assert(Base::ROWS_PER_XOR_PATTERN == 8);\n\n    smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 + (tidx & 0x07) + (tidx & 0x08);\n    smem_read_col = (tidx & 0x07);\n    smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16;\n\n    // The shared memory offset.\n    this->smem_read_offset_ = smem_read_row * Base::BYTES_PER_ROW + smem_read_col * BYTES_PER_LDS;\n\n    // Fill zeroes for group conv\n  }\n\n  // Rewind smem_read_offset for last LDS phase in main loop.\n  inline __device__ void reverse_smem_read_offset(int ki = 0) {\n    // The size of each element in bits.\n    const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;\n    // The size in bytes of the data needed to compute an MMA per CTA.\n    const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;\n\n#pragma unroll\n    for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) {\n      // Undo the pointer increment for the next ni.\n      // Should match the load function below for ki = 0.\n      if (BYTES_PER_MMA_PER_CTA >= 128) {\n        // Nothing to do!\n      } else if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1) {\n        this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;\n      } else if (BYTES_PER_MMA_PER_CTA == 64) {\n        // Nothing to do!\n      } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4) {\n        this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);\n      } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2) {\n        this->smem_read_offset_ ^= BYTES_PER_LDS * 2;\n      }\n    }\n\n    // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)\n    if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && Mma_tile::MMAS_N % 2 == 1) {\n      this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;\n    }\n  }\n\n  // Load from shared memory.\n  inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {\n    // The size of each element in bits.\n    const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;\n    // The size in bytes of the data needed to compute an MMA per CTA.\n    const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;\n\n#pragma unroll\n    for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) {\n      // Prepare the offset.\n      int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW;\n      if (BYTES_PER_MMA_PER_CTA == 32) {\n        offset += this->smem_read_offset_;\n      } else if (BYTES_PER_MMA_PER_CTA == 64) {\n        offset += this->smem_read_offset_ + (ni / 2) * BYTES_PER_MMA_PER_CTA * 2;\n      } else {\n        offset += this->smem_read_offset_ + (ni)*BYTES_PER_MMA_PER_CTA;\n      }\n\n      // Load the data using LDSM.MT88.2.\n      uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset;\n      uint4 tmp;\n      if (USE_LDSMT) {\n        ldsmt(tmp, ptr);\n      } else {\n        lds(tmp.x, (ptr) + 0 * Base::BYTES_PER_ROW);\n        lds(tmp.y, (ptr) + 4 * Base::BYTES_PER_ROW);\n        lds(tmp.z, (ptr ^ 32) + 0 * Base::BYTES_PER_ROW);\n        lds(tmp.w, (ptr ^ 32) + 4 * Base::BYTES_PER_ROW);\n      }\n\n      // Store those values in the fragment.\n      b[ni].reg(0) = tmp.x;\n      b[ni].reg(1) = tmp.y;\n      b[ni].reg(2) = tmp.z;\n      b[ni].reg(3) = tmp.w;\n\n      // Move the pointer for the next ni. I expect the compiler to not recompute those.\n      if (BYTES_PER_MMA_PER_CTA >= 128) {\n        // Nothing to do!\n      } else if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1) {\n        this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;\n      } else if (BYTES_PER_MMA_PER_CTA == 64) {\n        // Nothing to do!\n      } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4) {\n        this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);\n      } else if (BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2) {\n        this->smem_read_offset_ ^= BYTES_PER_LDS * 2;\n      }\n    }\n\n    // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)\n    if (BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && Mma_tile::MMAS_N % 2 == 1) {\n      this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;\n    }\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <\n    // The dimensions of the tile computed by the CTA.\n    typename Cta_tile,\n    // The size of the STS.\n    int BYTES_PER_STS,\n    // The number of buffers per tile.\n    int BUFFERS_PER_TILE>\nstruct Smem_tile_b<Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>\n    : public Smem_tile_row_b<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE> {\n  // The base class.\n  using Base = Smem_tile_row_b<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;\n\n  // Ctor.\n  inline __device__ Smem_tile_b(void* smem, int tidx) : Base(smem, tidx) {}\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Cta_tile>\nstruct Smem_tile_v : public fmha::Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, 16, 16, 1, 0, 8, 1> {\n  // The base class.\n  using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, 16, 16, 1, 0, 8, 1>;\n  // The MMA tile.\n  using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n  // The fragment.\n  using Fragment = Fragment_b<fmha::Col>;\n\n  // The size of a single LDS in bytes.\n  enum { BYTES_PER_LDS = 16 };\n\n  // Ctor.\n  inline __device__ Smem_tile_v(void* smem, int tidx) : Base(smem, tidx) {\n    // The row/col read by the thread.\n    int read_row, read_col;\n\n    static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 &&\n                  (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));\n\n    read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f);\n    read_col = (tidx & 0x07);\n    read_col ^= (tidx & 0x10) / 16;\n\n    // The shared memory offset.\n    this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS;\n  }\n\n  // Load from shared memory.\n  inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {\n#pragma unroll\n    for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) {\n      // Jump by 16 * #warps row.\n      int row = ki * 16 * Cta_tile::WARPS_K;\n\n      // Load the data using LDSM.MT88.2.\n      uint4 tmp;\n      fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW);\n      b[ni].reg(0) = tmp.x;\n      b[ni].reg(1) = tmp.y;\n      b[ni].reg(2) = tmp.z;\n      b[ni].reg(3) = tmp.w;\n\n      // Move the pointer for the next ni. I expect the compiler to not recompute those.\n      if (Mma_tile::MMAS_N == 4) {\n        this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);\n      } else {\n        assert(false);  // Not implemented!\n      }\n    }\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Cta_tile>\nstruct Smem_tile_o {\n  // The MMA tile.\n  using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n  // The accumulators.\n  using Accumulator = fmha::Fragment_accumulator;\n  // The accumulators.\n  using Data_type = typename Accumulator::Data_type;\n\n  // The size of each element.\n  enum { BYTES_PER_ELEMENT = sizeof(Data_type) };\n  // The size of each STS.\n  enum { BYTES_PER_STS = 8 };\n  // The size of each row in shared memory.\n  enum { BYTES_PER_ROW = Cta_tile::N * Cta_tile::WARPS_K * BYTES_PER_ELEMENT };\n\n  // The size of each LDS.\n  enum { BYTES_PER_LDS = 16 };\n  enum { THREADS_PER_ROW = 16 };\n\n  // The number of rows.\n  enum { ROWS = Cta_tile::M };\n  // The number of \"rows\" to process per loop iteration (in the \"epilogue\").\n  enum { ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA };\n  // The number of outer loops.\n  enum { LOOPS = ROWS / ROWS_PER_LOOP };\n  // Make sure it matches our expectations.\n  static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, \"\");\n\n  // The number of rows loaded per LDS.\n  enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };\n  // Do we have to guard against partial writes/reads.\n  enum { HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0 };\n  // The total number of LDS per loop.\n  enum { LDS_PER_LOOP = fmha::Div_up<ROWS_PER_LOOP, ROWS_PER_LDS>::VALUE };\n\n  // The amount of shared memory.\n  enum { BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW };\n\n  // The write pointer.\n  uint32_t smem_write_, smem_read_;\n  // Is the thread active for the last LDS of the series?\n  int is_active_for_last_lds_;\n\n  static_assert(BYTES_PER_ROW == 64 * 4 * Cta_tile::WARPS_K);\n  static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, \"\");\n\n  // Ctor.\n  inline __device__ Smem_tile_o(void* smem, int tidx) {\n    // Get a 32-bit value for the shared memory address.\n    uint32_t smem_ = __nvvm_get_smem_pointer(smem);\n\n    static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 &&\n                  (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));\n\n    int write_row = (tidx & 0x1c) / 4;\n    int write_col = (tidx);\n\n    // Assemble the write pointer.\n    smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;\n\n    // The element read by each thread.\n    int read_row = tidx / THREADS_PER_ROW;\n    int read_col = tidx % THREADS_PER_ROW;\n\n    // Take the XOR pattern into account for the column.\n    read_col ^= 2 * (read_row & 0x7);\n\n    // Assemble the read pointer.\n    this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;\n\n    // Is that thread active on the last LDS?\n    if (HAS_INCOMPLETE_LDS) {\n      this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M;\n    }\n  }\n\n  // Load the output fragments.\n  inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const {\n#pragma unroll\n    for (int ii = 0; ii < LDS_PER_LOOP; ++ii) {\n      // Load the elements before the reduction (split-K).\n      uint4 tmp[Cta_tile::WARPS_K];\n#pragma unroll\n      for (int jj = 0; jj < Cta_tile::WARPS_K; ++jj) {\n        int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT;\n        if (!HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_)) {\n          fmha::lds(tmp[jj], this->smem_read_ + imm);\n        }\n      }\n\n      // Perform the reduction.\n      out[ii] = tmp[0];\n#pragma unroll\n      for (int jj = 1; jj < Cta_tile::WARPS_K; ++jj) {\n        out[ii] = fmha::fadd4(out[ii], tmp[jj]);\n      }\n    }\n  }\n  // Store the accumulators.\n  template <int M, int N>\n  inline __device__ void store(const Accumulator (&acc)[M][N], int mi) {\n    enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA };\n#pragma unroll\n    for (int ni = 0; ni < Mma_tile::MMAS_N; ++ni) {\n      // The number of MMAs that are stored per loop iteration.\n      enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS };\n\n// Store 1st column of the different MMAs.\n#pragma unroll\n      for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) {\n        // Precompute the immediates to jump between rows.\n        int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;\n        int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;\n        uint2 tmp0, tmp1;\n        tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0);\n        tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1);\n\n        tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2);\n        tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3);\n\n        // Store.\n        fmha::sts(this->smem_write_ + row_0, tmp0);\n        fmha::sts(this->smem_write_ + row_1, tmp1);\n      }\n\n      // Swizzle the write pointer using a XOR of 16B.\n      this->smem_write_ ^= 32;\n\n// Store 2nd column of the different MMAs.\n#pragma unroll\n      for (int mj = 0; mj < MMAS_M_PER_LOOP; ++mj) {\n        // Precompute the immediates to jump between rows.\n        int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;\n        int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;\n\n        uint2 tmp0, tmp1;\n        tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4);\n        tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5);\n\n        tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6);\n        tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7);\n        // Store.\n        fmha::sts(this->smem_write_ + row_0, tmp0);\n        fmha::sts(this->smem_write_ + row_1, tmp1);\n      }\n\n      // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.\n      this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;\n    }\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Cta_tile>\nstruct Smem_tile_mma {\n  using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n  using Fragment = fmha::Fragment_a<fmha::Col>;\n\n  enum { COLS = Cta_tile::N };\n  enum { BYTES_PER_ELT = 2 };\n  enum { BYTES_PER_STS = 4 };\n  enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT };  // TODO\n  enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW };\n\n  enum { WARPS_M = Cta_tile::WARPS_M };\n  enum { WARPS_N = Cta_tile::WARPS_N };\n  enum { WARPS_K = Cta_tile::WARPS_K };\n\n  static_assert(WARPS_K == 1);\n  inline __device__ Smem_tile_mma(char* smem, int tidx) {\n    smem_ = __nvvm_get_smem_pointer(smem);\n\n    int write_col, write_row;\n    static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);\n    if (WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)) {\n      write_row = (tidx & 0x1c) / 4;\n      write_col = (tidx & 0xe0) / 4 + (tidx & 0x03);\n    } else {\n      write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4;\n      write_col = (tidx & 0x03);\n    }\n    write_col ^= (write_row & 0x07) * 4;\n\n    write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;\n  }\n\n  template <int M, int N>\n  inline __device__ void store(const uint4 (&regs)[M][N]) {\n    static_assert(COLS == Cta_tile::N);\n    for (int mi = 0; mi < M; mi++) {\n      for (int ni = 0; ni < N; ni++) {\n        size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;\n        fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);\n        fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);\n        offset ^= 4 * BYTES_PER_STS;\n        fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);\n        fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);\n      }\n    }\n  }\n\n  uint32_t smem_;\n  uint32_t write_offset_;\n  uint32_t warp_m;\n  uint32_t warp_n;\n  uint32_t lane;\n};\n\ntemplate <typename Cta_tile, typename Base = Smem_tile_mma<Cta_tile>>\nstruct Smem_tile_mma_transposed : public Base {\n  enum { BYTES_PER_LDS = 16 };\n  enum { BYTES_PER_ROW = Base::BYTES_PER_ROW };\n  enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };\n  enum { WARPS_M = Base::WARPS_M };\n  enum { WARPS_N = Base::WARPS_N };\n  static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));\n  using Fragment = typename Base::Fragment;\n  inline __device__ Smem_tile_mma_transposed(char* smem, int tidx) : Base(smem, tidx) {\n    static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));\n    int read_row, read_col;\n    read_row = (tidx & 0x0f);\n    read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16;\n\n    read_col ^= (read_row & 0x07);\n    read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;\n  }\n\n  template <int M, int N>\n  inline __device__ void load(Fragment (&frag)[M][N]) {\n    static_assert(Base::COLS == Cta_tile::N);\n    for (int mi = 0; mi < M; mi++) {\n      for (int ni = 0; ni < N; ni++) {\n        size_t offset = read_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;\n        uint4 dst;\n        fmha::ldsmt(dst, this->smem_ + offset);\n        frag[mi][ni].reg(0) = dst.x;\n        frag[mi][ni].reg(1) = dst.z;  // Fragment A regs col major!\n        frag[mi][ni].reg(2) = dst.y;\n        frag[mi][ni].reg(3) = dst.w;\n      }\n    }\n  }\n\n  uint32_t read_offset_;\n};\n\ntemplate <typename Cta_tile, typename Base = Smem_tile_mma<Cta_tile>>\nstruct Smem_tile_mma_epilogue : public Base {\n  enum { BYTES_PER_LDS = 16 };\n  enum { BYTES_PER_ROW = Base::BYTES_PER_ROW };\n  enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };\n  enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS };\n  static_assert(THREADS_PER_ROW * BYTES_PER_LDS == BYTES_PER_ROW);\n  enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };\n  enum { NUM_LDS = Cta_tile::M / ROWS_PER_LDS };\n  static_assert(NUM_LDS * ROWS_PER_LDS == Cta_tile::M);\n  enum { WARPS_M = Base::WARPS_M };\n  enum { WARPS_N = Base::WARPS_N };\n  static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);\n\n  using Acc = fmha::Fragment_accumulator;\n\n  inline __device__ Smem_tile_mma_epilogue(char* smem, int tidx) : Base(smem, tidx) {\n    const int read_row = tidx / THREADS_PER_ROW;\n    int read_col = tidx % THREADS_PER_ROW;\n    read_col ^= (read_row & 0x07);\n    read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;\n  }\n\n  inline __device__ void load(uint4 (&data)[NUM_LDS]) {\n    for (int ii = 0; ii < NUM_LDS; ii++) {\n      size_t offset = read_offset_ + ii * ROWS_PER_LDS * BYTES_PER_ROW;\n      fmha::lds(data[ii], this->smem_ + offset);\n    }\n  }\n\n  template <int M, int N>\n  inline __device__ void store(const Acc (&acc)[M][N]) {\n#pragma unroll\n    for (int mi = 0; mi < M; mi++) {\n#pragma unroll\n      for (int ni = 0; ni < N; ni++) {\n        // 1st row - 4 elements per row.\n        float tmp00 = acc[mi][ni].elt(0);\n        float tmp01 = acc[mi][ni].elt(1);\n        float tmp02 = acc[mi][ni].elt(4);\n        float tmp03 = acc[mi][ni].elt(5);\n        // 2nd row - 4 elements per row.\n        float tmp10 = acc[mi][ni].elt(2);\n        float tmp11 = acc[mi][ni].elt(3);\n        float tmp12 = acc[mi][ni].elt(6);\n        float tmp13 = acc[mi][ni].elt(7);\n\n        uint32_t x = fmha::float2_to_half2(tmp00, tmp01);\n        uint32_t y = fmha::float2_to_half2(tmp02, tmp03);\n        uint32_t z = fmha::float2_to_half2(tmp10, tmp11);\n        uint32_t w = fmha::float2_to_half2(tmp12, tmp13);\n\n        size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;\n        fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x);\n        fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z);\n        offset ^= 4 * Base::BYTES_PER_STS;\n        fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y);\n        fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w);\n      }\n    }\n  }\n\n  template <int M, int N>\n  inline __device__ void store(const uint4 (&regs)[M][N]) {\n    for (int mi = 0; mi < M; mi++) {\n      for (int ni = 0; ni < N; ni++) {\n        size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;\n        fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);\n        fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);\n        offset ^= 4 * Base::BYTES_PER_STS;\n        fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);\n        fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);\n      }\n    }\n  }\n\n  uint32_t read_offset_;\n};\n\n}  // namespace fmha\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha/softmax.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Sum_ {\n  enum { IS_SUM = 1 };\n  static inline __device__ float apply(float x, float y) { return x + y; }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Max_ {\n  enum { IS_SUM = 0 };\n  static inline __device__ float apply(float x, float y) { return x > y ? x : y; }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ float apply_exp_(float x, float max) { return __expf(x - max); }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int COLS>\nstruct ReadType {};\ntemplate <>\nstruct ReadType<4> {\n  using T = float;\n};\ntemplate <>\nstruct ReadType<8> {\n  using T = float2;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Cta_tile, typename Kernel_traits>\nstruct Smem_tile_reduce {\n  // Helper class to distribute MMA tiles reduced over rows per warp over quads.\n\n  // The Mma tile.\n  using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n\n  // The number of MMAs in M/N dimensions.\n  enum { MMAS_M = Mma_tile::MMAS_M };\n  enum { MMAS_N = Mma_tile::MMAS_N };\n\n  enum { WARPS_M = Cta_tile::WARPS_M };\n  enum { WARPS_N = Cta_tile::WARPS_N };\n\n  static constexpr int ROWS = WARPS_M * MMAS_M * 16;\n  static constexpr int COLS = WARPS_N;\n  static_assert(COLS == 4 || COLS == 8);\n  static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8;\n  static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float);\n  static constexpr int ELTS_PER_TILE = ROWS * COLS;\n\n  static constexpr int THREADS_PER_GROUP = Kernel_traits::Gmem_tile_o::THREADS_PER_ROW;\n  static_assert(THREADS_PER_GROUP == 16);  // DEBUG\n  static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP;\n  static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS;\n  static_assert(LOOPS == 1);\n\n  using read_t = typename ReadType<COLS>::T;\n\n  __device__ inline Smem_tile_reduce(float* smem_, const int tidx) {\n    int lane = tidx % 32;\n    int warp = tidx / 32;\n\n    int warp_m = warp % WARPS_M;\n    int warp_n = warp / WARPS_M;\n\n    qid_ = lane % 4;\n    int qp = lane / 4;\n\n    // Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps.\n    // This won't affect reading as we assume commutative reduction ops.\n    const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN);\n    smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col];\n    smem_read_ = &reinterpret_cast<read_t*>(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_];\n  }\n\n  __device__ inline void store(float (&frag)[2 * MMAS_M]) {\n    if (qid_ == 0) {\n#pragma unroll\n      for (int mi = 0; mi < MMAS_M; mi++) {\n        int offset = mi * 16 * WARPS_N;\n        smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0];\n        smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1];\n      }\n    }\n  }\n\n  __device__ inline void load(read_t (&frag)[2 * MMAS_M]) {\n#pragma unroll\n    for (int mi = 0; mi < MMAS_M; mi++) {\n      int offset = mi * 16 * 4;\n      frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4];\n      frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4];\n    }\n  }\n\n  int qid_;\n  float* smem_write_;\n  read_t* smem_read_;\n};\n\ntemplate <typename Cta_tile, typename Kernel_traits>\nstruct Softmax_base {\n  // The Mma tile.\n  using Mma_tile = fmha::Hmma_tile<Cta_tile>;\n\n  // The number of MMAs in M/N dimensions.\n  enum { MMAS_M = Mma_tile::MMAS_M };\n  enum { MMAS_N = Mma_tile::MMAS_N };\n\n  // The number of groups of warp such that we have at most 4 warps writing consecutive elements.\n  enum { GROUPS = fmha::Div_up<Cta_tile::WARPS_N, 4>::VALUE };\n  // The number of elements that we are going to store per row.\n  enum { ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS };\n  // The number of rows.\n  enum { ROWS = Cta_tile::M * GROUPS };\n  // The total number of elements.\n  enum { ELEMENTS = ROWS * ELEMENTS_PER_ROW };\n\n  // Ctor.\n  template <typename Params>\n  inline __device__ Softmax_base(const Params& params, void* smem, int bidb, int tidx)\n      :  // packed_mask_ptr_(reinterpret_cast<const char*>(params.packed_mask_ptr)),\n        smem_(reinterpret_cast<float*>(smem)),\n        tidx_(tidx) {\n    // Move to the 1st mask loaded by the thread+ tidx;\n    // packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t);\n\n    // Extract the position in the warp.\n    int warp = tidx / Cta_tile::THREADS_PER_WARP;\n    int lane = tidx % Cta_tile::THREADS_PER_WARP;\n\n    // Decompose the warp index into M and N.\n    int warp_m = warp % Cta_tile::WARPS_M;\n    int warp_n = warp / Cta_tile::WARPS_M;\n\n    // Decompose the warp-n index into group/position-inside-the-group.\n    int warp_g = warp_n / ELEMENTS_PER_ROW;\n    int warp_i = warp_n % ELEMENTS_PER_ROW;\n\n    // The location written by the threads.\n    int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4;\n    int write_col = warp_i;\n\n    // Assemble the write pointer.\n    smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col];\n\n    // Assemble the read pointer.\n    smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4];\n  }\n\n  template <typename Mask>\n  inline __device__ void apply_mask(const Mask& mask) {\n#pragma unroll\n    for (int mi = 0; mi < MMAS_M; ++mi) {\n#pragma unroll\n      for (int ii = 0; ii < 2; ++ii) {\n#pragma unroll\n        for (int ni = 0; ni < MMAS_N; ++ni) {\n#pragma unroll\n          for (int jj = 0; jj < 4; ++jj) {\n            if (!mask.is_valid(mi, ni, ii, jj)) {\n              elt_[2 * mi + ii][4 * ni + jj] = -INFINITY;\n            }\n          }\n        }\n      }\n    }\n  }\n\n  // Apply the exp to all the elements.\n  inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) {\n#pragma unroll\n    for (int mi = 0; mi < MMAS_M * 2; ++mi) {\n#pragma unroll\n      for (int ni = 0; ni < MMAS_N * 4; ++ni) {\n        elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]);\n      }\n    }\n  }\n\n  // Scale all the elements.\n  inline __device__ void scale(const float (&sum)[MMAS_M * 2]) {\n    // Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.\n    float inv_sum[MMAS_M * 2];\n#pragma unroll\n    for (int mi = 0; mi < MMAS_M * 2; ++mi) {\n      inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi];\n    }\n\n// Update the values.\n#pragma unroll\n    for (int mi = 0; mi < MMAS_M * 2; ++mi) {\n#pragma unroll\n      for (int ni = 0; ni < MMAS_N * 4; ++ni) {\n        elt_[mi][ni] *= inv_sum[mi];\n      }\n    }\n  }\n\n  // The pointer to the mask.\n  const char* packed_mask_ptr_;\n  // Shared memory for the CTA-wide reduction.\n  float *smem_, *smem_write_, *smem_read_;\n  // The current thread index.\n  int tidx_;\n  // The elements.\n  float elt_[MMAS_M * 2][MMAS_N * 4];\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Cta_tile, typename Kernel_traits>\nstruct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {\n  // The base class.\n  using Base = Softmax_base<Cta_tile, Kernel_traits>;\n  // The fragment.\n  using Fragment_a = fmha::Fragment_a<fmha::Row>;\n\n  static_assert(Fragment_a::NUM_REGS == 4);\n\n  enum { WARPS_M = Cta_tile::WARPS_M };\n  enum { WARPS_N = Cta_tile::WARPS_N };\n  // The MMAs.\n  enum { MMAS_M = Base::MMAS_M };\n  enum { MMAS_N = Base::MMAS_N };\n\n  // The accumulators.\n  using Accumulator = fmha::Fragment_accumulator;\n  using Accumulator_out = Fragment<uint16_t, 8>;\n  static_assert(Accumulator_out::NUM_REGS == 4);\n\n  static_assert(std::is_same<Accumulator::Data_type, float>::value);\n\n  using Smem_tile_red = Smem_tile_reduce<Cta_tile, Kernel_traits>;\n  static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N);\n  // Ctor.\n  template <typename Params>\n  inline __device__ Softmax(const Params& params, void* smem, int bidb, int tidx)\n      : Base(params, smem, bidb, tidx),\n        params_scale_bmm1_(params.scale_bmm1),\n        smem_sum_(static_cast<float*>(smem), tidx),\n        smem_max_(static_cast<float*>(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) {}\n\n  // Pack the data to a fragment for the next GEMM.\n  template <int K, int M>\n  inline __device__ void pack(Fragment_a (&dst)[K][M]) const {\n#pragma unroll\n    for (int mi = 0; mi < M; ++mi) {\n#pragma unroll\n      for (int ki = 0; ki < K; ++ki) {\n        // 1st row - 4 elements per row.\n        float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0];\n        float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1];\n        float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2];\n        float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3];\n\n        // 2nd row - 4 elements per row.\n        float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0];\n        float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1];\n        float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2];\n        float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3];\n\n        // Pack to 4 registers.\n        dst[ki][mi].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01);\n        dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11);\n        dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03);\n        dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13);\n      }\n    }\n  }\n\n  // Scale FP32 fragments\n  inline __device__ void unpack(const Accumulator (&acc)[MMAS_M][MMAS_N]) {\n    const float scalef = reinterpret_cast<const float&>(this->params_scale_bmm1_);\n\n#pragma unroll\n    for (int mi = 0; mi < MMAS_M; ++mi) {\n#pragma unroll\n      for (int ni = 0; ni < MMAS_N; ++ni) {\n        // 1st row - 4 elements per row.\n        this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef;\n        this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef;\n        this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef;\n        this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef;\n        // 2nd row - 4 elements per row.\n        this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef;\n        this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef;\n        this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef;\n        this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef;\n      }\n    }\n  }\n  // Scale FP32 fragments\n  inline __device__ void unpack_noscale(const Accumulator (&acc)[MMAS_M][MMAS_N]) {\n#pragma unroll\n    for (int mi = 0; mi < MMAS_M; ++mi) {\n#pragma unroll\n      for (int ni = 0; ni < MMAS_N; ++ni) {\n        // 1st row - 4 elements per row.\n        this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0);\n        this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1);\n        this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4);\n        this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5);\n        // 2nd row - 4 elements per row.\n        this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2);\n        this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3);\n        this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6);\n        this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7);\n      }\n    }\n  }\n\n  template <typename Operator>\n  __device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator& op, Smem_tile_red& smem_red) {\n    for (int mi = 0; mi < 2 * MMAS_M; mi++) {\n      frag[mi] = this->elt_[mi][0];\n      for (int ni = 1; ni < 4 * MMAS_N; ni++) {\n        frag[mi] = op(frag[mi], this->elt_[mi][ni]);\n      }\n    }\n    quad_reduce(frag, frag, op);\n\n    smem_red.store(frag);\n    __syncthreads();\n    typename Smem_tile_red::read_t tmp[2 * MMAS_M];\n    smem_red.load(tmp);\n\n    quad_allreduce(frag, tmp, op);\n  }\n\n  __device__ inline void reduce_max(float (&frag)[2 * MMAS_M]) {\n    MaxOp<float> max;\n    reduce_(frag, max, smem_max_);\n  }\n\n  __device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]) {\n    SumOp<float> sum;\n    reduce_(frag, sum, smem_sum_);\n  }\n\n  const uint32_t params_scale_bmm1_;\n  Smem_tile_red smem_max_;\n  Smem_tile_red smem_sum_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha/utils.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <assert.h>\n#include <stdint.h>\n#include <stdlib.h>\n\nextern \"C\" __device__ uint32_t __nvvm_get_smem_pointer(void* ptr);\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Row {};\nstruct Col {};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int M, bool = (M & (M - 1)) == 0>\nstruct Next_power_of_two {};\n\ntemplate <int M>\nstruct Next_power_of_two<M, true> {\n  enum { VALUE = M };\n};\ntemplate <>\nstruct Next_power_of_two<3, false> {\n  enum { VALUE = 4 };\n};\ntemplate <>\nstruct Next_power_of_two<5, false> {\n  enum { VALUE = 8 };\n};\ntemplate <>\nstruct Next_power_of_two<6, false> {\n  enum { VALUE = 8 };\n};\ntemplate <>\nstruct Next_power_of_two<7, false> {\n  enum { VALUE = 8 };\n};\ntemplate <>\nstruct Next_power_of_two<9, false> {\n  enum { VALUE = 16 };\n};\ntemplate <>\nstruct Next_power_of_two<10, false> {\n  enum { VALUE = 16 };\n};\ntemplate <>\nstruct Next_power_of_two<11, false> {\n  enum { VALUE = 16 };\n};\ntemplate <>\nstruct Next_power_of_two<12, false> {\n  enum { VALUE = 16 };\n};\ntemplate <>\nstruct Next_power_of_two<13, false> {\n  enum { VALUE = 16 };\n};\ntemplate <>\nstruct Next_power_of_two<14, false> {\n  enum { VALUE = 16 };\n};\ntemplate <>\nstruct Next_power_of_two<15, false> {\n  enum { VALUE = 16 };\n};\ntemplate <>\nstruct Next_power_of_two<24, false> {\n  enum { VALUE = 32 };\n};\ntemplate <>\nstruct Next_power_of_two<48, false> {\n  enum { VALUE = 64 };\n};\ntemplate <>\nstruct Next_power_of_two<80, false> {\n  enum { VALUE = 128 };\n};\ntemplate <>\nstruct Next_power_of_two<96, false> {\n  enum { VALUE = 128 };\n};\ntemplate <>\nstruct Next_power_of_two<112, false> {\n  enum { VALUE = 128 };\n};\ntemplate <>\nstruct Next_power_of_two<144, false> {\n  enum { VALUE = 256 };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N, bool = (N & (N - 1)) == 0>\nstruct Prev_power_of_two {};\n\ntemplate <int N>\nstruct Prev_power_of_two<N, true> {\n  enum { VALUE = N };\n};\ntemplate <>\nstruct Prev_power_of_two<3, false> {\n  enum { VALUE = 2 };\n};\ntemplate <>\nstruct Prev_power_of_two<5, false> {\n  enum { VALUE = 4 };\n};\ntemplate <>\nstruct Prev_power_of_two<6, false> {\n  enum { VALUE = 4 };\n};\ntemplate <>\nstruct Prev_power_of_two<7, false> {\n  enum { VALUE = 4 };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int M, int N>\nstruct Div_up {\n  enum { VALUE = (M + N - 1) / N };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int A, int B>\nstruct Max {\n  enum { VALUE = A >= B ? A : B };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int A, int B, int C>\nstruct Max_3 {\n  enum { VALUE = Max<Max<A, B>::VALUE, C>::VALUE };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int A, int B>\nstruct Min {\n  enum { VALUE = A <= B ? A : B };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int SIZE_IN_BYTES>\nstruct Uint_from_size_in_bytes {};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <>\nstruct Uint_from_size_in_bytes<1> {\n  using Type = uint8_t;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <>\nstruct Uint_from_size_in_bytes<2> {\n  using Type = uint16_t;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <>\nstruct Uint_from_size_in_bytes<4> {\n  using Type = uint32_t;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <>\nstruct Uint_from_size_in_bytes<8> {\n  using Type = uint2;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <>\nstruct Uint_from_size_in_bytes<16> {\n  using Type = uint4;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int WARPS_M, int WARPS_N, int WARPS_K>\nstruct Warp_masks {};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <>\nstruct Warp_masks<8, 1, 1> {\n  enum { M = 0xe0, N = 0x00, K = 0x00 };\n};\ntemplate <>\nstruct Warp_masks<4, 2, 1> {\n  enum { M = 0x60, N = 0x80, K = 0x00 };\n};\ntemplate <>\nstruct Warp_masks<4, 1, 2> {\n  enum { M = 0x60, N = 0x00, K = 0x80 };\n};\ntemplate <>\nstruct Warp_masks<4, 1, 1> {\n  enum { M = 0x60, N = 0x00, K = 0x00 };\n};\ntemplate <>\nstruct Warp_masks<2, 4, 1> {\n  enum { M = 0x20, N = 0xc0, K = 0x00 };\n};\ntemplate <>\nstruct Warp_masks<2, 2, 2> {\n  enum { M = 0x20, N = 0x40, K = 0x80 };\n};\ntemplate <>\nstruct Warp_masks<2, 2, 1> {\n  enum { M = 0x20, N = 0x40, K = 0x00 };\n};\ntemplate <>\nstruct Warp_masks<2, 1, 2> {\n  enum { M = 0x20, N = 0x00, K = 0x40 };\n};\ntemplate <>\nstruct Warp_masks<2, 1, 1> {\n  enum { M = 0x20, N = 0x00, K = 0x00 };\n};\ntemplate <>\nstruct Warp_masks<1, 8, 1> {\n  enum { M = 0x00, N = 0xe0, K = 0x00 };\n};\ntemplate <>\nstruct Warp_masks<1, 4, 2> {\n  enum { M = 0x00, N = 0x60, K = 0x80 };\n};\ntemplate <>\nstruct Warp_masks<1, 4, 1> {\n  enum { M = 0x00, N = 0x60, K = 0x00 };\n};\ntemplate <>\nstruct Warp_masks<1, 2, 2> {\n  enum { M = 0x00, N = 0x20, K = 0x40 };\n};\ntemplate <>\nstruct Warp_masks<1, 2, 1> {\n  enum { M = 0x00, N = 0x20, K = 0x00 };\n};\ntemplate <>\nstruct Warp_masks<1, 1, 4> {\n  enum { M = 0x00, N = 0x00, K = 0x60 };\n};\ntemplate <>\nstruct Warp_masks<1, 1, 2> {\n  enum { M = 0x00, N = 0x00, K = 0x20 };\n};\ntemplate <>\nstruct Warp_masks<1, 1, 1> {\n  enum { M = 0x00, N = 0x00, K = 0x00 };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T>\ninline __device__ __host__ T div_up(T m, T n) {\n  return (m + n - 1) / n;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline int clz(int x) {\n  for (int i = 31; i >= 0; --i) {\n    if ((1 << i) & x) {\n      return 31 - i;\n    }\n  }\n  return 32;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline int find_log_2(int x, bool round_up = false) {\n  int a = 31 - clz(x);\n  if (round_up) {\n    a += (x & (x - 1)) ? 1 : 0;\n  }\n  return a;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) {\n  uint32_t c;\n  asm volatile(\"add.f16x2 %0, %1, %2;\\n\" : \"=r\"(c) : \"r\"(a), \"r\"(b));\n  return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) {\n  uint32_t c;\n  asm volatile(\"min.f16x2 %0, %1, %2;\" : \"=r\"(c) : \"r\"(a), \"r\"(b));\n  return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hmul2(uint32_t a, uint32_t b) {\n  uint32_t c;\n  asm volatile(\"mul.f16x2 %0, %1, %2;\\n\" : \"=r\"(c) : \"r\"(a), \"r\"(b));\n  return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint2 hmul4(uint2 a, uint2 b) {\n  uint2 c;\n  c.x = hmul2(a.x, b.x);\n  c.y = hmul2(a.y, b.y);\n  return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint4 hmul8(uint4 a, uint4 b) {\n  uint4 c;\n  c.x = hmul2(a.x, b.x);\n  c.y = hmul2(a.y, b.y);\n  c.z = hmul2(a.z, b.z);\n  c.w = hmul2(a.w, b.w);\n  return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint4 hmul8(uint32_t a, uint4 b) {\n  uint4 c;\n  c.x = hmul2(a, b.x);\n  c.y = hmul2(a, b.y);\n  c.z = hmul2(a, b.z);\n  c.w = hmul2(a, b.w);\n  return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) {\n  uint32_t res;\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n  asm volatile(\"max.f16x2 %0, %1, %2;\\n\" : \"=r\"(res) : \"r\"(x), \"r\"(lb));\n#else\n  const uint32_t zero = 0u;\n  asm volatile(\n      \"{\\n\"\n      \"\\t .reg .f16x2 sela;\\n\"\n      \"\\t set.gtu.u32.f16x2 sela, %1, %2;\\n\"\n      \"\\t and.b32 %0, sela, %1;\\n\"\n      \"}\\n\"\n      : \"=r\"(res)\n      : \"r\"(x), \"r\"(zero));\n#endif\n  return res;\n}\nstatic inline __device__ uint32_t habs2(uint32_t x) {\n  uint32_t res;\n  asm volatile(\"abs.f16x2 %0, %1;\\n\" : \"=r\"(res) : \"r\"(x));\n  return res;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\ntemplate <typename T>\nstatic inline __device__ T clamp(T x, T lb, T ub) {\n  return x < lb ? lb : (x > ub ? ub : x);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint16_t clamp_to_zero(uint16_t x) {\n  uint16_t mask;\n  asm volatile(\"set.gtu %0, %1, 0;\" : \"=h\"(mask) : \"h\"(x));\n  return mask & x;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint16_t float_to_half(float f) {\n  uint16_t h;\n  asm volatile(\"cvt.rn.f16.f32 %0, %1;\" : \"=h\"(h) : \"f\"(f));\n  return h;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t float2_to_half2(float a, float b) {\n  uint32_t c;\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n  asm volatile(\"cvt.rn.f16x2.f32 %0, %1, %2;\\n\" : \"=r\"(c) : \"f\"(b), \"f\"(a));\n#else\n  uint16_t lo = float_to_half(a);\n  uint16_t hi = float_to_half(b);\n  asm volatile(\"mov.b32 %0, {%1, %2};\\n\" : \"=r\"(c) : \"h\"(lo), \"h\"(hi));\n#endif\n  return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t float_to_half2(float a) { return float2_to_half2(a, a); }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t float2_to_half2(const float2& f) { return float2_to_half2(f.x, f.y); }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint2 float4_to_half4(float x, float y, float z, float w) {\n  uint2 d;\n  d.x = float2_to_half2(x, y);\n  d.y = float2_to_half2(z, w);\n  return d;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) {\n  uint32_t d;\n  asm volatile(\"fma.rn.f16x2 %0, %1, %2, %3;\\n\" : \"=r\"(d) : \"r\"(a), \"r\"(b), \"r\"(c));\n  return d;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c) {\n  uint32_t d;\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n  asm volatile(\"fma.rn.f16x2.relu %0, %1, %2, %3;\" : \"=r\"(d) : \"r\"(a), \"r\"(b), \"r\"(c));\n#else\n  d = hrelu2(hfma2(a, b, c));\n#endif\n  return d;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t h0_h0(uint32_t x) {\n  uint32_t y;\n  asm volatile(\"{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\\n\" : \"=r\"(y) : \"r\"(x));\n  return y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ float h0_to_float(uint32_t h2) {\n  float f;\n  asm volatile(\n      \"{\\n\"\n      \".reg .f16 lo, hi;\\n\"\n      \"mov.b32 {lo, hi}, %1;\\n\"\n      \"cvt.f32.f16 %0, lo;\\n\"\n      \"}\\n\"\n      : \"=f\"(f)\n      : \"r\"(h2));\n  return f;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t h1_h1(uint32_t x) {\n  uint32_t y;\n  asm volatile(\"{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\\n\" : \"=r\"(y) : \"r\"(x));\n  return y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint16_t hadd(uint16_t a, uint16_t b) {\n  uint16_t d;\n  asm volatile(\"add.f16 %0, %1, %2;\" : \"=h\"(d) : \"h\"(a), \"h\"(b));\n  return d;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint32_t hadd(uint32_t a, uint32_t b) { return hadd2(a, b); }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint2 hadd4(uint2 a, uint2 b) {\n  uint2 c;\n  c.x = hadd2(a.x, b.x);\n  c.y = hadd2(a.y, b.y);\n  return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint2 hadd(uint2 a, uint2 b) { return hadd4(a, b); }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint4 hadd8(uint4 a, uint4 b) {\n  uint4 c;\n  c.x = hadd2(a.x, b.x);\n  c.y = hadd2(a.y, b.y);\n  c.z = hadd2(a.z, b.z);\n  c.w = hadd2(a.w, b.w);\n  return c;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint4 fadd4(uint4 a, uint4 b) {\n  float4 c;\n  c.x = reinterpret_cast<const float&>(a.x) + reinterpret_cast<const float&>(b.x);\n  c.y = reinterpret_cast<const float&>(a.y) + reinterpret_cast<const float&>(b.y);\n  c.z = reinterpret_cast<const float&>(a.z) + reinterpret_cast<const float&>(b.z);\n  c.w = reinterpret_cast<const float&>(a.w) + reinterpret_cast<const float&>(b.w);\n  return reinterpret_cast<const uint4&>(c);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint4 hadd(uint4 a, uint4 b) { return hadd8(a, b); }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ float half_to_float(uint16_t h) {\n  float f;\n  asm volatile(\"cvt.f32.f16 %0, %1;\\n\" : \"=f\"(f) : \"h\"(h));\n  return f;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ float2 half2_to_float2(uint32_t x) {\n  uint16_t lo, hi;\n  asm volatile(\"mov.b32 {%0, %1}, %2;\\n\" : \"=h\"(lo), \"=h\"(hi) : \"r\"(x));\n  return make_float2(half_to_float(lo), half_to_float(hi));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ void half2_to_float2(float& x, float& y, uint32_t h) {\n  float2 tmp = half2_to_float2(h);\n  x = tmp.x;\n  y = tmp.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c) {\n  uint16_t d;\n  asm volatile(\"fma.rn.f16 %0, %1, %2, %3;\" : \"=h\"(d) : \"h\"(a), \"h\"(b), \"h\"(c));\n  return d;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ uint16_t hmul(uint16_t a, uint16_t b) {\n  uint16_t d;\n  asm volatile(\"mul.f16 %0, %1, %2;\" : \"=h\"(d) : \"h\"(a), \"h\"(b));\n  return d;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ float sigmoid(float x) { return 1.f / (1.f + expf(-x)); }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void clear(uint16_t& dst) { dst = uint16_t(0); }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void clear(uint32_t& dst) { dst = 0u; }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void clear(uint2& dst) { dst = make_uint2(0u, 0u); }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void clear(uint4& dst) { dst = make_uint4(0u, 0u, 0u, 0u); }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// P R E D I C A T E   P A C K I N G\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\nenum { BYTES_PER_REG = 4, PREDS_PER_BYTE = 4, PREDS_PER_REG = BYTES_PER_REG * PREDS_PER_BYTE };\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// G E N E R I C   P R E D I C A T E D   L D G S T S\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N, int M, typename Functor>\ninline __device__ void load_(Functor& fct, const uint32_t (&preds)[M]) {\n  // The number of complete bytes (where we use all the predicates in a byte).\n  enum { COMPLETE = N / PREDS_PER_BYTE };\n  // Make sure we did allocate enough predicates.\n  static_assert(Div_up<COMPLETE, BYTES_PER_REG>::VALUE <= M, \"\");\n  // The remainder.\n  enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE };\n  // Make sure we got the math right and the remainder is between 0 and 3.\n  static_assert(REMAINDER >= 0 && REMAINDER <= 3, \"\");\n  // The mask to extract the predicates.\n  enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 };\n\n// Clear the fetch registers.\n#pragma unroll\n  for (int ii = 0; ii < N; ++ii) {\n    fct.clear(ii);\n  }\n\n  // Run complete steps.\n  bool p[PREDS_PER_BYTE];\n#pragma unroll\n  for (int ii = 0; ii < COMPLETE; ++ii) {\n    // The predicate.\n    uint32_t reg = preds[ii / BYTES_PER_REG];\n\n// Extract the predicates.\n#pragma unroll\n    for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) {\n      uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj);\n      p[jj] = (reg & mask) != 0u;\n    }\n\n// Issue the loads.\n#pragma unroll\n    for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) {\n      fct.load(ii * PREDS_PER_BYTE + jj, p[jj]);\n    }\n  }\n\n  // Skip the rest of the code if we do not have a remainder.\n  if (REMAINDER > 0) {\n    // The mask to extract the predicates.\n    enum { REMAINDER_MASK = (1 << REMAINDER) - 1 };\n\n    // The predicate register.\n    uint32_t reg = preds[COMPLETE / BYTES_PER_REG];\n\n// Extract the predicates.\n#pragma unroll\n    for (int jj = 0; jj < PREDS_PER_BYTE; ++jj) {\n      uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj);\n      p[jj] = (reg & mask) != 0u;\n    }\n\n// Issue the loads.\n#pragma unroll\n    for (int ii = 0; ii < REMAINDER; ++ii) {\n      fct.load(COMPLETE * PREDS_PER_BYTE + ii, p[ii]);\n    }\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int M, typename Functor>\ninline __device__ void load_(Functor& fct, uint32_t preds) {\n  uint32_t tmp[1] = {preds};\n  load_<M>(fct, tmp);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// L D G\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldg(uint8_t& dst, const void* ptr) { dst = *reinterpret_cast<const uint8_t*>(ptr); }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldg(uint16_t& dst, const void* ptr) { dst = *reinterpret_cast<const uint16_t*>(ptr); }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldg(uint32_t& dst, const void* ptr) { dst = *reinterpret_cast<const uint32_t*>(ptr); }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldg(uint2& dst, const void* ptr) { dst = *reinterpret_cast<const uint2*>(ptr); }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldg(uint4& dst, const void* ptr) { dst = *reinterpret_cast<const uint4*>(ptr); }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Data_type, int N>\nstruct Ldg_functor {\n  // Ctor.\n  inline __device__ Ldg_functor(Data_type (&fetch)[N], const void* (&ptrs)[N]) : fetch_(fetch), ptrs_(ptrs) {}\n\n  // Clear the element.\n  inline __device__ void clear(int ii) { fmha::clear(fetch_[ii]); }\n\n  // Trigger the loads.\n  inline __device__ void load(int ii, bool p) {\n    if (p) {\n      ldg(fetch_[ii], ptrs_[ii]);\n    }\n  }\n\n  // The fetch registers.\n  Data_type (&fetch_)[N];\n  // The pointers.\n  const void* (&ptrs_)[N];\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Data_type, int N, int M>\ninline __device__ void ldg_(Data_type (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {\n  Ldg_functor<Data_type, N> fct(fetch, ptrs);\n  load_<N>(fct, preds);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N, int M>\ninline __device__ void ldg(uint8_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {\n  ldg_<uint8_t, N>(fetch, ptrs, preds);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N, int M>\ninline __device__ void ldg(uint16_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {\n  ldg_<uint16_t, N>(fetch, ptrs, preds);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N, int M>\ninline __device__ void ldg(uint32_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {\n  ldg_<uint32_t, N>(fetch, ptrs, preds);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N, int M>\ninline __device__ void ldg(uint2 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {\n  ldg_<uint2, N>(fetch, ptrs, preds);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N, int M>\ninline __device__ void ldg(uint4 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {\n  ldg_<uint4, N>(fetch, ptrs, preds);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// L D S\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void lds(uint16_t& dst, uint32_t ptr) {\n  asm volatile(\"ld.shared.b16 %0, [%1];\\n\" : \"=h\"(dst) : \"r\"(ptr));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void lds(uint32_t& dst, uint32_t ptr) {\n  asm volatile(\"ld.shared.b32 %0, [%1];\\n\" : \"=r\"(dst) : \"r\"(ptr));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void lds(uint2& dst, uint32_t ptr) {\n  asm volatile(\"ld.shared.v2.b32 {%0, %1}, [%2];\\n\" : \"=r\"(dst.x), \"=r\"(dst.y) : \"r\"(ptr));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void lds(uint4& dst, uint32_t ptr) {\n  asm volatile(\"ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\\n\"\n               : \"=r\"(dst.x), \"=r\"(dst.y), \"=r\"(dst.z), \"=r\"(dst.w)\n               : \"r\"(ptr));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// L D S M\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldsm(uint32_t& dst, uint32_t ptr) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730\n  asm volatile(\"ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\\n\" : \"=r\"(dst) : \"r\"(ptr));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldsmt(uint32_t& dst, uint32_t ptr) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730\n  asm volatile(\"ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\\n\" : \"=r\"(dst) : \"r\"(ptr));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldsm(uint2& dst, uint32_t ptr) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730\n  asm volatile(\"ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\\n\" : \"=r\"(dst.x), \"=r\"(dst.y) : \"r\"(ptr));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldsmt(uint2& dst, uint32_t ptr) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730\n  asm volatile(\"ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\\n\"\n               : \"=r\"(dst.x), \"=r\"(dst.y)\n               : \"r\"(ptr));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldsm(uint4& dst, uint32_t ptr) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730\n  asm volatile(\"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\\n\"\n               : \"=r\"(dst.x), \"=r\"(dst.y), \"=r\"(dst.z), \"=r\"(dst.w)\n               : \"r\"(ptr));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void ldsmt(uint4& dst, uint32_t ptr) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730\n  asm volatile(\"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\\n\"\n               : \"=r\"(dst.x), \"=r\"(dst.y), \"=r\"(dst.z), \"=r\"(dst.w)\n               : \"r\"(ptr));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// S T G\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void stg(void* ptr, uint8_t val) { *reinterpret_cast<uint8_t*>(ptr) = val; }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void stg(void* ptr, uint16_t val) { *reinterpret_cast<uint16_t*>(ptr) = val; }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void stg(void* ptr, uint32_t val) { *reinterpret_cast<uint32_t*>(ptr) = val; }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void stg(void* ptr, uint2 val) { *reinterpret_cast<uint2*>(ptr) = val; }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void stg(void* ptr, uint4 val) { *reinterpret_cast<uint4*>(ptr) = val; }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// S T S\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void sts(uint32_t ptr, uint16_t val) {\n  asm volatile(\"st.shared.b16 [%0], %1;\\n\" : : \"r\"(ptr), \"h\"(val));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void sts(uint32_t ptr, uint32_t val) {\n  asm volatile(\"st.shared.b32 [%0], %1;\\n\" : : \"r\"(ptr), \"r\"(val));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void sts(uint32_t ptr, uint2 val) {\n  asm volatile(\"st.shared.v2.b32 [%0], {%1, %2};\\n\" : : \"r\"(ptr), \"r\"(val.x), \"r\"(val.y));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void sts(uint32_t ptr, uint4 val) {\n  asm volatile(\"st.shared.v4.b32 [%0], {%1, %2, %3, %4};\\n\"\n               :\n               : \"r\"(ptr), \"r\"(val.x), \"r\"(val.y), \"r\"(val.z), \"r\"(val.w));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Data_type, int N>\ninline __device__ void sts_(uint32_t (&ptrs)[N], const Data_type (&data)[N]) {\n#pragma unroll\n  for (int ii = 0; ii < N; ++ii) {\n    sts(ptrs[ii], data[ii]);\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\ninline __device__ void sts(uint32_t (&ptrs)[N], const uint16_t (&data)[N]) {\n  sts_<uint16_t, N>(ptrs, data);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\ninline __device__ void sts(uint32_t (&ptrs)[N], const uint32_t (&data)[N]) {\n  sts_<uint32_t, N>(ptrs, data);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\ninline __device__ void sts(uint32_t (&ptrs)[N], const uint2 (&data)[N]) {\n  sts_<uint2, N>(ptrs, data);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\ninline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) {\n  sts_<uint4, N>(ptrs, data);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T>\nstruct MaxOp {\n  __device__ inline T operator()(T const& x, T const& y) { return x > y ? x : y; }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T>\nstruct SumOp {\n  __device__ inline T operator()(T const& x, T const& y) { return x + y; }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int THREADS>\nstruct Allreduce {\n  static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);\n  template <typename T, typename Operator>\n  static __device__ inline T run(T x, Operator& op) {\n    constexpr int OFFSET = THREADS / 2;\n    x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));\n    return Allreduce<OFFSET>::run(x, op);\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <>\nstruct Allreduce<2> {\n  template <typename T, typename Operator>\n  static __device__ inline T run(T x, Operator& op) {\n    x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));\n    return x;\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Operator, int M>\n__device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator& op) {\n#pragma unroll\n  for (int mi = 0; mi < M; mi++) {\n    dst[mi] = src[mi];\n    dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2));\n    dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1));\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Operator, int M>\n__device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator& op) {\n  float tmp[M];\n#pragma unroll\n  for (int mi = 0; mi < M; mi++) {\n    tmp[mi] = op(src[mi].x, src[mi].y);\n  }\n  quad_reduce(dst, tmp, op);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Operator, int M>\n__device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator& op) {\n#pragma unroll\n  for (int mi = 0; mi < M; mi++) {\n    dst[mi] = src[mi];\n    dst[mi] = Allreduce<4>::run(dst[mi], op);\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Operator, int M>\n__device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator& op) {\n  float tmp[M];\n#pragma unroll\n  for (int mi = 0; mi < M; mi++) {\n    tmp[mi] = op(src[mi].x, src[mi].y);\n  }\n  quad_allreduce(dst, tmp, op);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <cuda.h>\n\n#include <vector>\n\n#ifdef OLD_GENERATOR_PATH\n#include <ATen/CUDAGeneratorImpl.h>\n#else\n#include <ATen/cuda/CUDAGeneratorImpl.h>\n#endif\n\n#include <fmha_utils.h>\n\n#include <ATen/cuda/CUDAGraphsUtils.cuh>\n\nconstexpr int TOTAL_DIM = 0;\nconstexpr int THREE_DIM = 1;\nconstexpr int H_DIM = 2;\nconstexpr int D_DIM = 3;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Qkv_params {\n  // The QKV matrices.\n  void* __restrict__ qkv_ptr;\n\n  // The stride between rows of the Q, K and V matrices.\n  size_t qkv_stride_in_bytes;\n\n  // The number of heads.\n  int h;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Fused_multihead_attention_fprop_params : public Qkv_params {\n  // The dQKV matrices.\n  void* __restrict__ dqkv_ptr;\n\n  // Temporary for dKV.\n  void* __restrict__ dkv_ptr;\n\n  // The O matrix (output).\n  void* __restrict__ o_ptr;\n\n  // The stride between rows of O.\n  int64_t o_stride_in_bytes;\n\n  // The pointer to the S matrix, overwritten by the dP matrix (bwd).\n  void* __restrict__ s_ptr;\n  // The stride between rows of the S matrix.\n  int64_t s_stride_in_bytes;\n\n  // The dimensions.\n  int b, s, d;\n\n  // The scaling factors for the kernel.\n  uint32_t scale_bmm1, scale_softmax, scale_bmm2;\n\n  // array of length b+1 holding starting offset of each sequence.\n  int* __restrict__ cu_seqlens;\n\n  // The dropout probability (probability of keeping an activation).\n  float p_dropout;\n\n  // Scale factor of 1 / (1 - p_dropout).\n  float rp_dropout;\n\n  // Scale factor of 1 / (1 - p_dropout), in half2.\n  uint32_t scale_dropout;\n\n  // Random state.\n  at::PhiloxCudaState philox_args;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Kernel_params>\nstruct Launch_params {\n  Launch_params(cudaDeviceProp* props_, cudaStream_t stream_, bool is_training_, bool is_nl_)\n      : elts_per_thread(0), props(props_), stream(stream_), is_training(is_training_), is_nl(is_nl_) {}\n\n  size_t elts_per_thread;\n\n  cudaDeviceProp* props;\n\n  cudaStream_t stream;\n\n  bool is_training;\n\n  Kernel_params params;\n  int num_full_heads;\n  int num_main_groups;\n  int heads_last_wave;\n  int main_steps;\n  int rest_steps;\n  bool is_nl;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nvoid run_fmha_fp16_128_64_sm80(Launch_params<Fused_multihead_attention_fprop_params>& launch_params,\n                               const bool configure);\nvoid run_fmha_fp16_256_64_sm80(Launch_params<Fused_multihead_attention_fprop_params>& launch_params,\n                               const bool configure);\nvoid run_fmha_fp16_384_64_sm80(Launch_params<Fused_multihead_attention_fprop_params>& launch_params,\n                               const bool configure);\nvoid run_fmha_fp16_512_64_sm80(Launch_params<Fused_multihead_attention_fprop_params>& launch_params,\n                               const bool configure);\n\nvoid run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params& params, cudaStream_t stream);\nvoid run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params& params, cudaStream_t stream);\nvoid run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params& params, cudaStream_t stream);\nvoid run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params& params, cudaStream_t stream);\n\nvoid run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params& params, const bool is_training,\n                                  const int num_chunks, cudaStream_t stream);\n\nvoid run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params& params, const int num_chunks,\n                                        cudaStream_t stream);\n\nvoid fmha_run_noloop_reduce(void* out, const void* in, const int* cu_seqlens, const int hidden_size,\n                            const int batch_size, const int total, const int num_chunks, cudaStream_t stream);\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_dgrad_kernel_1xN_reload.h\"\n\nusing Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;\n\nextern \"C\" __global__ void fmha_dgrad_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {\n  fmha::compute_dv_1xN<Kernel_traits>(params);\n  fmha::compute_dq_dk_1xN<Kernel_traits>(params);\n}\n\nvoid run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params& params, cudaStream_t stream) {\n  constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n  constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n  constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n  constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n  using Smem_tile_s = fmha::Smem_tile_mma_transposed<Kernel_traits::Cta_tile_p>;\n  constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;\n  static_assert(smem_size_s == 16 * 128 * 2);\n  static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);\n\n  constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;\n  constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;\n  constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);\n\n  if (smem_size >= 48 * 1024) {\n    FMHA_CHECK_CUDA(cudaFuncSetAttribute(fmha_dgrad_fp16_128_64_sm80_kernel,\n                                         cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n  }\n  dim3 grid(params.h, params.b);\n  fmha_dgrad_fp16_128_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_dgrad_kernel_1xN_reload.h\"\n\nusing Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;\n\nextern \"C\" __global__ void fmha_dgrad_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {\n  fmha::compute_dv_1xN<Kernel_traits>(params);\n  fmha::compute_dq_dk_1xN<Kernel_traits>(params);\n}\n\nvoid run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params& params, cudaStream_t stream) {\n  constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n  constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n  constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n  constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n  using Smem_tile_s = fmha::Smem_tile_mma_transposed<Kernel_traits::Cta_tile_p>;\n  constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;\n  static_assert(smem_size_s == 16 * 256 * 2);\n  static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);\n\n  constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;\n  constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;\n  constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);\n\n  if (smem_size >= 48 * 1024) {\n    FMHA_CHECK_CUDA(cudaFuncSetAttribute(fmha_dgrad_fp16_256_64_sm80_kernel,\n                                         cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n  }\n  dim3 grid(params.h, params.b);\n  fmha_dgrad_fp16_256_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_dgrad_kernel_1xN_reload.h\"\n\nusing Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 8, 0x08u>;\n\nextern \"C\" __global__ void fmha_dgrad_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {\n  fmha::compute_dv_1xN<Kernel_traits>(params);\n  fmha::compute_dq_dk_1xN<Kernel_traits>(params);\n}\n\nvoid run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params& params, cudaStream_t stream) {\n  constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n  constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n  constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n  constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n  using Smem_tile_s = fmha::Smem_tile_mma_transposed<Kernel_traits::Cta_tile_p>;\n  constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;\n  static_assert(smem_size_s == 16 * 384 * 2);\n  static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);\n\n  constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;\n  constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;\n  constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);\n\n  if (smem_size >= 48 * 1024) {\n    FMHA_CHECK_CUDA(cudaFuncSetAttribute(fmha_dgrad_fp16_384_64_sm80_kernel,\n                                         cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n  }\n  dim3 grid(params.h, params.b);\n  fmha_dgrad_fp16_384_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_dgrad_kernel_1xN_reload.h\"\n#include \"fmha_dgrad_kernel_1xN_reload_nl.h\"\n\nusing Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>;\n\nextern \"C\" __global__ void fmha_dgrad_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {\n  fmha::compute_dv_1xN<Kernel_traits>(params);\n  fmha::compute_dq_dk_1xN<Kernel_traits>(params);\n}\n\ntemplate <int CHUNKS>\n__global__ void fmha_dgrad_fp16_512_64_sm80_nl_kernel(Fused_multihead_attention_fprop_params params) {\n  fmha::compute_dv_1xN_nl<CHUNKS, Kernel_traits>(params);\n  fmha::compute_dq_dk_1xN_nl<CHUNKS, Kernel_traits>(params);\n}\n\nvoid run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params& params, cudaStream_t stream) {\n  constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n  constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n  constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n  constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n  using Smem_tile_s = fmha::Smem_tile_mma_transposed<Kernel_traits::Cta_tile_p>;\n  constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;\n  static_assert(smem_size_s == 16 * 512 * 2);\n  static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);\n\n  constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;\n  constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;\n  constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);\n\n  if (smem_size >= 48 * 1024) {\n    FMHA_CHECK_CUDA(cudaFuncSetAttribute(fmha_dgrad_fp16_512_64_sm80_kernel,\n                                         cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n  }\n  dim3 grid(params.h, params.b);\n  fmha_dgrad_fp16_512_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n}\n\nvoid run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params& params, const int num_chunks,\n                                        cudaStream_t stream) {\n  constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);\n  constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;\n  constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;\n  constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;\n\n  using Smem_tile_s = fmha::Smem_tile_mma_transposed<Kernel_traits::Cta_tile_p>;\n  constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;\n  static_assert(smem_size_s == 16 * 512 * 2);\n  static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);\n\n  constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;\n  constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;\n  constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);\n\n  auto kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>;\n\n  if (num_chunks == 2) {\n    kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>;\n  } else if (num_chunks == 3) {\n    kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<3>;\n  } else {\n    assert(false && \"Unsupperted number of chunks\");\n  }\n\n  if (smem_size >= 48 * 1024) {\n    FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n  }\n\n  dim3 grid(params.h, params.b, num_chunks);\n\n  kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);\n\n  FMHA_CHECK_CUDA(cudaPeekAtLastError());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <fmha/gemm.h>\n#include <fmha/kernel_traits.h>\n\n#include \"fmha_kernel.h\"\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Kernel_traits, typename Params>\ninline __device__ void compute_dv_1xN(const Params& params) {\n  // The description of the CTA tile for the 1st batched GEMM.\n  using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n  // The description of the CTA tile for the 2nd batched GEMM.\n  using Cta_tile_dv =\n      fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;\n\n  static_assert(Cta_tile_dv::M == 512 || Cta_tile_dv::M == 384 || Cta_tile_dv::M == 256 || Cta_tile_dv::M == 128);\n  static_assert(Cta_tile_dv::N == 64);\n  static_assert(Cta_tile_dv::K == 16);\n\n  // The MMA tile for the 1st GEMM.\n  using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n  // The MMA tile for the 2nd GEMM.\n  using Mma_tile_dv = fmha::Hmma_tile<Cta_tile_dv>;\n\n  // The global memory tile to load Q.\n  using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;\n  // The shared memory tile to swizzle Q.\n  // using Smem_tile_q = typename Kernel_traits::Smem_tile_q;\n  using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;\n  // The shared memory tile to reload Q as fragment b.\n  using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;\n\n  // The global memory tile to load K.\n  using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;\n  // The shared memory tile to swizzle K.\n  using Smem_tile_k = typename Kernel_traits::Smem_tile_k;\n\n  // The global memory tile to load V.\n  using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;\n  // The shared memory tile to swizzle V.\n  using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n\n  // The global memory tile to store O.\n  using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;\n  // The shared memory tile to swizzle O.\n  using Smem_tile_o = typename Kernel_traits::Smem_tile_o;\n\n  // The global memory tile to store dV.\n  using Gmem_tile_dv = typename Kernel_traits::Gmem_tile_v;\n  // The shared memory tile to swizzle dV.\n  using Smem_tile_dv = fmha::Smem_tile_mma_epilogue<Cta_tile_dv>;\n  static_assert(Smem_tile_dv::NUM_LDS == Gmem_tile_dv::LDGS);\n  static_assert(Smem_tile_dv::THREADS_PER_ROW == Gmem_tile_dv::THREADS_PER_ROW);\n\n  using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;\n  using Smem_tile_st = typename Kernel_traits::Smem_tile_st;\n  using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;\n\n  // Shared memory.\n  extern __shared__ char smem_[];\n\n  // The block index for the batch.\n  const int bidb = blockIdx.y;\n  // The block index for the head.\n  const int bidh = blockIdx.x;\n  // The thread index.\n  const int tidx = threadIdx.x;\n\n  const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);\n  if (binfo.stop_early()) return;\n  Mask<Cta_tile_p> mask(params, binfo, tidx);\n\n  // Allocate the global memory tile loader for Q.\n  Gmem_tile_do gmem_q(params, binfo, tidx);  // treating dout as Q\n  // Allocate the shared memory tile loader for Q.\n  Smem_tile_q smem_q(&smem_[0], tidx);\n  Smem_tile_qt smem_qt(&smem_[0], tidx);\n  Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);\n\n  // Allocate the global memory tile loader for K.\n  Gmem_tile_k gmem_k(params, 2, binfo, tidx);  // treating V as K\n  // Allocate the shared memory tile loader for K.\n  Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n  // Trigger the loads for Q.\n  gmem_q.load(smem_q);\n  // Trigger the loads for K.\n  gmem_k.load(smem_k);\n\n  // Commit the data for Q and K to shared memory.\n  gmem_q.commit(smem_q);\n  gmem_k.commit(smem_k);\n\n  // Make sure the data is in shared memory.\n  __syncthreads();\n\n  // Load the fragments for Q.\n  typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];\n  smem_q.load(frag_q[0], 0);\n\n  typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dv::MMAS_N];\n  static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4);\n  static_assert(Mma_tile_dv::MMAS_K == 1);\n  smem_qt.load(frag_qt[0], 0);\n\n  // Load the fragments for K. We keep the data in registers during the entire kernel.\n  typename Smem_tile_k::Fragment frag_k[2][Mma_tile_p::MMAS_N];\n  smem_k.load(frag_k[0], 0);\n\n  enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };\n\n  Gmem_tile_s gmem_s(params, binfo, tidx);\n\n  // Create the object to do the softmax.\n  using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;\n  Softmax softmax(params,\n                  &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE],\n                  bidb, tidx);\n\n  enum { THREADS_PER_ROW = 32 };\n  enum { M = Mma_tile_p::MMAS_M };\n  enum { N = Mma_tile_p::MMAS_N };\n\n  // Declare the accumulators for the 2nd gemm.\n  fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N];\n  fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dv::WARPS_K>::apply(acc_dv);\n\n  enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };\n  // Load over the entire sequence length.\n  for (int l = 0; l < STEPS; l++) {\n    const int loop = l * Cta_tile_p::M;\n    if (loop >= binfo.actual_seqlen) break;\n\n    // Load S\n    uint4 s_regs[M][N];\n    gmem_s.load(s_regs, mask);\n    fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];\n    fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);\n// Do this part of P^T = (Q * K^T)^T.\n#pragma unroll\n    for (int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki) {\n      // Trigger the load from shared memory for the next series of Q values.\n      smem_q.load(frag_q[ki & 1], ki);\n      smem_k.load(frag_k[ki & 1], ki);\n      // Do the math for the values already in registers.\n      fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);\n    }\n\n    // Store s * dmask to smem for transpose\n    smem_s.store(s_regs);\n\n    // Declare the accumulators for the 1st gemm.\n    // Do the final stage of math.\n    {\n      int ki = Mma_tile_p::MMAS_K;\n      fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);\n    }\n    // Trigger the load for the next Q values. We're using double buffering, so reading qt is safe\n    if (l < STEPS - 1) {\n      smem_q.move_to_next_write_buffer();\n      gmem_q.move();\n      gmem_q.load(smem_q);\n    }\n\n    // Convert from the accumulator type to FP32 for Softmax.\n    softmax.unpack(acc_p);\n\n    float s_mat[2 * M][4 * N];\n\n#pragma unroll\n    for (int mi = 0; mi < M; mi++) {\n#pragma unroll\n      for (int ni = 0; ni < N; ni++) {\n        uint4& dst = s_regs[mi][ni];\n        fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 0], s_mat[2 * mi + 0][4 * ni + 1], dst.x);\n        fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 2], s_mat[2 * mi + 0][4 * ni + 3], dst.y);\n        fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 0], s_mat[2 * mi + 1][4 * ni + 1], dst.z);\n        fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 2], s_mat[2 * mi + 1][4 * ni + 3], dst.w);\n      }\n    }\n\n#pragma unroll\n    for (int mi = 0; mi < M; mi++) {\n#pragma unroll\n      for (int ii = 0; ii < 2; ii++) {\n#pragma unroll\n        for (int ni = 0; ni < N; ni++) {\n#pragma unroll\n          for (int jj = 0; jj < 4; jj++) {\n            float& s_dmask = s_mat[2 * mi + ii][4 * ni + jj];\n            const bool drop = reinterpret_cast<const uint32_t&>(s_dmask) & 0x80000000;\n            const float d_s = drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout;\n            s_dmask = fabsf(s_dmask);\n            softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * fabsf(s_dmask);\n          }\n        }\n      }\n    }\n\n    float p_sum[2 * M];\n    softmax.reduce_sum(p_sum);\n\n    const float scalef = reinterpret_cast<const float&>(params.scale_softmax);\n#pragma unroll\n    for (int mi = 0; mi < M; mi++) {\n#pragma unroll\n      for (int ii = 0; ii < 2; ii++) {\n#pragma unroll\n        for (int ni = 0; ni < N; ni++) {\n#pragma unroll\n          for (int jj = 0; jj < 4; jj++) {\n            softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]);\n            softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef;\n          }\n        }\n      }\n    }\n    typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M];\n    smem_s.load(frag_s);\n    for (int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++) {\n      for (int mi = 0; mi < Mma_tile_dv::MMAS_M; mi++) {\n        for (int ii = 0; ii < Smem_tile_st::Fragment::NUM_REGS; ii++) {\n          frag_s[ki][mi].reg(ii) = fmha::hmul2(frag_s[ki][mi].reg(ii), params.scale_dropout);\n          frag_s[ki][mi].reg(ii) = fmha::hrelu2(frag_s[ki][mi].reg(ii));\n        }\n      }\n    }\n\n    gmem_s.store(softmax.elt_, mask);\n    gmem_s.move();\n\n#pragma unroll\n    for (int ki = 1; ki < Mma_tile_dv::MMAS_K; ++ki) {\n      // Trigger the load from shared memory for the next series of Q values.\n      smem_qt.load(frag_qt[ki & 1], ki);\n      // Do the math for the values already in registers.\n      fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n    }\n\n    // Do the final stage of math.\n    {\n      int ki = Mma_tile_dv::MMAS_K;\n      fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n    }\n    // Commit the values for Q into shared memory.\n    if (l < STEPS - 1) {\n      gmem_q.commit(smem_q);\n    }\n\n    // Make sure we are reading from the correct buffer.\n    smem_q.move_to_next_read_buffer();\n    smem_qt.move_to_next_read_buffer();\n\n    // Make sure the data is in shared memory.\n    __syncthreads();\n\n    // Trigger the loads for the values of Q for the next iteration.\n    smem_q.load(frag_q[0], 0);\n    smem_k.load(frag_k[0], 0);\n    smem_qt.load(frag_qt[0], 0);\n\n  }  // Outer loop over the sequence length.\n\n  // Epilogue swizzle for dV\n  Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx);\n  smem_dv.store(acc_dv);\n\n  __syncthreads();\n  uint4 dv_out[Smem_tile_dv::NUM_LDS];\n  smem_dv.load(dv_out);\n  Qkv_params dv_params;\n  dv_params.qkv_ptr = params.dqkv_ptr;\n  dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;\n  dv_params.h = params.h;\n  Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx);\n  gmem_dv.store(dv_out);\n}\n\ntemplate <typename Kernel_traits, typename Params>\ninline __device__ void compute_dq_dk_1xN(const Params& params) {\n  // The description of the CTA tile for the 1st batched GEMM.\n  using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n  using Cta_tile_o = typename Kernel_traits::Cta_tile_o;\n  // The description of the CTA tile for the 2nd batched GEMM.\n  using Cta_tile_dk =\n      fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;\n  static_assert(Cta_tile_dk::M == 512 || Cta_tile_dk::M == 384 || Cta_tile_dk::M == 256 || Cta_tile_dk::M == 128);\n  static_assert(Cta_tile_dk::N == 64);\n  static_assert(Cta_tile_dk::K == 16);\n\n  // The MMA tile for the 1st GEMM.\n  using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n  using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;\n  // The MMA tile for the 2nd GEMM.\n  using Mma_tile_dk = fmha::Hmma_tile<Cta_tile_dk>;\n\n  // The global memory tile to load Q.\n  using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;\n  // The shared memory tile to swizzle Q.\n  using Smem_tile_q = typename Kernel_traits::Smem_tile_q;\n\n  // The global memory tile to load K.\n  using Gmem_tile_k = typename Kernel_traits::Gmem_tile_v;\n  // The shared memory tile to swizzle K.\n  using Smem_tile_k = typename Kernel_traits::Smem_tile_v;  // K is used like V in fprop\n\n  // The global memory tile to load V.\n  using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;\n  // The shared memory tile to swizzle V.\n  using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n\n  // The global memory tile to store O.\n  // using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;\n  using Gmem_tile_o = fmha::Gmem_tile_dq<Cta_tile_o>;\n  // The shared memory tile to swizzle O.\n  using Smem_tile_o = typename Kernel_traits::Smem_tile_o;\n\n  // The global memory tile to store dK.\n  using Gmem_tile_dk = typename Kernel_traits::Gmem_tile_v;\n  // The shared memory tile to swizzle dK.\n  using Smem_tile_dk = fmha::Smem_tile_mma_epilogue<Cta_tile_dk>;\n  static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS);\n  static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW);\n\n  // The shared memory tile to reload Q transposed.\n  using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dk, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;\n\n  using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;\n\n  using Smem_tile_st = typename Kernel_traits::Smem_tile_st;\n\n  enum { M = Mma_tile_p::MMAS_M };\n  enum { N = Mma_tile_p::MMAS_N };\n  static_assert(M == Mma_tile_o::MMAS_M);\n  static_assert(N == Mma_tile_o::MMAS_K);\n  // Shared memory.\n  extern __shared__ char smem_[];\n\n  // The block index for the batch.\n  const int bidb = blockIdx.y;\n  // The block index for the head.\n  const int bidh = blockIdx.x;\n  // The thread index.\n  const int tidx = threadIdx.x;\n\n  const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);\n  if (binfo.stop_early()) return;\n\n  Mask<Cta_tile_p> mask(params, binfo, tidx);\n\n  // Allocate the global memory tile loader for Q.\n  Gmem_tile_q gmem_q(params, 0, binfo, tidx);\n  // Allocate the shared memory tile loader for Q.\n  Smem_tile_q smem_q(&smem_[0], tidx);\n  Smem_tile_qt smem_qt(&smem_[0], tidx);\n  Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE],\n                      tidx);\n\n  // Allocate the global memory tile loader for K.\n  Gmem_tile_k gmem_k(params, 1, binfo, tidx);\n  // Allocate the shared memory tile loader for K.\n  Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n  // Allocate the global memory tile loader for O.\n  Gmem_tile_o gmem_o(params, binfo, tidx);\n  // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!\n  Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);\n\n  // Trigger the loads for Q.\n  gmem_q.load(smem_q);\n  // Trigger the loads for K.\n  gmem_k.load(smem_k);\n\n  Gmem_tile_s gmem_s(params, binfo, tidx);\n  // Load dP\n  uint4 s_regs[M][N];\n  gmem_s.load(s_regs, mask);\n  gmem_s.move();\n\n  // Commit the data for Q and K to shared memory.\n  gmem_q.commit(smem_q);\n  gmem_k.commit(smem_k);\n\n  // Make sure the data is in shared memory.\n  __syncthreads();\n\n  typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N];\n  smem_qt.load(frag_qt[0], 0);\n  typename Smem_tile_k::Fragment frag_k[2][Mma_tile_o::MMAS_N];\n  smem_k.load(frag_k[0], 0);\n\n  enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };\n\n  enum { THREADS_PER_ROW = 32 };\n  enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };\n\n  // Declare the accumulators for the 2nd gemm.\n  fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N];\n  fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dk::WARPS_K>::apply(acc_dk);\n\n  // Load over the entire sequence length.\n  for (int l = 0; l < STEPS; l++) {\n    const int loop = l * Cta_tile_p::M;\n    if (loop >= binfo.actual_seqlen) break;\n\n    // Pack dP as Fragment_a\n    fmha::Fragment_a<fmha::Row> frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];\n#pragma unroll\n    for (int mi = 0; mi < M; mi++) {\n#pragma unroll\n      for (int ni = 0; ni < N; ni++) {\n        uint4& dst = s_regs[mi][ni];\n        frag_p[ni][mi].reg(0) = dst.x;  // row 0, cols 0,1\n        frag_p[ni][mi].reg(1) = dst.z;  // row 8, cols 0,1\n        frag_p[ni][mi].reg(2) = dst.y;  // row 0, cols 8,9\n        frag_p[ni][mi].reg(3) = dst.w;  // row 8, cols 8,9\n      }\n    }\n\n    // Declare the accumulators for the 1st gemm.\n    fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];\n    fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);\n\n// Do this part of O = P^T * V^T. dQ = dP x dK\n#pragma unroll\n    for (int ki = 1; ki < Mma_tile_o::MMAS_K; ++ki) {\n      // Trigger the load from shared memory for the next series of Q values.\n      smem_k.load(frag_k[ki & 1], ki);\n      // Do the math for the values already in registers.\n      fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);\n    }\n\n    // Do the final stage of math.\n    {\n      int ki = Mma_tile_o::MMAS_K;\n      fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);\n    }\n\n    // Store dP to smem for transpose\n    smem_s.store(s_regs);\n    if (l < STEPS - 1) {\n      // Load next part of S\n      gmem_s.load(s_regs, mask);\n      gmem_s.move();\n      smem_q.move_to_next_write_buffer();\n      gmem_q.move();\n      gmem_q.load(smem_q);\n    }\n// Loop over MMAS_M.\n#pragma unroll\n    for (int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii) {\n      // Swizzle the elements and do the final reduction.\n      smem_o.store(acc_o, ii);\n\n      // Make sure the data is in shared memory.\n      __syncthreads();\n\n      // Load from shared memory.\n      uint4 out[Gmem_tile_o::STGS_PER_LOOP];\n      smem_o.load(out);\n\n      // Make sure the data was read from shared memory.\n      if (ii < Gmem_tile_o::LOOPS - 1) {\n        __syncthreads();\n      }\n\n      // Output the values.\n      gmem_o.store(out, ii);\n    }\n\n    // Move to the next part of the output.\n    gmem_o.move();\n\n    typename Smem_tile_st::Fragment frag_s[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M];\n    smem_s.load(frag_s);\n\n#pragma unroll\n    for (int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki) {\n      // Trigger the load from shared memory for the next series of Q values.\n      smem_qt.load(frag_qt[ki & 1], ki);\n      // Do the math for the values already in registers.\n      fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n    }\n\n    // Do the final stage of math.\n    {\n      int ki = Mma_tile_dk::MMAS_K;\n      fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n    }\n\n    // Commit the values for Q into shared memory.\n    if (l < STEPS - 1) {\n      gmem_q.commit(smem_q);\n    }\n\n    // Make sure the data is in shared memory.\n    __syncthreads();\n\n    // Trigger the loads for the values of Q for the next iteration.\n    smem_qt.load(frag_qt[0], 0);\n    smem_k.load(frag_k[0], 0);\n\n  }  // Outer loop over the sequence length.\n\n  // Epilogue swizzle for dK\n  Smem_tile_dk smem_dk(&smem_[0], tidx);\n  smem_dk.store(acc_dk);\n  __syncthreads();\n  uint4 dk_out[Smem_tile_dk::NUM_LDS];\n  smem_dk.load(dk_out);\n  Qkv_params dk_params;\n  dk_params.qkv_ptr = params.dqkv_ptr;\n  dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;\n  dk_params.h = params.h;\n  Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx);\n  gmem_dk.store(dk_out);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <fmha/gemm.h>\n#include <fmha/kernel_traits.h>\n\n#include \"fmha_kernel.h\"\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int CHUNKS, typename Kernel_traits, typename Params>\ninline __device__ void compute_dv_1xN_nl(const Params& params) {\n  // The description of the CTA tile for the 1st batched GEMM.\n  using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n  // The description of the CTA tile for the 2nd batched GEMM.\n  using Cta_tile_dv =\n      fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;\n\n  static_assert(Cta_tile_dv::M == 512 || Cta_tile_dv::M == 384 || Cta_tile_dv::M == 256 || Cta_tile_dv::M == 128);\n  static_assert(Cta_tile_dv::N == 64);\n  static_assert(Cta_tile_dv::K == 16);\n\n  // The MMA tile for the 1st GEMM.\n  using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n  // The MMA tile for the 2nd GEMM.\n  using Mma_tile_dv = fmha::Hmma_tile<Cta_tile_dv>;\n\n  // The global memory tile to load Q.\n  using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;\n  // The shared memory tile to swizzle Q.\n  using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;\n  // The shared memory tile to reload Q as fragment b.\n  using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;\n\n  // The global memory tile to load K.\n  using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;\n  // The shared memory tile to swizzle K.\n  using Smem_tile_k = typename Kernel_traits::Smem_tile_k;\n\n  // The global memory tile to load V.\n  using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;\n  // The shared memory tile to swizzle V.\n  using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n\n  // The global memory tile to store dV.\n  using Gmem_tile_dv = fmha::Gmem_tile_qkv<typename Kernel_traits::Cta_tile_o, fmha::BITS_PER_ELEMENT_B,\n                                           Cta_tile_p::N,  // S,\n                                           Cta_tile_p::K,  // D,\n                                           2 * CHUNKS>;\n\n  // The shared memory tile to swizzle dV.\n  using Smem_tile_dv = fmha::Smem_tile_mma_epilogue<Cta_tile_dv>;\n  static_assert(Smem_tile_dv::NUM_LDS == Gmem_tile_dv::LDGS);\n  static_assert(Smem_tile_dv::THREADS_PER_ROW == Gmem_tile_dv::THREADS_PER_ROW);\n\n  using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;\n  using Smem_tile_st = typename Kernel_traits::Smem_tile_st;\n  using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;\n\n  // Shared memory.\n  extern __shared__ char smem_[];\n\n  // The block index for the chunk.\n  const int bidc = blockIdx.z;\n  // The block index for the batch.\n  const int bidb = blockIdx.y;\n  // The block index for the head.\n  const int bidh = blockIdx.x;\n  // The thread index.\n  const int tidx = threadIdx.x;\n\n  const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);\n  if (binfo.stop_early()) return;\n  fmha::Mask<Cta_tile_p> mask(params, binfo, tidx);\n\n  // Allocate the global memory tile loader for Q.\n  Gmem_tile_do gmem_q(params, binfo, tidx);  // treating dout as Q\n  // Allocate the shared memory tile loader for Q.\n  Smem_tile_q smem_q(&smem_[0], tidx);\n  Smem_tile_qt smem_qt(&smem_[0], tidx);\n  Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);\n\n  // Allocate the global memory tile loader for K.\n  Gmem_tile_k gmem_k(params, 2, binfo, tidx);  // treating V as K\n  // Allocate the shared memory tile loader for K.\n  Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n  Gmem_tile_s gmem_s(params, binfo, tidx);\n\n  using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>;\n\n  Noloop nl_traits(bidc, binfo);\n  nl_traits.move_all(gmem_q, gmem_s);\n\n  // Trigger the loads for Q.\n  gmem_q.load(smem_q);\n  // Trigger the loads for K.\n  gmem_k.load(smem_k);\n\n  // Commit the data for Q and K to shared memory.\n  gmem_q.commit(smem_q);\n  gmem_k.commit(smem_k);\n\n  // Make sure the data is in shared memory.\n  __syncthreads();\n\n  // Load the fragments for Q.\n  typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];\n  smem_q.load(frag_q[0], 0);\n\n  typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dv::MMAS_N];\n  static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4);\n  static_assert(Mma_tile_dv::MMAS_K == 1);\n  smem_qt.load(frag_qt[0], 0);\n\n  // Load the fragments for K. We keep the data in registers during the entire kernel.\n  typename Smem_tile_k::Fragment frag_k[2][Mma_tile_p::MMAS_N];\n  smem_k.load(frag_k[0], 0);\n\n  enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };\n\n  // Create the object to do the softmax.\n  using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;\n  Softmax softmax(params,\n                  &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE],\n                  bidb, tidx);\n\n  enum { THREADS_PER_ROW = 32 };\n  enum { M = Mma_tile_p::MMAS_M };\n  enum { N = Mma_tile_p::MMAS_N };\n\n  // Declare the accumulators for the 2nd gemm.\n  fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N];\n  fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dv::WARPS_K>::apply(acc_dv);\n\n  // Load over the entire sequence length.\n  for (int l = 0; l < nl_traits.num_steps_; l++) {\n    uint4 s_regs[M][N];\n    gmem_s.load(s_regs, mask);\n    fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];\n    fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);\n// Do this part of P^T = (Q * K^T)^T.\n#pragma unroll\n    for (int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki) {\n      // Trigger the load from shared memory for the next series of Q values.\n      smem_q.load(frag_q[ki & 1], ki);\n      smem_k.load(frag_k[ki & 1], ki);\n      // Do the math for the values already in registers.\n      fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);\n    }\n\n    smem_s.store(s_regs);\n\n    // Declare the accumulators for the 1st gemm.\n    // Do the final stage of math.\n    {\n      int ki = Mma_tile_p::MMAS_K;\n      fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);\n    }\n    // Trigger the load for the next Q values. We're using double buffering, so reading qt is safe\n    if (l < nl_traits.num_steps_ - 1) {\n      smem_q.move_to_next_write_buffer();\n      gmem_q.move();\n      gmem_q.load(smem_q);\n    }\n    // Convert from the accumulator type to FP32 for Softmax.\n    softmax.unpack(acc_p);\n\n    float s_mat[2 * M][4 * N];\n\n#pragma unroll\n    for (int mi = 0; mi < M; mi++) {\n#pragma unroll\n      for (int ni = 0; ni < N; ni++) {\n        uint4& dst = s_regs[mi][ni];\n        fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 0], s_mat[2 * mi + 0][4 * ni + 1], dst.x);\n        fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 2], s_mat[2 * mi + 0][4 * ni + 3], dst.y);\n        fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 0], s_mat[2 * mi + 1][4 * ni + 1], dst.z);\n        fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 2], s_mat[2 * mi + 1][4 * ni + 3], dst.w);\n      }\n    }\n\n#pragma unroll\n    for (int mi = 0; mi < M; mi++) {\n#pragma unroll\n      for (int ii = 0; ii < 2; ii++) {\n#pragma unroll\n        for (int ni = 0; ni < N; ni++) {\n#pragma unroll\n          for (int jj = 0; jj < 4; jj++) {\n            float& s_dmask = s_mat[2 * mi + ii][4 * ni + jj];\n            const bool drop = reinterpret_cast<const uint32_t&>(s_dmask) & 0x80000000;\n            const float d_s = drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout;\n            s_dmask = fabsf(s_dmask);\n            softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * (s_dmask);\n          }\n        }\n      }\n    }\n\n    float p_sum[2 * M];\n    softmax.reduce_sum(p_sum);\n\n    const float scalef = reinterpret_cast<const float&>(params.scale_softmax);\n#pragma unroll\n    for (int mi = 0; mi < M; mi++) {\n#pragma unroll\n      for (int ii = 0; ii < 2; ii++) {\n#pragma unroll\n        for (int ni = 0; ni < N; ni++) {\n#pragma unroll\n          for (int jj = 0; jj < 4; jj++) {\n            softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]);\n            softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef;\n          }\n        }\n      }\n    }\n\n    typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M];\n    smem_s.load(frag_s);\n    for (int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++) {\n      for (int mi = 0; mi < Mma_tile_dv::MMAS_M; mi++) {\n        for (int ii = 0; ii < Smem_tile_st::Fragment::NUM_REGS; ii++) {\n          frag_s[ki][mi].reg(ii) = fmha::hmul2(frag_s[ki][mi].reg(ii), params.scale_dropout);\n          frag_s[ki][mi].reg(ii) = fmha::hrelu2(frag_s[ki][mi].reg(ii));\n        }\n      }\n    }\n\n    gmem_s.store(softmax.elt_, mask);\n    gmem_s.move();\n\n    static_assert(Mma_tile_dv::MMAS_K == 1);  // DEBUG\n#pragma unroll\n    for (int ki = 1; ki < Mma_tile_dv::MMAS_K; ++ki) {\n      // Trigger the load from shared memory for the next series of Q values.\n      smem_qt.load(frag_qt[ki & 1], ki);\n      // Do the math for the values already in registers.\n      fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n    }\n\n    // Do the final stage of math.\n    {\n      int ki = Mma_tile_dv::MMAS_K;\n      fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n    }\n    // Commit the values for Q into shared memory.\n    if (l < nl_traits.num_steps_ - 1) {\n      gmem_q.commit(smem_q);\n    }\n\n    // Make sure we are reading from the correct buffer.\n    smem_q.move_to_next_read_buffer();\n    smem_qt.move_to_next_read_buffer();\n\n    // Make sure the data is in shared memory.\n    __syncthreads();\n\n    // Trigger the loads for the values of Q for the next iteration.\n    smem_q.load(frag_q[0], 0);\n    smem_k.load(frag_k[0], 0);\n    smem_qt.load(frag_qt[0], 0);\n\n  }  // Outer loop over the sequence length.\n\n  // Epilogue for dV = (S * D)' * dout'. We're fully exposed to this!\n\n  // Epilogue swizzle for dV\n  Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx);\n  smem_dv.store(acc_dv);\n\n  __syncthreads();\n\n  uint4 dv_out[Smem_tile_dv::NUM_LDS];\n  smem_dv.load(dv_out);\n  Qkv_params dv_params;\n  dv_params.qkv_ptr = params.dkv_ptr;\n  dv_params.qkv_stride_in_bytes = params.h * 2 * CHUNKS * params.d * sizeof(half);\n  dv_params.h = params.h;\n  Gmem_tile_dv gmem_dv(dv_params, nl_traits.get_idx_dv(), binfo, tidx);\n  gmem_dv.store(dv_out);\n}\n\ntemplate <int CHUNKS, typename Kernel_traits, typename Params>\ninline __device__ void compute_dq_dk_1xN_nl(const Params& params) {\n  // The description of the CTA tile for the 1st batched GEMM.\n  using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n  using Cta_tile_o = typename Kernel_traits::Cta_tile_o;\n  // The description of the CTA tile for the 2nd batched GEMM.\n  using Cta_tile_dk =\n      fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;\n\n  static_assert(Cta_tile_dk::M == 512 || Cta_tile_dk::M == 384 || Cta_tile_dk::M == 256 || Cta_tile_dk::M == 128);\n  static_assert(Cta_tile_dk::N == 64);\n  static_assert(Cta_tile_dk::K == 16);\n\n  // The MMA tile for the 1st GEMM.\n  using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n  using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;\n  // The MMA tile for the 2nd GEMM.\n  using Mma_tile_dk = fmha::Hmma_tile<Cta_tile_dk>;\n\n  // The global memory tile to load Q.\n  using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;\n  // The shared memory tile to swizzle Q.\n  using Smem_tile_q = typename Kernel_traits::Smem_tile_q;\n\n  // The global memory tile to load K.\n  using Gmem_tile_k = typename Kernel_traits::Gmem_tile_v;\n  // The shared memory tile to swizzle K.\n  using Smem_tile_k = typename Kernel_traits::Smem_tile_v;  // K is used like V in fprop\n\n  // The global memory tile to load V.\n  using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;\n  // The shared memory tile to swizzle V.\n  using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n\n  // The global memory tile to store O.\n  using Gmem_tile_o = Gmem_tile_dq<Cta_tile_o>;\n  // The shared memory tile to swizzle O.\n  using Smem_tile_o = typename Kernel_traits::Smem_tile_o;\n\n  // The global memory tile to store dK.\n  using Gmem_tile_dk = fmha::Gmem_tile_qkv<typename Kernel_traits::Cta_tile_o, fmha::BITS_PER_ELEMENT_B,\n                                           Cta_tile_p::N,  // S,\n                                           Cta_tile_p::K,  // D,\n                                           2 * CHUNKS>;\n\n  // The shared memory tile to swizzle dK.\n  using Smem_tile_dk = fmha::Smem_tile_mma_epilogue<Cta_tile_dk>;\n  static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS);\n  static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW);\n\n  // The shared memory tile to reload Q transposed.\n  using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dk, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;\n\n  // The global memory tile to load dP, stored in S\n  using Gmem_tile_s = Gmem_tile_mma_s<Cta_tile_p>;\n  // The shared memory tile to transpose dP.\n  using Smem_tile_st = Smem_tile_mma_transposed<Cta_tile_p>;\n\n  using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>;\n\n  enum { M = Mma_tile_p::MMAS_M };\n  enum { N = Mma_tile_p::MMAS_N };\n  static_assert(M == Mma_tile_o::MMAS_M);\n  static_assert(N == Mma_tile_o::MMAS_K);\n  // Shared memory.\n  extern __shared__ char smem_[];\n\n  const int bidc = blockIdx.z;\n  // The block index for the batch.\n  const int bidb = blockIdx.y;\n  // The block index for the head.\n  const int bidh = blockIdx.x;\n  // The thread index.\n  const int tidx = threadIdx.x;\n\n  const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);\n  if (binfo.stop_early()) return;\n\n  fmha::Mask<Cta_tile_p> mask(params, binfo, tidx);\n\n  // Allocate the global memory tile loader for Q.\n  Gmem_tile_q gmem_q(params, 0, binfo, tidx);\n  // Allocate the shared memory tile loader for Q (as B).\n  Smem_tile_qt smem_qt(&smem_[0], tidx);\n  // Allocate the global memory tile loader for dP.\n  Gmem_tile_s gmem_s(params, binfo, tidx);\n  // Allocate the shared memory tile loader for dP.\n  Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE],\n                      tidx);\n\n  // Allocate the global memory tile loader for K.\n  Gmem_tile_k gmem_k(params, 1, binfo, tidx);\n  // Allocate the shared memory tile loader for K.\n  Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);\n\n  // Allocate the global memory tile loader for O.\n  Gmem_tile_o gmem_o(params, binfo, tidx);\n  // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!\n  Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);\n\n  Noloop nl_traits(bidc, binfo);\n\n  nl_traits.move_all(gmem_q, gmem_o, gmem_s);\n\n  // Trigger the loads for Q.\n  gmem_q.load(smem_qt);\n  // Trigger the loads for K.\n  gmem_k.load(smem_k);\n\n  uint4 s_regs[M][N];\n  gmem_s.load(s_regs, mask);\n\n  // Commit the data for Q and K to shared memory.\n  gmem_q.commit(smem_qt);\n  gmem_k.commit(smem_k);\n\n  // Make sure the data is in shared memory.\n  __syncthreads();\n\n  typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N];\n  smem_qt.load(frag_qt[0], 0);\n  typename Smem_tile_k::Fragment frag_k[2][Mma_tile_o::MMAS_N];\n  smem_k.load(frag_k[0], 0);\n\n  enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };\n\n  enum { THREADS_PER_ROW = 32 };\n\n  // Declare the accumulators for the 2nd gemm.\n  fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N];\n  fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dk::WARPS_K>::apply(acc_dk);\n\n  // Load over the entire sequence length.\n  for (int l = 0; l < nl_traits.num_steps_; l++) {\n    // Pack dP as Fragment_a\n    fmha::Fragment_a<fmha::Row> frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];\n#pragma unroll\n    for (int mi = 0; mi < M; mi++) {\n#pragma unroll\n      for (int ni = 0; ni < N; ni++) {\n        uint4& dst = s_regs[mi][ni];\n        frag_p[ni][mi].reg(0) = dst.x;\n        frag_p[ni][mi].reg(1) = dst.z;\n        frag_p[ni][mi].reg(2) = dst.y;\n        frag_p[ni][mi].reg(3) = dst.w;\n      }\n    }\n    smem_s.store(s_regs);\n    if (l < nl_traits.num_steps_ - 1) {\n      // Load next part of S\n      gmem_s.move();\n      gmem_s.load(s_regs, mask);\n      // Trigger the load for the next Q values.\n      smem_qt.move_to_next_write_buffer();\n      gmem_q.move();\n      gmem_q.load(smem_qt);\n    }\n    // Declare the accumulators for the 1st gemm.\n    fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];\n    fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);\n\n// Do this part of O = P^T * V^T. dQ = dP x dK\n#pragma unroll\n    for (int ki = 1; ki < Mma_tile_o::MMAS_K; ++ki) {\n      // Trigger the load from shared memory for the next series of Q values.\n      smem_k.load(frag_k[ki & 1], ki);\n      // Do the math for the values already in registers.\n      fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);\n    }\n\n    // Do the final stage of math.\n    {\n      int ki = Mma_tile_o::MMAS_K;\n      fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);\n    }\n\n    static_assert(Gmem_tile_o::LOOPS == 1);  // DEBUG\n// Loop over MMAS_M.\n#pragma unroll\n    for (int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii) {\n      // Swizzle the elements and do the final reduction.\n      smem_o.store(acc_o, ii);\n\n      // Make sure the data is in shared memory.\n      __syncthreads();\n\n      // Load from shared memory.\n      uint4 out[Gmem_tile_o::STGS_PER_LOOP];\n      smem_o.load(out);\n\n      // Make sure the data was read from shared memory.\n      if (ii < Gmem_tile_o::LOOPS - 1) {\n        __syncthreads();\n      }\n\n      // Output the values.\n      gmem_o.store(out, ii);\n    }\n\n    // Move to the next part of the output.\n    gmem_o.move();\n\n    typename Smem_tile_st::Fragment frag_s[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M];\n    smem_s.load(frag_s);\n\n    static_assert(Mma_tile_dk::MMAS_K == 1);  // DEBUG\n\n#pragma unroll\n    for (int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki) {\n      // Trigger the load from shared memory for the next series of Q values.\n      smem_qt.load(frag_qt[ki & 1], ki);\n      // Do the math for the values already in registers.\n      fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n    }\n\n    // Do the final stage of math.\n    {\n      int ki = Mma_tile_dk::MMAS_K;\n      fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);\n    }\n\n    // Commit the values for Q into shared memory.\n    if (l < nl_traits.num_steps_ - 1) {\n      gmem_q.commit(smem_qt);\n      __syncthreads();\n      // Trigger the loads for the values of Q for the next iteration.\n      smem_qt.load(frag_qt[0], 0);\n      smem_k.load(frag_k[0], 0);\n    }\n\n  }  // Outer loop over the sequence length.\n\n  // Epilogue for dK = dP' * dq. We're fully exposed to this!\n\n  // Epilogue swizzle for dK\n  Smem_tile_dk smem_dk(&smem_[0], tidx);\n  smem_dk.store(acc_dk);\n\n  __syncthreads();\n\n  uint4 dk_out[Smem_tile_dk::NUM_LDS];\n  smem_dk.load(dk_out);\n  Qkv_params dk_params;\n  dk_params.qkv_ptr = params.dkv_ptr;\n  dk_params.qkv_stride_in_bytes = params.h * 2 * CHUNKS * params.d * sizeof(half);\n  dk_params.h = params.h;\n  Gmem_tile_dk gmem_dk(dk_params, nl_traits.get_idx_dk(), binfo, tidx);\n  gmem_dk.store(dk_out);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha_fill.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include <ATen/Dispatch.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n\nconstexpr int block_size = 512;\nconstexpr int ctas_per_sm = 4;\n\ntemplate <typename scalar_t>\n__global__ void __launch_bounds__(block_size)\n    mha_fill_kernel(scalar_t* out_tensor, const int32_t* const start_row, const size_t num_rows) {\n  size_t row_stride = gridDim.y * blockDim.x;\n  size_t row_index = blockIdx.x + (size_t)start_row[0];\n  size_t col_index = blockIdx.y * blockDim.x + threadIdx.x;\n  while (row_index < num_rows) {\n    out_tensor[row_index * row_stride + col_index] = 0;\n    row_index += gridDim.x;\n  }\n}\n\nat::Tensor& mha_fill(at::Tensor& self, const at::Tensor& start_index) {\n  auto max_tokens = self.size(0);\n  auto self_2d = self.view({max_tokens, -1});\n  auto fcd_size = self_2d.size(1);\n  TORCH_CHECK(self.is_contiguous(), \"input not contiguous\");\n  TORCH_CHECK(fcd_size % block_size == 0, \"input size not aligned to block size\");\n  const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;\n  uint64_t num_blk_y = (uint64_t)(fcd_size / block_size);\n  uint64_t num_blk_x = (uint64_t)std::ceil(num_mp * ctas_per_sm / num_blk_y);\n  dim3 dim_grid(num_blk_x, num_blk_y);\n  dim3 dim_block(block_size);\n\n  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(\n      at::ScalarType::Half, at::ScalarType::BFloat16, self_2d.scalar_type(), \"mha_padding_fill_\", [&]() {\n        mha_fill_kernel<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(\n            self_2d.data_ptr<scalar_t>(), start_index.data_ptr<int32_t>(), max_tokens);\n        C10_CUDA_KERNEL_LAUNCH_CHECK();\n      });\n  return self;\n}\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_fprop_kernel_1xN.h\"\n\nusing Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;\n\ntemplate <bool Is_training>\n__global__ void fmha_fprop_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params,\n                                                   const int total_heads) {\n  fmha::device_1xN<Kernel_traits, Is_training>(params, total_heads);\n}\n\nvoid run_fmha_fp16_128_64_sm80(Launch_params<Fused_multihead_attention_fprop_params>& launch_params,\n                               const bool configure) {\n  auto kernel = launch_params.is_training ? &fmha_fprop_fp16_128_64_sm80_kernel<true>\n                                          : &fmha_fprop_fp16_128_64_sm80_kernel<false>;\n\n  constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();\n\n  if (smem_size >= 48 * 1024) {\n    FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n  }\n\n  const int sm_count = launch_params.props->multiProcessorCount;\n  int ctas_per_sm;\n  FMHA_CHECK_CUDA(\n      cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));\n  int total_ctas = sm_count * ctas_per_sm;\n\n  const int heads_total = launch_params.params.b * launch_params.params.h;\n  if (configure) {\n    using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;\n    constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;\n    constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;\n    constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;\n\n    size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas);\n    size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8;\n    launch_params.elts_per_thread = heads_per_cta * elts_per_head;\n    return;\n  }\n\n  dim3 grid(total_ctas);\n  kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(launch_params.params, heads_total);\n\n  FMHA_CHECK_CUDA(cudaPeekAtLastError());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_fprop_kernel_1xN.h\"\n\nusing Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;\n\ntemplate <bool Is_training>\n__global__ void fmha_fprop_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params,\n                                                   const int total_heads) {\n  fmha::device_1xN<Kernel_traits, Is_training>(params, total_heads);\n}\n\nvoid run_fmha_fp16_256_64_sm80(Launch_params<Fused_multihead_attention_fprop_params>& launch_params,\n                               const bool configure) {\n  auto kernel = launch_params.is_training ? &fmha_fprop_fp16_256_64_sm80_kernel<true>\n                                          : &fmha_fprop_fp16_256_64_sm80_kernel<false>;\n\n  constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();\n\n  if (smem_size >= 48 * 1024) {\n    FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n  }\n\n  const int sm_count = launch_params.props->multiProcessorCount;\n  int ctas_per_sm;\n  FMHA_CHECK_CUDA(\n      cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));\n  int total_ctas = sm_count * ctas_per_sm;\n\n  const int heads_total = launch_params.params.b * launch_params.params.h;\n  if (configure) {\n    using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;\n    constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;\n    constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;\n    constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;\n\n    size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas);\n    size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8;\n    launch_params.elts_per_thread = heads_per_cta * elts_per_head;\n    return;\n  }\n\n  dim3 grid(total_ctas);\n  kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(launch_params.params, heads_total);\n\n  FMHA_CHECK_CUDA(cudaPeekAtLastError());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_fprop_kernel_1xN.h\"\n\nusing Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 4, 0x18u>;\n\ntemplate <bool Is_training>\n__global__ void fmha_fprop_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params,\n                                                   const int total_heads) {\n  fmha::device_1xN<Kernel_traits, Is_training>(params, total_heads);\n}\n\nvoid run_fmha_fp16_384_64_sm80(Launch_params<Fused_multihead_attention_fprop_params>& launch_params,\n                               const bool configure) {\n  auto kernel = launch_params.is_training ? &fmha_fprop_fp16_384_64_sm80_kernel<true>\n                                          : &fmha_fprop_fp16_384_64_sm80_kernel<false>;\n\n  constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();\n\n  if (smem_size >= 48 * 1024) {\n    FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n  }\n\n  const int sm_count = launch_params.props->multiProcessorCount;\n  int ctas_per_sm;\n  FMHA_CHECK_CUDA(\n      cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));\n  int total_ctas = sm_count * ctas_per_sm;\n\n  const int heads_total = launch_params.params.b * launch_params.params.h;\n  if (configure) {\n    using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;\n    constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;\n    constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;\n    constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;\n\n    size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas);\n    size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8;\n    launch_params.elts_per_thread = heads_per_cta * elts_per_head;\n    return;\n  }\n\n  dim3 grid(total_ctas);\n  kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(launch_params.params, heads_total);\n\n  FMHA_CHECK_CUDA(cudaPeekAtLastError());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n#include \"fmha_fprop_kernel_1xN.h\"\n\nusing Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x00u>;\n\ntemplate <bool Is_training>\n__global__ void fmha_fprop_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params,\n                                                   const int total_heads) {\n  fmha::device_1xN<Kernel_traits, Is_training>(params, total_heads);\n}\n\ntemplate <bool Is_training>\n__global__ void fmha_fprop_fp16_512_64_sm80_kernel_nl(Fused_multihead_attention_fprop_params params,\n                                                      const int num_full_heads, const int num_main_groups,\n                                                      const int main_group_size, const int main_steps,\n                                                      const int rest_steps) {\n  fmha::device_1xN<Kernel_traits, Is_training>(params, num_full_heads, num_main_groups, main_group_size, main_steps,\n                                               rest_steps);\n}\n\nvoid run_fmha_fp16_512_64_sm80_(Launch_params<Fused_multihead_attention_fprop_params>& launch_params,\n                                const bool configure) {\n  auto kernel = launch_params.is_training ? &fmha_fprop_fp16_512_64_sm80_kernel<true>\n                                          : &fmha_fprop_fp16_512_64_sm80_kernel<false>;\n\n  constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();\n\n  if (smem_size >= 48 * 1024) {\n    FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n  }\n\n  const int sm_count = launch_params.props->multiProcessorCount;\n  int ctas_per_sm;\n  FMHA_CHECK_CUDA(\n      cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));\n  int total_ctas = sm_count * ctas_per_sm;\n\n  const int heads_total = launch_params.params.b * launch_params.params.h;\n  if (configure) {\n    using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;\n    constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;\n    constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;\n    constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;\n\n    size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas);\n    size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8;\n    launch_params.elts_per_thread = heads_per_cta * elts_per_head;\n    return;\n  }\n\n  dim3 grid(total_ctas);\n  kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(launch_params.params, heads_total);\n\n  FMHA_CHECK_CUDA(cudaPeekAtLastError());\n}\n\nvoid run_fmha_fp16_512_64_sm80_nl_(Launch_params<Fused_multihead_attention_fprop_params>& launch_params,\n                                   const bool configure) {\n  auto kernel = launch_params.is_training ? &fmha_fprop_fp16_512_64_sm80_kernel_nl<true>\n                                          : &fmha_fprop_fp16_512_64_sm80_kernel_nl<false>;\n\n  constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();\n\n  if (smem_size >= 48 * 1024) {\n    FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n  }\n\n  const int sm_count = launch_params.props->multiProcessorCount;\n  int ctas_per_sm;\n  FMHA_CHECK_CUDA(\n      cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));\n  int total_ctas = sm_count * ctas_per_sm;\n\n  if (configure) {\n    const int heads_total = launch_params.params.b * launch_params.params.h;\n    std::tie(launch_params.num_full_heads, launch_params.num_main_groups, launch_params.heads_last_wave,\n             launch_params.main_steps, launch_params.rest_steps, launch_params.elts_per_thread) =\n        fmha::work_dist<Kernel_traits>(total_ctas, heads_total);\n    return;\n  }\n\n  dim3 grid(total_ctas);\n  kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(\n      launch_params.params, launch_params.num_full_heads, launch_params.num_main_groups, launch_params.heads_last_wave,\n      launch_params.main_steps, launch_params.rest_steps);\n\n  FMHA_CHECK_CUDA(cudaPeekAtLastError());\n}\n\nvoid run_fmha_fp16_512_64_sm80(Launch_params<Fused_multihead_attention_fprop_params>& launch_params,\n                               const bool configure) {\n  if (launch_params.is_nl) {\n    run_fmha_fp16_512_64_sm80_nl_(launch_params, configure);\n  } else {\n    run_fmha_fp16_512_64_sm80_(launch_params, configure);\n  }\n}\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h",
    "content": "/***************************************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <fmha/gemm.h>\n#include <fmha/kernel_traits.h>\n\n#include \"fmha_kernel.h\"\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Kernel_traits>\nstruct Gemm_Q_K_base {\n  using Smem_tile_o = typename Kernel_traits::Smem_tile_o;\n  using Smem_tile_q = typename Kernel_traits::Smem_tile_q;\n  using Smem_tile_k = typename Kernel_traits::Smem_tile_k;\n  using Fragment_q = typename Smem_tile_q::Fragment;\n  using Fragment_k = typename Smem_tile_k::Fragment;\n\n  // The description of the CTA tile for the 1st batched GEMM.\n  using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n\n  // The MMA tile for the 1st GEMM.\n  using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n\n  static constexpr int SMEM_BYTES_SOFTMAX = Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2;\n\n  __device__ inline Gemm_Q_K_base(char* smem_ptr_q, char* smem_ptr_k, const int tidx)\n      : smem_q(smem_ptr_q, tidx), smem_k(smem_ptr_k, tidx) {}\n\n  __device__ inline void load_q() { smem_q.load(frag_q[0], 0); }\n\n  __device__ inline void reload_q() { smem_q.load(frag_q[0], 0); }\n\n  Fragment_q frag_q[2][Mma_tile_p::MMAS_M];\n  Smem_tile_q smem_q;\n  Smem_tile_k smem_k;\n};\n\ntemplate <typename Kernel_traits, bool K_in_regs>\nstruct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {\n  using Base = Gemm_Q_K_base<Kernel_traits>;\n  using Smem_tile_o = typename Base::Smem_tile_o;\n  using Smem_tile_q = typename Base::Smem_tile_q;\n  using Smem_tile_k = typename Base::Smem_tile_k;\n  using Fragment_k = typename Base::Fragment_k;\n  using Mma_tile_p = typename Base::Mma_tile_p;\n\n  enum { SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V };\n\n  enum { SMEM_OFFSET_O = Smem_tile_q::BYTES_PER_TILE };\n  enum { SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE) };\n\n  // Q | K / V\n  //   | O | SOFTMAX\n  static constexpr int SMEM_BYTES =\n      Smem_tile_q::BYTES_PER_TILE + std::max((SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE,\n                                             Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX);\n\n  __device__ inline Gemm_Q_K(char* smem_, const int tidx) : Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {}\n\n  __device__ inline void load_k() {\n#pragma unroll\n    for (int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki) {\n      Base::smem_k.load(frag_k[ki], ki);\n    }\n  }\n\n  template <typename Acc, int M, int N>\n  __device__ inline void operator()(Acc (&acc_p)[M][N]) {\n// Do this part of P^T = (Q * K^T)^T.\n#pragma unroll\n    for (int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki) {\n      // Trigger the load from shared memory for the next series of Q values.\n      Base::smem_q.load(Base::frag_q[ki & 1], ki);\n      // Do the math for the values already in registers.\n      fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);\n    }\n    // Do the final stage of math.\n    {\n      int ki = Mma_tile_p::MMAS_K;\n      fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);\n    }\n  }\n\n  __device__ inline void reload_k() {\n    // Noop.\n  }\n\n  Fragment_k frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];\n};\n\ntemplate <typename Kernel_traits>\nstruct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {\n  using Base = Gemm_Q_K_base<Kernel_traits>;\n  using Smem_tile_o = typename Base::Smem_tile_o;\n  using Smem_tile_q = typename Base::Smem_tile_q;\n  using Smem_tile_k = typename Base::Smem_tile_k;\n  using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n  using Fragment_k = typename Base::Fragment_k;\n  using Mma_tile_p = typename Base::Mma_tile_p;\n  Fragment_k frag_k[2][Mma_tile_p::MMAS_N];\n\n  enum { SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V };\n\n  enum { SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE) };\n  static_assert(Smem_tile_v::BYTES_PER_TILE == (int)Smem_tile_k::BYTES_PER_TILE);\n  enum { SMEM_OFFSET_O = SMEM_OFFSET_V + Smem_tile_v::BYTES_PER_TILE };\n\n  // Q | K/V + O + SOFTMAX\n  static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE +\n                                    (SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE +\n                                    Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX;\n\n  __device__ inline Gemm_Q_K(char* smem_, const int tidx) : Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {}\n\n  __device__ inline void load_k() { Base::smem_k.load(frag_k[0], 0); }\n\n  template <typename Acc, int M, int N>\n  __device__ inline void operator()(Acc (&acc_p)[M][N]) {\n// Do this part of P^T = (Q * K^T)^T.\n#pragma unroll\n    for (int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki) {\n      // Trigger the load from shared memory for the next series of Q values.\n      Base::smem_q.load(Base::frag_q[ki & 1], ki);\n      Base::smem_k.load(frag_k[ki & 1], ki);\n      // Do the math for the values already in registers.\n      fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);\n    }\n    // Do the final stage of math.\n    {\n      int ki = Mma_tile_p::MMAS_K;\n      fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);\n    }\n  }\n\n  __device__ inline void reload_k() { Base::smem_k.load(frag_k[0], 0); }\n};\n\ntemplate <typename Kernel_traits>\nconstexpr size_t get_dynamic_smem_size() {\n  return Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>::SMEM_BYTES;\n}\n\ntemplate <typename Kernel_traits, bool Is_training, typename Params, typename Prng>\ninline __device__ void device_1xN_(const Params& params, const int bidb, const int bidh, const int begin,\n                                   const int steps, Prng& ph) {\n  // The description of the CTA tile for the 1st batched GEMM.\n  using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n  // The description of the CTA tile for the 2nd batched GEMM.\n  using Cta_tile_o = typename Kernel_traits::Cta_tile_o;\n\n  // The MMA tile for the 1st GEMM.\n  using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n  // The MMA tile for the 2nd GEMM.\n  using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;\n\n  // The global memory tile to load Q.\n  using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;\n\n  // The global memory tile to load K.\n  using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;\n\n  // The global memory tile to load V.\n  using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;\n  // The shared memory tile to swizzle V.\n  using Smem_tile_v = typename Kernel_traits::Smem_tile_v;\n\n  // The global memory tile to store O.\n  using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;\n  // The shared memory tile to swizzle O.\n  using Smem_tile_o = typename Kernel_traits::Smem_tile_o;\n\n  using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;\n\n  using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;\n\n  using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;\n\n  // The number of threads per row.\n  enum { THREADS_PER_ROW = 32 };\n\n  enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };\n\n  // Shared memory.\n  extern __shared__ char smem_[];\n\n  // The thread index.\n  const int tidx = threadIdx.x;\n\n  const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);\n  if (binfo.stop_early()) return;\n\n  Gemm1 gemm_q_k(smem_, tidx);\n  // Allocate the global memory tile loader for Q.\n  Gmem_tile_q gmem_q(params, 0, binfo, tidx);\n  // Allocate the global memory tile loader for O.\n  Gmem_tile_o gmem_o(params, binfo, tidx);\n  // Allocate the global memory tile loader for S.\n  Gmem_tile_s gmem_s(params, binfo, tidx);\n  // Wind gmem tiles to the correct position.\n  for (int it = 0; it < begin; it++) {\n    gmem_q.move();\n    gmem_s.move();\n    gmem_o.move();\n  }\n\n  fmha::Mask<Cta_tile_p> mask(params, binfo, tidx);\n\n  // Allocate the global memory tile loader for K.\n  Gmem_tile_k gmem_k(params, 1, binfo, tidx);\n  // Allocate the global memory tile loader for V.\n  Gmem_tile_v gmem_v(params, 2, binfo, tidx);\n  // The base pointer of smem_v;\n  char* smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V];\n\n  // Allocate the shared memory tile loader for V. We use the same as K so be careful!!!\n  Smem_tile_v smem_v(smem_v_, tidx);\n\n  // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!\n  Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx);\n\n  // Trigger the loads for K.\n  gmem_k.load(gemm_q_k.smem_k);\n  // Trigger the loads for Q.\n  gmem_q.load(gemm_q_k.smem_q);\n  // Trigger the loads for V.\n  gmem_v.load(smem_v);\n\n  const uint32_t scale_bmm1 = reinterpret_cast<const uint32_t&>(params.scale_bmm1);\n#pragma unroll\n  for (int it = 0; it < Gmem_tile_k::LDGS; it++) {\n    gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]);\n  }\n\n  // Commit the data for Q and V to shared memory.\n  gmem_q.commit(gemm_q_k.smem_q);\n  gmem_v.commit(smem_v);\n\n  // Commit the data for K to shared memory.\n  if (!Kernel_traits::SHARE_SMEM_FOR_K_AND_V) {\n    gmem_k.commit(gemm_q_k.smem_k);\n  }\n\n  __syncthreads();\n\n  // Load the fragments for Q.\n  gemm_q_k.load_q();\n\n  // Load the fragments for V. We keep the data in registers during the entire kernel.\n  typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];\n#pragma unroll\n  for (int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki) {\n    smem_v.load(frag_v[ki], ki);\n  }\n\n  // Commit the data for V to shared memory if it has not been done already.\n  if (Kernel_traits::SHARE_SMEM_FOR_K_AND_V) {\n    // Make sure we are done loading the fragments for K.\n    __syncthreads();\n\n    // Commit the data to shared memory for V.\n    gmem_k.commit(gemm_q_k.smem_k);\n\n    // Make sure the data is in shared memory.\n    __syncthreads();\n  }\n\n  // Load the fragments for K.\n  gemm_q_k.load_k();\n  uint32_t p_scaled = (uint32_t)256.0 * params.p_dropout;\n\n  // Create the object to do the softmax.\n  Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);\n\n  // Load over the entire sequence length.\n  for (int l = 0; l < steps; l++) {\n    if (begin + l * Cta_tile_p::M >= binfo.actual_seqlen) break;\n\n    // Declare the accumulators for the 1st gemm.\n    fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];\n    fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);\n\n    // Do this part of P^T = (Q * K^T)^T.\n    gemm_q_k(acc_p);\n\n    // Trigger the load for the next Q values.\n    if (l < steps - 1) {\n      gemm_q_k.smem_q.move_to_next_write_buffer();\n      gmem_q.move();\n      gmem_q.load(gemm_q_k.smem_q);\n    }\n\n    // Load the mask for that iteration.\n    mask.load(begin + l);\n\n    // Convert from the accumulator type to FP32 for Softmax.\n    softmax.unpack_noscale(acc_p);\n\n    // Apply the mask.\n    softmax.apply_mask(mask);\n\n    if (Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0) {\n      // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction\n      __syncthreads();\n    }\n    // Compute the max.\n    float p_max[Mma_tile_p::MMAS_M * 2];\n    // softmax.template reduce<fmha::Max_>(p_max);\n    softmax.reduce_max(p_max);\n\n    // Compute the exponential value.\n    softmax.apply_exp(p_max);\n\n    // Compute the sum.\n    float p_sum[Mma_tile_p::MMAS_M * 2];\n    softmax.reduce_sum(p_sum);\n\n    // Finalize softmax on the accumulators of P^T.\n    softmax.scale(p_sum);\n\n    using Frag_p = fmha::Fragment_a<fmha::Row>;\n    Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];\n    if (Is_training) {\n      auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; };\n#pragma unroll\n      for (int mi = 0; mi < Mma_tile_p::MMAS_M; mi++) {\n#pragma unroll\n        for (int ii = 0; ii < 2; ii++) {\n#pragma unroll\n          for (int ni = 0; ni < Mma_tile_p::MMAS_N / 4; ni++) {\n            uint8_t* rand_arr = (uint8_t*)&ph();\n            // We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from\n            // pre-existing zeros\n            for (int ind = 0; ind < 16; ind++) {\n              softmax.elt_[2 * mi + ii][16 * ni + ind] =\n                  encode_dropout(rand_arr[ind] <= p_scaled, softmax.elt_[2 * mi + ii][16 * ni + ind]);\n            }\n          }\n        }\n      }\n      softmax.pack(frag_p);\n      gmem_s.store(frag_p, mask);\n      gmem_s.move();\n    } else {\n      softmax.pack(frag_p);\n    }\n\n    // Commit the values for Q into shared memory.\n    if (l < steps - 1) {\n      gmem_q.commit(gemm_q_k.smem_q);\n    }\n\n    if (Is_training) {\n#pragma unroll\n      for (int ki = 0; ki < Mma_tile_o::MMAS_K; ki++) {\n#pragma unroll\n        for (int mi = 0; mi < Mma_tile_o::MMAS_M; mi++) {\n#pragma unroll\n          for (int ii = 0; ii < Frag_p::NUM_REGS; ii++) {\n            //\"Apply\" the dropout.\n            frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout);\n            frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii));\n          }\n        }\n      }\n    }\n\n    // Declare the accumulators for the 1st gemm.\n    fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];\n    fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);\n\n// Do this part of O = P^T * V^T.\n#pragma unroll\n    for (int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki) {\n      fmha::gemm(acc_o, frag_p[ki], frag_v[ki]);\n    }\n\n// Loop over MMAS_M.\n#pragma unroll\n    for (int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii) {\n      // Swizzle the elements and do the final reduction.\n      smem_o.store(acc_o, ii);\n\n      // Make sure the data is in shared memory.\n      __syncthreads();\n\n      // Load from shared memory.\n      uint4 out[Gmem_tile_o::STGS_PER_LOOP];\n      smem_o.load(out);\n\n      // Make sure the data was read from shared memory.\n      if (ii < Gmem_tile_o::LOOPS - 1) {\n        __syncthreads();\n      }\n\n      // Output the values.\n      gmem_o.store(out, ii);\n    }\n\n    // Move to the next part of the output.\n    gmem_o.move();\n    gemm_q_k.reload_k();\n\n    // Commit the values for Q into shared memory.\n    if (l < steps - 1) {\n      gemm_q_k.reload_q();\n    }\n\n  }  // Outer loop over the sequence length.\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Kernel_traits, bool Is_training, typename Params>\ninline __device__ void device_1xN(const Params& params, const int num_full_heads, const int num_main_groups,\n                                  const int main_group_size, const int main_steps, const int rest_steps) {\n  constexpr int STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;\n  const int tidx_global = blockIdx.x * gridDim.x + threadIdx.x;\n  auto seeds = at::cuda::philox::unpack(params.philox_args);\n  Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds));\n  for (int it = 0; it < num_full_heads; it++) {\n    const int bidx = it * gridDim.x + blockIdx.x;\n    const int bidh = bidx % params.h;\n    const int bidb = bidx / params.h;\n    fmha::device_1xN_<Kernel_traits, Is_training>(params, bidb, bidh, 0, STEPS, ph);\n    __syncthreads();\n  }\n  if (main_group_size == 0) return;\n  const int head_offset = num_full_heads * gridDim.x;\n\n  if (blockIdx.x < main_group_size * num_main_groups) {\n    // process within heads\n    const int group = blockIdx.x % num_main_groups;\n    const int bidx = blockIdx.x / num_main_groups;\n    const int bidh = (head_offset + bidx) % params.h;\n    const int bidb = (head_offset + bidx) / params.h;\n    const int offset = group * main_steps;\n    fmha::device_1xN_<Kernel_traits, Is_training>(params, bidb, bidh, offset, main_steps, ph);\n  } else {\n    if (rest_steps == 0) return;\n    // process across heads\n    const int bidx = blockIdx.x - main_group_size * num_main_groups;\n    const int offset = num_main_groups * main_steps;\n    const int total_heads = params.b * params.h;\n    const int rest_ctas = gridDim.x - main_group_size * num_main_groups;\n    for (int it = head_offset + bidx; it < total_heads; it += rest_ctas) {\n      const int bidh = it % params.h;\n      const int bidb = it / params.h;\n      fmha::device_1xN_<Kernel_traits, Is_training>(params, bidb, bidh, offset, rest_steps, ph);\n      __syncthreads();\n    }\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Kernel_traits, bool Is_training, typename Params>\ninline __device__ void device_1xN(const Params& params, const int total_heads) {\n  const int tidx_global = blockIdx.x * gridDim.x + threadIdx.x;\n  auto seeds = at::cuda::philox::unpack(params.philox_args);\n  Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds));\n  constexpr int STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;\n\n  for (int bidx = blockIdx.x; bidx < total_heads; bidx += gridDim.x) {\n    const int bidh = bidx % params.h;\n    const int bidb = bidx / params.h;\n    fmha::device_1xN_<Kernel_traits, Is_training>(params, bidb, bidh, 0, STEPS, ph);\n    __syncthreads();\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha_kernel.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <fmha.h>\n#include <fmha/gmem_tile.h>\n#include <fmha/mask.h>\n#include <fmha/smem_tile.h>\n#include <fmha/softmax.h>\n#include <fmha/utils.h>\n\n#include <multihead_attn/philox.cuh>\n\nnamespace fmha {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int THREADS_PER_CTA>\nstruct BlockInfoPadded {\n  template <typename Params>\n  __device__ BlockInfoPadded(const Params& params, const int bidb, const int bidh, const int tidx)\n      : bidb(bidb), bidh(bidh), h(params.h) {\n    // The block index.\n    sum_s = params.cu_seqlens[bidb];\n    actual_seqlen = params.cu_seqlens[bidb + 1] - sum_s;\n    bidx = sum_s * params.h + bidh;\n\n    tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx;\n  }\n\n  __device__ bool stop_early() const { return actual_seqlen == 0; }\n\n  int actual_seqlen;\n  int bidx;\n  int sum_s;\n  int bidh;\n  int bidb;\n  int tidx_global;\n  int h;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int CHUNKS, typename Cta_tile>\nstruct Noloop_traits {\n  // Interpretation of Cta_tile dims, i.e. Cta_tile_p:\n  enum { STEP = Cta_tile::M };\n  enum { SEQLEN = Cta_tile::N };\n\n  template <typename Block_info>\n  inline __device__ Noloop_traits(const int bidc, const Block_info& binfo) : bidc_(bidc) {\n    const int seqlen = binfo.actual_seqlen;\n    const int steps = (seqlen + STEP - 1) / STEP;\n    const int steps_per_chunk = (steps + CHUNKS - 1) / CHUNKS;\n\n    const int step_begin = bidc_ * steps_per_chunk;\n    const int step_end = min(steps, (bidc_ + 1) * steps_per_chunk);\n    const int actual_steps = max(0, step_end - step_begin);\n    loop_offset_ = step_begin;\n    num_steps_ = actual_steps;\n  }\n\n  template <typename... Tiles>\n  inline __device__ void move_all(Tiles&... tiles) const {\n    using expand_type = int[];\n    for (int s = 0; s < loop_offset_; s++) {\n      expand_type{(tiles.move(), 0)...};\n    }\n  }\n\n  inline __device__ int get_idx_dk() const {\n    // return bidc_;\n    return bidc_ * 2 + 0;\n  }\n\n  inline __device__ int get_idx_dv() const {\n    // return CHUNKS + bidc_;\n    return bidc_ * 2 + 1;\n  }\n\n  inline __device__ int offset_loop_count(const int l) {\n    // convert loop counter to position in the outer sequence\n    return (loop_offset_ + l) * STEP;\n  }\n\n  const uint32_t bidc_;\n  int loop_offset_;\n  int num_steps_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Kernel_traits>\nstd::tuple<int, int, int, int, int, int> work_dist(const int total_ctas, const int heads_total) {\n  constexpr int STEPS_PER_HEAD = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;\n\n  const int num_full_heads = heads_total / total_ctas;\n  const int heads_last_wave = heads_total % total_ctas;\n\n  int num_main_groups = 0;\n  int main_steps = 0;\n  int rest_steps = 0;\n  if (heads_last_wave > 0) {\n    // Number of CTA groups that process within heads.\n    num_main_groups = total_ctas / heads_last_wave;\n    // Remaining CTAs that process between heads.\n    const int rest_ctas = total_ctas - (heads_last_wave * num_main_groups);\n    if (rest_ctas == 0) {\n      // We have exactly \"num_main_groups\" CTAs to process each of the remaining heads.\n      main_steps = (STEPS_PER_HEAD + num_main_groups - 1) / num_main_groups;\n      num_main_groups = STEPS_PER_HEAD / main_steps;  // Here: main_step > 0\n      rest_steps = STEPS_PER_HEAD % main_steps;\n\n    } else {\n      // Ideal number of steps if we could load-balance as evenly as possible.\n      const int steps_ideal = (heads_last_wave * STEPS_PER_HEAD + total_ctas - 1) / total_ctas;\n      // Iterations that a \"rest\" CTA has to do at most.\n      const int max_rest_iters = (heads_last_wave + rest_ctas - 1) / rest_ctas;\n      // Find the first step distribution, s.t. the maximum work of the \"rest\" CTAs is less than the work of the main\n      // CTAs.\n      main_steps = steps_ideal;\n      rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;\n      for (; main_steps * num_main_groups < STEPS_PER_HEAD; main_steps++) {\n        rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;\n        const int max_rest_total_steps = rest_steps * max_rest_iters;\n        if (max_rest_total_steps < main_steps) break;\n      }\n      rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;\n    }\n  }\n\n  using Cta_tile_p = typename Kernel_traits::Cta_tile_p;\n  using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;\n\n  const int max_steps = STEPS_PER_HEAD * num_full_heads + std::max(main_steps, rest_steps);\n  const int elts_per_thread_per_step = Mma_tile_p::MMAS_M * Mma_tile_p::MMAS_N * 8;\n  const int elts_per_thread = max_steps * elts_per_thread_per_step;\n\n  return {num_full_heads, num_main_groups, heads_last_wave, main_steps, rest_steps, elts_per_thread};\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace fmha\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#include \"fmha.h\"\n\ninline __device__ float4 ldg128(const void* ptr) { return *static_cast<const float4*>(ptr); }\n\ninline __device__ void stg128(void* ptr, const float4& data) { *static_cast<float4*>(ptr) = data; }\n\ntemplate <typename T, int THREADS, int HIDDEN_SIZE, int CHUNKS>\n__global__ __launch_bounds__(THREADS) void fmha_noloop_reduce_kernel(void* __restrict__ out,\n                                                                     const void* __restrict__ in,\n                                                                     const int* __restrict__ cu_seqlens,\n                                                                     const int batch_size) {\n  enum { BYTES_PER_LDG = 16 };\n  enum { NUM_ELTS = BYTES_PER_LDG / sizeof(T) };\n\n  // One CTA hidden vector for K and V\n  enum { BYTES_PER_ROW = HIDDEN_SIZE * sizeof(T) * 2 };\n  // The stride in bytes in dQKV\n  enum { OUT_STRIDE_BYTES = 3 * HIDDEN_SIZE * sizeof(T) };\n  // The offset in bytes in dQKV to the dKV part for non-interleaved heads\n  enum { OUT_OFFSET_KV_BYTES = HIDDEN_SIZE * sizeof(T) };\n\n  static_assert(BYTES_PER_ROW == HIDDEN_SIZE * 2 * sizeof(T));\n\n  // Size in bytes of the input tile\n  enum { BYTES_PER_TILE = CHUNKS * BYTES_PER_ROW };\n\n  enum { BYTES_PER_CTA = THREADS * BYTES_PER_LDG };\n\n  enum { LDGS = BYTES_PER_ROW / BYTES_PER_CTA };\n  static_assert(BYTES_PER_CTA * LDGS == BYTES_PER_ROW);\n\n  union Vec_t {\n    float4 raw;\n    T elt[NUM_ELTS];\n  };\n\n  // ZERO-OUT invalid positions in dQKV\n  const int total = cu_seqlens[batch_size];\n  if (blockIdx.x >= total) {\n    enum { BYTES_PER_QKV_ROW = 3 * HIDDEN_SIZE * sizeof(T) };\n    enum { STGS = BYTES_PER_QKV_ROW / BYTES_PER_LDG };\n\n    const float4 zeros = make_float4(0.f, 0.f, 0.f, 0.f);\n\n    char* base_ptr = static_cast<char*>(out) + blockIdx.x * OUT_STRIDE_BYTES;\n\n    for (int tidx = threadIdx.x; tidx < STGS; tidx += THREADS) {\n      stg128(base_ptr + tidx * BYTES_PER_LDG, zeros);\n    }\n\n    return;\n  }\n\n  // SETUP\n  const int offset_in = blockIdx.x * BYTES_PER_TILE + threadIdx.x * BYTES_PER_LDG;\n  const char* ptr_in = static_cast<const char*>(in) + offset_in;\n\n  const int offset_out = blockIdx.x * OUT_STRIDE_BYTES + threadIdx.x * BYTES_PER_LDG;\n  char* ptr_out = static_cast<char*>(out) + OUT_OFFSET_KV_BYTES + offset_out;\n\n  // LOAD\n\n  Vec_t local_in[CHUNKS][LDGS];\n\n#pragma unroll\n  for (int c = 0; c < CHUNKS; c++) {\n#pragma unroll\n    for (int l = 0; l < LDGS; l++) {\n      int offset = c * BYTES_PER_ROW + l * BYTES_PER_CTA;\n      local_in[c][l].raw = ldg128(ptr_in + offset);\n    }\n  }\n\n  // UNPACK\n  float acc[LDGS][NUM_ELTS];\n\n#pragma unroll\n  for (int l = 0; l < LDGS; l++) {\n#pragma unroll\n    for (int e = 0; e < NUM_ELTS; e++) {\n      acc[l][e] = float(local_in[0][l].elt[e]);\n    }\n  }\n\n// COMPUTE\n#pragma unroll\n  for (int c = 1; c < CHUNKS; c++) {\n#pragma unroll\n    for (int l = 0; l < LDGS; l++) {\n#pragma unroll\n      for (int e = 0; e < NUM_ELTS; e++) {\n        acc[l][e] += float(local_in[c][l].elt[e]);\n      }\n    }\n  }\n\n  // PACK\n  Vec_t local_out[LDGS];\n\n#pragma unroll\n  for (int l = 0; l < LDGS; l++) {\n#pragma unroll\n    for (int e = 0; e < NUM_ELTS; e++) {\n      local_out[l].elt[e] = T(acc[l][e]);\n    }\n  }\n\n// STORE\n#pragma unroll\n  for (int l = 0; l < LDGS; l++) {\n    const int offset = l * BYTES_PER_CTA;\n    stg128(ptr_out + offset, local_out[l].raw);\n  }\n}\n\nvoid fmha_run_noloop_reduce(void* out, const void* in, const int* cu_seqlens, const int hidden_size,\n                            const int batch_size, const int total, const int num_chunks, cudaStream_t stream) {\n  const int blocks = total;\n\n  if (hidden_size == 1024) {\n    constexpr int HIDDEN_SIZE = 1024;\n    constexpr int THREADS = 256;\n\n    if (num_chunks == 2) {\n      fmha_noloop_reduce_kernel<half, THREADS, HIDDEN_SIZE, 2>\n          <<<blocks, THREADS, 0, stream>>>(out, in, cu_seqlens, batch_size);\n    } else if (num_chunks == 3) {\n      fmha_noloop_reduce_kernel<half, THREADS, HIDDEN_SIZE, 3>\n          <<<blocks, THREADS, 0, stream>>>(out, in, cu_seqlens, batch_size);\n    } else {\n      assert(false && \"Unsupported num_chunks\");\n    }\n\n  } else {\n    assert(false && \"Unsupported hidden_size\");\n  }\n\n  FMHA_CHECK_CUDA(cudaPeekAtLastError());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/fmha/src/fmha_utils.h",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#include <assert.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime_api.h>\n#include <stdio.h>\n#include <stdlib.h>\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define FMHA_CHECK_CUDA(call)                                                                       \\\n  do {                                                                                              \\\n    cudaError_t status_ = call;                                                                     \\\n    if (status_ != cudaSuccess) {                                                                   \\\n      fprintf(stderr, \"CUDA error (%s:%d): %s\\n\", __FILE__, __LINE__, cudaGetErrorString(status_)); \\\n      exit(1);                                                                                      \\\n    }                                                                                               \\\n  } while (0)\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nenum Data_type { DATA_TYPE_FP16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 };\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline void set_alpha(uint32_t& alpha, float norm, Data_type dtype) {\n  if (dtype == DATA_TYPE_FP16) {\n    half x = __float2half_rn(norm);\n    uint16_t h = reinterpret_cast<const uint16_t&>(x);\n    ushort2 h2 = {h, h};\n    alpha = reinterpret_cast<const uint32_t&>(h2);\n  } else if (dtype == DATA_TYPE_FP32) {\n    alpha = reinterpret_cast<const uint32_t&>(norm);\n  } else if (dtype == DATA_TYPE_INT32) {\n    int32_t inorm = static_cast<int32_t>(norm);\n    alpha = reinterpret_cast<const uint32_t&>(inorm);\n  } else {\n    assert(false);\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline size_t get_size_in_bytes(size_t n, Data_type dtype) {\n  switch (dtype) {\n    case DATA_TYPE_FP32:\n      return n * 4;\n    case DATA_TYPE_FP16:\n      return n * 2;\n    case DATA_TYPE_INT32:\n      return n * 4;\n    case DATA_TYPE_INT8:\n      return n;\n    default:\n      assert(false);\n      return 0;\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n"
  },
  {
    "path": "apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp",
    "content": "#include <torch/torch.h>\n\n#include <cstdint>\n#include <vector>\n\n// CUDA forward declarations\n\nstd::vector<at::Tensor> focal_loss_forward_cuda(const at::Tensor& cls_output, const at::Tensor& cls_targets_at_level,\n                                                const at::Tensor& num_positives_sum, const int64_t num_real_classes,\n                                                const float alpha, const float gamma, const float smoothing_factor);\n\nat::Tensor focal_loss_backward_cuda(const at::Tensor& grad_output, const at::Tensor& partial_grad,\n                                    const at::Tensor& num_positives_sum);\n\n// C++ interface\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) \\\n  CHECK_CUDA(x);       \\\n  CHECK_CONTIGUOUS(x)\n\nstd::vector<at::Tensor> focal_loss_forward(const at::Tensor& cls_output, const at::Tensor& cls_targets_at_level,\n                                           const at::Tensor& num_positives_sum, const int64_t num_real_classes,\n                                           const float alpha, const float gamma, const float smoothing_factor) {\n  CHECK_INPUT(cls_output);\n  CHECK_INPUT(cls_targets_at_level);\n  CHECK_INPUT(num_positives_sum);\n\n  return focal_loss_forward_cuda(cls_output, cls_targets_at_level, num_positives_sum, num_real_classes, alpha, gamma,\n                                 smoothing_factor);\n}\n\nat::Tensor focal_loss_backward(const at::Tensor& grad_output, const at::Tensor& partial_grad,\n                               const at::Tensor& num_positives_sum) {\n  CHECK_INPUT(grad_output);\n  CHECK_INPUT(partial_grad);\n\n  return focal_loss_backward_cuda(grad_output, partial_grad, num_positives_sum);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &focal_loss_forward, \"Focal loss calculation forward (CUDA)\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward\", &focal_loss_backward, \"Focal loss calculation backward (CUDA)\",\n        py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n\n// Use 128-bit vectorization\ntypedef uint4 vector_t;\n\n#define ASSERT_ALIGNED(DTYPE, PTR) \\\n  TORCH_INTERNAL_ASSERT(is_aligned<DTYPE>(PTR), \"Tensor \" #PTR \" is not \" #DTYPE \" aligned\")\n\ntemplate <class T>\nbool is_aligned(const void* ptr) noexcept {\n  auto iptr = reinterpret_cast<std::uintptr_t>(ptr);\n  return !(iptr % alignof(T));\n}\n\ntemplate <bool SMOOTHING, int ILP, typename scalar_t, typename labelscalar_t, typename accscalar_t,\n          typename outscalar_t>\n__global__ void focal_loss_forward_cuda_kernel(outscalar_t* loss, scalar_t* partial_grad,\n                                               const scalar_t* __restrict__ cls_output,\n                                               const labelscalar_t* __restrict__ cls_targets_at_level,\n                                               const float* __restrict__ num_positives_sum, const int64_t num_examples,\n                                               const int64_t num_classes, const int64_t num_real_classes,\n                                               const float alpha, const float gamma, const float smoothing_factor) {\n  extern __shared__ unsigned char shm[];\n  accscalar_t* loss_shm = reinterpret_cast<accscalar_t*>(shm);\n  loss_shm[threadIdx.x] = 0;\n  accscalar_t loss_acc = 0;\n\n  accscalar_t one = accscalar_t(1.0);\n  accscalar_t K = accscalar_t(2.0);\n  accscalar_t normalizer = one / static_cast<accscalar_t>(num_positives_sum[0]);\n  accscalar_t nn_norm, np_norm, pn_norm, pp_norm;\n\n  // *_norm is used for label smoothing only\n  if (SMOOTHING) {\n    nn_norm = one - smoothing_factor / K;\n    np_norm = smoothing_factor / K;\n    pn_norm = smoothing_factor - smoothing_factor / K;\n    pp_norm = one - smoothing_factor + smoothing_factor / K;\n  }\n\n  vector_t p_vec, grad_vec;\n\n  // Accumulate loss on each thread\n  for (int64_t i = (blockIdx.x * blockDim.x + threadIdx.x) * ILP; i < num_examples * num_classes;\n       i += gridDim.x * blockDim.x * ILP) {\n    int64_t idy = i / num_classes;\n    labelscalar_t y = cls_targets_at_level[idy];\n    int64_t base_yid = i % num_classes;\n\n    int64_t pos_idx = idy * num_classes + y;\n    p_vec = *(vector_t*)&cls_output[i];  // Vectorized load\n\n    // Skip ignored matches\n    if (y == -2) {\n#pragma unroll\n      for (int j = 0; j < ILP; j++) {\n        *((scalar_t*)(&grad_vec) + j) = 0;\n      }\n      *(vector_t*)&partial_grad[i] = grad_vec;\n      continue;\n    }\n\n#pragma unroll\n    for (int j = 0; j < ILP; j++) {\n      // Skip the pad classes\n      if (base_yid + j >= num_real_classes) {\n        *((scalar_t*)(&grad_vec) + j) = 0;\n        continue;\n      }\n\n      accscalar_t p = static_cast<accscalar_t>(*((scalar_t*)(&p_vec) + j));\n      accscalar_t exp_np = ::exp(-p);\n      accscalar_t exp_pp = ::exp(p);\n      accscalar_t sigma = one / (one + exp_np);\n      accscalar_t logee = (p >= 0) ? exp_np : exp_pp;\n      accscalar_t addee = (p >= 0) ? 0 : -p;\n      accscalar_t off_a = addee + ::log(one + logee);\n\n      // Negative matches\n      accscalar_t base = SMOOTHING ? nn_norm * p : p;\n      accscalar_t off_b = (SMOOTHING ? np_norm : 0) - sigma;\n      accscalar_t coeff_f1 = one - alpha;\n      accscalar_t coeff_f2 = sigma;\n      accscalar_t coeff_b1 = gamma;\n      accscalar_t coeff_b2 = one - sigma;\n\n      // Positive matches\n      if (y >= 0 && (i + j == pos_idx)) {\n        base = SMOOTHING ? pn_norm * p : 0;\n        off_b = (SMOOTHING ? pp_norm : one) - sigma;\n        coeff_f1 = alpha;\n        coeff_f2 = one - sigma;\n        coeff_b1 = -gamma;\n        coeff_b2 = sigma;\n      }\n\n      accscalar_t coeff_f = coeff_f1 * ::pow(coeff_f2, gamma);\n      accscalar_t coeff_b = coeff_b1 * coeff_b2;\n\n      accscalar_t loss_t = coeff_f * (base + off_a);\n      accscalar_t grad = coeff_f * (coeff_b * (base + off_a) - off_b);\n\n      // Delay the normalize of partial gradient by num_positives_sum to back\n      // propagation because scalar_t reduces precision. Focal loss is very\n      // sensitive to the small gradient. No worry on overflow here since\n      // gradient has relative smaller range than input.\n      loss_acc += loss_t;\n      *((scalar_t*)(&grad_vec) + j) = static_cast<scalar_t>(grad);\n    }\n\n    // This may generate two vectorized stores instead of one\n    *(vector_t*)&partial_grad[i] = grad_vec;\n  }\n  loss_shm[threadIdx.x] = loss_acc;\n\n  // Intra-CTA reduction\n  __syncthreads();\n  for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {\n    if (threadIdx.x < s) {\n      loss_shm[threadIdx.x] += loss_shm[threadIdx.x + s];\n    }\n    __syncthreads();\n  }\n\n  // Inter-CTA reduction\n  if (threadIdx.x == 0) {\n    loss_acc = loss_shm[0] * normalizer;\n    atomicAdd(loss, loss_acc);\n  }\n}\n\ntemplate <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>\n__global__ void focal_loss_backward_cuda_kernel(scalar_t* partial_grad, const outscalar_t* __restrict__ grad_output,\n                                                const float* __restrict__ num_positives_sum, const uint64_t numel) {\n  int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * ILP;\n\n  accscalar_t normalizer = static_cast<accscalar_t>(grad_output[0]) / static_cast<accscalar_t>(num_positives_sum[0]);\n\n  // The input is enforced to pad to use vector load, thus there's no need to\n  // check whether the last element of ILP can out of bound.\n  if (idx >= numel) return;\n\n  vector_t grad_vec;\n  grad_vec = *(vector_t*)&partial_grad[idx];\n#pragma unroll(ILP)\n  for (int i = 0; i < ILP; i++) {\n    auto grad = static_cast<accscalar_t>(*((scalar_t*)(&grad_vec) + i));\n    grad *= normalizer;\n    *((scalar_t*)(&grad_vec) + i) = static_cast<scalar_t>(grad);\n  }\n  *(vector_t*)&partial_grad[idx] = grad_vec;\n}\n\nstd::vector<at::Tensor> focal_loss_forward_cuda(const at::Tensor& cls_output, const at::Tensor& cls_targets_at_level,\n                                                const at::Tensor& num_positives_sum, const int64_t num_real_classes,\n                                                const float alpha, const float gamma, const float smoothing_factor) {\n  // Checks required for correctness\n  TORCH_INTERNAL_ASSERT(cls_output.size(-1) >= num_real_classes, \"Incorrect number of real classes.\");\n  TORCH_INTERNAL_ASSERT(cls_targets_at_level.scalar_type() == at::kLong, \"Invalid label type.\");\n  TORCH_INTERNAL_ASSERT((num_positives_sum.numel() == 1) && (num_positives_sum.scalar_type() == at::kFloat),\n                        \"Expect num_positives_sum to be a float32 tensor with only one element.\");\n  TORCH_INTERNAL_ASSERT(cls_output.dim() == cls_targets_at_level.dim() + 1,\n                        \"Mis-matched dimensions between class output and label.\");\n  for (int64_t i = 0; i < cls_targets_at_level.dim(); i++)\n    TORCH_INTERNAL_ASSERT(cls_output.size(i) == cls_targets_at_level.size(i),\n                          \"Mis-matched shape between class output and label.\");\n\n  // Checks required for better performance\n  const int ILP = sizeof(vector_t) / cls_output.element_size();\n  ASSERT_ALIGNED(vector_t, cls_output.data_ptr());\n  TORCH_INTERNAL_ASSERT(cls_output.size(-1) % ILP == 0,\n                        \"Pad number of classes first to take advantage of vectorized load.\");\n  TORCH_INTERNAL_ASSERT(num_real_classes >= ILP, \"Too few classes.\");\n\n  int64_t num_classes = cls_output.size(-1);\n  int64_t num_examples = cls_output.numel() / num_classes;\n  at::Tensor loss = at::zeros({}, cls_output.options().dtype(at::kFloat));\n\n  // Compute the incompelete gradient during fprop since most of the heavy\n  // functions of bprop are the same as fprop, thus trade memory for compute\n  // helps with focal loss.\n  at::Tensor partial_grad = at::empty_like(cls_output);\n\n  // Set the number of CTAs per SM according to the compute capability.\n  // Each CTA loops on input with stride till the last item.\n  cudaDeviceProp props;\n  cudaGetDeviceProperties(&props, at::cuda::current_device());\n  int cta_per_sm = 2;\n  if (props.major >= 10) {\n    cta_per_sm = 8;\n  }\n  dim3 block(512);\n  dim3 grid(cta_per_sm * props.multiProcessorCount);\n\n  // Specialize on label smoothing or not to reduce redundant operations\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n  if (smoothing_factor == 0.0f) {\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(cls_output.scalar_type(), \"focal_loss_fprop\", [&] {\n      using accscalar_t = at::acc_type<scalar_t, true>;\n      using labelscalar_t = int64_t;\n      using outscalar_t = float;\n      const int ILP = sizeof(vector_t) / sizeof(scalar_t);\n      focal_loss_forward_cuda_kernel<false, ILP, scalar_t, labelscalar_t, accscalar_t, outscalar_t>\n          <<<grid, block, block.x * sizeof(accscalar_t), stream>>>(\n              loss.data_ptr<outscalar_t>(), partial_grad.data_ptr<scalar_t>(), cls_output.data_ptr<scalar_t>(),\n              cls_targets_at_level.data_ptr<labelscalar_t>(), num_positives_sum.data_ptr<float>(), num_examples,\n              num_classes, num_real_classes, alpha, gamma, smoothing_factor);\n    });\n  } else {\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(cls_output.scalar_type(), \"focal_loss_fprop\", [&] {\n      using accscalar_t = at::acc_type<scalar_t, true>;\n      using labelscalar_t = int64_t;\n      using outscalar_t = float;\n      const int ILP = sizeof(vector_t) / sizeof(scalar_t);\n      focal_loss_forward_cuda_kernel<true, ILP, scalar_t, labelscalar_t, accscalar_t, outscalar_t>\n          <<<grid, block, block.x * sizeof(accscalar_t), stream>>>(\n              loss.data_ptr<outscalar_t>(), partial_grad.data_ptr<scalar_t>(), cls_output.data_ptr<scalar_t>(),\n              cls_targets_at_level.data_ptr<labelscalar_t>(), num_positives_sum.data_ptr<float>(), num_examples,\n              num_classes, num_real_classes, alpha, gamma, smoothing_factor);\n    });\n  }\n\n  AT_CUDA_CHECK(cudaGetLastError());\n  return {loss, partial_grad};\n}\n\nat::Tensor focal_loss_backward_cuda(const at::Tensor& grad_output, const at::Tensor& partial_grad,\n                                    const at::Tensor& num_positives_sum) {\n  // Each thread process ILP elements\n  const int ILP = sizeof(vector_t) / partial_grad.element_size();\n  dim3 block(512);\n  dim3 grid((partial_grad.numel() + block.x * ILP - 1) / (block.x * ILP));\n\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(partial_grad.scalar_type(), \"focal_loss_bprop\", [&] {\n    using accscalar_t = at::acc_type<scalar_t, true>;\n    using outscalar_t = float;\n    const int ILP = sizeof(vector_t) / sizeof(scalar_t);\n    focal_loss_backward_cuda_kernel<ILP, scalar_t, accscalar_t, outscalar_t>\n        <<<grid, block, 0, stream>>>(partial_grad.data_ptr<scalar_t>(), grad_output.data_ptr<outscalar_t>(),\n                                     num_positives_sum.data_ptr<float>(), partial_grad.numel());\n  });\n\n  AT_CUDA_CHECK(cudaGetLastError());\n  return partial_grad;\n}\n"
  },
  {
    "path": "apex/contrib/csrc/gpu_direct_storage/gds.cpp",
    "content": "// Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n\n#include <gds.h>\n\n// torch\n#include <c10/cuda/CUDAGuard.h>\n#include <torch/torch.h>\n\n// cuda\n#include <cuda_runtime.h>\n#include <cufile.h>\n\n// file io\n#include <fcntl.h>\n\nnamespace apex::contrib::gds {\n\n// POSIX\ntemplate <class T, typename std::enable_if<std::is_integral<T>::value, std::nullptr_t>::type = nullptr>\nstd::string cuFileGetErrorString(T status) {\n  status = std::abs(status);\n  return IS_CUFILE_ERR(status) ? std::string(CUFILE_ERRSTR(status)) : std::string(std::strerror(errno));\n}\n\n// CUfileError_t\ntemplate <class T, typename std::enable_if<!std::is_integral<T>::value, std::nullptr_t>::type = nullptr>\nstd::string cuFileGetErrorString(T status) {\n  std::string errStr = cuFileGetErrorString(static_cast<int>(status.err));\n  if (IS_CUDA_ERR(status)) errStr.append(\".\").append(cudaGetErrorString(static_cast<cudaError_t>(status.cu_err)));\n  return errStr;\n}\n\nFile::File() : is_open(false) {};\n\nFile::File(const std::string& filename, const std::string& mode) : filename(filename), mode(mode), is_open(false) {\n  open(filename, mode);\n}\n\nFile::~File() {\n  if (is_open) {\n    close();\n  }\n}\n\nvoid File::open(const std::string& other_filename, const std::string& other_mode) {\n  TORCH_CHECK(is_open == false, \"file\", filename, \"is already open\");\n  if (!filename.empty()) {\n    TORCH_CHECK(other_filename == filename, \"file\", filename, \"is already open with mode\", mode);\n  }\n  if (!mode.empty()) {\n    TORCH_CHECK(other_mode == mode, \"file\", filename, \"is already open with mode\", mode);\n  }\n\n  maybe_register = true;\n  // Open the binary file\n  if (mode == \"r\") {\n    // for reading\n    fd = ::open(filename.c_str(), O_RDONLY | O_DIRECT);\n  } else if (mode == \"w\") {\n    // for writing\n    fd = ::open(filename.c_str(), O_CREAT | O_WRONLY | O_DIRECT, 0664);\n  } else if (mode == \"rn\") {\n    // for reading\n    fd = ::open(filename.c_str(), O_RDONLY);\n    maybe_register = false;\n  } else if (mode == \"wn\") {\n    // for writing\n    fd = ::open(filename.c_str(), O_CREAT | O_WRONLY, 0664);\n    maybe_register = false;\n  } else {\n    TORCH_CHECK(false, \"only r and w modes are currently supported, but got:\", mode);\n  }\n  TORCH_CHECK(fd >= 0, \"fcntl cannot open file: \", filename);\n\n  // Register cuFile handle\n  if (maybe_register) {\n    memset((void*)&cf_descr, 0, sizeof(CUfileDescr_t));\n    cf_descr.handle.fd = fd;\n    cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;\n    status = cuFileHandleRegister(&cf_handle, &cf_descr);\n    if (status.err != CU_FILE_SUCCESS) {\n      TORCH_CHECK(false, \"cuFileHandleRegister failed: \", cuFileGetErrorString(status));\n    }\n  }\n  is_open = true;\n}\n\nvoid File::close() {\n  // Deregister cuFile handle and close the file\n  if (is_open) {\n    if (maybe_register) {\n      cuFileHandleDeregister(cf_handle);\n    }\n    ::close(fd);\n    fd = -1;\n  }\n  is_open = false;\n}\n\nvoid File::load_data(const torch::Tensor& tensor) {\n  TORCH_CHECK(mode == \"r\", filename, \" was opened for read only\");\n  c10::cuda::CUDAGuard gpuGuard(tensor.device());\n\n  void* dataPtr = tensor.data_ptr();\n  const size_t nbytes = tensor.nbytes();\n\n  // Read the binary file\n  ssize_t ret = cuFileRead(cf_handle, (void*)dataPtr, nbytes, 0, 0);\n  TORCH_CHECK(ret >= 0, \"cuFileWrite failed: \", cuFileGetErrorString(ret));\n}\n\nvoid File::save_data(const torch::Tensor& tensor) {\n  TORCH_CHECK(mode == \"w\", filename, \" was opened for write only\");\n  c10::cuda::CUDAGuard gpuGuard(tensor.device());\n\n  void* dataPtr = tensor.data_ptr();\n  const size_t nbytes = tensor.nbytes();\n\n  // Register device memory\n  status = cuFileBufRegister(dataPtr, nbytes, 0);\n  TORCH_CHECK(status.err == CU_FILE_SUCCESS, \"cuFileBufRegister failed: \", cuFileGetErrorString(status));\n\n  // Write device memory contents to the file\n  ssize_t ret = cuFileWrite(cf_handle, dataPtr, nbytes, 0, 0);\n  status = cuFileBufDeregister(dataPtr);\n\n  TORCH_CHECK(ret >= 0, \"cuFileWrite failed: \", cuFileGetErrorString(ret));\n  TORCH_CHECK(status.err == CU_FILE_SUCCESS, \"cuFileBufDeregister failed:\", cuFileGetErrorString(status));\n}\n\n// Just for benchmarking purposes\n\nvoid File::load_data_no_gds(const torch::Tensor& tensor) {\n  TORCH_CHECK(mode == \"rn\", filename, \" was opened for read only\");\n  c10::cuda::CUDAGuard gpuGuard(tensor.device());\n\n  void* dataPtrCPU = nullptr;\n  void* dataPtr = tensor.data_ptr();\n  const size_t nbytes = tensor.nbytes();\n  dataPtrCPU = malloc(nbytes);\n  TORCH_CHECK(dataPtrCPU != nullptr, \"malloc failed\");\n\n  const ssize_t nbytes_read = pread(fd, dataPtrCPU, nbytes, 0);\n  TORCH_CHECK(nbytes_read == nbytes || nbytes_read == 0, \"fcntl pread failed\");\n  C10_CUDA_CHECK(cudaMemcpy(dataPtr, dataPtrCPU, nbytes, cudaMemcpyHostToDevice));\n  free(dataPtrCPU);\n}\n\nvoid File::save_data_no_gds(const torch::Tensor& tensor) {\n  TORCH_CHECK(mode == \"wn\", filename, \" was opened for write only\");\n  c10::cuda::CUDAGuard gpuGuard(tensor.device());\n\n  void* dataPtrCPU = nullptr;\n  void* dataPtr = tensor.data_ptr();\n  const size_t nbytes = tensor.nbytes();\n  dataPtrCPU = malloc(nbytes);\n  TORCH_CHECK(dataPtrCPU != nullptr, \"malloc failed\");\n  C10_CUDA_CHECK(cudaMemcpy(dataPtrCPU, dataPtr, nbytes, cudaMemcpyDeviceToHost));\n\n  const ssize_t nbytes_written = pwrite(fd, dataPtrCPU, nbytes, 0);\n  TORCH_CHECK(nbytes_written == nbytes, \"fcntl pwrite failed\");\n  free(dataPtrCPU);\n}\n\n}  // namespace apex::contrib::gds\n"
  },
  {
    "path": "apex/contrib/csrc/gpu_direct_storage/gds.h",
    "content": "// Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n\n#pragma once\n\n#include <cufile.h>\n#include <torch/torch.h>\n\n#include <string>\n\nnamespace apex::contrib::gds {\nclass File {\n public:\n  File();\n  File(const std::string& filename, const std::string& mode);\n  ~File();\n\n  void open(const std::string& filename, const std::string& mode);\n  void close();\n\n  void load_data(const torch::Tensor& tensor);\n  void save_data(const torch::Tensor& tensor);\n  void load_data_no_gds(const torch::Tensor& tensor);\n  void save_data_no_gds(const torch::Tensor& tensor);\n\n private:\n  std::string filename;\n  std::string mode;\n\n  CUfileDescr_t cf_descr;\n  CUfileHandle_t cf_handle;\n  CUfileError_t status;\n\n  int fd = -1;\n  bool is_open = false;\n  bool maybe_register = true;\n};\n}  // namespace apex::contrib::gds\n"
  },
  {
    "path": "apex/contrib/csrc/gpu_direct_storage/gds_pybind.cpp",
    "content": "// Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n\n#include <gds.h>\n#include <torch/extension.h>\n#include <torch/torch.h>\n\n#include <string>\n\n// python bindings\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  py::class_<apex::contrib::gds::File, std::shared_ptr<apex::contrib::gds::File>>(m, \"_GDSFile\")\n      .def(py::init<>())\n      .def(py::init<const std::string&, const std::string&>())\n      .def(\"open\", &apex::contrib::gds::File::open)\n      .def(\"close\", &apex::contrib::gds::File::close)\n      .def(\"load_data\", &apex::contrib::gds::File::load_data)\n      .def(\"save_data\", &apex::contrib::gds::File::save_data)\n      .def(\"load_data_no_gds\", &apex::contrib::gds::File::load_data_no_gds)\n      .def(\"save_data_no_gds\", &apex::contrib::gds::File::save_data_no_gds);\n}\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc.cpp",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n#include <assert.h>\n#include <float.h>\n#include <group_norm_nhwc.h>\n#include <group_norm_nhwc_bwd_one_pass.h>\n#include <group_norm_nhwc_fwd_one_pass.h>\n#include <string.h>\n#include <traits.h>\n\n#include <type_traits>\n\ntemplate <typename T>\nfloat inline unpack(const T& x) {\n  return {};\n}\n\ntemplate <>\nfloat inline unpack(const __half& x) {\n  return __half2float(x);\n}\n\ntemplate <>\nfloat inline unpack(const __nv_bfloat16& x) {\n  return __bfloat162float(x);\n}\n\ntemplate <>\nfloat inline unpack(const float& x) {\n  return x;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T>\nvoid check_results(const char* name, const T* out, const T* ref, size_t elts, float tol) {\n  // The number of errors.\n  int failed = 0;\n  // The number of infinite value.\n  int infs = 0;\n  // The min/max values.\n  float min_val = FLT_MAX, max_val = -FLT_MAX, max_err = 0.f;\n  // The total sum of error.\n  double sum_err = 0.0;\n\n  // The case we are checking.\n  printf(\"\\e[1;34mchecking.....................: %s\\e[0m\\n\", name);\n  fflush(stdout);\n\n  // Iterate over the different values.\n  for (size_t ii = 0; ii < elts; ++ii) {\n    float a = unpack(out[ii]);\n    float b = unpack(ref[ii]);\n\n    // Compute the absolute norms.\n    float abs_a = fabsf(a);\n    float abs_b = fabsf(b);\n\n    // Compute the error.\n    float den = abs_a + abs_b;\n    // Is one of the quantities very small?\n    bool is_small = abs_a <= tol || abs_b <= tol || den <= tol;\n    // The error.\n    float err = is_small ? fabsf(a - b) : fabsf(a - b) / den;\n    // Is the result ok?\n    bool ok = !isnan(a) && !isnan(b) && err <= tol;\n\n    // Print the error.\n    if (!ok && (failed < 10 || err > max_err)) {\n      fprintf(stderr, \">> invalid result for ii=%lu:\\n\", ii);\n      if (std::is_same<T, __half>::value || std::is_same<T, __nv_bfloat16>::value) {\n        // The data.\n        fprintf(stderr, \">>   found...: 0x%04x (%10.6f)\\n\", reinterpret_cast<const uint16_t&>(out[ii]), a);\n        fprintf(stderr, \">>   expected: 0x%04x (%10.6f)\\n\", reinterpret_cast<const uint16_t&>(ref[ii]), b);\n      } else if (std::is_same<T, float>::value) {\n        fprintf(stderr, \">>   found...: 0x%08x (%10.6f)\\n\", reinterpret_cast<const uint32_t&>(a), a);\n        fprintf(stderr, \">>   expected: 0x%08x (%10.6f)\\n\", reinterpret_cast<const uint32_t&>(b), b);\n      } else {\n        fprintf(stderr, \"\\e[1;34mUnknown type of check_results\\e[0m\\n\");\n        exit(1);\n      }\n      fprintf(stderr, \">>   error...: %.6f\\n\", err);\n    }\n\n    // Update the number of failures.\n    failed += ok ? 0 : 1;\n\n    // Measure min/max errors.\n    min_val = fminf(min_val, a);\n    max_val = fmaxf(max_val, a);\n    max_err = fmaxf(max_err, err);\n\n    // Accumulate the sum.\n    sum_err = sum_err + (double)err;\n\n    infs += !isfinite(a);\n    infs += !isfinite(b);\n  }\n\n  if (!failed && infs < 10) {\n    printf(\"\\e[1;32mcheck........................: OK\\e[0m\\n\");\n  } else {\n    printf(\"\\e[1;31mcheck........................: FAILED\\e[0m\\n\");\n  }\n\n  printf(\"tested.......................: %lu\\n\", elts);\n  printf(\"failures.....................: %d\\n\", failed);\n  printf(\"failure rate.................: %.2lf%%\\n\", (double)failed * 100.0 / (double)elts);\n  printf(\"infs.........................: %d\\n\", infs);\n  printf(\"tolerance....................: %.8f\\n\", tol);\n  printf(\"\\n\");\n\n  printf(\"min. value...................: %.6f\\n\", min_val);\n  printf(\"max. value...................: %.6f\\n\", max_val);\n  printf(\"max. error...................: %.6f\\n\", max_err);\n  printf(\"sum. error...................: %.6lf\\n\", sum_err);\n  printf(\"avg. error...................: %.6lf\\n\", sum_err / (double)elts);\n  printf(\"\\n\");\n}\n\ntemplate void check_results(const char* name, const __half* out, const __half* ref, size_t elts, float tol);\n\ntemplate void check_results(const char* name, const __nv_bfloat16* out, const __nv_bfloat16* ref, size_t elts,\n                            float tol);\n\ntemplate void check_results(const char* name, const float* out, const float* ref, size_t elts, float tol);\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic void group_norm_nhwc_bwd_(void* dx_h, float* dgamma_h, float* dbeta_h, const void* dy_h, const void* x_h,\n                                 const float* gamma_h, const float* beta_h, const float2* sums_h, float epsilon, int n,\n                                 int h, int w, int c, int groups, bool with_swish, bool use_fp32, bool use_bf16) {\n  // The number of channels in each group.\n  int channels_per_group = c / groups;\n  // The normalization term to compute the means.\n  float rcp_hwc_per_group = 1.f / (float)(h * w * channels_per_group);\n\n  // The array to compute gamma.\n  float* dgamma = (float*)malloc(c * sizeof(float));\n  // The array to compute beta.\n  float* dbeta = (float*)malloc(c * sizeof(float));\n\n  // Set gamma/beta to 0.\n  memset(dgamma, 0, c * sizeof(float));\n  memset(dbeta, 0, c * sizeof(float));\n\n  // Normalize the activations.\n  for (int ni = 0; ni < n; ++ni) {\n    for (int gi = 0; gi < groups; ++gi) {\n      // The sums from the fwd pass.\n      float2 sums = sums_h[ni * groups + gi];\n      // The mean of X (computed during the fwd pass -- one value per batch*group).\n      float x_mean = sums.x;\n      // The mean of squares of X (computed during the fwd pass -- one value per batch*group).\n      float x_sq_mean = sums.y;\n      // The variance.\n      float x_var = x_sq_mean - x_mean * x_mean;\n      // The reciprocal of the standard deviation (i.e. 1.f / sqrt(var + epsilon)).\n      float rcp_x_stddev = x_var <= 0.f ? 1.f : 1.f / sqrtf(x_var + epsilon);\n\n      // TODO: We should store rcp_x_stddev instead of the sums of squares.\n\n      // The following nested loops compute 2 means.\n      float mean_1 = 0.f, mean_2 = 0.f;\n\n      // Iterate over the activations in the group.\n      for (int hi = 0; hi < h; ++hi) {\n        for (int wi = 0; wi < w; ++wi) {\n          for (int ii = 0; ii < channels_per_group; ++ii) {\n            // The channel.\n            int ci = gi * channels_per_group + ii;\n            // Compute the src/dst offset.\n            size_t offset = (size_t)ni * h * w * c + (size_t)hi * w * c + (size_t)wi * c + (size_t)ci;\n            // Convert the element at that position to float.\n            float x;\n            if (use_fp32) {\n              x = reinterpret_cast<const float*>(x_h)[offset];\n            } else if (use_bf16) {\n              x = __bfloat162float(reinterpret_cast<const __nv_bfloat16*>(x_h)[offset]);\n            } else {\n              x = __half2float(reinterpret_cast<const __half*>(x_h)[offset]);\n            }\n            // The output.\n            float dy;\n            if (use_fp32) {\n              dy = reinterpret_cast<const float*>(dy_h)[offset];\n            } else if (use_bf16) {\n              dy = __bfloat162float(reinterpret_cast<const __nv_bfloat16*>(dy_h)[offset]);\n            } else {\n              dy = __half2float(reinterpret_cast<const __half*>(dy_h)[offset]);\n            }\n\n            // Gamma.\n            float gamma = gamma_h[ci];\n\n            // X - X_mean.\n            float x_minus_x_mean = x - x_mean;\n            // Normalize X.\n            float x_norm = x_minus_x_mean * rcp_x_stddev;\n\n            if (with_swish) {\n              // Beta\n              float beta = beta_h[ci];\n\n              float x_gn = x_norm * gamma + beta;\n              float s = sigmoid(x_gn);\n              dy = dy * s * (1.f + x_gn * (1.f - s));\n            }\n\n            // Compute the gradient for beta.\n            dbeta[ci] += dy;\n\n            // Compute the gradient for gamma.\n            dgamma[ci] += dy * x_norm;\n\n            // The gradient that enters the x_norm node.\n            float dx_norm = dy * gamma;\n\n            // Accumulators over 2 means\n            mean_1 += x_norm * dx_norm;\n            mean_2 += dx_norm;\n\n          }  // ii\n        }  // wi\n      }  // hi\n\n      mean_1 *= rcp_hwc_per_group;\n      mean_2 *= rcp_hwc_per_group;\n\n      // Iterate over the activations in the group.\n      for (int hi = 0; hi < h; ++hi) {\n        for (int wi = 0; wi < w; ++wi) {\n          for (int ii = 0; ii < channels_per_group; ++ii) {\n            // The channel.\n            int ci = gi * channels_per_group + ii;\n            // Compute the src/dst offset.\n            size_t offset = (size_t)ni * h * w * c + (size_t)hi * w * c + (size_t)wi * c + (size_t)ci;\n            float x;\n            if (use_fp32) {\n              x = reinterpret_cast<const float*>(x_h)[offset];\n            } else if (use_bf16) {\n              x = __bfloat162float(reinterpret_cast<const __nv_bfloat16*>(x_h)[offset]);\n            } else {\n              x = __half2float(reinterpret_cast<const __half*>(x_h)[offset]);\n            }\n            // The output.\n            float dy;\n            if (use_fp32) {\n              dy = reinterpret_cast<const float*>(dy_h)[offset];\n            } else if (use_bf16) {\n              dy = __bfloat162float(reinterpret_cast<const __nv_bfloat16*>(dy_h)[offset]);\n            } else {\n              dy = __half2float(reinterpret_cast<const __half*>(dy_h)[offset]);\n            }\n\n            // Gamma.\n            float gamma = gamma_h[ci];\n\n            // X - X_mean.\n            float x_minus_x_mean = x - x_mean;\n            // Normalize X.\n            float x_norm = x_minus_x_mean * rcp_x_stddev;\n\n            if (with_swish) {\n              // Beta\n              float beta = beta_h[ci];\n\n              float x_gn = x_norm * gamma + beta;\n              float s = sigmoid(x_gn);\n              dy = dy * s * (1.f + x_gn * (1.f - s));\n            }\n\n            // The gradient that enters the x_norm node.\n            float dx_norm = dy * gamma;\n\n            // Input gradient\n            float dx = (dx_norm - (x_norm * mean_1 + mean_2)) * rcp_x_stddev;\n\n            // Set the output gradient.\n            if (use_fp32) {\n              reinterpret_cast<float*>(dx_h)[offset] = dx;\n            } else if (use_bf16) {\n              reinterpret_cast<__nv_bfloat16*>(dx_h)[offset] = __float2bfloat16_rn(dx);\n            } else {\n              reinterpret_cast<__half*>(dx_h)[offset] = __float2half_rn(dx);\n            }\n\n          }  // ii\n        }  // wi\n      }  // hi\n\n    }  // gi\n  }  // ni\n\n  // Store gamma/beta.\n  for (int ci = 0; ci < c; ++ci) {\n    dgamma_h[ci] = dgamma[ci];\n    dbeta_h[ci] = dbeta[ci];\n  }\n\n  // Release temporary memory.\n  free(dgamma);\n  free(dbeta);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic void group_norm_nhwc_fwd_(void* y_h, const void* x_h, const float* gamma_h, const float* beta_h, float epsilon,\n                                 int n, int h, int w, int c, int groups, bool with_swish, bool use_fp32,\n                                 bool use_bf16) {\n  // The number of channels in each group.\n  int channels_per_group = c / groups;\n\n  // The normalization term to compute the means.\n  float inv_hwcg = 1.f / (float)(h * w * channels_per_group);\n\n  // Normalize the activations.\n  for (int ni = 0; ni < n; ++ni) {\n    for (int gi = 0; gi < groups; ++gi) {\n      // The sums to compute the mean/variance for that group.\n      float sum = 0.f, sum_sq = 0.f;\n\n      // Iterate over the activations in the group.\n      for (int hi = 0; hi < h; ++hi) {\n        for (int wi = 0; wi < w; ++wi) {\n          for (int ii = 0; ii < channels_per_group; ++ii) {\n            // The channel.\n            int ci = gi * channels_per_group + ii;\n            // Compute the src/dst offset.\n            size_t offset = (size_t)ni * h * w * c + (size_t)hi * w * c + (size_t)wi * c + (size_t)ci;\n            // Convert the element at that position to float.\n            float x;\n            if (use_fp32) {\n              x = reinterpret_cast<const float*>(x_h)[offset];\n            } else if (use_bf16) {\n              x = __bfloat162float(reinterpret_cast<const __nv_bfloat16*>(x_h)[offset]);\n            } else {\n              x = __half2float(reinterpret_cast<const __half*>(x_h)[offset]);\n            }\n\n            // Update the sums.\n            sum += x;\n            sum_sq += x * x;\n\n          }  // ii\n        }  // wi\n      }  // hi\n\n      // Compute the mean.\n      float mean = sum * inv_hwcg;\n      // Compute the average value for the squares.\n      float mean_sq = sum_sq * inv_hwcg;\n      // Compute the variance.\n      float var = mean_sq - (mean * mean);\n      // Invert the variance.\n      float inv_stddev = var <= 0.f ? 1.f : (1.f / sqrtf(var + epsilon));\n\n      // Iterate over the data to normalize the output.\n      for (int hi = 0; hi < h; ++hi) {\n        for (int wi = 0; wi < w; ++wi) {\n          for (int ii = 0; ii < channels_per_group; ++ii) {\n            // The channel.\n            int ci = gi * channels_per_group + ii;\n            // Compute the src/dst offset.\n            size_t offset = (size_t)ni * h * w * c + (size_t)hi * w * c + (size_t)wi * c + (size_t)ci;\n            // Normalize.\n            float x;\n            if (use_fp32) {\n              x = reinterpret_cast<const float*>(x_h)[offset];\n            } else if (use_bf16) {\n              x = __bfloat162float(reinterpret_cast<const __nv_bfloat16*>(x_h)[offset]);\n            } else {\n              x = __half2float(reinterpret_cast<const __half*>(x_h)[offset]);\n            }\n            float y = (x - mean) * inv_stddev;\n            // Scale with gamma and add beta.\n            y = y * gamma_h[ci] + beta_h[ci];\n            // Apply swish (if needed).\n            if (with_swish) {\n              y = y * sigmoid(y);\n            }\n            // Store the result.\n            if (use_fp32) {\n              reinterpret_cast<float*>(y_h)[offset] = y;\n            } else if (use_bf16) {\n              reinterpret_cast<__nv_bfloat16*>(y_h)[offset] = __float2bfloat16_rn(y);\n            } else {\n              reinterpret_cast<__half*>(y_h)[offset] = __float2half_rn(y);\n            }\n\n          }  // ii\n        }  // wi\n      }  // hi\n    }  // gi\n  }  // ni\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T>\nvoid random_data(T* dst_h, size_t n, bool use_1s, int range = 3) {\n  for (size_t ii = 0; ii < n; ++ii) {\n    float x = 1.f;\n    if (!use_1s) {\n      x = (float)(rand() % range - (range / 2));\n    }\n    if (std::is_same<T, __half>::value) {\n      dst_h[ii] = __float2half_rn(x);\n    } else if (std::is_same<T, float>::value) {\n      dst_h[ii] = x;\n    } else if (std::is_same<T, __nv_bfloat16>::value) {\n      dst_h[ii] = __float2bfloat16_rn(x);\n    } else {\n      fprintf(stderr, \"\\e[1;34mUnknown type of random_data\\e[0m\\n\");\n      exit(1);\n    }\n  }\n}\n\ntemplate void random_data(float* dst_h, size_t n, bool use_1s, int range);\n\ntemplate void random_data(__half* dst_h, size_t n, bool use_1s, int range);\n\ntemplate void random_data(__nv_bfloat16* dst_h, size_t n, bool use_1s, int range);\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nenum class Mode { FWD_INFERENCE, FWD_TRAINING, BWD };\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nint main(int argc, char** argv) {\n  // The tensor size.\n  int n = 2, h = 64, w = 64, c = 320, groups = 32;\n  // The default mode is inference.\n  Mode mode = Mode::FWD_INFERENCE;\n  // The constant epsilon for sqrt(var + epsilon).\n  float epsilon = 1.e-5f;\n  // Do we fuse with the Swish activation function?\n  bool with_swish = false;\n  // Do we use the one-pass kernel?\n  bool use_one_pass = false;\n  // The number of runs to time the code.\n  int runs = 1;\n  // Do we use 1s for the input data.\n  bool use_1s = false;\n  // The tolerance to check the results.\n  float tol = 1.e-3f;\n  // Do we skip the checks?\n  bool skip_checks = false;\n  // Do we output csv format only\n  bool csv_output = false;\n  // Use fp32 IO\n  bool use_fp32 = false;\n  // Use bf16 IO\n  bool use_bf16 = false;\n\n  // Parse the parameters.\n  for (int ii = 1; ii < argc; ++ii) {\n    if (!strcmp(argv[ii], \"-1s\")) {\n      use_1s = true;\n    } else if (!strcmp(argv[ii], \"-bwd\")) {\n      mode = Mode::BWD;\n    } else if (!strcmp(argv[ii], \"-c\") && ++ii < argc) {\n      c = strtol(argv[ii], nullptr, 10);\n    } else if (!strcmp(argv[ii], \"-epsilon\") && ++ii < argc) {\n      epsilon = (float)strtod(argv[ii], nullptr);\n    } else if (!strcmp(argv[ii], \"-fwd\")) {\n      mode = Mode::FWD_INFERENCE;\n    } else if (!strcmp(argv[ii], \"-fwd-tr\")) {\n      mode = Mode::FWD_TRAINING;\n    } else if (!strcmp(argv[ii], \"-groups\") && ++ii < argc) {\n      groups = strtol(argv[ii], nullptr, 10);\n    } else if (!strcmp(argv[ii], \"-h\") && ++ii < argc) {\n      h = strtol(argv[ii], nullptr, 10);\n    } else if (!strcmp(argv[ii], \"-n\") && ++ii < argc) {\n      n = strtol(argv[ii], nullptr, 10);\n    } else if (!strcmp(argv[ii], \"-one-pass\")) {\n      use_one_pass = true;\n    } else if (!strcmp(argv[ii], \"-runs\") && ++ii < argc) {\n      runs = strtol(argv[ii], nullptr, 10);\n    } else if (!strcmp(argv[ii], \"-skip-checks\")) {\n      skip_checks = true;\n    } else if (!strcmp(argv[ii], \"-tol\") && ++ii < argc) {\n      tol = (float)strtod(argv[ii], nullptr);\n    } else if (!strcmp(argv[ii], \"-w\") && ++ii < argc) {\n      w = strtol(argv[ii], nullptr, 10);\n    } else if (!strcmp(argv[ii], \"-with-swish\")) {\n      with_swish = true;\n    } else if (!strcmp(argv[ii], \"-csv\")) {\n      csv_output = true;\n    } else if (!strcmp(argv[ii], \"-fp32\")) {\n      use_fp32 = true;\n    } else if (!strcmp(argv[ii], \"-bf16\")) {\n      use_bf16 = true;\n    } else if (ii < argc) {\n      fprintf(stderr, \"Unknown argument: %s\\n\", argv[ii]);\n      return 1;\n    } else {\n      fprintf(stderr, \"Argument %s requires a value\\n\", argv[ii - 1]);\n      return 1;\n    }\n  }\n\n  if (use_bf16 && use_fp32) {\n    fprintf(stderr, \"Can't use fp32 and bf16 IO at the same time\\n\");\n    return 1;\n  }\n\n  // Header.\n  if (!csv_output) {\n    printf(\"\\n\");\n    printf(\"#######################################################################\\n\");\n    printf(\"# Group Norm NHWC + Swish kernel\\n\");\n    printf(\"# --\\n\");\n    printf(\"# Compiled on %s\\n\", __DATE__);\n    printf(\"#######################################################################\\n\");\n    printf(\"\\n\");\n  }\n\n  // GPU info.\n  cudaDeviceProp props;\n  CHECK_CUDA(cudaGetDeviceProperties(&props, 0));\n  if (!csv_output) {\n    printf(\"device.......................: %s\\n\", props.name);\n    printf(\"cc...........................: %d.%d\\n\", props.major, props.minor);\n    printf(\"# of sms.....................: %d\\n\", props.multiProcessorCount);\n  }\n\n  // Dram peak bandwidth.\n  float dram_clock = props.memoryClockRate / 1.e6f;\n  float dram_peak = 2.f * dram_clock * props.memoryBusWidth / 8.f;\n  if (!csv_output) {\n    printf(\"dram clock...................: %.3f GHz\\n\", dram_clock);\n    printf(\"dram peak....................: %.3f TB/s\\n\", dram_peak * 1.e-3f);\n    printf(\"\\n\");\n  }\n\n  // Output the problem size.\n  if (!csv_output) {\n    printf(\"n............................: %d\\n\", n);\n    printf(\"h............................: %d\\n\", h);\n    printf(\"w............................: %d\\n\", w);\n    printf(\"c............................: %d\\n\", c);\n    printf(\"groups.......................: %d\\n\", groups);\n    printf(\"epsilon......................: %f\\n\", epsilon);\n    printf(\"with swish...................: %s\\n\", with_swish ? \"true\" : \"false\");\n    printf(\"channels per group...........: %d\\n\", c / groups);\n    if (mode == Mode::BWD) {\n      printf(\"mode.........................: bwd\\n\");\n    } else if (mode == Mode::FWD_INFERENCE) {\n      printf(\"mode.........................: fwd inference\\n\");\n    } else if (mode == Mode::FWD_TRAINING) {\n      printf(\"mode.........................: fwd training\\n\");\n    } else {\n      assert(false);\n    }\n    printf(\"\\n\");\n  }\n\n  // Compute the SOL.\n  double bytes = 0;\n  int32_t io_bytes = use_fp32 ? sizeof(float) : sizeof(__half);\n  if (mode != Mode::BWD) {\n    bytes = (double)n * h * w * c * io_bytes +  // src\n            (double)c * 4 +                     // gamma\n            (double)c * 4 +                     // beta\n            (double)n * h * w * c * io_bytes;   // out\n  } else {\n    bytes = (double)n * h * w * c * io_bytes * 2 +  // src, dsrc\n            (double)c * 4 * 2 +                     // gamma, dgamma\n            (double)c * 4 * 2 +                     // beta, dbeta\n            (double)n * h * w * c * io_bytes * 1;   // dout\n  }\n  double gbytes = bytes * 1.e-9;\n  double dram_sol = gbytes / dram_peak * 1.e3;\n  if (!csv_output) {\n    printf(\"bytes........................: %.3lfGB\\n\", gbytes);\n    printf(\"dram sol.....................: %.6lfms\\n\", dram_sol);\n\n    // The number of runs to measure performance.\n    printf(\"runs.........................: %d\\n\", runs);\n    printf(\"\\n\");\n  }\n\n  // The number of elements in the x tensor. The layout is N x H x W x C.\n  size_t x_elts = (size_t)n * h * w * c;\n  // The size of the src in bytes.\n  size_t x_sz = x_elts * io_bytes;\n\n  // Allocate the src/dst on the host.\n  void* x_h = malloc(x_sz);\n  void* y_h = malloc(x_sz);\n\n  // Allocate src/dst on the device.\n  void *x_d, *y_d;\n  CHECK_CUDA(cudaMalloc((void**)&x_d, x_sz));\n  CHECK_CUDA(cudaMalloc((void**)&y_d, x_sz));\n\n  // The number of elements in the gamma/beta array.\n  size_t gamma_elts = (size_t)c;\n  // The size of the gamma/beta array in bytes.\n  size_t gamma_sz = gamma_elts * sizeof(float);\n  // Allocate gamma/beta on the host.\n  float* gamma_h = (float*)malloc(gamma_sz);\n  // Allocate gamma/beta on the device.\n  float* gamma_d;\n  CHECK_CUDA(cudaMalloc((void**)&gamma_d, gamma_sz));\n\n  // Allocate gamma/beta on the host.\n  float* beta_h = (float*)malloc(gamma_sz);\n  // Allocate gamma/beta on the device.\n  float* beta_d;\n  CHECK_CUDA(cudaMalloc((void**)&beta_d, gamma_sz));\n\n  // Allocate the reference on the host (to be computed on the host).\n  void* y_ref_h = nullptr;\n  if (!skip_checks) {\n    y_ref_h = malloc(x_sz);\n  }\n\n  // Allocate the src/dst on the host for the gradients (bwd).\n  void *dx_h = nullptr, *dy_h = nullptr;\n  if (mode == Mode::BWD) {\n    dx_h = malloc(x_sz);\n    dy_h = malloc(x_sz);\n  }\n\n  // Allocate src/dst on the device.\n  void *dx_d = nullptr, *dy_d = nullptr;\n  if (mode == Mode::BWD) {\n    CHECK_CUDA(cudaMalloc((void**)&dx_d, x_sz));\n    CHECK_CUDA(cudaMalloc((void**)&dy_d, x_sz));\n  }\n\n  // The gradients for gamma and beta on the host.\n  float *dgamma_h = nullptr, *dbeta_h = nullptr;\n  if (mode == Mode::BWD) {\n    dgamma_h = (float*)malloc(gamma_sz);\n    dbeta_h = (float*)malloc(gamma_sz);\n  }\n\n  // The gradients for gamma and beta on the device.\n  float *dgamma_d = nullptr, *dbeta_d = nullptr;\n  if (mode == Mode::BWD) {\n    CHECK_CUDA(cudaMalloc((void**)&dgamma_d, gamma_sz));\n    CHECK_CUDA(cudaMalloc((void**)&dbeta_d, gamma_sz));\n  }\n\n  // The number of sums for the bwd pass.\n  size_t sums_elts = mode == Mode::FWD_INFERENCE ? 0 : n * groups;\n  // The size needed to store that array.\n  size_t sums_sz = sums_elts * sizeof(float2);\n\n  // The sums for the bwd pass on the host.\n  float2* sums_h = nullptr;\n  if (sums_sz > 0) {\n    sums_h = (float2*)malloc(sums_sz);\n  }\n\n  // The sums for the bwd pass on the device.\n  float2* sums_d = nullptr;\n  if (sums_sz > 0) {\n    CHECK_CUDA(cudaMalloc((void**)&sums_d, sums_sz));\n  }\n\n  // Allocate the reference on the host (to be computed on the host).\n  void* dx_ref_h = nullptr;\n  if (mode == Mode::BWD && !skip_checks) {\n    dx_ref_h = malloc(x_sz);\n  }\n\n  // Allocate the reference on the host (to be computed on the host).\n  float *dgamma_ref_h = nullptr, *dbeta_ref_h = nullptr;\n  if (mode == Mode::BWD && !skip_checks) {\n    dgamma_ref_h = (float*)malloc(gamma_sz);\n    dbeta_ref_h = (float*)malloc(gamma_sz);\n  }\n\n  // Generate random input data for the forward pass.\n  if (use_fp32) {\n    random_data<float>(reinterpret_cast<float*>(x_h), x_elts, use_1s);\n  } else if (use_bf16) {\n    random_data<__nv_bfloat16>(reinterpret_cast<__nv_bfloat16*>(x_h), x_elts, use_1s);\n  } else {\n    random_data<__half>(reinterpret_cast<__half*>(x_h), x_elts, use_1s);\n  }\n  random_data<float>(gamma_h, gamma_elts, use_1s);\n  random_data<float>(beta_h, gamma_elts, use_1s);\n\n  // Generate the gradients for the bwd pass.\n  if (mode == Mode::BWD) {\n    if (use_fp32) {\n      random_data<float>(reinterpret_cast<float*>(dy_h), x_elts, use_1s);\n    } else if (use_bf16) {\n      random_data<__nv_bfloat16>(reinterpret_cast<__nv_bfloat16*>(dy_h), x_elts, use_1s);\n    } else {\n      random_data<__half>(reinterpret_cast<__half*>(dy_h), x_elts, use_1s);\n    }\n  }\n\n  // Precompute the sums (from the fwd pass) for bwd.\n  if (mode == Mode::BWD) {\n    // Clear the array of sums (all the elements are set to 0.f).\n    memset(sums_h, 0, sums_sz);\n\n    // The number of channels in each group.\n    int channels_per_group = c / groups;\n    // Iterate over the different groups.\n    for (int ni = 0; ni < n; ++ni) {\n      for (int gi = 0; gi < groups; ++gi) {\n        for (int hi = 0; hi < h; ++hi) {\n          for (int wi = 0; wi < w; ++wi) {\n            for (int ii = 0; ii < channels_per_group; ++ii) {\n              // The position of the channel.\n              int ci = gi * channels_per_group + ii;\n              // The offset to the element.\n              int64_t offset = (int64_t)ni * h * w * c + hi * w * c + wi * c + ci;\n              // The element in float.\n              float x;\n              if (use_fp32) {\n                x = reinterpret_cast<float*>(x_h)[offset];\n              } else if (use_bf16) {\n                x = __bfloat162float(reinterpret_cast<__nv_bfloat16*>(x_h)[offset]);\n              } else {\n                x = __half2float(reinterpret_cast<__half*>(x_h)[offset]);\n              }\n\n              // Update the sums (sum of X and sum of squares).\n              sums_h[ni * groups + gi].x += x;\n              sums_h[ni * groups + gi].y += x * x;\n            }\n          }\n        }\n      }\n    }\n\n    // The normalization term to compute the means.\n    float rcp_hwc_per_group = 1.f / (float)(h * w * channels_per_group);\n    // Normalize the sums.\n    for (int ngi = 0; ngi < n * groups; ++ngi) {\n      sums_h[ngi].x *= rcp_hwc_per_group;\n      sums_h[ngi].y *= rcp_hwc_per_group;\n    }\n  }\n\n  // Compute the golden reference on the host.\n  if (!skip_checks) {\n    if (mode == Mode::BWD) {\n      group_norm_nhwc_bwd_(dx_ref_h, dgamma_ref_h, dbeta_ref_h, dy_h, x_h, gamma_h, beta_h, sums_h, epsilon, n, h, w, c,\n                           groups, with_swish, use_fp32, use_bf16);\n    } else {\n      group_norm_nhwc_fwd_(y_ref_h, x_h, gamma_h, beta_h, epsilon, n, h, w, c, groups, with_swish, use_fp32, use_bf16);\n    }\n  }\n\n  // Copy to the device.\n  CHECK_CUDA(cudaMemcpyAsync(x_d, x_h, x_sz, cudaMemcpyHostToDevice, cudaStreamDefault));\n  CHECK_CUDA(cudaMemcpyAsync(gamma_d, gamma_h, gamma_sz, cudaMemcpyHostToDevice, cudaStreamDefault));\n  CHECK_CUDA(cudaMemcpyAsync(beta_d, beta_h, gamma_sz, cudaMemcpyHostToDevice, cudaStreamDefault));\n\n  if (mode == Mode::BWD) {\n    CHECK_CUDA(cudaMemcpyAsync(dy_d, dy_h, x_sz, cudaMemcpyHostToDevice, cudaStreamDefault));\n\n    // // DEBUG.\n    // printf(\"sums_h[0] = %8.3f, %8.3f\\n\", sums_h[0].x, sums_h[0].y);\n    // // END OF DEBUG.\n\n    CHECK_CUDA(cudaMemcpyAsync(sums_d, sums_h, sums_sz, cudaMemcpyHostToDevice, cudaStreamDefault));\n  }\n\n  // Reset the output buffer with garbage to detect invalid results.\n  if (mode == Mode::BWD) {\n    CHECK_CUDA(cudaMemsetAsync(dx_d, 0xdc, x_sz, cudaStreamDefault));\n    CHECK_CUDA(cudaMemsetAsync(dgamma_d, 0xdc, gamma_sz, cudaStreamDefault));\n    CHECK_CUDA(cudaMemsetAsync(dbeta_d, 0xdc, gamma_sz, cudaStreamDefault));\n  } else {\n    CHECK_CUDA(cudaMemsetAsync(y_d, 0xdc, x_sz, cudaStreamDefault));\n  }\n\n  // Declare the parameters.\n  Group_norm_nhwc_fwd_params params_fwd;\n  memset(&params_fwd, 0, sizeof(params_fwd));\n  Group_norm_nhwc_bwd_params params_bwd;\n  memset(&params_bwd, 0, sizeof(params_bwd));\n\n  const auto precision = [&]() -> PrecisionMode {\n    if (use_fp32) {\n      return PrecisionMode::FP32IOFP32W;\n    } else if (use_bf16) {\n      return PrecisionMode::BF16IOFP32W;\n    } else {\n      return PrecisionMode::FP16IOFP32W;\n    }\n  }();\n\n  // Initialize the parameters.\n  if (mode == Mode::BWD) {\n    params_bwd.dx = dx_d;\n    params_bwd.dgamma = dgamma_d;\n    params_bwd.dbeta = dbeta_d;\n    params_bwd.sums = sums_d;\n    params_bwd.dy = dy_d;\n    params_bwd.x = x_d;\n    params_bwd.gamma = gamma_d;\n    params_bwd.beta = beta_d;\n    params_bwd.epsilon = epsilon;\n    params_bwd.n = n;\n    params_bwd.h = h;\n    params_bwd.w = w;\n    params_bwd.c = c;\n    params_bwd.groups = groups;\n    params_bwd.with_swish = with_swish;\n    params_bwd.precision = precision;\n  } else {\n    params_fwd.y = y_d;\n    params_fwd.sums = sums_d;\n    params_fwd.x = x_d;\n    params_fwd.gamma = gamma_d;\n    params_fwd.beta = beta_d;\n    params_fwd.epsilon = epsilon;\n    params_fwd.n = n;\n    params_fwd.h = h;\n    params_fwd.w = w;\n    params_fwd.c = c;\n    params_fwd.groups = groups;\n    params_fwd.with_swish = with_swish;\n    params_fwd.precision = precision;\n  }\n\n  // The number of barriers.\n  size_t barriers_elts = 0;\n  // The number of elements in the reduction buffer.\n  size_t red_buffer_elts = 0;\n  // The number of elements in the reduction buffer that must be zeroed.\n  size_t zeroed_red_buffer_elts = 0;\n\n  // Finalize the parameters.\n  dim3 grid;\n  if (mode == Mode::BWD && use_one_pass) {\n    group_norm_nhwc_bwd_one_pass_setup(params_bwd, barriers_elts, red_buffer_elts, zeroed_red_buffer_elts, grid, props);\n  } else if (mode == Mode::BWD) {\n    group_norm_nhwc_bwd_two_passes_setup(params_bwd, zeroed_red_buffer_elts);\n  } else if (use_one_pass) {\n    group_norm_nhwc_fwd_one_pass_setup(params_fwd, barriers_elts, red_buffer_elts, grid, props);\n  } else {\n    group_norm_nhwc_fwd_two_passes_setup(params_fwd, zeroed_red_buffer_elts);\n  }\n\n  // The size in bytes for the reduction buffer.\n  size_t red_buffer_sz = red_buffer_elts * sizeof(float);\n  // Allocate on the device.\n  if (red_buffer_sz > 0) {\n    float** ptr = mode == Mode::BWD ? &params_bwd.red_buffer : &params_fwd.red_buffer;\n    CHECK_CUDA(cudaMalloc((void**)ptr, red_buffer_sz));\n  }\n\n  // The size of the array of barriers.\n  size_t barriers_sz = barriers_elts * sizeof(int);\n  // The size in bytes for the reduction buffer that must be zeroed.\n  size_t zeroed_red_buffer_sz = barriers_sz + zeroed_red_buffer_elts * sizeof(float);\n\n  // Allocate the buffer if needed.\n  void* zeroed_red_buffer_d_ = nullptr;\n  if (zeroed_red_buffer_sz > 0) {\n    CHECK_CUDA(cudaMalloc((void**)&zeroed_red_buffer_d_, zeroed_red_buffer_sz));\n  }\n\n  // The buffer of barriers. DO NOT CALL cudaFree on it!!!\n  int* barriers_d = reinterpret_cast<int*>(zeroed_red_buffer_d_);\n  // The zeroed red buffer. DO NOT CALL cudaFree on it!!!\n  float* zeroed_red_buffer_d = reinterpret_cast<float*>(&barriers_d[barriers_elts]);\n  // Must be aligned on 4B for floats. It obviously is (unless someone changes the code ;)).\n  assert(reinterpret_cast<const int64_t&>(zeroed_red_buffer_d) % sizeof(float) == 0);\n\n  // Set the barriers if needed.\n  if (mode == Mode::BWD) {\n    params_bwd.barriers = barriers_d;\n    params_bwd.zeroed_red_buffer = zeroed_red_buffer_d;\n  } else {\n    params_fwd.barriers = barriers_d;\n    params_fwd.zeroed_red_buffer = zeroed_red_buffer_d;\n  }\n\n  // Create events to time the reference code.\n  cudaEvent_t start, stop;\n  CHECK_CUDA(cudaEventCreate(&start));\n  CHECK_CUDA(cudaEventCreate(&stop));\n\n  // Time the reference code.\n  CHECK_CUDA(cudaEventRecord(start));\n  for (int ii = 0; ii < runs; ++ii) {\n    // Clear the zeroed buffer if needed.\n    if (zeroed_red_buffer_sz > 0) {\n      CHECK_CUDA(cudaMemsetAsync(zeroed_red_buffer_d_, 0, zeroed_red_buffer_sz, cudaStreamDefault));\n    }\n    if (use_one_pass && mode == Mode::BWD) {\n      group_norm_nhwc_bwd_one_pass_run(params_bwd, grid, cudaStreamDefault);\n    } else if (use_one_pass) {\n      group_norm_nhwc_fwd_one_pass_run(params_fwd, grid, cudaStreamDefault);\n    } else if (mode == Mode::BWD) {\n      group_norm_nhwc_bwd_two_passes_sum(params_bwd, cudaStreamDefault);\n      group_norm_nhwc_bwd_two_passes_scale(params_bwd, cudaStreamDefault);\n    } else {\n      group_norm_nhwc_fwd_two_passes_sum(params_fwd, cudaStreamDefault);\n      group_norm_nhwc_fwd_two_passes_scale(params_fwd, cudaStreamDefault);\n    }\n  }\n  CHECK_CUDA(cudaEventRecord(stop));\n  CHECK_CUDA(cudaDeviceSynchronize());\n\n  // Print the runtime.\n  float elapsed = 0.f;\n  CHECK_CUDA(cudaEventElapsedTime(&elapsed, start, stop));\n  if (!csv_output) {\n    printf(\"elapsed......................: %.3fms\\n\", elapsed);\n    printf(\"elapsed per run..............: %.3fms\\n\", elapsed / (float)runs);\n    printf(\"efficiency...................: %.3lf%%\\n\", dram_sol * runs / elapsed * 100.0);\n    printf(\"\\n\");\n  }\n\n  // Copy the results to the host.\n  if (mode == Mode::BWD) {\n    CHECK_CUDA(cudaMemcpyAsync(dx_h, dx_d, x_sz, cudaMemcpyDeviceToHost, cudaStreamDefault));\n    CHECK_CUDA(cudaMemcpyAsync(dgamma_h, dgamma_d, gamma_sz, cudaMemcpyDeviceToHost, cudaStreamDefault));\n    CHECK_CUDA(cudaMemcpyAsync(dbeta_h, dbeta_d, gamma_sz, cudaMemcpyDeviceToHost, cudaStreamDefault));\n  } else {\n    CHECK_CUDA(cudaMemcpyAsync(y_h, y_d, x_sz, cudaMemcpyDeviceToHost, cudaStreamDefault));\n  }\n\n  // Make sure the data has been transferred.\n  CHECK_CUDA(cudaStreamSynchronize(cudaStreamDefault));\n\n  // Check the results.\n  if (!csv_output) {\n    if (mode == Mode::BWD && !skip_checks) {\n      if (use_fp32) {\n        check_results<float>(\"dx\", reinterpret_cast<float*>(dx_h), reinterpret_cast<float*>(dx_ref_h), x_elts, tol);\n      } else if (use_bf16) {\n        check_results<__nv_bfloat16>(\"dx\", reinterpret_cast<__nv_bfloat16*>(dx_h),\n                                     reinterpret_cast<__nv_bfloat16*>(dx_ref_h), x_elts, tol);\n      } else {\n        check_results<__half>(\"dx\", reinterpret_cast<__half*>(dx_h), reinterpret_cast<__half*>(dx_ref_h), x_elts, tol);\n      }\n      check_results<float>(\"dgamma\", dgamma_h, dgamma_ref_h, gamma_elts, tol);\n      check_results<float>(\"dbeta\", dbeta_h, dbeta_ref_h, gamma_elts, tol);\n    } else if (!skip_checks) {\n      if (use_fp32) {\n        check_results<float>(\"y\", reinterpret_cast<float*>(y_h), reinterpret_cast<float*>(y_ref_h), x_elts, tol);\n      } else if (use_bf16) {\n        check_results<__nv_bfloat16>(\"y\", reinterpret_cast<__nv_bfloat16*>(y_h),\n                                     reinterpret_cast<__nv_bfloat16*>(y_ref_h), x_elts, tol);\n      } else {\n        check_results<__half>(\"y\", reinterpret_cast<__half*>(y_h), reinterpret_cast<__half*>(y_ref_h), x_elts, tol);\n      }\n    }\n  } else {\n    printf(\"%d,%d,%d,%d,%d,%d,%d,%f\\n\", n, h, w, c, groups, (uint32_t)use_one_pass, (uint32_t)mode,\n           elapsed / (float)runs);\n  }\n\n  // Destroy the cuda events.\n  CHECK_CUDA(cudaEventDestroy(start));\n  CHECK_CUDA(cudaEventDestroy(stop));\n\n  // Release device memory.\n  CHECK_CUDA(cudaFree(x_d));\n  CHECK_CUDA(cudaFree(y_d));\n  CHECK_CUDA(cudaFree(gamma_d));\n  CHECK_CUDA(cudaFree(beta_d));\n  CHECK_CUDA(cudaFree(dx_d));\n  CHECK_CUDA(cudaFree(dy_d));\n  CHECK_CUDA(cudaFree(dgamma_d));\n  CHECK_CUDA(cudaFree(dbeta_d));\n  CHECK_CUDA(cudaFree(sums_d));\n  CHECK_CUDA(cudaFree(zeroed_red_buffer_d_));\n  CHECK_CUDA(cudaFree(params_bwd.red_buffer));\n  CHECK_CUDA(cudaFree(params_fwd.red_buffer));\n\n  // Release host memory.\n  free(x_h);\n  free(y_h);\n  free(gamma_h);\n  free(beta_h);\n  free(dx_h);\n  free(dy_h);\n  free(dgamma_h);\n  free(dbeta_h);\n  free(sums_h);\n  free(y_ref_h);\n  free(dx_ref_h);\n  free(dgamma_ref_h);\n  free(dbeta_ref_h);\n\n  // Release the GPU.\n  CHECK_CUDA(cudaDeviceReset());\n  return 0;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc.h",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n#pragma once\n\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime_api.h>\n#include <math.h>\n#include <stdint.h>\n#include <stdio.h>\n#include <stdlib.h>\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define CHECK_CUDA(call)                                                                            \\\n  do {                                                                                              \\\n    cudaError_t status_ = call;                                                                     \\\n    if (status_ != cudaSuccess) {                                                                   \\\n      fprintf(stderr, \"CUDA error (%s:%d): %s\\n\", __FILE__, __LINE__, cudaGetErrorString(status_)); \\\n      exit(1);                                                                                      \\\n    }                                                                                               \\\n  } while (0)\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ __host__ int div_up(int m, int n) { return (m + n - 1) / n; }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ __host__ float sigmoid(float x) { return 1.f / (1.f + expf(-x)); }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline __device__ void spin_wait_(int* barrier, int step, int expected) {\n  // THE FOLLOWING CODE MUST BE EXECUTED BY A SINGLE THREAD IN THE CTA.\n\n  // Update the global counter. Make sure prior writes are visible.\n  asm volatile(\"red.release.gpu.global.add.s32 [%0], %1;\" ::\"l\"(barrier), \"r\"(step));\n\n  // Busy wait. We could use found = old + step with old = atomicAdd(...) but it's not faster.\n  for (volatile int found = -1; found != expected;) {\n    asm volatile(\"ld.global.acquire.gpu.b32 %0, [%1];\" : \"=r\"(found) : \"l\"(barrier));\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// Input type followed by parameter type\nenum PrecisionMode {\n  FP32IOFP16W,\n  FP32IOBF16W,\n  FP32IOFP32W,\n  FP16IOFP16W,\n  FP16IOBF16W,\n  FP16IOFP32W,\n  BF16IOFP16W,\n  BF16IOBF16W,\n  BF16IOFP32W,\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Group_sums {\n  // Is it the 1st element of the group?\n  int flag;\n  // The sum.\n  float sum;\n  // The sum of squares.\n  float sum_sq;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Group_sums_op {\n  inline __device__ Group_sums operator()(const Group_sums& a, const Group_sums& b) {\n    Group_sums dst;\n    dst.sum = b.flag ? b.sum : (a.sum + b.sum);\n    dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq);\n    dst.flag = a.flag + b.flag;\n    return dst;\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Group_norm_nhwc_fwd_params {\n  // The output buffer. Layout NHWC.\n  void* y;\n  // The sums for the bwd pass. Not written if it is a nullptr.\n  float2* sums;\n  // The input buffer. Layout NHWC.\n  const void* x;\n  // The gamma scaling factor.\n  const void* gamma;\n  // The beta term to add in GN.\n  const void* beta;\n  // The constant epsilon for sqrt(var + epsilon).\n  float epsilon;\n  // The barriers for the persistent kernel.\n  int* barriers;\n  // The extra storage for multi-CTA reductions as well as to pass data to the bwd.\n  float *red_buffer, *zeroed_red_buffer;\n\n  // The number of instances in the batch.\n  int n;\n  // The height and width of each activation map. The number of channels.\n  int64_t h, w, c, hw, hwc;\n  // The number of groups.\n  int groups;\n  // Do we apply the Swish activation function?\n  bool with_swish;\n\n  // Precomputed values and parameters to control the execution of the kernels.\n\n  // The number of batch instances per block.\n  int instances_per_block;\n  // The number of activations computed per block.\n  int acts_per_block;\n  // The number of groups in each block.\n  int groups_per_block;\n  // The number of channels per group = c / groups.\n  int channels_per_group;\n  // The number of channels per block = groups_per_block * channels_per_group.\n  int channels_per_block;\n  // The inverse of hwc in floats (to compute mean/var).\n  float inv_hwc_per_group;\n  // IO precision\n  PrecisionMode precision;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nvoid group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params&, size_t& red_buffer_elts);\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nvoid group_norm_nhwc_fwd_two_passes_sum(const Group_norm_nhwc_fwd_params&, cudaStream_t);\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nvoid group_norm_nhwc_fwd_two_passes_scale(const Group_norm_nhwc_fwd_params&, cudaStream_t);\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Group_norm_nhwc_bwd_params {\n  // The output buffer. Layout NHWC.\n  void* dx;\n  // The output buffer. Layout NHWC.\n  void* dgamma;\n  // The output buffer. Layout NHWC.\n  void* dbeta;\n  // The input buffer. Layout NHWC.\n  const void* dy;\n  // The input buffer. Layout NHWC.\n  const void* x;\n  // The gamma scaling factor.\n  const void* gamma;\n  // The beta term to add in GN.\n  const void* beta;\n  // The sums from the fwd pass.\n  const float2* sums;\n  // The constant epsilon for sqrt(var + epsilon).\n  float epsilon;\n  // The barriers for the persistent kernel.\n  int* barriers;\n  // The extra storage for multi-CTA reductions as well as to pass data to the bwd.\n  float *red_buffer, *zeroed_red_buffer;\n\n  // The number of instances in the batch.\n  int n;\n  // The height and width of each activation map. The number of channels.\n  int64_t h, w, c, hw, hwc;\n  // The number of groups.\n  int groups;\n  // Do we apply the Swish activation function?\n  bool with_swish;\n\n  // Precomputed values and parameters to control the execution of the kernels.\n\n  // The number of batch instances per block.\n  int instances_per_block;\n  // The number of activations computed per block.\n  int acts_per_block;\n  // The number of groups in each block.\n  int groups_per_block;\n  // The number of channels per group = c / groups.\n  int channels_per_group;\n  // The number of channels per block = groups_per_block * channels_per_group.\n  int channels_per_block;\n  // The inverse of hwc in floats (to compute mean/var).\n  float inv_hwc_per_group;\n  // IO precision\n  PrecisionMode precision;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nvoid group_norm_nhwc_bwd_two_passes_setup(Group_norm_nhwc_bwd_params&, size_t& red_buffer_elts);\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nvoid group_norm_nhwc_bwd_two_passes_sum(const Group_norm_nhwc_bwd_params&, cudaStream_t);\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nvoid group_norm_nhwc_bwd_two_passes_scale(const Group_norm_nhwc_bwd_params&, cudaStream_t);\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_one_pass.h",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n#include <assert.h>\n\n#include <algorithm>\n\n#include \"group_norm_nhwc.h\"\n#include \"macros.h\"\n#include \"traits.h\"\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// B A C K W A R D\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define GN_BWD_SELECT(FUNC_POSTFIX, function)                                                    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(4, FUNC_POSTFIX, function)     \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(8, FUNC_POSTFIX, function)     \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(10, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(12, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(14, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(16, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(20, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(26, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(24, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(28, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(30, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(32, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(40, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(42, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(48, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(56, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(60, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(64, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(70, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(80, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(84, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(96, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(98, FUNC_POSTFIX, function)    \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(112, FUNC_POSTFIX, function)   \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(120, FUNC_POSTFIX, function)   \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(128, FUNC_POSTFIX, function)   \\\n  GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(160, FUNC_POSTFIX, function) { \\\n    assert(false && \"Not implemented\");                                                          \\\n  }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define GN_BWD_RUNNER_SELECT(function) GN_BWD_SELECT(_run, function)\n\n#define GN_BWD_BLOCKS_PER_SM_SELECT(function) GN_BWD_SELECT(_blocks_per_sm, function)\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 4)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 8)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 10)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 12)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 14)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 16)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 20)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 26)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 24)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 28)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 30)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 32)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 40)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 42)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 48)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 56)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 60)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 64)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 70)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 80)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 84)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 96)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 98)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 112)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 120)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 128)\nGN_BWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 160)\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nvoid group_norm_nhwc_bwd_one_pass_setup(Group_norm_nhwc_bwd_params& params, size_t& barriers_elts,\n                                        size_t& red_buffer_elts, size_t& zeroed_red_buffer_elts, dim3& grid,\n                                        const cudaDeviceProp& props) {\n  // The pre-computed dimensions.\n  params.hw = params.h * params.w;\n  params.hwc = params.c * params.hw;\n\n  // The number of channels per group.\n  params.channels_per_group = params.c / params.groups;\n  // The inverse to compute the mean/variance.\n  params.inv_hwc_per_group = 1.f / (float)(params.hw * params.channels_per_group);\n\n  // Define how many activations are computed per block.\n  if ((params.hw >= 1024 && params.channels_per_group >= 80) ||\n      (params.hw >= 256 && params.channels_per_group >= 160)) {\n    params.acts_per_block = 8 * 16;\n  } else if (params.hw >= 512) {\n    params.acts_per_block = 32 * 16;\n  } else if (params.hw >= 256) {\n    params.acts_per_block = 16 * 16;\n  } else if (params.hw >= 128) {\n    params.acts_per_block = 8 * 16;\n  } else if (params.hw > 0) {\n    params.acts_per_block = 8 * 8;\n  } else {\n    // We should never be here if params are set correctly.\n    assert(false);\n  }\n\n  // Define the number of blocks per activation map. TODO: Make sure it matches the kernel sizes.\n  int blocks_per_slice = div_up(params.hw, params.acts_per_block);\n\n  // Select the kernel.\n  using Function_t = int (*)();\n\n  Function_t blocks_per_sm_function;\n  GN_BWD_BLOCKS_PER_SM_SELECT(blocks_per_sm_function);\n  // The number of blocks that can be run per SM.\n  int blocks_per_sm = blocks_per_sm_function();\n\n  // The number of blocks per grid.\n  int max_blocks_per_grid = blocks_per_sm * props.multiProcessorCount;\n\n  // Make sure we are safe to run that many blocks\n  assert(blocks_per_slice <= max_blocks_per_grid);\n\n  // The number of blocks per slice is the X dimension of the grid.\n  grid.x = blocks_per_slice;\n  // The number of groups *  is the X dimension of the grid.\n  grid.y = std::min(max_blocks_per_grid / blocks_per_slice, params.groups * params.n);\n\n  // The number of barriers.\n  barriers_elts = blocks_per_slice > 1 ? grid.y * 2 : 0;\n\n  // Add 1 for the final conversion for dgamma/dbeta.\n  barriers_elts += 1;\n\n  // The number of elements in the reduction buffer (for the sums and sums of squared).\n  if (blocks_per_slice == 1) {\n    red_buffer_elts = 0;\n  } else {\n    // The first 2 is for double-buffering. The 2nd one is for the fact that we have two floats.\n    red_buffer_elts = 2 * grid.x * grid.y * 2;\n  }\n\n  // The number of elements in the buffer that has to be zeroed.\n  zeroed_red_buffer_elts = params.c * 2;\n\n  // Make sure a group does not span multiple blocks.\n  assert(params.channels_per_block % params.channels_per_group == 0);\n}\n\ninline void group_norm_nhwc_bwd_one_pass_run(const Group_norm_nhwc_bwd_params& params, const dim3& grid,\n                                             cudaStream_t stream) {\n  using Function_t = void (*)(const Group_norm_nhwc_bwd_params&, const dim3&, cudaStream_t);\n\n  Function_t runner;\n  GN_BWD_RUNNER_SELECT(runner);\n\n  runner(params, grid, stream);\n}\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_one_pass_kernel.cuh",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n#include <assert.h>\n\n#include <cub/cub.cuh>\n\n#include \"group_norm_nhwc.h\"\n#include \"traits.h\"\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// B A C K W A R D\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Traits_, int ACTS_PER_BLOCK_, int CHANNELS_PER_GROUP_, int THREADS_PER_BLOCK_>\n__global__ __launch_bounds__(THREADS_PER_BLOCK_) void group_norm_nhwc_bwd_one_pass_kernel(\n    Group_norm_nhwc_bwd_params params) {\n  // The IO traits.\n  using Traits = Traits_;\n  // The IO traits.\n  using IOTraits = typename Traits::IOTraits;\n  // The Weights traits.\n  using WTraits = typename Traits::WTraits;\n\n  // The IO type\n  using IOType = typename IOTraits::Type;\n  // The IO doubled type\n  using IOType2 = typename IOTraits::Type2;\n\n  // Weights type\n  using WType = typename WTraits::Type;\n  // Weights doubled type\n  using WType2 = typename WTraits::Type2;\n\n  // The number of activations per block.\n  constexpr int ACTS_PER_BLOCK = ACTS_PER_BLOCK_;\n  // The number of channels per group.\n  constexpr int CHANNELS_PER_GROUP = CHANNELS_PER_GROUP_;\n  // The number of threads per block.\n  constexpr int THREADS_PER_BLOCK = THREADS_PER_BLOCK_;\n  // The number of channels per thread (load fp16x2 numbers).\n  constexpr int CHANNELS_PER_THREAD = 2;\n\n  // The number of threads needed per activation.\n  constexpr int THREADS_PER_ACT = CHANNELS_PER_GROUP / CHANNELS_PER_THREAD;\n  // The number of activations that are loaded per loop.\n  constexpr int ACTS_PER_LOOP = THREADS_PER_BLOCK / THREADS_PER_ACT;\n  // The number of rows per thread.\n  constexpr int ACTS_PER_THREAD = (ACTS_PER_BLOCK + ACTS_PER_LOOP - 1) / ACTS_PER_LOOP;\n\n  // The number of active threads.\n  constexpr int ACTIVE_THREADS = THREADS_PER_BLOCK / THREADS_PER_ACT * THREADS_PER_ACT;\n\n  // The object in charge of doing the sums for the block.\n  typedef cub::BlockReduce<float2, THREADS_PER_BLOCK> Block_reduce;\n  // Allocate shared memory for Block_reduce.\n  __shared__ typename Block_reduce::TempStorage temp_storage;\n  // Allocate shared memory to store the sums.\n  __shared__ float2 smem_sums;\n  // Allocate shared memory to store the gamma/beta gradients.\n  __shared__ float4 smem_dgamma_dbeta[THREADS_PER_BLOCK];\n\n  // Shared memory to store the gradients for gamma and beta.\n\n  // The first activation loaded by that thread.\n  int hwi = blockIdx.x * params.acts_per_block + threadIdx.x / THREADS_PER_ACT;\n  // The first channel loaded by that thread.\n  int ci = threadIdx.x % THREADS_PER_ACT * CHANNELS_PER_THREAD;\n\n  // Is it an active thread?\n  const bool is_active = threadIdx.x < ACTIVE_THREADS;\n\n  // Iterate over the iterms in the batch.\n  for (int ngi = blockIdx.y, step = 0; ngi < params.n * params.groups; ngi += gridDim.y, ++step) {\n    // The instance and the group. TODO: Use fast divmod?\n    int ni = ngi / params.groups;\n    int gi = ngi % params.groups;\n\n    // The sums from the fwd pass.\n    float2 fwd = params.sums[ngi];\n    // The mean of X (computed during the fwd pass -- one value per batch*group).\n    float x_mean = fwd.x;\n    // The mean of squares of X (computed during the fwd pass -- one value per batch*group).\n    float x_sq_mean = fwd.y;\n    // The variance.\n    float x_var = x_sq_mean - x_mean * x_mean;\n    // The reciprocal of the standard deviation (i.e. 1.f / sqrt(var + epsilon)).\n    float rcp_x_stddev = x_var <= 0.f ? 1.f : 1.f / sqrtf(x_var + params.epsilon);\n\n    // The offset to the first activation loaded by that thread.\n    const int64_t offset = (int64_t)ni * params.hwc + gi * CHANNELS_PER_GROUP + ci;\n    // The pointer to the first activation loaded by that thread.\n    const IOType* x_ptr = &reinterpret_cast<const IOType*>(params.x)[offset];\n    // The pointer to the first gradient loaded by that thread.\n    const IOType* dy_ptr = &reinterpret_cast<const IOType*>(params.dy)[offset];\n\n    // Load the X and dY into registers.\n    IOType2 x[ACTS_PER_THREAD], dy[ACTS_PER_THREAD];\n#pragma unroll\n    for (int ii = 0; ii < ACTS_PER_THREAD; ++ii) {\n      int hwj = hwi + ii * ACTS_PER_LOOP;\n      x[ii] = IOTraits::zero();\n      dy[ii] = IOTraits::zero();\n      if (is_active && hwj < params.hw) {\n        x[ii] = *reinterpret_cast<const IOType2*>(&x_ptr[hwj * params.c]);\n        dy[ii] = *reinterpret_cast<const IOType2*>(&dy_ptr[hwj * params.c]);\n      }\n    }\n\n    // Load gamma as well.\n    float2 gamma_f2 = make_float2(0.f, 0.f);\n    float2 beta_f2 = make_float2(0.f, 0.f);\n    if (is_active) {\n      gamma_f2 = WTraits::unpack(*reinterpret_cast<const WType2*>(\n          &reinterpret_cast<const WType*>(params.gamma)[gi * CHANNELS_PER_GROUP + ci]));\n      if (params.with_swish) {\n        beta_f2 = WTraits::unpack(*reinterpret_cast<const WType2*>(\n            &reinterpret_cast<const WType*>(params.beta)[gi * CHANNELS_PER_GROUP + ci]));\n      }\n    }\n\n    // Gradients for gamma and beta (for this particular group).\n    float4 dgamma_dbeta = make_float4(0.f, 0.f, 0.f, 0.f);\n    // Accumulated gradients for dgrad calculation.\n    float mean_1 = 0.f, mean_2 = 0.f;\n\n// Compute the sum and the sum of squares for each thread.\n#pragma unroll\n    for (int ii = 0; ii < ACTS_PER_THREAD; ++ii) {\n      // Convert x to float.\n      float2 x_f2 = IOTraits::unpack(x[ii]);\n      // Convert dY to float.\n      float2 dy_f2 = IOTraits::unpack(dy[ii]);\n\n      // X - X_mean.\n      float x_minus_x_mean_x = x_f2.x - x_mean;\n      float x_minus_x_mean_y = x_f2.y - x_mean;\n\n      // Normalize X.\n      float x_norm_x = x_minus_x_mean_x * rcp_x_stddev;\n      float x_norm_y = x_minus_x_mean_y * rcp_x_stddev;\n\n      if (params.with_swish) {\n        float x_gn_x = x_norm_x * gamma_f2.x + beta_f2.x;\n        float x_gn_y = x_norm_y * gamma_f2.y + beta_f2.y;\n        float s_x = sigmoid(x_gn_x);\n        float s_y = sigmoid(x_gn_y);\n        dy_f2.x = dy_f2.x * s_x * (1.f + x_gn_x * (1.f - s_x));\n        dy_f2.y = dy_f2.y * s_y * (1.f + x_gn_y * (1.f - s_y));\n      }\n\n      // Update beta.\n      dgamma_dbeta.z += dy_f2.x;\n      dgamma_dbeta.w += dy_f2.y;\n\n      // Update dgamma.\n      dgamma_dbeta.x += dy_f2.x * x_norm_x;\n      dgamma_dbeta.y += dy_f2.y * x_norm_y;\n\n      // The gradient that enters the x_norm node.\n      float dx_norm_x = dy_f2.x * gamma_f2.x;\n      float dx_norm_y = dy_f2.y * gamma_f2.y;\n\n      // Add to the 1st mean.\n      mean_1 += dx_norm_x * x_norm_x;\n      mean_1 += dx_norm_y * x_norm_y;\n\n      // Add to the 2nd mean.\n      mean_2 += dx_norm_x;\n      mean_2 += dx_norm_y;\n    }\n\n    // Pack valid gradients.\n    float2 sums = make_float2(0.f, 0.f);\n    if (ACTIVE_THREADS == THREADS_PER_BLOCK || is_active) {\n      sums = make_float2(mean_1, mean_2);\n    }\n\n    // Store dgamma and dbeta to shared memory.\n    smem_dgamma_dbeta[threadIdx.x] = dgamma_dbeta;\n\n    // Compute the sums for the block.\n    sums = Block_reduce(temp_storage).Reduce(sums, [](const float2& a, const float2& b) {\n      return make_float2(a.x + b.x, a.y + b.y);\n    });\n\n    // Make sure we can read gamma/beta from smemory. Block_reduce uses one syncthread already.\n    __syncthreads();\n\n    // Compute gamma/beta for the block.\n    if (threadIdx.x < THREADS_PER_ACT) {\n      for (int ii = 1; ii < ACTS_PER_LOOP; ++ii) {\n        float4 other = smem_dgamma_dbeta[threadIdx.x + ii * THREADS_PER_ACT];\n        dgamma_dbeta.x += other.x;\n        dgamma_dbeta.y += other.y;\n        dgamma_dbeta.z += other.z;\n        dgamma_dbeta.w += other.w;\n      }\n    }\n\n    // The position in the channel dimension - 2 channels per thread.\n    int cj = gi * THREADS_PER_ACT + threadIdx.x;\n    // The reduction buffer dfor gamma/dbeta.\n    float* red_buffer_dgamma_dbeta = &params.zeroed_red_buffer[cj];\n\n    // The first threads store their gradients for gamma/beta.\n    if (threadIdx.x < THREADS_PER_ACT) {\n      atomicAdd(&red_buffer_dgamma_dbeta[0 * params.c / 2], dgamma_dbeta.x);\n      atomicAdd(&red_buffer_dgamma_dbeta[1 * params.c / 2], dgamma_dbeta.y);\n      atomicAdd(&red_buffer_dgamma_dbeta[2 * params.c / 2], dgamma_dbeta.z);\n      atomicAdd(&red_buffer_dgamma_dbeta[3 * params.c / 2], dgamma_dbeta.w);\n    }\n\n    // The block leader stores to global memory, if needed.\n    if (gridDim.x > 1) {\n      // The index of the buffer.\n      int red_buffer_idx = step & 1;\n      // The barrier.\n      int* barrier = &params.barriers[red_buffer_idx * gridDim.y + blockIdx.y];\n      // The offset to the reduction buffer.\n      int red_buffer_offset = red_buffer_idx * gridDim.x * gridDim.y * 2;\n      // The reduction buffer.\n      float2* red_buffer = reinterpret_cast<float2*>(&params.red_buffer[red_buffer_offset]);\n\n      // The offset to the reduction buffer for dgamma/dbeta.\n\n      // The first thread stores its sums.\n      if (threadIdx.x == 0) {\n        red_buffer[blockIdx.x * gridDim.y + blockIdx.y] = sums;\n      }\n\n      // Make sure the data is in memory.\n      if (threadIdx.x == 0) {\n        spin_wait_(barrier, (step & 2) ? -1 : 1, (step & 2) ? 0 : gridDim.x);\n      }\n      __syncthreads();\n\n      // Update the sums.\n      for (int ii = 0; ii < gridDim.x; ++ii) {\n        if (ii != blockIdx.x && threadIdx.x == 0) {\n          float2 other_sums = red_buffer[ii * gridDim.y + blockIdx.y];\n          sums.x += other_sums.x;\n          sums.y += other_sums.y;\n        }\n      }\n    }\n\n    // Store the result for other threads.\n    if (threadIdx.x == 0) {\n      smem_sums = sums;\n    }\n\n    // Make sure the sums are in shared memory.\n    __syncthreads();\n\n    // Read the 1st mean from shared memory.\n    mean_1 = smem_sums.x;\n    // Read the 2nd mean from shared memory.\n    mean_2 = smem_sums.y;\n\n    mean_1 *= params.inv_hwc_per_group;\n    mean_2 *= params.inv_hwc_per_group;\n\n    // The pointer to the first activation stored by that thread.\n    IOType* dx_ptr = &reinterpret_cast<IOType*>(params.dx)[offset];\n\n    // Iterate over the activations to normalize the activations and store the results.\n    for (int ii = 0; ii < ACTS_PER_THREAD; ++ii) {\n      // Convert x to float.\n      float2 x_f2 = IOTraits::unpack(x[ii]);\n      // Convert dY to float.\n      float2 dy_f2 = IOTraits::unpack(dy[ii]);\n\n      // X - X_mean.\n      float2 x_minus_x_mean_f2;\n      x_minus_x_mean_f2.x = x_f2.x - x_mean;\n      x_minus_x_mean_f2.y = x_f2.y - x_mean;\n      // Normalize X.\n      float2 x_norm_f2;\n      x_norm_f2.x = x_minus_x_mean_f2.x * rcp_x_stddev;\n      x_norm_f2.y = x_minus_x_mean_f2.y * rcp_x_stddev;\n\n      if (params.with_swish) {\n        float x_gn_x = x_norm_f2.x * gamma_f2.x + beta_f2.x;\n        float x_gn_y = x_norm_f2.y * gamma_f2.y + beta_f2.y;\n        float s_x = sigmoid(x_gn_x);\n        float s_y = sigmoid(x_gn_y);\n        dy_f2.x = dy_f2.x * s_x * (1.f + x_gn_x * (1.f - s_x));\n        dy_f2.y = dy_f2.y * s_y * (1.f + x_gn_y * (1.f - s_y));\n      }\n\n      // The gradient that enters the x_norm node.\n      float2 dx_norm;\n      dx_norm.x = dy_f2.x * gamma_f2.x;\n      dx_norm.y = dy_f2.y * gamma_f2.y;\n\n      // The gradient along the input path.\n      float2 dx;\n      dx.x = (dx_norm.x - (x_norm_f2.x * mean_1 + mean_2)) * rcp_x_stddev;\n      dx.y = (dx_norm.y - (x_norm_f2.y * mean_1 + mean_2)) * rcp_x_stddev;\n\n      // Store the scaled values.\n      int hwj = hwi + ii * ACTS_PER_LOOP;\n      if (is_active && hwj < params.hw) {\n        *reinterpret_cast<IOType2*>(&dx_ptr[hwj * params.c]) = IOTraits::pack(dx);\n      }\n    }\n  }\n\n  // The completion barrier.\n  int* barrier = &params.barriers[gridDim.x == 1 ? 0 : gridDim.y * 2];\n\n  // Mark the completion of the threadblock.\n  if (threadIdx.x == 0) {\n    asm volatile(\"red.release.gpu.global.add.s32 [%0], 1;\" ::\"l\"(barrier));\n  }\n\n  // Exit if that's not the last thread block.\n  if (blockIdx.x != gridDim.x - 1 || blockIdx.y != gridDim.y - 1) {\n    return;\n  }\n\n  // Busy wait. We could use found = old + step with old = atomicAdd(...) but it's not faster.\n  if (threadIdx.x == 0) {\n    for (int found = -1; found != gridDim.x * gridDim.y;) {\n      asm volatile(\"ld.global.acquire.gpu.b32 %0, [%1];\" : \"=r\"(found) : \"l\"(barrier));\n    }\n  }\n  __syncthreads();\n\n  // The last block converts dgamma and dbeta to half.\n  for (int idx = threadIdx.x; idx < params.c / 2; idx += THREADS_PER_BLOCK) {\n    // Load dgamma.\n    float2 dgamma;\n    dgamma.x = params.zeroed_red_buffer[idx + 0 * params.c / 2];\n    dgamma.y = params.zeroed_red_buffer[idx + 1 * params.c / 2];\n\n    // Load dbeta.\n    float2 dbeta;\n    dbeta.x = params.zeroed_red_buffer[idx + 2 * params.c / 2];\n    dbeta.y = params.zeroed_red_buffer[idx + 3 * params.c / 2];\n\n    // Store to global memory.\n    *reinterpret_cast<WType2*>(&reinterpret_cast<WType*>(params.dgamma)[idx * 2]) = WTraits::pack(dgamma);\n    *reinterpret_cast<WType2*>(&reinterpret_cast<WType*>(params.dbeta)[idx * 2]) = WTraits::pack(dbeta);\n  }\n}\n\n//////////////////////////////////////////////////////////////////////////////////////////////////\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_bwd_two_pass.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n#include <assert.h>\n\n#include <cub/cub.cuh>\n\n#include \"group_norm_nhwc.h\"\n#include \"macros.h\"\n#include \"traits.h\"\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// B A C K W A R D\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Traits_, int THREADS_PER_BLOCK>\n__global__ void group_norm_nhwc_bwd_sum_kernel(Group_norm_nhwc_bwd_params params) {\n  // The IO traits.\n  using Traits = Traits_;\n  // The IO traits.\n  using IOTraits = typename Traits::IOTraits;\n  // The Weights traits.\n  using WTraits = typename Traits::WTraits;\n\n  // The IO type\n  using IOType = typename IOTraits::Type;\n  // The IO doubled type\n  using IOType2 = typename IOTraits::Type2;\n\n  // Weights type\n  using WType = typename WTraits::Type;\n  // Weights doubled type\n  using WType2 = typename WTraits::Type2;\n\n  // The object in charge of doing the sums for the different blocks.\n  typedef cub::BlockScan<Group_sums, THREADS_PER_BLOCK> Block_scan;\n\n  // Allocate shared memory for Block_scan.\n  __shared__ typename Block_scan::TempStorage temp_storage;\n  // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved.\n  __shared__ float2 smem[THREADS_PER_BLOCK];\n\n  // The instance in the batch.\n  int ni = blockIdx.z;\n  // The channel loaded by that thread (2 channels per thread for F16x2).\n  int ci = blockIdx.x * params.channels_per_block + threadIdx.x * 2;\n  // The group that thread works on and the channel in the group (modulus).\n  int gi = ci / params.channels_per_group;\n\n  // The sums from the fwd pass.\n  float2 fwd = params.sums[ni * params.groups + gi];\n  // The mean of X (computed during the fwd pass -- one value per batch*group).\n  float x_mean = fwd.x;\n  // The mean of squares of X (computed during the fwd pass -- one value per batch*group).\n  float x_sq_mean = fwd.y;\n  // The variance.\n  float x_var = x_sq_mean - x_mean * x_mean;\n  // The reciprocal of the standard deviation (i.e. 1.f / sqrt(var + epsilon)).\n  float rcp_x_stddev = x_var <= 0.f ? 1.f : 1.f / sqrtf(x_var + params.epsilon);\n\n  // Load gamma.\n  float2 gamma_f2 = make_float2(0.f, 0.f);\n  float2 beta_f2 = make_float2(0.f, 0.f);\n  if (ci < params.c) {\n    gamma_f2 = WTraits::unpack(*reinterpret_cast<const WType2*>(&reinterpret_cast<const WType*>(params.gamma)[ci]));\n    if (params.with_swish) {\n      beta_f2 = WTraits::unpack(*reinterpret_cast<const WType2*>(&reinterpret_cast<const WType*>(params.beta)[ci]));\n    }\n  }\n\n  // The group that thread works on and the channel in the group (modulus).\n  int gj = threadIdx.x * 2 / params.channels_per_group;\n  int cj = threadIdx.x * 2 - params.channels_per_group * gj;\n\n  // The first activation loaded by that block.\n  int hw_begin = blockIdx.y * params.acts_per_block;\n  // The last activation loaded by that block.\n  int hw_end = min((int64_t)hw_begin + params.acts_per_block, params.hw);\n\n  // The gradients for gamma/beta.\n  float2 dgamma = make_float2(0.f, 0.f), dbeta = make_float2(0.f, 0.f);\n  // Accumulated gradients for dgrad calculation\n  float mean_1 = 0.f, mean_2 = 0.f;\n\n  // Iterate over the activations to compute the sums.\n  for (int hwi = hw_begin; hwi < hw_end; ++hwi) {\n    // The offset.\n    int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci;\n\n    // Fetch two channels per thread.\n    IOType2 x_v2 = IOTraits::zero();\n    IOType2 dy_v2 = IOTraits::zero();\n    if (ci < params.c) {\n      x_v2 = *reinterpret_cast<const IOType2*>(&reinterpret_cast<const IOType*>(params.x)[offset]);\n      dy_v2 = *reinterpret_cast<const IOType2*>(&reinterpret_cast<const IOType*>(params.dy)[offset]);\n    }\n\n    // Extract the two half values.\n    float2 x_f2 = IOTraits::unpack(x_v2);\n    float2 dy_f2 = IOTraits::unpack(dy_v2);\n\n    // X - X_mean.\n    float x_minus_x_mean_x = x_f2.x - x_mean;\n    float x_minus_x_mean_y = x_f2.y - x_mean;\n\n    // Normalize X.\n    float x_norm_x = x_minus_x_mean_x * rcp_x_stddev;\n    float x_norm_y = x_minus_x_mean_y * rcp_x_stddev;\n\n    if (params.with_swish) {\n      float x_gn_x = x_norm_x * gamma_f2.x + beta_f2.x;\n      float x_gn_y = x_norm_y * gamma_f2.y + beta_f2.y;\n      float s_x = sigmoid(x_gn_x);\n      float s_y = sigmoid(x_gn_y);\n      dy_f2.x = dy_f2.x * s_x * (1.f + x_gn_x * (1.f - s_x));\n      dy_f2.y = dy_f2.y * s_y * (1.f + x_gn_y * (1.f - s_y));\n    }\n\n    // Update beta.\n    dbeta.x += dy_f2.x;\n    dbeta.y += dy_f2.y;\n\n    // Update dgamma.\n    dgamma.x += dy_f2.x * x_norm_x;\n    dgamma.y += dy_f2.y * x_norm_y;\n\n    // The gradient that enters the x_norm node.\n    float dx_norm_x = dy_f2.x * gamma_f2.x;\n    float dx_norm_y = dy_f2.y * gamma_f2.y;\n\n    // Add to the 1st mean.\n    mean_1 += dx_norm_x * x_norm_x;\n    mean_1 += dx_norm_y * x_norm_y;\n\n    // Add to the 2nd mean.\n    mean_2 += dx_norm_x;\n    mean_2 += dx_norm_y;\n  }\n\n  // The data for the summations.\n  Group_sums inp{cj == 0 ? 1 : 0, mean_1, mean_2};\n\n  // Do the segmented scan.\n  Group_sums out;\n  Block_scan(temp_storage).InclusiveScan(inp, out, Group_sums_op());\n\n  // Store the results for the groups in shared memory (to produce coalesced stores later).\n  if (cj == params.channels_per_group - 2 /* 2 channels per thread */) {\n    smem[gj] = make_float2(out.sum, out.sum_sq);\n  }\n\n  // Make sure the data is in shared memory.\n  __syncthreads();\n\n  // The global group index.\n  int gk = blockIdx.x * params.groups_per_block + threadIdx.x;\n\n  // The first threads (those storing to global memory, load the values).\n  float2 sums = smem[threadIdx.x];\n\n  // Store to global memory.\n  if (threadIdx.x < params.groups_per_block && gk < params.groups) {\n    atomicAdd(&params.zeroed_red_buffer[(2 * ni + 0) * params.groups + gk], sums.x);\n    atomicAdd(&params.zeroed_red_buffer[(2 * ni + 1) * params.groups + gk], sums.y);\n  }\n\n  // The base pointer for the gradients for gamma and beta.\n  float* dgamma_beta_ptr = &params.zeroed_red_buffer[params.n * params.groups * 2];\n\n  // The 1st channel in the output tensor. NOTE: Two channels per thread store interleaved.\n  int ck = blockIdx.x * params.channels_per_block + threadIdx.x;\n\n  // Store dgamma and dbeta as well.\n  if (ck < params.c) {\n    atomicAdd(&dgamma_beta_ptr[0 * params.c + 0 * blockDim.x + ck], dgamma.x);\n    atomicAdd(&dgamma_beta_ptr[0 * params.c + 1 * blockDim.x + ck], dgamma.y);\n    atomicAdd(&dgamma_beta_ptr[1 * params.c + 0 * blockDim.x + ck], dbeta.x);\n    atomicAdd(&dgamma_beta_ptr[1 * params.c + 1 * blockDim.x + ck], dbeta.y);\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nvoid group_norm_nhwc_bwd_two_passes_setup(Group_norm_nhwc_bwd_params& params, size_t& zeroed_red_buffer_elts) {\n  // The pre-computed dimensions.\n  params.hw = params.h * params.w;\n  params.hwc = params.c * params.hw;\n\n  // The number of channels per group.\n  params.channels_per_group = params.c / params.groups;\n  // The inverse to compute the mean/variance.\n  params.inv_hwc_per_group = 1.f / (float)(params.hw * params.channels_per_group);\n\n  // Define the number of blocks per activation map. That's a simple heuristic.\n  int blocks_per_act_slice = 0;\n  if (params.c >= 1280) {\n    blocks_per_act_slice = 128 / params.n;\n  } else if (params.c >= 640) {\n    blocks_per_act_slice = 256 / params.n;\n  } else {\n    blocks_per_act_slice = 512 / params.n;\n  }\n\n  // Clamp to at least 1 to avoid divide-by-zero when batch size is large.\n  blocks_per_act_slice = max(blocks_per_act_slice, 1);\n\n  // Make sure we launch blocks per activation is no less than activations\n  blocks_per_act_slice = min(blocks_per_act_slice, div_up(params.hw, params.n));\n\n  // Define how many activations are computed per block.\n  params.acts_per_block = div_up(params.hw, blocks_per_act_slice);\n\n  // The number of channels per block.\n  params.channels_per_block = 320;\n  // Special case to deal with 30 channels per group.\n  if (params.channels_per_block % params.channels_per_group != 0) {\n    params.channels_per_block = 240;\n  }\n\n  // Special case to deal with 70 channels per group.\n  if (params.c == 2240) {\n    params.channels_per_block = 280;\n  } else if (params.c == 832) {\n    params.channels_per_block = 208;\n  }\n\n  if (params.c % params.channels_per_block != 0) {\n    if (params.c % 512 == 0 && params.c != 1536 && params.c != 3072 && params.c % 448 != 0) {\n      params.channels_per_block = 512;\n    } else if (params.c % 42 == 0) {\n      params.channels_per_block = 336;\n    } else if (params.c % 384 == 0) {\n      params.channels_per_block = 384;\n    } else if (params.c % 256 == 0 && params.c % 448 != 0 && params.c % 392 != 0) {\n      params.channels_per_block = 256;\n    } else if (params.c % 128 == 0 && params.c % 448 != 0 && params.c % 392 != 0) {\n      params.channels_per_block = 128;\n    } else if (params.c % 448 == 0 && params.c % 392 != 0) {\n      params.channels_per_block = 448;\n    } else if (params.c % 392 == 0) {\n      params.channels_per_block = 392;\n    }\n  }\n\n  // The number of groups per block.\n  params.groups_per_block = params.channels_per_block / params.channels_per_group;\n\n  // Make sure the number of channels is a multiple of the number of channels per block.\n  assert(params.c % params.channels_per_block == 0);\n  // Make sure a group does not span multiple blocks.\n  assert(params.channels_per_block % params.channels_per_group == 0);\n\n  // The number of elements in the reduction buffer (for the sums and sums of squared).\n  zeroed_red_buffer_elts = params.n * params.groups * 2 + params.c * 2;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nvoid group_norm_nhwc_bwd_two_passes_sum(const Group_norm_nhwc_bwd_params& params, cudaStream_t stream) {\n  // The dimension of the grid.\n  dim3 grid;\n\n  // The number of blocks to compute all the channels.\n  grid.x = params.c / params.channels_per_block;\n  // The number of blocks to compute all the activations in a given instance.\n  grid.y = div_up(params.hw, params.acts_per_block);\n  // The number of instances.\n  grid.z = params.n;\n\n  if (params.precision == PrecisionMode::FP16IOFP16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_sum_kernel, Fp16IOFp16W)\n  } else if (params.precision == PrecisionMode::FP16IOBF16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_sum_kernel, Fp16IOBf16W)\n  } else if (params.precision == PrecisionMode::FP16IOFP32W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_sum_kernel, Fp16IOFp32W)\n  } else if (params.precision == PrecisionMode::BF16IOFP16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_sum_kernel, Bf16IOFp16W)\n  } else if (params.precision == PrecisionMode::BF16IOBF16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_sum_kernel, Bf16IOBf16W)\n  } else if (params.precision == PrecisionMode::BF16IOFP32W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_sum_kernel, Bf16IOFp32W)\n  } else if (params.precision == PrecisionMode::FP32IOFP16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_sum_kernel, Fp32IOFp16W)\n  } else if (params.precision == PrecisionMode::FP32IOBF16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_sum_kernel, Fp32IOBf16W)\n  } else {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_sum_kernel, Fp32IOFp32W)\n  }\n\n  // Make sure it launched ok.\n  CHECK_CUDA(cudaGetLastError());\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Traits_, int THREADS_PER_BLOCK>\n__global__ void group_norm_nhwc_bwd_scale_kernel(Group_norm_nhwc_bwd_params params) {\n  // The IO traits.\n  using Traits = Traits_;\n  // The IO traits.\n  using IOTraits = typename Traits::IOTraits;\n  // The Weights traits.\n  using WTraits = typename Traits::WTraits;\n\n  // The IO type\n  using IOType = typename IOTraits::Type;\n  // The IO doubled type\n  using IOType2 = typename IOTraits::Type2;\n\n  // Weights type\n  using WType = typename WTraits::Type;\n  // Weights doubled type\n  using WType2 = typename WTraits::Type2;\n\n  // The instance in the batch.\n  int ni = blockIdx.z;\n  // The channel loaded by that thread (2 channels per thread for F16x2).\n  int ci = blockIdx.x * params.channels_per_block + threadIdx.x * 2;\n  // The group that thread works on and the channel in the group (modulus).\n  int gi = ci / params.channels_per_group;\n\n  // Load the gradients for the group.\n  float mean_1 = 0.f, mean_2 = 0.f;\n  if (gi < params.groups) {\n    mean_1 = params.zeroed_red_buffer[(2 * ni + 0) * params.groups + gi];\n    mean_2 = params.zeroed_red_buffer[(2 * ni + 1) * params.groups + gi];\n  }\n\n  // The sums from the fwd pass.\n  float2 fwd = params.sums[ni * params.groups + gi];\n  // The mean of X (computed during the fwd pass -- one value per batch*group).\n  float x_mean = fwd.x;\n  // The mean of squares of X (computed during the fwd pass -- one value per batch*group).\n  float x_sq_mean = fwd.y;\n  // The variance.\n  float x_var = x_sq_mean - x_mean * x_mean;\n  // The reciprocal of the standard deviation (i.e. 1.f / sqrt(var + epsilon)).\n  float rcp_x_stddev = x_var <= 0.f ? 1.f : 1.f / sqrtf(x_var + params.epsilon);\n\n  // Mutiply by 1/(HWC) to get real mean\n  mean_1 *= params.inv_hwc_per_group;\n  mean_2 *= params.inv_hwc_per_group;\n\n  // Load gamma.\n  float2 gamma_f2 = make_float2(0.f, 0.f);\n  float2 beta_f2 = make_float2(0.f, 0.f);\n  if (ci < params.c) {\n    gamma_f2 = WTraits::unpack(*reinterpret_cast<const WType2*>(&reinterpret_cast<const WType*>(params.gamma)[ci]));\n    if (params.with_swish) {\n      beta_f2 = WTraits::unpack(*reinterpret_cast<const WType2*>(&reinterpret_cast<const WType*>(params.beta)[ci]));\n    }\n  }\n\n  // The first activation loaded by that block.\n  int hw_begin = blockIdx.y * params.acts_per_block;\n  // The last activation loaded by that block.\n  int hw_end = min((int64_t)hw_begin + params.acts_per_block, params.hw);\n\n  // Iterate over the activations to compute the sums.\n  for (int hwi = hw_begin; hwi < hw_end; ++hwi) {\n    // The src/dst offset.\n    int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci;\n\n    // Fetch two channels per thread.\n    IOType2 x_v2 = IOTraits::zero();\n    IOType2 dy_v2 = IOTraits::zero();\n    if (ci < params.c) {\n      x_v2 = *reinterpret_cast<const IOType2*>(&reinterpret_cast<const IOType*>(params.x)[offset]);\n      dy_v2 = *reinterpret_cast<const IOType2*>(&reinterpret_cast<const IOType*>(params.dy)[offset]);\n    }\n\n    // Extract the two half values.\n    float2 x_f2 = IOTraits::unpack(x_v2);\n    float2 dy_f2 = IOTraits::unpack(dy_v2);\n\n    // X - X_mean.\n    float2 x_minus_x_mean_f2;\n    x_minus_x_mean_f2.x = x_f2.x - x_mean;\n    x_minus_x_mean_f2.y = x_f2.y - x_mean;\n\n    // Normalize X.\n    float2 x_norm_f2;\n    x_norm_f2.x = x_minus_x_mean_f2.x * rcp_x_stddev;\n    x_norm_f2.y = x_minus_x_mean_f2.y * rcp_x_stddev;\n\n    if (params.with_swish) {\n      float x_gn_x = x_norm_f2.x * gamma_f2.x + beta_f2.x;\n      float x_gn_y = x_norm_f2.y * gamma_f2.y + beta_f2.y;\n      float s_x = sigmoid(x_gn_x);\n      float s_y = sigmoid(x_gn_y);\n      dy_f2.x = dy_f2.x * s_x * (1.f + x_gn_x * (1.f - s_x));\n      dy_f2.y = dy_f2.y * s_y * (1.f + x_gn_y * (1.f - s_y));\n    }\n\n    // The gradient that enters the x_norm node.\n    float2 dx_norm;\n    dx_norm.x = dy_f2.x * gamma_f2.x;\n    dx_norm.y = dy_f2.y * gamma_f2.y;\n\n    // The gradient along the input path.\n    float2 dx;\n    dx.x = (dx_norm.x - (x_norm_f2.x * mean_1 + mean_2)) * rcp_x_stddev;\n    dx.y = (dx_norm.y - (x_norm_f2.y * mean_1 + mean_2)) * rcp_x_stddev;\n\n    // Store the scaled values.\n    if (ci < params.c) {\n      *reinterpret_cast<IOType2*>(&reinterpret_cast<IOType*>(params.dx)[offset]) = IOTraits::pack(dx);\n    }\n  }\n\n  // Load gamma/beta and convert to half.\n  if (blockIdx.y > 0 || blockIdx.z > 0 || ci >= params.c) {\n    return;\n  }\n\n  // The base pointer for the gradients for gamma and beta.\n  float* dgamma_beta_ptr = &params.zeroed_red_buffer[params.n * params.groups * 2];\n\n  // The 1st channel in the output tensor. NOTE: Two channels per thread store interleaved.\n  int ck = blockIdx.x * params.channels_per_block + threadIdx.x;\n\n  // Load the FP32 version of dgamma and dbeta.\n  float2 dgamma, dbeta;\n  if (ck < params.c) {\n    dgamma.x = dgamma_beta_ptr[0 * params.c + 0 * blockDim.x + ck];\n    dgamma.y = dgamma_beta_ptr[0 * params.c + 1 * blockDim.x + ck];\n    dbeta.x = dgamma_beta_ptr[1 * params.c + 0 * blockDim.x + ck];\n    dbeta.y = dgamma_beta_ptr[1 * params.c + 1 * blockDim.x + ck];\n\n    // Convert to half2 and store to memory.\n    *reinterpret_cast<WType2*>(&reinterpret_cast<WType*>(params.dgamma)[ci]) = WTraits::pack(dgamma);\n    *reinterpret_cast<WType2*>(&reinterpret_cast<WType*>(params.dbeta)[ci]) = WTraits::pack(dbeta);\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nvoid group_norm_nhwc_bwd_two_passes_scale(const Group_norm_nhwc_bwd_params& params, cudaStream_t stream) {\n  // The dimension of the grid.\n  dim3 grid;\n\n  // The number of blocks to compute all the channels.\n  grid.x = params.c / params.channels_per_block;\n  // The number of blocks to compute all the activations in a given instance.\n  grid.y = div_up(params.hw, params.acts_per_block);\n  // The number of instances.\n  grid.z = params.n;\n\n  if (params.precision == PrecisionMode::FP16IOFP16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_scale_kernel, Fp16IOFp16W)\n  } else if (params.precision == PrecisionMode::FP16IOBF16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_scale_kernel, Fp16IOBf16W)\n  } else if (params.precision == PrecisionMode::FP16IOFP32W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_scale_kernel, Fp16IOFp32W)\n  } else if (params.precision == PrecisionMode::BF16IOFP16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_scale_kernel, Bf16IOFp16W)\n  } else if (params.precision == PrecisionMode::BF16IOBF16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_scale_kernel, Bf16IOBf16W)\n  } else if (params.precision == PrecisionMode::BF16IOFP32W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_scale_kernel, Bf16IOFp32W)\n  } else if (params.precision == PrecisionMode::FP32IOFP16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_scale_kernel, Fp32IOFp16W)\n  } else if (params.precision == PrecisionMode::FP32IOBF16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_scale_kernel, Fp32IOBf16W)\n  } else {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_bwd_scale_kernel, Fp32IOFp32W)\n  }\n\n  // Make sure it launched ok.\n  CHECK_CUDA(cudaGetLastError());\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_one_pass.h",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n#include <assert.h>\n\n#include <algorithm>\n\n#include \"group_norm_nhwc.h\"\n#include \"macros.h\"\n#include \"traits.h\"\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// F O R W A R D\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define GN_FWD_SELECT(FUNC_POSTFIX, function)                                                    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(4, FUNC_POSTFIX, function)     \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(8, FUNC_POSTFIX, function)     \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(10, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(12, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(14, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(16, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(20, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(26, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(24, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(28, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(30, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(32, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(40, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(42, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(48, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(56, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(60, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(64, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(70, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(80, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(84, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(96, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(98, FUNC_POSTFIX, function)    \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(112, FUNC_POSTFIX, function)   \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(120, FUNC_POSTFIX, function)   \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(128, FUNC_POSTFIX, function)   \\\n  GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(160, FUNC_POSTFIX, function) { \\\n    assert(false && \"Not implemented\");                                                          \\\n  }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define GN_FWD_RUNNER_SELECT(function) GN_FWD_SELECT(_run, function)\n\n#define GN_FWD_BLOCKS_PER_SM_SELECT(function) GN_FWD_SELECT(_blocks_per_sm, function)\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 4)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 8)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 10)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 12)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 14)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 16)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 20)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 26)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 24)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 28)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 30)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 32)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 40)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 42)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 48)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 56)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 60)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 64)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 70)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 80)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 84)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 96)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 98)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 112)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 120)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 128)\nGN_FWD_ONE_PASS_DECLARATION(/* CHANNELS_PER_GROUP */ 160)\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline void group_norm_nhwc_fwd_one_pass_setup(Group_norm_nhwc_fwd_params& params, size_t& barriers_elts,\n                                               size_t& red_buffer_elts, dim3& grid, const cudaDeviceProp& props) {\n  // The pre-computed dimensions.\n  params.hw = params.h * params.w;\n  params.hwc = params.c * params.hw;\n\n  // The number of channels per group.\n  params.channels_per_group = params.c / params.groups;\n  // The inverse to compute the mean/variance.\n  params.inv_hwc_per_group = 1.f / (float)(params.hw * params.channels_per_group);\n\n  // Select the kernel.\n  using Function_t = int (*)();\n\n  Function_t blocks_per_sm_function;\n  GN_FWD_BLOCKS_PER_SM_SELECT(blocks_per_sm_function);\n\n  // Define how many activations are computed per block.\n  if (params.hw >= 1024 && params.channels_per_group >= 80 || (params.hw >= 256 && params.channels_per_group >= 160)) {\n    params.acts_per_block = 8 * 16;\n  } else if (params.hw >= 512) {\n    params.acts_per_block = 16 * 32;\n  } else if (params.hw >= 256) {\n    params.acts_per_block = 16 * 16;\n  } else if (params.hw >= 128) {\n    params.acts_per_block = 8 * 16;\n  } else if (params.hw > 0) {\n    params.acts_per_block = 8 * 8;\n  } else {\n    // We should never be here if params are set correctly.\n    assert(false);\n  }\n\n  // Define the number of blocks per activation map. TODO: Make sure it matches the kernel sizes.\n  int blocks_per_slice = div_up(params.hw, params.acts_per_block);\n\n  // The number of blocks that can be run per SM.\n  int blocks_per_sm = blocks_per_sm_function();\n\n  // The number of blocks per grid.\n  int max_blocks_per_grid = blocks_per_sm * props.multiProcessorCount;\n\n  // Make sure we are safe to run that many blocks\n  assert(blocks_per_slice <= max_blocks_per_grid);\n\n  // The number of blocks per slice is the X dimension of the grid.\n  grid.x = blocks_per_slice;\n  // The number of groups *  is the X dimension of the grid.\n  grid.y = std::min(max_blocks_per_grid / blocks_per_slice, params.groups * params.n);\n\n  // The number of barriers.\n  barriers_elts = blocks_per_slice > 1 ? grid.y * 2 : 0;\n\n  // The number of elements in the reduction buffer (for the sums and sums of squared).\n  if (blocks_per_slice == 1) {\n    red_buffer_elts = 0;\n  } else {\n    // The first 2 is for double-buffering. The 2nd one is for the fact that we have two floats.\n    red_buffer_elts = 2 * grid.x * grid.y * 2;\n  }\n}\n\ninline void group_norm_nhwc_fwd_one_pass_run(const Group_norm_nhwc_fwd_params& params, const dim3& grid,\n                                             cudaStream_t stream) {\n  using Function_t = void (*)(const Group_norm_nhwc_fwd_params&, const dim3&, cudaStream_t);\n\n  Function_t runner;\n  GN_FWD_RUNNER_SELECT(runner);\n\n  runner(params, grid, stream);\n}\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_one_pass_kernel.cuh",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n#include <assert.h>\n\n#include <cub/cub.cuh>\n\n#include \"group_norm_nhwc.h\"\n#include \"traits.h\"\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// F O R W A R D\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Traits_, int ACTS_PER_BLOCK_, int CHANNELS_PER_GROUP_, int THREADS_PER_BLOCK_>\n__global__ __launch_bounds__(THREADS_PER_BLOCK_) void group_norm_nhwc_fwd_one_pass_kernel(\n    Group_norm_nhwc_fwd_params params) {\n  // The traits.\n  using Traits = Traits_;\n  // The IO traits.\n  using IOTraits = typename Traits::IOTraits;\n  // The Weights traits.\n  using WTraits = typename Traits::WTraits;\n\n  // The IO type\n  using IOType = typename IOTraits::Type;\n  // The IO doubled type\n  using IOType2 = typename IOTraits::Type2;\n\n  // Weights type\n  using WType = typename WTraits::Type;\n  // Weights doubled type\n  using WType2 = typename WTraits::Type2;\n\n  // The number of activations per block.\n  constexpr int ACTS_PER_BLOCK = ACTS_PER_BLOCK_;\n  // The number of channels per group.\n  constexpr int CHANNELS_PER_GROUP = CHANNELS_PER_GROUP_;\n  // The number of threads per block.\n  constexpr int THREADS_PER_BLOCK = THREADS_PER_BLOCK_;\n  // The number of channels per thread (load fp16x2 numbers).\n  constexpr int CHANNELS_PER_THREAD = 2;\n\n  // The number of threads needed per activation.\n  constexpr int THREADS_PER_ACT = CHANNELS_PER_GROUP / CHANNELS_PER_THREAD;\n  // The number of activations that are loaded per loop.\n  constexpr int ACTS_PER_LOOP = THREADS_PER_BLOCK / THREADS_PER_ACT;\n  // The number of rows per thread.\n  constexpr int ACTS_PER_THREAD = (ACTS_PER_BLOCK + ACTS_PER_LOOP - 1) / ACTS_PER_LOOP;\n\n  // The number of active threads.\n  constexpr int ACTIVE_THREADS = THREADS_PER_BLOCK / THREADS_PER_ACT * THREADS_PER_ACT;\n\n  // The object in charge of doing the sums for the block.\n  typedef cub::BlockReduce<float2, THREADS_PER_BLOCK> Block_reduce;\n  // Allocate shared memory for Block_reduce.\n  __shared__ typename Block_reduce::TempStorage temp_storage;\n  // Allocate shared memory to store the sums.\n  __shared__ float2 smem_sums;\n\n  // The first activation loaded by that thread.\n  int hwi = blockIdx.x * params.acts_per_block + threadIdx.x / THREADS_PER_ACT;\n  // The first channel loaded by that thread.\n  int ci = threadIdx.x % THREADS_PER_ACT * CHANNELS_PER_THREAD;\n\n  // Is it an active thread?\n  const bool is_active = threadIdx.x < ACTIVE_THREADS;\n\n  // Iterate over the iterms in the batch.\n  for (int ngi = blockIdx.y, step = 0; ngi < params.n * params.groups; ngi += gridDim.y, ++step) {\n    // The instance and the group. TODO: Use fast divmod?\n    int ni = ngi / params.groups;\n    int gi = ngi % params.groups;\n\n    // The offset to the first activation loaded by that thread.\n    const int64_t offset = (int64_t)ni * params.hwc + gi * CHANNELS_PER_GROUP + ci;\n    // The pointer to the first activation loaded by that thread.\n    const IOType* x_ptr = &reinterpret_cast<const IOType*>(params.x)[offset];\n\n    // Load the activations into registers.\n    IOType2 x[ACTS_PER_THREAD];\n#pragma unroll\n    for (int ii = 0; ii < ACTS_PER_THREAD; ++ii) {\n      int hwj = hwi + ii * ACTS_PER_LOOP;\n      x[ii] = IOTraits::zero();\n      if (is_active && hwj < params.hw) {\n        x[ii] = *reinterpret_cast<const IOType2*>(&x_ptr[hwj * params.c]);\n      }\n    }\n\n    // Compute the sum and the sum of squares for each thread.\n    float2 sums = make_float2(0.f, 0.f);\n#pragma unroll\n    for (int ii = 0; ii < ACTS_PER_THREAD; ++ii) {\n      float2 f2 = IOTraits::unpack(x[ii]);\n      sums.x += f2.x + f2.y;\n      sums.y += f2.x * f2.x + f2.y * f2.y;\n    }\n\n    // Clear invalid threads.\n    if (ACTIVE_THREADS < THREADS_PER_BLOCK && !is_active) {\n      sums = make_float2(0.f, 0.f);\n    }\n\n    // Compute the sums for the block.\n    sums = Block_reduce(temp_storage).Reduce(sums, [](const float2& a, const float2& b) {\n      return make_float2(a.x + b.x, a.y + b.y);\n    });\n\n    // The block leader stores to global memory, if needed.\n    if (gridDim.x > 1) {\n      // The index of the buffer (double-buffering).\n      int red_buffer_idx = step & 1;\n      // The barrier.\n      int* barrier = &params.barriers[red_buffer_idx * gridDim.y + blockIdx.y];\n      // The offset to the reduction buffer.\n      int red_buffer_offset = red_buffer_idx * gridDim.x * gridDim.y * 2;\n      // The reduction buffer.\n      float2* red_buffer = reinterpret_cast<float2*>(&params.red_buffer[red_buffer_offset]);\n\n      // The first thread stores its sums.\n      if (threadIdx.x == 0) {\n        red_buffer[blockIdx.x * gridDim.y + blockIdx.y] = sums;\n      }\n\n      // Make sure the data is in memory.\n      if (threadIdx.x == 0) {\n        spin_wait_(barrier, (step & 2) ? -1 : 1, (step & 2) ? 0 : gridDim.x);\n      }\n      __syncthreads();\n\n      // Update the sums.\n      for (int ii = 0; ii < gridDim.x; ++ii) {\n        if (ii != blockIdx.x && threadIdx.x == 0) {\n          float2 other_sums = red_buffer[ii * gridDim.y + blockIdx.y];\n          sums.x += other_sums.x;\n          sums.y += other_sums.y;\n        }\n      }\n    }\n\n    // Store the result for other threads.\n    if (threadIdx.x == 0) {\n      smem_sums = sums;\n    }\n\n    // Store the results to global memory as well (for training).\n    if (params.sums != nullptr && blockIdx.x == 0 && threadIdx.x == 0) {\n      sums.x *= params.inv_hwc_per_group;\n      sums.y *= params.inv_hwc_per_group;\n      params.sums[ngi] = sums;\n    }\n\n    // Make sure the sums are in shared memory.\n    __syncthreads();\n\n    // Load gamma/beta.\n    float2 gamma_f2 = WTraits::unpack(\n        *reinterpret_cast<const WType2*>(&reinterpret_cast<const WType*>(params.gamma)[gi * CHANNELS_PER_GROUP + ci]));\n    float2 beta_f2 = WTraits::unpack(\n        *reinterpret_cast<const WType2*>(&reinterpret_cast<const WType*>(params.beta)[gi * CHANNELS_PER_GROUP + ci]));\n\n    // Compute the mean.\n    float mean = smem_sums.x * params.inv_hwc_per_group;\n    // Compute the variance.\n    float var = smem_sums.y * params.inv_hwc_per_group - (mean * mean);\n    // Compute the inverse of the stddev.\n    float inv_stddev = var <= 0.f ? 1.f : rsqrtf(var + params.epsilon);\n\n    // The pointer to the first activation stored by that thread.\n    IOType* y_ptr = &reinterpret_cast<IOType*>(params.y)[offset];\n\n    // Iterate over the activations to normalize the activations and store the results.\n    for (int ii = 0; ii < ACTS_PER_THREAD; ++ii) {\n      // Extract the two half values.\n      float2 f2 = IOTraits::unpack(x[ii]);\n\n      // Normalize the channels.\n      f2.x = (f2.x - mean) * inv_stddev;\n      f2.y = (f2.y - mean) * inv_stddev;\n\n      // Scale by gamma and add beta.\n      f2.x = gamma_f2.x * f2.x + beta_f2.x;\n      f2.y = gamma_f2.y * f2.y + beta_f2.y;\n\n      // Apply Swish if needed.\n      if (params.with_swish) {\n        f2.x = f2.x * sigmoid(f2.x);\n        f2.y = f2.y * sigmoid(f2.y);\n      }\n\n      // Store the scaled values.\n      int hwj = hwi + ii * ACTS_PER_LOOP;\n      if (is_active && hwj < params.hw) {\n        *reinterpret_cast<IOType2*>(&y_ptr[hwj * params.c]) = IOTraits::pack(f2);\n      }\n    }\n  }\n}\n\n//////////////////////////////////////////////////////////////////////////////////////////////////\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_fwd_two_pass.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n#include <assert.h>\n\n#include <cub/cub.cuh>\n\n#include \"group_norm_nhwc.h\"\n#include \"macros.h\"\n#include \"traits.h\"\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n//\n// F O R W A R D\n//\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Traits_, int THREADS_PER_BLOCK>\n__global__ void group_norm_nhwc_fwd_sum_kernel(Group_norm_nhwc_fwd_params params) {\n  // The traits.\n  using Traits = Traits_;\n  // The IO traits.\n  using IOTraits = typename Traits::IOTraits;\n\n  // The IO type\n  using IOType = typename IOTraits::Type;\n  // The IO doubled type\n  using IOType2 = typename IOTraits::Type2;\n\n  // The object in charge of doing the sums for the different blocks.\n  typedef cub::BlockScan<Group_sums, THREADS_PER_BLOCK> Block_scan;\n\n  // Allocate shared memory for Block_scan.\n  __shared__ typename Block_scan::TempStorage temp_storage;\n  // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved.\n  __shared__ float2 smem[THREADS_PER_BLOCK];\n\n  // The instance in the batch.\n  int ni = blockIdx.z;\n  // The channel loaded by that thread (2 channels per thread for F16x2).\n  int ci = blockIdx.x * params.channels_per_block + threadIdx.x * 2;\n\n  // The first activation loaded by that block.\n  int hw_begin = blockIdx.y * params.acts_per_block;\n  // The last activation loaded by that block.\n  int hw_end = min((int64_t)hw_begin + params.acts_per_block, params.hw);\n\n  // The sums.\n  float sum = 0.f, sum_sq = 0.f;\n\n  // Iterate over the activations to compute the sums.\n  for (int hwi = hw_begin; hwi < hw_end; ++hwi) {\n    // The offset.\n    int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci;\n\n    // Fetch two channels per thread.\n    IOType2 v2 = IOTraits::zero();\n    if (ci < params.c) {\n      v2 = *reinterpret_cast<const IOType2*>(&reinterpret_cast<const IOType*>(params.x)[offset]);\n    }\n\n    // Extract the two values.\n    float2 f2 = IOTraits::unpack(v2);\n\n    // Update the sum.\n    sum += f2.x + f2.y;\n    // Update the sum of squares.\n    sum_sq += f2.x * f2.x + f2.y * f2.y;\n  }\n\n  // The group that thread works on and the channel in the group (modulus).\n  int gj = threadIdx.x * 2 / params.channels_per_group;\n  int cj = threadIdx.x * 2 - params.channels_per_group * gj;\n\n  // The data for the summations.\n  Group_sums inp{cj == 0 ? 1 : 0, sum, sum_sq};\n\n  // Do the segmented scan.\n  Group_sums out;\n  Block_scan(temp_storage).InclusiveScan(inp, out, Group_sums_op());\n\n  // Store the results for the groups in shared memory (to produce coalesced stores later).\n  if (cj == params.channels_per_group - 2 /* 2 channels per thread */) {\n    smem[gj] = make_float2(out.sum, out.sum_sq);\n  }\n\n  // Make sure the data is in shared memory.\n  __syncthreads();\n\n  // The global group index.\n  int gk = blockIdx.x * params.groups_per_block + threadIdx.x;\n\n  // Threads that have nothing left to do, exit.\n  if (threadIdx.x >= params.groups_per_block || gk >= params.groups) {\n    return;\n  }\n\n  // The first threads (those storing to global memory, load the values).\n  float2 sums = smem[threadIdx.x];\n\n  // Store to global memory.\n  atomicAdd(&params.zeroed_red_buffer[(2 * ni + 0) * params.groups + gk], sums.x);\n  atomicAdd(&params.zeroed_red_buffer[(2 * ni + 1) * params.groups + gk], sums.y);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nvoid group_norm_nhwc_fwd_two_passes_setup(Group_norm_nhwc_fwd_params& params, size_t& zeroed_red_buffer_elts) {\n  // The pre-computed dimensions.\n  params.hw = params.h * params.w;\n  params.hwc = params.c * params.hw;\n\n  // The number of channels per group.\n  params.channels_per_group = params.c / params.groups;\n  // The inverse to compute the mean/variance.\n  params.inv_hwc_per_group = 1.f / (float)(params.hw * params.channels_per_group);\n\n  // Define the number of blocks per activation map. That's a simple heuristic.\n  int blocks_per_act_slice = 0;\n  if (params.c >= 1280) {\n    blocks_per_act_slice = 128 / params.n;\n  } else if (params.c >= 640) {\n    blocks_per_act_slice = 256 / params.n;\n  } else {\n    blocks_per_act_slice = 512 / params.n;\n  }\n\n  // Clamp to at least 1 to avoid divide-by-zero when batch size is large.\n  blocks_per_act_slice = max(blocks_per_act_slice, 1);\n\n  // Make sure we launch blocks per activation is no less than activations\n  blocks_per_act_slice = min(blocks_per_act_slice, div_up(params.hw, params.n));\n\n  // Define how many activations are computed per block.\n  params.acts_per_block = div_up(params.hw, blocks_per_act_slice);\n  // The number of channels per block.\n  params.channels_per_block = 320;\n  // Special case to deal with 30 channels per group.\n  if (params.channels_per_block % params.channels_per_group != 0) {\n    params.channels_per_block = 240;\n  }\n\n  // Special case to deal with 70 channels per group.\n  if (params.c == 2240) {\n    params.channels_per_block = 280;\n  } else if (params.c == 832) {\n    params.channels_per_block = 208;\n  }\n\n  if (params.c % params.channels_per_block != 0) {\n    if (params.c % 512 == 0 && params.c != 1536 && params.c != 3072 && params.c % 448 != 0) {\n      params.channels_per_block = 512;\n    } else if (params.c % 42 == 0) {\n      params.channels_per_block = 336;\n    } else if (params.c % 384 == 0) {\n      params.channels_per_block = 384;\n    } else if (params.c % 256 == 0 && params.c % 448 != 0 && params.c % 392 != 0) {\n      params.channels_per_block = 256;\n    } else if (params.c % 128 == 0 && params.c % 448 != 0 && params.c % 392 != 0) {\n      params.channels_per_block = 128;\n    } else if (params.c % 448 == 0 && params.c % 392 != 0) {\n      params.channels_per_block = 448;\n    } else if (params.c % 392 == 0) {\n      params.channels_per_block = 392;\n    }\n  }\n\n  // The number of groups per block.\n  params.groups_per_block = params.channels_per_block / params.channels_per_group;\n\n  // Make sure the number of channels is a multiple of the number of channels per block.\n  assert(params.c % params.channels_per_block == 0);\n  // Make sure a group does not span multiple blocks.\n  assert(params.channels_per_block % params.channels_per_group == 0);\n\n  // The number of elements in the reduction buffer (for the sums and sums of squared).\n  zeroed_red_buffer_elts = params.n * params.groups * 2;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nvoid group_norm_nhwc_fwd_two_passes_sum(const Group_norm_nhwc_fwd_params& params, cudaStream_t stream) {\n  // The dimension of the grid.\n  dim3 grid;\n\n  // The number of blocks to compute all the channels.\n  grid.x = params.c / params.channels_per_block;\n  // The number of blocks to compute all the activations in a given instance.\n  grid.y = div_up(params.hw, params.acts_per_block);\n  // The number of instances.\n  grid.z = params.n;\n\n  // Launch the kernel.\n  if (params.precision == PrecisionMode::FP16IOFP16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_sum_kernel, Fp16IOFp16W)\n  } else if (params.precision == PrecisionMode::FP16IOBF16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_sum_kernel, Fp16IOBf16W)\n  } else if (params.precision == PrecisionMode::FP16IOFP32W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_sum_kernel, Fp16IOFp32W)\n  } else if (params.precision == PrecisionMode::BF16IOFP16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_sum_kernel, Bf16IOFp16W)\n  } else if (params.precision == PrecisionMode::BF16IOBF16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_sum_kernel, Bf16IOBf16W)\n  } else if (params.precision == PrecisionMode::BF16IOFP32W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_sum_kernel, Bf16IOFp32W)\n  } else if (params.precision == PrecisionMode::FP32IOFP16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_sum_kernel, Fp32IOFp16W)\n  } else if (params.precision == PrecisionMode::FP32IOBF16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_sum_kernel, Fp32IOBf16W)\n  } else {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_sum_kernel, Fp32IOFp32W)\n  }\n\n  // Make sure it launched ok.\n  CHECK_CUDA(cudaGetLastError());\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Traits_, int THREADS_PER_BLOCK>\n__global__ void group_norm_nhwc_fwd_scale_kernel(Group_norm_nhwc_fwd_params params) {\n  // The traits.\n  using Traits = Traits_;\n  // The IO traits.\n  using IOTraits = typename Traits::IOTraits;\n  // The Weights traits.\n  using WTraits = typename Traits::WTraits;\n\n  // The IO type\n  using IOType = typename IOTraits::Type;\n  // The IO doubled type\n  using IOType2 = typename IOTraits::Type2;\n\n  // Weights type\n  using WType = typename WTraits::Type;\n  // Weights doubled type\n  using WType2 = typename WTraits::Type2;\n\n  // The instance in the batch.\n  int ni = blockIdx.z;\n  // The channel loaded by that thread (2 channels per thread for F16x2).\n  int ci = blockIdx.x * params.channels_per_block + threadIdx.x * 2;\n  // The group that thread works on and the channel in the group (modulus).\n  int gi = ci / params.channels_per_group;\n\n  // Load the sum and sum of squares for the group.\n  float sum = 0.f, sum_sq = 0.f;\n  if (gi < params.groups) {\n    sum = params.zeroed_red_buffer[(2 * ni + 0) * params.groups + gi];\n    sum_sq = params.zeroed_red_buffer[(2 * ni + 1) * params.groups + gi];\n  }\n\n  // Load gamma/beta.\n  float2 gamma_f2, beta_f2;\n  if (ci < params.c) {\n    gamma_f2 = WTraits::unpack(*reinterpret_cast<const WType2*>(&reinterpret_cast<const WType*>(params.gamma)[ci]));\n    beta_f2 = WTraits::unpack(*reinterpret_cast<const WType2*>(&reinterpret_cast<const WType*>(params.beta)[ci]));\n  }\n\n  // Compute the mean.\n  float mean = sum * params.inv_hwc_per_group;\n  // Compute the variance.\n  float var = sum_sq * params.inv_hwc_per_group - (mean * mean);\n  // Compute the inverse of the stddev.\n  float inv_stddev = var <= 0.f ? 1.f : rsqrtf(var + params.epsilon);\n\n  // The first activation loaded by that block.\n  int hw_begin = blockIdx.y * params.acts_per_block;\n  // The last activation loaded by that block.\n  int hw_end = min((int64_t)hw_begin + params.acts_per_block, params.hw);\n\n  // Iterate over the activations to compute the sums.\n  for (int hwi = hw_begin; hwi < hw_end; ++hwi) {\n    // The src/dst offset.\n    int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci;\n\n    // Fetch two channels per thread.\n    IOType2 v2 = IOTraits::zero();\n    if (ci < params.c) {\n      v2 = *reinterpret_cast<const IOType2*>(&reinterpret_cast<const IOType*>(params.x)[offset]);\n    }\n\n    // Extract the two values.\n    float2 f2 = IOTraits::unpack(v2);\n\n    // Normalize the channels.\n    f2.x = (f2.x - mean) * inv_stddev;\n    f2.y = (f2.y - mean) * inv_stddev;\n\n    // Scale by gamma and add beta.\n    f2.x = gamma_f2.x * f2.x + beta_f2.x;\n    f2.y = gamma_f2.y * f2.y + beta_f2.y;\n\n    // Apply Swish if needed.\n    if (params.with_swish) {\n      f2.x = f2.x * sigmoid(f2.x);\n      f2.y = f2.y * sigmoid(f2.y);\n    }\n\n    // Store the scaled values.\n    if (ci < params.c) {\n      *reinterpret_cast<IOType2*>(&reinterpret_cast<IOType*>(params.y)[offset]) = IOTraits::pack(f2);\n    }\n  }\n\n  // Write the sums if needed.\n  if (params.sums != nullptr && gi < params.groups) {\n    float2 sums;\n    sums.x = sum * params.inv_hwc_per_group;\n    sums.y = sum_sq * params.inv_hwc_per_group;\n    params.sums[ni * params.groups + gi] = sums;\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nvoid group_norm_nhwc_fwd_two_passes_scale(const Group_norm_nhwc_fwd_params& params, cudaStream_t stream) {\n  // The dimension of the grid.\n  dim3 grid;\n\n  // The number of blocks to compute all the channels.\n  grid.x = params.c / params.channels_per_block;\n  // The number of blocks to compute all the activations in a given instance.\n  grid.y = div_up(params.hw, params.acts_per_block);\n  // The number of instances.\n  grid.z = params.n;\n\n  if (params.precision == PrecisionMode::FP16IOFP16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_scale_kernel, Fp16IOFp16W)\n  } else if (params.precision == PrecisionMode::FP16IOBF16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_scale_kernel, Fp16IOBf16W)\n  } else if (params.precision == PrecisionMode::FP16IOFP32W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_scale_kernel, Fp16IOFp32W)\n  } else if (params.precision == PrecisionMode::BF16IOFP16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_scale_kernel, Bf16IOFp16W)\n  } else if (params.precision == PrecisionMode::BF16IOBF16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_scale_kernel, Bf16IOBf16W)\n  } else if (params.precision == PrecisionMode::BF16IOFP32W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_scale_kernel, Bf16IOFp32W)\n  } else if (params.precision == PrecisionMode::FP32IOFP16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_scale_kernel, Fp32IOFp16W)\n  } else if (params.precision == PrecisionMode::FP32IOBF16W) {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_scale_kernel, Fp32IOBf16W)\n  } else {\n    CALL_TWO_PASS_KERNEL(group_norm_nhwc_fwd_scale_kernel, Fp32IOFp32W)\n  }\n\n  // Make sure it launched ok.\n  CHECK_CUDA(cudaGetLastError());\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_10.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 10, /* THREADS_PER_BLOCK */ 640)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_112.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 112, /* THREADS_PER_BLOCK */ 448)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_12.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 12, /* THREADS_PER_BLOCK */ 384)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_120.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 120, /* THREADS_PER_BLOCK */ 480)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_128.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 128, /* THREADS_PER_BLOCK */ 512)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_14.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 14, /* THREADS_PER_BLOCK */ 224)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_16.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 16, /* THREADS_PER_BLOCK */ 256)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_160.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 160, /* THREADS_PER_BLOCK */ 640)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_20.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 20, /* THREADS_PER_BLOCK */ 640)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_24.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 24, /* THREADS_PER_BLOCK */ 384)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_26.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 26, /* THREADS_PER_BLOCK */ 416)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_28.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 28, /* THREADS_PER_BLOCK */ 448)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_30.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 30, /* THREADS_PER_BLOCK */ 480)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_32.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 32, /* THREADS_PER_BLOCK */ 512)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_4.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 4, /* THREADS_PER_BLOCK */ 128)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_40.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 40, /* THREADS_PER_BLOCK */ 640)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_42.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 42, /* THREADS_PER_BLOCK */ 672)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_48.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 48, /* THREADS_PER_BLOCK */ 384)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_56.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 56, /* THREADS_PER_BLOCK */ 448)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_60.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 60, /* THREADS_PER_BLOCK */ 480)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_64.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 64, /* THREADS_PER_BLOCK */ 512)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_70.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 70, /* THREADS_PER_BLOCK */ 560)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_8.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 8, /* THREADS_PER_BLOCK */ 128)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_80.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 80, /* THREADS_PER_BLOCK */ 640)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_84.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 84, /* THREADS_PER_BLOCK */ 672)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_96.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 96, /* THREADS_PER_BLOCK */ 768)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_98.cu",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include \"group_norm_nhwc_bwd_one_pass_kernel.cuh\"\n#include \"group_norm_nhwc_fwd_one_pass_kernel.cuh\"\n#include \"macros.h\"\n\nGN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 98, /* THREADS_PER_BLOCK */ 392)\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n\n#include \"group_norm_nhwc.h\"\n#include \"group_norm_nhwc_bwd_one_pass.h\"\n#include \"group_norm_nhwc_fwd_one_pass.h\"\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define CHECK_CUDA_STATUS(call)                                                                     \\\n  do {                                                                                              \\\n    cudaError_t status_ = call;                                                                     \\\n    if (status_ != cudaSuccess) {                                                                   \\\n      fprintf(stderr, \"CUDA error (%s:%d): %s\\n\", __FILE__, __LINE__, cudaGetErrorString(status_)); \\\n      exit(1);                                                                                      \\\n    }                                                                                               \\\n  } while (0)\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_CHANNELS_LAST(x) TORCH_CHECK(x.is_contiguous(at::MemoryFormat::ChannelsLast), #x \" must be channels last\")\n#define CHECK_INPUT(x) \\\n  CHECK_CUDA(x);       \\\n  CHECK_CONTIGUOUS(x)\n#define CHECK_NHWC_INPUT(x) \\\n  CHECK_CUDA(x);            \\\n  CHECK_CHANNELS_LAST(x)\n\nstatic bool initialized = false;\nstatic cudaDeviceProp props;\n\nconst std::unordered_set<int> supported_c_values = {128,  256,  320,  384,  448,  512,  640,  768,\n                                                    896,  960,  1024, 1280, 1344, 1536, 1792, 1920,\n                                                    2048, 2240, 2560, 2688, 3072, 3136, 3584, 4096};\nconst std::unordered_set<int> supported_groups_values = {16, 32};\n\nstd::vector<torch::Tensor> group_norm_fwd(torch::Tensor input, int groups, torch::Tensor weight, torch::Tensor bias,\n                                          float eps, int passes, bool with_swish = false) {\n  if (!initialized) {\n    CHECK_CUDA_STATUS(cudaGetDeviceProperties(&props, 0));\n    initialized = true;\n  }\n  CHECK_NHWC_INPUT(input);\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  // Achieve group norm arguments\n  int n = input.size(0);\n  int c = input.size(1);\n  int h = input.size(2);\n  int w = input.size(3);\n\n  // Check kernel constraints\n  TORCH_CHECK(supported_groups_values.count(groups), \"`groups` of {16, 32} are only supported but \", groups,\n              \" is passed\");\n  TORCH_CHECK(supported_c_values.count(c), \"`c` of \", c, \" is not included in supported_c_values\");\n\n  // Allocate tensors\n  auto options = at::TensorOptions(at::kCUDA);\n  auto output = at::empty_like(input, at::MemoryFormat::Preserve);\n  auto sums_d = at::empty({2 * n * groups}, options.dtype(at::kFloat));\n\n  // Declare the parameters.\n  Group_norm_nhwc_fwd_params params_fwd;\n  memset(&params_fwd, 0, sizeof(params_fwd));\n\n  // Initialize the parameters.\n  params_fwd.y = reinterpret_cast<void*>(output.data_ptr());\n  params_fwd.sums = reinterpret_cast<float2*>(sums_d.data_ptr());\n  params_fwd.x = const_cast<void*>(reinterpret_cast<void*>(input.data_ptr()));\n  params_fwd.gamma = const_cast<void*>(reinterpret_cast<void*>(weight.data_ptr()));\n  params_fwd.beta = const_cast<void*>(reinterpret_cast<void*>(bias.data_ptr()));\n  params_fwd.epsilon = eps;\n  params_fwd.n = n;\n  params_fwd.h = h;\n  params_fwd.w = w;\n  params_fwd.c = c;\n  params_fwd.groups = groups;\n  params_fwd.with_swish = with_swish;\n\n  PrecisionMode mode;\n  if (input.dtype() == torch::kFloat32) {\n    if (weight.dtype() == torch::kFloat16) {\n      mode = PrecisionMode::FP32IOFP16W;\n    } else if (weight.dtype() == torch::kBFloat16) {\n      mode = PrecisionMode::FP32IOBF16W;\n    } else {\n      mode = PrecisionMode::FP32IOFP32W;\n    }\n  } else if (input.dtype() == torch::kBFloat16) {\n    if (weight.dtype() == torch::kFloat16) {\n      mode = PrecisionMode::BF16IOFP16W;\n    } else if (weight.dtype() == torch::kBFloat16) {\n      mode = PrecisionMode::BF16IOBF16W;\n    } else {\n      mode = PrecisionMode::BF16IOFP32W;\n    }\n  } else {\n    if (weight.dtype() == torch::kFloat16) {\n      mode = PrecisionMode::FP16IOFP16W;\n    } else if (weight.dtype() == torch::kBFloat16) {\n      mode = PrecisionMode::FP16IOBF16W;\n    } else {\n      mode = PrecisionMode::FP16IOFP32W;\n    }\n  }\n  params_fwd.precision = mode;\n\n  // The number of barriers.\n  size_t barriers_elts = 0;\n  // The number of elements in the reduction buffer.\n  size_t red_buffer_elts = 0;\n  // The number of elements in the reduction buffer that must be zeroed.\n  size_t zeroed_red_buffer_elts = 0;\n\n  // Finalize the parameters.\n  dim3 grid;\n  if (passes == 1) {\n    group_norm_nhwc_fwd_one_pass_setup(params_fwd, barriers_elts, red_buffer_elts, grid, props);\n  } else {\n    group_norm_nhwc_fwd_two_passes_setup(params_fwd, zeroed_red_buffer_elts);\n  }\n\n  // Allocate on the device.\n  auto red_buffer = at::empty({red_buffer_elts}, options.dtype(at::kFloat));\n  params_fwd.red_buffer = red_buffer.data_ptr<float>();\n\n  // Allocate the buffer if needed.\n  auto barriers = at::zeros({barriers_elts}, options.dtype(at::kInt));\n  params_fwd.barriers = barriers.data_ptr<int>();\n  auto zeroed_red_buffer = at::zeros({zeroed_red_buffer_elts}, options.dtype(at::kFloat));\n  params_fwd.zeroed_red_buffer = zeroed_red_buffer.data_ptr<float>();\n\n  if (passes == 1) {\n    group_norm_nhwc_fwd_one_pass_run(params_fwd, grid, stream);\n  } else {\n    group_norm_nhwc_fwd_two_passes_sum(params_fwd, stream);\n    group_norm_nhwc_fwd_two_passes_scale(params_fwd, stream);\n  }\n\n  return {output, sums_d};\n}\n\nstd::vector<torch::Tensor> group_norm_bwd(torch::Tensor grad_output, torch::Tensor sums, torch::Tensor input,\n                                          int groups, torch::Tensor weight, torch::Tensor bias, float eps, int passes,\n                                          bool with_swish = false) {\n  if (!initialized) {\n    CHECK_CUDA_STATUS(cudaGetDeviceProperties(&props, 0));\n    initialized = true;\n  }\n  CHECK_NHWC_INPUT(grad_output);\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  // Achieve group norm arguments\n  int n = input.size(0);\n  int c = input.size(1);\n  int h = input.size(2);\n  int w = input.size(3);\n\n  // Check kernel constraints\n  TORCH_CHECK(supported_groups_values.count(groups), \"`groups` of {16, 32} are only supported but \", groups,\n              \" is passed\");\n  TORCH_CHECK(supported_c_values.count(c), \"`c` of \", c, \" is not included in supported_c_values\");\n\n  // Allocate tensors\n  auto options = at::TensorOptions(at::kCUDA);\n  auto grad_input = at::empty_like(input, at::MemoryFormat::Preserve);\n  auto grad_weight = at::empty_like(weight, at::MemoryFormat::Preserve);\n  auto grad_bias = at::empty_like(bias, at::MemoryFormat::Preserve);\n  auto sums_d = at::empty({2 * n * groups}, options.dtype(at::kFloat));\n\n  // Declare the parameters.\n  Group_norm_nhwc_bwd_params params_bwd;\n  memset(&params_bwd, 0, sizeof(params_bwd));\n\n  // Initialize the parameters.\n  params_bwd.dx = reinterpret_cast<void*>(grad_input.data_ptr());\n  params_bwd.dgamma = reinterpret_cast<void*>(grad_weight.data_ptr());\n  params_bwd.dbeta = reinterpret_cast<void*>(grad_bias.data_ptr());\n  params_bwd.sums = const_cast<float2*>(reinterpret_cast<float2*>(sums.data_ptr()));\n  params_bwd.dy = const_cast<void*>(reinterpret_cast<void*>(grad_output.data_ptr()));\n  params_bwd.x = const_cast<void*>(reinterpret_cast<void*>(input.data_ptr()));\n  ;\n  params_bwd.gamma = const_cast<void*>(reinterpret_cast<void*>(weight.data_ptr()));\n  params_bwd.beta = const_cast<void*>(reinterpret_cast<void*>(bias.data_ptr()));\n  ;\n  params_bwd.epsilon = eps;\n  params_bwd.n = n;\n  params_bwd.h = h;\n  params_bwd.w = w;\n  params_bwd.c = c;\n  params_bwd.groups = groups;\n  params_bwd.with_swish = with_swish;\n\n  PrecisionMode mode;\n  if (input.dtype() == torch::kFloat32) {\n    if (weight.dtype() == torch::kFloat16) {\n      mode = PrecisionMode::FP32IOFP16W;\n    } else if (weight.dtype() == torch::kBFloat16) {\n      mode = PrecisionMode::FP32IOBF16W;\n    } else {\n      mode = PrecisionMode::FP32IOFP32W;\n    }\n  } else if (input.dtype() == torch::kBFloat16) {\n    if (weight.dtype() == torch::kFloat16) {\n      mode = PrecisionMode::BF16IOFP16W;\n    } else if (weight.dtype() == torch::kBFloat16) {\n      mode = PrecisionMode::BF16IOBF16W;\n    } else {\n      mode = PrecisionMode::BF16IOFP32W;\n    }\n  } else {\n    if (weight.dtype() == torch::kFloat16) {\n      mode = PrecisionMode::FP16IOFP16W;\n    } else if (weight.dtype() == torch::kBFloat16) {\n      mode = PrecisionMode::FP16IOBF16W;\n    } else {\n      mode = PrecisionMode::FP16IOFP32W;\n    }\n  }\n  params_bwd.precision = mode;\n\n  // The number of barriers.\n  size_t barriers_elts = 0;\n  // The number of elements in the reduction buffer.\n  size_t red_buffer_elts = 0;\n  // The number of elements in the reduction buffer that must be zeroed.\n  size_t zeroed_red_buffer_elts = 0;\n\n  // Finalize the parameters.\n  dim3 grid;\n  if (passes == 1) {\n    group_norm_nhwc_bwd_one_pass_setup(params_bwd, barriers_elts, red_buffer_elts, zeroed_red_buffer_elts, grid, props);\n  } else {\n    group_norm_nhwc_bwd_two_passes_setup(params_bwd, zeroed_red_buffer_elts);\n  }\n\n  // Allocate on the device.\n  auto red_buffer = at::empty({red_buffer_elts}, options.dtype(at::kFloat));\n  params_bwd.red_buffer = red_buffer.data_ptr<float>();\n\n  // Allocate the buffer if needed.\n  auto barriers = at::zeros({barriers_elts}, options.dtype(at::kInt));\n  params_bwd.barriers = barriers.data_ptr<int>();\n  auto zeroed_red_buffer = at::zeros({zeroed_red_buffer_elts}, options.dtype(at::kFloat));\n  params_bwd.zeroed_red_buffer = zeroed_red_buffer.data_ptr<float>();\n\n  if (passes == 1) {\n    group_norm_nhwc_bwd_one_pass_run(params_bwd, grid, stream);\n  } else {\n    group_norm_nhwc_bwd_two_passes_sum(params_bwd, stream);\n    group_norm_nhwc_bwd_two_passes_scale(params_bwd, stream);\n  }\n\n  return {grad_input, grad_weight, grad_bias};\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &group_norm_fwd, \"NHWC group norm forward\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward\", &group_norm_bwd, \"NHWC group norm backward\", py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/macros.h",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n\n#define GN_ONE_PASS_RUN_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \\\n  void group_norm_nhwc_##PASS_NAME##_one_pass_##CHANNELS_PER_GROUP##_##ACTS_PER_BLOCK##_##Traits##_run(         \\\n      const Group_norm_nhwc_##PASS_NAME##_params& params, const dim3& grid, cudaStream_t stream)\n\n#define GN_ONE_PASS_RUN_FUNCTION(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME)            \\\n  GN_ONE_PASS_RUN_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) {           \\\n    auto kernel =                                                                                                     \\\n        group_norm_nhwc_##PASS_NAME##_one_pass_kernel<Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK>; \\\n                                                                                                                      \\\n    const Group_norm_nhwc_##PASS_NAME##_params* params_ = &params;                                                    \\\n    if (grid.x > 1) {                                                                                                 \\\n      CHECK_CUDA(cudaLaunchCooperativeKernel((const void*)kernel, grid, dim3(THREADS_PER_BLOCK), (void**)&params_, 0, \\\n                                             stream));                                                                \\\n                                                                                                                      \\\n    } else {                                                                                                          \\\n      CHECK_CUDA(cudaLaunchKernel((const void*)kernel, grid, dim3(THREADS_PER_BLOCK), (void**)&params_, 0, stream));  \\\n    }                                                                                                                 \\\n                                                                                                                      \\\n    CHECK_CUDA(cudaGetLastError());                                                                                   \\\n  }\n\n//////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \\\n                                                PASS_NAME)                                                     \\\n  int group_norm_nhwc_##PASS_NAME##_one_pass_##CHANNELS_PER_GROUP##_##ACTS_PER_BLOCK##_##Traits##_blocks_per_sm()\n\n#define GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME)  \\\n  GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) { \\\n    auto kernel =                                                                                                     \\\n        group_norm_nhwc_##PASS_NAME##_one_pass_kernel<Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK>; \\\n                                                                                                                      \\\n    int blocks_per_sm = 0;                                                                                            \\\n    CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks_per_sm, kernel, THREADS_PER_BLOCK, 0));          \\\n                                                                                                                      \\\n    CHECK_CUDA(cudaGetLastError());                                                                                   \\\n    return blocks_per_sm;                                                                                             \\\n  }\n\n//////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define GN_ONE_PASS_(FUNCTION, Traits, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \\\n  FUNCTION(Traits, 512, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME);               \\\n  FUNCTION(Traits, 256, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME);               \\\n  FUNCTION(Traits, 128, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME);               \\\n  FUNCTION(Traits, 64, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME);\n\n#define GN_ONE_PASS_RUN_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME)                     \\\n  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp32IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp32IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp32IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Bf16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Bf16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Bf16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME);\n\n#define GN_ONE_PASS_RUN_DECLARATION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME)                         \\\n  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp32IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp32IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp32IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Bf16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Bf16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Bf16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME);\n\n#define GN_ONE_PASS_BLOCKS_PER_SM_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME)                     \\\n  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp32IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp32IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp32IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Bf16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Bf16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \\\n  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Bf16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME);\n\n#define GN_ONE_PASS_BLOCKS_PER_SM_DECLARATION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME)             \\\n  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp32IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \\\n               PASS_NAME);                                                                                  \\\n  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp32IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \\\n               PASS_NAME);                                                                                  \\\n  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp32IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \\\n               PASS_NAME);                                                                                  \\\n  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \\\n               PASS_NAME);                                                                                  \\\n  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \\\n               PASS_NAME);                                                                                  \\\n  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \\\n               PASS_NAME);                                                                                  \\\n  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Bf16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \\\n               PASS_NAME);                                                                                  \\\n  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Bf16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \\\n               PASS_NAME);                                                                                  \\\n  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Bf16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME);\n\n#define GN_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \\\n  GN_ONE_PASS_RUN_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME)   \\\n  GN_ONE_PASS_BLOCKS_PER_SM_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME)\n\n#define GN_FWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK) \\\n  GN_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, fwd)\n\n#define GN_BWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK) \\\n  GN_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, bwd)\n\n#define GN_FWD_BWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK) \\\n  GN_FWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK)           \\\n  GN_BWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK)\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, HW_THRESHOLD, ACTS_PER_BLOCK,          \\\n                               CHANNELS_PER_GROUP, PASS_NAME)                                                    \\\n  if (params.hw >= HW_THRESHOLD && params.channels_per_group == CHANNELS_PER_GROUP &&                            \\\n      params.precision == PrecisionMode::PRECISION) {                                                            \\\n    function =                                                                                                   \\\n        group_norm_nhwc_##PASS_NAME##_one_pass_##CHANNELS_PER_GROUP##_##ACTS_PER_BLOCK##_##Traits##FUNC_POSTFIX; \\\n  } else\n\n#define GN_SELECTION_STATEMENT_WITH_CPG_LIMIT(function, Traits, PRECISION, FUNC_POSTFIX, HW_THRESHOLD, ACTS_PER_BLOCK, \\\n                                              CHANNELS_PER_GROUP, PASS_NAME, LIMIT_CPG)                                \\\n  if (params.hw >= HW_THRESHOLD && params.channels_per_group == CHANNELS_PER_GROUP &&                                  \\\n      params.precision == PrecisionMode::PRECISION && CHANNELS_PER_GROUP >= LIMIT_CPG) {                               \\\n    function =                                                                                                         \\\n        group_norm_nhwc_##PASS_NAME##_one_pass_##CHANNELS_PER_GROUP##_##ACTS_PER_BLOCK##_##Traits##FUNC_POSTFIX;       \\\n  } else\n\n#define GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Traits, PRECISION, CHANNELS_PER_GROUP,        \\\n                                                                    FUNC_POSTFIX, function, PASS_NAME)            \\\n  GN_SELECTION_STATEMENT_WITH_CPG_LIMIT(function, Traits, PRECISION, FUNC_POSTFIX, 1024, 128, CHANNELS_PER_GROUP, \\\n                                        PASS_NAME, 80)                                                            \\\n  GN_SELECTION_STATEMENT_WITH_CPG_LIMIT(function, Traits, PRECISION, FUNC_POSTFIX, 256, 128, CHANNELS_PER_GROUP,  \\\n                                        PASS_NAME, 160)                                                           \\\n  GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 512, 512, CHANNELS_PER_GROUP, PASS_NAME)      \\\n  GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 256, 256, CHANNELS_PER_GROUP, PASS_NAME)      \\\n  GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 128, 128, CHANNELS_PER_GROUP, PASS_NAME)      \\\n  GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 0, 64, CHANNELS_PER_GROUP, PASS_NAME)\n\n#define GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(CHANNELS_PER_GROUP, FUNC_POSTFIX, function) \\\n  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp16W, FP32IOFP16W, CHANNELS_PER_GROUP,         \\\n                                                              FUNC_POSTFIX, function, fwd)                          \\\n  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOBf16W, FP32IOBF16W, CHANNELS_PER_GROUP,         \\\n                                                              FUNC_POSTFIX, function, fwd)                          \\\n  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp32W, FP32IOFP32W, CHANNELS_PER_GROUP,         \\\n                                                              FUNC_POSTFIX, function, fwd)                          \\\n  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp16W, FP16IOFP16W, CHANNELS_PER_GROUP,         \\\n                                                              FUNC_POSTFIX, function, fwd)                          \\\n  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOBf16W, FP16IOBF16W, CHANNELS_PER_GROUP,         \\\n                                                              FUNC_POSTFIX, function, fwd)                          \\\n  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp32W, FP16IOFP32W, CHANNELS_PER_GROUP,         \\\n                                                              FUNC_POSTFIX, function, fwd)                          \\\n  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp16W, BF16IOFP16W, CHANNELS_PER_GROUP,         \\\n                                                              FUNC_POSTFIX, function, fwd)                          \\\n  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOBf16W, BF16IOBF16W, CHANNELS_PER_GROUP,         \\\n                                                              FUNC_POSTFIX, function, fwd)                          \\\n  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp32W, BF16IOFP32W, CHANNELS_PER_GROUP,         \\\n                                                              FUNC_POSTFIX, function, fwd)\n\n#define GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(CHANNELS_PER_GROUP, FUNC_POSTFIX, function) \\\n  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp16W, FP32IOFP16W, CHANNELS_PER_GROUP,         \\\n                                                              FUNC_POSTFIX, function, bwd)                          \\\n  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOBf16W, FP32IOBF16W, CHANNELS_PER_GROUP,         \\\n                                                              FUNC_POSTFIX, function, bwd)                          \\\n  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp32W, FP32IOFP32W, CHANNELS_PER_GROUP,         \\\n                                                              FUNC_POSTFIX, function, bwd)                          \\\n  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp16W, FP16IOFP16W, CHANNELS_PER_GROUP,         \\\n                                                              FUNC_POSTFIX, function, bwd)                          \\\n  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOBf16W, FP16IOBF16W, CHANNELS_PER_GROUP,         \\\n                                                              FUNC_POSTFIX, function, bwd)                          \\\n  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp32W, FP16IOFP32W, CHANNELS_PER_GROUP,         \\\n                                                              FUNC_POSTFIX, function, bwd)                          \\\n  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp16W, BF16IOFP16W, CHANNELS_PER_GROUP,         \\\n                                                              FUNC_POSTFIX, function, bwd)                          \\\n  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOBf16W, BF16IOBF16W, CHANNELS_PER_GROUP,         \\\n                                                              FUNC_POSTFIX, function, bwd)                          \\\n  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp32W, BF16IOFP32W, CHANNELS_PER_GROUP,         \\\n                                                              FUNC_POSTFIX, function, bwd)\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define GN_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP, PASS_NAME)                    \\\n  GN_ONE_PASS_RUN_DECLARATION(CHANNELS_PER_GROUP, /* dummy value */ 0, PASS_NAME) \\\n  GN_ONE_PASS_BLOCKS_PER_SM_DECLARATION(CHANNELS_PER_GROUP, /* dummy value */ 0, PASS_NAME)\n\n#define GN_FWD_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP) GN_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP, fwd)\n\n#define GN_BWD_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP) GN_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP, bwd)\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define CALL_TWO_PASS_KERNEL(Kernel, Precision)               \\\n  if (params.channels_per_block == 320) {                     \\\n    Kernel<Precision, 160><<<grid, 160, 0, stream>>>(params); \\\n  } else if (params.channels_per_block == 280) {              \\\n    Kernel<Precision, 140><<<grid, 140, 0, stream>>>(params); \\\n  } else if (params.channels_per_block == 208) {              \\\n    Kernel<Precision, 140><<<grid, 104, 0, stream>>>(params); \\\n  } else if (params.channels_per_block == 240) {              \\\n    Kernel<Precision, 120><<<grid, 120, 0, stream>>>(params); \\\n  } else if (params.channels_per_block == 512) {              \\\n    Kernel<Precision, 256><<<grid, 256, 0, stream>>>(params); \\\n  } else if (params.channels_per_block == 448) {              \\\n    Kernel<Precision, 448><<<grid, 224, 0, stream>>>(params); \\\n  } else if (params.channels_per_block == 384) {              \\\n    Kernel<Precision, 192><<<grid, 192, 0, stream>>>(params); \\\n  } else if (params.channels_per_block == 256) {              \\\n    Kernel<Precision, 128><<<grid, 128, 0, stream>>>(params); \\\n  } else if (params.channels_per_block == 128) {              \\\n    Kernel<Precision, 64><<<grid, 64, 0, stream>>>(params);   \\\n  } else if (params.channels_per_block == 336) {              \\\n    Kernel<Precision, 168><<<grid, 168, 0, stream>>>(params); \\\n  } else if (params.channels_per_block == 392) {              \\\n    Kernel<Precision, 196><<<grid, 196, 0, stream>>>(params); \\\n  } else {                                                    \\\n    assert(false);                                            \\\n  }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm/traits.h",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n */\n#pragma once\n\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime_api.h>\n#include <math.h>\n#include <stdint.h>\n#include <stdio.h>\n#include <stdlib.h>\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Fp32 {\n  // Type is float32_t\n  using Type = float;\n  // Doubled type\n  using Type2 = float2;\n\n  // Unpack input to accumulators type\n  static inline __device__ float2 unpack(const float2& f2) { return f2; }\n\n  // Pack the accumulators into outputs.\n  static inline __device__ float2 pack(const float2& f2) { return f2; }\n\n  static inline __device__ float2 zero() { return {0.f, 0.f}; }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Fp16 {\n  // Type is __half\n  using Type = __half;\n  // Doubled type\n  using Type2 = __half2;\n\n  // Unpack input to accumulators type\n  static inline __device__ float2 unpack(const __half2& h2) {\n    // FIXME(nkorobov): __half22float2 makes compilation error in container\n    return {__half2float(h2.x), __half2float(h2.y)};\n  }\n\n  // Pack the accumulators into outputs.\n  static inline __device__ __half2 pack(const float2& f2) {\n    // FIXME(nkorobov): __float22half2_rn makes compilation error in container\n    return {__float2half_rn(f2.x), __float2half_rn(f2.y)};\n  }\n\n  static inline __device__ __half2 zero() {\n    uint32_t zero = 0;\n    return *reinterpret_cast<__half2*>(&zero);\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Bf16 {\n  // Type is __nv_bfloat16\n  using Type = __nv_bfloat16;\n  // Doubled type\n  using Type2 = __nv_bfloat162;\n\n  // Unpack input to accumulators type\n  static inline __device__ float2 unpack(const __nv_bfloat162& h2) {\n    // FIXME(nkorobov): __half22float2 makes compilation error in container\n    return {__bfloat162float(h2.x), __bfloat162float(h2.y)};\n  }\n\n  // Pack the accumulators into outputs.\n  static inline __device__ __nv_bfloat162 pack(const float2& f2) {\n    // FIXME(nkorobov): __float22bfloat162_rn makes compilation error in container\n    return {__float2bfloat16_rn(f2.x), __float2bfloat16_rn(f2.y)};\n  }\n\n  static inline __device__ __nv_bfloat162 zero() {\n    uint32_t zero = 0;\n    return *reinterpret_cast<__nv_bfloat162*>(&zero);\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\nstruct Fp32IOFp16W {\n  // IO traits\n  using IOTraits = Fp32;\n  // Weigths traits\n  using WTraits = Fp16;\n};\n\nstruct Fp32IOBf16W {\n  // IO traits\n  using IOTraits = Fp32;\n  // Weigths traits\n  using WTraits = Bf16;\n};\n\nstruct Fp32IOFp32W {\n  // IO traits\n  using IOTraits = Fp32;\n  // Weigths traits\n  using WTraits = Fp32;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Fp16IOFp16W {\n  // IO traits\n  using IOTraits = Fp16;\n  // Weigths traits\n  using WTraits = Fp16;\n};\n\nstruct Fp16IOBf16W {\n  // IO traits\n  using IOTraits = Fp16;\n  // Weigths traits\n  using WTraits = Bf16;\n};\n\nstruct Fp16IOFp32W {\n  // IO traits\n  using IOTraits = Fp16;\n  // Weigths traits\n  using WTraits = Fp32;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\nstruct Bf16IOFp16W {\n  // IO traits\n  using IOTraits = Bf16;\n  // Weigths traits\n  using WTraits = Fp16;\n};\n\nstruct Bf16IOBf16W {\n  // IO traits\n  using IOTraits = Bf16;\n  // Weigths traits\n  using WTraits = Bf16;\n};\n\nstruct Bf16IOFp32W {\n  // IO traits\n  using IOTraits = Bf16;\n  // Weigths traits\n  using WTraits = Fp32;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/generate_gn_cuda_inst.py",
    "content": "import pathlib\n\n\nhw_c_list = [\n    (8 * 8, 1280),\n    (8 * 8, 2560),\n    (16 * 16, 640),\n    (16 * 16, 1280),\n    (16 * 16, 1920),\n    (16 * 16, 2560),\n    (32 * 32, 320),\n    (32 * 32, 640),\n    (32 * 32, 960),\n    (32 * 32, 1280),\n    (32 * 32, 1920),\n    (64 * 64, 320),\n    (64 * 64, 640),\n    (64 * 64, 960),\n]\n\n\ndef run():\n    src_path = pathlib.Path(__file__).parent.absolute()\n\n    for f in src_path.glob(\"gn_cuda_inst_*.cu\"):\n        f.unlink()\n\n    for hw, c in hw_c_list:\n        print(f\"GN_CUDA_INST_DEFINE({hw}, {c})\")\n        with open(src_path / f\"gn_cuda_inst_{hw}_{c}.cu\", \"w\") as f:\n            f.write('#include \"gn_cuda_host_template.cuh\"\\n')\n            f.write(\"\\n\")\n            f.write(\"\\n\")\n            f.write(\"namespace group_norm_v2 {\\n\")\n            f.write(\"\\n\")\n            f.write(f\"GN_CUDA_INST_DEFINE({hw}, {c})\\n\")\n            f.write(\"\\n\")\n            f.write(\"}  // namespace group_norm_v2\\n\")\n\n    with open(src_path / \"gn_dispatch_hw_c.hpp\", \"w\") as f:\n        f.write(\"#pragma once\\n\")\n        f.write(\"\\n\")\n        f.write(\"#define DISPATCH_HW_C(hw, c, HW, C, ...) [&] { \\\\\\n\")\n        for hw, c in hw_c_list:\n            f.write(\n                f\"    if (hw == {hw} && c == {c}) {{ constexpr int HW = {hw}, C = {c}; return __VA_ARGS__(); }} \\\\\\n\"\n            )\n        f.write(\n            '    throw std::invalid_argument(\"DISPATCH_HW_C \" + std::to_string(hw) + \" \" + std::to_string(c)); \\\\\\n'\n        )\n        f.write(\"    }()\\n\")\n\n\nif __name__ == \"__main__\":\n    run()\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn.cpp",
    "content": "#include \"gn.hpp\"\n\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n\nnamespace group_norm_v2 {\n\ntorch::Tensor gn(torch::Tensor x, torch::Tensor w, torch::Tensor b, float eps, bool silu, int num_groups,\n                 std::optional<torch::Tensor> mean_var_out, int sm_margin) {\n  if (w.dtype() != b.dtype() || (mean_var_out.has_value() && mean_var_out->dtype() != torch::kFloat32)) {\n    throw std::invalid_argument(\"gn dtype mismatch\");\n  }\n  torch::Tensor out = torch::empty_like(x);\n  float* ptr_mean_var_out = mean_var_out.has_value() ? mean_var_out->data_ptr<float>() : nullptr;\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();\n  int device_id = at::cuda::getCurrentCUDAStream().device().index();\n  group_norm_v2::Meta meta;\n  if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) {\n    group_norm_v2::gn_cuda((half*)out.data_ptr(), (half*)x.data_ptr(), (half*)w.data_ptr(), (half*)b.data_ptr(), eps,\n                           silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out,\n                           nullptr, nullptr, sm_margin, stream, device_id, &meta, true);\n  } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) {\n    group_norm_v2::gn_cuda((__nv_bfloat16*)out.data_ptr(), (__nv_bfloat16*)x.data_ptr(), (__nv_bfloat16*)w.data_ptr(),\n                           (__nv_bfloat16*)b.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups,\n                           x.size(1) / num_groups, ptr_mean_var_out, nullptr, nullptr, sm_margin, stream, device_id,\n                           &meta, true);\n  } else {\n    throw std::invalid_argument(\"gn only supports half or bfloat16 input and weight\");\n  }\n  torch::Tensor red_buffer =\n      torch::empty({meta.red_buffer_size}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));\n  thread_local torch::Tensor barrier;\n  if (barrier.size(0) < meta.barrier_size) {\n    barrier = torch::zeros({meta.barrier_size}, torch::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA));\n  }\n  if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) {\n    group_norm_v2::gn_cuda((half*)out.data_ptr(), (half*)x.data_ptr(), (half*)w.data_ptr(), (half*)b.data_ptr(), eps,\n                           silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out,\n                           red_buffer.data_ptr<float>(), barrier.data_ptr<unsigned>(), sm_margin, stream, device_id,\n                           nullptr, false);\n  } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) {\n    group_norm_v2::gn_cuda((__nv_bfloat16*)out.data_ptr(), (__nv_bfloat16*)x.data_ptr(), (__nv_bfloat16*)w.data_ptr(),\n                           (__nv_bfloat16*)b.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups,\n                           x.size(1) / num_groups, ptr_mean_var_out, red_buffer.data_ptr<float>(),\n                           barrier.data_ptr<unsigned>(), sm_margin, stream, device_id, nullptr, false);\n  } else {\n    throw std::invalid_argument(\"gn only supports half or bfloat16 input and weight\");\n  }\n  return out;\n}\n\nauto gn_bwd(torch::Tensor grad_output, torch::Tensor x, torch::Tensor w, torch::Tensor b, torch::Tensor mean_var,\n            float eps, bool silu, int num_groups, int sm_margin) {\n  if (w.dtype() != b.dtype() || x.dtype() != grad_output.dtype() || mean_var.dtype() != torch::kFloat32) {\n    throw std::invalid_argument(\"gn_bwd dtype mismatch\");\n  }\n  torch::Tensor grad_input = torch::empty_like(x);\n  torch::Tensor grad_weight = torch::empty_like(w);\n  torch::Tensor grad_bias = torch::empty_like(w);\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();\n  int device_id = at::cuda::getCurrentCUDAStream().device().index();\n  group_norm_v2::Meta meta;\n  if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) {\n    group_norm_v2::gn_bwd_cuda((half*)grad_input.data_ptr(), (half*)grad_weight.data_ptr(), (half*)grad_bias.data_ptr(),\n                               (half*)grad_output.data_ptr(), (half*)x.data_ptr(), (half*)w.data_ptr(),\n                               (half*)b.data_ptr(), mean_var.data_ptr<float>(), eps, silu, x.size(0),\n                               x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, nullptr, nullptr, sm_margin,\n                               stream, device_id, &meta, true);\n  } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) {\n    group_norm_v2::gn_bwd_cuda((__nv_bfloat16*)grad_input.data_ptr(), (__nv_bfloat16*)grad_weight.data_ptr(),\n                               (__nv_bfloat16*)grad_bias.data_ptr(), (__nv_bfloat16*)grad_output.data_ptr(),\n                               (__nv_bfloat16*)x.data_ptr(), (__nv_bfloat16*)w.data_ptr(), (__nv_bfloat16*)b.data_ptr(),\n                               mean_var.data_ptr<float>(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups,\n                               x.size(1) / num_groups, nullptr, nullptr, sm_margin, stream, device_id, &meta, true);\n  } else {\n    throw std::invalid_argument(\"gn only supports half or bfloat16 input and weight\");\n  }\n  torch::Tensor red_buffer =\n      torch::empty({meta.red_buffer_size}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));\n  thread_local torch::Tensor barrier;\n  if (barrier.size(0) < meta.barrier_size) {\n    barrier = torch::zeros({meta.barrier_size}, torch::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA));\n  }\n  if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) {\n    group_norm_v2::gn_bwd_cuda((half*)grad_input.data_ptr(), (half*)grad_weight.data_ptr(), (half*)grad_bias.data_ptr(),\n                               (half*)grad_output.data_ptr(), (half*)x.data_ptr(), (half*)w.data_ptr(),\n                               (half*)b.data_ptr(), mean_var.data_ptr<float>(), eps, silu, x.size(0),\n                               x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, red_buffer.data_ptr<float>(),\n                               barrier.data_ptr<unsigned>(), sm_margin, stream, device_id, nullptr, false);\n  } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) {\n    group_norm_v2::gn_bwd_cuda((__nv_bfloat16*)grad_input.data_ptr(), (__nv_bfloat16*)grad_weight.data_ptr(),\n                               (__nv_bfloat16*)grad_bias.data_ptr(), (__nv_bfloat16*)grad_output.data_ptr(),\n                               (__nv_bfloat16*)x.data_ptr(), (__nv_bfloat16*)w.data_ptr(), (__nv_bfloat16*)b.data_ptr(),\n                               mean_var.data_ptr<float>(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups,\n                               x.size(1) / num_groups, red_buffer.data_ptr<float>(), barrier.data_ptr<unsigned>(),\n                               sm_margin, stream, device_id, nullptr, false);\n  } else {\n    throw std::invalid_argument(\"gn only supports half or bfloat16 input and weight\");\n  }\n  return std::make_tuple(grad_input, grad_weight, grad_bias);\n}\n\n}  // namespace group_norm_v2\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"gn\", &group_norm_v2::gn, py::arg(\"x\"), py::arg(\"w\"), py::arg(\"b\"), py::arg(\"eps\"), py::arg(\"silu\"),\n        py::arg(\"num_groups\"), py::arg(\"mean_var_out\") = py::none(), py::arg(\"sm_margin\") = 0, \"\");\n  m.def(\"gn_bwd\", &group_norm_v2::gn_bwd, py::arg(\"grad_output\"), py::arg(\"x\"), py::arg(\"w\"), py::arg(\"b\"),\n        py::arg(\"mean_var\"), py::arg(\"eps\"), py::arg(\"silu\"), py::arg(\"num_groups\"), py::arg(\"sm_margin\") = 0, \"\");\n}\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn.hpp",
    "content": "#pragma once\n\n#include <cuda_runtime.h>\n\n#include <cstdint>\n\nnamespace group_norm_v2 {\n\nstruct Meta {\n  int64_t red_buffer_size;\n  int64_t barrier_size;\n  int BLOCK_DIM_X;\n  int C_PER_BLOCK;\n  int ROWS_PER_BLOCK;\n  int VEC_ELEMS;\n  bool LOAD_TWICE;\n  int BLOCKS_PER_SM;\n  bool HARDWARE_CLUSTER;\n  int wgrad_sync_method;\n};\n\ntemplate <typename T>\nvoid gn_cuda(T* out, T* x, T* w, T* b, float eps, bool silu, int64_t n, int64_t hw, int num_groups,\n             int channels_per_group, float* mean_var_out, float* red_buffer, unsigned* barrier, int sm_margin,\n             cudaStream_t stream, int device_id, Meta* meta_ptr, bool meta_only);\n\ntemplate <typename T>\nvoid gn_bwd_cuda(T* grad_input, T* grad_weight, T* grad_bias, T* grad_output, T* x, T* w, T* b, float* mean_var,\n                 float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float* red_buffer,\n                 unsigned* barrier, int sm_margin, cudaStream_t stream, int device_id, Meta* meta_ptr, bool meta_only);\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_cuda.cu",
    "content": "#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <cstdio>\n#include <mutex>\n#include <stdexcept>\n\n#include \"gn.hpp\"\n#include \"gn_dispatch_hw_c.hpp\"\n#include \"gn_utils.hpp\"\n\n#define DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, NUM_GROUPS, SILU, ...)                        \\\n  [&] {                                                                                              \\\n    if (num_groups == 16 && silu == true) {                                                          \\\n      constexpr int NUM_GROUPS = 16;                                                                 \\\n      constexpr bool SILU = true;                                                                    \\\n      return __VA_ARGS__();                                                                          \\\n    }                                                                                                \\\n    if (num_groups == 32 && silu == false) {                                                         \\\n      constexpr int NUM_GROUPS = 32;                                                                 \\\n      constexpr bool SILU = false;                                                                   \\\n      return __VA_ARGS__();                                                                          \\\n    }                                                                                                \\\n    throw std::invalid_argument(\"DISPATCH_NUM_GROUPS_AND_SILU \" + std::to_string(num_groups) + \" \" + \\\n                                std::to_string(silu));                                               \\\n  }()\n\nnamespace group_norm_v2 {\n\ntemplate <typename T, int HW, int C, int G, bool SILU>\nvoid gn_cuda_single_shape(GN_CUDA_HOST_PARAMS(T));\n\ntemplate <typename T, int HW, int C, int G, bool SILU>\nvoid gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_PARAMS(T));\n\ntemplate <typename T>\nvoid gn_cuda(GN_CUDA_HOST_PARAMS(T)) {\n  DISPATCH_HW_C(hw, num_groups * channels_per_group, HW, C, [&] {\n    DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, G, SILU,\n                                 [&] { return gn_cuda_single_shape<T, HW, C, G, SILU>(GN_CUDA_HOST_ARGS); });\n  });\n}\n\ntemplate <typename T>\nvoid gn_bwd_cuda(GN_BWD_CUDA_HOST_PARAMS(T)) {\n  DISPATCH_HW_C(hw, num_groups * channels_per_group, HW, C, [&] {\n    DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, G, SILU,\n                                 [&] { return gn_bwd_cuda_single_shape<T, HW, C, G, SILU>(GN_BWD_CUDA_HOST_ARGS); });\n  });\n}\n\ntemplate void gn_cuda(GN_CUDA_HOST_PARAMS(half));\ntemplate void gn_cuda(GN_CUDA_HOST_PARAMS(__nv_bfloat16));\n\ntemplate void gn_bwd_cuda(GN_BWD_CUDA_HOST_PARAMS(half));\ntemplate void gn_bwd_cuda(GN_BWD_CUDA_HOST_PARAMS(__nv_bfloat16));\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_cuda_host_template.cuh",
    "content": "#pragma once\n\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <cstdio>\n#include <stdexcept>\n\n#include \"gn_cuda_kernel.cuh\"\n#include \"gn_utils.hpp\"\n\nnamespace group_norm_v2 {\n\n#define DISPATCH_LOWER_BOUND_N(VALUE, CONST_NAME, ...)                              \\\n  [&] {                                                                             \\\n    if (VALUE >= 16) {                                                              \\\n      constexpr int CONST_NAME = 16;                                                \\\n      return __VA_ARGS__();                                                         \\\n    }                                                                               \\\n    if (VALUE >= 8) {                                                               \\\n      constexpr int CONST_NAME = 8;                                                 \\\n      return __VA_ARGS__();                                                         \\\n    }                                                                               \\\n    if (VALUE >= 4) {                                                               \\\n      constexpr int CONST_NAME = 4;                                                 \\\n      return __VA_ARGS__();                                                         \\\n    }                                                                               \\\n    if (VALUE >= 2) {                                                               \\\n      constexpr int CONST_NAME = 2;                                                 \\\n      return __VA_ARGS__();                                                         \\\n    }                                                                               \\\n    if (VALUE >= 1) {                                                               \\\n      constexpr int CONST_NAME = 1;                                                 \\\n      return __VA_ARGS__();                                                         \\\n    }                                                                               \\\n    throw std::invalid_argument(\"DISPATCH_LOWER_BOUND_N \" + std::to_string(VALUE)); \\\n  }()\n\n#define DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT(runtime_cuda_arch, sm_count, RUNTIME_CUDA_ARCH, LB_SM_COUNT, ...) \\\n  [&] {                                                                                                               \\\n    if (runtime_cuda_arch == 1000 && sm_count >= 148) {                                                               \\\n      constexpr int RUNTIME_CUDA_ARCH = 1000, LB_SM_COUNT = 148;                                                      \\\n      return __VA_ARGS__();                                                                                           \\\n    }                                                                                                                 \\\n    throw std::invalid_argument(\"DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT \" + std::to_string(runtime_cuda_arch) +  \\\n                                \" \" + std::to_string(sm_count));                                                      \\\n  }()\n\n#define DISPATCH_SM_MARGIN(VALUE, CONST_NAME, ...)                              \\\n  [&] {                                                                         \\\n    if (VALUE == 0) {                                                           \\\n      constexpr int CONST_NAME = 0;                                             \\\n      return __VA_ARGS__();                                                     \\\n    }                                                                           \\\n    if (VALUE == 32) {                                                          \\\n      constexpr int CONST_NAME = 32;                                            \\\n      return __VA_ARGS__();                                                     \\\n    }                                                                           \\\n    throw std::invalid_argument(\"DISPATCH_SM_MARGIN \" + std::to_string(VALUE)); \\\n  }()\n\ninline constexpr int get_max_cuda_arch() {\n  int cuda_arch_list[] = {__CUDA_ARCH_LIST__};\n  int max_cuda_arch = -1;\n  for (int cuda_arch_item : cuda_arch_list) {\n    if (cuda_arch_item > max_cuda_arch) {\n      max_cuda_arch = cuda_arch_item;\n    }\n  }\n  return max_cuda_arch;\n}\n\ntemplate <typename T, bool BWD, bool REQUIRES_WGRAD, int HW, int G, int CPG, int LB_N, int RUNTIME_CUDA_ARCH,\n          int LB_SM_COUNT, int EFFECTIVE_CUDA_ARCH, int SM_MARGIN>\nconstexpr auto compute_gn_params() {\n  constexpr int C = G * CPG;\n\n  // Initialize each variable to comply with C++17\n  int BLOCK_DIM_X = 0;\n  int C_PER_BLOCK = 0;\n  int ROWS_PER_BLOCK = 0;\n  bool LOAD_TWICE = false;\n  int BLOCKS_PER_SM = 0;\n  WgradSyncMethod wgrad_sync_method = WGRAD_SYNC_UNSPECIFIED;\n\n  // There are two tiling strategies:\n  //   - block sync: each block handles a whole group, i.e., a multiple of (G * HW) elements\n  //   - virtual cluster sync: each virtual cluster handles a group\n  // Block sync can avoid cross-block synchronization latency, but it may cause low occupancy.\n  //   Use block sync if the IO size is small, when latency rather than occupancy dominates the kernel running time.\n\n  // Elements to load for forward pass is `x`, elements to load for backward pass are `x` and `grad_output`, hence there\n  // is a factor of (1 + BWD)\n  if (HW * CPG * (1 + BWD) * sizeof(T) <= 20480) {\n    // Strategy 1: block sync\n    C_PER_BLOCK = CPG;\n    ROWS_PER_BLOCK = HW;\n    BLOCK_DIM_X = lcm(32, C_PER_BLOCK);\n    while (BLOCK_DIM_X < 256) {\n      BLOCK_DIM_X *= 2;\n    }\n    BLOCKS_PER_SM = 1;\n    // The size of registers is 65536 registers * 4 bytes per register.\n    //   We have to leave some room for other variables and compiler optimizations,\n    //   so we use 36000 as the threshold.\n    LOAD_TWICE = BLOCKS_PER_SM * ROWS_PER_BLOCK * C_PER_BLOCK * (1 + BWD) * sizeof(T) > 36000 * 4;\n  } else {\n    // Strategy 2: virtual cluster sync\n    //   A virtual cluster is a group of blocks that are synchronized with each other.\n    //   Each group, i.e., a multiple of (G * HW) elements, should be handled on the same virtual cluster.\n    //   If the virtual cluster size is supported by the hardware, HARDWARE_CLUSTER is preferred;\n    //   otherwise, cooperative groups are used (i.e., PERSISTENT kernels).\n    int c_per_cluster = lcm(128 / (int)sizeof(T), CPG);\n\n    C_PER_BLOCK = c_per_cluster;\n    BLOCK_DIM_X = C_PER_BLOCK == 320 ? 320 : 480;\n\n    // Maximum number of rows that should reside in registers\n    int register_max_rows = 36000 * 4 / (C_PER_BLOCK * (1 + BWD) * sizeof(T));\n\n    std::tuple<bool, int, int, int, int, int> best_candidate{};\n    BLOCKS_PER_SM = 0;\n    ROWS_PER_BLOCK = 0;\n    for (int blocks_per_sm = 1; blocks_per_sm <= 3; blocks_per_sm++) {\n      for (int rows_per_block = HW; rows_per_block >= 1; rows_per_block /= 2) {\n        int virtual_cluster_size = (HW / rows_per_block) * (c_per_cluster / C_PER_BLOCK);\n        if (virtual_cluster_size > blocks_per_sm * (LB_SM_COUNT - SM_MARGIN)) {\n          continue;\n        }\n        int num_clusters = blocks_per_sm * (LB_SM_COUNT - SM_MARGIN) / virtual_cluster_size;\n        int num_tasks = LB_N * (C / c_per_cluster);\n        int num_waves = up_div(num_tasks, num_clusters);\n        bool load_twice = rows_per_block > register_max_rows / blocks_per_sm;\n\n        // Wave utilization: the percent of SMs that are used for each wave\n        //   For example, SM_COUNT=100 and VIRTUAL_CLUSTER_SIZE=64,\n        //     if BLOCKS_PER_SM=1, num_clusters=1, wave_util=64%;\n        //     if BLOCKS_PER_SM=2, num_clusters=3, wave_util=96%.\n        //   This helps select a good number of BLOCKS_PER_SM\n        int wave_util = 10000 * std::min(num_tasks, num_clusters) * virtual_cluster_size /\n                        (blocks_per_sm * (LB_SM_COUNT - SM_MARGIN));\n\n        decltype(best_candidate) candidate = {\n            true,\n            !load_twice,  // Prefer no load twice\n            !(num_waves >= 2 &&\n              blocks_per_sm ==\n                  1),    // When there are multiple waves, prefer multiple blocks per SM to ensure overlapping\n            -num_waves,  // Prefer fewer waves\n            std::min(9000, wave_util),  // Prefer high wave utilization\n            -blocks_per_sm,             // Prefer fewer blocks per SM in order to reduce threads overhead\n        };\n        if (candidate > best_candidate) {\n          // Assign each element respectively to comply with C++17\n          std::get<0>(best_candidate) = std::get<0>(candidate);\n          std::get<1>(best_candidate) = std::get<1>(candidate);\n          std::get<2>(best_candidate) = std::get<2>(candidate);\n          std::get<3>(best_candidate) = std::get<3>(candidate);\n          std::get<4>(best_candidate) = std::get<4>(candidate);\n          std::get<5>(best_candidate) = std::get<5>(candidate);\n          static_assert(std::tuple_size<decltype(best_candidate)>::value == 6, \"missing assignments\");\n\n          BLOCKS_PER_SM = blocks_per_sm;\n          ROWS_PER_BLOCK = rows_per_block;\n        }\n      }\n    }\n\n    LOAD_TWICE = ROWS_PER_BLOCK > register_max_rows / BLOCKS_PER_SM;\n  }\n\n  int c_per_cluster = lcm(CPG, C_PER_BLOCK);\n  int virtual_cluster_size = (c_per_cluster / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK);\n\n  // The occupancy is affected if cluster size is large.\n  //   For example, on H100, when gridDim=128 and each block occupies the whole SM,\n  //     if cluster is not used, all blocks can be active simultaneously.\n  //     if cluster size is 16, not all blocks can be active simultaneously (which can be queried by\n  //     cudaOccupancyMaxActiveClusters),\n  //       so there will be two waves which impacts efficiency.\n  // When SM_MARGIN is set, no cluster should be used because other kernels may occupy a part of the cluster.\n  bool HARDWARE_CLUSTER = virtual_cluster_size <= 2 && virtual_cluster_size != 1 && SM_MARGIN == 0;\n\n  int MAX_VEC_BYTES =\n      8;  // Sometimes 4 or 16 is better, but there is no trivial way to select the best vectorization size.\n  int VEC_ELEMS = std::min(gcd(MAX_VEC_BYTES / (int)sizeof(T), C_PER_BLOCK),\n                           gcd(MAX_VEC_BYTES / (int)sizeof(T), ROWS_PER_BLOCK * C_PER_BLOCK / BLOCK_DIM_X));\n\n  return std::make_tuple(BLOCK_DIM_X, C_PER_BLOCK, ROWS_PER_BLOCK, VEC_ELEMS, LOAD_TWICE, BLOCKS_PER_SM,\n                         HARDWARE_CLUSTER, wgrad_sync_method);\n}\n\n// Save compilation time for unused CUDA_ARCHs\n//   For each template argument from DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT, the kernel is only compiled for the\n//   corresponding CUDA_ARCH\ntemplate <int EFFECTIVE_CUDA_ARCH>\nclass CompileCondition {\n public:\n  __host__ __device__ static constexpr bool matches() {\n#if defined(__CUDA_ARCH__)\n    return __CUDA_ARCH__ == EFFECTIVE_CUDA_ARCH;\n#else\n    return false;\n#endif\n  }\n};\n\ntemplate <typename T, int HW, int C, int G, bool SILU>\nvoid gn_cuda_single_shape(GN_CUDA_HOST_PARAMS(T)) {\n  if (out == x) {\n    throw std::invalid_argument(\"not __restrict__\");\n  }\n\n  cudaDeviceProp const& deviceProp = get_device_prop(device_id);\n  int runtime_cuda_arch = deviceProp.major * 100 + deviceProp.minor * 10;\n  int sm_count = deviceProp.multiProcessorCount;\n\n  DISPATCH_LOWER_BOUND_N(n, LB_N, [&] {\n    DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT(runtime_cuda_arch, sm_count, RUNTIME_CUDA_ARCH, LB_SM_COUNT, [&] {\n      DISPATCH_SM_MARGIN(sm_margin, SM_MARGIN, [&] {\n        if (hw != HW) {\n          throw std::invalid_argument(\"wrong HW\");\n        }\n        if (num_groups * channels_per_group != C) {\n          throw std::invalid_argument(\"wrong C\");\n        }\n        if (num_groups != G) {\n          throw std::invalid_argument(\"wrong G\");\n        }\n        if (silu != SILU) {\n          throw std::invalid_argument(\"wrong SILU\");\n        }\n        if (n < LB_N) {\n          throw std::invalid_argument(\"wrong LB_N\");\n        }\n        if (runtime_cuda_arch != RUNTIME_CUDA_ARCH) {\n          throw std::invalid_argument(\"wrong RUNTIME_CUDA_ARCH\");\n        }\n        if (sm_count < LB_SM_COUNT) {\n          throw std::invalid_argument(\"wrong LB_SM_COUNT\");\n        }\n        if (sm_margin != SM_MARGIN) {\n          throw std::invalid_argument(\"wrong SM_MARGIN\");\n        }\n        constexpr int EFFECTIVE_CUDA_ARCH =\n            std::min(RUNTIME_CUDA_ARCH, get_max_cuda_arch());  // Assume the max CUDA_ARCH is used to generate PTX\n\n        constexpr int CPG = C / G;\n\n        constexpr auto params = compute_gn_params<T, false, false, HW, G, CPG, LB_N, RUNTIME_CUDA_ARCH, LB_SM_COUNT,\n                                                  EFFECTIVE_CUDA_ARCH, SM_MARGIN>();\n        constexpr int BLOCK_DIM_X = std::get<0>(params);\n        constexpr int C_PER_BLOCK = std::get<1>(params);\n        constexpr int ROWS_PER_BLOCK = std::get<2>(params);\n        constexpr int VEC_ELEMS = std::get<3>(params);\n        constexpr bool LOAD_TWICE = std::get<4>(params);\n        constexpr int BLOCKS_PER_SM = std::get<5>(params);\n        constexpr bool HARDWARE_CLUSTER = std::get<6>(params);\n\n        constexpr int C_PER_CLUSTER = lcm(CPG, C_PER_BLOCK);\n        constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK);\n        constexpr int NUM_VIRTUAL_CLUSTERS = ((LB_SM_COUNT - SM_MARGIN) * BLOCKS_PER_SM) / VIRTUAL_CLUSTER_SIZE;\n        constexpr bool PERSISTENT =\n            !HARDWARE_CLUSTER &&\n            VIRTUAL_CLUSTER_SIZE >=\n                2;  // Only virtual cluster sync (not include hardware cluster sync) requires PERSISTENT kernels\n\n        if (meta_ptr) {\n          constexpr int MAX_NUM_GROUPS_PER_BLOCK =\n              C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1;\n          meta_ptr->red_buffer_size = 2 * NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK * 2;\n          meta_ptr->barrier_size = NUM_VIRTUAL_CLUSTERS;\n          meta_ptr->BLOCK_DIM_X = BLOCK_DIM_X;\n          meta_ptr->C_PER_BLOCK = C_PER_BLOCK;\n          meta_ptr->ROWS_PER_BLOCK = ROWS_PER_BLOCK;\n          meta_ptr->VEC_ELEMS = VEC_ELEMS;\n          meta_ptr->LOAD_TWICE = LOAD_TWICE;\n          meta_ptr->BLOCKS_PER_SM = BLOCKS_PER_SM;\n          meta_ptr->HARDWARE_CLUSTER = HARDWARE_CLUSTER;\n          meta_ptr->wgrad_sync_method = (int)WGRAD_SYNC_UNSPECIFIED;\n        }\n        if (meta_only) {\n          return;\n        }\n\n        cudaLaunchConfig_t config = {0};\n        config.gridDim = dim3(\n            VIRTUAL_CLUSTER_SIZE,\n            PERSISTENT ? std::min((int)n * (C / C_PER_CLUSTER), NUM_VIRTUAL_CLUSTERS) : n * (C / C_PER_CLUSTER), 1);\n        config.blockDim = BLOCK_DIM_X;\n        config.stream = stream;\n\n        cudaLaunchAttribute attribute[2];\n        if constexpr (HARDWARE_CLUSTER) {\n          attribute[0].id = cudaLaunchAttributeClusterDimension;\n          attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE;  // Cluster size in X-dimension\n          attribute[0].val.clusterDim.y = 1;\n          attribute[0].val.clusterDim.z = 1;\n          config.attrs = attribute;\n          config.numAttrs++;\n        }\n        if constexpr (PERSISTENT) {\n          attribute[config.numAttrs].id = cudaLaunchAttributeCooperative;\n          attribute[config.numAttrs].val.cooperative = 1;\n          config.attrs = attribute;\n          config.numAttrs++;\n        }\n\n        auto kernel = &gn_cuda_kernel<T, BLOCK_DIM_X, BLOCKS_PER_SM, G, CPG, HW, SILU, ROWS_PER_BLOCK, C_PER_BLOCK,\n                                      C_PER_CLUSTER, VEC_ELEMS, PERSISTENT, NUM_VIRTUAL_CLUSTERS, LOAD_TWICE,\n                                      HARDWARE_CLUSTER, CompileCondition<EFFECTIVE_CUDA_ARCH> >;\n        if constexpr (HARDWARE_CLUSTER) {\n          if constexpr (VIRTUAL_CLUSTER_SIZE > 8) {\n            CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeNonPortableClusterSizeAllowed, 1));\n          }\n          int max_cluster_size;\n          int active_clusters;\n          CUDA_CHECK(cudaOccupancyMaxPotentialClusterSize(&max_cluster_size, (void*)kernel, &config));\n          if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size && PERSISTENT) {\n            attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE;\n            CUDA_CHECK(cudaOccupancyMaxActiveClusters(&active_clusters, (void*)kernel, &config));\n          }\n          if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size &&\n              (!PERSISTENT || PERSISTENT && NUM_VIRTUAL_CLUSTERS <= active_clusters)) {\n            attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE;\n          } else {\n            // Fallback to cooperative groups because hardware cluster cannot be active simultaneously\n            constexpr bool HARDWARE_CLUSTER_NEW = false;\n            constexpr bool PERSISTENT_NEW = !HARDWARE_CLUSTER_NEW && VIRTUAL_CLUSTER_SIZE >= 2;\n            config.gridDim = dim3(\n                VIRTUAL_CLUSTER_SIZE,\n                PERSISTENT_NEW ? std::min((int)n * (C / C_PER_CLUSTER), NUM_VIRTUAL_CLUSTERS) : n * (C / C_PER_CLUSTER),\n                1);\n            config.attrs = nullptr;\n            config.numAttrs = 0;\n            if constexpr (PERSISTENT_NEW) {\n              attribute[config.numAttrs].id = cudaLaunchAttributeCooperative;\n              attribute[config.numAttrs].val.cooperative = 1;\n              config.attrs = attribute;\n              config.numAttrs++;\n            }\n            kernel = &gn_cuda_kernel<T, BLOCK_DIM_X, BLOCKS_PER_SM, G, CPG, HW, SILU, ROWS_PER_BLOCK, C_PER_BLOCK,\n                                     C_PER_CLUSTER, VEC_ELEMS, PERSISTENT_NEW, NUM_VIRTUAL_CLUSTERS, LOAD_TWICE,\n                                     HARDWARE_CLUSTER_NEW, CompileCondition<EFFECTIVE_CUDA_ARCH> >;\n          }\n        }\n        CUDA_CHECK(cudaLaunchKernelEx(&config, kernel, out, x, w, b, eps, n, mean_var_out, red_buffer, barrier));\n      });\n    });\n  });\n}\n\ntemplate <typename T, int HW, int C, int G, bool SILU>\nvoid gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_PARAMS(T)) {\n  if (grad_input == grad_output || grad_input == x) {\n    throw std::invalid_argument(\"not __restrict__\");\n  }\n\n  cudaDeviceProp const& deviceProp = get_device_prop(device_id);\n  int runtime_cuda_arch = deviceProp.major * 100 + deviceProp.minor * 10;\n  int sm_count = deviceProp.multiProcessorCount;\n\n  DISPATCH_LOWER_BOUND_N(n, LB_N, [&] {\n    DISPATCH_CUDA_ARCH_AND_LOWER_BOUND_SM_COUNT(runtime_cuda_arch, sm_count, RUNTIME_CUDA_ARCH, LB_SM_COUNT, [&] {\n      DISPATCH_SM_MARGIN(sm_margin, SM_MARGIN, [&] {\n        if (hw != HW) {\n          throw std::invalid_argument(\"wrong HW\");\n        }\n        if (num_groups * channels_per_group != C) {\n          throw std::invalid_argument(\"wrong C\");\n        }\n        if (num_groups != G) {\n          throw std::invalid_argument(\"wrong G\");\n        }\n        if (silu != SILU) {\n          throw std::invalid_argument(\"wrong SILU\");\n        }\n        if (n < LB_N) {\n          throw std::invalid_argument(\"wrong LB_N\");\n        }\n        if (runtime_cuda_arch != RUNTIME_CUDA_ARCH) {\n          throw std::invalid_argument(\"wrong RUNTIME_CUDA_ARCH\");\n        }\n        if (sm_count < LB_SM_COUNT) {\n          throw std::invalid_argument(\"wrong LB_SM_COUNT\");\n        }\n        if (sm_margin != SM_MARGIN) {\n          throw std::invalid_argument(\"wrong SM_MARGIN\");\n        }\n        constexpr int EFFECTIVE_CUDA_ARCH =\n            std::min(RUNTIME_CUDA_ARCH, get_max_cuda_arch());  // Assume the max CUDA_ARCH is used to generate PTX\n\n        constexpr bool REQUIRES_WGRAD = true;\n        constexpr int CPG = C / G;\n\n        constexpr auto params = compute_gn_params<T, true, REQUIRES_WGRAD, HW, G, CPG, LB_N, RUNTIME_CUDA_ARCH,\n                                                  LB_SM_COUNT, EFFECTIVE_CUDA_ARCH, SM_MARGIN>();\n        constexpr int BLOCK_DIM_X = std::get<0>(params);\n        constexpr int C_PER_BLOCK = std::get<1>(params);\n        constexpr int ROWS_PER_BLOCK = std::get<2>(params);\n        constexpr int VEC_ELEMS = std::get<3>(params);\n        constexpr bool LOAD_TWICE = std::get<4>(params);\n        constexpr int BLOCKS_PER_SM = std::get<5>(params);\n        constexpr bool HARDWARE_CLUSTER = std::get<6>(params);\n        constexpr WgradSyncMethod wgrad_sync_method_hint = std::get<7>(params);\n\n        constexpr int C_PER_CLUSTER = lcm(CPG, C_PER_BLOCK);\n        constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK);\n        constexpr int NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED =\n            ((LB_SM_COUNT - SM_MARGIN) * BLOCKS_PER_SM) / VIRTUAL_CLUSTER_SIZE;\n\n        // PERSISTENT is required because wgrad reduction requires synchronization.\n        //   TODO: specilize for the case that REQUIRES_WGRAD == false\n        constexpr bool PERSISTENT = true;\n\n        // Determine whether to align each virtual cluster to a fixed range of channels\n        //   If aligned, WGRAD_REUSE_SUM_SYNC_GROUP can be used, then less local wgrad memory is used (leave more room\n        //   for compiler\n        //     optimizations), and wgrad reduction is more efficient.\n        //   However, aligning can cause low occupancy.\n        //   There is a trade-off, and the condition to align is `NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED > 2 * (C /\n        //   C_PER_CLUSTER)`\n        constexpr WgradSyncMethod wgrad_sync_method =\n            wgrad_sync_method_hint == WGRAD_SYNC_UNSPECIFIED\n                ? NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED > 2 * (C / C_PER_CLUSTER) ||\n                          NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED % (C / C_PER_CLUSTER) == 0\n                      ? (HARDWARE_CLUSTER ? WGRAD_ARRIVE_AND_WAIT_GROUP : WGRAD_REUSE_SUM_SYNC_GROUP)\n                      : WGRAD_REUSE_SUM_SYNC_GRID\n                : wgrad_sync_method_hint;\n        constexpr int NUM_VIRTUAL_CLUSTERS =\n            wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP || wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP\n                ? NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED / (C / C_PER_CLUSTER) * (C / C_PER_CLUSTER)\n                : NUM_VIRTUAL_CLUSTERS_NOT_ALIGNED;\n\n        if (meta_ptr) {\n          constexpr int MAX_NUM_GROUPS_PER_BLOCK =\n              C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1;\n          meta_ptr->red_buffer_size =\n              2 * NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK * 2 +\n              std::max(n, (int64_t)NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER)) * (HW / ROWS_PER_BLOCK) * C * 2;\n          meta_ptr->barrier_size = NUM_VIRTUAL_CLUSTERS + C / C_PER_CLUSTER;\n          meta_ptr->BLOCK_DIM_X = BLOCK_DIM_X;\n          meta_ptr->C_PER_BLOCK = C_PER_BLOCK;\n          meta_ptr->ROWS_PER_BLOCK = ROWS_PER_BLOCK;\n          meta_ptr->VEC_ELEMS = VEC_ELEMS;\n          meta_ptr->LOAD_TWICE = LOAD_TWICE;\n          meta_ptr->BLOCKS_PER_SM = BLOCKS_PER_SM;\n          meta_ptr->HARDWARE_CLUSTER = HARDWARE_CLUSTER;\n          meta_ptr->wgrad_sync_method = (int)wgrad_sync_method;\n        }\n        if (meta_only) {\n          return;\n        }\n\n        cudaLaunchConfig_t config = {0};\n        config.gridDim = dim3(VIRTUAL_CLUSTER_SIZE, PERSISTENT ? NUM_VIRTUAL_CLUSTERS : n * (C / C_PER_CLUSTER), 1);\n        config.blockDim = BLOCK_DIM_X;\n        config.stream = stream;\n\n        cudaLaunchAttribute attribute[2];\n        if constexpr (HARDWARE_CLUSTER) {\n          attribute[0].id = cudaLaunchAttributeClusterDimension;\n          attribute[0].val.clusterDim.x = 1;  // Cluster size in X-dimension\n          attribute[0].val.clusterDim.y = 1;\n          attribute[0].val.clusterDim.z = 1;\n          config.attrs = attribute;\n          config.numAttrs++;\n        }\n        if constexpr (PERSISTENT) {\n          attribute[config.numAttrs].id = cudaLaunchAttributeCooperative;\n          attribute[config.numAttrs].val.cooperative = 1;\n          config.attrs = attribute;\n          config.numAttrs++;\n        }\n\n        auto kernel =\n            &gn_bwd_cuda_kernel<T, BLOCK_DIM_X, BLOCKS_PER_SM, G, CPG, HW, SILU, REQUIRES_WGRAD, ROWS_PER_BLOCK,\n                                C_PER_BLOCK, C_PER_CLUSTER, VEC_ELEMS, PERSISTENT, NUM_VIRTUAL_CLUSTERS, LOAD_TWICE,\n                                HARDWARE_CLUSTER, wgrad_sync_method, CompileCondition<EFFECTIVE_CUDA_ARCH> >;\n        if constexpr (HARDWARE_CLUSTER) {\n          if constexpr (VIRTUAL_CLUSTER_SIZE > 8) {\n            CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeNonPortableClusterSizeAllowed, 1));\n          }\n          int max_cluster_size;\n          int active_clusters;\n          CUDA_CHECK(cudaOccupancyMaxPotentialClusterSize(&max_cluster_size, (void*)kernel, &config));\n          if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size && PERSISTENT) {\n            attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE;\n            CUDA_CHECK(cudaOccupancyMaxActiveClusters(&active_clusters, (void*)kernel, &config));\n          }\n          if (VIRTUAL_CLUSTER_SIZE <= max_cluster_size &&\n              (!PERSISTENT || PERSISTENT && NUM_VIRTUAL_CLUSTERS <= active_clusters)) {\n            attribute[0].val.clusterDim.x = VIRTUAL_CLUSTER_SIZE;\n          } else {\n            // Fallback to cooperative groups for dgrad computation because hardware cluster cannot be active\n            // simultaneously\n            attribute[0].val.clusterDim.x = 1;\n            kernel =\n                &gn_bwd_cuda_kernel<T, BLOCK_DIM_X, BLOCKS_PER_SM, G, CPG, HW, SILU, REQUIRES_WGRAD, ROWS_PER_BLOCK,\n                                    C_PER_BLOCK, C_PER_CLUSTER, VEC_ELEMS, PERSISTENT, NUM_VIRTUAL_CLUSTERS, LOAD_TWICE,\n                                    false, wgrad_sync_method, CompileCondition<EFFECTIVE_CUDA_ARCH> >;\n          }\n        }\n        CUDA_CHECK(cudaLaunchKernelEx(&config, kernel, grad_input, grad_weight, grad_bias, grad_output, x, w, b,\n                                      mean_var, eps, n, red_buffer, barrier));\n      });\n    });\n  });\n}\n\n#define GN_CUDA_INST_DEFINE(HW, C)                                                                                \\\n  template void gn_cuda_single_shape<half, HW, C, 16, true>(GN_CUDA_HOST_PARAMS(half));                           \\\n  template void gn_cuda_single_shape<half, HW, C, 32, false>(GN_CUDA_HOST_PARAMS(half));                          \\\n  template void gn_bwd_cuda_single_shape<half, HW, C, 16, true>(GN_BWD_CUDA_HOST_PARAMS(half));                   \\\n  template void gn_bwd_cuda_single_shape<half, HW, C, 32, false>(GN_BWD_CUDA_HOST_PARAMS(half));                  \\\n  template void gn_cuda_single_shape<__nv_bfloat16, HW, C, 16, true>(GN_CUDA_HOST_PARAMS(__nv_bfloat16));         \\\n  template void gn_cuda_single_shape<__nv_bfloat16, HW, C, 32, false>(GN_CUDA_HOST_PARAMS(__nv_bfloat16));        \\\n  template void gn_bwd_cuda_single_shape<__nv_bfloat16, HW, C, 16, true>(GN_BWD_CUDA_HOST_PARAMS(__nv_bfloat16)); \\\n  template void gn_bwd_cuda_single_shape<__nv_bfloat16, HW, C, 32, false>(GN_BWD_CUDA_HOST_PARAMS(__nv_bfloat16));\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_1280.cu",
    "content": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(1024, 1280)\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_1920.cu",
    "content": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(1024, 1920)\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_320.cu",
    "content": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(1024, 320)\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_640.cu",
    "content": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(1024, 640)\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_960.cu",
    "content": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(1024, 960)\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_1280.cu",
    "content": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(256, 1280)\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_1920.cu",
    "content": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(256, 1920)\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_2560.cu",
    "content": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(256, 2560)\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_640.cu",
    "content": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(256, 640)\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_320.cu",
    "content": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(4096, 320)\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_640.cu",
    "content": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(4096, 640)\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_960.cu",
    "content": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(4096, 960)\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_64_1280.cu",
    "content": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(64, 1280)\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_cuda_inst_64_2560.cu",
    "content": "#include \"gn_cuda_host_template.cuh\"\n\nnamespace group_norm_v2 {\n\nGN_CUDA_INST_DEFINE(64, 2560)\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_cuda_kernel.cuh",
    "content": "#pragma once\n\n#include <cooperative_groups.h>\n\n#include \"gn_utils.hpp\"\n\nnamespace group_norm_v2 {\n\nnamespace cg = cooperative_groups;\n\ntemplate <typename T>\ninline constexpr T up_div(T a, T b) {\n  return (a + b - 1) / b;\n}\n\ntemplate <typename T>\ninline constexpr T round_up(T a, T b) {\n  return up_div(a, b) * b;\n}\n\ninline constexpr unsigned round_up_pow2(unsigned x) {\n  int log = 0;\n  x--;\n  while (x) {\n    x /= 2;\n    log++;\n  }\n  return 1U << log;\n}\n\ninline constexpr unsigned round_down_pow2(unsigned x) { return round_up_pow2(x + 1) / 2; }\n\ntemplate <typename T>\ninline constexpr T gcd(T a, T b) {\n  while (b != 0) {\n    int t = b;\n    b = a % b;\n    a = t;\n  }\n  return a;\n}\n\ntemplate <typename T>\ninline constexpr T lcm(T a, T b) {\n  return (a * b) / gcd(a, b);\n}\n\ntemplate <typename T>\ninline constexpr T relative_prime(T x, T min) {\n  int p = min;\n  while (gcd(p, x) != 1) {\n    p++;\n  }\n  return p;\n}\n\ntemplate <typename T>\ninline constexpr T max_divisor(T x, T max) {\n  int p = max;\n  while (x % p != 0) {\n    p--;\n  }\n  return p;\n}\n\nconstexpr unsigned FINAL_MASK = 0xffffffff;\n\ntemplate <int VIRTUAL_CLUSTER_SIZE, bool PERSISTENT, bool HARDWARE_CLUSTER>\n__device__ void virtual_cluster_sync(unsigned int* barrier) {\n  if constexpr (VIRTUAL_CLUSTER_SIZE == 1) {\n    __syncthreads();\n  } else if constexpr (HARDWARE_CLUSTER) {\n    cg::this_cluster().sync();\n  } else {\n    static_assert(PERSISTENT, \"potential deadlock\");\n    volatile unsigned int* arrived = &barrier[blockIdx.y];\n    __syncthreads();\n    if (threadIdx.x == 0) {\n      unsigned int expected = VIRTUAL_CLUSTER_SIZE;\n      bool gpu_master = blockIdx.x == 0;\n      unsigned int nb = 1;\n      if (gpu_master) {\n        nb = 0x80000000 - (expected - 1);\n      }\n      unsigned int oldArrive;\n      asm volatile(\"atom.add.release.gpu.u32 %0,[%1],%2;\"\n                   : \"=r\"(oldArrive)\n                   : _CG_ASM_PTR_CONSTRAINT((unsigned int*)arrived), \"r\"(nb)\n                   : \"memory\");\n      unsigned int current_arrive;\n      do {\n        asm volatile(\"ld.acquire.gpu.u32 %0,[%1];\"\n                     : \"=r\"(current_arrive)\n                     : _CG_ASM_PTR_CONSTRAINT((unsigned int*)arrived)\n                     : \"memory\");\n      } while (!cooperative_groups::details::bar_has_flipped(oldArrive, current_arrive));\n    }\n    __syncthreads();\n  }\n}\n\ntemplate <int NUM_BLOCKS, bool PERSISTENT>\n__device__ unsigned int group_barrier_arrive(unsigned int* barrier, bool gpu_master) {\n  static_assert(PERSISTENT, \"potential deadlock\");\n  volatile unsigned int* arrived = &barrier[0];\n  __syncthreads();\n  if (threadIdx.x == 0) {\n    unsigned int expected = NUM_BLOCKS;\n    unsigned int nb = 1;\n    if (gpu_master) {\n      nb = 0x80000000 - (expected - 1);\n    }\n    unsigned int oldArrive;\n    asm volatile(\"atom.add.release.gpu.u32 %0,[%1],%2;\"\n                 : \"=r\"(oldArrive)\n                 : _CG_ASM_PTR_CONSTRAINT((unsigned int*)arrived), \"r\"(nb)\n                 : \"memory\");\n    return oldArrive;\n  } else {\n    return 0;\n  }\n}\n\n__device__ inline void group_barrier_wait(unsigned int* barrier, unsigned int oldArrive) {\n  volatile unsigned int* arrived = &barrier[0];\n  if (threadIdx.x == 0) {\n    unsigned int current_arrive;\n    do {\n      asm volatile(\"ld.acquire.gpu.u32 %0,[%1];\"\n                   : \"=r\"(current_arrive)\n                   : _CG_ASM_PTR_CONSTRAINT((unsigned int*)arrived)\n                   : \"memory\");\n    } while (!cooperative_groups::details::bar_has_flipped(oldArrive, current_arrive));\n  }\n  __syncthreads();\n}\n\n// Calculate `n` (batch id) and `c` (channel range id) for each loop\ntemplate <bool CONSTANT_C_LOOP, int C, int C_PER_CLUSTER, int NUM_VIRTUAL_CLUSTERS, bool PERSISTENT>\nclass NCScheduler;\n\ntemplate <int C, int C_PER_CLUSTER, int NUM_VIRTUAL_CLUSTERS, bool PERSISTENT>\nclass NCScheduler<false, C, C_PER_CLUSTER, NUM_VIRTUAL_CLUSTERS, PERSISTENT> {\n public:\n  __device__ NCScheduler(int64_t n) {\n    nc_loop_ = blockIdx.y;\n    at_end_ = nc_loop_ >= n * (C / C_PER_CLUSTER);\n  }\n  __device__ auto get_nc() {\n    int64_t n_loop = nc_loop_ / (C / C_PER_CLUSTER);\n    int c_loop = nc_loop_ % (C / C_PER_CLUSTER);\n    return std::make_tuple(n_loop, c_loop);\n  }\n  __device__ void next(int64_t n) {\n    if constexpr (PERSISTENT) {\n      nc_loop_ += NUM_VIRTUAL_CLUSTERS;\n      at_end_ = nc_loop_ >= n * (C / C_PER_CLUSTER);\n    }\n  }\n  __device__ bool at_end(int64_t n) { return !PERSISTENT || at_end_; }\n\n private:\n  int64_t nc_loop_;\n  bool at_end_;\n};\n\ntemplate <int C, int C_PER_CLUSTER, int NUM_VIRTUAL_CLUSTERS, bool PERSISTENT>\nclass NCScheduler<true, C, C_PER_CLUSTER, NUM_VIRTUAL_CLUSTERS, PERSISTENT> {\n public:\n  __device__ NCScheduler(int64_t n) {\n    n_loop_ = blockIdx.y / (C / C_PER_CLUSTER);\n    c_loop_ = blockIdx.y % (C / C_PER_CLUSTER);\n  }\n  __device__ auto get_nc() { return std::make_tuple(n_loop_, c_loop_); }\n  __device__ void next(int64_t n) {\n    if constexpr (PERSISTENT) {\n      n_loop_ += NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER);\n    }\n  }\n  __device__ bool at_end(int64_t n) { return !PERSISTENT || n_loop_ >= n; }\n\n private:\n  int64_t n_loop_;\n  int c_loop_;\n};\n\nclass CompileConditionAlwaysTrue {\n public:\n  __device__ static constexpr bool matches() { return true; }\n};\n\ntemplate <typename T, int BLOCK_DIM_X, int BLOCKS_PER_SM, int G, int CPG, int HW, bool SILU, int ROWS_PER_BLOCK,\n          int C_PER_BLOCK, int C_PER_CLUSTER, int VEC_ELEMS, bool PERSISTENT, int NUM_VIRTUAL_CLUSTERS, bool LOAD_TWICE,\n          bool HARDWARE_CLUSTER, class CompileCondition = CompileConditionAlwaysTrue>\n__global__ __launch_bounds__(BLOCK_DIM_X, BLOCKS_PER_SM) void gn_cuda_kernel(\n    T* __restrict__ out, T const* __restrict__ x, T const* __restrict__ w, T const* __restrict__ b, float eps,\n    int64_t n, float* __restrict__ mean_var_out, float* __restrict__ red_buffer, unsigned* __restrict__ barrier) {\n  // Procedure Overview\n  //   1. Thread sum: read from gmem, write partial sum to smem, store input in registers (if no LOAD_TWICE)\n  //   2. Block sum: read from smem, write partial sum to gmem (or distributed shared memory if HARDWARE_CLUSTER is\n  //   used)\n  //   3. Group sum: read from gmem, write mean&var to smem\n  //   4. Scale: read mean&var from smem, read input from gmem (if LOAD_TWICE), write output to gmem\n\n  static_assert(BLOCK_DIM_X % 32 == 0, \"warp shuffle error\");\n\n  constexpr int C = G * CPG;\n  static_assert(C % C_PER_CLUSTER == 0, \"cannot divide channels into clusters\");\n  static_assert(C_PER_CLUSTER % C_PER_BLOCK == 0, \"cannot divide a cluster into blocks\");\n  static_assert(C_PER_CLUSTER % CPG == 0, \"no reduce between clusters, would produce incorrect results\");\n  static_assert(!(C_PER_BLOCK % CPG == 0 && C_PER_CLUSTER != C_PER_BLOCK),\n                \"inefficient configuration, please reduce C_PER_CLUSTER\");\n\n  static_assert(ROWS_PER_BLOCK * C_PER_BLOCK % BLOCK_DIM_X == 0, \"cannot divide tile into threads\");\n  struct alignas(VEC_ELEMS * sizeof(T)) U {\n    T data[VEC_ELEMS];\n  };\n\n  auto compute_mean_var = [&](float2 sum) {\n    float mean = sum.x / (HW * CPG);\n    float var = std::max(0.f, sum.y / (HW * CPG) - mean * mean);\n    return float2{mean, var};\n  };\n\n  static_assert(HW % ROWS_PER_BLOCK == 0,\n                \"HW must be divisible by ROWS_PER_BLOCK to determine the number of blocks on the HW axis\");\n  constexpr int MAX_NUM_GROUPS_PER_BLOCK =\n      C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1;\n  constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK);\n  constexpr int virtual_cluster_dim_x = C_PER_CLUSTER / C_PER_BLOCK;\n  constexpr int virtual_cluster_dim_y = HW / ROWS_PER_BLOCK;\n  int virtual_block_idx_x = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) % virtual_cluster_dim_x;\n  int virtual_block_idx_y = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) / virtual_cluster_dim_x;\n\n  if constexpr (CompileCondition::matches()) {\n    int step = 0;\n    constexpr bool CONSTANT_C_LOOP = PERSISTENT && NUM_VIRTUAL_CLUSTERS % (C / C_PER_CLUSTER) == 0;\n    NCScheduler<CONSTANT_C_LOOP, C, C_PER_CLUSTER, NUM_VIRTUAL_CLUSTERS, PERSISTENT> nc_scheduler(n);\n    while (true) {  // TODO: unroll the loop\n      if constexpr (PERSISTENT) {\n        if (nc_scheduler.at_end(n)) {\n          break;\n        }\n      }\n      auto [n_loop, c_loop] = nc_scheduler.get_nc();\n      if constexpr (PERSISTENT) {\n        nc_scheduler.next(n);\n      }\n      static_assert(C_PER_BLOCK % VEC_ELEMS == 0, \"cannot vectorize\");\n      static_assert((BLOCK_DIM_X * VEC_ELEMS) % C_PER_BLOCK == 0,\n                    \"each block should load one or more C_PER_BLOCK at once\");\n      constexpr int ROWS_PER_IO = BLOCK_DIM_X * VEC_ELEMS / C_PER_BLOCK;\n      static_assert(ROWS_PER_BLOCK % ROWS_PER_IO == 0, \"cannot determine the IO times per batch\");\n      int block_channel_start = virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER;\n      int block_group_start = block_channel_start / CPG;\n      int thread_channel_start = block_channel_start + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS) * VEC_ELEMS;\n      U frag[ROWS_PER_BLOCK / ROWS_PER_IO];\n\n      // GCD_VEC_CPG is an important constant that determines how many channels can be merged in reduction computation\n      //   For example, VEC_ELEMS=4 and CPG=10, then GCD_VEC_CPG=2,\n      //   so we need to store only 2 sums on each thread, and compute only 2 mean&var for each thread.\n      constexpr int GCD_VEC_CPG = gcd(VEC_ELEMS, CPG);\n\n      // If each block handles only one group, run warpReduce and store the sum to `sum_per_channel_single_group`;\n      // otherwise store (VEC_ELEMS / GCD_VEC_CPG) sums to `sum_per_channel_multi_group`, where `relative_prime` is used\n      // for swizzle.\n      constexpr bool SINGLE_GROUP_PER_BLOCK = CPG % C_PER_BLOCK == 0;\n      [[maybe_unused]] __shared__ float2 sum_per_channel_single_group[BLOCK_DIM_X / 32];\n      [[maybe_unused]] __shared__ float2 sum_per_channel_multi_group[C_PER_BLOCK / GCD_VEC_CPG][relative_prime(\n          128 / (int)sizeof(float2), ROWS_PER_IO)];\n\n      if constexpr (LOAD_TWICE) {\n        float2 frag_sum_per_channel[VEC_ELEMS / GCD_VEC_CPG]{};\n        for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) {\n          int64_t input_idx =\n              n_loop * HW * C +\n              (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C +\n              thread_channel_start;\n          U val = *reinterpret_cast<U const*>(&x[input_idx]);\n          for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) {\n            float2 sum = frag_sum_per_channel[i];\n            for (int k = 0; k < GCD_VEC_CPG; k++) {\n              sum.x += (float)val.data[i * GCD_VEC_CPG + k];\n              sum.y += (float)val.data[i * GCD_VEC_CPG + k] * (float)val.data[i * GCD_VEC_CPG + k];\n            }\n            frag_sum_per_channel[i] = sum;\n          }\n        }\n        for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) {\n          if constexpr (SINGLE_GROUP_PER_BLOCK) {\n            for (int mask = 16; mask > 0; mask >>= 1) {\n              frag_sum_per_channel[i].x += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].x, mask, 32);\n              frag_sum_per_channel[i].y += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].y, mask, 32);\n            }\n            static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, \"process only one element for each warp\");\n            if (threadIdx.x % 32 == 0) {\n              sum_per_channel_single_group[threadIdx.x / 32] = frag_sum_per_channel[i];\n            }\n          } else {\n            sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)]\n                                       [threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = frag_sum_per_channel[i];\n          }\n        }\n        __syncthreads();\n      } else {\n        for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) {\n          int64_t input_idx =\n              n_loop * HW * C +\n              (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C +\n              thread_channel_start;\n          frag[j] = *reinterpret_cast<U const*>(&x[input_idx]);\n        }\n\n        for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) {\n          float2 sum = {0.f, 0.f};\n          for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) {\n            for (int k = 0; k < GCD_VEC_CPG; k++) {\n              sum.x += (float)frag[j].data[i * GCD_VEC_CPG + k];\n              sum.y += (float)frag[j].data[i * GCD_VEC_CPG + k] * (float)frag[j].data[i * GCD_VEC_CPG + k];\n            }\n          }\n          if constexpr (SINGLE_GROUP_PER_BLOCK) {\n            for (int mask = 16; mask > 0; mask >>= 1) {\n              sum.x += __shfl_xor_sync(FINAL_MASK, sum.x, mask, 32);\n              sum.y += __shfl_xor_sync(FINAL_MASK, sum.y, mask, 32);\n            }\n            static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, \"process only one element for each warp\");\n            if (threadIdx.x % 32 == 0) {\n              sum_per_channel_single_group[threadIdx.x / 32] = sum;\n            }\n          } else {\n            sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)]\n                                       [threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = sum;\n          }\n        }\n        __syncthreads();\n      }\n\n      U uw = *reinterpret_cast<U const*>(&w[thread_channel_start]);\n      U ub = *reinterpret_cast<U const*>(&b[thread_channel_start]);\n\n      // Three cases for the red_buffer:\n      //   - Block sync (VIRTUAL_CLUSTER_SIZE=1): use shared memory\n      //   - Virtual cluster sync with HARDWARE_CLUSTER: use distributed shared memory\n      //   - Virtual cluster sync without HARDWARE_CLUSTER: use global memory, i.e., `red_buffer`\n      constexpr bool USE_SHARED_RED_BUFFER = HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1;\n\n      // Specialize for the case that each group is handled by only one block\n      //   For common cases, blockSum produces partial sum and stores it to the red_buffer, and groupSum produces\n      //   mean&var For the special case, blockSum produces mean&var directly\n      constexpr bool STORE_MEAN_VAR_IN_SHARED_RED_BUFFER =\n          VIRTUAL_CLUSTER_SIZE == 1 &&\n          MAX_NUM_GROUPS_PER_BLOCK == 1;  // MAX_NUM_GROUPS_PER_BLOCK > 1 is possible but not implemented\n\n      [[maybe_unused]] __align__(16)\n          __shared__ float2 shared_red_buffer[MAX_NUM_GROUPS_PER_BLOCK * (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? 1 : 2)];\n\n      // Block sum\n      if constexpr (SINGLE_GROUP_PER_BLOCK) {\n        // block reduce\n        if (threadIdx.x < 32) {\n          float2 sum_local_group =\n              threadIdx.x < BLOCK_DIM_X / 32 ? sum_per_channel_single_group[threadIdx.x] : float2{0.f, 0.f};\n          constexpr int warp_num_pow2 = round_up_pow2(BLOCK_DIM_X / 32);\n          for (int mask = warp_num_pow2 / 2; mask > 0; mask >>= 1) {\n            sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32);\n            sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32);\n          }\n          if (threadIdx.x == 0) {\n            if constexpr (USE_SHARED_RED_BUFFER) {\n              if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) {\n                shared_red_buffer[0] = compute_mean_var(sum_local_group);\n              } else {\n                shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + 0] = sum_local_group;\n              }\n            } else {\n              *reinterpret_cast<float2*>(\n                  &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK +\n                               virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK +\n                               // (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y +\n                               virtual_block_idx_y) *\n                              2]) = sum_local_group;\n            }\n          }\n        }\n      } else {\n        // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce)\n        constexpr int THREADS_PER_GROUP = std::min(std::min(32U, round_up_pow2(ROWS_PER_IO)),\n                                                   round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1));\n        static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, \"not enough threads\");\n        float2 sum_local_group = {0.f, 0.f};\n        if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) {\n          int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP;\n          // TODO: map threads to both the CPG loop and the ROWS loop\n          for (int local_c_loop = 0; local_c_loop < CPG; local_c_loop += GCD_VEC_CPG) {\n            int c = local_group_idx * CPG + local_c_loop;\n            if (C_PER_BLOCK % CPG == 0 || (c >= block_channel_start && c < block_channel_start + C_PER_BLOCK)) {\n              for (int src_thread_tile_y = threadIdx.x % THREADS_PER_GROUP; src_thread_tile_y < ROWS_PER_IO;\n                   src_thread_tile_y += THREADS_PER_GROUP) {\n                int channel_idx = (c - block_channel_start) / GCD_VEC_CPG;\n                channel_idx = channel_idx % (VEC_ELEMS / GCD_VEC_CPG) * (C_PER_BLOCK / VEC_ELEMS) +\n                              channel_idx / (VEC_ELEMS / GCD_VEC_CPG);\n                sum_local_group.x += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].x;\n                sum_local_group.y += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].y;\n              }\n            }\n          }\n        }\n        static_assert(32 % THREADS_PER_GROUP == 0, \"cannot shuffle\");\n        for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) {\n          sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32);\n          sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32);\n        }\n        if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) {\n          if constexpr (USE_SHARED_RED_BUFFER) {\n            static_assert(HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1, \"no distributed shared memory\");\n            if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) {\n              shared_red_buffer[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_local_group);\n            } else {\n              shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP] = sum_local_group;\n            }\n          } else {\n            *reinterpret_cast<float2*>(\n                &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK +\n                             virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK +\n                             (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + virtual_block_idx_y) *\n                            2]) = sum_local_group;\n          }\n        }\n      }\n\n      virtual_cluster_sync<VIRTUAL_CLUSTER_SIZE, PERSISTENT, HARDWARE_CLUSTER>(barrier);\n\n      // Group sum\n      __shared__ float2 mean_var[MAX_NUM_GROUPS_PER_BLOCK];\n      if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) {\n        // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce)\n        constexpr int THREADS_PER_GROUP = std::min(std::min(32U, round_up_pow2(virtual_cluster_dim_y)),\n                                                   round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1));\n        static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, \"not enough threads\");\n        float2 sum_global_group = {0.f, 0.f};\n        if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) {\n          if constexpr (C_PER_BLOCK % CPG == 0) {\n            // Special case: no cross-virtual_cluster_dim_x reduction\n            float2 buffer[up_div(virtual_cluster_dim_y, THREADS_PER_GROUP)];\n            for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) {\n              float2 val;\n              if constexpr (USE_SHARED_RED_BUFFER) {\n                if constexpr (VIRTUAL_CLUSTER_SIZE == 1) {\n                  val = shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP];\n                } else {\n                  static_assert(HARDWARE_CLUSTER, \"no distributed shared memory\");\n                  float2 const* src_shared_red_buffer = cg::this_cluster().map_shared_rank(\n                      shared_red_buffer, i * virtual_cluster_dim_x + virtual_block_idx_x);\n                  val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP];\n                }\n              } else {\n                val = *reinterpret_cast<float2 const*>(\n                    &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK +\n                                 virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK +\n                                 (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + i) *\n                                2]);\n              }\n              buffer[i / THREADS_PER_GROUP] = val;\n            }\n            for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) {\n              float2 val = buffer[i / THREADS_PER_GROUP];\n              sum_global_group.x += val.x;\n              sum_global_group.y += val.y;\n            }\n          } else {\n            // Common case: cross-virtual_cluster_dim_x reduction\n            int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP;\n            for (int i = threadIdx.x % THREADS_PER_GROUP; i < VIRTUAL_CLUSTER_SIZE; i += THREADS_PER_GROUP) {\n              int src_virtual_block_idx_x = i % virtual_cluster_dim_x;\n              int src_block_channel_start = src_virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER;\n              int src_block_group_start = src_block_channel_start / CPG;\n              int relative_group_idx = local_group_idx - src_block_group_start;\n              if (0 <= relative_group_idx && relative_group_idx < MAX_NUM_GROUPS_PER_BLOCK) {\n                float2 val;\n                if constexpr (USE_SHARED_RED_BUFFER) {\n                  static_assert(HARDWARE_CLUSTER, \"no distributed shared memory\");\n                  static_assert(VIRTUAL_CLUSTER_SIZE != 1,\n                                \"layout error: should not add (step * MAX_NUM_GROUPS_PER_BLOCK)\");\n                  float2 const* src_shared_red_buffer = cg::this_cluster().map_shared_rank(shared_red_buffer, i);\n                  val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + relative_group_idx];\n                } else {\n                  val = *reinterpret_cast<float2 const*>(\n                      &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK +\n                                   src_virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK +\n                                   relative_group_idx * virtual_cluster_dim_y + i / virtual_cluster_dim_x) *\n                                  2]);\n                }\n                sum_global_group.x += val.x;\n                sum_global_group.y += val.y;\n              }\n            }\n          }\n        }\n        if constexpr (USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) {\n          // Need cluster sync after distributed shared memory access, otherwise behavior is undefined\n          if constexpr (PERSISTENT) {\n            if (nc_scheduler.at_end(n)) {\n              cg::this_cluster().barrier_arrive();\n            }\n          } else {\n            cg::this_cluster().barrier_arrive();\n          }\n        }\n        static_assert(32 % THREADS_PER_GROUP == 0, \"cannot shuffle\");\n        for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) {\n          sum_global_group.x += __shfl_xor_sync(FINAL_MASK, sum_global_group.x, mask, 32);\n          sum_global_group.y += __shfl_xor_sync(FINAL_MASK, sum_global_group.y, mask, 32);\n        }\n        if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) {\n          mean_var[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_global_group);\n        }\n        __syncthreads();\n      }\n\n      auto get_mean_var = [&](int relative_group_idx) {\n        return STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? shared_red_buffer[relative_group_idx]\n                                                   : mean_var[relative_group_idx];\n      };\n\n      if (mean_var_out) {\n        static_assert(MAX_NUM_GROUPS_PER_BLOCK <= BLOCK_DIM_X, \"need loop\");\n        if (virtual_block_idx_y == 0 && threadIdx.x < MAX_NUM_GROUPS_PER_BLOCK) {\n          int g = block_group_start + threadIdx.x;\n          if (C_PER_BLOCK % CPG == 0 || g < G) {\n            *reinterpret_cast<float2*>(&mean_var_out[(n_loop * G + g) * 2]) = get_mean_var(threadIdx.x);\n          }\n        }\n      }\n\n      float frag_mean[VEC_ELEMS / GCD_VEC_CPG];\n      float frag_var[VEC_ELEMS / GCD_VEC_CPG];\n      for (int k = 0; k < VEC_ELEMS; k += GCD_VEC_CPG) {\n        frag_mean[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).x;\n        frag_var[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).y;\n      }\n\n      for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) {\n        int64_t input_idx =\n            n_loop * HW * C +\n            (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C +\n            thread_channel_start;\n        U val;\n        if constexpr (LOAD_TWICE) {\n          val = *reinterpret_cast<U const*>(&x[input_idx]);\n        } else {\n          val = frag[j];\n        }\n        for (int k = 0; k < VEC_ELEMS; k++) {\n          float f = ((float)val.data[k] - frag_mean[k / GCD_VEC_CPG]) * rsqrtf(frag_var[k / GCD_VEC_CPG] + eps) *\n                        (float)uw.data[k] +\n                    (float)ub.data[k];\n          if constexpr (SILU) f = f / (1.f + expf(-f));\n          val.data[k] = f;\n        }\n        *reinterpret_cast<U*>(&out[input_idx]) = val;\n      }\n\n      if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER && USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) {\n        if constexpr (PERSISTENT) {\n          if (nc_scheduler.at_end(n)) {\n            cg::this_cluster().barrier_wait();\n          }\n        } else {\n          cg::this_cluster().barrier_wait();\n        }\n      }\n\n      if constexpr (!PERSISTENT) {\n        break;\n      }\n      step ^= 1;\n    }\n  }\n}\n\nenum WgradSyncMethod {\n  WGRAD_ARRIVE_AND_WAIT_GRID = 0,  // grid arrive after the last virtual cluster sync\n  WGRAD_ARRIVE_AND_WAIT_GROUP,     // group arrive after the last virtual cluster sync (a group sync means synchronizing\n                                   // all clusters cooperating on the same groups)\n  WGRAD_REUSE_SUM_SYNC_GRID,       // grid sync together with the last virtual cluster sync\n  WGRAD_REUSE_SUM_SYNC_GROUP,      // group sync together with the last virtual cluster sync\n  WGRAD_SYNC_AT_LAST,              // add a sync at the end of NC loops\n  WGRAD_SYNC_UNSPECIFIED,\n};\n\ntemplate <typename T, int BLOCK_DIM_X, int BLOCKS_PER_SM, int G, int CPG, int HW, bool SILU, bool REQUIRES_WGRAD,\n          int ROWS_PER_BLOCK, int C_PER_BLOCK, int C_PER_CLUSTER, int VEC_ELEMS, bool PERSISTENT,\n          int NUM_VIRTUAL_CLUSTERS, bool LOAD_TWICE, bool HARDWARE_CLUSTER, WgradSyncMethod wgrad_sync_method,\n          class CompileCondition = CompileConditionAlwaysTrue>\n__global__ __launch_bounds__(BLOCK_DIM_X, BLOCKS_PER_SM) void gn_bwd_cuda_kernel(\n    T* __restrict__ grad_input, T* __restrict__ grad_weight, T* __restrict__ grad_bias,\n    T const* __restrict__ grad_output, T const* __restrict__ x, T const* __restrict__ w, T const* __restrict__ b,\n    float const* __restrict__ mean_var, float eps, int64_t n, float* __restrict__ red_buffer,\n    unsigned* __restrict__ barrier) {\n  // Procedure Overview\n  //   1. Thread sum: read from gmem, write partial sum to smem, store input in registers (if no LOAD_TWICE)\n  //   2. Block sum: read from smem, write partial sum to gmem (or distributed shared memory if HARDWARE_CLUSTER is\n  //   used),\n  //        write wgrad to gmem at the last loop (at each loop if not CONSTANT_C_LOOP)\n  //   3. Group sum: read from gmem, write mean&var to smem\n  //   4. Scale: read mean&var from smem, read input from gmem (if LOAD_TWICE), write output to gmem\n  //   5. Wgrad sum: read from gmem, write to gmem\n\n  static_assert(BLOCK_DIM_X % 32 == 0, \"warp shuffle error\");\n\n  constexpr int C = G * CPG;\n  static_assert(C % C_PER_CLUSTER == 0, \"cannot divide channels into clusters\");\n  static_assert(C_PER_CLUSTER % C_PER_BLOCK == 0, \"cannot divide a cluster into blocks\");\n  static_assert(C_PER_CLUSTER % CPG == 0, \"no reduce between clusters, would produce incorrect results\");\n  static_assert(!(C_PER_BLOCK % CPG == 0 && C_PER_CLUSTER != C_PER_BLOCK),\n                \"inefficient configuration, please reduce C_PER_CLUSTER\");\n\n  static_assert(ROWS_PER_BLOCK * C_PER_BLOCK % BLOCK_DIM_X == 0, \"cannot divide tile into threads\");\n  struct alignas(VEC_ELEMS * sizeof(T)) U {\n    T data[VEC_ELEMS];\n  };\n\n  // This function computes mean_dyw and mean_xdyw.\n  // The function name is not changed because it has the same logic as the forward pass.\n  auto compute_mean_var = [&](float2 sum) {\n    float mean_dyw = sum.x / (HW * CPG);\n    float mean_xdyw = sum.y / (HW * CPG);\n    return float2{mean_dyw, mean_xdyw};\n  };\n\n  static_assert(HW % ROWS_PER_BLOCK == 0,\n                \"HW must be divisible by ROWS_PER_BLOCK to determine the number of blocks on the HW axis\");\n  constexpr int MAX_NUM_GROUPS_PER_BLOCK =\n      C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1;\n  constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK);\n  constexpr int virtual_cluster_dim_x = C_PER_CLUSTER / C_PER_BLOCK;\n  constexpr int virtual_cluster_dim_y = HW / ROWS_PER_BLOCK;\n  int virtual_block_idx_x = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) % virtual_cluster_dim_x;\n  int virtual_block_idx_y = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) / virtual_cluster_dim_x;\n\n  if constexpr (CompileCondition::matches()) {\n    int step = 0;\n    constexpr bool CONSTANT_C_LOOP = PERSISTENT && NUM_VIRTUAL_CLUSTERS % (C / C_PER_CLUSTER) == 0;\n    if constexpr (!CONSTANT_C_LOOP) {\n      static_assert(wgrad_sync_method != WGRAD_ARRIVE_AND_WAIT_GROUP && wgrad_sync_method != WGRAD_REUSE_SUM_SYNC_GROUP,\n                    \"grid sync is required when each block is responsible for multiple channel ranges\");\n    }\n    NCScheduler<false, C, C_PER_CLUSTER, NUM_VIRTUAL_CLUSTERS, PERSISTENT> nc_scheduler(\n        n);  // TODO: I don't know why the template specialization with CONSTANT_C_LOOP=true is slower.\n\n    [[maybe_unused]] int virtual_cluster_idx_c = blockIdx.y % (C / C_PER_CLUSTER);\n    [[maybe_unused]] cg::grid_group::arrival_token wgrad_sync_token;\n    [[maybe_unused]] float dw_thread[VEC_ELEMS];\n    [[maybe_unused]] float db_thread[VEC_ELEMS];\n    [[maybe_unused]] __shared__ union {\n      float2 dwdb_block_buffer[BLOCK_DIM_X][VEC_ELEMS];\n      struct {\n        float wgrad_buffer[BLOCK_DIM_X / 32][32];\n        float bgrad_buffer[BLOCK_DIM_X / 32][32];\n      } transpose_buffer;\n    } union_smem;\n    if constexpr (REQUIRES_WGRAD && CONSTANT_C_LOOP) {\n      for (int i = 0; i < VEC_ELEMS; i++) {\n        dw_thread[i] = 0.f;\n        db_thread[i] = 0.f;\n      }\n    }\n    float* red_buffer_wgrad =\n        &red_buffer[(2 * NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK) * 2];\n    unsigned* barrier_wgrad = barrier + NUM_VIRTUAL_CLUSTERS;\n    if constexpr (REQUIRES_WGRAD && wgrad_sync_method != WGRAD_SYNC_AT_LAST) {\n      if (nc_scheduler.at_end(n)) {\n        static_assert(PERSISTENT, \"persistent is a must for reducing wgrad\");\n        if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GRID) {\n          wgrad_sync_token = group_barrier_arrive<NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE, PERSISTENT>(\n              barrier_wgrad, blockIdx.x + blockIdx.y == 0);\n        } else if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP) {\n          wgrad_sync_token =\n              group_barrier_arrive<NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE / (C / C_PER_CLUSTER), PERSISTENT>(\n                  barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0);\n        } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GRID) {\n          wgrad_sync_token = group_barrier_arrive<NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE, PERSISTENT>(\n              barrier_wgrad, blockIdx.x + blockIdx.y == 0);\n          group_barrier_wait(barrier_wgrad, wgrad_sync_token);\n        } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP) {\n          wgrad_sync_token =\n              group_barrier_arrive<NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE / (C / C_PER_CLUSTER), PERSISTENT>(\n                  barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0);\n          group_barrier_wait(barrier_wgrad + virtual_cluster_idx_c, wgrad_sync_token);\n        }\n      }\n    }\n\n    while (true) {  // TODO: unroll the loop\n      if constexpr (PERSISTENT) {\n        if (nc_scheduler.at_end(n)) {\n          break;\n        }\n      }\n      auto [n_loop, c_loop] = nc_scheduler.get_nc();\n      if constexpr (PERSISTENT) {\n        nc_scheduler.next(n);\n      }\n      static_assert(C_PER_BLOCK % VEC_ELEMS == 0, \"cannot vectorize\");\n      static_assert((BLOCK_DIM_X * VEC_ELEMS) % C_PER_BLOCK == 0,\n                    \"each block should load one or more C_PER_BLOCK at once\");\n      constexpr int ROWS_PER_IO = BLOCK_DIM_X * VEC_ELEMS / C_PER_BLOCK;\n      static_assert(ROWS_PER_BLOCK % ROWS_PER_IO == 0, \"cannot determine the IO times per batch\");\n      int block_channel_start = virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER;\n      int block_group_start = block_channel_start / CPG;\n      int thread_channel_start = block_channel_start + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS) * VEC_ELEMS;\n      U frag_x[ROWS_PER_BLOCK / ROWS_PER_IO];\n      U frag_dy[ROWS_PER_BLOCK / ROWS_PER_IO];\n\n      constexpr int GCD_VEC_CPG = gcd(VEC_ELEMS, CPG);\n\n      constexpr bool SINGLE_GROUP_PER_BLOCK = CPG % C_PER_BLOCK == 0;\n      [[maybe_unused]] __shared__ float2 sum_per_channel_multi_group[C_PER_BLOCK / GCD_VEC_CPG][relative_prime(\n          128 / (int)sizeof(float2), ROWS_PER_IO)];\n      [[maybe_unused]] __shared__ float2 sum_per_channel_single_group[BLOCK_DIM_X / 32];\n\n      float frag_mean[VEC_ELEMS / GCD_VEC_CPG];\n      float frag_var[VEC_ELEMS / GCD_VEC_CPG];\n      for (int k = 0; k < VEC_ELEMS; k += GCD_VEC_CPG) {\n        float2 value = *reinterpret_cast<float2 const*>(&mean_var[(n_loop * G + (thread_channel_start + k) / CPG) * 2]);\n        frag_mean[k / GCD_VEC_CPG] = value.x;\n        frag_var[k / GCD_VEC_CPG] = value.y;\n      }\n\n      U uw = *reinterpret_cast<U const*>(&w[thread_channel_start]);\n      U ub;\n      if constexpr (SILU) {\n        ub = *reinterpret_cast<U const*>(&b[thread_channel_start]);\n      }\n      if constexpr (REQUIRES_WGRAD && !CONSTANT_C_LOOP) {\n        for (int i = 0; i < VEC_ELEMS; i++) {\n          dw_thread[i] = 0.f;\n          db_thread[i] = 0.f;\n        }\n      }\n\n      if constexpr (LOAD_TWICE) {\n        float2 frag_sum_per_channel[VEC_ELEMS / GCD_VEC_CPG]{};\n        for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) {\n          int64_t input_idx =\n              n_loop * HW * C +\n              (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C +\n              thread_channel_start;\n          U ux = *reinterpret_cast<U const*>(&x[input_idx]);\n          U udy = *reinterpret_cast<U const*>(&grad_output[input_idx]);\n          for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) {\n            float2 sum = frag_sum_per_channel[i];\n            for (int k = 0; k < GCD_VEC_CPG; k++) {\n              float rnorm = rsqrtf(frag_var[i] + eps);\n              float x_norm =\n                  ((float)ux.data[i * GCD_VEC_CPG + k] - frag_mean[i]) * rnorm;  // TODO: store rsqrtf in mean_var\n              float grad_gn = udy.data[i * GCD_VEC_CPG + k];\n              if constexpr (SILU) {\n                float x_gn = x_norm * (float)uw.data[i * GCD_VEC_CPG + k] + (float)ub.data[i * GCD_VEC_CPG + k];\n                float s = 1.f / (1.f + expf(-x_gn));\n                grad_gn *= s * (1.f + x_gn * (1.f - s));\n              }\n              sum.x += grad_gn * (float)uw.data[i * GCD_VEC_CPG + k];\n              sum.y += x_norm * (grad_gn * (float)uw.data[i * GCD_VEC_CPG + k]);\n              if constexpr (REQUIRES_WGRAD) {\n                dw_thread[i * GCD_VEC_CPG + k] += x_norm * grad_gn;\n                db_thread[i * GCD_VEC_CPG + k] += grad_gn;\n              }\n            }\n            frag_sum_per_channel[i] = sum;\n          }\n        }\n        for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) {\n          if constexpr (SINGLE_GROUP_PER_BLOCK) {\n            for (int mask = 16; mask > 0; mask >>= 1) {\n              frag_sum_per_channel[i].x += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].x, mask, 32);\n              frag_sum_per_channel[i].y += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].y, mask, 32);\n            }\n            static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, \"process only one element for each warp\");\n            if (threadIdx.x % 32 == 0) {\n              sum_per_channel_single_group[threadIdx.x / 32] = frag_sum_per_channel[i];\n            }\n          } else {\n            sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)]\n                                       [threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = frag_sum_per_channel[i];\n          }\n        }\n        __syncthreads();\n      } else {\n        for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) {\n          int64_t input_idx =\n              n_loop * HW * C +\n              (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C +\n              thread_channel_start;\n          frag_x[j] = *reinterpret_cast<U const*>(&x[input_idx]);\n          frag_dy[j] = *reinterpret_cast<U const*>(&grad_output[input_idx]);\n        }\n\n        for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) {\n          float2 sum = {0.f, 0.f};\n          for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) {\n            for (int k = 0; k < GCD_VEC_CPG; k++) {\n              float rnorm = rsqrtf(frag_var[i] + eps);\n              float x_norm = ((float)frag_x[j].data[i * GCD_VEC_CPG + k] - frag_mean[i]) *\n                             rnorm;  // TODO: store rsqrtf in mean_var\n              float grad_gn = frag_dy[j].data[i * GCD_VEC_CPG + k];\n              if constexpr (SILU) {\n                float x_gn = x_norm * (float)uw.data[i * GCD_VEC_CPG + k] + (float)ub.data[i * GCD_VEC_CPG + k];\n                float s = 1.f / (1.f + expf(-x_gn));\n                grad_gn *= s * (1.f + x_gn * (1.f - s));\n              }\n              sum.x += grad_gn * (float)uw.data[i * GCD_VEC_CPG + k];\n              sum.y += x_norm * (grad_gn * (float)uw.data[i * GCD_VEC_CPG + k]);\n              if constexpr (REQUIRES_WGRAD) {\n                dw_thread[i * GCD_VEC_CPG + k] += x_norm * grad_gn;\n                db_thread[i * GCD_VEC_CPG + k] += grad_gn;\n              }\n            }\n          }\n          if constexpr (SINGLE_GROUP_PER_BLOCK) {\n            for (int mask = 16; mask > 0; mask >>= 1) {\n              sum.x += __shfl_xor_sync(FINAL_MASK, sum.x, mask, 32);\n              sum.y += __shfl_xor_sync(FINAL_MASK, sum.y, mask, 32);\n            }\n            static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, \"process only one element for each warp\");\n            if (threadIdx.x % 32 == 0) {\n              sum_per_channel_single_group[threadIdx.x / 32] = sum;\n            }\n          } else {\n            sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)]\n                                       [threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = sum;\n          }\n        }\n        __syncthreads();\n      }\n\n      if ((CONSTANT_C_LOOP && nc_scheduler.at_end(n)) || !CONSTANT_C_LOOP) {\n        constexpr int NT_C = max_divisor(C_PER_BLOCK, BLOCK_DIM_X);  // Number of threads on the C axis\n        constexpr int NT_R =\n            1;  // std::min(32, (int)round_down_pow2(BLOCK_DIM_X / NT_C));  // Number of threads on the ROWS axis\n        // TODO: swizzle for NT_R\n        for (int i = 0; i < VEC_ELEMS; i++) {\n          union_smem.dwdb_block_buffer[threadIdx.x][i ^ ((threadIdx.x / (16 / VEC_ELEMS)) & (VEC_ELEMS - 1))] =\n              float2{dw_thread[i], db_thread[i]};\n        }\n        __syncthreads();\n        static_assert(NT_C * NT_R <= BLOCK_DIM_X, \"not enough threads\");\n        static_assert(C_PER_BLOCK % NT_C == 0, \"need to loop once more and check c < C_PER_BLOCK\");\n        for (int i = 0; i < C_PER_BLOCK / NT_C; i++) {\n          int c = i * NT_C + threadIdx.x / NT_R;\n          float dw_block = 0.f;\n          float db_block = 0.f;\n          if (BLOCK_DIM_X == NT_C * NT_R || threadIdx.x < NT_C * NT_R) {\n            for (int j = threadIdx.x % NT_R; j < ROWS_PER_IO; j += NT_R) {\n              int src_thread = j * (C_PER_BLOCK / VEC_ELEMS) + c / VEC_ELEMS;\n              float2 val = union_smem.dwdb_block_buffer[src_thread][(c % VEC_ELEMS) ^ ((src_thread / (16 / VEC_ELEMS)) &\n                                                                                       (VEC_ELEMS - 1))];\n              dw_block += val.x;\n              db_block += val.y;\n            }\n          }\n          static_assert(32 % NT_R == 0, \"cannot shuffle\");\n          for (int mask = NT_R / 2; mask > 0; mask >>= 1) {\n            dw_block += __shfl_xor_sync(FINAL_MASK, dw_block, mask, 32);\n            db_block += __shfl_xor_sync(FINAL_MASK, db_block, mask, 32);\n          }\n          if (BLOCK_DIM_X == NT_C * NT_R || threadIdx.x < NT_C * NT_R) {\n            if (threadIdx.x % NT_R == 0) {\n              if constexpr (CONSTANT_C_LOOP) {\n                *reinterpret_cast<float2*>(\n                    &red_buffer_wgrad\n                        [((blockIdx.y / (C / C_PER_CLUSTER) * virtual_cluster_dim_y + virtual_block_idx_y) * C +\n                          c_loop * C_PER_CLUSTER + virtual_block_idx_x * C_PER_BLOCK + c) *\n                         2]) = float2{dw_block, db_block};\n              } else {\n                *reinterpret_cast<float2*>(\n                    &red_buffer_wgrad[((n_loop * virtual_cluster_dim_y + virtual_block_idx_y) * C +\n                                       c_loop * C_PER_CLUSTER + virtual_block_idx_x * C_PER_BLOCK + c) *\n                                      2]) = float2{dw_block, db_block};\n              }\n            }\n          }\n        }\n      }\n\n      constexpr bool USE_SHARED_RED_BUFFER = HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1;\n      constexpr bool STORE_MEAN_VAR_IN_SHARED_RED_BUFFER =\n          VIRTUAL_CLUSTER_SIZE == 1 &&\n          MAX_NUM_GROUPS_PER_BLOCK == 1;  // MAX_NUM_GROUPS_PER_BLOCK > 1 is possible but not implemented\n      [[maybe_unused]] __align__(16)\n          __shared__ float2 shared_red_buffer[MAX_NUM_GROUPS_PER_BLOCK * (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? 1 : 2)];\n\n      // Block sum\n      if constexpr (SINGLE_GROUP_PER_BLOCK) {\n        // block reduce\n        if (threadIdx.x < 32) {\n          float2 sum_local_group =\n              threadIdx.x < BLOCK_DIM_X / 32 ? sum_per_channel_single_group[threadIdx.x] : float2{0.f, 0.f};\n          constexpr int warp_num_pow2 = round_up_pow2(BLOCK_DIM_X / 32);\n          for (int mask = warp_num_pow2 / 2; mask > 0; mask >>= 1) {\n            sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32);\n            sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32);\n          }\n          if (threadIdx.x == 0) {\n            if constexpr (USE_SHARED_RED_BUFFER) {\n              if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) {\n                shared_red_buffer[0] = compute_mean_var(sum_local_group);\n              } else {\n                shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + 0] = sum_local_group;\n              }\n            } else {\n              *reinterpret_cast<float2*>(\n                  &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK +\n                               virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK +\n                               // (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y +\n                               virtual_block_idx_y) *\n                              2]) = sum_local_group;\n            }\n          }\n        }\n      } else {\n        // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce)\n        constexpr int THREADS_PER_GROUP = std::min(std::min(32U, round_up_pow2(ROWS_PER_IO)),\n                                                   round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1));\n        static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, \"not enough threads\");\n        float2 sum_local_group = {0.f, 0.f};\n        if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) {\n          int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP;\n          // TODO: map threads to both the CPG loop and the ROWS loop\n          for (int local_c_loop = 0; local_c_loop < CPG; local_c_loop += GCD_VEC_CPG) {\n            int c = local_group_idx * CPG + local_c_loop;\n            if (C_PER_BLOCK % CPG == 0 || (c >= block_channel_start && c < block_channel_start + C_PER_BLOCK)) {\n              for (int src_thread_tile_y = threadIdx.x % THREADS_PER_GROUP; src_thread_tile_y < ROWS_PER_IO;\n                   src_thread_tile_y += THREADS_PER_GROUP) {\n                int channel_idx = (c - block_channel_start) / GCD_VEC_CPG;\n                channel_idx = channel_idx % (VEC_ELEMS / GCD_VEC_CPG) * (C_PER_BLOCK / VEC_ELEMS) +\n                              channel_idx / (VEC_ELEMS / GCD_VEC_CPG);\n                sum_local_group.x += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].x;\n                sum_local_group.y += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].y;\n              }\n            }\n          }\n        }\n        static_assert(32 % THREADS_PER_GROUP == 0, \"cannot shuffle\");\n        for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) {\n          sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32);\n          sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32);\n        }\n        if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) {\n          if constexpr (USE_SHARED_RED_BUFFER) {\n            static_assert(HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1, \"no distributed shared memory\");\n            if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) {\n              shared_red_buffer[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_local_group);\n            } else {\n              shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP] = sum_local_group;\n            }\n          } else {\n            *reinterpret_cast<float2*>(\n                &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK +\n                             virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK +\n                             (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + virtual_block_idx_y) *\n                            2]) = sum_local_group;\n          }\n        }\n      }\n\n      if constexpr (REQUIRES_WGRAD && wgrad_sync_method != WGRAD_SYNC_AT_LAST) {\n        if (nc_scheduler.at_end(n)) {\n          static_assert(PERSISTENT, \"persistent is a must for reducing wgrad\");\n          if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GRID) {\n            virtual_cluster_sync<VIRTUAL_CLUSTER_SIZE, PERSISTENT, HARDWARE_CLUSTER>(barrier);\n            wgrad_sync_token = group_barrier_arrive<NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE, PERSISTENT>(\n                barrier_wgrad, blockIdx.x + blockIdx.y == 0);\n          } else if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP) {\n            virtual_cluster_sync<VIRTUAL_CLUSTER_SIZE, PERSISTENT, HARDWARE_CLUSTER>(barrier);\n            wgrad_sync_token =\n                group_barrier_arrive<NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE / (C / C_PER_CLUSTER), PERSISTENT>(\n                    barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0);\n          } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GRID) {\n            static_assert(!HARDWARE_CLUSTER,\n                          \"Distributed smem sync cannot reuse gmem sync. Use WGRAD_ARRIVE_AND_WAIT_GRID instead.\");\n            wgrad_sync_token = group_barrier_arrive<NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE, PERSISTENT>(\n                barrier_wgrad, blockIdx.x + blockIdx.y == 0);\n            group_barrier_wait(barrier_wgrad, wgrad_sync_token);\n          } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP) {\n            static_assert(!HARDWARE_CLUSTER,\n                          \"Distributed smem sync cannot reuse gmem sync. Use WGRAD_ARRIVE_AND_WAIT_GROUP instead.\");\n            wgrad_sync_token =\n                group_barrier_arrive<NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE / (C / C_PER_CLUSTER), PERSISTENT>(\n                    barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0);\n            group_barrier_wait(barrier_wgrad + virtual_cluster_idx_c, wgrad_sync_token);\n          }\n        } else {\n          virtual_cluster_sync<VIRTUAL_CLUSTER_SIZE, PERSISTENT, HARDWARE_CLUSTER>(barrier);\n        }\n      } else {\n        virtual_cluster_sync<VIRTUAL_CLUSTER_SIZE, PERSISTENT, HARDWARE_CLUSTER>(barrier);\n      }\n\n      // Group sum\n      __shared__ float2 mean_var[MAX_NUM_GROUPS_PER_BLOCK];\n      if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) {\n        // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce)\n        constexpr int THREADS_PER_GROUP = std::min(std::min(32U, round_up_pow2(virtual_cluster_dim_y)),\n                                                   round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1));\n        static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, \"not enough threads\");\n        float2 sum_global_group = {0.f, 0.f};\n        if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) {\n          if constexpr (C_PER_BLOCK % CPG == 0) {\n            // Special case: no cross-virtual_cluster_dim_x reduction\n            float2 buffer[up_div(virtual_cluster_dim_y, THREADS_PER_GROUP)];\n            for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) {\n              float2 val;\n              if constexpr (USE_SHARED_RED_BUFFER) {\n                if constexpr (VIRTUAL_CLUSTER_SIZE == 1) {\n                  val = shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP];\n                } else {\n                  static_assert(HARDWARE_CLUSTER, \"no distributed shared memory\");\n                  float2 const* src_shared_red_buffer = cg::this_cluster().map_shared_rank(\n                      shared_red_buffer, i * virtual_cluster_dim_x + virtual_block_idx_x);\n                  val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP];\n                }\n              } else {\n                val = *reinterpret_cast<float2 const*>(\n                    &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK +\n                                 virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK +\n                                 (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + i) *\n                                2]);\n              }\n              buffer[i / THREADS_PER_GROUP] = val;\n            }\n            for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) {\n              float2 val = buffer[i / THREADS_PER_GROUP];\n              sum_global_group.x += val.x;\n              sum_global_group.y += val.y;\n            }\n          } else {\n            // Common case: cross-virtual_cluster_dim_x reduction\n            int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP;\n            for (int i = threadIdx.x % THREADS_PER_GROUP; i < VIRTUAL_CLUSTER_SIZE; i += THREADS_PER_GROUP) {\n              int src_virtual_block_idx_x = i % virtual_cluster_dim_x;\n              int src_block_channel_start = src_virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER;\n              int src_block_group_start = src_block_channel_start / CPG;\n              int relative_group_idx = local_group_idx - src_block_group_start;\n              if (0 <= relative_group_idx && relative_group_idx < MAX_NUM_GROUPS_PER_BLOCK) {\n                float2 val;\n                if constexpr (USE_SHARED_RED_BUFFER) {\n                  static_assert(HARDWARE_CLUSTER, \"no distributed shared memory\");\n                  static_assert(VIRTUAL_CLUSTER_SIZE != 1,\n                                \"layout error: should not add (step * MAX_NUM_GROUPS_PER_BLOCK)\");\n                  float2 const* src_shared_red_buffer = cg::this_cluster().map_shared_rank(shared_red_buffer, i);\n                  val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + relative_group_idx];\n                } else {\n                  val = *reinterpret_cast<float2 const*>(\n                      &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK +\n                                   src_virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK +\n                                   relative_group_idx * virtual_cluster_dim_y + i / virtual_cluster_dim_x) *\n                                  2]);\n                }\n                sum_global_group.x += val.x;\n                sum_global_group.y += val.y;\n              }\n            }\n          }\n        }\n        if constexpr (USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) {\n          // Need cluster sync after distributed shared memory access, otherwise behavior is undefined\n          if constexpr (PERSISTENT) {\n            if (nc_scheduler.at_end(n)) {\n              cg::this_cluster().barrier_arrive();\n            }\n          } else {\n            cg::this_cluster().barrier_arrive();\n          }\n        }\n        static_assert(32 % THREADS_PER_GROUP == 0, \"cannot shuffle\");\n        for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) {\n          sum_global_group.x += __shfl_xor_sync(FINAL_MASK, sum_global_group.x, mask, 32);\n          sum_global_group.y += __shfl_xor_sync(FINAL_MASK, sum_global_group.y, mask, 32);\n        }\n        if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) {\n          mean_var[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_global_group);\n        }\n        __syncthreads();\n      }\n\n      auto get_mean_var = [&](int relative_group_idx) {\n        return STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? shared_red_buffer[relative_group_idx]\n                                                   : mean_var[relative_group_idx];\n      };\n\n      float frag_dyw[VEC_ELEMS / GCD_VEC_CPG];\n      float frag_xdyw[VEC_ELEMS / GCD_VEC_CPG];\n      for (int k = 0; k < VEC_ELEMS; k += GCD_VEC_CPG) {\n        frag_dyw[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).x;\n        frag_xdyw[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).y;\n      }\n\n      for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) {\n        int64_t input_idx =\n            n_loop * HW * C +\n            (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C +\n            thread_channel_start;\n        U ux;\n        U udy;\n        if constexpr (LOAD_TWICE) {\n          ux = *reinterpret_cast<U const*>(&x[input_idx]);\n          udy = *reinterpret_cast<U const*>(&grad_output[input_idx]);\n        } else {\n          ux = frag_x[j];\n          udy = frag_dy[j];\n        }\n        U val;\n        for (int k = 0; k < VEC_ELEMS; k++) {\n          float rnorm = rsqrtf(frag_var[k / GCD_VEC_CPG] + eps);\n          float x_norm = ((float)ux.data[k] - frag_mean[k / GCD_VEC_CPG]) * rnorm;  // TODO: store rsqrtf in mean_var\n          float grad_gn = udy.data[k];\n          if constexpr (SILU) {\n            float x_gn = x_norm * (float)uw.data[k] + (float)ub.data[k];\n            float s = 1.f / (1.f + expf(-x_gn));\n            grad_gn *= s * (1.f + x_gn * (1.f - s));\n          }\n          val.data[k] =\n              (grad_gn * (float)uw.data[k] - frag_dyw[k / GCD_VEC_CPG] - frag_xdyw[k / GCD_VEC_CPG] * x_norm) * rnorm;\n        }\n        *reinterpret_cast<U*>(&grad_input[input_idx]) = val;\n      }\n\n      if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER && USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) {\n        if constexpr (PERSISTENT) {\n          if (nc_scheduler.at_end(n)) {\n            cg::this_cluster().barrier_wait();\n          }\n        } else {\n          cg::this_cluster().barrier_wait();\n        }\n      }\n\n      if constexpr (!PERSISTENT) {\n        break;\n      }\n      step ^= 1;\n    }\n\n    // Wgrad sum\n    if constexpr (REQUIRES_WGRAD) {\n      static_assert(PERSISTENT, \"cannot reduce wgrad\");\n      static_assert(C % 32 == 0, \"cannot reduce wgrad\");\n      if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GRID) {\n        group_barrier_wait(barrier_wgrad, wgrad_sync_token);\n      } else if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP) {\n        group_barrier_wait(barrier_wgrad + virtual_cluster_idx_c, wgrad_sync_token);\n      } else if constexpr (wgrad_sync_method == WGRAD_SYNC_AT_LAST) {\n        cg::this_grid().sync();\n      }\n\n      // If group sync, map blocks that are responsible for the same range of channels to these channels (named \"split\n      // channels\"); otherwise, map all blocks to all channels.\n      constexpr bool split_channels =\n          wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP || wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP;\n\n      for (int c = split_channels ? virtual_cluster_idx_c * C_PER_CLUSTER +\n                                        32 * (blockIdx.y / (C / C_PER_CLUSTER) * VIRTUAL_CLUSTER_SIZE + blockIdx.x)\n                                  : 32 * (blockIdx.y * VIRTUAL_CLUSTER_SIZE + blockIdx.x);\n           split_channels ? c < (virtual_cluster_idx_c + 1) * C_PER_CLUSTER : c < C;\n           c += split_channels ? 32 * (NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER) * VIRTUAL_CLUSTER_SIZE)\n                               : 32 * (NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE)) {\n        int64_t rows = (CONSTANT_C_LOOP ? std::min(n, (int64_t)NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER)) : n) *\n                       virtual_cluster_dim_y;\n        float sum_wgrad = 0.f;\n        float sum_bgrad = 0.f;\n        if ((split_channels &&\n             (C_PER_CLUSTER % 32 == 0 || c + threadIdx.x % 32 < (virtual_cluster_idx_c + 1) * C_PER_CLUSTER)) ||\n            (!split_channels && (C % 32 == 0 || c + threadIdx.x % 32 < C))) {\n          for (int64_t i = threadIdx.x / 32; i < rows; i += BLOCK_DIM_X / 32) {\n            float2 val = *reinterpret_cast<float2 const*>(&red_buffer_wgrad[(i * C + c + threadIdx.x % 32) * 2]);\n            sum_wgrad += val.x;\n            sum_bgrad += val.y;\n          }\n        }\n        constexpr int warp_num_pow2 = round_up_pow2(BLOCK_DIM_X / 32);\n        union_smem.transpose_buffer\n            .wgrad_buffer[threadIdx.x / 32][(threadIdx.x % 32) ^ ((threadIdx.x / 32) * (32 / warp_num_pow2))] =\n            sum_wgrad;\n        union_smem.transpose_buffer\n            .bgrad_buffer[threadIdx.x / 32][(threadIdx.x % 32) ^ ((threadIdx.x / 32) * (32 / warp_num_pow2))] =\n            sum_bgrad;\n        __syncthreads();\n        for (int i = threadIdx.x / warp_num_pow2;\n             i < 32 &&\n             ((split_channels && (C_PER_CLUSTER % 32 == 0 || c + i < (virtual_cluster_idx_c + 1) * C_PER_CLUSTER)) ||\n              (!split_channels && (C % 32 == 0 || c + i < C)));\n             i += BLOCK_DIM_X / warp_num_pow2) {\n          int j = threadIdx.x % warp_num_pow2;\n          float sum_wgrad =\n              j < BLOCK_DIM_X / 32 ? union_smem.transpose_buffer.wgrad_buffer[j][i ^ (j * (32 / warp_num_pow2))] : 0.f;\n          float sum_bgrad =\n              j < BLOCK_DIM_X / 32 ? union_smem.transpose_buffer.bgrad_buffer[j][i ^ (j * (32 / warp_num_pow2))] : 0.f;\n          for (int mask = warp_num_pow2 / 2; mask > 0; mask >>= 1) {\n            sum_wgrad += __shfl_xor_sync((uint64_t(1) << warp_num_pow2) - 1, sum_wgrad, mask, warp_num_pow2);\n            sum_bgrad += __shfl_xor_sync((uint64_t(1) << warp_num_pow2) - 1, sum_bgrad, mask, warp_num_pow2);\n          }\n          if (j == 0) {\n            grad_weight[c + i] = sum_wgrad;\n            grad_bias[c + i] = sum_bgrad;\n          }\n        }\n        __syncthreads();\n      }\n    }\n  }\n}\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_dispatch_hw_c.hpp",
    "content": "#pragma once\n\n#define DISPATCH_HW_C(hw, c, HW, C, ...)                                                          \\\n  [&] {                                                                                           \\\n    if (hw == 64 && c == 1280) {                                                                  \\\n      constexpr int HW = 64, C = 1280;                                                            \\\n      return __VA_ARGS__();                                                                       \\\n    }                                                                                             \\\n    if (hw == 64 && c == 2560) {                                                                  \\\n      constexpr int HW = 64, C = 2560;                                                            \\\n      return __VA_ARGS__();                                                                       \\\n    }                                                                                             \\\n    if (hw == 256 && c == 640) {                                                                  \\\n      constexpr int HW = 256, C = 640;                                                            \\\n      return __VA_ARGS__();                                                                       \\\n    }                                                                                             \\\n    if (hw == 256 && c == 1280) {                                                                 \\\n      constexpr int HW = 256, C = 1280;                                                           \\\n      return __VA_ARGS__();                                                                       \\\n    }                                                                                             \\\n    if (hw == 256 && c == 1920) {                                                                 \\\n      constexpr int HW = 256, C = 1920;                                                           \\\n      return __VA_ARGS__();                                                                       \\\n    }                                                                                             \\\n    if (hw == 256 && c == 2560) {                                                                 \\\n      constexpr int HW = 256, C = 2560;                                                           \\\n      return __VA_ARGS__();                                                                       \\\n    }                                                                                             \\\n    if (hw == 1024 && c == 320) {                                                                 \\\n      constexpr int HW = 1024, C = 320;                                                           \\\n      return __VA_ARGS__();                                                                       \\\n    }                                                                                             \\\n    if (hw == 1024 && c == 640) {                                                                 \\\n      constexpr int HW = 1024, C = 640;                                                           \\\n      return __VA_ARGS__();                                                                       \\\n    }                                                                                             \\\n    if (hw == 1024 && c == 960) {                                                                 \\\n      constexpr int HW = 1024, C = 960;                                                           \\\n      return __VA_ARGS__();                                                                       \\\n    }                                                                                             \\\n    if (hw == 1024 && c == 1280) {                                                                \\\n      constexpr int HW = 1024, C = 1280;                                                          \\\n      return __VA_ARGS__();                                                                       \\\n    }                                                                                             \\\n    if (hw == 1024 && c == 1920) {                                                                \\\n      constexpr int HW = 1024, C = 1920;                                                          \\\n      return __VA_ARGS__();                                                                       \\\n    }                                                                                             \\\n    if (hw == 4096 && c == 320) {                                                                 \\\n      constexpr int HW = 4096, C = 320;                                                           \\\n      return __VA_ARGS__();                                                                       \\\n    }                                                                                             \\\n    if (hw == 4096 && c == 640) {                                                                 \\\n      constexpr int HW = 4096, C = 640;                                                           \\\n      return __VA_ARGS__();                                                                       \\\n    }                                                                                             \\\n    if (hw == 4096 && c == 960) {                                                                 \\\n      constexpr int HW = 4096, C = 960;                                                           \\\n      return __VA_ARGS__();                                                                       \\\n    }                                                                                             \\\n    throw std::invalid_argument(\"DISPATCH_HW_C \" + std::to_string(hw) + \" \" + std::to_string(c)); \\\n  }()\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_utils.cpp",
    "content": "#include \"gn_utils.hpp\"\n\n#include <mutex>\n#include <vector>\n\nnamespace group_norm_v2 {\n\ncudaDeviceProp const& get_device_prop(int device_id) {\n  static std::vector<cudaDeviceProp> device_props;\n  static std::once_flag flag;\n  std::call_once(flag, [&] {\n    int count;\n    CUDA_CHECK(cudaGetDeviceCount(&count));\n    device_props.resize(count);\n    for (int i = 0; i < count; i++) {\n      CUDA_CHECK(cudaGetDeviceProperties(&device_props[i], i));\n    }\n  });\n  return device_props.at(device_id);\n}\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/group_norm_v2/gn_utils.hpp",
    "content": "#pragma once\n\n#include <cuda_runtime.h>\n\n#include <cassert>\n#include <cstdio>\n#include <cstdlib>\n\n#include \"gn.hpp\"\n\n// Definition of CUDA_CHECK macro\n#define CUDA_CHECK(call)                                                                                               \\\n  do {                                                                                                                 \\\n    cudaError_t err_ = call;                                                                                           \\\n    if (err_ != cudaSuccess) {                                                                                         \\\n      fprintf(stderr, \"CUDA error at %s:%d code=%d(%s) \\\"%s\\\" \\n\", __FILE__, __LINE__, err_, cudaGetErrorString(err_), \\\n              #call);                                                                                                  \\\n      exit(EXIT_FAILURE);                                                                                              \\\n    }                                                                                                                  \\\n  } while (0)\n\n#define GN_CUDA_HOST_PARAMS(T)                                                                                      \\\n  T *out, T *x, T *w, T *b, float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group,    \\\n      float *mean_var_out, float *red_buffer, unsigned *barrier, int sm_margin, cudaStream_t stream, int device_id, \\\n      Meta *meta_ptr, bool meta_only\n\n#define GN_BWD_CUDA_HOST_PARAMS(T)                                                                                    \\\n  T *grad_input, T *grad_weight, T *grad_bias, T *grad_output, T *x, T *w, T *b, float *mean_var, float eps,          \\\n      bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float *red_buffer, unsigned *barrier, \\\n      int sm_margin, cudaStream_t stream, int device_id, Meta *meta_ptr, bool meta_only\n\n#define GN_CUDA_HOST_ARGS                                                                                       \\\n  out, x, w, b, eps, silu, n, hw, num_groups, channels_per_group, mean_var_out, red_buffer, barrier, sm_margin, \\\n      stream, device_id, meta_ptr, meta_only\n\n#define GN_BWD_CUDA_HOST_ARGS                                                                       \\\n  grad_input, grad_weight, grad_bias, grad_output, x, w, b, mean_var, eps, silu, n, hw, num_groups, \\\n      channels_per_group, red_buffer, barrier, sm_margin, stream, device_id, meta_ptr, meta_only\n\nnamespace group_norm_v2 {\n\ncudaDeviceProp const& get_device_prop(int device_id);\n\n#ifdef __CUDA_ARCH__\n\ntemplate <class... Ts>\n__host__ __device__ inline int print_rank_0(char const* fmt, Ts&&... args) {\n  if (threadIdx.x + threadIdx.y + threadIdx.z == 0 && blockIdx.x + blockIdx.y + blockIdx.z == 0) {\n    return printf(fmt, std::forward<Ts>(args)...);\n  }\n  return 0;\n}\n\n#endif\n\n}  // namespace group_norm_v2\n"
  },
  {
    "path": "apex/contrib/csrc/groupbn/batch_norm.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDACachingAllocator.h>\n#include <cuda.h>\n\n#include \"batch_norm.h\"\n\n#define cudaCheckErrors(msg)                                                                                  \\\n  do {                                                                                                        \\\n    cudaError_t __err = cudaGetLastError();                                                                   \\\n    if (__err != cudaSuccess) {                                                                               \\\n      fprintf(stderr, \"Fatal error: %s (%s at %s:%d)\\n\", msg, cudaGetErrorString(__err), __FILE__, __LINE__); \\\n      fprintf(stderr, \"*** FAILED - ABORTING\\n\");                                                             \\\n      exit(1);                                                                                                \\\n    }                                                                                                         \\\n  } while (0)\n\nstatic size_t round_up_to_multiple(size_t x, int multiple) { return ((x + multiple - 1) / multiple) * multiple; }\n\nstruct Workspace {\n  Workspace(size_t size) : size(size), data(NULL) {\n    auto& allocator = *::c10::cuda::CUDACachingAllocator::get();\n    dataPtr = allocator.allocate(size);\n    data = dataPtr.get();\n  }\n  Workspace(const Workspace&) = delete;\n  Workspace(Workspace&&) = default;\n  Workspace& operator=(Workspace&&) = default;\n  ~Workspace() = default;\n\n  size_t size;\n  void* data;\n  c10::DataPtr dataPtr;\n};\n\n// Return {y}\nat::Tensor nhwc_bn_fwd_train(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias,\n                             const at::Tensor& running_mean, const at::Tensor& running_inv_var,\n                             const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var,\n                             const at::Tensor& ret_cta, const float momentum, const float epsilon, const bool fuse_relu,\n                             void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group,\n                             const at::Tensor& magic_tensor, const int occupancy, const int grid_dim_x,\n                             const bool coop) {\n  const int N = x.size(0);\n  const int H = x.size(1);\n  const int W = x.size(2);\n  const int C = x.size(3);\n\n  // generating new magic number and use that for sync\n  int* magic = magic_tensor.data_ptr<int>();\n  *magic = (*magic + 1) & 0xff;\n\n  // Allocate output tensor\n  at::Tensor y = at::empty({N, H, W, C}, x.options());\n\n  // Create wrapper\n  NhwcBatchNorm* bn = new NhwcBatchNorm();\n\n  bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);\n  bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);\n\n  bn->setConstants(momentum, epsilon);\n\n  // set pointers within the wrapper\n  bn->setInputOutputPointers(x.data_ptr<at::Half>(), nullptr, y.data_ptr<at::Half>(), nullptr);\n\n  bn->setWeightPointers({scale.data_ptr<float>(), bias.data_ptr<float>()}, {nullptr, nullptr});\n  bn->setParameterPointers({running_mean.data_ptr<float>(), running_inv_var.data_ptr<float>()});\n\n  // deal with workspace(s)\n  auto workspace_bytes = bn->numWorkspaceBytes();\n  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset\n  // an allocated workspace for the others\n  size_t total_workspace_bytes = 0;\n  std::vector<size_t> workspace_offsets;\n\n  for (auto index = 3; index < workspace_bytes.size(); ++index) {\n    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);\n    workspace_offsets.push_back(total_workspace_bytes);\n\n    auto alloc_bytes = workspace_bytes[index];\n    total_workspace_bytes += alloc_bytes;\n  }\n\n  // Allocate the workspace\n  Workspace ws(total_workspace_bytes);\n\n  std::vector<void*> workspace;\n  workspace.push_back(minibatch_mean.data_ptr<float>());\n  workspace.push_back(minibatch_inv_var.data_ptr<float>());\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const int retired_cta_bytes = workspace_bytes[2];\n  void* retired_ctas = ret_cta.data_ptr<uint8_t>();\n  assert(ret_cta.size(0) >= retired_cta_bytes);\n  workspace.push_back(retired_ctas);\n\n  for (auto index = 3; index < workspace_bytes.size(); ++index) {\n    void* ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index - 3];\n    workspace.push_back(ptr);\n  }\n\n  bn->setWorkspacePointers(workspace, workspace_bytes);\n\n  // Don't fuse in ReLU for now at least\n  bn->fwd(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);\n\n  return y;\n}\n\nat::Tensor nhwc_bn_fwd_eval(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias,\n                            const at::Tensor& running_mean, const at::Tensor& running_inv_var,\n                            const at::Tensor& ret_cta, const int bn_group, const float momentum, const float epsilon,\n                            const bool fuse_relu) {\n  const int N = x.size(0);\n  const int H = x.size(1);\n  const int W = x.size(2);\n  const int C = x.size(3);\n\n  // Allocate output tensor\n  at::Tensor y = at::empty({N, H, W, C}, x.options());\n\n  // Create wrapper\n  NhwcBatchNorm* bn = new NhwcBatchNorm();\n\n  bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);\n  bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);\n\n  bn->setConstants(momentum, epsilon);\n\n  // set pointers within the wrapper\n  bn->setInputOutputPointers(x.data_ptr<at::Half>(), nullptr, y.data_ptr<at::Half>(), nullptr);\n\n  bn->setWeightPointers({scale.data_ptr<float>(), bias.data_ptr<float>()}, {nullptr, nullptr});\n  bn->setParameterPointers({running_mean.data_ptr<float>(), running_inv_var.data_ptr<float>()});\n\n  // deal with workspace(s)\n  auto workspace_bytes = bn->numWorkspaceBytes();\n  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset\n  // an allocated workspace for the others\n  size_t total_workspace_bytes = 0;\n  std::vector<size_t> workspace_offsets;\n\n  for (auto index = 3; index < workspace_bytes.size(); ++index) {\n    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);\n    workspace_offsets.push_back(total_workspace_bytes);\n\n    auto alloc_bytes = workspace_bytes[index];\n    total_workspace_bytes += alloc_bytes;\n  }\n\n  // Allocate the workspace\n  Workspace ws(total_workspace_bytes);\n\n  std::vector<void*> workspace;\n  workspace.push_back(nullptr);\n  workspace.push_back(nullptr);\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const int retired_cta_bytes = workspace_bytes[2];\n  void* retired_ctas = ret_cta.data_ptr<uint8_t>();\n  assert(ret_cta.size(0) >= retired_cta_bytes);\n  workspace.push_back(retired_ctas);\n\n  for (auto index = 3; index < workspace_bytes.size(); ++index) {\n    void* ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index - 3];\n    workspace.push_back(ptr);\n  }\n\n  bn->setWorkspacePointers(workspace, workspace_bytes);\n\n  // Don't fuse in ReLU for now at least\n  bn->fwdInference(stream, fuse_relu);\n\n  return y;\n}\n\nstd::vector<at::Tensor> nhwc_bn_bwd(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale,\n                                    const at::Tensor& bias, const at::Tensor& running_mean,\n                                    const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean,\n                                    const at::Tensor& minibatch_inv_var, const at::Tensor& ret_cta,\n                                    const float momentum, const float epsilon, const bool fuse_relu, void* my_data,\n                                    void* pair_data, void* pair_data2, void* pair_data3, const int bn_group,\n                                    const at::Tensor& magic_tensor, const int occupancy, const int grid_dim_x,\n                                    const bool coop) {\n  // shape\n  const int N = x.size(0);\n  const int H = x.size(1);\n  const int W = x.size(2);\n  const int C = x.size(3);\n\n  // generating new magic number and use that for sync\n  int* magic = magic_tensor.data_ptr<int>();\n  *magic = (*magic + 1) & 0xff;\n\n  // outputs\n  at::Tensor x_grad, scale_grad, bias_grad;\n\n  // Allocate outputs\n  x_grad = at::empty_like(x);\n  scale_grad = at::empty_like(scale);\n  bias_grad = at::empty_like(bias);\n\n  // Create wrapper\n  NhwcBatchNorm* bn = new NhwcBatchNorm();\n\n  bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);\n  bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);\n\n  bn->setConstants(momentum, epsilon);\n\n  // set pointers within the wrapper\n  bn->setInputOutputPointers(x.data_ptr<at::Half>(), x_grad.data_ptr<at::Half>(), nullptr, dy.data_ptr<at::Half>());\n\n  bn->setWeightPointers({scale.data_ptr<float>(), bias.data_ptr<float>()},\n                        {scale_grad.data_ptr<float>(), bias_grad.data_ptr<float>()});\n  bn->setParameterPointers({running_mean.data_ptr<float>(), running_inv_var.data_ptr<float>()});\n\n  // deal with workspace(s)\n  auto workspace_bytes = bn->numWorkspaceBytes();\n  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset\n  // an allocated workspace for the others\n  size_t total_workspace_bytes = 0;\n  std::vector<size_t> workspace_offsets;\n\n  for (auto index = 3; index < workspace_bytes.size(); ++index) {\n    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);\n    workspace_offsets.push_back(total_workspace_bytes);\n\n    auto alloc_bytes = workspace_bytes[index];\n    total_workspace_bytes += alloc_bytes;\n  }\n\n  // Allocate the workspace\n  Workspace ws(total_workspace_bytes);\n\n  std::vector<void*> workspace;\n  workspace.push_back(minibatch_mean.data_ptr<float>());\n  workspace.push_back(minibatch_inv_var.data_ptr<float>());\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const int retired_cta_bytes = workspace_bytes[2];\n  void* retired_ctas = ret_cta.data_ptr<uint8_t>();\n  assert(ret_cta.size(0) >= retired_cta_bytes);\n  workspace.push_back(retired_ctas);\n\n  for (auto index = 3; index < workspace_bytes.size(); ++index) {\n    void* ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index - 3];\n    workspace.push_back(ptr);\n  }\n\n  bn->setWorkspacePointers(workspace, workspace_bytes);\n\n  bn->dgrad(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x,\n            coop);\n\n  return std::vector<at::Tensor>{x_grad, scale_grad, bias_grad};\n}\n\nint nhwc_bn_fwd_occupancy() {\n  int device_id = -1;\n  cudaGetDevice(&device_id);\n\n  // max occupancy supported by the code is 2\n  return NhwcBatchNorm::smem_driven_fwd_occupancy(device_id, 2);\n}\n\nint nhwc_bn_bwd_occupancy() {\n  int device_id = -1;\n  cudaGetDevice(&device_id);\n\n  // max occupancy supported by the code is 2\n  return NhwcBatchNorm::smem_driven_bwd_occupancy(device_id, 2);\n}\n"
  },
  {
    "path": "apex/contrib/csrc/groupbn/batch_norm.h",
    "content": "/*\n * Licensed to the Apache Software Foundation (ASF) under one\n * or more contributor license agreements.  See the NOTICE file\n * distributed with this work for additional information\n * regarding copyright ownership.  The ASF licenses this file\n * to you under the Apache License, Version 2.0 (the\n * \"License\"); you may not use this file except in compliance\n * with the License.  You may obtain a copy of the License at\n *\n *   http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied.  See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n\n/*!\n * Copyright (c) 2018 by Contributors\n * \\file nhwc_batch_norm.h\n * \\brief CUDA NHWC Batch Normalization code\n * \\author Shankara Rao Thejaswi Nanditale, Dick Carter, Evgeni Krimer\n */\n#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_\n#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_\n\n#include <cudnn.h>\n\n#include <algorithm>\n#include <iostream>\n#include <string>\n#include <vector>\n\n#include \"cuda_utils.h\"\n#include \"nhwc_batch_norm_kernel.h\"\n\n#define VERBOSE_DEFAULT false\n\nclass NhwcBatchNorm {\n public:\n  NhwcBatchNorm() {\n    name_ = \"nhwc_batchnorm\";\n    createTensorDescriptor(&X_tensor_desc_);\n    createTensorDescriptor(&Y_tensor_desc_);\n  }\n\n  ~NhwcBatchNorm() {\n    destroyTensorDescriptor(X_tensor_desc_);\n    destroyTensorDescriptor(Y_tensor_desc_);\n  }\n\n  void die() {\n    std::cerr << \"batchnorm not initialized\" << std::endl;\n    exit(-1);\n  }\n\n  void fwd(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,\n           const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);\n  void dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,\n             const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);\n  void fwdInference(cudaStream_t stream, bool use_relu);\n  dim3 calc_fwd_grid(int* loop, const int grid_dim_x);\n  dim3 calc_bwd_grid(int* loop, const int grid_dim_x);\n\n  void setInputDescriptor(const cudnnTensorFormat_t format, const cudnnDataType_t data_type, int n, int c, int h, int w,\n                          int bn_group) {\n    m_ = n * h * w;\n    int m_bn_adjusted = m_ * bn_group;\n    c_ = c;\n    // factor to scale sum of squared errors to get saved variance.  Must be 1/nhw.\n    svar_inv_count_ = 1.f / m_bn_adjusted;\n    // factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1).\n    int divisor = m_bn_adjusted - 1;\n    // nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs.\n    rvar_inv_count_ = divisor == 0 ? 1.f : 1.f / divisor;\n    setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w);\n  }\n\n  void setOutputDescriptor(const cudnnTensorFormat_t format, const cudnnDataType_t data_type, int n, int c, int h,\n                           int w) {\n    setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w);\n  }\n\n  const std::vector<size_t> numWorkspaceBytes() const;\n\n  void setWorkspacePointers(const std::vector<void*>& workspace, const std::vector<size_t>& num_workspace_bytes);\n\n  void setInputOutputPointers(void* X, void* dX, void* Y, void* dY) {\n    X_ = X;\n    dX_ = dX;\n    Y_ = Y;\n    dY_ = dY;\n  }\n\n  // Sets the pointers for the scale and weight (in that order) data and derivative buffers.\n  void setWeightPointers(const std::vector<void*>& weight_pointers, const std::vector<void*>& deriv_pointers) {\n    assert(weight_pointers.size() == 2);\n    assert(deriv_pointers.size() == 2);\n    scale_ = static_cast<float*>(weight_pointers[0]);\n    bias_ = static_cast<float*>(weight_pointers[1]);\n    dscale_ = static_cast<float*>(deriv_pointers[0]);\n    dbias_ = static_cast<float*>(deriv_pointers[1]);\n  }\n\n  // Sets the pointers for the population mean and variance buffers, in that order.\n  void setParameterPointers(const std::vector<void*>& param_pointers) {\n    assert(param_pointers.size() == 2);\n    population_mean_ = static_cast<float*>(param_pointers[0]);\n    population_variance_ = static_cast<float*>(param_pointers[1]);\n  }\n\n  void setConstants(const double exp_avg_factor, const double eps) {\n    exp_avg_factor_ = exp_avg_factor;\n    eps_ = eps;\n  }\n\n  void processCudnnStatus(const cudnnStatus_t& status, const std::string& string = std::string(),\n                          bool verbose = VERBOSE_DEFAULT) {\n    if (status != CUDNN_STATUS_SUCCESS)\n      LOG(FATAL) << string << \" \" << cudnnGetErrorString(status);\n    else if (verbose)\n      LOG(INFO) << string << \" \" << cudnnGetErrorString(status);\n  }\n\n  void checkCudaStatus(const std::string& string = std::string(), bool verbose = VERBOSE_DEFAULT) {\n    cudaError_t status = cudaGetLastError();\n    if (status != cudaSuccess)\n      LOG(FATAL) << string << \" \" << cudaGetErrorString(status);\n    else if (verbose)\n      LOG(INFO) << string << \" \" << cudaGetErrorString(status);\n  }\n\n  size_t size_retired_ctas(int grid_y) const {\n    // Note that the value of max_grid_y to handle known GPUs is about 160.\n    const int max_grid_y = 1024;\n    if (grid_y > max_grid_y) LOG(INFO) << \"GPU capabilities exceeds assumptions.\";\n    const int retired_cta_bytes = max_grid_y * 2 * sizeof(int);\n    // Since the region will be initialized once and used for many kernels,\n    // the idea is to return an ample size that will cover all uses.\n    return retired_cta_bytes;\n  }\n\n  cudnnTensorDescriptor_t X_tensor_desc_ = nullptr;\n  cudnnTensorDescriptor_t Y_tensor_desc_ = nullptr;\n\n  void* X_ = nullptr;\n  void* dX_ = nullptr;\n  void* Y_ = nullptr;\n  void* dY_ = nullptr;\n\n  // Learned scale and bias weights.\n  float* scale_ = nullptr;\n  float* dscale_ = nullptr;\n  float* bias_ = nullptr;\n  float* dbias_ = nullptr;\n\n  // Computed population mean and variance parameters.\n  float* population_mean_ = nullptr;\n  float* population_variance_ = nullptr;\n\n  // Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd).\n  float* minibatch_mean_ = nullptr;\n  float* minibatch_variance_ = nullptr;\n\n  int m_ = 0;  // Number of values per channel that BN is normalizing.\n  int c_ = 0;  // Number of channels over which BN is normalizing.\n\n  float svar_inv_count_ = 0.f;  // factor to scale sum of squared errors to get saved variance\n  float rvar_inv_count_ = 0.f;  // factor to scale sum of squared errors to get running variance\n\n  double exp_avg_factor_ = 0.;\n  double eps_ = 0.;\n  std::string name_;\n\n private:\n  void setTensorDescriptor(cudnnTensorDescriptor_t descriptor, cudnnTensorFormat_t format, cudnnDataType_t data_type,\n                           int n, int c, int h, int w) {\n    cudnnStatus_t status = CUDNN_STATUS_SUCCESS;\n    status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w);\n    processCudnnStatus(status, \"set tensor descriptor\");\n  }\n\n  void createTensorDescriptor(cudnnTensorDescriptor_t* descriptor) {\n    cudnnStatus_t status = CUDNN_STATUS_SUCCESS;\n    status = cudnnCreateTensorDescriptor(descriptor);\n    processCudnnStatus(status, \"create tensor_descriptor\");\n  }\n\n  void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) {\n    cudnnStatus_t status = CUDNN_STATUS_SUCCESS;\n    status = cudnnDestroyTensorDescriptor(descriptor);\n    processCudnnStatus(status, \"destroy tensor_descriptor\");\n  }\n\n protected:\n  float* partial_sums_ = nullptr;\n  int* partial_counts_ = nullptr;\n  int* retired_ctas_ = nullptr;\n\n  void _setFwdParams(NhwcBatchNormFwdParams* params) const;\n  void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams* params) const;\n  void _setBwdParams(NhwcBatchNormBwdParams* params) const;\n\n  // @todo: ability to configure these?\n  // Kernel params\n  static const int USE_ONLINE_APPROACH = 1;\n  static const int THREADS_PER_CTA = 512;\n  static const int THREADS_PER_PIXEL = 16;\n  static const int C_ELEMENTS_PER_CTA = 64;\n  static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL;\n  static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024;\n\n  typedef uint16_t StorageType;\n  // typedef float StorageType;\n  //  increasing this to 6 causes spills in fwd kernel!\n  static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 5;\n  static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 3;\n  static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10;\n  static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5;\n\n  static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + PIXELS_PER_THREAD_IN_SMEM_FWD;\n  static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + PIXELS_PER_THREAD_IN_SMEM_BWD;\n  static const int PIXELS_PER_THREAD_FWD_INFERENCE = 4;\n\n  // Derived params\n  static const size_t SMEM_SIZE_FWD =\n      PIXELS_PER_THREAD_IN_SMEM_FWD * THREADS_PER_CTA * ELEMENTS_PER_LDG * sizeof(StorageType);\n  static const size_t SMEM_SIZE_BWD =\n      PIXELS_PER_THREAD_IN_SMEM_BWD * THREADS_PER_CTA * ELEMENTS_PER_LDG * 2 * sizeof(StorageType);\n  static const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n  static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA / THREADS_PER_PIXEL * PIXELS_PER_THREAD_FWD;\n  static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA / THREADS_PER_PIXEL * PIXELS_PER_THREAD_BWD;\n  static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA / THREADS_PER_PIXEL * PIXELS_PER_THREAD_FWD_INFERENCE;\n\n  // max grid.y in case of group bn is limited by exchange buffer size\n  static const int MAX_GBN_BLOCK_Y = 256;\n\n  // Helper function to launch the forward kernel.\n\n  // We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel\n  // version that was compiled with that occupancy in its launch bounds.  This way, we avoid\n  // needless register spills.\n  void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params, dim3 grid_dim, int outer_loops,\n                          bool use_relu, const int occupancy, const bool coop) {\n#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP)                          \\\n  do {                                                                                                                \\\n    CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \"Nhwc batchnorm kernel smem too big.\";                         \\\n    auto fwd_func =                                                                                                   \\\n        nhwc_batch_norm_fwd<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, PIXELS_PER_THREAD_IN_REGISTERS_FWD,      \\\n                            PIXELS_PER_THREAD_IN_SMEM_FWD, ELEMENTS_PER_LDG, USE_ONLINE_APPROACH, OUTER_LOOPS,        \\\n                            USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY>;                                          \\\n    if (COMPILED_FOR_OCCUPANCY > 1) {                                                                                 \\\n      cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100);                            \\\n      checkCudaStatus(name_ + \" fwd ser coop kernel (cudaFuncSetAttribute carveout)\");                                \\\n    }                                                                                                                 \\\n    void* params_ptr = static_cast<void*>(&params);                                                                   \\\n    using FWD_FUNC = decltype(nhwc_batch_norm_fwd<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL,                    \\\n                                                  PIXELS_PER_THREAD_IN_REGISTERS_FWD, PIXELS_PER_THREAD_IN_SMEM_FWD,  \\\n                                                  ELEMENTS_PER_LDG, USE_ONLINE_APPROACH, OUTER_LOOPS, USE_RELU,       \\\n                                                  USE_ADD_RELU, COMPILED_FOR_OCCUPANCY>);                             \\\n    if (COOP) {                                                                                                       \\\n      cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, grid_dim, THREADS_PER_CTA, &params_ptr, SMEM_SIZE_FWD, stream); \\\n    } else {                                                                                                          \\\n      cudaLaunchKernel<FWD_FUNC>(fwd_func, grid_dim, THREADS_PER_CTA, &params_ptr, SMEM_SIZE_FWD, stream);            \\\n    }                                                                                                                 \\\n    checkCudaStatus(name_ + \" fwd ser coop kernel\");                                                                  \\\n  } while (0)\n\n    // Don't try for an occupancy > 2 as this will squeeze register use and create spills.\n    if (outer_loops == 1 && use_relu) {\n      if (occupancy >= 2)\n        LAUNCH_FWD_KERNEL(1, true, false, 2, coop);\n      else\n        LAUNCH_FWD_KERNEL(1, true, false, 1, coop);\n    } else if (outer_loops == 1 && !use_relu) {\n      if (occupancy >= 2)\n        LAUNCH_FWD_KERNEL(1, false, false, 2, coop);\n      else\n        LAUNCH_FWD_KERNEL(1, false, false, 1, coop);\n    } else if (use_relu) {\n      if (occupancy >= 2)\n        LAUNCH_FWD_KERNEL(0, true, false, 2, coop);\n      else\n        LAUNCH_FWD_KERNEL(0, true, false, 1, coop);\n    } else {\n      if (occupancy >= 2)\n        LAUNCH_FWD_KERNEL(0, false, false, 2, coop);\n      else\n        LAUNCH_FWD_KERNEL(0, false, false, 1, coop);\n    }\n#undef LAUNCH_FWD_KERNEL\n  }\n\n  // Helper function to launch the backward kernel.\n\n  void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params, dim3 grid_dim, int outer_loops,\n                          bool use_relu, const int occupancy, const bool coop) {\n#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP)                                                  \\\n  do {                                                                                                                \\\n    CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \"Nhwc batchnorm kernel smem too big.\";                         \\\n    auto bwd_func = nhwc_batch_norm_bwd<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL,                              \\\n                                        PIXELS_PER_THREAD_IN_REGISTERS_BWD, PIXELS_PER_THREAD_IN_SMEM_BWD,            \\\n                                        ELEMENTS_PER_LDG, USE_ONLINE_APPROACH, OUTER_LOOPS, COMPILED_FOR_OCCUPANCY>;  \\\n    if (COMPILED_FOR_OCCUPANCY > 1) {                                                                                 \\\n      cudaFuncSetAttribute(bwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100);                            \\\n      checkCudaStatus(name_ + \" bwd coop serial kernel (cudaFuncSetAttribute carveout)\");                             \\\n    }                                                                                                                 \\\n    void* params_ptr = static_cast<void*>(&params);                                                                   \\\n    using BWD_FUNC =                                                                                                  \\\n        decltype(nhwc_batch_norm_bwd<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL,                                 \\\n                                     PIXELS_PER_THREAD_IN_REGISTERS_BWD, PIXELS_PER_THREAD_IN_SMEM_BWD,               \\\n                                     ELEMENTS_PER_LDG, USE_ONLINE_APPROACH, OUTER_LOOPS, COMPILED_FOR_OCCUPANCY>);    \\\n    if (COOP) {                                                                                                       \\\n      cudaLaunchCooperativeKernel<BWD_FUNC>(bwd_func, grid_dim, THREADS_PER_CTA, &params_ptr, SMEM_SIZE_BWD, stream); \\\n    } else {                                                                                                          \\\n      cudaLaunchKernel<BWD_FUNC>(bwd_func, grid_dim, THREADS_PER_CTA, &params_ptr, SMEM_SIZE_BWD, stream);            \\\n    }                                                                                                                 \\\n    checkCudaStatus(name_ + \" bwd coop serial kernel\");                                                               \\\n  } while (0)\n\n#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP)                                              \\\n  do {                                                                                                                 \\\n    CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \"Nhwc batchnorm kernel smem too big.\";                          \\\n    auto bwd_relu_func =                                                                                               \\\n        nhwc_batch_norm_bwd_relu<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, PIXELS_PER_THREAD_IN_REGISTERS_BWD,  \\\n                                 PIXELS_PER_THREAD_IN_SMEM_BWD, ELEMENTS_PER_LDG, USE_ONLINE_APPROACH, OUTER_LOOPS,    \\\n                                 COMPILED_FOR_OCCUPANCY>;                                                              \\\n    if (COMPILED_FOR_OCCUPANCY > 1) {                                                                                  \\\n      cudaFuncSetAttribute(bwd_relu_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100);                        \\\n      checkCudaStatus(name_ + \" bwd-relu coop serial kernel (cudaFuncSetAttribute carveout)\");                         \\\n    }                                                                                                                  \\\n    void* params_ptr = static_cast<void*>(&params);                                                                    \\\n    using BWD_RELU_FUNC =                                                                                              \\\n        decltype(nhwc_batch_norm_bwd_relu<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL,                             \\\n                                          PIXELS_PER_THREAD_IN_REGISTERS_BWD, PIXELS_PER_THREAD_IN_SMEM_BWD,           \\\n                                          ELEMENTS_PER_LDG, USE_ONLINE_APPROACH, OUTER_LOOPS,                          \\\n                                          COMPILED_FOR_OCCUPANCY>);                                                    \\\n    if (COOP) {                                                                                                        \\\n      cudaLaunchCooperativeKernel<BWD_RELU_FUNC>(bwd_relu_func, grid_dim, THREADS_PER_CTA, &params_ptr, SMEM_SIZE_BWD, \\\n                                                 stream);                                                              \\\n    } else {                                                                                                           \\\n      cudaLaunchKernel<BWD_RELU_FUNC>(bwd_relu_func, grid_dim, THREADS_PER_CTA, &params_ptr, SMEM_SIZE_BWD, stream);   \\\n    }                                                                                                                  \\\n    checkCudaStatus(name_ + \" bwd-relu coop serial kernel\");                                                           \\\n  } while (0)\n\n    // Don't try for an occupancy > 2 as this will squeeze register use and create spills.\n    if (outer_loops == 1 && use_relu) {\n      if (occupancy >= 2)\n        LAUNCH_BWD_RELU_KERNEL(1, 2, coop);\n      else\n        LAUNCH_BWD_RELU_KERNEL(1, 1, coop);\n    } else if (outer_loops == 1 && !use_relu) {\n      if (occupancy >= 2)\n        LAUNCH_BWD_KERNEL(1, 2, coop);\n      else\n        LAUNCH_BWD_KERNEL(1, 1, coop);\n    } else if (use_relu) {\n      if (occupancy >= 2)\n        LAUNCH_BWD_RELU_KERNEL(0, 2, coop);\n      else\n        LAUNCH_BWD_RELU_KERNEL(0, 1, coop);\n    } else {\n      if (occupancy >= 2)\n        LAUNCH_BWD_KERNEL(0, 2, coop);\n      else\n        LAUNCH_BWD_KERNEL(0, 1, coop);\n    }\n#undef LAUNCH_BWD_KERNEL\n  }\n\n public:\n  // Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.\n  static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {\n    using namespace at::cuda::utils;\n    int fwd_reduction_bytes = THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG * sizeof(float);\n    int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes;\n    int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes;\n    return std::min(max_cta_per_sm, occupancy);\n  }\n\n  // Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.\n  static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) {\n    using namespace at::cuda::utils;\n    int bwd_reduction_bytes = THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG * sizeof(float);\n    int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes;\n    int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes;\n    return std::min(max_cta_per_sm, occupancy);\n  }\n};\n\nconst std::vector<size_t> NhwcBatchNorm::numWorkspaceBytes() const {\n  assert(c_ > 0);\n\n  // choose the max memory required between fwd/bwd passes\n  int grid_x_fwd = div_up(m_, PIXELS_PER_CTA_FWD);\n  int grid_x_bwd = div_up(m_, PIXELS_PER_CTA_BWD);\n  int grid_x = max(grid_x_fwd, grid_x_bwd);\n  int grid_y = div_up(c_, C_ELEMENTS_PER_CTA);\n\n  const size_t num_mean_bytes = c_ * sizeof(float);\n  const size_t num_variance_bytes = num_mean_bytes;\n  const size_t size_sums = grid_y * grid_x * THREADS_PER_PIXEL * ELEMENTS_PER_LDG * 2 * sizeof(float);\n  const size_t size_counts = grid_y * grid_x * sizeof(int);\n\n  return {num_mean_bytes, num_variance_bytes, size_retired_ctas(grid_y), size_sums, size_counts};\n}\n\nvoid NhwcBatchNorm::setWorkspacePointers(const std::vector<void*>& workspace,\n                                         const std::vector<size_t>& num_workspace_bytes) {\n  assert(workspace.size() == 5);\n  assert(num_workspace_bytes.size() == 5);\n\n  minibatch_mean_ = static_cast<float*>(workspace[0]);\n  minibatch_variance_ = static_cast<float*>(workspace[1]);\n  retired_ctas_ = static_cast<int*>(workspace[2]);\n  partial_sums_ = static_cast<float*>(workspace[3]);\n  partial_counts_ = static_cast<int*>(workspace[4]);\n}\n\nvoid NhwcBatchNorm::_setFwdParams(NhwcBatchNormFwdParams* params) const {\n  params->gmem_src = static_cast<uint16_t*>(X_);\n  params->gmem_dst = static_cast<uint16_t*>(Y_);\n  params->gmem_src1 = nullptr;\n  params->gmem_bias = bias_;\n  params->gmem_scale = scale_;\n  params->gmem_running_mean = population_mean_;\n  params->gmem_running_var = population_variance_;\n  params->gmem_saved_mean = minibatch_mean_;\n  params->gmem_saved_var = minibatch_variance_;\n  params->gmem_relu_bitmask = nullptr;\n  params->nhw = m_;\n  params->c = c_;\n  params->svar_inv_count = svar_inv_count_;\n  params->rvar_inv_count = rvar_inv_count_;\n  params->gmem_sums = partial_sums_;\n  params->gmem_counts = partial_counts_;\n  params->gmem_retired_ctas = retired_ctas_;\n  params->var_eps = eps_;\n  params->outer_loops = 0;\n  params->exp_avg_factor = static_cast<float>(exp_avg_factor_);\n  params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA);\n}\n\nvoid NhwcBatchNorm::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams* params) const {\n  params->gmem_src = static_cast<uint16_t*>(X_);\n  params->gmem_dst = static_cast<uint16_t*>(Y_);\n  params->gmem_src1 = nullptr;\n  params->gmem_bias = bias_;\n  params->gmem_scale = scale_;\n  params->gmem_mean = population_mean_;\n  params->gmem_var = population_variance_;\n  params->nhw = m_;\n  params->c = c_;\n  params->var_eps = eps_;\n}\n\nvoid NhwcBatchNorm::_setBwdParams(NhwcBatchNormBwdParams* params) const {\n  params->gmem_src = static_cast<uint16_t*>(X_);\n  params->gmem_dy = static_cast<uint16_t*>(dY_);\n  params->gmem_dst = static_cast<uint16_t*>(dX_);\n  params->gmem_dst1 = nullptr;\n  params->gmem_relu_bitmask = nullptr;\n  params->gmem_dscale = dscale_;\n  params->gmem_dbias = dbias_;\n  params->gmem_scale = scale_;\n  params->gmem_bias = bias_;\n  params->gmem_saved_mean = minibatch_mean_;\n  params->gmem_saved_var = minibatch_variance_;\n  params->nhw = m_;\n  params->c = c_;\n  params->svar_inv_count = svar_inv_count_;\n  params->gmem_sums = partial_sums_;\n  params->gmem_retired_ctas = retired_ctas_;\n  params->outer_loops = 0;\n  params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA);\n}\n\nvoid NhwcBatchNorm::fwdInference(cudaStream_t stream, bool use_relu) {\n  bool ptrs_are_set = X_tensor_desc_ != nullptr && Y_tensor_desc_ != nullptr && scale_ != nullptr &&\n                      bias_ != nullptr\n                      //      && minibatch_mean_ != nullptr\n                      //      && minibatch_variance_ != nullptr\n                      && population_mean_ != nullptr && population_variance_ != nullptr &&\n                      X_ != nullptr\n                      //      && dX_ != nullptr\n                      && Y_ != nullptr\n                      //      && dY_ != nullptr\n                      //      && dscale_ != nullptr\n                      //      && dbias_ != nullptr\n                      && partial_sums_ != nullptr && partial_counts_ != nullptr;\n\n  if (!ptrs_are_set) die();\n\n  dim3 grid_dim;\n  grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD_INFERENCE);\n  grid_dim.y = div_up(c_, C_ELEMENTS_PER_CTA);\n\n  // @todo: maybe just move this inside initialize routine?\n  NhwcBatchNormFwdInferenceParams params;\n  _setFwdInferenceParams(&params);\n\n  if (use_relu) {\n    nhwc_batch_norm_fwd_inference<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG, true, false>\n        <<<grid_dim, THREADS_PER_CTA, 0, stream>>>(params);\n    checkCudaStatus(name_ + \" fwd_inference-relu kernel\");\n  } else {\n    nhwc_batch_norm_fwd_inference<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG, false, false>\n        <<<grid_dim, THREADS_PER_CTA, 0, stream>>>(params);\n    checkCudaStatus(name_ + \" fwd_inference kernel\");\n  }\n}\n\ndim3 NhwcBatchNorm::calc_fwd_grid(int* loop, const int grid_dim_x) {\n  dim3 grid_dim;\n  grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD);\n  int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);\n  unsigned int max_grid_x = grid_dim_x;\n  if (grid_dim.x <= max_grid_x) {\n    *loop = 1;\n    if (max_grid_x / grid_dim.x > 1) {\n      grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));\n      assert(grid_dim.y < MAX_GBN_BLOCK_Y);  // FIXME: turn into a loop\n    } else {\n      grid_dim.y = 1;\n    }\n  } else {\n    grid_dim.x = max_grid_x;\n    grid_dim.y = 1;\n    int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_FWD * PIXELS_PER_LDG * grid_dim.x;\n    int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_FWD * PIXELS_PER_LDG * grid_dim.x;\n    *loop = div_up(nhw_in_regs, pixels_per_iteration);\n  }\n  return grid_dim;\n}\n\ndim3 NhwcBatchNorm::calc_bwd_grid(int* loop, const int grid_dim_x) {\n  dim3 grid_dim;\n  grid_dim.x = div_up(m_, PIXELS_PER_CTA_BWD);\n  int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);\n  unsigned int max_grid_x = grid_dim_x;\n  if (grid_dim.x <= max_grid_x) {\n    *loop = 1;\n    if (max_grid_x / grid_dim.x > 1) {\n      grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));\n      assert(grid_dim.y < MAX_GBN_BLOCK_Y);  // FIXME: turn into a loop\n    } else {\n      grid_dim.y = 1;\n    }\n  } else {\n    grid_dim.x = max_grid_x;\n    grid_dim.y = 1;\n    int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_BWD * PIXELS_PER_LDG * grid_dim.x;\n    int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_BWD * PIXELS_PER_LDG * grid_dim.x;\n    *loop = div_up(nhw_in_regs, pixels_per_iteration);\n  }\n  return grid_dim;\n}\n\nvoid NhwcBatchNorm::fwd(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2,\n                        void* pair_data3, const int bn_group, const int magic, const int occupancy,\n                        const int grid_dim_x, const bool coop) {\n  bool ptrs_are_set = X_tensor_desc_ != nullptr && Y_tensor_desc_ != nullptr && scale_ != nullptr && bias_ != nullptr &&\n                      minibatch_mean_ != nullptr && minibatch_variance_ != nullptr && population_mean_ != nullptr &&\n                      population_variance_ != nullptr &&\n                      X_ != nullptr\n                      //      && dX_ != nullptr\n                      && Y_ != nullptr\n                      //      && dY_ != nullptr\n                      //      && dscale_ != nullptr\n                      //      && dbias_ != nullptr\n                      && partial_sums_ != nullptr && partial_counts_ != nullptr && retired_ctas_ != nullptr;\n\n  if (!ptrs_are_set) die();\n\n  // reset of retired_cta_count no longer needed\n\n  NhwcBatchNormFwdParams params;\n  _setFwdParams(&params);\n  params.my_data = my_data;\n  params.pair_datas[0] = pair_data;\n  params.pair_datas[1] = pair_data2;\n  params.pair_datas[2] = pair_data3;\n  params.magic = magic;\n  params.sync_iters = (bn_group == 8) ? 3 : (bn_group >> 1);\n\n  dim3 grid_dim = calc_fwd_grid(&params.outer_loops, grid_dim_x);\n  _fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, occupancy, coop);\n}\n\nvoid NhwcBatchNorm::dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2,\n                          void* pair_data3, const int bn_group, const int magic, const int occupancy,\n                          const int grid_dim_x, const bool coop) {\n  bool ptrs_are_set = X_tensor_desc_ != nullptr && Y_tensor_desc_ != nullptr && scale_ != nullptr &&\n                      (bias_ != nullptr || !use_relu) && minibatch_mean_ != nullptr &&\n                      minibatch_variance_ != nullptr\n                      //      && population_mean_ != nullptr\n                      //      && population_variance_ != nullptr\n                      && X_ != nullptr &&\n                      dX_ != nullptr\n                      //      && Y_ != nullptr\n                      && dY_ != nullptr && dscale_ != nullptr && dbias_ != nullptr;\n\n  if (!ptrs_are_set) die();\n\n  // reset of retired_cta_count no longer needed\n\n  NhwcBatchNormBwdParams params;\n  _setBwdParams(&params);\n  params.my_data = my_data;\n  params.pair_datas[0] = pair_data;\n  params.pair_datas[1] = pair_data2;\n  params.pair_datas[2] = pair_data3;\n  params.magic = magic;\n  params.sync_iters = (bn_group == 8) ? 3 : (bn_group >> 1);\n  params.wgrad_coeff = 1.0 / bn_group;\n\n  dim3 grid_dim = calc_bwd_grid(&params.outer_loops, grid_dim_x);\n  _bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, occupancy, coop);\n}\n\n#endif  // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_\n"
  },
  {
    "path": "apex/contrib/csrc/groupbn/batch_norm_add_relu.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDACachingAllocator.h>\n#include <cuda.h>\n\n#include \"batch_norm_add_relu.h\"\n\n// FIXME move the common stuff to common h file\n#define cudaCheckErrors(msg)                                                                                  \\\n  do {                                                                                                        \\\n    cudaError_t __err = cudaGetLastError();                                                                   \\\n    if (__err != cudaSuccess) {                                                                               \\\n      fprintf(stderr, \"Fatal error: %s (%s at %s:%d)\\n\", msg, cudaGetErrorString(__err), __FILE__, __LINE__); \\\n      fprintf(stderr, \"*** FAILED - ABORTING\\n\");                                                             \\\n      exit(1);                                                                                                \\\n    }                                                                                                         \\\n  } while (0)\n\nstatic size_t round_up_to_multiple(size_t x, int multiple) { return ((x + multiple - 1) / multiple) * multiple; }\n\nstruct Workspace {\n  Workspace(size_t size) : size(size), data(NULL) {\n    auto& allocator = *::c10::cuda::CUDACachingAllocator::get();\n    dataPtr = allocator.allocate(size);\n    data = dataPtr.get();\n  }\n  Workspace(const Workspace&) = delete;\n  Workspace(Workspace&&) = default;\n  Workspace& operator=(Workspace&&) = default;\n  ~Workspace() = default;\n\n  size_t size;\n  void* data;\n  c10::DataPtr dataPtr;\n};\n\n// Return {y}\nat::Tensor nhwc_bn_addrelu_fwd_train(const at::Tensor& x, const at::Tensor& z, const at::Tensor& scale,\n                                     const at::Tensor& bias, const at::Tensor& running_mean,\n                                     const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean,\n                                     const at::Tensor& minibatch_inv_var, const at::Tensor& bitmask,\n                                     const at::Tensor& ret_cta, const float momentum, const float epsilon,\n                                     void* my_data, void* pair_data, void* pair_data2, void* pair_data3,\n                                     const int bn_group, const at::Tensor& magic_tensor, const int occupancy,\n                                     const int grid_dim_x, const bool coop) {\n  const int N = x.size(0);\n  const int H = x.size(1);\n  const int W = x.size(2);\n  const int C = x.size(3);\n\n  // generating new magic number and use that for sync\n  int* magic = magic_tensor.data_ptr<int>();\n  *magic = (*magic + 1) & 0xff;\n\n  // Allocate output tensor\n  at::Tensor y = at::empty({N, H, W, C}, x.options());\n\n  // Create wrapper\n  NhwcBatchNormAddRelu* bn = new NhwcBatchNormAddRelu();\n\n  bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);\n  bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);\n\n  bn->setConstants(momentum, epsilon);\n\n  // set pointers within the wrapper\n  bn->setInputOutputPointers(x.data_ptr<at::Half>(), nullptr, y.data_ptr<at::Half>(), nullptr, z.data_ptr<at::Half>(),\n                             nullptr);\n\n  bn->setWeightPointers({scale.data_ptr<float>(), bias.data_ptr<float>()}, {nullptr, nullptr});\n  bn->setParameterPointers({running_mean.data_ptr<float>(), running_inv_var.data_ptr<float>()});\n\n  // deal with workspace(s)\n  auto workspace_bytes = bn->numWorkspaceBytes();\n  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset\n  // an allocated workspace for the others\n  size_t total_workspace_bytes = 0;\n  std::vector<size_t> workspace_offsets;\n\n  for (auto index = 4; index < workspace_bytes.size(); ++index) {\n    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);\n    workspace_offsets.push_back(total_workspace_bytes);\n\n    auto alloc_bytes = workspace_bytes[index];\n    total_workspace_bytes += alloc_bytes;\n  }\n\n  // Allocate the workspace\n  Workspace ws(total_workspace_bytes);\n\n  std::vector<void*> workspace;\n  workspace.push_back(minibatch_mean.data_ptr<float>());\n  workspace.push_back(minibatch_inv_var.data_ptr<float>());\n  workspace.push_back(bitmask.data_ptr<int32_t>());\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const int retired_cta_bytes = workspace_bytes[3];\n  void* retired_ctas = ret_cta.data_ptr<uint8_t>();\n  assert(ret_cta.size(0) >= retired_cta_bytes);\n\n  workspace.push_back(retired_ctas);\n\n  for (auto index = 4; index < workspace_bytes.size(); ++index) {\n    void* ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index - 4];\n    workspace.push_back(ptr);\n  }\n\n  bn->setWorkspacePointers(workspace, workspace_bytes);\n\n  // Don't fuse in ReLU for now at least\n  bn->fwd(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);\n\n  return y;\n}\n\nat::Tensor nhwc_bn_addrelu_fwd_eval(const at::Tensor& x, const at::Tensor& z, const at::Tensor& scale,\n                                    const at::Tensor& bias, const at::Tensor& running_mean,\n                                    const at::Tensor& running_inv_var, const at::Tensor& ret_cta, const int bn_group,\n                                    const float momentum, const float epsilon) {\n  const int N = x.size(0);\n  const int H = x.size(1);\n  const int W = x.size(2);\n  const int C = x.size(3);\n\n  // Allocate output tensor\n  at::Tensor y = at::empty({N, H, W, C}, x.options());\n\n  // Create wrapper\n  NhwcBatchNormAddRelu* bn = new NhwcBatchNormAddRelu();\n\n  bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);\n  bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);\n\n  bn->setConstants(momentum, epsilon);\n\n  // set pointers within the wrapper\n  bn->setInputOutputPointers(x.data_ptr<at::Half>(), nullptr, y.data_ptr<at::Half>(), nullptr, z.data_ptr<at::Half>(),\n                             nullptr);\n\n  bn->setWeightPointers({scale.data_ptr<float>(), bias.data_ptr<float>()}, {nullptr, nullptr});\n  bn->setParameterPointers({running_mean.data_ptr<float>(), running_inv_var.data_ptr<float>()});\n\n  // deal with workspace(s)\n  auto workspace_bytes = bn->numWorkspaceBytes();\n  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset\n  // an allocated workspace for the others\n  size_t total_workspace_bytes = 0;\n  std::vector<size_t> workspace_offsets;\n\n  for (auto index = 4; index < workspace_bytes.size(); ++index) {\n    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);\n    workspace_offsets.push_back(total_workspace_bytes);\n\n    auto alloc_bytes = workspace_bytes[index];\n    total_workspace_bytes += alloc_bytes;\n  }\n\n  // Allocate the workspace\n  Workspace ws(total_workspace_bytes);\n\n  std::vector<void*> workspace;\n  workspace.push_back(nullptr);\n  workspace.push_back(nullptr);\n  workspace.push_back(nullptr);\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const int retired_cta_bytes = workspace_bytes[3];\n  void* retired_ctas = ret_cta.data_ptr<uint8_t>();\n  assert(ret_cta.size(0) >= retired_cta_bytes);\n  workspace.push_back(retired_ctas);\n\n  for (auto index = 4; index < workspace_bytes.size(); ++index) {\n    void* ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index - 4];\n    workspace.push_back(ptr);\n  }\n\n  bn->setWorkspacePointers(workspace, workspace_bytes);\n\n  // Don't fuse in ReLU for now at least\n  bn->fwdInference(stream);\n\n  return y;\n}\n\nstd::vector<at::Tensor> nhwc_bn_addrelu_bwd(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale,\n                                            const at::Tensor& bias, const at::Tensor& running_mean,\n                                            const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean,\n                                            const at::Tensor& minibatch_inv_var, const at::Tensor& bitmask,\n                                            const at::Tensor& ret_cta, const float momentum, const float epsilon,\n                                            void* my_data, void* pair_data, void* pair_data2, void* pair_data3,\n                                            const int bn_group, const at::Tensor& magic_tensor, const int occupancy,\n                                            const int grid_dim_x, const bool coop) {\n  // shape\n  const int N = x.size(0);\n  const int H = x.size(1);\n  const int W = x.size(2);\n  const int C = x.size(3);\n\n  // generating new magic number and use that for sync\n  int* magic = magic_tensor.data_ptr<int>();\n  *magic = (*magic + 1) & 0xff;\n\n  // outputs\n  at::Tensor x_grad, z_grad, scale_grad, bias_grad;\n\n  // Allocate outputs\n  x_grad = at::empty_like(x);\n  z_grad = at::empty_like(x);\n  scale_grad = at::empty_like(scale);\n  bias_grad = at::empty_like(bias);\n\n  // Create wrapper\n  NhwcBatchNormAddRelu* bn = new NhwcBatchNormAddRelu();\n\n  bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);\n  bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);\n\n  bn->setConstants(momentum, epsilon);\n\n  // set pointers within the wrapper\n  bn->setInputOutputPointers(x.data_ptr<at::Half>(), x_grad.data_ptr<at::Half>(), nullptr, dy.data_ptr<at::Half>(),\n                             nullptr, z_grad.data_ptr<at::Half>());\n\n  bn->setWeightPointers({scale.data_ptr<float>(), bias.data_ptr<float>()},\n                        {scale_grad.data_ptr<float>(), bias_grad.data_ptr<float>()});\n  bn->setParameterPointers({running_mean.data_ptr<float>(), running_inv_var.data_ptr<float>()});\n\n  // deal with workspace(s)\n  auto workspace_bytes = bn->numWorkspaceBytes();\n  // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset\n  // an allocated workspace for the others\n  size_t total_workspace_bytes = 0;\n  std::vector<size_t> workspace_offsets;\n\n  for (auto index = 4; index < workspace_bytes.size(); ++index) {\n    total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);\n    workspace_offsets.push_back(total_workspace_bytes);\n\n    auto alloc_bytes = workspace_bytes[index];\n    total_workspace_bytes += alloc_bytes;\n  }\n\n  // Allocate the workspace\n  Workspace ws(total_workspace_bytes);\n\n  std::vector<void*> workspace;\n  workspace.push_back(minibatch_mean.data_ptr<float>());\n  workspace.push_back(minibatch_inv_var.data_ptr<float>());\n  workspace.push_back(bitmask.data_ptr<int32_t>());\n\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const int retired_cta_bytes = workspace_bytes[3];\n  void* retired_ctas = ret_cta.data_ptr<uint8_t>();\n  assert(ret_cta.size(0) >= retired_cta_bytes);\n  workspace.push_back(retired_ctas);\n\n  for (auto index = 4; index < workspace_bytes.size(); ++index) {\n    void* ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index - 4];\n    workspace.push_back(ptr);\n  }\n\n  bn->setWorkspacePointers(workspace, workspace_bytes);\n\n  bn->dgrad(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);\n\n  return std::vector<at::Tensor>{x_grad, z_grad, scale_grad, bias_grad};\n}\n\nint nhwc_bn_addrelu_fwd_occupancy() {\n  int device_id = -1;\n  cudaGetDevice(&device_id);\n\n  // max occupancy supported by the code is 2\n  return NhwcBatchNormAddRelu::smem_driven_fwd_occupancy(device_id, 2);\n}\n\nint nhwc_bn_addrelu_bwd_occupancy() {\n  int device_id = -1;\n  cudaGetDevice(&device_id);\n\n  // max occupancy supported by the code is 2\n  return NhwcBatchNormAddRelu::smem_driven_bwd_occupancy(device_id, 2);\n}\n"
  },
  {
    "path": "apex/contrib/csrc/groupbn/batch_norm_add_relu.h",
    "content": "/*\n * Licensed to the Apache Software Foundation (ASF) under one\n * or more contributor license agreements.  See the NOTICE file\n * distributed with this work for additional information\n * regarding copyright ownership.  The ASF licenses this file\n * to you under the Apache License, Version 2.0 (the\n * \"License\"); you may not use this file except in compliance\n * with the License.  You may obtain a copy of the License at\n *\n *   http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied.  See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n\n/*!\n * Copyright (c) 2018 by Contributors\n * \\file nhwc_batch_norm_add_relu.h\n * \\brief CUDA NHWC Batch Normalization code with fused addition\n * \\author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer\n */\n#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_\n#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_\n\n#include <cudnn.h>\n\n#include <algorithm>\n#include <iostream>\n#include <string>\n#include <vector>\n\n#include \"cuda_utils.h\"\n#include \"nhwc_batch_norm_kernel.h\"\n\n#define VERBOSE_DEFAULT false\n\nclass NhwcBatchNormAddRelu {\n public:\n  NhwcBatchNormAddRelu() {\n    name_ = \"nhwc_batchnormaddrelu\";\n    createTensorDescriptor(&X_tensor_desc_);\n    createTensorDescriptor(&Y_tensor_desc_);\n  }\n\n  ~NhwcBatchNormAddRelu() {\n    destroyTensorDescriptor(X_tensor_desc_);\n    destroyTensorDescriptor(Y_tensor_desc_);\n  }\n\n  void die() {\n    std::cerr << \"batchnormaddrelu not initialized\" << std::endl;\n    exit(-1);\n  }\n\n  void fwd(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group,\n           const int magic, const int occupancy, const int grid_dim_x, const bool coop);\n  void dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,\n             const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);\n  void fwdInference(cudaStream_t stream);\n  dim3 calc_fwd_grid(int* loop, const int grid_dim_x);\n  dim3 calc_bwd_grid(int* loop, const int grid_dim_x);\n\n  void setInputDescriptor(const cudnnTensorFormat_t format, const cudnnDataType_t data_type, int n, int c, int h, int w,\n                          int bn_group) {\n    m_ = n * h * w;\n    int m_bn_adjusted = m_ * bn_group;\n    c_ = c;\n    // factor to scale sum of squared errors to get saved variance.  Must be 1/nhw.\n    svar_inv_count_ = 1.f / m_bn_adjusted;\n    // factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1).\n    int divisor = m_bn_adjusted - 1;\n    // nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs.\n    rvar_inv_count_ = divisor == 0 ? 1.f : 1.f / divisor;\n    setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w);\n  }\n\n  void setOutputDescriptor(const cudnnTensorFormat_t format, const cudnnDataType_t data_type, int n, int c, int h,\n                           int w) {\n    setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w);\n  }\n\n  const std::vector<size_t> numWorkspaceBytes() const;\n\n  void setWorkspacePointers(const std::vector<void*>& workspace, const std::vector<size_t>& num_workspace_bytes);\n\n  void setInputOutputPointers(void* X, void* dX, void* Y, void* dY, void* addend, void* dAddend) {\n    X_ = X;\n    dX_ = dX;\n    Y_ = Y;\n    dY_ = dY;\n    addend_ = addend;\n    dAddend_ = dAddend;\n  }\n\n  // Sets the pointers for the scale and weight (in that order) data and derivative buffers.\n  void setWeightPointers(const std::vector<void*>& weight_pointers, const std::vector<void*>& deriv_pointers) {\n    assert(weight_pointers.size() == 2);\n    assert(deriv_pointers.size() == 2);\n    scale_ = static_cast<float*>(weight_pointers[0]);\n    bias_ = static_cast<float*>(weight_pointers[1]);\n    dscale_ = static_cast<float*>(deriv_pointers[0]);\n    dbias_ = static_cast<float*>(deriv_pointers[1]);\n  }\n\n  // Sets the pointers for the population mean and variance buffers, in that order.\n  void setParameterPointers(const std::vector<void*>& param_pointers) {\n    assert(param_pointers.size() == 2);\n    population_mean_ = static_cast<float*>(param_pointers[0]);\n    population_variance_ = static_cast<float*>(param_pointers[1]);\n  }\n\n  void setConstants(const double exp_avg_factor, const double eps) {\n    exp_avg_factor_ = exp_avg_factor;\n    eps_ = eps;\n  }\n\n  void processCudnnStatus(const cudnnStatus_t& status, const std::string& string = std::string(),\n                          bool verbose = VERBOSE_DEFAULT) {\n    if (status != CUDNN_STATUS_SUCCESS)\n      LOG(FATAL) << string << \" \" << cudnnGetErrorString(status);\n    else if (verbose)\n      LOG(INFO) << string << \" \" << cudnnGetErrorString(status);\n  }\n\n  void checkCudaStatus(const std::string& string = std::string(), bool verbose = VERBOSE_DEFAULT) {\n    cudaError_t status = cudaGetLastError();\n    if (status != cudaSuccess)\n      LOG(FATAL) << string << \" \" << cudaGetErrorString(status);\n    else if (verbose)\n      LOG(INFO) << string << \" \" << cudaGetErrorString(status);\n  }\n\n  size_t size_retired_ctas(int grid_y) const {\n    // Note that the value of max_grid_y to handle known GPUs is about 160.\n    const int max_grid_y = 1024;\n    if (grid_y > max_grid_y) LOG(INFO) << \"GPU capabilities exceeds assumptions.\";\n    const int retired_cta_bytes = max_grid_y * 2 * sizeof(int);\n    // Since the region will be initialized once and used for many kernels,\n    // the idea is to return an ample size that will cover all uses.\n    return retired_cta_bytes;\n  }\n\n  cudnnTensorDescriptor_t X_tensor_desc_ = nullptr;\n  cudnnTensorDescriptor_t Y_tensor_desc_ = nullptr;\n\n  void* X_ = nullptr;\n  void* dX_ = nullptr;\n  void* Y_ = nullptr;\n  void* dY_ = nullptr;\n  void* addend_ = nullptr;\n  void* dAddend_ = nullptr;\n\n  // Learned scale and bias weights.\n  float* scale_ = nullptr;\n  float* dscale_ = nullptr;\n  float* bias_ = nullptr;\n  float* dbias_ = nullptr;\n\n  // Computed population mean and variance parameters.\n  float* population_mean_ = nullptr;\n  float* population_variance_ = nullptr;\n\n  // Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd).\n  float* minibatch_mean_ = nullptr;\n  float* minibatch_variance_ = nullptr;\n\n  int m_ = 0;  // Number of values per channel that BN is normalizing.\n  int c_ = 0;  // Number of channels over which BN is normalizing.\n\n  float svar_inv_count_ = 0.f;  // factor to scale sum of squared errors to get saved variance\n  float rvar_inv_count_ = 0.f;  // factor to scale sum of squared errors to get running variance\n\n  double exp_avg_factor_ = 0.;\n  double eps_ = 0.;\n  std::string name_;\n\n private:\n  void setTensorDescriptor(cudnnTensorDescriptor_t descriptor, cudnnTensorFormat_t format, cudnnDataType_t data_type,\n                           int n, int c, int h, int w) {\n    cudnnStatus_t status = CUDNN_STATUS_SUCCESS;\n    status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w);\n    processCudnnStatus(status, \"set tensor descriptor\");\n  }\n\n  void createTensorDescriptor(cudnnTensorDescriptor_t* descriptor) {\n    cudnnStatus_t status = CUDNN_STATUS_SUCCESS;\n    status = cudnnCreateTensorDescriptor(descriptor);\n    processCudnnStatus(status, \"create tensor_descriptor\");\n  }\n\n  void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) {\n    cudnnStatus_t status = CUDNN_STATUS_SUCCESS;\n    status = cudnnDestroyTensorDescriptor(descriptor);\n    processCudnnStatus(status, \"destroy tensor_descriptor\");\n  }\n\n protected:\n  float* partial_sums_ = nullptr;\n  int* partial_counts_ = nullptr;\n  int* retired_ctas_ = nullptr;\n  unsigned int* relu_bitmask_ = nullptr;\n\n  void _setFwdParams(NhwcBatchNormFwdParams* params) const;\n  void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams* params) const;\n  void _setBwdParams(NhwcBatchNormBwdParams* params) const;\n\n  // @todo: ability to configure these?\n  // Kernel params\n  static const int USE_ONLINE_APPROACH = 1;\n  static const int THREADS_PER_CTA = 512;\n  static const int THREADS_PER_PIXEL = 16;\n  static const int C_ELEMENTS_PER_CTA = 64;\n  static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL;\n  static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024;\n\n  typedef uint16_t StorageType;\n  // increasing this to 6 causes spills in fwd kernel!\n  static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 5;\n  static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 3;\n  static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10;\n  static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5;\n\n  static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + PIXELS_PER_THREAD_IN_SMEM_FWD;\n  static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + PIXELS_PER_THREAD_IN_SMEM_BWD;\n  static const int PIXELS_PER_THREAD_FWD_INFERENCE = 4;\n\n  // Derived params\n  static const size_t SMEM_SIZE_FWD =\n      PIXELS_PER_THREAD_IN_SMEM_FWD * THREADS_PER_CTA * ELEMENTS_PER_LDG * sizeof(StorageType);\n  static const size_t SMEM_SIZE_BWD =\n      PIXELS_PER_THREAD_IN_SMEM_BWD * THREADS_PER_CTA * ELEMENTS_PER_LDG * 2 * sizeof(StorageType);\n  static const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n  static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA / THREADS_PER_PIXEL * PIXELS_PER_THREAD_FWD;\n  static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA / THREADS_PER_PIXEL * PIXELS_PER_THREAD_BWD;\n  static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA / THREADS_PER_PIXEL * PIXELS_PER_THREAD_FWD_INFERENCE;\n\n  // max grid.y in case of group bn is limited by exchange buffer size\n  static const int MAX_GBN_BLOCK_Y = 256;\n\n  // Helper function to launch the forward kernel.\n\n  // We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel\n  // version that was compiled with that occupancy in its launch bounds.  This way, we avoid\n  // needless register spills.\n  void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params, dim3 grid_dim, int outer_loops,\n                          const int occupancy, const bool coop) {\n#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP)                          \\\n  do {                                                                                                                \\\n    CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \"Nhwc batchnormaddrelu kernel smem too big.\";                  \\\n    auto fwd_func =                                                                                                   \\\n        nhwc_batch_norm_fwd<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, PIXELS_PER_THREAD_IN_REGISTERS_FWD,      \\\n                            PIXELS_PER_THREAD_IN_SMEM_FWD, ELEMENTS_PER_LDG, USE_ONLINE_APPROACH, OUTER_LOOPS,        \\\n                            USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY>;                                          \\\n    if (COMPILED_FOR_OCCUPANCY > 1) {                                                                                 \\\n      cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100);                            \\\n      checkCudaStatus(name_ + \" fwd ser coop kernel (cudaFuncSetAttribute carveout)\");                                \\\n    }                                                                                                                 \\\n    void* params_ptr = static_cast<void*>(&params);                                                                   \\\n    using FWD_FUNC = decltype(nhwc_batch_norm_fwd<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL,                    \\\n                                                  PIXELS_PER_THREAD_IN_REGISTERS_FWD, PIXELS_PER_THREAD_IN_SMEM_FWD,  \\\n                                                  ELEMENTS_PER_LDG, USE_ONLINE_APPROACH, OUTER_LOOPS, USE_RELU,       \\\n                                                  USE_ADD_RELU, COMPILED_FOR_OCCUPANCY>);                             \\\n    if (COOP) {                                                                                                       \\\n      cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, grid_dim, THREADS_PER_CTA, &params_ptr, SMEM_SIZE_FWD, stream); \\\n    } else {                                                                                                          \\\n      cudaLaunchKernel<FWD_FUNC>(fwd_func, grid_dim, THREADS_PER_CTA, &params_ptr, SMEM_SIZE_FWD, stream);            \\\n    }                                                                                                                 \\\n    checkCudaStatus(name_ + \" fwd ser coop kernel\");                                                                  \\\n  } while (0)\n\n    // Don't try for an occupancy > 2 as this will squeeze register use and create spills.\n    if (outer_loops == 1) {\n      if (occupancy >= 2)\n        LAUNCH_FWD_KERNEL(1, false, true, 2, coop);\n      else\n        LAUNCH_FWD_KERNEL(1, false, true, 1, coop);\n    } else {\n      if (occupancy >= 2)\n        LAUNCH_FWD_KERNEL(0, false, true, 2, coop);\n      else\n        LAUNCH_FWD_KERNEL(0, false, true, 1, coop);\n    }\n#undef LAUNCH_FWD_KERNEL\n  }\n\n  // Helper function to launch the backward kernel.\n\n  void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params, dim3 grid_dim, int outer_loops,\n                          const int occupancy, const bool coop) {\n#define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP)                                       \\\n  do {                                                                                                              \\\n    CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \"Nhwc batchnormaddrelu kernel smem too big.\";                \\\n    auto bwd_add_relu_func =                                                                                        \\\n        nhwc_batch_norm_bwd_add_relu<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL,                               \\\n                                     PIXELS_PER_THREAD_IN_REGISTERS_BWD, PIXELS_PER_THREAD_IN_SMEM_BWD,             \\\n                                     ELEMENTS_PER_LDG, USE_ONLINE_APPROACH, OUTER_LOOPS, COMPILED_FOR_OCCUPANCY>;   \\\n    if (COMPILED_FOR_OCCUPANCY > 1) {                                                                               \\\n      cudaFuncSetAttribute(bwd_add_relu_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100);                 \\\n      checkCudaStatus(name_ + \" bwd-add-relu coop serial kernel (cudaFuncSetAttribute carveout)\");                  \\\n    }                                                                                                               \\\n    void* params_ptr = static_cast<void*>(&params);                                                                 \\\n    using BWD_ADD_RELU_FUNC =                                                                                       \\\n        decltype(nhwc_batch_norm_bwd_add_relu<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL,                      \\\n                                              PIXELS_PER_THREAD_IN_REGISTERS_BWD, PIXELS_PER_THREAD_IN_SMEM_BWD,    \\\n                                              ELEMENTS_PER_LDG, USE_ONLINE_APPROACH, OUTER_LOOPS,                   \\\n                                              COMPILED_FOR_OCCUPANCY>);                                             \\\n    if (COOP) {                                                                                                     \\\n      cudaLaunchCooperativeKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, grid_dim, THREADS_PER_CTA, &params_ptr,     \\\n                                                     SMEM_SIZE_BWD, stream);                                        \\\n    } else {                                                                                                        \\\n      cudaLaunchKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, grid_dim, THREADS_PER_CTA, &params_ptr, SMEM_SIZE_BWD, \\\n                                          stream);                                                                  \\\n    }                                                                                                               \\\n    checkCudaStatus(name_ + \" bwd-add-relu coop serial kernel\");                                                    \\\n  } while (0)\n\n    // Don't try for an occupancy > 2 as this will squeeze register use and create spills.\n    if (outer_loops == 1) {\n      if (occupancy >= 2)\n        LAUNCH_BWD_ADD_RELU_KERNEL(1, 2, coop);\n      else\n        LAUNCH_BWD_ADD_RELU_KERNEL(1, 1, coop);\n    } else {\n      if (occupancy >= 2)\n        LAUNCH_BWD_ADD_RELU_KERNEL(0, 2, coop);\n      else\n        LAUNCH_BWD_ADD_RELU_KERNEL(0, 1, coop);\n    }\n#undef LAUNCH_BWD_KERNEL\n  }\n\n public:\n  // Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.\n  static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {\n    using namespace at::cuda::utils;\n    int fwd_reduction_bytes = THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG * sizeof(float);\n    int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes;\n    int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes;\n    return std::min(max_cta_per_sm, occupancy);\n  }\n\n  // Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.\n  static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) {\n    using namespace at::cuda::utils;\n    int bwd_reduction_bytes = THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG * sizeof(float);\n    int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes;\n    int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes;\n    return std::min(max_cta_per_sm, occupancy);\n  }\n};\n\nconst std::vector<size_t> NhwcBatchNormAddRelu::numWorkspaceBytes() const {\n  assert(c_ > 0);\n\n  // choose the max memory required between fwd/bwd passes\n  int grid_x_fwd = div_up(m_, PIXELS_PER_CTA_FWD);\n  int grid_x_bwd = div_up(m_, PIXELS_PER_CTA_BWD);\n  int grid_x = max(grid_x_fwd, grid_x_bwd);\n  int grid_y = div_up(c_, C_ELEMENTS_PER_CTA);\n\n  const size_t num_mean_bytes = c_ * sizeof(float);\n  const size_t num_variance_bytes = num_mean_bytes;\n\n  int elems_per_group = ((m_ + 31) & ~31) * 2;\n  int group_count = div_up(c_, C_ELEMENTS_PER_CTA);\n  const size_t bitmask_bytes = elems_per_group * group_count * sizeof(unsigned int);\n\n  const size_t size_sums = grid_y * grid_x * THREADS_PER_PIXEL * ELEMENTS_PER_LDG * 2 * sizeof(float);\n  const size_t size_counts = grid_y * grid_x * sizeof(int);\n\n  return {num_mean_bytes, num_variance_bytes, bitmask_bytes, size_retired_ctas(grid_y), size_sums, size_counts};\n}\n\nvoid NhwcBatchNormAddRelu::setWorkspacePointers(const std::vector<void*>& workspace,\n                                                const std::vector<size_t>& num_workspace_bytes) {\n  assert(workspace.size() == 6);\n  assert(num_workspace_bytes.size() == 6);\n\n  minibatch_mean_ = static_cast<float*>(workspace[0]);\n  minibatch_variance_ = static_cast<float*>(workspace[1]);\n  relu_bitmask_ = static_cast<unsigned int*>(workspace[2]);\n  retired_ctas_ = static_cast<int*>(workspace[3]);\n  partial_sums_ = static_cast<float*>(workspace[4]);\n  partial_counts_ = static_cast<int*>(workspace[5]);\n}\n\nvoid NhwcBatchNormAddRelu::_setFwdParams(NhwcBatchNormFwdParams* params) const {\n  params->gmem_src = static_cast<uint16_t*>(X_);\n  params->gmem_dst = static_cast<uint16_t*>(Y_);\n  params->gmem_src1 = static_cast<uint16_t*>(addend_);\n  params->gmem_bias = bias_;\n  params->gmem_scale = scale_;\n  params->gmem_running_mean = population_mean_;\n  params->gmem_running_var = population_variance_;\n  params->gmem_saved_mean = minibatch_mean_;\n  params->gmem_saved_var = minibatch_variance_;\n  params->gmem_relu_bitmask = relu_bitmask_;\n  params->nhw = m_;\n  params->c = c_;\n  params->svar_inv_count = svar_inv_count_;\n  params->rvar_inv_count = rvar_inv_count_;\n  params->gmem_sums = partial_sums_;\n  params->gmem_counts = partial_counts_;\n  params->gmem_retired_ctas = retired_ctas_;\n  params->var_eps = eps_;\n  params->outer_loops = 0;\n  params->exp_avg_factor = static_cast<float>(exp_avg_factor_);\n  params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA);\n}\n\nvoid NhwcBatchNormAddRelu::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams* params) const {\n  params->gmem_src = static_cast<uint16_t*>(X_);\n  params->gmem_dst = static_cast<uint16_t*>(Y_);\n  params->gmem_src1 = static_cast<uint16_t*>(addend_);\n  params->gmem_bias = bias_;\n  params->gmem_scale = scale_;\n  params->gmem_mean = population_mean_;\n  params->gmem_var = population_variance_;\n  params->nhw = m_;\n  params->c = c_;\n  params->var_eps = eps_;\n}\n\nvoid NhwcBatchNormAddRelu::_setBwdParams(NhwcBatchNormBwdParams* params) const {\n  params->gmem_src = static_cast<uint16_t*>(X_);\n  params->gmem_dy = static_cast<uint16_t*>(dY_);\n  params->gmem_dst = static_cast<uint16_t*>(dX_);\n  params->gmem_dst1 = static_cast<uint16_t*>(dAddend_);\n  params->gmem_relu_bitmask = relu_bitmask_;\n  params->gmem_dscale = dscale_;\n  params->gmem_dbias = dbias_;\n  params->gmem_scale = scale_;\n  params->gmem_bias = bias_;\n  params->gmem_saved_mean = minibatch_mean_;\n  params->gmem_saved_var = minibatch_variance_;\n  params->nhw = m_;\n  params->c = c_;\n  params->svar_inv_count = svar_inv_count_;\n  params->gmem_sums = partial_sums_;\n  params->gmem_retired_ctas = retired_ctas_;\n  params->outer_loops = 0;\n  params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA);\n}\n\nvoid NhwcBatchNormAddRelu::fwdInference(cudaStream_t stream) {\n  bool ptrs_are_set = X_tensor_desc_ != nullptr && Y_tensor_desc_ != nullptr && scale_ != nullptr &&\n                      bias_ != nullptr\n                      //      && minibatch_mean_ != nullptr\n                      //      && minibatch_variance_ != nullptr\n                      && population_mean_ != nullptr && population_variance_ != nullptr &&\n                      X_ != nullptr\n                      //      && dX_ != nullptr\n                      && Y_ != nullptr &&\n                      addend_ != nullptr\n                      //      && dY_ != nullptr\n                      //      && dscale_ != nullptr\n                      //      && dbias_ != nullptr\n                      && partial_sums_ != nullptr && partial_counts_ != nullptr;\n\n  if (!ptrs_are_set) die();\n\n  dim3 grid_dim;\n  grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD_INFERENCE);\n  grid_dim.y = div_up(c_, C_ELEMENTS_PER_CTA);\n\n  // @todo: maybe just move this inside initialize routine?\n  NhwcBatchNormFwdInferenceParams params;\n  _setFwdInferenceParams(&params);\n\n  nhwc_batch_norm_fwd_inference<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG, false, true>\n      <<<grid_dim, THREADS_PER_CTA, 0, stream>>>(params);\n  checkCudaStatus(name_ + \" fwd_inference-relu kernel\");\n}\n\ndim3 NhwcBatchNormAddRelu::calc_fwd_grid(int* loop, const int grid_dim_x) {\n  dim3 grid_dim;\n  grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD);\n  int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);\n  unsigned int max_grid_x = grid_dim_x;\n  if (grid_dim.x <= max_grid_x) {\n    *loop = 1;\n    if (max_grid_x / grid_dim.x > 1) {\n      grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));\n      assert(grid_dim.y < MAX_GBN_BLOCK_Y);  // FIXME: turn into a loop\n    } else {\n      grid_dim.y = 1;\n    }\n  } else {\n    grid_dim.x = max_grid_x;\n    grid_dim.y = 1;\n    int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_FWD * PIXELS_PER_LDG * grid_dim.x;\n    int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_FWD * PIXELS_PER_LDG * grid_dim.x;\n    *loop = div_up(nhw_in_regs, pixels_per_iteration);\n  }\n  return grid_dim;\n}\n\ndim3 NhwcBatchNormAddRelu::calc_bwd_grid(int* loop, const int grid_dim_x) {\n  dim3 grid_dim;\n  grid_dim.x = div_up(m_, PIXELS_PER_CTA_BWD);\n  int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);\n  unsigned int max_grid_x = grid_dim_x;\n  if (grid_dim.x <= max_grid_x) {\n    *loop = 1;\n    if (max_grid_x / grid_dim.x > 1) {\n      grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));\n      assert(grid_dim.y < MAX_GBN_BLOCK_Y);  // FIXME: turn into a loop\n    } else {\n      grid_dim.y = 1;\n    }\n  } else {\n    grid_dim.x = max_grid_x;\n    grid_dim.y = 1;\n    int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_BWD * PIXELS_PER_LDG * grid_dim.x;\n    int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_BWD * PIXELS_PER_LDG * grid_dim.x;\n    *loop = div_up(nhw_in_regs, pixels_per_iteration);\n  }\n  return grid_dim;\n}\n\nvoid NhwcBatchNormAddRelu::fwd(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,\n                               const int bn_group, const int magic, const int occupancy, const int grid_dim_x,\n                               const bool coop) {\n  bool ptrs_are_set = X_tensor_desc_ != nullptr && Y_tensor_desc_ != nullptr && scale_ != nullptr && bias_ != nullptr &&\n                      minibatch_mean_ != nullptr && minibatch_variance_ != nullptr && relu_bitmask_ != nullptr &&\n                      population_mean_ != nullptr && population_variance_ != nullptr &&\n                      X_ != nullptr\n                      //      && dX_ != nullptr\n                      && Y_ != nullptr &&\n                      addend_ != nullptr\n                      //      && dY_ != nullptr\n                      //      && dscale_ != nullptr\n                      //      && dbias_ != nullptr\n                      && partial_sums_ != nullptr && partial_counts_ != nullptr && retired_ctas_ != nullptr;\n\n  if (!ptrs_are_set) die();\n\n  // reset of retired_cta_count no longer needed\n\n  NhwcBatchNormFwdParams params;\n  _setFwdParams(&params);\n\n  params.my_data = my_data;\n  params.pair_datas[0] = pair_data;\n  params.pair_datas[1] = pair_data2;\n  params.pair_datas[2] = pair_data3;\n  params.magic = magic;\n  params.sync_iters = (bn_group == 8) ? 3 : (bn_group >> 1);\n\n  dim3 grid_dim = calc_fwd_grid(&params.outer_loops, grid_dim_x);\n  _fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, occupancy, coop);\n}\n\nvoid NhwcBatchNormAddRelu::dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2,\n                                 void* pair_data3, const int bn_group, const int magic, const int occupancy,\n                                 const int grid_dim_x, const bool coop) {\n  bool ptrs_are_set = X_tensor_desc_ != nullptr && Y_tensor_desc_ != nullptr && scale_ != nullptr && bias_ != nullptr &&\n                      minibatch_mean_ != nullptr && minibatch_variance_ != nullptr &&\n                      relu_bitmask_ != nullptr\n                      //      && population_mean_ != nullptr\n                      //      && population_variance_ != nullptr\n                      && X_ != nullptr &&\n                      dX_ != nullptr\n                      //      && Y_ != nullptr\n                      && dY_ != nullptr && dAddend_ != nullptr && dscale_ != nullptr && dbias_ != nullptr &&\n                      retired_ctas_ != nullptr;\n\n  if (!ptrs_are_set) die();\n\n  // reset of retired_cta_count no longer needed\n\n  NhwcBatchNormBwdParams params;\n  _setBwdParams(&params);\n\n  params.my_data = my_data;\n  params.pair_datas[0] = pair_data;\n  params.pair_datas[1] = pair_data2;\n  params.pair_datas[2] = pair_data3;\n  params.magic = magic;\n  params.sync_iters = (bn_group == 8) ? 3 : (bn_group >> 1);\n  params.wgrad_coeff = 1.0 / bn_group;\n\n  dim3 grid_dim = calc_bwd_grid(&params.outer_loops, grid_dim_x);\n  _bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, occupancy, coop);\n}\n\n#endif  // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_\n"
  },
  {
    "path": "apex/contrib/csrc/groupbn/cuda_utils.h",
    "content": "#include <ATen/cuda/CUDAContext.h>\n#ifndef CUDA_UTILS_H\n#define CUDA_UTILS_H\n\nnamespace at {\nnamespace cuda {\n\nnamespace utils {\n\nstatic inline int MaxSharedMemoryPerMultiprocessor(int device_id) {\n  return getDeviceProperties(device_id)->sharedMemPerMultiprocessor;\n}\n\n}  // namespace utils\n}  // namespace cuda\n}  // namespace at\n\n#endif\n"
  },
  {
    "path": "apex/contrib/csrc/groupbn/interface.cpp",
    "content": "#include <ATen/ATen.h>\n#include <ATen/ArrayRef.h>\n#include <ATen/ScalarType.h>\n#include <pybind11/numpy.h>\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include <torch/extension.h>\n\n#include \"ATen/Generator.h\"\n#include \"ATen/Scalar.h\"\n#include \"ATen/Storage.h\"\n#include \"ATen/Tensor.h\"\n\nnamespace py = pybind11;\n\nint64_t get_buffer_size(const int bn_sync_steps);\n\nvoid* get_data_ptr(const at::Tensor& data);\n\nvoid* get_remote_data_ptr(const at::Tensor& handle, const int64_t offset);\n\nvoid close_remote_data(const at::Tensor& handle);\n\nat::Tensor nhwc_bn_fwd_train(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias,\n                             const at::Tensor& running_mean, const at::Tensor& running_inv_var,\n                             const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var,\n                             const at::Tensor& ret_cta, const float momentum, const float epsilon, const bool fuse_relu,\n                             void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group,\n                             const at::Tensor& magic_tensor, const int occupancy, const int grid_dim_x,\n                             const bool coop);\n\nat::Tensor nhwc_bn_fwd_eval(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias,\n                            const at::Tensor& running_mean, const at::Tensor& running_inv_var,\n                            const at::Tensor& ret_cta, const int bn_group, const float momentum, const float epsilon,\n                            const bool fuse_relu);\n\nstd::vector<at::Tensor> nhwc_bn_bwd(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale,\n                                    const at::Tensor& bias, const at::Tensor& running_mean,\n                                    const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean,\n                                    const at::Tensor& minibatch_inv_var, const at::Tensor& ret_cta,\n                                    const float momentum, const float epsilon, const bool fuse_relu, void* my_data,\n                                    void* pair_data, void* pair_data2, void* pair_data3, const int bn_group,\n                                    const at::Tensor& magic_tensor, const int occupancy, const int grid_dim_x,\n                                    const bool coop);\n\nat::Tensor nhwc_bn_addrelu_fwd_train(const at::Tensor& x, const at::Tensor& z, const at::Tensor& scale,\n                                     const at::Tensor& bias, const at::Tensor& running_mean,\n                                     const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean,\n                                     const at::Tensor& minibatch_inv_var, const at::Tensor& bitmask,\n                                     const at::Tensor& ret_cta, const float momentum, const float epsilon,\n                                     void* my_data, void* pair_data, void* pair_data2, void* pair_data3,\n                                     const int bn_group, const at::Tensor& magic_tensor, const int occupancy,\n                                     const int grid_dim_x, const bool coop);\n\nat::Tensor nhwc_bn_addrelu_fwd_eval(const at::Tensor& x, const at::Tensor& z, const at::Tensor& scale,\n                                    const at::Tensor& bias, const at::Tensor& running_mean,\n                                    const at::Tensor& running_inv_var, const at::Tensor& ret_cta, const int bn_group,\n                                    const float momentum, const float epsilon);\n\nstd::vector<at::Tensor> nhwc_bn_addrelu_bwd(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale,\n                                            const at::Tensor& bias, const at::Tensor& running_mean,\n                                            const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean,\n                                            const at::Tensor& minibatch_inv_var, const at::Tensor& bitmask,\n                                            const at::Tensor& ret_cta, const float momentum, const float epsilon,\n                                            void* my_data, void* pair_data, void* pair_data2, void* pair_data3,\n                                            const int bn_group, const at::Tensor& magic_tensor, const int occupancy,\n                                            const int grid_dim_x, const bool coop);\n\nint nhwc_bn_fwd_occupancy();\nint nhwc_bn_bwd_occupancy();\n\nint nhwc_bn_addrelu_fwd_occupancy();\nint nhwc_bn_addrelu_bwd_occupancy();\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"get_buffer_size\", &get_buffer_size, \"get_buffer_size\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"get_data_ptr\", &get_data_ptr, \"get_data_ptr\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"get_remote_data_ptr\", &get_remote_data_ptr, \"get_remote_data_ptr\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"close_remote_data\", &close_remote_data, \"close_remote_data\", py::call_guard<py::gil_scoped_release>());\n\n  m.def(\"bn_fwd_nhwc\", &nhwc_bn_fwd_train, \"bn_fwd_nhwc\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"bn_fwd_eval_nhwc\", &nhwc_bn_fwd_eval, \"bn_fwd_eval_nhwc\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"bn_bwd_nhwc\", &nhwc_bn_bwd, \"bn_bwd_nhwc\", py::call_guard<py::gil_scoped_release>());\n\n  m.def(\"bn_fwd_nhwc_occupancy\", &nhwc_bn_fwd_occupancy, \"bn_fwd_nhwc_occupancy\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"bn_bwd_nhwc_occupancy\", &nhwc_bn_bwd_occupancy, \"bn_bwd_nhwc_occupancy\",\n        py::call_guard<py::gil_scoped_release>());\n\n  m.def(\"bn_addrelu_fwd_nhwc\", &nhwc_bn_addrelu_fwd_train, \"bn_addrelu_fwd_nhwc\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"bn_addrelu_fwd_eval_nhwc\", &nhwc_bn_addrelu_fwd_eval, \"bn_addrelu_fwd_eval_nhwc\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"bn_addrelu_bwd_nhwc\", &nhwc_bn_addrelu_bwd, \"bn_addrelu_bwd_nhwc\", py::call_guard<py::gil_scoped_release>());\n\n  m.def(\"bn_addrelu_fwd_nhwc_occupancy\", &nhwc_bn_addrelu_fwd_occupancy, \"bn_addrelu_fwd_nhwc_occupancy\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"bn_addrelu_bwd_nhwc_occupancy\", &nhwc_bn_addrelu_bwd_occupancy, \"bn_addrelu_bwd_nhwc_occupancy\",\n        py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/groupbn/ipc.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n\n#define cudaCheckErrors(msg)                                                                                  \\\n  do {                                                                                                        \\\n    cudaError_t __err = cudaGetLastError();                                                                   \\\n    if (__err != cudaSuccess) {                                                                               \\\n      fprintf(stderr, \"Fatal error: %s (%s at %s:%d)\\n\", msg, cudaGetErrorString(__err), __FILE__, __LINE__); \\\n      fprintf(stderr, \"*** FAILED - ABORTING\\n\");                                                             \\\n      exit(1);                                                                                                \\\n    }                                                                                                         \\\n  } while (0)\n\ntemplate <>\nstruct std::hash<cudaIpcMemHandle_t> {\n  size_t operator()(const cudaIpcMemHandle_t& handle) const {\n    size_t hash = 0;\n    uint8_t* ptr = (uint8_t*)&handle;\n    assert(sizeof(uint8_t) == 1);\n    for (int i = 0; i < sizeof(cudaIpcMemHandle_t); i++) {\n      hash += *ptr;\n      ptr++;\n    }\n    return hash;\n  }\n};\n\ntemplate <>\nstruct std::equal_to<cudaIpcMemHandle_t> {\n  bool operator()(const cudaIpcMemHandle_t& lhs, const cudaIpcMemHandle_t& rhs) const {\n    return (std::memcmp((void*)&lhs, (void*)&rhs, sizeof(cudaIpcMemHandle_t)) == 0);\n  }\n};\n\nnamespace {\n\nnamespace gpuipc {\n// from: src/operator/nn/cudnn/nhwc_batch_norm_kernel.h\n//  The number of threads per pixel.\nconst int THREADS_PER_PIXEL = 16;\n// The number of elements per ldg.\nconst int ELEMENTS_PER_LDG = 4;\n// The number of reducing ops, each uses its own space : mean, var, dscale, dbias\nconst int REDUCE_OPS = 4;\n// Maximum block.y supported - limited due to buffer allocation\nconst int MAX_BLOCK_Y = 256;\nconst int MAX_OFFSET = REDUCE_OPS * MAX_BLOCK_Y;\nconst int BYTES_PER_ELEM = 4;\n// Buffer size per sync step\nconst int SINGLE_SYNC_BUFFER_BYTES = MAX_OFFSET * THREADS_PER_PIXEL * 2 * ELEMENTS_PER_LDG * BYTES_PER_ELEM;\n};  // namespace gpuipc\n\nclass IpcMemHandleRegistry {\n public:\n  void* getPtr(const cudaIpcMemHandle_t& handle, int64_t offset) {\n    if (registry_.count(handle) == 0) {\n      registry_.insert(std::make_pair(handle, RegistryEntry()));\n      registry_[handle].dev_ptr = ipcOpenMem(handle);\n    }\n    registry_[handle].ref_count++;\n    return (((uint8_t*)registry_[handle].dev_ptr) + offset);\n  }\n\n  void releasePtr(const cudaIpcMemHandle_t& handle) {\n    if (registry_.count(handle) == 0) {\n    }\n    if (--registry_[handle].ref_count == 0) {\n      ipcCloseMem(registry_[handle].dev_ptr);\n      registry_.erase(handle);\n    }\n  }\n\n  struct RegistryEntry {\n    void* dev_ptr;\n    int ref_count;\n    RegistryEntry() : dev_ptr(NULL), ref_count(0) {}\n  };\n\n protected:\n  std::unordered_map<cudaIpcMemHandle_t, RegistryEntry> registry_;\n\n  void* ipcOpenMem(const cudaIpcMemHandle_t& handle) {\n    void* data;\n    cudaIpcOpenMemHandle(&data, handle, cudaIpcMemLazyEnablePeerAccess);\n    cudaCheckErrors(\"ipc init\");\n    return data;\n  }\n\n  void ipcCloseMem(void* dev_ptr) {\n    cudaIpcCloseMemHandle(dev_ptr);\n    cudaCheckErrors(\"ipc close\");\n  }\n};\n\n}  // namespace\n\nstatic IpcMemHandleRegistry ipc_mem_registry;\n\nint64_t get_buffer_size(const int bn_sync_steps) { return bn_sync_steps * gpuipc::SINGLE_SYNC_BUFFER_BYTES; }\n\nvoid* get_remote_data_ptr(const at::Tensor& handle, const int64_t offset) {\n  cudaIpcMemHandle_t my_handle;\n  memcpy((unsigned char*)(&my_handle), handle.data_ptr<uint8_t>(), sizeof(my_handle));\n  return ipc_mem_registry.getPtr(my_handle, offset);\n}\n\nvoid close_remote_data(const at::Tensor& handle) {\n  cudaIpcMemHandle_t my_handle;\n  memcpy((unsigned char*)(&my_handle), handle.data_ptr<uint8_t>(), sizeof(my_handle));\n  ipc_mem_registry.releasePtr(my_handle);\n}\n\nvoid* get_data_ptr(const at::Tensor& data) { return data.data_ptr<uint8_t>(); }\n"
  },
  {
    "path": "apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h",
    "content": "/*\n * Licensed to the Apache Software Foundation (ASF) under one\n * or more contributor license agreements.  See the NOTICE file\n * distributed with this work for additional information\n * regarding copyright ownership.  The ASF licenses this file\n * to you under the Apache License, Version 2.0 (the\n * \"License\"); you may not use this file except in compliance\n * with the License.  You may obtain a copy of the License at\n *\n *   http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied.  See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n\n/*!\n * Copyright (c) 2018 by Contributors\n * \\file nhwc_batch_norm_kernel.h\n * \\brief CUDA NHWC Batch Normalization code\n * \\author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer\n */\n#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_\n#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_\n\n#include <stdint.h>\n\n#include <algorithm>\n\n#define DEVICE_FUNCTION static inline __device__\n\n// CTA margin used by cooperative launch. Can be overridden by env var NHWC_BATCHNORM_LAUNCH_MARGIN.\n#define NHWC_BATCHNORM_LAUNCH_MARGIN_MIN 3\n#define NHWC_BATCHNORM_LAUNCH_MARGIN_DEFAULT NHWC_BATCHNORM_LAUNCH_MARGIN_MIN\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T, int ELEMENTS_PER_LDG>\nstruct PackedStorage {\n  enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG };\n  typedef T Type;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int ELEMENTS_PER_LDG>\nstruct PackedStorage<uint16_t, ELEMENTS_PER_LDG> {\n  enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG / 2 };\n  typedef int Type;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void from_float(int (&dst)[N], const float (&src)[2 * N]) {\n#pragma unroll\n  for (int i = 0; i < N; ++i) {\n    uint16_t lo, hi;\n    asm volatile(\"cvt.rn.f16.f32 %0, %1;\" : \"=h\"(lo) : \"f\"(src[2 * i + 0]));\n    asm volatile(\"cvt.rn.f16.f32 %0, %1;\" : \"=h\"(hi) : \"f\"(src[2 * i + 1]));\n    asm volatile(\"mov.b32 %0, {%1, %2};\" : \"=r\"(dst[i]) : \"h\"(lo), \"h\"(hi));\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void from_float(float (&dst)[N], const float (&src)[N]) {\n#pragma unroll\n  for (int i = 0; i < N; ++i) {\n    dst[i] = src[i];\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void to_float(float (&dst)[2 * N], int (&src)[N]) {\n#pragma unroll\n  for (int i = 0; i < N; ++i) {\n    uint16_t lo, hi;\n    asm volatile(\"mov.b32 {%0, %1}, %2;\" : \"=h\"(lo), \"=h\"(hi) : \"r\"(src[i]));\n    asm volatile(\"cvt.f32.f16 %0, %1;\" : \"=f\"(dst[2 * i + 0]) : \"h\"(lo));\n    asm volatile(\"cvt.f32.f16 %0, %1;\" : \"=f\"(dst[2 * i + 1]) : \"h\"(hi));\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void to_float(float (&dst)[N], float (&src)[N]) {\n#pragma unroll\n  for (int i = 0; i < N; ++i) {\n    dst[i] = src[i];\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void ldg(int (&dst)[1], const uint16_t* gmem) { dst[0] = __ldg((const int*)gmem); }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void ldg_stream(int (&dst)[1], const uint16_t* gmem) {\n  unsigned int tmp;\n  asm volatile(\"ld.global.cs.nc.s32 %0, [%1];\" : \"=r\"(tmp) : \"l\"((const uint*)gmem));\n  dst[0] = tmp;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void ldg(int (&dst)[2], const uint16_t* gmem) {\n  int2 tmp = __ldg((const int2*)gmem);\n  dst[0] = tmp.x;\n  dst[1] = tmp.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void ldg_stream(int (&dst)[2], const uint16_t* gmem) {\n  int2 tmp;\n  asm volatile(\"ld.global.cs.nc.v2.s32 {%0,%1}, [%2];\" : \"=r\"(tmp.x), \"=r\"(tmp.y) : \"l\"((const int2*)gmem));\n  dst[0] = tmp.x;\n  dst[1] = tmp.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void ldg(float (&dst)[N], const uint16_t* gmem) {\n  int tmp[N / 2];\n  ldg(tmp, gmem);\n  to_float(dst, tmp);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void ldg_stream(float (&dst)[N], const uint16_t* gmem) {\n  int tmp[N / 2];\n  ldg_stream(tmp, gmem);\n  to_float(dst, tmp);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void stg(uint16_t* gmem, int (&src)[1]) { reinterpret_cast<int*>(gmem)[0] = src[0]; }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void stg_stream(uint16_t* gmem, int (&src)[1]) {\n  unsigned int tmp = src[0];\n  asm volatile(\"st.global.cs.s32 [%0], %1;\" ::\"l\"((uint*)gmem), \"r\"(tmp));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void stg(uint16_t* gmem, int (&src)[2]) {\n  reinterpret_cast<int2*>(gmem)[0] = make_int2(src[0], src[1]);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void stg_stream(uint16_t* gmem, int (&src)[2]) {\n  asm volatile(\"st.global.cs.v2.s32 [%0], {%1,%2};\" ::\"l\"((uint*)gmem), \"r\"(src[0]), \"r\"(src[1]));\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void stg(uint16_t* gmem, float (&src)[N]) {\n  int tmp[N / 2];\n  from_float(tmp, src);\n  stg(gmem, tmp);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void stg_stream(uint16_t* gmem, float (&src)[N]) {\n  int tmp[N / 2];\n  from_float(tmp, src);\n  stg_stream(gmem, tmp);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void read_from_gmem(float (&dst)[2], const float* gmem, int idx) {\n  float2 tmp = __ldg(reinterpret_cast<const float2*>(&gmem[2 * idx]));\n  dst[0] = tmp.x;\n  dst[1] = tmp.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void read_from_gmem(float (&dst)[4], const float* gmem, int idx) {\n  float4 tmp = __ldg(reinterpret_cast<const float4*>(&gmem[4 * idx]));\n  dst[0] = tmp.x;\n  dst[1] = tmp.y;\n  dst[2] = tmp.z;\n  dst[3] = tmp.w;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void read_from_smem(float (&x)[2], const float* smem, int idx) {\n  float2 tmp = *(const float2*)&smem[2 * idx];\n  x[0] = tmp.x;\n  x[1] = tmp.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void read_from_smem(int (&x)[1], const int* smem, int idx) { x[0] = smem[idx]; }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void read_from_smem(float (&x)[4], const float* smem, int idx) {\n  float4 tmp = *(const float4*)&smem[4 * idx];\n  x[0] = tmp.x;\n  x[1] = tmp.y;\n  x[2] = tmp.z;\n  x[3] = tmp.w;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void read_from_smem(int (&x)[2], const int* smem, int idx) {\n  int2 tmp = *(const int2*)&smem[2 * idx];\n  x[0] = tmp.x;\n  x[1] = tmp.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void write_to_gmem(float* gmem, int idx, const float (&src)[2]) {\n  reinterpret_cast<float2*>(&gmem[2 * idx])[0] = make_float2(src[0], src[1]);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void write_to_gmem(float* gmem, int idx, const float (&src)[4]) {\n  reinterpret_cast<float4*>(&gmem[4 * idx])[0] = make_float4(src[0], src[1], src[2], src[3]);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void scaled_write_to_gmem(float* gmem, int idx, const float (&src)[4], const float coeff) {\n  reinterpret_cast<float4*>(&gmem[4 * idx])[0] =\n      make_float4(src[0] * coeff, src[1] * coeff, src[2] * coeff, src[3] * coeff);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void write_to_smem(float* smem, int idx, const float (&x)[2]) {\n  reinterpret_cast<float2*>(&smem[2 * idx])[0] = make_float2(x[0], x[1]);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void write_to_smem(int* smem, int idx, const int (&x)[1]) { smem[idx] = x[0]; }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void write_to_smem(float* smem, int idx, const float (&x)[4]) {\n  reinterpret_cast<float4*>(&smem[4 * idx])[0] = make_float4(x[0], x[1], x[2], x[3]);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nDEVICE_FUNCTION void write_to_smem(int* smem, int idx, const int (&x)[2]) {\n  reinterpret_cast<int2*>(&smem[2 * idx])[0] = make_int2(x[0], x[1]);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void zero_array(int (&dst)[N]) {\n#pragma unroll\n  for (int i = 0; i < N; ++i) {\n    dst[i] = 0;\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void zero_array(float (&dst)[N]) {\n#pragma unroll\n  for (int i = 0; i < N; ++i) {\n    dst[i] = 0.f;\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void add(float (&x)[N], const float (&y)[N]) {\n#pragma unroll\n  for (int i = 0; i < N; ++i) {\n    x[i] += y[i];\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void multiply(float (&x)[N], const float (&y)[N]) {\n#pragma unroll\n  for (int i = 0; i < N; ++i) {\n    x[i] *= y[i];\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void scale_(float (&x)[N], float scalar) {\n#pragma unroll\n  for (int i = 0; i < N; ++i) {\n    x[i] *= scalar;\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void normalize(float (&x)[N], const float (&bias)[N], const float (&scale)[N], const float (&m1)[N]) {\n#pragma unroll\n  for (int i = 0; i < N; ++i) {\n    x[i] = bias[i] + scale[i] * (x[i] - m1[i]);\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Storage>\nDEVICE_FUNCTION Storage relu(Storage in) {\n  Storage zero = (Storage)0.f;\n  return (in < zero) ? zero : in;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void relu_activation(float (&x)[N]) {\n#pragma unroll\n  for (int i = 0; i < N; ++i) {\n    x[i] = relu(x[i]);\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\ntemplate <int THREADS_PER_CTA>\nDEVICE_FUNCTION void parallel_sums_16x2(float* smem, float (&x)[4], int nhw, void* params_my_data,\n                                        void** params_pair_datas, int off, const int magic, const int sync_iters) {\n  // The size of a warp.\n  const int THREADS_PER_WARP = 32;\n  // The number of warps in a CTA.\n  const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;\n  // The number of threads per pixel.\n  const int THREADS_PER_PIXEL = 16;\n  // The number of elements per ldg.\n  const int ELEMENTS_PER_LDG = 4;\n  // The number of reducing ops, each uses its own space : mean, var, dscale, dbias\n  const int REDUCE_OPS = 4;\n  // Maximum block.y supported - limited due to buffer allocation\n  const int MAX_BLOCK_Y = 256;\n  const int MAX_OFFSET = REDUCE_OPS * MAX_BLOCK_Y;\n  // The warp decomposition.\n  const int warp_id = threadIdx.x / THREADS_PER_WARP;\n  const int lane_id = threadIdx.x % THREADS_PER_WARP;\n  // total size of data per sync iter\n  const int data_total = MAX_OFFSET * THREADS_PER_PIXEL * ELEMENTS_PER_LDG * 2;\n\n#pragma unroll\n  for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n    x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL + lane_id);\n  }\n\n  // The warp leaders, write to SMEM.\n  if (lane_id < THREADS_PER_PIXEL) {\n    write_to_smem(smem, warp_id * THREADS_PER_PIXEL + lane_id, x);\n  }\n\n  // The data is in SMEM. Do the final reduction.\n  __syncthreads();\n\n  // The 1st warp does all the work.\n  // We do the final reduction each half-warp sequentially reduces the final values.\n  if (warp_id == 0) {\n    read_from_smem(x, smem, threadIdx.x);\n\n#pragma unroll\n    for (int offset = 1; offset < WARPS_PER_CTA / (THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) {\n      float y[ELEMENTS_PER_LDG];\n      // Read the mean and variance from the other pixel.\n      read_from_smem(y, smem, threadIdx.x + offset * THREADS_PER_WARP);\n      // Compute the updated sum.\n      add(x, y);\n    }\n\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n      x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL + lane_id);\n    }\n\n    // Make sure the data was read from SMEM.\n    __syncwarp();\n\n    // Store the final values.\n    if (threadIdx.x < THREADS_PER_PIXEL) {\n      // probably could do it earlier, before sync\n\n      for (int sync_iter = 0; sync_iter < sync_iters; ++sync_iter) {\n        // float* params_pair_data = (reinterpret_cast<float**>(params_pair_datas))[sync_iter];\n        void* params_pair_data = params_pair_datas[sync_iter];\n\n        // skip the space consumed by previous sync iterations\n        const int xbuf_offset = sync_iter * data_total;\n        // data starts after flags, but have to skip previous\n        const int data_offset =\n            xbuf_offset + off * ELEMENTS_PER_LDG * THREADS_PER_PIXEL * 2 + ELEMENTS_PER_LDG * threadIdx.x * 2;\n\n        // after sums for this GPU were computed, let CTA0 broadcast the sum to over GPU\n        if (blockIdx.x == 0) {\n          volatile float* write_data = &((reinterpret_cast<float*>(params_pair_data))[data_offset]);\n\n          // write the data to memory region to be reflected to other GPU\n          asm volatile(\"st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};\" ::\"l\"(write_data), \"f\"(x[0]), \"r\"(magic), \"f\"(x[2]),\n                       \"r\"(magic));\n\n          asm volatile(\"st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};\" ::\"l\"(write_data + 4), \"f\"(x[1]), \"r\"(magic),\n                       \"f\"(x[3]), \"r\"(magic));\n        }\n\n        // now each CTA (on each GPU) reads the data written by CTA 0 of the other GPU\n        volatile float* read_data = &((reinterpret_cast<float*>(params_my_data))[data_offset]);\n\n        float other[4];\n        uint32_t other_flag_a, other_flag_b;\n        do {\n          asm volatile(\"ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];\"\n                       : \"=f\"(other[0]), \"=r\"(other_flag_a), \"=f\"(other[2]), \"=r\"(other_flag_b)\n                       : \"l\"(read_data));\n        } while ((other_flag_a != magic) || (other_flag_b != magic));\n\n        do {\n          asm volatile(\"ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];\"\n                       : \"=f\"(other[1]), \"=r\"(other_flag_a), \"=f\"(other[3]), \"=r\"(other_flag_b)\n                       : \"l\"(read_data + 4));\n        } while ((other_flag_a != magic) || (other_flag_b != magic));\n\n        add(x, other);\n      }\n      // finally, after syncing up and accounting for partial sums from\n      // other GPUs as required, write the result\n\n      write_to_smem(smem, threadIdx.x, x);\n    }\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int THREADS_PER_CTA>\nDEVICE_FUNCTION void parallel_sums_8x4(float* smem, float (&x)[4], int nhw) {\n  // The size of a warp.\n  const int THREADS_PER_WARP = 32;\n  // The number of warps in a CTA.\n  const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;\n  // The number of threads per pixel.\n  const int THREADS_PER_PIXEL = 8;\n  // The number of elements per ldg.\n  const int ELEMENTS_PER_LDG = 4;\n  // The warp decomposition.\n  const int warp_id = threadIdx.x / THREADS_PER_WARP;\n  const int lane_id = threadIdx.x % THREADS_PER_WARP;\n\n#pragma unroll\n  for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n    x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL + lane_id);\n    x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL * 2 + lane_id);\n  }\n\n  // The warp leaders, write to SMEM.\n  if (lane_id < THREADS_PER_PIXEL) {\n    write_to_smem(smem, warp_id * THREADS_PER_PIXEL + lane_id, x);\n  }\n\n  // The data is in SMEM. Do the final reduction.\n  __syncthreads();\n\n  // The 1st warp does all the work.\n  // We do the final reduction each half-warp sequentially reduces the final values.\n  if (warp_id == 0) {\n    read_from_smem(x, smem, threadIdx.x);\n\n#pragma unroll\n    for (int offset = 1; offset < WARPS_PER_CTA / (THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) {\n      float y[ELEMENTS_PER_LDG];\n      // Read the mean and variance from the other pixel.\n      read_from_smem(y, smem, threadIdx.x + offset * THREADS_PER_WARP);\n      // Compute the updated sum.\n      add(x, y);\n    }\n\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n      x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL + lane_id);\n      x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL * 2 + lane_id);\n    }\n\n    // Make sure the data was read from SMEM.\n    __syncwarp();\n\n    // Store the final values.\n    if (threadIdx.x < THREADS_PER_PIXEL) {\n      write_to_smem(smem, threadIdx.x, x);\n    }\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int THREADS_PER_CTA, int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG>\nDEVICE_FUNCTION void parallel_sums(float* smem, float (&x)[ELEMENTS_PER_LDG], int nhw) {\n  // The size of a warp.\n  const int THREADS_PER_WARP = 32;\n  // The number of warps in a CTA.\n  const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;\n  // The number of pixels computed by a single warp.\n  const int PIXELS_PER_WARP = THREADS_PER_WARP / THREADS_PER_PIXEL;\n\n  // The position in the warp.\n  const int nhw_in_warp = nhw % PIXELS_PER_WARP;\n  // The C in the warp.\n  const int c_in_warp = threadIdx.x % THREADS_PER_PIXEL;\n\n  // Store the values to shared memory.\n  write_to_smem(smem, threadIdx.x, x);\n\n  // Compute the parallel sums.\n  for (int offset = PIXELS_PER_WARP / 2; offset > 0; offset /= 2) {\n    // NOP.\n    __syncwarp();\n\n    // Read the running sum from the other thread.\n    float y[ELEMENTS_PER_LDG];\n    if (nhw_in_warp < offset) {\n      read_from_smem(y, smem, threadIdx.x + offset * THREADS_PER_PIXEL);\n    }\n\n    // Compute the updated sum.\n    add(x, y);\n\n    // NOP.\n    __syncwarp();\n\n    // Update the sum in SMEM.\n    if (offset > 1 && nhw_in_warp < offset) {\n      write_to_smem(smem, threadIdx.x, x);\n    }\n  }\n\n  // The warps are done. Do the final reduction at the CTA level.\n  __syncthreads();\n\n  // The warp leaders, write to SMEM.\n  const int idx = (threadIdx.x / THREADS_PER_WARP) * THREADS_PER_PIXEL + c_in_warp;\n  if (nhw_in_warp == 0) {\n    write_to_smem(smem, idx, x);\n  }\n\n  // The data is in SMEM. Do the final reduction.\n  __syncthreads();\n\n  // Read the 1st element to prepare the work.\n  if (nhw < WARPS_PER_CTA / 2) {\n    read_from_smem(x, smem, threadIdx.x);\n  }\n\n  // We have the running mean and running m2. Let's build the mean/var of the CTA.\n  for (int offset = WARPS_PER_CTA / 2; offset > 0; offset /= 2) {\n    // NOP.\n    __syncwarp();\n\n    // Read the mean and variance from the other pixel.\n    float y[ELEMENTS_PER_LDG];\n    if (nhw < offset) {\n      read_from_smem(y, smem, threadIdx.x + offset * THREADS_PER_PIXEL);\n    }\n\n    // Compute the updated sum.\n    add(x, y);\n\n    // NOP.\n    __syncwarp();\n\n    // Store the mean/var for the different pixels.\n    if (nhw < offset) {\n      write_to_smem(smem, threadIdx.x, x);\n    }\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG>\nstruct ParallelSums {\n  template <int THREADS_PER_CTA>\n  DEVICE_FUNCTION void dispatch(float* smem, float (&x)[ELEMENTS_PER_LDG], int nhw) {\n    parallel_sums<THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG>(smem, x, nhw);\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <>\nstruct ParallelSums<16, 4> {\n  template <int THREADS_PER_CTA>\n  DEVICE_FUNCTION void dispatch(float* smem, float (&x)[4], int nhw) {\n    parallel_sums_16x2<THREADS_PER_CTA>(smem, x, nhw, 0, 0, 0, 0, 0);\n  }\n\n  template <int THREADS_PER_CTA>\n  DEVICE_FUNCTION void dispatchX(float* smem, float (&x)[4], int nhw, void* params_my_data, void** params_pair_datas,\n                                 int off, const int magic, const unsigned int& sync_iters) {\n    parallel_sums_16x2<THREADS_PER_CTA>(smem, x, nhw, params_my_data, params_pair_datas, off, magic, sync_iters);\n  }\n};\n\ntemplate <>\nstruct ParallelSums<8, 4> {\n  template <int THREADS_PER_CTA>\n  DEVICE_FUNCTION void dispatch(float* smem, float (&x)[4], int nhw) {\n    parallel_sums_8x4<THREADS_PER_CTA>(smem, x, nhw);\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstatic inline int div_up(int m, int n) { return (m + n - 1) / n; }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// It is expected that all threads in the CTA enter this function!\nDEVICE_FUNCTION void inter_block_sync(int* gmem_retired_ctas, int expected_count, bool master) {\n  // Register the CTA.\n  if (threadIdx.x == 0) {\n    // Issue the membar.\n    __threadfence();\n    // Notify that the CTA is done.\n    int val_to_add = 1;\n    if (master) {\n      val_to_add = -(expected_count - 1);\n    }\n    atomicAdd(gmem_retired_ctas, val_to_add);\n  }\n\n  // Are all CTAs done?\n  if (threadIdx.x == 0) {\n    int retired_ctas = -1;\n    do {\n      __threadfence();\n      asm volatile(\"ld.global.cg.b32 %0, [%1];\" : \"=r\"(retired_ctas) : \"l\"(gmem_retired_ctas));\n    } while (retired_ctas != 0);\n  }\n  __syncthreads();\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct NhwcBatchNormFwdInferenceParams {\n  // The input/output tensors.\n  uint16_t *gmem_src, *gmem_dst, *gmem_src1;\n  // the final mean and variance as calculated during the training process\n  float *gmem_mean, *gmem_var;\n  // The bias/scale.\n  float *gmem_bias, *gmem_scale;\n  // The dimensions.\n  int nhw, c;\n  // epsilon\n  float var_eps;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// No DESIRED_OCCUPANCY launch bounds needed, as this is not launched cooperatively\ntemplate <typename Storage, int THREADS_PER_CTA, int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG, bool USE_RELU,\n          bool USE_ADD_RELU>\n__global__ __launch_bounds__(THREADS_PER_CTA) void nhwc_batch_norm_fwd_inference(\n    NhwcBatchNormFwdInferenceParams params) {\n  // The number of pixels loaded in a single LDG.\n  const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n  // The number of C elements per CTA.\n  const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL * ELEMENTS_PER_LDG;\n\n  // The start position in the NHW dimension where the CTA starts.\n  const int cta_nhw_stride = gridDim.x * PIXELS_PER_LDG;\n  // Compute the NHW coordinate of the thread in the CTA.\n  const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;\n  // thread's starting point in NHW\n  const int thread_nhw = thread_in_cta_nhw + blockIdx.x * PIXELS_PER_LDG;\n\n  // The position in the C dimension where the CTA starts.\n  const int cta_c = blockIdx.y * C_ELEMENTS_PER_CTA;\n  // Compute the C coordinate of the thread in the CTA.\n  const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;\n  // Compute the C coordinate of the thread.\n  const int thread_c = cta_c + thread_in_cta_c * ELEMENTS_PER_LDG;\n\n  // Is the thread working on a valid C dimension?\n  const int is_valid_c = thread_c < params.c;\n\n  float mean[ELEMENTS_PER_LDG], var[ELEMENTS_PER_LDG];\n  float scale[ELEMENTS_PER_LDG], bias[ELEMENTS_PER_LDG];\n  zero_array(mean);\n  zero_array(var);\n  zero_array(scale);\n  zero_array(bias);\n  if (is_valid_c) {\n    read_from_gmem(var, &params.gmem_var[cta_c], thread_in_cta_c);\n    read_from_gmem(scale, &params.gmem_scale[cta_c], thread_in_cta_c);\n    read_from_gmem(mean, &params.gmem_mean[cta_c], thread_in_cta_c);\n    read_from_gmem(bias, &params.gmem_bias[cta_c], thread_in_cta_c);\n  }\n\n// Update the scale with the stddev and eps.\n#pragma unroll\n  for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n    scale[i] *= rsqrtf(var[i] + params.var_eps);\n  }\n\n  // The base pointers for reading/writing\n  uint16_t* const gmem_src = &params.gmem_src[thread_c];\n  uint16_t* const gmem_dst = &params.gmem_dst[thread_c];\n  const uint16_t* gmem_src1 = nullptr;\n  if (USE_ADD_RELU) {\n    gmem_src1 = &params.gmem_src1[thread_c];\n  }\n\n  // apply BN\n  for (int nhw = thread_nhw; nhw < params.nhw; nhw += cta_nhw_stride) {\n    float x_math[ELEMENTS_PER_LDG];\n    zero_array(x_math);\n    if (is_valid_c) {\n      ldg(x_math, &gmem_src[nhw * params.c]);\n    }\n\n    // Normalize and apply activation function\n    normalize(x_math, bias, scale, mean);\n    if (USE_ADD_RELU) {\n      float x1_math[ELEMENTS_PER_LDG];\n      ldg(x1_math, &gmem_src1[nhw * params.c]);\n      add(x_math, x1_math);\n      relu_activation(x_math);\n    } else if (USE_RELU) {\n      relu_activation(x_math);\n    }\n\n    if (is_valid_c) {\n      stg(&gmem_dst[nhw * params.c], x_math);\n    }\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct NhwcBatchNormFwdParams {\n  // The input/output tensors.\n  uint16_t *gmem_src, *gmem_dst, *gmem_src1;\n  // The bias/scale.\n  float *gmem_bias, *gmem_scale;\n  // running mean/var (refer BN API from cudnn doc)\n  float *gmem_running_mean, *gmem_running_var;\n  // saved mean/var (refer BN API from cudnn doc)\n  float *gmem_saved_mean, *gmem_saved_var;\n  // ReLU bitmask\n  unsigned int* gmem_relu_bitmask;\n  // The dimensions.\n  int nhw, c;\n  // factor to scale sum of squared errors to get saved variance.  Must be 1/nhw.\n  float svar_inv_count;\n  // factor to scale sum of squared errors to get running variance. Should be 1/nhw or 1/(nhw-1).\n  float rvar_inv_count;\n  // The buffer to do the reduction for mean, stddev and count.\n  float* gmem_sums;\n  // The buffer to count items in the different CTAs.\n  int* gmem_counts;\n  // The counters of retired CTAs.\n  int* gmem_retired_ctas;\n  // The epsilon to apply to the computation of the variance.\n  float var_eps;\n  // outer loop count\n  int outer_loops;\n  // exponential average factor\n  float exp_avg_factor;\n  // number of CTAs along .x dimension\n  int c_blks;\n\n  void* my_data;\n  void* pair_datas[4];\n  int magic;\n  int sync_iters;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Storage, int THREADS_PER_CTA, int THREADS_PER_PIXEL, int PIXELS_PER_THREAD_IN_REGISTERS,\n          int PIXELS_PER_THREAD_IN_SMEM, int ELEMENTS_PER_LDG, int USE_ONLINE_APPROACH, int OUTER_LOOPS_, bool USE_RELU,\n          bool USE_ADD_RELU, int DESIRED_OCCUPANCY>\n__global__ __launch_bounds__(THREADS_PER_CTA,\n                             DESIRED_OCCUPANCY) void nhwc_batch_norm_fwd(NhwcBatchNormFwdParams params) {\n  // The number of pixels loaded in a single LDG.\n  const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n  // The number of pixels computed per CTA stored in registers.\n  const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;\n  // The number of pixels computed per CTA stored in SMEM.\n  const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM * PIXELS_PER_LDG;\n  // The number of C elements per CTA.\n  const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL * ELEMENTS_PER_LDG;\n\n  // Shared memory to do CTA-wide parallel sums.\n  __shared__ float smem[THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG];\n\n  // Compute the NHW coordinate of the thread in the CTA.\n  const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;\n\n  // The adapter for the storage.\n  typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;\n  // The data type for packed storage in SMEM.\n  typedef typename PackedStorage_::Type PackedStorageType;\n  // The number of elements in the packed storage.\n  const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;\n  // Registers to keep the data live for the persistent approach.\n  PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n\n  // Shared memory buffer to store the extra pixels.\n  extern __shared__ PackedStorageType smem_storage_packed[];\n\n  for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {\n    // The position in the NHW dimension where the CTA starts.\n    int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;\n    // The position in the NHW dimension where the CTA starts for the portion in SMEM.\n    int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;\n\n    // The position in the C dimension where the CTA starts.\n    const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;\n    // Compute the C coordinate of the thread in the CTA.\n    const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;\n    // Compute the C coordinate of the thread.\n    int thread_c = cta_c + thread_in_cta_c * ELEMENTS_PER_LDG;\n\n    // Is the thread working on a valid C dimension?\n    const int is_valid_c = thread_c < params.c;\n\n    // Clamp thread_c so that we load from valid locations even if we don't use the value\n    if (!is_valid_c) thread_c = params.c - 4;\n\n    // Single pass numerically stable algorithm, see:\n    // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm\n    //\n    // n = 0, mean = 0.0, M2 = 0.0\n    //\n    // for x in data:\n    //     n += 1\n    //     delta = x - mean\n    //     mean += delta/n\n    //     delta2 = x - mean\n    //     M2 += delta*delta2\n    //\n    // if n < 2:\n    //     return float('nan')\n    // else:\n    //     return M2 / (n - 1)\n\n    // Register to store the number of elements read so far.\n    float count = 0.f, mean[ELEMENTS_PER_LDG], m2[ELEMENTS_PER_LDG];\n#pragma unroll\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n      mean[i] = 0.f;\n      m2[i] = 0.f;\n    }\n\n    // The number of elements loaded by this CTA.\n    int cta_count = 0;\n    // The base pointer to load from.\n    const uint16_t* gmem_src = &params.gmem_src[thread_c];\n\n    // outer loops\n    int OUTER_LOOPS = OUTER_LOOPS_ == 1 ? 1 : params.outer_loops;\n    // Load the batch of elements. Compute the mean/var across those elements.\n    const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS * gridDim.x;\n\n    if (OUTER_LOOPS_ != 1) {\n      // We cannot load everything to store persistently, so let's makes sure registers and\n      // smem are fully utilized, offset is evenly divisible by 32\n      int offset = (pixels_per_iteration * OUTER_LOOPS + PIXELS_PER_CTA_IN_SMEM * gridDim.x - params.nhw) & ~31;\n      cta_nhw_regs -= offset;\n      cta_nhw_smem -= offset;\n    }\n\n#pragma unroll 1\n    for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {\n      // The nhw position.\n      int nhw_regs = cta_nhw_regs + loop_i * pixels_per_iteration;\n      // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!\n      cta_count += max(min(nhw_regs + PIXELS_PER_CTA_IN_REGISTERS, params.nhw) - max(nhw_regs, 0), 0);\n\n      // Load the data and compute the local mean/sum and the variance.\n      if (USE_ONLINE_APPROACH) {\n        // Read the elements from memory.\n        float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];\n#pragma unroll\n        for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n          const int idx = nhw_regs + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n          zero_array(x_storage[i]);\n          is_valid[i] = 0.f;\n          if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n            if (loop_i == OUTER_LOOPS - 1) {\n              ldg_stream(x_storage[i], &gmem_src[idx * params.c]);\n            } else {\n              ldg(x_storage[i], &gmem_src[idx * params.c]);\n            }\n            is_valid[i] = 1.f;\n          }\n        }\n\n// Do the math.\n#pragma unroll\n        for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n          // Convert to float.\n          float x_math[ELEMENTS_PER_LDG];\n          to_float(x_math, x_storage[i]);\n\n          // Update the count.\n          count += is_valid[i];\n          // Invert the count.\n          float inv_count = is_valid[i] ? 1.f / count : 0.f;\n\n// Update the mean and m2 using deltas.\n#pragma unroll\n          for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n            float delta0 = x_math[j] - mean[j];\n            mean[j] += delta0 * inv_count;\n            float delta1 = x_math[j] - mean[j];\n            m2[j] += delta0 * delta1 * is_valid[i];\n          }\n        }\n      } else {\n// Read the elements from memory.\n#pragma unroll\n        for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n          const int idx = nhw_regs + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n          zero_array(x_storage[i]);\n          if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n            if (loop_i == OUTER_LOOPS - 1) {\n              ldg_stream(x_storage[i], &gmem_src[idx * params.c]);\n            } else {\n              ldg(x_storage[i], &gmem_src[idx * params.c]);\n            }\n            count += 1.f;\n          }\n        }\n\n// Sum the elements in registers.\n#pragma unroll\n        for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n          // Convert to float.\n          float x_math[ELEMENTS_PER_LDG];\n          to_float(x_math, x_storage[i]);\n\n// Update the mean and m2 using deltas.\n#pragma unroll\n          for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n            mean[j] += x_math[j];\n          }\n        }\n\n        // Compute the mean.\n        float inv_count = 1.f / count;\n#pragma unroll\n        for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n          mean[j] *= inv_count;\n        }\n\n// Compute the variance.\n#pragma unroll\n        for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n          // Convert to float.\n          float x_math[ELEMENTS_PER_LDG];\n          to_float(x_math, x_storage[i]);\n\n          // Is it a valid pixel?\n          float is_valid = i < static_cast<int>(count) ? 1.f : 0.f;\n// Update the mean and m2 using deltas.\n#pragma unroll\n          for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n            m2[j] += (x_math[j] - mean[j]) * (x_math[j] - mean[j]) * is_valid;\n          }\n        }\n      }\n    }\n\n    // The elements to load and store in SMEM.\n    int smem_nhw = OUTER_LOOPS * pixels_per_iteration + cta_nhw_smem;\n    // Load elements from SMEM, update the CTA count.\n    int pixels_in_smem = min(smem_nhw + PIXELS_PER_CTA_IN_SMEM, params.nhw) - max(smem_nhw, 0);\n    if (pixels_in_smem > 0) {\n      cta_count += pixels_in_smem;\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n        const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        float is_pixel_valid = (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) ? 1.f : 0.f;\n\n        PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG];\n        ldg_stream(x_storage_local, &gmem_src[(is_pixel_valid ? idx : 0) * params.c]);\n\n        // The offset to store in SMEM.\n        const int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG;\n        // Store in SMEM.\n        write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);\n        // Update the count.\n        count += is_pixel_valid;\n        // Invert the count.\n        float inv_count = is_pixel_valid ? 1.f / count : 0.f;\n\n        float x_math[ELEMENTS_PER_LDG];\n        to_float(x_math, x_storage_local);\n// Update the mean and m2 using deltas.\n#pragma unroll\n        for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n          float delta0 = x_math[j] - mean[j];\n          mean[j] += delta0 * inv_count;\n          float delta1 = x_math[j] - mean[j];\n          m2[j] += delta0 * delta1 * is_pixel_valid;\n        }\n      }\n    }\n\n    // We scale the mean by the number of elements. It brings more stability.\n    float m1[ELEMENTS_PER_LDG];\n#pragma unroll\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n      m1[i] = mean[i] * count;\n    }\n\n    // Run the parallel sum accross the CTA to get the local sum.\n    ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(smem, m1, thread_in_cta_nhw);\n    __syncthreads();\n\n    // The values in shared memory correspond to the CTA-wide sums.\n    read_from_smem(m1, smem, thread_in_cta_c);\n    __syncthreads();\n\n    // Adjust the variance.\n    float inv_cta_count = 1.f / static_cast<float>(cta_count);\n#pragma unroll\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n      float mean_diff = m1[i] * inv_cta_count - mean[i];\n      m2[i] = m2[i] + mean_diff * mean_diff * count;\n    }\n\n    // Run the parallel sum accross the CTA to get the local adjusted variance.\n    ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(smem, m2, thread_in_cta_nhw);\n\n    // The workspace in global memory is distributed across the different CTA.\n    int gmem_sums_offset = c_blk_index * gridDim.x * C_ELEMENTS_PER_CTA * 2;\n\n    // Write the data for the CTA to global memory.\n    float* gmem_sums = &params.gmem_sums[gmem_sums_offset];\n    if (threadIdx.x < THREADS_PER_PIXEL) {\n      const int idx = blockIdx.x * THREADS_PER_PIXEL + threadIdx.x;\n      write_to_gmem(&gmem_sums[0], idx, m1);\n      write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA * gridDim.x], idx, m2);\n    }\n\n    // The memory location to store the number of pixels per CTA.\n    int* gmem_counts = &params.gmem_counts[c_blk_index * gridDim.x];\n    if (threadIdx.x == 0) {\n      gmem_counts[blockIdx.x] = cta_count;\n    }\n\n    // Read the bias and scale.\n    float bias[ELEMENTS_PER_LDG], scale[ELEMENTS_PER_LDG];\n    if (is_valid_c) {\n      read_from_gmem(bias, &params.gmem_bias[cta_c], thread_in_cta_c);\n      read_from_gmem(scale, &params.gmem_scale[cta_c], thread_in_cta_c);\n    }\n\n    // The counters to count how many CTAs have retired at this point.\n    // A given cta uses the same counter every other time through the outer loop.\n    int* gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];\n    inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);\n\n// Reset the mean to compute the global mean.\n#pragma unroll\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n      m1[i] = 0.f;\n    }\n\n// Build the global mean.\n#pragma unroll 1\n    for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL * gridDim.x; idx += THREADS_PER_CTA) {\n      float tmp[ELEMENTS_PER_LDG];\n      read_from_gmem(tmp, gmem_sums, idx);\n      add(m1, tmp);\n    }\n\n    if (params.sync_iters > 0) {\n      ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n          smem, m1, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 3, params.magic,\n          params.sync_iters);\n    } else {\n      ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(smem, m1, thread_in_cta_nhw);\n    }\n    __syncthreads();\n\n    // The values in shared memory correspond to the CTA-wide sums.\n    read_from_smem(m1, smem, thread_in_cta_c);\n    __syncthreads();\n\n// Normalize the mean.\n#pragma unroll\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n      m1[i] = m1[i] * params.svar_inv_count;\n    }\n\n// Reset the variance.\n#pragma unroll\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n      m2[i] = 0.f;\n    }\n\n    // for add+relu fusion\n    const uint16_t* gmem_src1 = nullptr;\n    if (USE_ADD_RELU) {\n      gmem_src1 = &params.gmem_src1[thread_c];\n    }\n\n// Build the global variance.\n#pragma unroll 1\n    for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL * gridDim.x; idx += THREADS_PER_CTA) {\n      // Read the means computed by different CTAs (again). Reuse tmp if we have 1 iteration.\n      float tmp_mean[ELEMENTS_PER_LDG], tmp_var[ELEMENTS_PER_LDG];\n      read_from_gmem(tmp_mean, &gmem_sums[0], idx);\n      read_from_gmem(tmp_var, &gmem_sums[C_ELEMENTS_PER_CTA * gridDim.x], idx);\n\n      // Read the number of pixels visited by a given CTA.\n      cta_count = __ldg(&gmem_counts[idx / THREADS_PER_PIXEL]);\n\n      // Compute the diff to update the variance.\n      float mean_diff[ELEMENTS_PER_LDG], inv_cta_count = 1.f / static_cast<float>(cta_count);\n#pragma unroll\n      for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n        mean_diff[i] = m1[i] - tmp_mean[i] * inv_cta_count;\n      }\n\n// Update the variance.\n#pragma unroll\n      for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n        m2[i] += tmp_var[i] + mean_diff[i] * mean_diff[i] * static_cast<float>(cta_count);\n      }\n    }\n\n    if (params.sync_iters > 0) {\n      ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n          smem, m2, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 2, params.magic,\n          params.sync_iters);\n    } else {\n      ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(smem, m2, thread_in_cta_nhw);\n    }\n    __syncthreads();\n\n    read_from_smem(m2, smem, thread_in_cta_c);\n\n    // Finalize the stddev.\n    // becasue saved var and running var may have different denominator, we don't do it here\n    // scale_(m2, inv_count);\n\n    // store the saved mean/var\n    float svarinv[ELEMENTS_PER_LDG];\n    bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;\n#pragma unroll\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n      svarinv[i] = rsqrtf(m2[i] * params.svar_inv_count + params.var_eps);\n    }\n    if (is_valid_for_saving) {\n      write_to_gmem(params.gmem_saved_mean, thread_c / ELEMENTS_PER_LDG, m1);\n      write_to_gmem(params.gmem_saved_var, thread_c / ELEMENTS_PER_LDG, svarinv);\n    }\n\n    // store the running mean/var\n    float rmean[ELEMENTS_PER_LDG], rvar[ELEMENTS_PER_LDG];\n    zero_array(rmean);\n    zero_array(rvar);\n    if (params.exp_avg_factor != 1.f && is_valid_for_saving) {\n      read_from_gmem(rmean, params.gmem_running_mean, thread_c / ELEMENTS_PER_LDG);\n      read_from_gmem(rvar, params.gmem_running_var, thread_c / ELEMENTS_PER_LDG);\n    }\n#pragma unroll\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n      rmean[i] = (1.f - params.exp_avg_factor) * rmean[i] + params.exp_avg_factor * m1[i];\n      rvar[i] = (1.f - params.exp_avg_factor) * rvar[i] + params.exp_avg_factor * (m2[i] * params.rvar_inv_count);\n    }\n    if (is_valid_for_saving) {\n      write_to_gmem(params.gmem_running_mean, thread_c / ELEMENTS_PER_LDG, rmean);\n      write_to_gmem(params.gmem_running_var, thread_c / ELEMENTS_PER_LDG, rvar);\n    }\n\n    // Update the scale with the stddev and eps.\n    multiply(scale, svarinv);\n\n    // The base pointer to write to.\n    uint16_t* const gmem_dst = &params.gmem_dst[thread_c];\n\n    unsigned int* const gmem_relu_bitmask = params.gmem_relu_bitmask + ((params.nhw + 31) & ~31) * 2 * c_blk_index;\n\n// Store the elements in registers.\n#pragma unroll 1\n    for (int loop_i = OUTER_LOOPS - 1; loop_i >= 0; --loop_i) {\n      // The value for nhw.\n      int out_nhw = cta_nhw_regs + loop_i * pixels_per_iteration;\n\n// Normalize the elements and write to memory.\n#pragma unroll\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n        const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        const bool is_valid_nhw = static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);\n        const bool is_valid = is_valid_nhw && is_valid_c;\n        // Convert to float.\n        float x_math[ELEMENTS_PER_LDG];\n        to_float(x_math, x_storage[i]);\n\n        // Normalize and apply activation function\n        normalize(x_math, bias, scale, m1);\n        if (USE_ADD_RELU) {\n          float x1_math[ELEMENTS_PER_LDG];\n          ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0) * params.c]);\n          add(x_math, x1_math);\n          unsigned int relu_mask;\n          int lane_id = threadIdx.x & 31;\n#pragma unroll\n          for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            bool rectified = x_math[i] < 0.0F;\n            unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified);\n            if (lane_id == i) {\n              // Thread 0 remembers the relu_mask from the first time through this\n              // loop, Thread 1 the next, Thread 2 the next, and Thread 3 the last.\n              relu_mask = local_relu_mask;\n            }\n            if (rectified) {\n              x_math[i] = 0.0F;\n            }\n          }\n          if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) {\n            gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask;\n          }\n        } else if (USE_RELU) {\n          relu_activation(x_math);\n        }\n\n        // Write back.\n        if (is_valid) {\n          stg_stream(&gmem_dst[idx * params.c], x_math);\n        }\n      }\n\n      // The next value of nhw.\n      out_nhw -= pixels_per_iteration;\n\n// Read the next elements from memory.\n#pragma unroll\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n        const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n          ldg_stream(x_storage[i], &gmem_src[idx * params.c]);\n        }\n      }\n    }\n\n    // Normalize the elements from SMEM and write them out.\n    if (pixels_in_smem > 0) {\n#pragma unroll 2\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n        const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        const bool is_valid_nhw = static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);\n        const bool is_valid = is_valid_nhw && is_valid_c;\n\n        // Read from SMEM.\n        const int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG;\n        PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG];\n        read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);\n        float x_math[ELEMENTS_PER_LDG];\n        to_float(x_math, x_storage_local);\n\n        // Normalize and apply activation function\n        normalize(x_math, bias, scale, m1);\n        if (USE_ADD_RELU) {\n          float x1_math[ELEMENTS_PER_LDG];\n          ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0) * params.c]);\n          add(x_math, x1_math);\n          unsigned int relu_mask;\n          int lane_id = threadIdx.x & 31;\n#pragma unroll\n          for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n            bool rectified = x_math[i] < 0.0F;\n            unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified);\n            if (lane_id == i) {\n              relu_mask = local_relu_mask;\n            }\n            if (rectified) {\n              x_math[i] = 0.0F;\n            }\n          }\n          if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) {\n            gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask;\n          }\n        } else if (USE_RELU) {\n          relu_activation(x_math);\n        }\n\n        // Write back.\n        if (is_valid) {\n          stg_stream(&gmem_dst[idx * params.c], x_math);\n        }\n      }\n    }\n    // We're about to start on the next c-blk.  Needed?\n    __syncthreads();\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct NhwcBatchNormBwdParams {\n  // The input/output tensors.\n  uint16_t *gmem_src, *gmem_dy, *gmem_dst, *gmem_dst1;\n  // dscale/dbias\n  float *gmem_dscale, *gmem_dbias;\n  // The scale and bias.\n  float *gmem_scale, *gmem_bias;\n  // The mean/inv-var saved from fwd pass\n  float *gmem_saved_mean, *gmem_saved_var;\n  // ReLU bitmask\n  unsigned int* gmem_relu_bitmask;\n  // The dimensions.\n  int nhw, c;\n  // factor to scale sum of squared errors to get saved variance.  Must be 1/nhw.\n  float svar_inv_count;\n  // The buffer to do the reduction for dscale and dbias\n  float* gmem_sums;\n  // The counters of retired CTAs.\n  int* gmem_retired_ctas;\n  // outer loop count\n  int outer_loops;\n  // number of CTAs along .x dimension\n  int c_blks;\n\n  void* my_data;\n  void* pair_datas[4];\n  int magic;\n  int sync_iters;\n  float wgrad_coeff;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void relu_bwd(float (&dy)[N], const float (&x)[N], const float (&mean_var_scale_bias)[N],\n                              const float (&var_scale)[N], bool valid_data) {\n#pragma unroll\n  for (int j = 0; j < N; ++j) {\n    float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j];\n    if ((y <= 0.f) && valid_data) {\n      dy[j] = 0.f;\n    }\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void relu_bwd(float (&dy)[N], const float (&y)[N], bool valid_data) {\n#pragma unroll\n  for (int j = 0; j < N; ++j) {\n    if ((y[j] <= 0.f) && valid_data) {\n      dy[j] = 0.f;\n    }\n  }\n}\n\ntemplate <int N>\nDEVICE_FUNCTION void relu_bwd(float (&dy)[N], const bool (&rectified)[N], bool valid_data) {\n#pragma unroll\n  for (int j = 0; j < N; ++j) {\n    if (rectified[j] && valid_data) {\n      dy[j] = 0.f;\n    }\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void relu_bwd_for_dx(float (&dy)[N], const float (&x)[N], const float (&mean_var_scale_bias)[N],\n                                     const float (&var_scale)[N]) {\n#pragma unroll\n  for (int j = 0; j < N; ++j) {\n    float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j];\n    if (y <= 0.f) {\n      dy[j] = 0.f;\n    }\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void relu_bwd_for_dx(float (&dy)[N], const float (&y)[N]) {\n#pragma unroll\n  for (int j = 0; j < N; ++j) {\n    if (y[j] <= 0.f) {\n      dy[j] = 0.f;\n    }\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void bwd_update(float (&dscale)[N], float (&dbias)[N], const float (&dy)[N], const float (&x)[N],\n                                const float (&mean)[N], float inv_count) {\n#pragma unroll\n  for (int j = 0; j < N; ++j) {\n    float delta0 = dy[j] - dbias[j];\n    dbias[j] += delta0 * inv_count;\n    delta0 = (dy[j] * (x[j] - mean[j])) - dscale[j];\n    dscale[j] += delta0 * inv_count;\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int N>\nDEVICE_FUNCTION void bwd_dx(float (&dx)[N], const float (&dy)[N], const float (&var)[N], const float (&x)[N],\n                            const float (&mean)[N], const float (&dscale)[N], const float (&dbias)[N],\n                            float inv_count) {\n#pragma unroll\n  for (int j = 0; j < N; ++j) {\n    float tmp1 = dy[j] - (dbias[j] * inv_count);\n    float tmp2 = dscale[j] * inv_count;\n    float tmp3 = x[j] - mean[j];\n    dx[j] = var[j] * (tmp1 - (tmp2 * tmp3));\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Storage, int THREADS_PER_CTA, int THREADS_PER_PIXEL, int PIXELS_PER_THREAD_IN_REGISTERS,\n          int PIXELS_PER_THREAD_IN_SMEM, int ELEMENTS_PER_LDG, int USE_ONLINE_APPROACH, int OUTER_LOOPS_,\n          int DESIRED_OCCUPANCY>\n__global__ __launch_bounds__(THREADS_PER_CTA,\n                             DESIRED_OCCUPANCY) void nhwc_batch_norm_bwd(NhwcBatchNormBwdParams params) {\n  // The number of pixels loaded in a single LDG.\n  const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n  // The number of pixels computed per CTA stored in registers.\n  const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;\n  // The number of pixels computed per CTA stored in SMEM.\n  const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM * PIXELS_PER_LDG;\n  // The number of C elements per CTA.\n  const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL * ELEMENTS_PER_LDG;\n\n  // Shared memory to do CTA-wide parallel sums.\n  __shared__ float smem[THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG];\n\n  // The adapter for the storage.\n  typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;\n  // The data type for packed storage in SMEM.\n  typedef typename PackedStorage_::Type PackedStorageType;\n  // The number of elements in the packed storage.\n  const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;\n  // Registers to keep the data live for the persistent approach.\n  PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n  PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n\n  // Shared memory buffer to store the extra pixels.\n  extern __shared__ PackedStorageType smem_storage_packed[];\n\n  for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {\n    // The position in the NHW dimension where the CTA starts.\n    int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;\n    // The position in the NHW dimension where the CTA starts for the portion in SMEM.\n    int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;\n    // Compute the NHW coordinate of the thread in the CTA.\n    const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;\n\n    // The position in the C dimension where the CTA starts.\n    const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;\n    // Compute the C coordinate of the thread in the CTA.\n    const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;\n    // Compute the C coordinate of the thread.\n    const int thread_c = cta_c + thread_in_cta_c * ELEMENTS_PER_LDG;\n\n    // Is the thread working on a valid C dimension?\n    const int is_valid_c = thread_c < params.c;\n\n    // Registers to store the mean used for entire duration\n    float mean[ELEMENTS_PER_LDG];\n    zero_array(mean);\n    if (is_valid_c) {\n      read_from_gmem(mean, params.gmem_saved_mean, thread_c / ELEMENTS_PER_LDG);\n    }\n\n    // accumulation related registers\n    float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG];\n    zero_array(dscale);\n    zero_array(dbias);\n\n    // The number of elements loaded by this CTA.\n    int cta_count = 0;\n    // The base pointers to load from.\n    const uint16_t* gmem_src = &params.gmem_src[thread_c];\n    const uint16_t* gmem_dy = &params.gmem_dy[thread_c];\n\n    // outer loops\n    int OUTER_LOOPS = OUTER_LOOPS_ == 1 ? 1 : params.outer_loops;\n    // Load the batch of elements. Compute sum across them\n    const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS * gridDim.x;\n\n    if (OUTER_LOOPS_ != 1) {\n      // We cannot load everything to store persistently, so let's makes sure registers and\n      // smem are fully utilized\n      int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS - PIXELS_PER_CTA_IN_SMEM * gridDim.x;\n      cta_nhw_regs += offset;\n      cta_nhw_smem += offset;\n    }\n\n#pragma unroll 1\n    for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {\n      // The nhw position.\n      int nhw_regs = cta_nhw_regs + loop_i * pixels_per_iteration;\n      // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!\n      cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw - nhw_regs));\n\n      // Read the elements from memory.\n      float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];\n#pragma unroll\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n        const int idx = nhw_regs + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        zero_array(x_storage[i]);\n        zero_array(dy_storage[i]);\n        is_valid[i] = 0.f;\n        if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n          if (loop_i == OUTER_LOOPS - 1) {\n            ldg_stream(x_storage[i], &gmem_src[idx * params.c]);\n            ldg_stream(dy_storage[i], &gmem_dy[idx * params.c]);\n          } else {\n            ldg(x_storage[i], &gmem_src[idx * params.c]);\n            ldg(dy_storage[i], &gmem_dy[idx * params.c]);\n          }\n          is_valid[i] = 1.f;\n        }\n      }\n\n// Do the math.\n#pragma unroll\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n        // Convert to float and update\n        float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n        to_float(x_math, x_storage[i]);\n        to_float(dy_math, dy_storage[i]);\n\n        // Update the count.\n        count += is_valid[i];\n        // Invert the count.\n        float inv_count = is_valid[i] ? 1.f / count : 0.f;\n\n        bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);\n      }\n    }\n\n    // The elements to load and store in SMEM.\n    int smem_nhw = OUTER_LOOPS * pixels_per_iteration + cta_nhw_smem;\n    // Load elements from SMEM, update the CTA count.\n    int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw - smem_nhw);\n    if (pixels_in_smem > 0) {\n      cta_count += pixels_in_smem;\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n        const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        bool is_pixel_valid = (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c);\n        PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], dy_storage_local[PACKED_ELEMENTS_PER_LDG];\n        zero_array(x_storage_local);\n        zero_array(dy_storage_local);\n        if (is_pixel_valid) {\n          ldg_stream(x_storage_local, &gmem_src[idx * params.c]);\n          ldg_stream(dy_storage_local, &gmem_dy[idx * params.c]);\n        }\n\n        // The offset to store in SMEM.\n        int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG;\n        // Store in SMEM.\n        write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);\n        offset += PIXELS_PER_THREAD_IN_SMEM * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG;\n        write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local);\n        // Update the count.\n        count += is_pixel_valid;\n        // Invert the count.\n        float inv_count = is_pixel_valid ? 1.f / count : 0.f;\n\n        float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n        to_float(x_math, x_storage_local);\n        to_float(dy_math, dy_storage_local);\n\n        bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);\n      }\n    }\n\n// We scale the mean by the number of elements. It brings more stability.\n#pragma unroll\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n      dbias[i] *= count;\n      dscale[i] *= count;\n    }\n\n    // dscale parallel sum\n    ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(smem, dscale, thread_in_cta_nhw);\n    __syncthreads();\n    // The values in shared memory correspond to the CTA-wide sums.\n    read_from_smem(dscale, smem, thread_in_cta_c);\n    __syncthreads();\n\n    // dbias parallel sum\n    ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(smem, dbias, thread_in_cta_nhw);\n    __syncthreads();\n    // The values in shared memory correspond to the CTA-wide sums.\n    read_from_smem(dbias, smem, thread_in_cta_c);\n    __syncthreads();\n\n    // The workspace in global memory is distributed across the different CTA.\n    int gmem_sums_offset = c_blk_index * gridDim.x * C_ELEMENTS_PER_CTA * 2;\n    // Write the data for the CTA to global memory.\n    float* gmem_sums = &params.gmem_sums[gmem_sums_offset];\n    if (threadIdx.x < THREADS_PER_PIXEL) {\n      const int idx = blockIdx.x * THREADS_PER_PIXEL + threadIdx.x;\n      write_to_gmem(&gmem_sums[0], idx, dscale);\n      write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA * gridDim.x], idx, dbias);\n    }\n\n    // The counters to count how many CTAs have retired at this point.\n    // A given cta uses the same counter every other time through the outer loop.\n    int* gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];\n    inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);\n\n    // Reset the accumulators for global summation\n    zero_array(dscale);\n    zero_array(dbias);\n\n// Build the global accumulation\n#pragma unroll 1\n    for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL * gridDim.x; idx += THREADS_PER_CTA) {\n      float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG];\n      read_from_gmem(tmp1, gmem_sums, idx);\n      read_from_gmem(tmp2, gmem_sums + C_ELEMENTS_PER_CTA * gridDim.x, idx);\n\n#pragma unroll\n      for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n        dscale[i] += tmp1[i];\n        dbias[i] += tmp2[i];\n      }\n    }\n\n    // dscale parallel sum\n    if (params.sync_iters > 0) {\n      ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n          smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 1, params.magic,\n          params.sync_iters);\n    } else {\n      ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(smem, dscale, thread_in_cta_nhw);\n    }\n\n    __syncthreads();\n    // The values in shared memory correspond to the CTA-wide sums.\n    read_from_smem(dscale, smem, thread_in_cta_c);\n    __syncthreads();\n\n    // dbias parallel sum\n    if (params.sync_iters > 0) {\n      ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n          smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 0, params.magic,\n          params.sync_iters);\n    } else {\n      ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(smem, dbias, thread_in_cta_nhw);\n    }\n\n    __syncthreads();\n    // The values in shared memory correspond to the CTA-wide sums.\n    read_from_smem(dbias, smem, thread_in_cta_c);\n\n    // inv-var\n    float var[ELEMENTS_PER_LDG];\n    zero_array(var);\n    if (is_valid_c) {\n      read_from_gmem(var, params.gmem_saved_var, thread_c / ELEMENTS_PER_LDG);\n    }\n\n    // Normalize the dscale.\n    multiply(dscale, var);\n\n    // store dscale/dbias\n    bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;\n    if (is_valid_for_saving) {\n      if (params.sync_iters > 0) {\n        scaled_write_to_gmem(params.gmem_dscale, thread_c / ELEMENTS_PER_LDG, dscale, params.wgrad_coeff);\n        scaled_write_to_gmem(params.gmem_dbias, thread_c / ELEMENTS_PER_LDG, dbias, params.wgrad_coeff);\n      } else {\n        write_to_gmem(params.gmem_dscale, thread_c / ELEMENTS_PER_LDG, dscale);\n        write_to_gmem(params.gmem_dbias, thread_c / ELEMENTS_PER_LDG, dbias);\n      }\n    }\n\n    // scale\n    float scale[ELEMENTS_PER_LDG];\n    zero_array(scale);\n    if (is_valid_c) {\n      read_from_gmem(scale, params.gmem_scale, thread_c / ELEMENTS_PER_LDG);\n    }\n\n    // Further normalize the dscale to be used in dx calculation\n    multiply(dscale, var);\n    // scale the inv-var as well, afterwards\n    multiply(var, scale);\n\n    // inverse count\n    float inv_count = params.svar_inv_count;\n\n    // The base pointer to write to.\n    uint16_t* const gmem_dst = &params.gmem_dst[thread_c];\n\n// Store the elements in registers.\n#pragma unroll 1\n    for (int loop_i = OUTER_LOOPS - 1; loop_i >= 0; --loop_i) {\n      // The value for nhw.\n      int out_nhw = cta_nhw_regs + loop_i * pixels_per_iteration;\n\n// Normalize the elements and write to memory.\n#pragma unroll\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n        // Convert to float.\n        float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n        to_float(x_math, x_storage[i]);\n        to_float(dy_math, dy_storage[i]);\n\n        float dx[ELEMENTS_PER_LDG];\n        bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);\n\n        // Write back.\n        const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n          stg_stream(&gmem_dst[idx * params.c], dx);\n        }\n      }\n\n      // The next value of nhw.\n      out_nhw -= pixels_per_iteration;\n\n// Read the next elements from memory.\n#pragma unroll\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n        const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n          ldg_stream(x_storage[i], &gmem_src[idx * params.c]);\n          ldg_stream(dy_storage[i], &gmem_dy[idx * params.c]);\n        }\n      }\n    }\n\n    // Normalize the elements from SMEM and write them out.\n    if (pixels_in_smem > 0) {\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n        const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;\n        if (is_valid) {\n          // Read from SMEM.\n          int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG;\n          PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], dy_storage_local[PACKED_ELEMENTS_PER_LDG];\n          read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);\n          offset += PIXELS_PER_THREAD_IN_SMEM * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG;\n          read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x);\n          float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n          to_float(x_math, x_storage_local);\n          to_float(dy_math, dy_storage_local);\n\n          float dx[ELEMENTS_PER_LDG];\n          bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);\n\n          // Write back.\n          stg_stream(&gmem_dst[idx * params.c], dx);\n        }\n      }\n    }\n    // We're about to start on the next c-blk.  Needed?\n    __syncthreads();\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Storage, int THREADS_PER_CTA, int THREADS_PER_PIXEL, int PIXELS_PER_THREAD_IN_REGISTERS,\n          int PIXELS_PER_THREAD_IN_SMEM, int ELEMENTS_PER_LDG, int USE_ONLINE_APPROACH, int OUTER_LOOPS_,\n          int DESIRED_OCCUPANCY>\n__global__ __launch_bounds__(THREADS_PER_CTA,\n                             DESIRED_OCCUPANCY) void nhwc_batch_norm_bwd_relu(NhwcBatchNormBwdParams params) {\n  // The number of pixels loaded in a single LDG.\n  const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n  // The number of pixels computed per CTA stored in registers.\n  const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;\n  // The number of pixels computed per CTA stored in SMEM.\n  const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM * PIXELS_PER_LDG;\n  // The number of C elements per CTA.\n  const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL * ELEMENTS_PER_LDG;\n\n  // Shared memory to do CTA-wide parallel sums.\n  __shared__ float smem[THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG];\n\n  // The adapter for the storage.\n  typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;\n  // The data type for packed storage in SMEM.\n  typedef typename PackedStorage_::Type PackedStorageType;\n  // The number of elements in the packed storage.\n  const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;\n  // Registers to keep the data live for the persistent approach.\n  PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n  PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n\n  // Shared memory buffer to store the extra pixels.\n  extern __shared__ PackedStorageType smem_storage_packed[];\n\n  for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {\n    // The position in the NHW dimension where the CTA starts.\n    int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;\n    // The position in the NHW dimension where the CTA starts for the portion in SMEM.\n    int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;\n    // Compute the NHW coordinate of the thread in the CTA.\n    const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;\n\n    // The position in the C dimension where the CTA starts.\n    const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;\n    // Compute the C coordinate of the thread in the CTA.\n    const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;\n    // Compute the C coordinate of the thread.\n    const int thread_c = cta_c + thread_in_cta_c * ELEMENTS_PER_LDG;\n\n    // Is the thread working on a valid C dimension?\n    const int is_valid_c = thread_c < params.c;\n\n    // Registers to store the mean/var/scale/bias used for the entire duration\n    // Register usage optimizations:\n    // 1. Can combine bias - (mean * var * scale) into a single register\n    // 2. Can combine var * scale into a single register\n    float varscale[ELEMENTS_PER_LDG];\n    zero_array(varscale);\n    if (is_valid_c) {\n      read_from_gmem(varscale, params.gmem_saved_var, thread_c / ELEMENTS_PER_LDG);\n    }\n    float tmp[ELEMENTS_PER_LDG];\n    zero_array(tmp);\n    if (is_valid_c) {\n      read_from_gmem(tmp, params.gmem_scale, thread_c / ELEMENTS_PER_LDG);\n    }\n    multiply(varscale, tmp);\n    float mean[ELEMENTS_PER_LDG];\n    zero_array(mean);\n    if (is_valid_c) {\n      read_from_gmem(mean, params.gmem_saved_mean, thread_c / ELEMENTS_PER_LDG);\n    }\n    zero_array(tmp);\n    if (is_valid_c) {\n      read_from_gmem(tmp, params.gmem_bias, thread_c / ELEMENTS_PER_LDG);\n    }\n    float mean_var_scale_bias[ELEMENTS_PER_LDG];\n#pragma unroll\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n      mean_var_scale_bias[i] = tmp[i] - (mean[i] * varscale[i]);\n    }\n\n    // accumulation related registers\n    float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG];\n    zero_array(dscale);\n    zero_array(dbias);\n\n    // The number of elements loaded by this CTA.\n    int cta_count = 0;\n    // The base pointers to load from.\n    const uint16_t* gmem_src = &params.gmem_src[thread_c];\n    const uint16_t* gmem_dy = &params.gmem_dy[thread_c];\n\n    // outer loops\n    int OUTER_LOOPS = OUTER_LOOPS_ == 1 ? 1 : params.outer_loops;\n    // Load the batch of elements. Compute sum across them\n    const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS * gridDim.x;\n\n    if (OUTER_LOOPS_ != 1) {\n      // We cannot load everything to store persistently, so let's makes sure registers and\n      // smem are fully utilized\n      int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS - PIXELS_PER_CTA_IN_SMEM * gridDim.x;\n      cta_nhw_regs += offset;\n      cta_nhw_smem += offset;\n    }\n\n#pragma unroll 1\n    for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {\n      // The nhw position.\n      int nhw_regs = cta_nhw_regs + loop_i * pixels_per_iteration;\n      // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!\n      cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw - nhw_regs));\n\n      // Read the elements from memory.\n      float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];\n#pragma unroll\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n        const int idx = nhw_regs + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        zero_array(x_storage[i]);\n        zero_array(dy_storage[i]);\n        is_valid[i] = 0.f;\n        if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n          if (loop_i == OUTER_LOOPS - 1) {\n            ldg_stream(x_storage[i], &gmem_src[idx * params.c]);\n            ldg_stream(dy_storage[i], &gmem_dy[idx * params.c]);\n          } else {\n            ldg(x_storage[i], &gmem_src[idx * params.c]);\n            ldg(dy_storage[i], &gmem_dy[idx * params.c]);\n          }\n          is_valid[i] = 1.f;\n        }\n      }\n\n// Do the math.\n#pragma unroll\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n        // Convert to float and update\n        float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n        to_float(x_math, x_storage[i]);\n        to_float(dy_math, dy_storage[i]);\n\n        // Update the count.\n        count += is_valid[i];\n        // Invert the count.\n        float inv_count = is_valid[i] ? 1.f / count : 0.f;\n\n        relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_valid[i]);\n        bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);\n      }\n    }\n\n    // The elements to load and store in SMEM.\n    int smem_nhw = OUTER_LOOPS * pixels_per_iteration + cta_nhw_smem;\n    // Load elements from SMEM, update the CTA count.\n    int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw - smem_nhw);\n    if (pixels_in_smem > 0) {\n      cta_count += pixels_in_smem;\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n        const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        bool is_pixel_valid = (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c);\n        PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], dy_storage_local[PACKED_ELEMENTS_PER_LDG];\n        zero_array(x_storage_local);\n        zero_array(dy_storage_local);\n        if (is_pixel_valid) {\n          ldg_stream(x_storage_local, &gmem_src[idx * params.c]);\n          ldg_stream(dy_storage_local, &gmem_dy[idx * params.c]);\n        }\n\n        // The offset to store in SMEM.\n        int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG;\n        // Store in SMEM.\n        write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);\n        offset += PIXELS_PER_THREAD_IN_SMEM * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG;\n        write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local);\n        // Update the count.\n        count += is_pixel_valid;\n        // Invert the count.\n        float inv_count = is_pixel_valid ? 1.f / count : 0.f;\n\n        float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n        to_float(x_math, x_storage_local);\n        to_float(dy_math, dy_storage_local);\n\n        relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_pixel_valid);\n        bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);\n      }\n    }\n\n// We scale the mean by the number of elements. It brings more stability.\n#pragma unroll\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n      dbias[i] *= count;\n      dscale[i] *= count;\n    }\n\n    // dscale parallel sum\n    ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(smem, dscale, thread_in_cta_nhw);\n    __syncthreads();\n    // The values in shared memory correspond to the CTA-wide sums.\n    read_from_smem(dscale, smem, thread_in_cta_c);\n    __syncthreads();\n\n    // dbias parallel sum\n    ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(smem, dbias, thread_in_cta_nhw);\n    __syncthreads();\n    // The values in shared memory correspond to the CTA-wide sums.\n    read_from_smem(dbias, smem, thread_in_cta_c);\n    __syncthreads();\n\n    // The workspace in global memory is distributed across the different CTA.\n    int gmem_sums_offset = c_blk_index * gridDim.x * C_ELEMENTS_PER_CTA * 2;\n    // Write the data for the CTA to global memory.\n    float* gmem_sums = &params.gmem_sums[gmem_sums_offset];\n    if (threadIdx.x < THREADS_PER_PIXEL) {\n      const int idx = blockIdx.x * THREADS_PER_PIXEL + threadIdx.x;\n      write_to_gmem(&gmem_sums[0], idx, dscale);\n      write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA * gridDim.x], idx, dbias);\n    }\n\n    // The counters to count how many CTAs have retired at this point.\n    // A given cta uses the same counter every other time through the outer loop.\n    int* gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];\n    inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);\n\n    // Reset the accumulators for global summation\n    zero_array(dscale);\n    zero_array(dbias);\n\n// Build the global accumulation\n#pragma unroll 1\n    for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL * gridDim.x; idx += THREADS_PER_CTA) {\n      float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG];\n      read_from_gmem(tmp1, gmem_sums, idx);\n      read_from_gmem(tmp2, gmem_sums + C_ELEMENTS_PER_CTA * gridDim.x, idx);\n\n#pragma unroll\n      for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n        dscale[i] += tmp1[i];\n        dbias[i] += tmp2[i];\n      }\n    }\n\n    // dscale parallel sum\n    if (params.sync_iters > 0) {\n      ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n          smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 1, params.magic,\n          params.sync_iters);\n    } else {\n      ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(smem, dscale, thread_in_cta_nhw);\n    }\n\n    __syncthreads();\n    // The values in shared memory correspond to the CTA-wide sums.\n    read_from_smem(dscale, smem, thread_in_cta_c);\n    __syncthreads();\n\n    // dbias parallel sum\n    if (params.sync_iters > 0) {\n      ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n          smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 0, params.magic,\n          params.sync_iters);\n    } else {\n      ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(smem, dbias, thread_in_cta_nhw);\n    }\n\n    __syncthreads();\n    // The values in shared memory correspond to the CTA-wide sums.\n    read_from_smem(dbias, smem, thread_in_cta_c);\n\n    // Normalize the dscale.\n    float var[ELEMENTS_PER_LDG];\n    zero_array(var);\n    if (is_valid_c) {\n      read_from_gmem(var, params.gmem_saved_var, thread_c / ELEMENTS_PER_LDG);\n    }\n    multiply(dscale, var);\n\n    // store dscale/dbias\n    bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;\n    if (is_valid_for_saving) {\n      if (params.sync_iters > 0) {\n        scaled_write_to_gmem(params.gmem_dscale, thread_c / ELEMENTS_PER_LDG, dscale, params.wgrad_coeff);\n        scaled_write_to_gmem(params.gmem_dbias, thread_c / ELEMENTS_PER_LDG, dbias, params.wgrad_coeff);\n      } else {\n        write_to_gmem(params.gmem_dscale, thread_c / ELEMENTS_PER_LDG, dscale);\n        write_to_gmem(params.gmem_dbias, thread_c / ELEMENTS_PER_LDG, dbias);\n      }\n    }\n\n    // Further normalize the dscale to be used in dx calculation\n    float scale[ELEMENTS_PER_LDG];\n    zero_array(scale);\n    if (is_valid_c) {\n      read_from_gmem(scale, params.gmem_scale, thread_c / ELEMENTS_PER_LDG);\n    }\n    multiply(dscale, var);\n    // scale the inv-var as well, afterwards\n    multiply(var, scale);\n\n    // inverse count\n    float inv_count = params.svar_inv_count;\n\n    // The base pointer to write to.\n    uint16_t* const gmem_dst = &params.gmem_dst[thread_c];\n\n// Store the elements in registers.\n#pragma unroll 1\n    for (int loop_i = OUTER_LOOPS - 1; loop_i >= 0; --loop_i) {\n      // The value for nhw.\n      int out_nhw = cta_nhw_regs + loop_i * pixels_per_iteration;\n\n// Normalize the elements and write to memory.\n#pragma unroll\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n        // Convert to float.\n        float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n        to_float(x_math, x_storage[i]);\n        to_float(dy_math, dy_storage[i]);\n        relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var);\n\n        float dx[ELEMENTS_PER_LDG];\n        bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);\n\n        // Write back.\n        const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n          stg_stream(&gmem_dst[idx * params.c], dx);\n        }\n      }\n\n      // The next value of nhw.\n      out_nhw -= pixels_per_iteration;\n\n// Read the next elements from memory.\n#pragma unroll\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n        const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n          ldg_stream(x_storage[i], &gmem_src[idx * params.c]);\n          ldg_stream(dy_storage[i], &gmem_dy[idx * params.c]);\n        }\n      }\n    }\n\n    // Normalize the elements from SMEM and write them out.\n    if (pixels_in_smem > 0) {\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n        const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;\n        if (is_valid) {\n          // Read from SMEM.\n          int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG;\n          PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], dy_storage_local[PACKED_ELEMENTS_PER_LDG];\n          read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);\n          offset += PIXELS_PER_THREAD_IN_SMEM * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG;\n          read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x);\n          float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n          to_float(x_math, x_storage_local);\n          to_float(dy_math, dy_storage_local);\n          relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var);\n\n          float dx[ELEMENTS_PER_LDG];\n          bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);\n\n          // Write back.\n          stg_stream(&gmem_dst[idx * params.c], dx);\n        }\n      }\n    }\n    // We're about to start on the next c-blk.  Needed?\n    __syncthreads();\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Storage, int THREADS_PER_CTA, int THREADS_PER_PIXEL, int PIXELS_PER_THREAD_IN_REGISTERS,\n          int PIXELS_PER_THREAD_IN_SMEM, int ELEMENTS_PER_LDG, int USE_ONLINE_APPROACH, int OUTER_LOOPS_,\n          int DESIRED_OCCUPANCY>\n__global__ __launch_bounds__(THREADS_PER_CTA,\n                             DESIRED_OCCUPANCY) void nhwc_batch_norm_bwd_add_relu(NhwcBatchNormBwdParams params) {\n  // The number of pixels loaded in a single LDG.\n  const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;\n  // The number of pixels computed per CTA stored in registers.\n  const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;\n  // The number of pixels computed per CTA stored in SMEM.\n  const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM * PIXELS_PER_LDG;\n  // The number of C elements per CTA.\n  const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL * ELEMENTS_PER_LDG;\n\n  // Shared memory to do CTA-wide parallel sums.\n  __shared__ float smem[THREADS_PER_PIXEL * (THREADS_PER_CTA / 32) * ELEMENTS_PER_LDG];\n\n  // The adapter for the storage.\n  typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;\n  // The data type for packed storage in SMEM.\n  typedef typename PackedStorage_::Type PackedStorageType;\n  // The number of elements in the packed storage.\n  const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;\n  // Registers to keep the data live for the persistent approach.\n  PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n  PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];\n\n  // Shared memory buffer to store the extra pixels.\n  extern __shared__ PackedStorageType smem_storage_packed[];\n\n  for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {\n    // The position in the NHW dimension where the CTA starts.\n    int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;\n    // The position in the NHW dimension where the CTA starts for the portion in SMEM.\n    int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;\n    // Compute the NHW coordinate of the thread in the CTA.\n    const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;\n\n    // The position in the C dimension where the CTA starts.\n    const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;\n    // Compute the C coordinate of the thread in the CTA.\n    const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;\n    // Compute the C coordinate of the thread.\n    const int thread_c = cta_c + thread_in_cta_c * ELEMENTS_PER_LDG;\n\n    // Is the thread working on a valid C dimension?\n    const int is_valid_c = thread_c < params.c;\n\n    float mean[ELEMENTS_PER_LDG];\n    zero_array(mean);\n    if (is_valid_c) {\n      read_from_gmem(mean, params.gmem_saved_mean, thread_c / ELEMENTS_PER_LDG);\n    }\n\n    // accumulation related registers\n    float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG];\n    zero_array(dscale);\n    zero_array(dbias);\n\n    // The number of elements loaded by this CTA.\n    int cta_count = 0;\n    // The base pointers to load from.\n    const uint16_t* gmem_src = &params.gmem_src[thread_c];\n    const uint16_t* gmem_dy = &params.gmem_dy[thread_c];\n    uint16_t* gmem_dst1 = &params.gmem_dst1[thread_c];\n\n    // outer loops\n    int OUTER_LOOPS = OUTER_LOOPS_ == 1 ? 1 : params.outer_loops;\n    // Load the batch of elements. Compute sum across them\n    const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS * gridDim.x;\n\n    if (OUTER_LOOPS_ != 1) {\n      // We cannot load everything to store persistently, so let's makes sure registers and\n      // smem are fully utilized, offset is evenly divisible by 32\n      int offset = (pixels_per_iteration * OUTER_LOOPS + PIXELS_PER_CTA_IN_SMEM * gridDim.x - params.nhw) & ~31;\n      cta_nhw_regs -= offset;\n      cta_nhw_smem -= offset;\n    }\n\n    const unsigned int* const gmem_relu_bitmask =\n        params.gmem_relu_bitmask + ((params.nhw + 31) & ~31) * 2 * c_blk_index;\n\n#pragma unroll 1\n    for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {\n      // The nhw position.\n      int nhw_regs = cta_nhw_regs + loop_i * pixels_per_iteration;\n      // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!\n      cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw - nhw_regs));\n\n      int lane_id = threadIdx.x & 31;\n\n      // Read the elements from memory.\n      float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];\n      unsigned int relu_mask[PIXELS_PER_THREAD_IN_REGISTERS];\n#pragma unroll\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n        const int idx = nhw_regs + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        zero_array(x_storage[i]);\n        zero_array(dy_storage[i]);\n        is_valid[i] = 0.f;\n        const bool is_valid_nhw = static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);\n        if (is_valid_nhw) {\n          if (is_valid_c) {\n            if (loop_i == OUTER_LOOPS - 1) {\n              ldg_stream(x_storage[i], &gmem_src[idx * params.c]);\n              ldg_stream(dy_storage[i], &gmem_dy[idx * params.c]);\n            } else {\n              ldg(x_storage[i], &gmem_src[idx * params.c]);\n              ldg(dy_storage[i], &gmem_dy[idx * params.c]);\n            }\n            is_valid[i] = 1.f;\n          }\n\n          if (lane_id < ELEMENTS_PER_LDG) {\n            relu_mask[i] = gmem_relu_bitmask[idx * 2 + lane_id];\n          }\n        }\n      }\n\n// Do the math.\n#pragma unroll\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n        const int idx = nhw_regs + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        // Convert to float and update\n        float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n        bool rectified[ELEMENTS_PER_LDG];\n#pragma unroll\n        for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n          rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask[i], j) & (1U << lane_id)) != 0);\n        }\n        to_float(x_math, x_storage[i]);\n        to_float(dy_math, dy_storage[i]);\n\n        // Update the count.\n        count += is_valid[i];\n        // Invert the count.\n        float inv_count = is_valid[i] ? 1.f / count : 0.f;\n\n        relu_bwd(dy_math, rectified, is_valid[i]);\n        bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);\n\n        // Lastly we need 'dy' only for BN, so store the 'relu-dgrad'ed version\n        from_float(dy_storage[i], dy_math);\n\n        // dZ for elementwise add\n        if (is_valid[i]) {\n          if (loop_i == OUTER_LOOPS - 1) {\n            stg_stream(&gmem_dst1[idx * params.c], dy_storage[i]);\n          } else {\n            stg(&gmem_dst1[idx * params.c], dy_storage[i]);\n          }\n        }\n      }\n    }\n\n    // The elements to load and store in SMEM.\n    int smem_nhw = OUTER_LOOPS * pixels_per_iteration + cta_nhw_smem;\n    // Load elements from SMEM, update the CTA count.\n    int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw - smem_nhw);\n    if (pixels_in_smem > 0) {\n      cta_count += pixels_in_smem;\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n        const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        const bool is_pixel_valid_nhw = static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);\n        const bool is_pixel_valid = is_pixel_valid_nhw && is_valid_c;\n        PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], dy_storage_local[PACKED_ELEMENTS_PER_LDG];\n        unsigned int relu_mask;\n        int lane_id = threadIdx.x & 31;\n        zero_array(x_storage_local);\n        zero_array(dy_storage_local);\n        if (is_pixel_valid_nhw) {\n          if (is_valid_c) {\n            ldg_stream(x_storage_local, &gmem_src[idx * params.c]);\n            ldg_stream(dy_storage_local, &gmem_dy[idx * params.c]);\n          }\n          if (lane_id < ELEMENTS_PER_LDG) {\n            relu_mask = gmem_relu_bitmask[idx * 2 + lane_id];\n          }\n        }\n        bool rectified[ELEMENTS_PER_LDG];\n#pragma unroll\n        for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {\n          rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask, j) & (1U << lane_id)) != 0);\n        }\n\n        // The offset to store in SMEM.\n        int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG;\n        // Store in SMEM.\n        write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);\n        offset += PIXELS_PER_THREAD_IN_SMEM * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG;\n        // Update the count.\n        count += is_pixel_valid;\n        // Invert the count.\n        float inv_count = is_pixel_valid ? 1.f / count : 0.f;\n\n        float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n        to_float(x_math, x_storage_local);\n        to_float(dy_math, dy_storage_local);\n\n        relu_bwd(dy_math, rectified, is_pixel_valid);\n        bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);\n\n        from_float(dy_storage_local, dy_math);\n        // dZ for elementwise add\n        if (is_pixel_valid) {\n          stg_stream(&gmem_dst1[idx * params.c], dy_storage_local);\n        }\n        // only store the 'relu-dgrad'ed version!\n        write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local);\n      }\n    }\n\n// We scale the mean by the number of elements. It brings more stability.\n#pragma unroll\n    for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n      dbias[i] *= count;\n      dscale[i] *= count;\n    }\n\n    // dscale parallel sum\n    ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(smem, dscale, thread_in_cta_nhw);\n    __syncthreads();\n    // The values in shared memory correspond to the CTA-wide sums.\n    read_from_smem(dscale, smem, thread_in_cta_c);\n    __syncthreads();\n\n    // dbias parallel sum\n    ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(smem, dbias, thread_in_cta_nhw);\n    __syncthreads();\n    // The values in shared memory correspond to the CTA-wide sums.\n    read_from_smem(dbias, smem, thread_in_cta_c);\n    __syncthreads();\n\n    // The workspace in global memory is distributed across the different CTA.\n    int gmem_sums_offset = c_blk_index * gridDim.x * C_ELEMENTS_PER_CTA * 2;\n    // Write the data for the CTA to global memory.\n    float* gmem_sums = &params.gmem_sums[gmem_sums_offset];\n    if (threadIdx.x < THREADS_PER_PIXEL) {\n      const int idx = blockIdx.x * THREADS_PER_PIXEL + threadIdx.x;\n      write_to_gmem(&gmem_sums[0], idx, dscale);\n      write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA * gridDim.x], idx, dbias);\n    }\n\n    // The counters to count how many CTAs have retired at this point.\n    // A given cta uses the same counter every other time through the outer loop.\n    int* gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];\n    inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);\n\n    // Reset the accumulators for global summation\n    zero_array(dscale);\n    zero_array(dbias);\n\n// Build the global accumulation\n#pragma unroll 1\n    for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL * gridDim.x; idx += THREADS_PER_CTA) {\n      float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG];\n      read_from_gmem(tmp1, gmem_sums, idx);\n      read_from_gmem(tmp2, gmem_sums + C_ELEMENTS_PER_CTA * gridDim.x, idx);\n\n#pragma unroll\n      for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {\n        dscale[i] += tmp1[i];\n        dbias[i] += tmp2[i];\n      }\n    }\n\n    // dscale parallel sum\n    if (params.sync_iters > 0) {\n      ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n          smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 1, params.magic,\n          params.sync_iters);\n    } else {\n      ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(smem, dscale, thread_in_cta_nhw);\n    }\n\n    __syncthreads();\n    // The values in shared memory correspond to the CTA-wide sums.\n    read_from_smem(dscale, smem, thread_in_cta_c);\n    __syncthreads();\n\n    // dbias parallel sum\n    if (params.sync_iters > 0) {\n      ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(\n          smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4 * c_blk_index + 0, params.magic,\n          params.sync_iters);\n    } else {\n      ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(smem, dbias, thread_in_cta_nhw);\n    }\n\n    __syncthreads();\n    // The values in shared memory correspond to the CTA-wide sums.\n    read_from_smem(dbias, smem, thread_in_cta_c);\n\n    // Normalize the dscale.\n    float var[ELEMENTS_PER_LDG];\n    zero_array(var);\n    if (is_valid_c) {\n      read_from_gmem(var, params.gmem_saved_var, thread_c / ELEMENTS_PER_LDG);\n    }\n    multiply(dscale, var);\n\n    // store dscale/dbias\n    bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;\n    if (is_valid_for_saving) {\n      if (params.sync_iters > 0) {\n        scaled_write_to_gmem(params.gmem_dscale, thread_c / ELEMENTS_PER_LDG, dscale, params.wgrad_coeff);\n        scaled_write_to_gmem(params.gmem_dbias, thread_c / ELEMENTS_PER_LDG, dbias, params.wgrad_coeff);\n      } else {\n        write_to_gmem(params.gmem_dscale, thread_c / ELEMENTS_PER_LDG, dscale);\n        write_to_gmem(params.gmem_dbias, thread_c / ELEMENTS_PER_LDG, dbias);\n      }\n    }\n\n    // Further normalize the dscale to be used in dx calculation\n    float scale[ELEMENTS_PER_LDG];\n    zero_array(scale);\n    if (is_valid_c) {\n      read_from_gmem(scale, params.gmem_scale, thread_c / ELEMENTS_PER_LDG);\n    }\n    multiply(dscale, var);\n    // scale the inv-var as well, afterwards\n    multiply(var, scale);\n\n    // inverse count\n    float inv_count = params.svar_inv_count;\n\n    // The base pointer to write to.\n    uint16_t* const gmem_dst = &params.gmem_dst[thread_c];\n\n// Store the elements in registers.\n#pragma unroll 1\n    for (int loop_i = OUTER_LOOPS - 1; loop_i >= 0; --loop_i) {\n      // The value for nhw.\n      int out_nhw = cta_nhw_regs + loop_i * pixels_per_iteration;\n\n// Normalize the elements and write to memory.\n#pragma unroll\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n        const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;\n        // Convert to float.\n        float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n        to_float(x_math, x_storage[i]);\n        to_float(dy_math, dy_storage[i]);\n\n        float dx[ELEMENTS_PER_LDG];\n        bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);\n\n        // Write back.\n        if (is_valid) {\n          stg_stream(&gmem_dst[idx * params.c], dx);\n        }\n      }\n\n      // The next value of nhw.\n      out_nhw -= pixels_per_iteration;\n\n// Read the next elements from memory.\n#pragma unroll\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {\n        const int idx = out_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        float y[ELEMENTS_PER_LDG];\n        zero_array(y);\n        if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {\n          ldg_stream(x_storage[i], &gmem_src[idx * params.c]);\n          ldg_stream(dy_storage[i], &gmem_dst1[idx * params.c]);\n        }\n      }\n    }\n\n    // Normalize the elements from SMEM and write them out.\n    if (pixels_in_smem > 0) {\n      for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {\n        const int idx = smem_nhw + thread_in_cta_nhw + i * PIXELS_PER_LDG;\n        const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;\n        if (is_valid) {\n          // Read from SMEM.\n          int offset = i * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG;\n          PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], dy_storage_local[PACKED_ELEMENTS_PER_LDG];\n          read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);\n          offset += PIXELS_PER_THREAD_IN_SMEM * THREADS_PER_CTA * PACKED_ELEMENTS_PER_LDG;\n          read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x);\n          float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];\n          to_float(x_math, x_storage_local);\n          to_float(dy_math, dy_storage_local);\n\n          float dx[ELEMENTS_PER_LDG];\n          bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);\n\n          // Write back.\n          stg_stream(&gmem_dst[idx * params.c], dx);\n        }\n      }\n    }\n    // We're about to start on the next c-blk.  Needed?\n    __syncthreads();\n  }\n}\n\n#endif  // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_\n"
  },
  {
    "path": "apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp",
    "content": "#include <torch/torch.h>\n\n#include <cstdint>\n#include <vector>\n\nvoid index_mul_2d_float_foward_cuda(at::Tensor& out, const at::Tensor& in1, const at::Tensor& in2,\n                                    const at::Tensor& idx1);\n\nvoid index_mul_2d_float_backward_cuda(at::Tensor& grad_in1, at::Tensor& grad_in2, const at::Tensor& grad_out,\n                                      const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1);\n\nvoid index_mul_2d_float_backward_backward_cuda(at::Tensor& grad_grad_out, at::Tensor& grad_in1, at::Tensor& grad_in2,\n                                               const at::Tensor& grad_out, const at::Tensor& grad_grad_in1,\n                                               const at::Tensor& grad_grad_in2, const at::Tensor& in1,\n                                               const at::Tensor& in2, const at::Tensor& idx1);\n\nvoid index_mul_2d_half_foward_cuda(at::Tensor& out, const at::Tensor& in1, const at::Tensor& in2,\n                                   const at::Tensor& idx1);\n\nvoid index_mul_2d_half_backward_cuda(at::Tensor& grad_in1, at::Tensor& grad_in2, const at::Tensor& grad_out,\n                                     const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1);\n\nvoid index_mul_2d_half_backward_backward_cuda(at::Tensor& grad_grad_out, at::Tensor& grad_in1, at::Tensor& grad_in2,\n                                              const at::Tensor& grad_out, const at::Tensor& grad_grad_in1,\n                                              const at::Tensor& grad_grad_in2, const at::Tensor& in1,\n                                              const at::Tensor& in2, const at::Tensor& idx1);\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) \\\n  CHECK_CUDA(x);       \\\n  CHECK_CONTIGUOUS(x)\n\nvoid index_mul_2d_float_forward(at::Tensor& out, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1) {\n  return index_mul_2d_float_foward_cuda(out, in1, in2, idx1);\n}\n\nvoid index_mul_2d_float_backward(at::Tensor& grad_in1, at::Tensor& grad_in2, const at::Tensor& grad_out,\n                                 const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1) {\n  return index_mul_2d_float_backward_cuda(grad_in1, grad_in2, grad_out, in1, in2, idx1);\n}\n\nvoid index_mul_2d_float_backwrad_backward(at::Tensor& grad_grad_out, at::Tensor& grad_in1, at::Tensor& grad_in2,\n                                          const at::Tensor& grad_out, const at::Tensor& grad_grad_in1,\n                                          const at::Tensor& grad_grad_in2, const at::Tensor& in1, const at::Tensor& in2,\n                                          const at::Tensor& idx1) {\n  return index_mul_2d_float_backward_backward_cuda(grad_grad_out, grad_in1, grad_in2, grad_out, grad_grad_in1,\n                                                   grad_grad_in2, in1, in2, idx1);\n}\n\nvoid index_mul_2d_half_forward(at::Tensor& out, const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1) {\n  return index_mul_2d_half_foward_cuda(out, in1, in2, idx1);\n}\n\nvoid index_mul_2d_half_backward(at::Tensor& grad_in1, at::Tensor& grad_in2, const at::Tensor& grad_out,\n                                const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1) {\n  return index_mul_2d_half_backward_cuda(grad_in1, grad_in2, grad_out, in1, in2, idx1);\n}\n\nvoid index_mul_2d_half_backwrad_backward(at::Tensor& grad_grad_out, at::Tensor& grad_in1, at::Tensor& grad_in2,\n                                         const at::Tensor& grad_out, const at::Tensor& grad_grad_in1,\n                                         const at::Tensor& grad_grad_in2, const at::Tensor& in1, const at::Tensor& in2,\n                                         const at::Tensor& idx1) {\n  return index_mul_2d_half_backward_backward_cuda(grad_grad_out, grad_in1, grad_in2, grad_out, grad_grad_in1,\n                                                  grad_grad_in2, in1, in2, idx1);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"float_forward\", &index_mul_2d_float_forward, \"index mul float calculation forward (CUDA)\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"float_backward\", &index_mul_2d_float_backward, \"index mul float calculation backward (CUDA)\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"float_backward_backward\", &index_mul_2d_float_backwrad_backward,\n        \"index mul float calculation backward backward (CUDA)\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"half_forward\", &index_mul_2d_half_forward, \"index mul half calculation forward (CUDA)\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"half_backward\", &index_mul_2d_half_backward, \"index mul half calculation backward (CUDA)\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"half_backward_backward\", &index_mul_2d_half_backwrad_backward,\n        \"index mul half calculation backward backward (CUDA)\", py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#include <ATen/cuda/Atomic.cuh>\n\n__global__ void index_mul_2d_float_dim64(float* out, const float* in1, const float* in2, const int64_t* idx1,\n                                         const int64_t size) {\n  const int tidx = threadIdx.x;\n  const int tidy = threadIdx.y;\n  const int bidx = blockIdx.x;\n  const int start_idx = bidx * blockDim.y + tidy;\n  constexpr int fea_dim = 64;\n\n  if (start_idx < size) {\n    int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx;\n    int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx;\n\n    float4 res, src1, src2;\n    src1 = reinterpret_cast<const float4*>(in1)[vec_idx1];\n    src2 = reinterpret_cast<const float4*>(in2)[vec_idx2];\n    res.x = src1.x * src2.x;\n    res.y = src1.y * src2.y;\n    res.z = src1.z * src2.z;\n    res.w = src1.w * src2.w;\n    reinterpret_cast<float4*>(out)[vec_idx2] = res;\n  }\n}\n\n__global__ void index_mul_2d_float(float* out, const float* in1, const float* in2, const int64_t* idx1,\n                                   const int64_t size, const int64_t fea_dim) {\n  const int tidx = threadIdx.x;\n  const int tidy = threadIdx.y;\n  const int bidx = blockIdx.x;\n  const int start_idx = bidx * blockDim.y + tidy;\n  const int stride = blockDim.x;\n\n  if (start_idx < size) {\n    int64_t vec_idx1 = (idx1[start_idx] * fea_dim);\n    int64_t vec_idx2 = (start_idx * fea_dim);\n\n    for (int i = tidx; i < fea_dim; i += stride) {\n      out[vec_idx2 + i] = in1[vec_idx1 + i] * in2[vec_idx2 + i];\n    }\n  }\n}\n\n__global__ void index_mul_2d_half(at::Half* out, const at::Half* in1, const at::Half* in2, const int64_t* idx1,\n                                  const int64_t size, const int64_t fea_dim) {\n  const int tidx = threadIdx.x;\n  const int tidy = threadIdx.y;\n  const int bidx = blockIdx.x;\n  const int start_idx = bidx * blockDim.y + tidy;\n  const int stride = blockDim.x;\n\n  if (start_idx < size) {\n    int64_t vec_idx1 = (idx1[start_idx] * fea_dim);\n    int64_t vec_idx2 = (start_idx * fea_dim);\n\n    for (int i = tidx; i < fea_dim; i += stride) {\n      out[vec_idx2 + i] = at::Half(static_cast<float>(in1[vec_idx1 + i]) * static_cast<float>(in2[vec_idx2 + i]));\n    }\n  }\n}\n\n__global__ void index_mul_2d_grad_float_dim64(float* grad_in1, float* grad_in2, const float* grad_out, const float* in1,\n                                              const float* in2, const int64_t* idx1, const int64_t size) {\n  const int tidx = threadIdx.x;\n  const int tidy = threadIdx.y;\n  const int bidx = blockIdx.x;\n  const int start_idx = bidx * blockDim.y + tidy;\n  constexpr int fea_dim = 64;\n\n  if (start_idx < size) {\n    int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx;\n    int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx;\n\n    float4 src_in1, src_in2, src_grad_out, dst_grad_in2;\n    src_grad_out = reinterpret_cast<const float4*>(grad_out)[vec_idx2];\n    src_in1 = reinterpret_cast<const float4*>(in1)[vec_idx1];\n    src_in2 = reinterpret_cast<const float4*>(in2)[vec_idx2];\n    int64_t grad_in1_base_idx = idx1[start_idx] * fea_dim + tidx * 4;\n    gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 0, src_grad_out.x * src_in2.x);\n    gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 1, src_grad_out.y * src_in2.y);\n    gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 2, src_grad_out.z * src_in2.z);\n    gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 3, src_grad_out.w * src_in2.w);\n    dst_grad_in2.x = src_grad_out.x * src_in1.x;\n    dst_grad_in2.y = src_grad_out.y * src_in1.y;\n    dst_grad_in2.z = src_grad_out.z * src_in1.z;\n    dst_grad_in2.w = src_grad_out.w * src_in1.w;\n    reinterpret_cast<float4*>(grad_in2)[vec_idx2] = dst_grad_in2;\n  }\n}\n\n__global__ void index_mul_2d_grad_float(float* grad_in1, float* grad_in2, const float* grad_out, const float* in1,\n                                        const float* in2, const int64_t* idx1, const int64_t size,\n                                        const int64_t fea_dim) {\n  const int tidx = threadIdx.x;\n  const int tidy = threadIdx.y;\n  const int bidx = blockIdx.x;\n  const int start_idx = bidx * blockDim.y + tidy;\n  const int stride = blockDim.x;\n\n  if (start_idx < size) {\n    int64_t vec_idx1 = idx1[start_idx] * fea_dim;\n    int64_t vec_idx2 = start_idx * fea_dim;\n\n    for (int i = tidx; i < fea_dim; i += stride) {\n      float src_in1 = in1[vec_idx1 + i];\n      float src_in2 = in2[vec_idx2 + i];\n      float src_grad_out = grad_out[vec_idx2 + i];\n      grad_in2[vec_idx2 + i] = src_grad_out * src_in1;\n      gpuAtomicAdd(grad_in1 + vec_idx1 + i, src_grad_out * src_in2);\n    }\n  }\n}\n\n__global__ void index_mul_2d_grad_half(at::Half* grad_in1, at::Half* grad_in2, const at::Half* grad_out,\n                                       const at::Half* in1, const at::Half* in2, const int64_t* idx1,\n                                       const int64_t size, const int64_t fea_dim) {\n  const int tidx = threadIdx.x;\n  const int tidy = threadIdx.y;\n  const int bidx = blockIdx.x;\n  const int start_idx = bidx * blockDim.y + tidy;\n  const int stride = blockDim.x;\n\n  if (start_idx < size) {\n    int64_t vec_idx1 = idx1[start_idx] * fea_dim;\n    int64_t vec_idx2 = start_idx * fea_dim;\n\n    for (int i = tidx; i < fea_dim; i += stride) {\n      float src_in1 = static_cast<float>(in1[vec_idx1 + i]);\n      float src_in2 = static_cast<float>(in2[vec_idx2 + i]);\n      float src_grad_out = static_cast<float>(grad_out[vec_idx2 + i]);\n      grad_in2[vec_idx2 + i] = at::Half(src_grad_out * src_in1);\n      gpuAtomicAdd(grad_in1 + vec_idx1 + i, at::Half(src_grad_out * src_in2));\n    }\n  }\n}\n\n__global__ void index_mul_2d_grad_grad_float_dim64(float* grad_grad_out, float* grad_in1, float* grad_in2,\n                                                   const float* grad_out, const float* grad_grad_in1,\n                                                   const float* grad_grad_in2, const float* in1, const float* in2,\n                                                   const int64_t* idx1, const int64_t size) {\n  const int tidx = threadIdx.x;\n  const int tidy = threadIdx.y;\n  const int bidx = blockIdx.x;\n  const int start_idx = bidx * blockDim.y + tidy;\n  constexpr int fea_dim = 64;\n\n  if (start_idx < size) {\n    int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx;\n    int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx;\n\n    float4 src_grad_grad_in1, src_in1, src_grad_grad_in2, src_in2, src_grad_out;\n    float4 dst_grad_grad_out, dst_grad_in2;\n    src_grad_grad_in1 = reinterpret_cast<const float4*>(grad_grad_in1)[vec_idx1];\n    src_in1 = reinterpret_cast<const float4*>(in1)[vec_idx1];\n    src_grad_grad_in2 = reinterpret_cast<const float4*>(grad_grad_in2)[vec_idx2];\n    src_in2 = reinterpret_cast<const float4*>(in2)[vec_idx2];\n    dst_grad_grad_out.x = src_grad_grad_in1.x * src_in2.x + src_grad_grad_in2.x * src_in1.x;\n    dst_grad_grad_out.y = src_grad_grad_in1.y * src_in2.y + src_grad_grad_in2.y * src_in1.y;\n    dst_grad_grad_out.z = src_grad_grad_in1.z * src_in2.z + src_grad_grad_in2.z * src_in1.z;\n    dst_grad_grad_out.w = src_grad_grad_in1.w * src_in2.w + src_grad_grad_in2.w * src_in1.w;\n    reinterpret_cast<float4*>(grad_grad_out)[vec_idx2] = dst_grad_grad_out;\n    src_grad_out = reinterpret_cast<const float4*>(grad_out)[vec_idx2];\n    int64_t grad_in1_base_idx = idx1[start_idx] * fea_dim + tidx * 4;\n    gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 0, src_grad_grad_in2.x * src_grad_out.x);\n    gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 1, src_grad_grad_in2.y * src_grad_out.y);\n    gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 2, src_grad_grad_in2.z * src_grad_out.z);\n    gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 3, src_grad_grad_in2.w * src_grad_out.w);\n    dst_grad_in2.x = src_grad_grad_in1.x * src_grad_out.x;\n    dst_grad_in2.y = src_grad_grad_in1.y * src_grad_out.y;\n    dst_grad_in2.z = src_grad_grad_in1.z * src_grad_out.z;\n    dst_grad_in2.w = src_grad_grad_in1.w * src_grad_out.w;\n    reinterpret_cast<float4*>(grad_in2)[vec_idx2] = dst_grad_in2;\n  }\n}\n\n__global__ void index_mul_2d_grad_grad_float(float* grad_grad_out, float* grad_in1, float* grad_in2,\n                                             const float* grad_out, const float* grad_grad_in1,\n                                             const float* grad_grad_in2, const float* in1, const float* in2,\n                                             const int64_t* idx1, const int64_t size, const int64_t fea_dim) {\n  const int tidx = threadIdx.x;\n  const int tidy = threadIdx.y;\n  const int bidx = blockIdx.x;\n  const int start_idx = bidx * blockDim.y + tidy;\n  const int stride = blockDim.x;\n\n  if (start_idx < size) {\n    int64_t vec_idx1 = idx1[start_idx] * fea_dim;\n    int64_t vec_idx2 = start_idx * fea_dim;\n\n    for (int i = tidx; i < fea_dim; i += stride) {\n      float src_grad_grad_in1 = grad_grad_in1[vec_idx1 + i];\n      float src_grad_grad_in2 = grad_grad_in2[vec_idx2 + i];\n      float src_in1 = in1[vec_idx1 + i];\n      float src_in2 = in2[vec_idx2 + i];\n      float src_grad_out = grad_out[vec_idx2 + i];\n      grad_grad_out[vec_idx2 + i] = src_grad_grad_in1 * src_in2 + src_grad_grad_in2 * src_in1;\n      grad_in2[vec_idx2 + i] = src_grad_grad_in1 * src_grad_out;\n      gpuAtomicAdd(grad_in1 + vec_idx1 + i, src_grad_grad_in2 * src_grad_out);\n    }\n  }\n}\n\n__global__ void index_mul_2d_grad_grad_half(at::Half* grad_grad_out, at::Half* grad_in1, at::Half* grad_in2,\n                                            const at::Half* grad_out, const at::Half* grad_grad_in1,\n                                            const at::Half* grad_grad_in2, const at::Half* in1, const at::Half* in2,\n                                            const int64_t* idx1, const int64_t size, const int64_t fea_dim) {\n  const int tidx = threadIdx.x;\n  const int tidy = threadIdx.y;\n  const int bidx = blockIdx.x;\n  const int start_idx = bidx * blockDim.y + tidy;\n  const int stride = blockDim.x;\n\n  if (start_idx < size) {\n    int64_t vec_idx1 = idx1[start_idx] * fea_dim;\n    int64_t vec_idx2 = start_idx * fea_dim;\n\n    for (int i = tidx; i < fea_dim; i += stride) {\n      float src_grad_grad_in1 = static_cast<float>(grad_grad_in1[vec_idx1 + i]);\n      float src_grad_grad_in2 = static_cast<float>(grad_grad_in2[vec_idx2 + i]);\n      float src_in1 = static_cast<float>(in1[vec_idx1 + i]);\n      float src_in2 = static_cast<float>(in2[vec_idx2 + i]);\n      float src_grad_out = static_cast<float>(grad_out[vec_idx2 + i]);\n      grad_grad_out[vec_idx2 + i] = at::Half(src_grad_grad_in1 * src_in2 + src_grad_grad_in2 * src_in1);\n      grad_in2[vec_idx2 + i] = at::Half(src_grad_grad_in1 * src_grad_out);\n      gpuAtomicAdd(grad_in1 + vec_idx1 + i, at::Half(src_grad_grad_in2 * src_grad_out));\n    }\n  }\n}\n\nvoid index_mul_2d_float_foward_cuda(at::Tensor& out, const at::Tensor& in1, const at::Tensor& in2,\n                                    const at::Tensor& idx1) {\n  const int64_t size = in2.size(0);\n  const int64_t fea_dim = in2.size(1);\n  if (size < 0) {\n    return;\n  }\n\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  if (fea_dim == 64) {\n    const int BLOCK_THREADS_DIMX = 16;\n    const int BLOCK_THREADS_DIMY = 16;\n    const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;\n\n    index_mul_2d_float_dim64<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(\n        out.data_ptr<float>(), in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size);\n  } else {\n    const int BLOCK_THREADS_DIMX = 32;\n    const int BLOCK_THREADS_DIMY = 8;\n    const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;\n\n    index_mul_2d_float<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(\n        out.data_ptr<float>(), in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size, fea_dim);\n  }\n\n  AT_CUDA_CHECK(cudaGetLastError());\n}\n\nvoid index_mul_2d_float_backward_cuda(at::Tensor& grad_in1, at::Tensor& grad_in2, const at::Tensor& grad_out,\n                                      const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1) {\n  const int64_t size = in2.size(0);\n  const int64_t fea_dim = in2.size(1);\n  if (size < 0) {\n    return;\n  }\n\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  if (fea_dim == 64) {\n    const int BLOCK_THREADS_DIMX = 16;\n    const int BLOCK_THREADS_DIMY = 16;\n    const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;\n\n    index_mul_2d_grad_float_dim64<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(\n        grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(), grad_out.data_ptr<float>(), in1.data_ptr<float>(),\n        in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size);\n\n    AT_CUDA_CHECK(cudaGetLastError());\n  } else {\n    const int BLOCK_THREADS_DIMX = 32;\n    const int BLOCK_THREADS_DIMY = 8;\n    const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;\n\n    index_mul_2d_grad_float<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(\n        grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(), grad_out.data_ptr<float>(), in1.data_ptr<float>(),\n        in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size, fea_dim);\n  }\n}\n\nvoid index_mul_2d_float_backward_backward_cuda(at::Tensor& grad_grad_out, at::Tensor& grad_in1, at::Tensor& grad_in2,\n                                               const at::Tensor& grad_out, const at::Tensor& grad_grad_in1,\n                                               const at::Tensor& grad_grad_in2, const at::Tensor& in1,\n                                               const at::Tensor& in2, const at::Tensor& idx1) {\n  const int64_t size = in2.size(0);\n  const int64_t fea_dim = in2.size(1);\n  if (size < 0) {\n    return;\n  }\n\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  if (fea_dim == 64) {\n    const int BLOCK_THREADS_DIMX = 16;\n    const int BLOCK_THREADS_DIMY = 16;\n    const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;\n\n    index_mul_2d_grad_grad_float_dim64<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(\n        grad_grad_out.data_ptr<float>(), grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(),\n        grad_out.data_ptr<float>(), grad_grad_in1.data_ptr<float>(), grad_grad_in2.data_ptr<float>(),\n        in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size);\n  } else {\n    const int BLOCK_THREADS_DIMX = 32;\n    const int BLOCK_THREADS_DIMY = 8;\n    const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;\n\n    index_mul_2d_grad_grad_float<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(\n        grad_grad_out.data_ptr<float>(), grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(),\n        grad_out.data_ptr<float>(), grad_grad_in1.data_ptr<float>(), grad_grad_in2.data_ptr<float>(),\n        in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size, fea_dim);\n  }\n\n  AT_CUDA_CHECK(cudaGetLastError());\n}\n\nvoid index_mul_2d_half_foward_cuda(at::Tensor& out, const at::Tensor& in1, const at::Tensor& in2,\n                                   const at::Tensor& idx1) {\n  const int64_t size = in2.size(0);\n  const int64_t fea_dim = in2.size(1);\n  if (size < 0) {\n    return;\n  }\n\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  const int BLOCK_THREADS_DIMX = 32;\n  const int BLOCK_THREADS_DIMY = 8;\n  const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;\n\n  index_mul_2d_half<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(\n      out.data_ptr<at::Half>(), in1.data_ptr<at::Half>(), in2.data_ptr<at::Half>(), idx1.data_ptr<int64_t>(), size,\n      fea_dim);\n\n  AT_CUDA_CHECK(cudaGetLastError());\n}\n\nvoid index_mul_2d_half_backward_cuda(at::Tensor& grad_in1, at::Tensor& grad_in2, const at::Tensor& grad_out,\n                                     const at::Tensor& in1, const at::Tensor& in2, const at::Tensor& idx1) {\n  const int64_t size = in2.size(0);\n  const int64_t fea_dim = in2.size(1);\n  if (size < 0) {\n    return;\n  }\n\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  const int BLOCK_THREADS_DIMX = 32;\n  const int BLOCK_THREADS_DIMY = 8;\n  const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;\n\n  index_mul_2d_grad_half<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(\n      grad_in1.data_ptr<at::Half>(), grad_in2.data_ptr<at::Half>(), grad_out.data_ptr<at::Half>(),\n      in1.data_ptr<at::Half>(), in2.data_ptr<at::Half>(), idx1.data_ptr<int64_t>(), size, fea_dim);\n}\n\nvoid index_mul_2d_half_backward_backward_cuda(at::Tensor& grad_grad_out, at::Tensor& grad_in1, at::Tensor& grad_in2,\n                                              const at::Tensor& grad_out, const at::Tensor& grad_grad_in1,\n                                              const at::Tensor& grad_grad_in2, const at::Tensor& in1,\n                                              const at::Tensor& in2, const at::Tensor& idx1) {\n  const int64_t size = in2.size(0);\n  const int64_t fea_dim = in2.size(1);\n  if (size < 0) {\n    return;\n  }\n\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  const int BLOCK_THREADS_DIMX = 32;\n  const int BLOCK_THREADS_DIMY = 8;\n  const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;\n\n  index_mul_2d_grad_grad_half<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(\n      grad_grad_out.data_ptr<at::Half>(), grad_in1.data_ptr<at::Half>(), grad_in2.data_ptr<at::Half>(),\n      grad_out.data_ptr<at::Half>(), grad_grad_in1.data_ptr<at::Half>(), grad_grad_in2.data_ptr<at::Half>(),\n      in1.data_ptr<at::Half>(), in2.data_ptr<at::Half>(), idx1.data_ptr<int64_t>(), size, fea_dim);\n\n  AT_CUDA_CHECK(cudaGetLastError());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/layer_norm/ln.h",
    "content": "#pragma once\n\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <stdint.h>\n#include <stdio.h>\n\n#include <unordered_map>\n\nnamespace layer_norm {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Params>\nstruct LaunchParams {\n  size_t workspace_bytes;\n  size_t barrier_size;\n\n  cudaDeviceProp* props;\n\n  cudaStream_t stream;\n\n  Params params;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct FwdParams {\n  FwdParams()\n      : ctas_per_col(0),\n        rows(0),\n        cols(0),\n        x(nullptr),\n        z(nullptr),\n        mu(nullptr),\n        rs(nullptr),\n        gamma(nullptr),\n        beta(nullptr),\n        workspace(nullptr),\n        barrier(nullptr),\n        epsilon(0.f) {}\n\n  // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.\n  int ctas_per_col;\n\n  // Input is interpreted as matrix. We normalize across columns.\n  int rows;\n  int cols;\n\n  // Common data pointers.\n  void* x;\n  void* z;\n  void* mu;\n  void* rs;\n  void* gamma;\n  void* beta;\n\n  // Multi-CTA workspace in gmem.\n  void* workspace;\n\n  // Multi-CTA sync barriers in gmem.\n  int* barrier;\n\n  // Output of LN FWD.\n  float epsilon;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct BwdParams : public FwdParams {\n  BwdParams()\n      : FwdParams(),\n        dz(nullptr),\n        dbeta_part(nullptr),\n        dgamma_part(nullptr),\n        dx(nullptr),\n        dbeta(nullptr),\n        dgamma(nullptr) {}\n  // Input: gradient wrt. LN FWD output.\n  void* dz;\n\n  // Workspace for Wgrad pre-reduction.\n  void* dbeta_part;\n  void* dgamma_part;\n\n  // Output: Dgrad.\n  void* dx;\n  // Output: Wgrad.\n  void* dbeta;\n  void* dgamma;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nusing FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;\nusing BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;\nusing FunctionKey = uint64_t;\nusing FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;\nusing BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;\n\nextern FwdRegistry FWD_FUNCS;\nextern BwdRegistry BWD_FUNCS;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nusing fp32 = float;\nusing fp16 = half;\nusing bf16 = nv_bfloat16;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T>\nstruct TypeId {};\n\ntemplate <>\nstruct TypeId<fp16> {\n  constexpr static uint32_t Value = 0;\n};\n\ntemplate <>\nstruct TypeId<bf16> {\n  constexpr static uint32_t Value = 1;\n};\n\ntemplate <>\nstruct TypeId<fp32> {\n  constexpr static uint32_t Value = 2;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T, int S>\nstruct Type2Key {\n  constexpr static uint32_t Value = TypeId<T>::Value << S;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T>\nstruct WeightType2Key : public Type2Key<T, 0> {};\n\ntemplate <typename T>\nstruct InputType2Key : public Type2Key<T, 2> {};\n\ntemplate <typename T>\nstruct OutputType2Key : public Type2Key<T, 4> {};\n\ntemplate <typename T>\nstruct ComputeType2Key : public Type2Key<T, 6> {};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename W, typename I, typename O, typename C>\nstruct Types2Key {\n  constexpr static uint32_t Value =\n      WeightType2Key<W>::Value | InputType2Key<I>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value;\n  constexpr static inline uint64_t get(const uint64_t hidden_size) {\n    constexpr uint64_t type_key = Value;\n    return (type_key << 32) | hidden_size;\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>\nstruct FwdRegistrar {\n  FwdRegistrar(FwdFunction f) {\n    uint64_t key = Types2Key<W, I, O, C>::get(HIDDEN_SIZE);\n    FWD_FUNCS.insert({key, f});\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>\nstruct BwdRegistrar {\n  BwdRegistrar(BwdFunction f) {\n    uint64_t key = Types2Key<W, I, O, C>::get(HIDDEN_SIZE);\n    BWD_FUNCS.insert({key, f});\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace layer_norm\n"
  },
  {
    "path": "apex/contrib/csrc/layer_norm/ln_api.cpp",
    "content": "#include <torch/extension.h>\n\n#include \"ATen/cuda/CUDAContext.h\"\n#include \"ln.h\"\n\n/*\n\nSupported Type combinations:\n\ninput    compute   weights   output\n=======================================\nfp32     fp32      fp32      fp32\nfp16     fp32      fp16      fp16\nbf16     fp32      bf16      bf16\nfp32     fp32      fp16      fp16\nfp32     fp32      bf16      bf16\n\nRemarks:\nOutput type = Weight type\nCompute always in FP32\n\n*/\n\nnamespace layer_norm {\n\n// Create registries and provide runtime versions of config hash functions.\n\nFwdRegistry FWD_FUNCS;\nBwdRegistry BWD_FUNCS;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nuint32_t get_type_id(torch::Dtype dtype) {\n  if (dtype == torch::kFloat16) {\n    return TypeId<fp16>::Value;\n  } else if (dtype == torch::kBFloat16) {\n    return TypeId<bf16>::Value;\n  } else if (dtype == torch::kFloat32) {\n    return TypeId<fp32>::Value;\n  } else {\n    TORCH_CHECK(false, \"Type not supported: \", dtype);\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nuint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) {\n  using namespace layer_norm;\n  uint64_t type_key =\n      get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) | (get_type_id(ctype) << 6);\n  uint64_t launcher_key = (type_key << 32) | hidden_size;\n  return launcher_key;\n}\n\n}  // namespace layer_norm\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nlayer_norm::FwdFunction& get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype,\n                                          torch::Dtype ctype, uint32_t hidden_size) {\n  auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size));\n  if (iter != layer_norm::FWD_FUNCS.end()) {\n    return iter->second;\n  } else {\n    TORCH_CHECK(false, \"FWD: Unsupported hidden_size or types: \", hidden_size, wtype, itype, otype, ctype);\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nlayer_norm::BwdFunction& get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype,\n                                          torch::Dtype ctype, uint32_t hidden_size) {\n  auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size));\n  if (iter != layer_norm::BWD_FUNCS.end()) {\n    return iter->second;\n  } else {\n    TORCH_CHECK(false, \"BWD: Unsupported hidden_size or types: \", hidden_size, wtype, itype, otype, ctype);\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstd::vector<at::Tensor> ln_fwd(const at::Tensor& x,      // BxSxhidden_size\n                               const at::Tensor& gamma,  // hidden_size\n                               const at::Tensor& beta,   // hidden_size\n                               const float epsilon) {\n  auto itype = x.scalar_type();\n  auto wtype = gamma.scalar_type();\n  auto otype = wtype;\n  auto ctype = torch::kFloat32;\n\n  TORCH_CHECK(beta.scalar_type() == wtype);\n\n  TORCH_CHECK(x.is_cuda())\n  TORCH_CHECK(gamma.is_cuda())\n  TORCH_CHECK(beta.is_cuda())\n\n  TORCH_CHECK(x.is_contiguous());\n  auto sizes = x.sizes();\n  TORCH_CHECK(sizes.size() == 2);\n\n  const int rows = sizes[0];\n  const int cols = sizes[1];\n  auto hidden_size = gamma.numel();\n\n  TORCH_CHECK(gamma.sizes() == beta.sizes());\n  TORCH_CHECK(hidden_size == cols);\n\n  TORCH_CHECK(epsilon >= 0.f);\n\n  auto opts = x.options();\n\n  auto z = torch::empty(sizes, opts.dtype(otype));\n\n  auto mu = torch::empty({rows}, opts.dtype(ctype));\n  auto rsigma = torch::empty({rows}, opts.dtype(ctype));\n\n  layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;\n\n  launch_params.props = at::cuda::getCurrentDeviceProperties();\n  launch_params.stream = at::cuda::getCurrentCUDAStream().stream();\n\n  // Request the kernel launcher.\n  auto launcher = get_fwd_launcher(wtype, itype, otype, ctype, hidden_size);\n\n  // Query the kernel-specific launch parameters.\n  launcher(launch_params, true);\n\n  at::Tensor workspace, barrier;\n\n  // Set the kernel runtime parameters.\n  layer_norm::FwdParams& params = launch_params.params;\n  params.rows = rows;\n  params.cols = cols;\n  params.z = z.data_ptr();\n  params.mu = mu.data_ptr();\n  params.rs = rsigma.data_ptr();\n  params.gamma = gamma.data_ptr();\n  params.beta = beta.data_ptr();\n  params.x = x.data_ptr();\n  params.epsilon = epsilon;\n\n  if (launch_params.barrier_size > 0) {\n    auto options = x.options();\n    barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));\n    workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));\n    params.workspace = workspace.data_ptr();\n    params.barrier = barrier.data_ptr<int>();\n  }\n\n  // Launch the kernel.\n  launcher(launch_params, false);\n\n  return {z, mu, rsigma};\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\nstd::vector<at::Tensor> ln_bwd(const at::Tensor& dz,                    // BxSxhidden_size\n                               const at::Tensor& x_or_z,                // BxSxhidden_size\n                               c10::optional<const at::Tensor>& mu_,    // BxS, FP32!\n                               const at::Tensor& rsigma,                // BxS, FP32!\n                               const at::Tensor& gamma,                 // hidden_size\n                               c10::optional<const at::Tensor>& beta_,  // hidden_size\n                               bool memory_efficient) {\n  auto itype = x_or_z.scalar_type();\n  auto wtype = gamma.scalar_type();\n  auto otype = wtype;\n  auto ctype = torch::kFloat32;\n\n  TORCH_CHECK(dz.dtype() == otype);\n  TORCH_CHECK(rsigma.dtype() == ctype);\n  if (mu_.has_value()) {\n    TORCH_CHECK(mu_.value().dtype() == ctype);\n  }\n\n  TORCH_CHECK(x_or_z.is_cuda());\n  TORCH_CHECK(dz.is_cuda());\n  TORCH_CHECK(rsigma.is_cuda());\n  TORCH_CHECK(gamma.is_cuda());\n  if (beta_.has_value()) {\n    TORCH_CHECK(beta_.value().is_cuda());\n    TORCH_CHECK(beta_.value().dtype() == wtype);\n  }\n\n  TORCH_CHECK(x_or_z.is_contiguous());\n  TORCH_CHECK(dz.is_contiguous());\n\n  auto sizes = x_or_z.sizes();\n  TORCH_CHECK(sizes.size() == 2);\n  TORCH_CHECK(dz.sizes() == sizes);\n  auto rows = sizes[0];\n  auto cols = sizes[1];\n\n  auto hidden_size = gamma.numel();\n\n  TORCH_CHECK(gamma.numel() == cols);\n  if (beta_.has_value()) {\n    TORCH_CHECK(beta_.value().numel() == cols);\n  }\n\n  auto options = x_or_z.options();\n\n  auto dx = torch::empty_like(x_or_z);\n  auto dgamma = torch::empty_like(gamma);\n  auto dbeta = torch::empty_like(gamma);\n\n  layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;\n  launch_params.stream = at::cuda::getCurrentCUDAStream().stream();\n  launch_params.props = at::cuda::getCurrentDeviceProperties();\n\n  auto launcher = get_bwd_launcher(wtype, itype, otype, ctype, hidden_size);\n\n  launcher(launch_params, true);\n\n  auto dgamma_part = torch::empty({launch_params.params.ctas_per_col, hidden_size}, options.dtype(ctype));\n  auto dbeta_part = torch::empty({launch_params.params.ctas_per_col, hidden_size}, options.dtype(ctype));\n  at::Tensor workspace, barrier;\n\n  layer_norm::BwdParams& params = launch_params.params;\n  params.rows = rows;\n  params.cols = cols;\n  if (memory_efficient) {\n    params.z = x_or_z.data_ptr();\n    params.beta = beta_.value().data_ptr();\n  } else {\n    params.x = x_or_z.data_ptr();\n    params.mu = mu_.value().data_ptr();\n  }\n  params.rs = rsigma.data_ptr();\n  params.gamma = gamma.data_ptr();\n  params.dz = dz.data_ptr();\n  params.dx = dx.data_ptr();\n  params.dbeta = dbeta.data_ptr();\n  params.dgamma = dgamma.data_ptr();\n  params.dbeta_part = dbeta_part.data_ptr();\n  params.dgamma_part = dgamma_part.data_ptr();\n\n  if (launch_params.barrier_size > 0) {\n    // TODO Any way to avoid this?\n    barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));\n    workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));\n    params.workspace = workspace.data_ptr();\n    params.barrier = barrier.data_ptr<int>();\n  }\n\n  launcher(launch_params, false);\n\n  return {dx, dgamma, dbeta, dgamma_part, dbeta_part};\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.doc() = \"CUDA LayerNorm\";\n  m.def(\"ln_fwd\", &ln_fwd, \"Run LayerNorm forward kernel\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"ln_bwd\", &ln_bwd, \"Run LayerNorm backward kernel\", py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh",
    "content": "#pragma once\n\n#include \"ln_utils.cuh\"\n\nnamespace layer_norm {\n\ntemplate <typename Ktraits>\n__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_kernel(layer_norm::BwdParams params) {\n  enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };\n  enum { WARPS_M = Ktraits::WARPS_M };\n  enum { WARPS_N = Ktraits::WARPS_N };\n  enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };\n  enum { COLS = Ktraits::COLS };\n  enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };\n  enum { LDGS = Ktraits::LDGS };\n  enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };\n  enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };\n  enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };\n\n  using compute_t = typename Ktraits::compute_t;\n  using index_t = typename Ktraits::index_t;\n  using Ivec = typename Ktraits::Ivec;\n  using Ovec = typename Ktraits::Ovec;\n  using Wvec = typename Ktraits::Wvec;\n  using Cvec = typename Ktraits::Cvec;\n  using Reducer = typename Ktraits::Reducer;\n  using reduce_t = typename Reducer::Type;\n\n  extern __shared__ char smem_[];\n\n  const index_t tidx = threadIdx.x;\n  const index_t bidn = blockIdx.x % CTAS_PER_ROW;\n  const index_t bidm = blockIdx.x / CTAS_PER_ROW;\n  const index_t lane = tidx % THREADS_PER_WARP;\n  const index_t warp = tidx / THREADS_PER_WARP;\n  const index_t warp_m = warp / Ktraits::WARPS_N;\n  const index_t warp_n = warp % Ktraits::WARPS_N;\n  const index_t tid_r = warp_n * THREADS_PER_WARP + lane;\n\n  const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;\n  const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;\n\n  static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);\n\n  Cvec dzy_sum[LDGS];\n  Cvec dz_sum[LDGS];\n\n  memset(dzy_sum, 0, sizeof(dzy_sum));\n  memset(dz_sum, 0, sizeof(dz_sum));\n\n  compute_t* smem_wgrad = reinterpret_cast<compute_t*>(smem_);\n  char* smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;\n\n  Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);\n\n  Sum<reduce_t> sum;\n\n  constexpr float rn = 1.f / float(COLS);\n  Wvec gamma[LDGS];\n  Wvec beta[LDGS];\n  index_t idx = c;\n#pragma unroll\n  for (int it = 0; it < LDGS; it++) {\n    gamma[it].load_from(params.gamma, idx);\n    if (params.z != nullptr) {\n      beta[it].load_from(params.beta, idx);\n    }\n    idx += Ktraits::VEC_COLS_PER_LDG;\n  }\n// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the\n// last blocks with syncthreads!\n// grid stride over rows\n#pragma unroll 1\n  for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {\n    const compute_t mu_r = params.z == nullptr ? static_cast<const compute_t*>(params.mu)[row] : 0.f;\n    const compute_t rs_r = static_cast<const compute_t*>(params.rs)[row];\n    Ivec x_or_z[LDGS];\n    Ovec dz[LDGS];\n    index_t idx = row * Ktraits::VEC_COLS + c;\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n      dz[it].load_from(params.dz, idx);\n      if (params.z != nullptr) {\n        x_or_z[it].load_from(params.z, idx);\n      } else {\n        x_or_z[it].load_from(params.x, idx);\n      }\n      idx += Ktraits::VEC_COLS_PER_LDG;\n    }\n\n    compute_t dy[LDGS * NUM_ELTS];\n    compute_t y[LDGS * NUM_ELTS];\n\n    compute_t mdy_local = 0.f;\n    compute_t mdyy_local = 0.f;\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n#pragma unroll\n      for (int jt = 0; jt < NUM_ELTS; jt++) {\n        compute_t gamma_tmp = compute_t(gamma[it].data.elt[jt]);\n        compute_t beta_tmp = compute_t(beta[it].data.elt[jt]);\n        compute_t x_or_z_tmp = compute_t(x_or_z[it].data.elt[jt]);\n        compute_t y_tmp = params.z != nullptr ? (x_or_z_tmp - beta_tmp) / gamma_tmp : rs_r * (x_or_z_tmp - mu_r);\n        compute_t dy_tmp = compute_t(dz[it].data.elt[jt]) * gamma_tmp;\n        compute_t dz_tmp = dz[it].data.elt[jt];\n\n        mdy_local += dy_tmp;\n        mdyy_local += dy_tmp * y_tmp;\n\n        dy[it * NUM_ELTS + jt] = dy_tmp;\n        y[it * NUM_ELTS + jt] = y_tmp;\n\n        dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;\n        dz_sum[it].data.elt[jt] += dz_tmp;\n      }\n    }\n\n    reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);\n    mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;\n    mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;\n\n    Ivec dx[LDGS];\n    idx = row * Ktraits::VEC_COLS + c;\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n#pragma unroll\n      for (int jt = 0; jt < NUM_ELTS; jt++) {\n        compute_t dy_tmp = dy[it * NUM_ELTS + jt];\n        compute_t y_tmp = y[it * NUM_ELTS + jt];\n        compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));\n        dx[it].data.elt[jt] = dx_tmp;\n      }\n      dx[it].store_to(params.dx, idx);\n      idx += Ktraits::VEC_COLS_PER_LDG;\n    }\n\n  }  // end: grid stride loop\n\n  if (WARPS_M == 1) {\n    idx = r * Ktraits::VEC_COLS + c;\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n      dz_sum[it].store_to(params.dbeta_part, idx);\n      dzy_sum[it].store_to(params.dgamma_part, idx);\n      idx += Ktraits::VEC_COLS_PER_LDG;\n    }\n  } else {\n    static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, \"Multiple rows per CTA not supported for Multi-CTA.\");\n    // Finalize reduction of part dgamma and dbeta for this CTA\n    // by reducing over the rows held across the WARPS_M warps\n\n    // Assumption: blockSize divides hidden size.\n    enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };\n    static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, \"\");\n\n    idx = warp_m * Ktraits::VEC_COLS + tid_r;\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n      dz_sum[it].store_to(smem_wgrad, idx);\n      idx += THREADS_PER_ROW;\n    }\n    __syncthreads();\n    compute_t cta_dz_sum[NUM_RES];\n    memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES);\n    for (int it = 0; it < ROWS_PER_CTA; it++) {\n      for (int jt = 0; jt < NUM_RES; jt++) {\n        cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];\n      }\n    }\n    __syncthreads();\n\n    idx = warp_m * Ktraits::VEC_COLS + tid_r;\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n      dzy_sum[it].store_to(smem_wgrad, idx);\n      idx += THREADS_PER_ROW;\n    }\n    __syncthreads();\n    compute_t cta_dzy_sum[NUM_RES];\n    memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);\n    for (int it = 0; it < ROWS_PER_CTA; it++) {\n      for (int jt = 0; jt < NUM_RES; jt++) {\n        cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];\n      }\n    }\n\n    compute_t* dgamma_part = static_cast<compute_t*>(params.dgamma_part) + bidm * COLS + tidx;\n    for (int jt = 0; jt < NUM_RES; jt++) {\n      *dgamma_part = cta_dzy_sum[jt];\n      dgamma_part += Ktraits::THREADS_PER_CTA;\n    }\n\n    compute_t* dbeta_part = static_cast<compute_t*>(params.dbeta_part) + bidm * COLS + tidx;\n    for (int jt = 0; jt < NUM_RES; jt++) {\n      *dbeta_part = cta_dz_sum[jt];\n      dbeta_part += Ktraits::THREADS_PER_CTA;\n    }\n  }\n}\n\ntemplate <typename Kernel_traits>\n__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finalize_kernel(BwdParams params) {\n  using compute_t = typename Kernel_traits::compute_t;\n  using weight_t = typename Kernel_traits::weight_t;\n  using index_t = typename Kernel_traits::index_t;\n  using Reducer = typename Kernel_traits::Reducer;\n  using reduce_t = typename Reducer::Type;\n\n  Sum<reduce_t> sum;\n  enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };\n  enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };\n\n  __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];\n\n  constexpr uint32_t bidm = 0;\n\n  const uint32_t bidn = blockIdx.x;\n  const uint32_t tidx = threadIdx.x;\n  const uint32_t warp = tidx / THREADS_PER_WARP;\n  const uint32_t lane = tidx % THREADS_PER_WARP;\n\n  Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);\n\n  const uint32_t c = bidn * THREADS_PER_WARP + lane;\n  const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;\n  constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;\n  for (uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2) {\n    // Each thread sums over NUM_ELT columns.\n    Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;\n    memset(&dgamma_local, 0, sizeof(dgamma_local));\n    memset(&dbeta_local, 0, sizeof(dbeta_local));\n    for (uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA) {\n      index_t idx = row * Kernel_traits::COLS + col;\n\n      Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;\n      dbeta_part.load_from(params.dbeta_part, idx);\n      dgamma_part.load_from(params.dgamma_part, idx);\n#pragma unroll\n      for (int it = 0; it < NUM_ELT; it++) {\n        dgamma_local.data.elt[it] += dgamma_part.data.elt[it];\n        dbeta_local.data.elt[it] += dbeta_part.data.elt[it];\n      }\n    }\n\n    void* smem_gamma = smem_;\n    void* smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];\n\n    const int write_row = warp;\n    const int write_col = lane ^ write_row;\n    const int write_idx = write_row * THREADS_PER_WARP + write_col;\n\n    dgamma_local.store_to(smem_gamma, write_idx);\n    dbeta_local.store_to(smem_beta, write_idx);\n\n    __syncthreads();\n\n    // It would be probably safe to reuse the first row of smem_beta and smem_gamma\n    void* smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];\n    void* smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];\n\n    // More than one iter iff ROWS_PER_CTA < 32.\n    for (int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA) {\n      const int read_row = lane;\n      const int read_col = w ^ read_row;\n      const int read_idx = read_row * THREADS_PER_WARP + read_col;\n\n      memset(&dbeta_local, 0, sizeof(dbeta_local));\n      memset(&dgamma_local, 0, sizeof(dgamma_local));\n\n      // Load beta and gamma transposed\n      if (read_row < Kernel_traits::ROWS_PER_CTA) {\n        dbeta_local.load_from(smem_beta, read_idx);\n        dgamma_local.load_from(smem_gamma, read_idx);\n      }\n\n// Call reducer on the loaded value(s) and convert.\n#pragma unroll\n      for (int it = 0; it < NUM_ELT; it++) {\n        compute_t b_i = dbeta_local.data.elt[it];\n        compute_t g_i = dgamma_local.data.elt[it];\n        b_i = reducer.allreduce(b_i, sum);\n        g_i = reducer.allreduce(g_i, sum);\n\n        dgamma_local.data.elt[it] = g_i;\n        dbeta_local.data.elt[it] = b_i;\n      }\n\n      // Leader stores the result at the current column.\n      if (lane == 0) {\n        dgamma_local.store_to(smem_gamma_out, w);\n        dbeta_local.store_to(smem_beta_out, w);\n      }\n    }\n\n    // All writes done.\n    __syncthreads();\n\n    // Pack and store: 2-wide stores with half the threads.\n    if (warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2) {\n      using src_t = typename TypeToVec2<compute_t>::Type;\n      using dst_t = typename TypeToVec2<weight_t>::Type;\n      Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2;\n      Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2;\n\n      dgamma_vec2.load_from(smem_gamma_out, lane);\n      dbeta_vec2.load_from(smem_beta_out, lane);\n#pragma unroll\n      for (int it = 0; it < NUM_ELT; it++) {\n        dgamma_out2.data.elt[it] = Converter<src_t, dst_t>::convert(dgamma_vec2.data.elt[it]);\n        dbeta_out2.data.elt[it] = Converter<src_t, dst_t>::convert(dbeta_vec2.data.elt[it]);\n      }\n      dgamma_out2.store_to(params.dgamma, col_out);\n      dbeta_out2.store_to(params.dbeta, col_out);\n    }\n  }\n}\n}  // namespace layer_norm\n"
  },
  {
    "path": "apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu",
    "content": "#include \"ln.h\"\n#include \"ln_bwd_kernels.cuh\"\n#include \"ln_kernel_traits.h\"\n#include \"ln_utils.cuh\"\n\nusing namespace layer_norm;\n\ntemplate <typename weight_t, typename input_t, typename output_t, typename compute_t, typename index_t, int HIDDEN_SIZE,\n          int CTAS_PER_ROW, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL>\nvoid launch_(LaunchParams<BwdParams>& launch_params, const bool configure_params) {\n  using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE, CTAS_PER_ROW,\n                                      WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;\n  auto kernel = &ln_bwd_kernel<Kernel_traits>;\n\n  if (configure_params) {\n    int ctas_per_sm;\n    cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);\n    launch_params.params.ctas_per_col =\n        launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;\n    launch_params.barrier_size = 0;\n    launch_params.workspace_bytes = 0;\n    if (Kernel_traits::CTAS_PER_ROW > 1) {\n      launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;\n      launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M *\n                                      Kernel_traits::CTAS_PER_ROW * sizeof(typename Kernel_traits::reduce_t) * 2;\n    }\n    return;\n  }\n\n  if (Kernel_traits::SMEM_BYTES >= 48 * 1024) {\n    CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));\n  }\n  auto stream = launch_params.stream;\n  auto ctas_per_col = launch_params.params.ctas_per_col;\n\n  if (Kernel_traits::CTAS_PER_ROW == 1) {\n    kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);\n  } else {\n    dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);\n    dim3 block(Kernel_traits::THREADS_PER_CTA);\n    void* params_ = (void*)&launch_params.params;\n    cudaLaunchCooperativeKernel((void*)kernel, grid, block, (void**)&params_, Kernel_traits::SMEM_BYTES, stream);\n  }\n\n  using Kernel_traits_f =\n      layer_norm::Kernel_traits_finalize<HIDDEN_SIZE, weight_t, input_t, output_t, compute_t, index_t,\n                                         32 * 32,  // THREADS_PER_CTA\n                                         BYTES_PER_LDG_FINAL>;\n\n  auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f>;\n  kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);\n}\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_BWD_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\n\nREGISTER_BWD_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\n\nREGISTER_BWD_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);\nREGISTER_BWD_LAUNCHER(1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);\nREGISTER_BWD_LAUNCHER(1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);\n\nREGISTER_BWD_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);\n\nREGISTER_BWD_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);\nREGISTER_BWD_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);\nREGISTER_BWD_LAUNCHER(2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);\nREGISTER_BWD_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);\nREGISTER_BWD_LAUNCHER(2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);\n\nREGISTER_BWD_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);\n\nREGISTER_BWD_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);\nREGISTER_BWD_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);\nREGISTER_BWD_LAUNCHER(3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);\nREGISTER_BWD_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);\nREGISTER_BWD_LAUNCHER(3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);\n\nREGISTER_BWD_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);\n\nREGISTER_BWD_LAUNCHER(5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);\n\nREGISTER_BWD_LAUNCHER(6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);\n\nREGISTER_BWD_LAUNCHER(8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);\n\nREGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);\n\nREGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);\n\nREGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4);\nREGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4);\nREGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);\n\nREGISTER_BWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);\nREGISTER_BWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4);\nREGISTER_BWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);\nREGISTER_BWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);\nREGISTER_BWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);\n\nREGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);\nREGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4);\nREGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);\nREGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4);\nREGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);\n\nREGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);\n\nREGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4);\nREGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);\nREGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);\n\nREGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);\n\nREGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);\n\nREGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);\n\nREGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4);\nREGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4);\nREGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4);\nREGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4);\nREGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4);\n\nREGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);\n\nREGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);\n\nREGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);\n\nREGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);\n"
  },
  {
    "path": "apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu",
    "content": "#include \"ln.h\"\n#include \"ln_fwd_kernels.cuh\"\n#include \"ln_kernel_traits.h\"\n#include \"ln_utils.cuh\"\n\nusing namespace layer_norm;\n\ntemplate <typename weight_t, typename input_t, typename output_t, typename compute_t, typename index_t, int HIDDEN_SIZE,\n          int CTAS_PER_ROW, int WARPS_M, int WARPS_N, int BYTES_PER_LDG>\nvoid launch_(LaunchParams<FwdParams>& launch_params, const bool configure_params) {\n  using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE, CTAS_PER_ROW,\n                                      WARPS_M, WARPS_N, BYTES_PER_LDG>;\n  auto kernel = &ln_fwd_kernel<Kernel_traits>;\n\n  if (configure_params) {\n    int ctas_per_sm;\n    cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n        &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);\n    launch_params.params.ctas_per_col =\n        launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;\n    launch_params.barrier_size = 0;\n    launch_params.workspace_bytes = 0;\n    if (Kernel_traits::CTAS_PER_ROW > 1) {\n      launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;\n      launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M *\n                                      Kernel_traits::CTAS_PER_ROW * sizeof(typename Kernel_traits::Stats::stats_t) * 2;\n    }\n    return;\n  }\n\n  if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) {\n    CHECK_CUDA(\n        cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));\n  }\n  auto stream = launch_params.stream;\n  auto ctas_per_col = launch_params.params.ctas_per_col;\n\n  if (Kernel_traits::CTAS_PER_ROW == 1) {\n    kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(\n        launch_params.params);\n  } else {\n    dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);\n    dim3 block(Kernel_traits::THREADS_PER_CTA);\n    void* params_ = (void*)&launch_params.params;\n    cudaLaunchCooperativeKernel((void*)kernel, grid, block, (void**)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);\n  }\n}\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_FWD_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(768, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(768, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\n\nREGISTER_FWD_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\n\nREGISTER_FWD_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(1536, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(1536, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\n\nREGISTER_FWD_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\n\nREGISTER_FWD_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(2304, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(2304, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\n\nREGISTER_FWD_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\n\nREGISTER_FWD_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4);\nREGISTER_FWD_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4);\nREGISTER_FWD_LAUNCHER(3840, fp16, fp32, fp16, fp32, 1, 1, 4, 4);\nREGISTER_FWD_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4);\nREGISTER_FWD_LAUNCHER(3840, bf16, fp32, bf16, fp32, 1, 1, 4, 4);\n\nREGISTER_FWD_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\n\nREGISTER_FWD_LAUNCHER(5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\n\nREGISTER_FWD_LAUNCHER(6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(6144, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(6144, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\n\nREGISTER_FWD_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\n\nREGISTER_FWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\n\nREGISTER_FWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16);\n\nREGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4);\nREGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4);\nREGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4);\nREGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4);\nREGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4);\n\nREGISTER_FWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 2, 1, 4, 8);\n\nREGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8);\nREGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8);\nREGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8);\nREGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8);\nREGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8);\n\nREGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16);\n\nREGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16);\n\nREGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16);\n\nREGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16);\n\nREGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4);\nREGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8);\nREGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4);\nREGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8);\nREGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4);\n\nREGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4);\nREGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4);\nREGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4, 4);\nREGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4);\nREGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4);\n\nREGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16);\n\nREGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16);\n\nREGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 4, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16);\n\nREGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16);\nREGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16);\n"
  },
  {
    "path": "apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh",
    "content": "#pragma once\n\n#include \"ln.h\"\n#include \"ln_utils.cuh\"\n\nnamespace layer_norm {\n\ntemplate <typename Ktraits>\n__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_kernel(FwdParams params) {\n  enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };\n  enum { WARPS_N = Ktraits::WARPS_N };\n  enum { WARPS_M = Ktraits::WARPS_M };\n  enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };\n  enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };\n  enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };\n  enum { LDGS = Ktraits::LDGS };\n  enum { NUM_ELTS = Ktraits::NUM_ELTS };\n  enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };\n  enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };\n\n  using output_t = typename Ktraits::output_t;\n  using index_t = typename Ktraits::index_t;\n  using compute_t = typename Ktraits::compute_t;\n  using Ivec = typename Ktraits::Ivec;\n  using Ovec = typename Ktraits::Ovec;\n  using Wvec = typename Ktraits::Wvec;\n  using Cvec = typename Ktraits::Cvec;\n\n  using Stats = typename Ktraits::Stats;\n  using stats_t = typename Stats::stats_t;\n\n  extern __shared__ char smem_[];\n\n  const index_t tidx = threadIdx.x;\n  const index_t bidn = blockIdx.x % CTAS_PER_ROW;\n  const index_t bidm = blockIdx.x / CTAS_PER_ROW;\n  const index_t lane = tidx % THREADS_PER_WARP;\n  const index_t warp = tidx / THREADS_PER_WARP;\n  const index_t warp_m = warp / WARPS_N;\n  const index_t warp_n = warp % WARPS_N;\n\n  const index_t r = bidm * ROWS_PER_CTA + warp_m;\n  const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;\n\n  Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);\n\n  compute_t* mu_ptr = static_cast<compute_t*>(params.mu);\n  compute_t* rs_ptr = static_cast<compute_t*>(params.rs);\n\n  Wvec gamma[LDGS];\n  Wvec beta[LDGS];\n  index_t idx = c;\n#pragma unroll\n  for (int it = 0; it < LDGS; it++) {\n    gamma[it].load_from(params.gamma, idx);\n    beta[it].load_from(params.beta, idx);\n    idx += VEC_COLS_PER_LDG;\n  }\n\n  constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);\n\n  for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {\n    Ivec x[LDGS];\n    index_t idx = row * Ktraits::VEC_COLS + c;\n    compute_t xf[LDGS * NUM_ELTS];\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n      x[it].load_from(params.x, idx);\n#pragma unroll\n      for (int jt = 0; jt < NUM_ELTS; jt++) {\n        compute_t x_ij = compute_t(x[it].data.elt[jt]);\n        xf[it * NUM_ELTS + jt] = x_ij;\n      }\n      idx += VEC_COLS_PER_LDG;\n    }\n\n    stats_t s = stats.compute(xf, rn);\n\n    compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);\n    compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);\n\n    if (bidn == 0 && warp_n == 0 && lane == 0) {\n      mu_ptr[row] = mu;\n    }\n\n    compute_t rs = rsqrtf(rn * m2 + params.epsilon);\n\n    if (bidn == 0 && warp_n == 0 && lane == 0) {\n      rs_ptr[row] = rs;\n    }\n\n    Ovec z[LDGS];\n    idx = row * Ktraits::VEC_COLS + c;\n#pragma unroll\n    for (int it = 0; it < LDGS; it++) {\n#pragma unroll\n      for (int jt = 0; jt < NUM_ELTS; jt++) {\n        output_t y_ij = output_t(rs * (xf[it * NUM_ELTS + jt] - mu));\n        output_t g_ij = gamma[it].data.elt[jt];\n        output_t b_ij = beta[it].data.elt[jt];\n        z[it].data.elt[jt] = (g_ij * y_ij + b_ij);\n      }\n      z[it].store_to(params.z, idx);\n      idx += VEC_COLS_PER_LDG;\n    }\n  }\n}\n\n}  // namespace layer_norm\n"
  },
  {
    "path": "apex/contrib/csrc/layer_norm/ln_kernel_traits.h",
    "content": "#pragma once\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nnamespace layer_norm {\ntemplate <uint32_t HIDDEN_SIZE_, typename weight_t_, typename input_t_, typename output_t_, typename compute_t_,\n          typename index_t_, uint32_t THREADS_PER_CTA_>\nstruct Kernel_traits_base {\n  using weight_t = weight_t_;\n  using input_t = input_t_;\n  using output_t = output_t_;\n  using compute_t = compute_t_;\n  using index_t = index_t_;\n\n  enum { HIDDEN_SIZE = HIDDEN_SIZE_ };\n  enum { THREADS_PER_CTA = THREADS_PER_CTA_ };\n  enum { THREADS_PER_WARP = 32 };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <uint32_t HIDDEN_SIZE_, typename weight_t_, typename input_t_, typename output_t_, typename compute_t_,\n          typename index_t_, uint32_t THREADS_PER_CTA_, uint32_t BYTES_PER_LDG_,\n          typename Base =\n              Kernel_traits_base<HIDDEN_SIZE_, weight_t_, input_t_, output_t_, compute_t_, index_t_, THREADS_PER_CTA_> >\nstruct Kernel_traits_finalize : public Base {\n  enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP };\n  static_assert((int)ROWS_PER_CTA <= (int)Base::THREADS_PER_WARP);\n  // Bytes per global load from the input.\n  enum { BYTES_PER_LDG = BYTES_PER_LDG_ };\n  // Number of elements fetched by a global load.\n  enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) };\n  // Bytes per global store of the weights.\n  enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) };\n  static_assert(sizeof(BYTES_PER_LDG) == 4, \"Conflict-free smem transpose only implemented for 4B compute type!\");\n  static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, \"We assume one warp per row!\");\n  // The total number of BYTES_PER_LDG-wide words in a hidden vector.\n  enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG };\n  static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_));\n\n  // Shared memory size to transpose the CTA result.\n  enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG };\n  // Shared memory size to coalsece the CTA result.\n  enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };\n  // Shared memory requirement per CTA.\n  enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT };\n\n  // The type of the reducer.\n  using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>;\n\n  // Condition for the whole CTA to participate in syncthreads.\n  static_assert(COLS % Base::THREADS_PER_WARP == 0);\n  enum { CTAS = COLS / Base::THREADS_PER_WARP };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename weight_t_, typename input_t_, typename output_t_, typename compute_t_, typename index_t_,\n          uint32_t HIDDEN_SIZE_, uint32_t CTAS_PER_ROW_, uint32_t WARPS_M_, uint32_t WARPS_N_,\n          uint32_t BYTES_PER_LDG_ = 16,\n          typename Base = Kernel_traits_base<HIDDEN_SIZE_, weight_t_, input_t_, output_t_, compute_t_, index_t_,\n                                             WARPS_M_ * WARPS_N_ * THREADS_PER_WARP> >\nstruct Kernel_traits : public Base {\n  using input_t = typename Base::input_t;\n  using weight_t = typename Base::weight_t;\n  using compute_t = typename Base::compute_t;\n  using output_t = typename Base::output_t;\n  using index_t = typename Base::index_t;\n\n  enum { CTAS_PER_ROW = CTAS_PER_ROW_ };\n  enum { WARPS_M = WARPS_M_ };\n  enum { WARPS_N = WARPS_N_ };\n  enum { COLS = HIDDEN_SIZE_ };\n  enum { HIDDEN_SIZE = HIDDEN_SIZE_ };\n  enum { BYTES_PER_LDG = BYTES_PER_LDG_ };\n  enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) };\n\n  enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };\n  enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };\n  enum { ROWS_PER_CTA = WARPS_M };\n\n  enum { BYTES_PER_ROW = COLS * sizeof(input_t) };\n  enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };\n  // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed\n  enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) };\n  static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1);\n\n  using reduce_t = typename layer_norm::TypeToVec2<compute_t>::Type;\n  using Reducer = layer_norm::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;\n\n  enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES };\n  enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD };\n\n  using Ivec = layer_norm::Vec<input_t, NUM_ELTS>;\n  using Ovec = layer_norm::Vec<output_t, NUM_ELTS>;\n  using Wvec = layer_norm::Vec<weight_t, NUM_ELTS>;\n  using Cvec = layer_norm::Vec<compute_t, NUM_ELTS>;\n  enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) };\n\n  // Assume that each thread can handle the same number of elements in the output and weights as in the input.\n  static_assert(sizeof(input_t) >= sizeof(output_t));\n  static_assert(sizeof(input_t) >= sizeof(weight_t));\n  // The number of columns fetched per load from input: one per thread.\n  enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW };\n  // The total number of vectorized loads/stores per hidden vector.\n  enum { VEC_COLS = COLS / ELTS_PER_LDG };\n  // The number of loads per thread for the input.\n  enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG };\n  static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS);\n  // static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, \"\");\n\n  using Stats = layer_norm::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;\n  enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace layer_norm\n"
  },
  {
    "path": "apex/contrib/csrc/layer_norm/ln_utils.cuh",
    "content": "#pragma once\n\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n\n#include <cassert>\n\n#include \"ln.h\"\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nconstexpr uint32_t THREADS_PER_WARP = 32;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline void check_cuda_(cudaError_t status, const char* file, int line) {\n  if (status != cudaSuccess) {\n    fprintf(stderr, \"CUDA Error: %s %s %d\\n\", cudaGetErrorString(status), file, line);\n    exit(status);\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define CHECK_CUDA(ans)                     \\\n  {                                         \\\n    check_cuda_((ans), __FILE__, __LINE__); \\\n  }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define DIVUP(x, y) (((x) + ((y) - 1)) / (y))\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \\\n  void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams>& launch_params,           \\\n                                                                    const bool configure_params) {                    \\\n    launch_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>(        \\\n        launch_params, configure_params);                                                                             \\\n  }                                                                                                                   \\\n  static FwdRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE>                                                        \\\n  reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(                                                          \\\n      ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define REGISTER_BWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, \\\n                              BYTES_PER_LDG_FINALIZE)                                                                 \\\n  void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams>& launch_params,           \\\n                                                                    const bool configure_params) {                    \\\n    launch_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG,         \\\n            BYTES_PER_LDG_FINALIZE>(launch_params, configure_params);                                                 \\\n  }                                                                                                                   \\\n  static BwdRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE>                                                        \\\n  reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(                                                          \\\n      ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ float2 operator+(const float2& a, const float2& b) { return {a.x + b.x, a.y + b.y}; }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void operator+=(float2& a, const float2& b) {\n  a.x += b.x;\n  a.y += b.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T>\nstruct Sum {\n  inline __device__ Sum() {}\n  inline __device__ T operator()(const T& a, const T& b) { return a + b; }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T>\ninline __device__ T warp_shuffle_xor(const T& x, uint32_t idx) {\n  return __shfl_xor_sync(uint32_t(-1), x, idx);\n}\n\ntemplate <>\ninline __device__ float2 warp_shuffle_xor<float2>(const float2& x, uint32_t idx) {\n  return {warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx)};\n}\n\ntemplate <typename T>\ninline __device__ T warp_shuffle_down(const T& x, uint32_t idx) {\n  return __shfl_down_sync(uint32_t(-1), x, idx);\n}\n\ntemplate <>\ninline __device__ float2 warp_shuffle_down<float2>(const float2& x, uint32_t idx) {\n  return {warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx)};\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nnamespace layer_norm {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct uint16 {\n  uint4 u;\n  uint4 v;\n  uint4 s;\n  uint4 t;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct uint8 {\n  uint4 u;\n  uint4 v;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int BYTES>\nstruct BytesToType {};\n\ntemplate <>\nstruct BytesToType<64> {\n  using Type = uint16;\n  static_assert(sizeof(Type) == 64);\n};\n\ntemplate <>\nstruct BytesToType<32> {\n  using Type = uint8;\n  static_assert(sizeof(Type) == 32);\n};\n\ntemplate <>\nstruct BytesToType<16> {\n  using Type = uint4;\n  static_assert(sizeof(Type) == 16);\n};\n\ntemplate <>\nstruct BytesToType<8> {\n  using Type = uint64_t;\n  static_assert(sizeof(Type) == 8);\n};\n\ntemplate <>\nstruct BytesToType<4> {\n  using Type = uint32_t;\n  static_assert(sizeof(Type) == 4);\n};\n\ntemplate <>\nstruct BytesToType<2> {\n  using Type = uint16_t;\n  static_assert(sizeof(Type) == 2);\n};\n\ntemplate <>\nstruct BytesToType<1> {\n  using Type = uint8_t;\n  static_assert(sizeof(Type) == 1);\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T>\nstruct TypeToVec2 {};\n\ntemplate <>\nstruct TypeToVec2<float> {\n  using Type = float2;\n};\n\ntemplate <>\nstruct TypeToVec2<half> {\n  using Type = half2;\n};\n\ntemplate <>\nstruct TypeToVec2<nv_bfloat16> {\n  using Type = nv_bfloat162;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int INDEX>\nstruct Get {\n  template <typename T, typename R>\n  static inline __device__ R of(const T& vec);\n};\n\ntemplate <>\ntemplate <typename T, typename R>\ninline __device__ R Get<0>::of(const T& vec) {\n  return vec.x;\n}\n\ntemplate <>\ntemplate <typename T, typename R>\ninline __device__ R Get<1>::of(const T& vec) {\n  return vec.y;\n}\n\ntemplate <>\ntemplate <typename T, typename R>\ninline __device__ R Get<2>::of(const T& vec) {\n  return vec.z;\n}\n\ntemplate <>\ntemplate <typename T, typename R>\ninline __device__ R Get<3>::of(const T& vec) {\n  return vec.w;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Src, typename Dst>\nstruct Converter {\n  static inline __device__ Dst convert(const Src& from) { return Dst(from); }\n};\n\ntemplate <>\nstruct Converter<float2, half2> {\n  static inline __device__ half2 convert(const float2& x) { return __float22half2_rn(x); }\n};\n\ntemplate <>\nstruct Converter<float2, nv_bfloat162> {\n  static inline __device__ nv_bfloat162 convert(const float2& x) {\n#if __CUDA_ARCH__ >= 800\n    return __float22bfloat162_rn(x);\n#else\n    union {\n      nv_bfloat162 raw;\n      nv_bfloat16 x;\n      nv_bfloat16 y;\n    } tmp;\n    tmp.x = __float2bfloat16_rn(x.x);\n    tmp.y = __float2bfloat16_rn(x.y);\n    return tmp.raw;\n#endif\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T>\nstruct Zeros {\n  static inline __device__ T get() { return T(0.f); }\n};\n\ntemplate <>\nstruct Zeros<float2> {\n  static inline __device__ float2 get() { return make_float2(0.f, 0.f); }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Elt_type, uint32_t NUM_ELT>\nstruct Vec {\n  enum { BYTES = NUM_ELT * sizeof(Elt_type) };\n\n  using Vec_type = typename BytesToType<BYTES>::Type;\n\n  using Alias_type = union {\n    Vec_type vec;\n    Elt_type elt[NUM_ELT];\n  };\n\n  Alias_type data;\n\n  template <typename S>\n  inline __device__ void to(Vec<S, NUM_ELT>& other) {\n#pragma unroll\n    for (int it = 0; it < NUM_ELT; it++) {\n      other.data.elt[it] = S(this->data.elt[it]);\n    }\n  }\n\n  template <typename Op>\n  inline __device__ void assign(const Op& op) {\n#pragma unroll\n    for (int it = 0; it < NUM_ELT; it++) {\n      this->data.elt[it] = op(it);\n    }\n  }\n\n  inline __device__ void load_from(const void* base_ptr, const size_t idx) {\n    this->data.vec = static_cast<const Vec_type*>(base_ptr)[idx];\n  }\n\n  inline __device__ void store_to(void* base_ptr, const size_t idx) {\n    static_cast<Vec_type*>(base_ptr)[idx] = this->data.vec;\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <uint32_t CTAS_PER_ROW>\nstruct InterCTASync {\n  template <typename Params>\n  inline __device__ InterCTASync(Params& params, uint32_t bidm, uint32_t bidn)\n      : phase_counter_(0),\n        b0_(params.barrier + bidm)  // The barrier for this group of CTAs.\n        ,\n        b1_(params.barrier + bidm + params.ctas_per_col)  // The barrier for this group of CTAs.\n  {\n    // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!\n  }\n\n  inline __device__ void spin_wait_(int* barrier, int step, int expected) {\n    asm volatile(\"red.release.gpu.global.add.s32 [%0], %1;\" ::\"l\"(barrier), \"r\"(step));\n    for (int found = -1; found != expected;) {\n      asm volatile(\"ld.global.acquire.gpu.b32 %0, [%1];\" : \"=r\"(found) : \"l\"(barrier));\n    }\n  }\n\n  inline __device__ void sync() {\n    // ALL THREADS MUST ENTER!\n\n    // We switch barrier every iteration.\n    int* barrier = phase_counter_ & 0x1 ? b1_ : b0_;\n    // We decrement every other iteration.\n    bool dec = phase_counter_ & 0x2;\n    int step = dec ? -1 : 1;\n    int expected = dec ? 0 : CTAS_PER_ROW;\n    // There are only 4 phases: up/down for b0/b1.\n    phase_counter_ = (phase_counter_ + 1) & 0x3;\n\n    if (threadIdx.x == 0) {\n      spin_wait_(barrier, step, expected);\n    }\n    // CTA waits for thread 0\n    __syncthreads();\n  }\n\n  int phase_counter_;\n  int* b0_;\n  int* b1_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>\nstruct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {\n  using InterCTASync = InterCTASync<CTAS_PER_ROW>;\n  using Base = Reducer<T, 1, WARPS_M, WARPS_N>;\n  using Type = typename Base::Type;\n\n  enum { SMEM_BYTES = Base::SMEM_BYTES };\n\n  enum { WS_BARRIER_BYTES = 2 * sizeof(int) };\n  enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };\n\n  // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)\n  enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES };\n\n  template <typename Params>\n  inline __device__ Reducer(Params& params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n,\n                            uint32_t lane, void* smem)\n      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),\n        inter_cta_(params, bidm, bidn),\n        bidn_(bidn)  // CTA id within the group.\n        ,\n        w0_(static_cast<T*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),\n        w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) {}\n\n  template <typename Op>\n  inline __device__ T allreduce(T data, Op& op) {\n    data = Base::reduce(data, op);\n    // We switch workspace every iteration.\n    T* workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;\n\n    // Warp leaders 0 hold the CTA-local results.\n    if (this->warp_n_ == 0 && this->lane_ == 0) {\n      workspace[bidn_] = data;\n    }\n    inter_cta_.sync();\n    static_assert(CTAS_PER_ROW <= 32);\n    T total = Zeros<T>::get();\n    if (this->lane_ < CTAS_PER_ROW) {\n      total = workspace[this->lane_];\n    }\n    total = Reducer<T, 1, 1, 1>::allreduce_(total, op);\n\n    return total;\n  }\n\n  InterCTASync inter_cta_;\n\n  T* w0_;\n  T* w1_;\n  int bidn_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T, uint32_t WARPS_M>\nstruct Reducer<T, 1, WARPS_M, 1> {\n  using Type = T;\n  enum { SMEM_BYTES = 0 };\n  enum { WORKSPACE_BYTES_PER_GROUP = 0 };\n\n  enum { THREADS_PER_WARP = 32 };\n\n  template <typename Params>\n  inline __device__ Reducer(Params& params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n,\n                            uint32_t lane, void* smem)\n      : warp_n_(warp_n), lane_(lane) {}\n\n  template <typename Op>\n  static inline __device__ T allreduce_(T data, Op& op) {\n#pragma unroll\n    for (int it = 1; it < THREADS_PER_WARP; it *= 2) {\n      data = op(data, warp_shuffle_xor(data, it));\n    }\n    return data;\n  }\n\n  template <typename Op>\n  inline __device__ T allreduce(T data, Op& op) {\n    return allreduce_(data, op);\n  }\n\n  template <typename Op>\n  inline __device__ T reduce(T data, Op& op) {\n// only lane 0 holds the result!\n#pragma unroll\n    for (int it = THREADS_PER_WARP / 2; it > 0; it /= 2) {\n      data = op(data, warp_shuffle_down(data, it));\n    }\n    return data;\n  }\n  int warp_n_;\n  int lane_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T, uint32_t WARPS_M, uint32_t WARPS_N>\nstruct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {\n  using Base = Reducer<T, 1, WARPS_M, 1>;\n\n  using Type = T;\n\n  enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };\n  enum { WORKSPACE_BYTES_PER_GROUP = 0 };\n\n  enum { THREADS_PER_WARP = 32 };\n\n  template <typename Params>\n  inline __device__ Reducer(Params& params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n,\n                            uint32_t lane, void* smem)\n      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) {\n    smem0_ = &static_cast<T*>(smem)[warp_m * WARPS_N];\n    smem1_ = smem0_ + WARPS_M * WARPS_N;\n  }\n\n  template <typename Op>\n  inline __device__ T allreduce(T data, Op& op) {\n    T* smem = use0_ ? smem0_ : smem1_;\n    use0_ = !use0_;\n    data = Base::reduce(data, op);\n    if (this->lane_ == 0) {\n      smem[this->warp_n_] = data;\n    }\n    __syncthreads();\n    T out = Zeros<T>::get();\n#pragma unroll\n    for (int it = 0; it < WARPS_N; it++) {\n      out = op(out, smem[it]);\n    }\n    return out;\n  }\n\n  template <typename Op>\n  inline __device__ T reduce(T data, Op& op) {\n    T* smem = use0_ ? smem0_ : smem1_;\n    use0_ = !use0_;\n    // only intra-CTA group leader holds the result!\n    data = Base::reduce(data, op);\n    if (this->lane_ == 0) {\n      smem[this->warp_n_] = data;\n    }\n    __syncthreads();\n    T out = Zeros<T>::get();\n    if (this->warp_n_ == 0 && this->lane_ == 0) {\n#pragma unroll\n      for (int it = 0; it < WARPS_N; it++) {\n        out = op(out, smem[it]);\n      }\n    }\n    return out;\n  }\n\n  T* smem0_;\n  T* smem1_;\n  bool use0_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T>\ninline __device__ void warp_chan_upd_dynamic(T& m_a, T& m2_a, T& n_a, int num_active) {\n  // Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise)\n  int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);\n\n#pragma unroll\n  for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) {\n    // Exchange\n    T n_b = warp_shuffle_down(n_a, step);\n    T m_b = warp_shuffle_down(m_a, step);\n    T m2_b = warp_shuffle_down(m2_a, step);\n\n    // Update\n    const T n_ab = n_a + n_b;    // We can handle one of them being 0, not both.\n    const T rn_ab = 1.f / n_ab;  // Might have different n per thread, otherwise this would simplify :(\n    const T delta = m_a - m_b;\n    const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;\n    const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;\n\n    n_a = n_ab;\n    m_a = m_ab;\n    m2_a = m2_ab;\n  }\n  // Intra-warp broadcast (only lane 0 has valid stats).\n  m_a = __shfl_sync(uint32_t(-1), m_a, 0);\n  m2_a = __shfl_sync(uint32_t(-1), m2_a, 0);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>\nstruct Stats {\n  // This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields.\n\n  using InterCTASync = InterCTASync<CTAS_PER_ROW>;\n  using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;\n  using stats_t = typename BlockStats::stats_t;\n\n  enum { SMEM_BYTES = BlockStats::SMEM_BYTES };\n\n  template <typename Params>\n  inline __device__ Stats(Params& params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane,\n                          void* smem)\n      : inter_cta_(params, bidm, bidn),\n        block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem),\n        bidn_(bidn)  // CTA id within the group.\n        ,\n        w0_(static_cast<stats_t*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),\n        w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW),\n        warp_n_(warp_n),\n        lane_(lane) {}\n\n  template <uint32_t N>\n  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {\n    constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;\n    // TODO rn is not really needed here..\n    constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);\n    stats_t block_stats = block_stats_.compute(elts, block_rn);\n\n    stats_t* workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;\n\n    if (warp_n_ == 0 && lane_ == 0) {\n      workspace[bidn_] = block_stats;\n    }\n\n    // Wait for all CTAS_PER_ROW CTAS in the group to have written their result.\n    inter_cta_.sync();\n\n    T n = Zeros<T>::get();\n    T m = Zeros<T>::get();\n    T m2 = Zeros<T>::get();\n\n    // Assume CTA group size in N less than 32, such that we can finalize with a single warp.\n    static_assert(CTAS_PER_ROW <= 32);\n\n    // Every warp does the final reduction locally.\n    if (lane_ < CTAS_PER_ROW) {\n      stats_t result = workspace[lane_];\n      n = ELTS_PER_ROW_PER_CTA;\n      m = layer_norm::Get<0>::of<stats_t, T>(result);\n      m2 = layer_norm::Get<1>::of<stats_t, T>(result);\n    }\n\n    warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);\n\n    return {m, m2};\n  }\n\n  InterCTASync inter_cta_;\n  BlockStats block_stats_;\n\n  stats_t* w0_;\n  stats_t* w1_;\n  int bidn_;\n  int warp_n_;\n  int lane_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T, uint32_t WARPS_M, uint32_t WARPS_N>\nstruct Stats<T, 1, WARPS_M, WARPS_N> {\n  using WarpStats = Stats<T, 1, WARPS_M, 1>;\n  using stats_t = typename WarpStats::stats_t;\n\n  enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };\n\n  template <typename Params>\n  inline __device__ Stats(Params& params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane,\n                          void* smem)\n      : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) {\n    smem0_ = static_cast<stats_t*>(smem) + warp_m * WARPS_N;\n    smem1_ = smem0_ + WARPS_M * WARPS_N;\n  }\n\n  template <uint32_t N>\n  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {\n    stats_t* smem = use0_ ? smem0_ : smem1_;\n    use0_ = !use0_;\n    // Compute warp local for all WARPS_N\n    constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP);\n    stats_t warp_stats = warp_stats_.compute(elts, warp_rn);\n\n    // Each warp warp leader stores its stats\n    const auto warp_n = warp_stats_.reducer_.warp_n_;\n    const auto lane = warp_stats_.reducer_.lane_;\n    if (lane == 0) {\n      smem[warp_n] = warp_stats;\n    }\n    __syncthreads();\n\n    T n = Zeros<T>::get();\n    T m = Zeros<T>::get();\n    T m2 = Zeros<T>::get();\n\n    // Assume that there are less than 32 warps, such that we can finalize with a single warp\n    static_assert(WARPS_N <= 32);\n    if (lane < WARPS_N) {\n      stats_t result = smem[lane];\n      n = N * THREADS_PER_WARP;\n      m = layer_norm::Get<0>::of<stats_t, T>(result);\n      m2 = layer_norm::Get<1>::of<stats_t, T>(result);\n    }\n\n    warp_chan_upd_dynamic(m, m2, n, WARPS_N);\n\n    return {m, m2};\n  }\n  WarpStats warp_stats_;\n  stats_t* smem0_;\n  stats_t* smem1_;\n  bool use0_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T, uint32_t WARPS_M>\nstruct Stats<T, 1, WARPS_M, 1> {\n  using stats_t = typename TypeToVec2<T>::Type;\n  // The simple Warp reducer.\n  using Reducer = Reducer<T, 1, WARPS_M, 1>;\n\n  enum { SMEM_BYTES = 0 };\n\n  template <typename Params>\n  inline __device__ Stats(Params& params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane,\n                          void* smem)\n      : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) {}\n\n  template <uint32_t N>\n  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {\n    auto sum = Sum<T>();\n\n    T m = Zeros<T>::get();\n#pragma unroll\n    for (int it = 0; it < N; it++) {\n      m += elts[it];\n    }\n    m = reducer_.allreduce(m, sum) * rn;\n\n    T m2 = Zeros<T>::get();\n#pragma unroll\n    for (int it = 0; it < N; it++) {\n      T diff = (elts[it] - m);\n      m2 += diff * diff;\n    }\n    m2 = reducer_.allreduce(m2, sum);\n\n    return {m, m2};\n  }\n\n  Reducer reducer_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace layer_norm\n"
  },
  {
    "path": "apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include <cuda_runtime.h>\n#include <math.h>\n#include <torch/extension.h>\n\n#include <iostream>\n#include <vector>\n\n#include \"dropout.cuh\"\n#include \"softmax.cuh\"\n\n// symbol to be automatically resolved by PyTorch libs\n\nnamespace multihead_attn {\nnamespace fused_softmax {\nnamespace additive_mask_softmax_dropout {\n\nstd::vector<torch::Tensor> fwd_cuda(bool is_training, int heads, torch::Tensor const& input, const half* pad_mask,\n                                    float dropout_prob) {\n  const int attn_batches = input.size(0);\n  const int sequences = attn_batches / heads;\n  const int q_seq_len = input.size(1);\n  const int k_seq_len = q_seq_len;\n  // const int dropout_elems = attn_batches * q_seq_len * k_seq_len;\n\n  // There is no reason to use more than one stream as every kernel is\n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated\n  // by ATen library code)\n  auto act_options = input.options().requires_grad(false);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);\n  torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);\n  torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* input_ptr = static_cast<void*>(input.data_ptr());\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n\n  // Padded Softmax\n  [[maybe_unused]] bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(reinterpret_cast<half*>(softmax_results_ptr),\n                                                          reinterpret_cast<const half*>(input_ptr), k_seq_len,\n                                                          k_seq_len, attn_batches * q_seq_len);\n  } else {\n    softmax_success = dispatch_additive_masked_softmax<half, half, float>(\n        reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<const half*>(input_ptr), pad_mask, k_seq_len,\n        k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences);\n  }\n\n  if (is_training) {\n    // use at:: function so that C++ version generates the same random mask as\n    // python version\n    auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f - dropout_prob);\n    dropout_results = std::get<0>(dropout_tuple);\n    dropout_mask = std::get<1>(dropout_tuple);\n  }\n\n  // Matmul2\n\n  return {dropout_results, dropout_mask, softmax_results};\n}\n\ntorch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results,\n                       torch::Tensor const& dropout_mask, float dropout_prob) {\n  const int attn_batches = output_grads.size(0);\n  const int q_seq_len = output_grads.size(1);\n  const int k_seq_len = q_seq_len;\n  // const int dropout_elems = attn_batches * q_seq_len * k_seq_len;\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // Output Tensor Allocations\n  //  torch::Tensor input_grads         = torch::empty_like(output_grads);\n\n  // Apply Dropout Mask and Scale by Dropout Probability\n  // Softmax Grad\n  dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(\n      static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),\n      reinterpret_cast<half const*>(softmax_results.data_ptr()), static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n      1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, attn_batches * q_seq_len, stream);\n  // backward pass is completely in-place\n  return output_grads;\n}\n}  // namespace additive_mask_softmax_dropout\n}  // namespace fused_softmax\n}  // namespace multihead_attn\n"
  },
  {
    "path": "apex/contrib/csrc/multihead_attn/dropout.cuh",
    "content": "#pragma once\n#include <ATen/ATen.h>\n\n#ifdef OLD_GENERATOR_PATH\n#include <ATen/CUDAGeneratorImpl.h>\n#else\n#include <ATen/cuda/CUDAGeneratorImpl.h>\n#endif\n\n#include <ATen/cuda/CUDAContext.h>\n#include <curand_kernel.h>\n\nnamespace {\nconstexpr int UNROLL = 4;\n}  // namespace\n\ntemplate <typename scalar_t, typename accscalar_t, typename IndexType>\n__global__ void apex_fused_dropout_kernel(scalar_t const* inputs, scalar_t* outputs, uint8_t* mask,\n                                          IndexType totalElements, accscalar_t p, std::pair<uint64_t, uint64_t> seeds) {\n  accscalar_t pinv = accscalar_t(1) / p;\n  IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n  curandStatePhilox4_32_10_t state;\n  curand_init(seeds.first, idx, seeds.second, &state);\n\n  IndexType rounded_size =\n      ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * blockDim.x * gridDim.x * UNROLL;\n  for (IndexType linearIndex = idx; linearIndex < rounded_size; linearIndex += gridDim.x * blockDim.x * UNROLL) {\n    float4 rand = curand_uniform4(&state);\n    scalar_t src[UNROLL];\n    rand.x = rand.x <= p;\n    rand.y = rand.y <= p;\n    rand.z = rand.z <= p;\n    rand.w = rand.w <= p;\n\n    for (int ii = 0; ii < UNROLL; ii++) {\n      IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n      if (li < totalElements) {\n        src[ii] = inputs[li];\n      }\n    }\n    for (int ii = 0; ii < UNROLL; ii++) {\n      IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n      if (li < totalElements) {\n        outputs[li] = src[ii] * (&rand.x)[ii] * pinv;\n        mask[li] = (uint8_t)(&rand.x)[ii];\n      }\n    }\n    __syncthreads();\n  }\n}\n\ntemplate <typename scalar_t, typename accscalar_t, typename IndexType>\n__global__ void apex_dropout_add_kernel(scalar_t const* inputs, scalar_t const* add_inputs, scalar_t* outputs,\n                                        uint8_t* mask, IndexType totalElements, accscalar_t p,\n                                        std::pair<uint64_t, uint64_t> seeds) {\n  accscalar_t pinv = accscalar_t(1) / p;\n  IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n  curandStatePhilox4_32_10_t state;\n  curand_init(seeds.first, idx, seeds.second, &state);\n\n  IndexType rounded_size =\n      ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * blockDim.x * gridDim.x * UNROLL;\n  for (IndexType linearIndex = idx; linearIndex < rounded_size; linearIndex += gridDim.x * blockDim.x * UNROLL) {\n    float4 rand = curand_uniform4(&state);\n    scalar_t src[UNROLL];\n    scalar_t add_src[UNROLL];\n    rand.x = rand.x <= p;\n    rand.y = rand.y <= p;\n    rand.z = rand.z <= p;\n    rand.w = rand.w <= p;\n    for (int ii = 0; ii < UNROLL; ii++) {\n      IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n      if (li < totalElements) {\n        src[ii] = inputs[li];\n        add_src[ii] = add_inputs[li];\n      }\n    }\n    for (int ii = 0; ii < UNROLL; ii++) {\n      IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n      if (li < totalElements) {\n        accscalar_t int1 = src[ii] * (&rand.x)[ii] * pinv;\n        outputs[li] = static_cast<scalar_t>(static_cast<accscalar_t>(add_src[ii]) + int1);\n        mask[li] = (uint8_t)(&rand.x)[ii];\n      }\n    }\n    __syncthreads();\n  }\n}\n\ntemplate <typename scalar_t, typename accscalar_t, typename IndexType>\n__global__ void apex_add_kernel(scalar_t const* inputs, scalar_t const* add_inputs, scalar_t* outputs,\n                                IndexType totalElements) {\n  IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;\n  IndexType rounded_size =\n      ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * blockDim.x * gridDim.x * UNROLL;\n  for (IndexType linearIndex = idx; linearIndex < rounded_size; linearIndex += gridDim.x * blockDim.x * UNROLL) {\n    scalar_t src[UNROLL];\n    scalar_t add_src[UNROLL];\n    for (int ii = 0; ii < UNROLL; ii++) {\n      IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n      if (li < totalElements) {\n        src[ii] = inputs[li];\n        add_src[ii] = add_inputs[li];\n      }\n    }\n    for (int ii = 0; ii < UNROLL; ii++) {\n      IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n      if (li < totalElements) {\n        outputs[li] = src[ii] + add_src[ii];\n      }\n    }\n    __syncthreads();\n  }\n}\n\ntemplate <typename scalar_t, typename accscalar_t, typename IndexType>\n__global__ void apex_masked_scale_kernel(scalar_t const* inputs, scalar_t* outputs, uint8_t const* mask,\n                                         IndexType totalElements, accscalar_t scale) {\n  IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;\n  IndexType rounded_size =\n      ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * blockDim.x * gridDim.x * UNROLL;\n  for (IndexType linearIndex = idx; linearIndex < rounded_size; linearIndex += gridDim.x * blockDim.x * UNROLL) {\n    scalar_t src[UNROLL];\n    scalar_t msk[UNROLL];\n    for (int ii = 0; ii < UNROLL; ii++) {\n      IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n      if (li < totalElements) {\n        src[ii] = static_cast<scalar_t>(inputs[li]);\n        msk[ii] = static_cast<scalar_t>(mask[li]);\n      }\n    }\n    for (int ii = 0; ii < UNROLL; ii++) {\n      IndexType li = linearIndex + blockDim.x * gridDim.x * ii;\n      if (li < totalElements) {\n        outputs[li] = static_cast<accscalar_t>(src[ii]) * scale * static_cast<accscalar_t>(msk[ii]);\n      }\n    }\n  }\n}\n\ntemplate <typename scalar_t, typename accscalar_t, typename IndexType>\nvoid apex_fused_dropout_cuda(scalar_t const* inputs, scalar_t* outputs, uint8_t* mask, IndexType totalElements,\n                             accscalar_t p) {\n  auto gen = at::cuda::detail::getDefaultCUDAGenerator();\n\n  int block_size = 256;\n  dim3 dim_block(block_size);\n  dim3 grid((totalElements + block_size - 1) / block_size);\n  unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;\n  grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);\n\n  // number of times random will be generated per thread, to offset philox\n  // counter in the random state\n  int64_t counter_offset = ((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL;\n  std::pair<uint64_t, uint64_t> rng_engine_inputs;\n  {\n    // See Note [Acquire lock when using random generators]\n    std::lock_guard<std::mutex> lock(gen.mutex());\n    rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset);\n  }\n\n  apex_fused_dropout_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(\n      inputs, outputs, mask, totalElements, p, rng_engine_inputs);\n  C10_CUDA_CHECK(cudaGetLastError());\n}\n\ntemplate <typename scalar_t, typename accscalar_t, typename IndexType>\nvoid apex_dropout_add_cuda(scalar_t const* inputs, scalar_t const* add_inputs, scalar_t* outputs, uint8_t* mask,\n                           IndexType totalElements, accscalar_t p) {\n  auto gen = at::cuda::detail::getDefaultCUDAGenerator();\n\n  int block_size = 256;\n  dim3 dim_block(block_size);\n  dim3 grid((totalElements + block_size - 1) / block_size);\n  unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;\n  grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);\n\n  // number of times random will be generated per thread, to offset philox\n  // counter in the random state\n  int64_t counter_offset = ((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL;\n  std::pair<uint64_t, uint64_t> rng_engine_inputs;\n  {\n    // See Note [Acquire lock when using random generators]\n    std::lock_guard<std::mutex> lock(gen.mutex());\n    rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset);\n  }\n\n  apex_dropout_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(\n      inputs, add_inputs, outputs, mask, totalElements, p, rng_engine_inputs);\n  C10_CUDA_CHECK(cudaGetLastError());\n}\n\ntemplate <typename scalar_t, typename accscalar_t, typename IndexType>\nvoid apex_add_cuda(scalar_t const* inputs, scalar_t const* add_inputs, scalar_t* outputs, IndexType totalElements) {\n  int block_size = 256;\n  dim3 dim_block(block_size);\n  dim3 grid((totalElements + block_size - 1) / block_size);\n  unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;\n  grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);\n\n  apex_add_kernel<scalar_t, accscalar_t, IndexType>\n      <<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, totalElements);\n  C10_CUDA_CHECK(cudaGetLastError());\n}\n\ntemplate <typename scalar_t, typename accscalar_t, typename IndexType>\nvoid apex_masked_scale_cuda(scalar_t const* inputs, scalar_t* outputs, uint8_t const* mask, IndexType totalElements,\n                            accscalar_t scale) {\n  int block_size = 256;\n  dim3 dim_block(block_size);\n  dim3 grid((totalElements + block_size - 1) / block_size);\n  unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;\n  grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);\n\n  apex_masked_scale_kernel<scalar_t, accscalar_t, IndexType>\n      <<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, scale);\n  C10_CUDA_CHECK(cudaGetLastError());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include <cuda_runtime.h>\n#include <math.h>\n#include <torch/extension.h>\n\n#include <iostream>\n#include <vector>\n\n#include \"dropout.cuh\"\n#include \"softmax.cuh\"\n#include \"strided_batched_gemm.cuh\"\n\nnamespace multihead_attn {\nnamespace encdec {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs_q,\n                                    torch::Tensor const& inputs_kv, torch::Tensor const& input_weights_q,\n                                    torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights,\n                                    const uint8_t* pad_mask, float dropout_prob) {\n  const int embed_dim = inputs_q.size(2);\n  const int sequences = inputs_q.size(1);\n  const int q_seq_len = inputs_q.size(0);\n  const int k_seq_len = inputs_kv.size(0);\n  const int batches_q = sequences * q_seq_len;\n  const int batches_kv = sequences * k_seq_len;\n  const int head_dim = embed_dim / heads;\n  const int output_lin_q_dim = embed_dim;\n  const int output_lin_kv_dim = 2 * embed_dim;\n  const int attn_batches = heads * sequences;\n  const int lead_dim_q = attn_batches * head_dim;\n  const int lead_dim_kv = attn_batches * 2 * head_dim;\n  const int batch_stride_q = head_dim;\n  const int batch_stride_kv = 2 * head_dim;\n  const int dropout_elems = attn_batches * q_seq_len * k_seq_len;\n  const float alpha = 1.0;\n  const float beta = 0.0;\n  const float scale = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // There is no reason to use more than one stream as every kernel is\n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated\n  // by ATen library code)\n  auto act_options = inputs_q.options().requires_grad(false);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options);\n  torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);\n  torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);\n  torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);\n  torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);\n  torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);\n  torch::Tensor outputs = torch::empty_like(inputs_q, act_options);\n\n  // Input Linear Results Pointers to Q, K, and V of interviewed activations\n  void* q_lin_results_ptr = static_cast<void*>(input_lin_q_results.data_ptr());\n  void* k_lin_results_ptr = static_cast<void*>(input_lin_kv_results.data_ptr());\n  void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n\n  char a_layout_t{'t'};\n  char a_layout_n{'n'};\n  char b_layout_n{'n'};\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  // Input Linear Q Fwd\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_q_dim, batches_q, embed_dim, static_cast<const void*>(&alpha),\n      static_cast<const void*>(input_weights_q.data_ptr()), CUDA_R_16F, embed_dim,\n      static_cast<const void*>(inputs_q.data_ptr()), CUDA_R_16F, embed_dim, static_cast<const void*>(&beta),\n      q_lin_results_ptr, CUDA_R_16F, output_lin_q_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Input Linear KV Fwd\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_kv_dim, batches_kv, embed_dim, static_cast<const void*>(&alpha),\n      static_cast<const void*>(input_weights_kv.data_ptr()), CUDA_R_16F, embed_dim,\n      static_cast<const void*>(inputs_kv.data_ptr()), CUDA_R_16F, embed_dim, static_cast<const void*>(&beta),\n      k_lin_results_ptr, CUDA_R_16F, output_lin_kv_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)\n  gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale,\n                        static_cast<const half*>(k_lin_results_ptr), lead_dim_kv, batch_stride_kv,\n                        static_cast<const half*>(q_lin_results_ptr), lead_dim_q, batch_stride_q, beta,\n                        static_cast<half*>(softmax_results_ptr), k_seq_len, k_seq_len * q_seq_len, attn_batches);\n\n  // Padded Softmax\n  bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(reinterpret_cast<half*>(softmax_results_ptr),\n                                                          reinterpret_cast<const half*>(softmax_results_ptr), k_seq_len,\n                                                          k_seq_len, attn_batches * q_seq_len);\n  } else {\n    if (use_time_mask) {\n      softmax_success = dispatch_time_masked_softmax<half, half, float>(\n          reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<const half*>(softmax_results_ptr), pad_mask,\n          k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);\n    } else {\n      softmax_success = dispatch_masked_softmax<half, half, float>(\n          reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<const half*>(softmax_results_ptr), pad_mask,\n          k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences);\n    }\n  }\n  assert(softmax_success);\n\n  if (is_training) {\n    apex_fused_dropout_cuda<at::Half, float, uint32_t>(\n        static_cast<at::Half const*>(softmax_results.data_ptr()), static_cast<at::Half*>(dropout_results.data_ptr()),\n        static_cast<uint8_t*>(dropout_mask.data_ptr()), dropout_elems, (1.0f - dropout_prob));\n  }\n\n  // Matmul2\n  gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha,\n                        static_cast<const half*>(v_lin_results_ptr), lead_dim_kv, batch_stride_kv,\n                        (is_training) ? static_cast<const half*>(dropout_results.data_ptr())\n                                      : static_cast<const half*>(softmax_results.data_ptr()),\n                        k_seq_len, k_seq_len * q_seq_len, beta, static_cast<half*>(matmul2_results.data_ptr()),\n                        head_dim * attn_batches, head_dim, attn_batches);\n\n  // Output Linear\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches_q, embed_dim, static_cast<const void*>(&alpha),\n      static_cast<const void*>(output_weights.data_ptr()), CUDA_R_16F, embed_dim,\n      static_cast<const void*>(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast<const void*>(&beta),\n      static_cast<void*>(outputs.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F,\n      // CUBLAS_GEMM_ALGO1_TENSOR_OP));\n      CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {input_lin_q_results, input_lin_kv_results, softmax_results, dropout_results,\n          dropout_mask,        matmul2_results,      outputs};\n}\n\nstd::vector<torch::Tensor> bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results,\n                                    torch::Tensor const& dropout_results, torch::Tensor const& softmax_results,\n                                    torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results,\n                                    torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv,\n                                    torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv,\n                                    torch::Tensor const& output_weights, torch::Tensor const& dropout_mask,\n                                    float dropout_prob) {\n  const int embed_dim = inputs_q.size(2);\n  const int sequences = inputs_q.size(1);\n  const int q_seq_len = inputs_q.size(0);\n  const int k_seq_len = inputs_kv.size(0);\n  const int batches_q = sequences * q_seq_len;\n  const int batches_kv = sequences * k_seq_len;\n  const int head_dim = embed_dim / heads;\n  const int output_lin_q_dim = embed_dim;\n  const int output_lin_kv_dim = 2 * embed_dim;\n  const int attn_batches = heads * sequences;\n  const int lead_dim_q = attn_batches * head_dim;\n  const int lead_dim_kv = attn_batches * 2 * head_dim;\n  const int batch_stride_q = head_dim;\n  const int batch_stride_kv = 2 * head_dim;\n  const int dropout_elems = attn_batches * q_seq_len * k_seq_len;\n  const float alpha = 1.0;\n  const float beta = 0.0;\n  const float scale = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // Output Tensor Allocations\n  torch::Tensor input_q_grads = torch::empty_like(inputs_q);\n  torch::Tensor input_kv_grads = torch::empty_like(inputs_kv);\n  torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q);\n  torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv);\n  torch::Tensor output_weight_grads = torch::empty_like(output_weights);\n  // Intermediate Tensor Allocations\n  at::Tensor output_lin_grads = torch::empty_like(matmul2_results);\n  at::Tensor matmul2_grads = torch::empty_like(dropout_results);\n  at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results);\n  at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results);\n\n  auto q_lin_results_ptr = static_cast<half*>(input_lin_q_results.data_ptr());\n  auto k_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr());\n  auto v_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim;\n\n  auto q_lin_grads_ptr = static_cast<half*>(input_lin_q_output_grads.data_ptr());\n  auto k_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr());\n  auto v_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr()) + head_dim;\n\n  char a_layout_n{'n'};\n  char a_layout_t{'t'};\n  char b_layout_n{'n'};\n  char b_layout_t{'t'};\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n\n  // Output Linear Dgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, embed_dim,\n                                    static_cast<const void*>(&alpha),\n                                    static_cast<const void*>(output_weights.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(output_grads.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(&beta), static_cast<void*>(output_lin_grads.data_ptr()),\n                                    CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Output Linear Wgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches_q,\n                                    static_cast<const void*>(&alpha),\n                                    static_cast<const void*>(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(output_grads.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(&beta), static_cast<void*>(output_weight_grads.data_ptr()),\n                                    CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul2 Dgrad1\n  gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha,\n                        static_cast<const half*>(v_lin_results_ptr), lead_dim_kv, batch_stride_kv,\n                        static_cast<const half*>(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, beta,\n                        static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, attn_batches);\n\n  // Matmul2 Dgrad2\n  gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, alpha,\n                        static_cast<const half*>(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim,\n                        static_cast<const half*>(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta,\n                        v_lin_grads_ptr, lead_dim_kv, batch_stride_kv, attn_batches);\n\n  // Apply Dropout Mask and Scale by Dropout Probability\n  apex_masked_scale_cuda<at::Half, float, uint32_t>(\n      static_cast<at::Half const*>(matmul2_grads.data_ptr()), static_cast<at::Half*>(matmul2_grads.data_ptr()),\n      static_cast<uint8_t const*>(dropout_mask.data_ptr()), dropout_elems, (1.0 / (1.0 - dropout_prob)));\n\n  // Softmax Grad\n  bool softmax_success = false;\n  softmax_success = dispatch_softmax_backward<half, half, float>(\n      static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),\n      reinterpret_cast<half const*>(softmax_results.data_ptr()), k_seq_len, k_seq_len, attn_batches * q_seq_len);\n  assert(softmax_success);\n\n  // Matmul1 Dgrad1\n  gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, scale, k_lin_results_ptr, lead_dim_kv,\n                        batch_stride_kv, static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len,\n                        beta, q_lin_grads_ptr, lead_dim_q, batch_stride_q, attn_batches);\n\n  // Matmul1 Dgrad2\n  gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, scale, q_lin_results_ptr, lead_dim_q,\n                        batch_stride_q, static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len,\n                        beta, k_lin_grads_ptr, lead_dim_kv, batch_stride_kv, attn_batches);\n\n  // Input Linear Q Dgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, output_lin_q_dim, static_cast<const void*>(&alpha),\n      static_cast<const void*>(input_weights_q.data_ptr()), CUDA_R_16F, embed_dim,\n      static_cast<const void*>(q_lin_grads_ptr), CUDA_R_16F, output_lin_q_dim, static_cast<const void*>(&beta),\n      static_cast<void*>(input_q_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F,\n      // CUBLAS_GEMM_ALGO10_TENSOR_OP));\n      CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Input Linear Q Wgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_q_dim, batches_q,\n                                    static_cast<const void*>(&alpha), static_cast<const void*>(inputs_q.data_ptr()),\n                                    CUDA_R_16F, embed_dim, static_cast<const void*>(q_lin_grads_ptr), CUDA_R_16F,\n                                    output_lin_q_dim, static_cast<const void*>(&beta),\n                                    static_cast<void*>(input_weight_q_grads.data_ptr()), CUDA_R_16F, embed_dim,\n                                    CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Input Linear KV Dgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_kv, output_lin_kv_dim, static_cast<const void*>(&alpha),\n      static_cast<const void*>(input_weights_kv.data_ptr()), CUDA_R_16F, embed_dim,\n      static_cast<const void*>(k_lin_grads_ptr), CUDA_R_16F, output_lin_kv_dim, static_cast<const void*>(&beta),\n      static_cast<void*>(input_kv_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F,\n      // CUBLAS_GEMM_ALGO10_TENSOR_OP));\n      CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Input Linear KV Wgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_kv_dim, batches_kv,\n                                    static_cast<const void*>(&alpha), static_cast<const void*>(inputs_kv.data_ptr()),\n                                    CUDA_R_16F, embed_dim, static_cast<const void*>(k_lin_grads_ptr), CUDA_R_16F,\n                                    output_lin_kv_dim, static_cast<const void*>(&beta),\n                                    static_cast<void*>(input_weight_kv_grads.data_ptr()), CUDA_R_16F, embed_dim,\n                                    CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {input_q_grads, input_kv_grads, input_weight_q_grads, input_weight_kv_grads, output_weight_grads};\n}\n\n}  // end namespace cublas_gemmex\n}  // end namespace encdec\n}  // end namespace multihead_attn\n"
  },
  {
    "path": "apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include <cuda_runtime.h>\n#include <math.h>\n#include <torch/extension.h>\n\n#include <iostream>\n#include <vector>\n\n#include \"dropout.cuh\"\n#include \"layer_norm.cuh\"\n#include \"softmax.cuh\"\n#include \"strided_batched_gemm.cuh\"\n\nnamespace multihead_attn {\nnamespace encdec_norm_add {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs_q,\n                                    torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights,\n                                    torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q,\n                                    torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights,\n                                    const uint8_t* pad_mask, float dropout_prob) {\n  const int embed_dim = inputs_q.size(2);\n  const int sequences = inputs_q.size(1);\n  const int q_seq_len = inputs_q.size(0);\n  const int k_seq_len = inputs_kv.size(0);\n  const int batches_q = sequences * q_seq_len;\n  const int batches_kv = sequences * k_seq_len;\n  const int total_tokens_q = batches_q * embed_dim;\n  const int head_dim = embed_dim / heads;\n  const int output_lin_q_dim = embed_dim;\n  const int output_lin_kv_dim = 2 * embed_dim;\n  const int attn_batches = heads * sequences;\n  const int lead_dim_q = attn_batches * head_dim;\n  const int lead_dim_kv = attn_batches * 2 * head_dim;\n  const int batch_stride_q = head_dim;\n  const int batch_stride_kv = 2 * head_dim;\n  const int dropout_elems = attn_batches * q_seq_len * k_seq_len;\n  const float alpha = 1.0;\n  const float beta = 0.0;\n  const float scale = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // There is no reason to use more than one stream as every kernel is\n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated\n  // by ATen library code)\n  auto act_options = inputs_q.options().requires_grad(false);\n  auto lyr_nrm_options = act_options.dtype(torch::kFloat32);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor lyr_nrm_mean = torch::empty({batches_q}, lyr_nrm_options);\n  torch::Tensor lyr_nrm_invvar = torch::empty({batches_q}, lyr_nrm_options);\n  torch::Tensor lyr_nrm_results = torch::empty_like(inputs_q, act_options);\n\n  torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options);\n  torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options);\n  torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);\n  torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);\n  torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);\n  torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);\n  torch::Tensor output_lin_results = torch::empty_like(inputs_q, act_options);\n  torch::Tensor dropout_add_mask = torch::empty_like(inputs_q, mask_options);\n  torch::Tensor outputs = torch::empty_like(inputs_q, act_options);\n\n  // Input Linear Results Pointers to Q, K, and V of interviewed activations\n  void* q_lin_results_ptr = static_cast<void*>(input_lin_q_results.data_ptr());\n  void* k_lin_results_ptr = static_cast<void*>(input_lin_kv_results.data_ptr());\n  void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n\n  char a_layout_t{'t'};\n  char a_layout_n{'n'};\n  char b_layout_n{'n'};\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  // Layer Norm\n  HostApplyLayerNorm<at::Half, float>(\n      static_cast<at::Half*>(lyr_nrm_results.data_ptr()), static_cast<float*>(lyr_nrm_mean.data_ptr()),\n      static_cast<float*>(lyr_nrm_invvar.data_ptr()), static_cast<const at::Half*>(inputs_q.data_ptr()),\n      static_cast<int>(batches_q),  // n1\n      static_cast<int>(embed_dim),  // n2\n      1.0e-5, static_cast<const at::Half*>(lyr_nrm_gamma_weights.data_ptr()),\n      static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));\n\n  // Input Linear Q Fwd\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_q_dim, batches_q, embed_dim, static_cast<const void*>(&alpha),\n      static_cast<const void*>(input_weights_q.data_ptr()), CUDA_R_16F, embed_dim,\n      // static_cast<const void*>(inputs_q.data_ptr()),\n      static_cast<const void*>(lyr_nrm_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast<const void*>(&beta),\n      q_lin_results_ptr, CUDA_R_16F, output_lin_q_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Input Linear KV Fwd\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_kv_dim, batches_kv, embed_dim, static_cast<const void*>(&alpha),\n      static_cast<const void*>(input_weights_kv.data_ptr()), CUDA_R_16F, embed_dim,\n      static_cast<const void*>(inputs_kv.data_ptr()), CUDA_R_16F, embed_dim, static_cast<const void*>(&beta),\n      k_lin_results_ptr, CUDA_R_16F, output_lin_kv_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)\n  gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale,\n                        static_cast<const half*>(k_lin_results_ptr), lead_dim_kv, batch_stride_kv,\n                        static_cast<const half*>(q_lin_results_ptr), lead_dim_q, batch_stride_q, beta,\n                        static_cast<half*>(softmax_results_ptr), k_seq_len, k_seq_len * q_seq_len, attn_batches);\n\n  // Padded Softmax\n  bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(reinterpret_cast<half*>(softmax_results_ptr),\n                                                          reinterpret_cast<const half*>(softmax_results_ptr), k_seq_len,\n                                                          k_seq_len, attn_batches * q_seq_len);\n  } else {\n    if (use_time_mask) {\n      softmax_success = dispatch_time_masked_softmax<half, half, float>(\n          reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<const half*>(softmax_results_ptr), pad_mask,\n          k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);\n    } else {\n      softmax_success = dispatch_masked_softmax<half, half, float>(\n          reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<const half*>(softmax_results_ptr), pad_mask,\n          k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences);\n    }\n  }\n  assert(softmax_success);\n\n  if (is_training) {\n    apex_fused_dropout_cuda<at::Half, float, uint32_t>(\n        static_cast<at::Half const*>(softmax_results.data_ptr()), static_cast<at::Half*>(dropout_results.data_ptr()),\n        static_cast<uint8_t*>(dropout_mask.data_ptr()), dropout_elems, (1.0f - dropout_prob));\n  }\n\n  // Matmul2\n  gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha,\n                        static_cast<const half*>(v_lin_results_ptr), lead_dim_kv, batch_stride_kv,\n                        (is_training) ? static_cast<const half*>(dropout_results.data_ptr())\n                                      : static_cast<const half*>(softmax_results.data_ptr()),\n                        // static_cast<const half*>(dropout_results.data_ptr()),\n                        k_seq_len, k_seq_len * q_seq_len, beta, static_cast<half*>(matmul2_results.data_ptr()),\n                        head_dim * attn_batches, head_dim, attn_batches);\n\n  // Output Linear\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches_q, embed_dim, static_cast<const void*>(&alpha),\n      static_cast<const void*>(output_weights.data_ptr()), CUDA_R_16F, embed_dim,\n      static_cast<const void*>(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast<const void*>(&beta),\n      static_cast<void*>(output_lin_results.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F,\n      // CUBLAS_GEMM_ALGO1_TENSOR_OP));\n      CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // End-of-block Dropout-Add\n  if (is_training) {\n    apex_dropout_add_cuda<at::Half, float, uint32_t>(\n        static_cast<at::Half const*>(output_lin_results.data_ptr()), static_cast<at::Half const*>(inputs_q.data_ptr()),\n        static_cast<at::Half*>(outputs.data_ptr()), static_cast<uint8_t*>(dropout_add_mask.data_ptr()), total_tokens_q,\n        (1.0f - dropout_prob));\n  } else {\n    apex_add_cuda<at::Half, float, uint32_t>(static_cast<at::Half const*>(output_lin_results.data_ptr()),\n                                             static_cast<at::Half const*>(inputs_q.data_ptr()),\n                                             static_cast<at::Half*>(outputs.data_ptr()), total_tokens_q);\n  }\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {lyr_nrm_results,      lyr_nrm_mean,     lyr_nrm_invvar,  input_lin_q_results,\n          input_lin_kv_results, softmax_results,  dropout_results, dropout_mask,\n          matmul2_results,      dropout_add_mask, outputs};\n}\n\nstd::vector<torch::Tensor> bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results,\n                                    torch::Tensor const& dropout_results, torch::Tensor const& softmax_results,\n                                    torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results,\n                                    torch::Tensor const& lyr_nrm_results, torch::Tensor const& lyr_nrm_mean,\n                                    torch::Tensor const& lyr_nrm_invvar, torch::Tensor const& inputs_q,\n                                    torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights,\n                                    torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q,\n                                    torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights,\n                                    torch::Tensor const& dropout_mask, torch::Tensor const& dropout_add_mask,\n                                    float dropout_prob) {\n  const int embed_dim = inputs_q.size(2);\n  const int sequences = inputs_q.size(1);\n  const int q_seq_len = inputs_q.size(0);\n  const int k_seq_len = inputs_kv.size(0);\n  const int batches_q = sequences * q_seq_len;\n  const int batches_kv = sequences * k_seq_len;\n  const int total_tokens_q = batches_q * embed_dim;\n  const int head_dim = embed_dim / heads;\n  const int output_lin_q_dim = embed_dim;\n  const int output_lin_kv_dim = 2 * embed_dim;\n  const int attn_batches = heads * sequences;\n  const int lead_dim_q = attn_batches * head_dim;\n  const int lead_dim_kv = attn_batches * 2 * head_dim;\n  const int batch_stride_q = head_dim;\n  const int batch_stride_kv = 2 * head_dim;\n  const int dropout_elems = attn_batches * q_seq_len * k_seq_len;\n  const float alpha = 1.0;\n  const float beta = 0.0;\n  const float scale = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // Output Tensor Allocations\n  torch::Tensor input_q_grads = torch::empty_like(inputs_q);\n  torch::Tensor input_kv_grads = torch::empty_like(inputs_kv);\n  torch::Tensor lyr_nrm_gamma_grads = torch::empty_like(lyr_nrm_gamma_weights);\n  torch::Tensor lyr_nrm_beta_grads = torch::empty_like(lyr_nrm_beta_weights);\n  torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q);\n  torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv);\n  torch::Tensor output_weight_grads = torch::empty_like(output_weights);\n  // Intermediate Tensor Allocations\n  at::Tensor dropout_add_grads = torch::empty_like(output_grads);\n  at::Tensor output_lin_grads = torch::empty_like(matmul2_results);\n  at::Tensor matmul2_grads = torch::empty_like(dropout_results);\n  at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results);\n  at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results);\n  at::Tensor input_lin_q_grads = torch::empty_like(inputs_q);\n\n  auto q_lin_results_ptr = static_cast<half*>(input_lin_q_results.data_ptr());\n  auto k_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr());\n  auto v_lin_results_ptr = static_cast<half*>(input_lin_kv_results.data_ptr()) + head_dim;\n\n  auto q_lin_grads_ptr = static_cast<half*>(input_lin_q_output_grads.data_ptr());\n  auto k_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr());\n  auto v_lin_grads_ptr = static_cast<half*>(input_lin_kv_output_grads.data_ptr()) + head_dim;\n\n  char a_layout_n{'n'};\n  char a_layout_t{'t'};\n  char b_layout_n{'n'};\n  char b_layout_t{'t'};\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n\n  // Dropout Add Backward\n  apex_masked_scale_cuda<at::Half, float, uint32_t>(\n      static_cast<at::Half const*>(output_grads.data_ptr()), static_cast<at::Half*>(dropout_add_grads.data_ptr()),\n      static_cast<uint8_t const*>(dropout_add_mask.data_ptr()), total_tokens_q, (1.0 / (1.0 - dropout_prob)));\n\n  // Output Linear Dgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, embed_dim,\n                                    static_cast<const void*>(&alpha),\n                                    static_cast<const void*>(output_weights.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(dropout_add_grads.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(&beta), static_cast<void*>(output_lin_grads.data_ptr()),\n                                    CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Output Linear Wgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches_q,\n                                    static_cast<const void*>(&alpha),\n                                    static_cast<const void*>(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(dropout_add_grads.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(&beta), static_cast<void*>(output_weight_grads.data_ptr()),\n                                    CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul2 Dgrad1\n  gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha,\n                        static_cast<const half*>(v_lin_results_ptr), lead_dim_kv, batch_stride_kv,\n                        static_cast<const half*>(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, beta,\n                        static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, attn_batches);\n\n  // Matmul2 Dgrad2\n  gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, alpha,\n                        static_cast<const half*>(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim,\n                        static_cast<const half*>(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta,\n                        v_lin_grads_ptr, lead_dim_kv, batch_stride_kv, attn_batches);\n\n  // Apply Dropout Mask and Scale by Dropout Probability\n  apex_masked_scale_cuda<at::Half, float, uint32_t>(\n      static_cast<at::Half const*>(matmul2_grads.data_ptr()), static_cast<at::Half*>(matmul2_grads.data_ptr()),\n      static_cast<uint8_t const*>(dropout_mask.data_ptr()), dropout_elems, (1.0 / (1.0 - dropout_prob)));\n\n  // Softmax Grad\n  bool softmax_success = false;\n  softmax_success = dispatch_softmax_backward<half, half, float>(\n      static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),\n      reinterpret_cast<half const*>(softmax_results.data_ptr()), k_seq_len, k_seq_len, attn_batches * q_seq_len);\n  assert(softmax_success);\n\n  // Matmul1 Dgrad1\n  gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, scale, k_lin_results_ptr, lead_dim_kv,\n                        batch_stride_kv, static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len,\n                        beta, q_lin_grads_ptr, lead_dim_q, batch_stride_q, attn_batches);\n\n  // Matmul1 Dgrad2\n  gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, scale, q_lin_results_ptr, lead_dim_q,\n                        batch_stride_q, static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len,\n                        beta, k_lin_grads_ptr, lead_dim_kv, batch_stride_kv, attn_batches);\n\n  // Input Linear Q Dgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_q, output_lin_q_dim, static_cast<const void*>(&alpha),\n      static_cast<const void*>(input_weights_q.data_ptr()), CUDA_R_16F, embed_dim,\n      static_cast<const void*>(q_lin_grads_ptr), CUDA_R_16F, output_lin_q_dim, static_cast<const void*>(&beta),\n      // static_cast<void*>(input_q_grads.data_ptr()),\n      static_cast<void*>(input_lin_q_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F,\n      // CUBLAS_GEMM_ALGO10_TENSOR_OP));\n      CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Input Linear Q Wgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_q_dim, batches_q,\n                                    static_cast<const void*>(&alpha), static_cast<const void*>(inputs_q.data_ptr()),\n                                    CUDA_R_16F, embed_dim, static_cast<const void*>(q_lin_grads_ptr), CUDA_R_16F,\n                                    output_lin_q_dim, static_cast<const void*>(&beta),\n                                    static_cast<void*>(input_weight_q_grads.data_ptr()), CUDA_R_16F, embed_dim,\n                                    CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Input Linear KV Dgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches_kv, output_lin_kv_dim, static_cast<const void*>(&alpha),\n      static_cast<const void*>(input_weights_kv.data_ptr()), CUDA_R_16F, embed_dim,\n      static_cast<const void*>(k_lin_grads_ptr), CUDA_R_16F, output_lin_kv_dim, static_cast<const void*>(&beta),\n      static_cast<void*>(input_kv_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F,\n      // CUBLAS_GEMM_ALGO10_TENSOR_OP));\n      CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Input Linear KV Wgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_kv_dim, batches_kv,\n                                    static_cast<const void*>(&alpha), static_cast<const void*>(inputs_kv.data_ptr()),\n                                    CUDA_R_16F, embed_dim, static_cast<const void*>(k_lin_grads_ptr), CUDA_R_16F,\n                                    output_lin_kv_dim, static_cast<const void*>(&beta),\n                                    static_cast<void*>(input_weight_kv_grads.data_ptr()), CUDA_R_16F, embed_dim,\n                                    CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Fused Layer Norm Bwd with Residual Add\n  HostLayerNormGradient<half, float>(\n      static_cast<const half*>(input_lin_q_grads.data_ptr()), static_cast<half const*>(output_grads.data_ptr()),\n      static_cast<const float*>(lyr_nrm_mean.data_ptr()), static_cast<const float*>(lyr_nrm_invvar.data_ptr()),\n      inputs_q,\n      static_cast<int>(batches_q),  // n1\n      static_cast<int>(embed_dim),  // n2\n      static_cast<const half*>(lyr_nrm_gamma_weights.data_ptr()),\n      static_cast<const half*>(lyr_nrm_beta_weights.data_ptr()), 1.0e-5, static_cast<half*>(input_q_grads.data_ptr()),\n      static_cast<half*>(lyr_nrm_gamma_grads.data_ptr()), static_cast<half*>(lyr_nrm_beta_grads.data_ptr()));\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {input_q_grads,        input_kv_grads,        lyr_nrm_gamma_grads, lyr_nrm_beta_grads,\n          input_weight_q_grads, input_weight_kv_grads, output_weight_grads};\n}\n\n}  // end namespace cublas_gemmex\n}  // end namespace encdec_norm_add\n}  // end namespace multihead_attn\n"
  },
  {
    "path": "apex/contrib/csrc/multihead_attn/layer_norm.cuh",
    "content": "#pragma once\n#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n\n#include <ATen/cuda/DeviceUtils.cuh>\n\nnamespace {\ntemplate <typename U>\n__device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) {\n  count = count + U(1);\n  U delta = curr - mu;\n  U lmean = mu + delta / count;\n  mu = lmean;\n  U delta2 = curr - lmean;\n  sigma2 = sigma2 + delta * delta2;\n}\n\ntemplate <typename U>\n__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB, U& mu, U& sigma2, U& count) {\n  U delta = muB - mu;\n  U nA = count;\n  U nB = countB;\n  count = count + countB;\n  U nX = count;\n  if (nX > U(0)) {\n    nA = nA / nX;\n    nB = nB / nX;\n    mu = nA * mu + nB * muB;\n    sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;\n  } else {\n    mu = U(0);\n    sigma2 = U(0);\n  }\n}\n\ntemplate <typename T, typename U>\n__device__ void cuWelfordMuSigma2(const T* __restrict__ vals, const int n1, const int n2, const int i1, U& mu,\n                                  U& sigma2, U* buf) {\n  // Assumptions:\n  // 1) blockDim.x == warpSize\n  // 2) Tensor is contiguous\n  // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.\n  //\n  // compute variance and mean over n2\n  U count = U(0);\n  mu = U(0);\n  sigma2 = U(0);\n  if (i1 < n1) {\n    // one warp normalizes one n1 index,\n    // synchronization is implicit\n    // initialize with standard Welford algorithm\n    const int numx = blockDim.x * blockDim.y;\n    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;\n    const T* lvals = vals + i1 * n2;\n    int l = 4 * thrx;\n    for (; l + 3 < n2; l += 4 * numx) {\n      for (int k = 0; k < 4; ++k) {\n        U curr = static_cast<U>(lvals[l + k]);\n        cuWelfordOnlineSum<U>(curr, mu, sigma2, count);\n      }\n    }\n    for (; l < n2; ++l) {\n      U curr = static_cast<U>(lvals[l]);\n      cuWelfordOnlineSum<U>(curr, mu, sigma2, count);\n    }\n    // intra-warp reductions\n    for (int l = 0; l <= 4; ++l) {\n      int srcLaneB = (threadIdx.x + (1 << l)) & 31;\n      U muB = WARP_SHFL(mu, srcLaneB);\n      U countB = WARP_SHFL(count, srcLaneB);\n      U sigma2B = WARP_SHFL(sigma2, srcLaneB);\n      cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);\n    }\n    // threadIdx.x == 0 has correct values for each warp\n    // inter-warp reductions\n    if (blockDim.y > 1) {\n      U* ubuf = (U*)buf;\n      U* ibuf = (U*)(ubuf + blockDim.y);\n      for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {\n        // upper half of warps write to shared\n        if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2 * offset) {\n          const int wrt_y = threadIdx.y - offset;\n          ubuf[2 * wrt_y] = mu;\n          ubuf[2 * wrt_y + 1] = sigma2;\n          ibuf[wrt_y] = count;\n        }\n        __syncthreads();\n        // lower half merges\n        if (threadIdx.x == 0 && threadIdx.y < offset) {\n          U muB = ubuf[2 * threadIdx.y];\n          U sigma2B = ubuf[2 * threadIdx.y + 1];\n          U countB = ibuf[threadIdx.y];\n          cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);\n        }\n        __syncthreads();\n      }\n      // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values\n      if (threadIdx.x == 0 && threadIdx.y == 0) {\n        ubuf[0] = mu;\n        ubuf[1] = sigma2;\n      }\n      __syncthreads();\n      mu = ubuf[0];\n      sigma2 = ubuf[1] / U(n2);\n      // don't care about final value of count, we know count == n2\n    } else {\n      mu = WARP_SHFL(mu, 0);\n      sigma2 = WARP_SHFL(sigma2 / U(n2), 0);\n    }\n  }\n}\n\ntemplate <>\n__device__ void cuWelfordMuSigma2(const at::Half* __restrict__ vals, const int n1, const int n2, const int i1,\n                                  float& mu, float& sigma2, float* buf) {\n  // Assumptions:\n  // 1) blockDim.x == warpSize\n  // 2) Tensor is contiguous\n  // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.\n  //\n  // compute variance and mean over n2\n  float count = 0.0f;\n  mu = float(0);\n  sigma2 = float(0);\n\n  if (i1 < n1) {\n    // one warp normalizes one n1 index,\n    // synchronization is implicit\n    // initialize with standard Welford algorithm\n    const int numx = blockDim.x * blockDim.y;\n    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;\n    const at::Half* lvals = vals + i1 * n2;\n    int l = 8 * thrx;\n    if ((((size_t)lvals) & 3) != 0) {\n      // 16 bit alignment\n      // first thread consumes first point\n      if (thrx == 0) {\n        float curr = static_cast<float>(lvals[0]);\n        cuWelfordOnlineSum(curr, mu, sigma2, count);\n      }\n      ++l;\n    }\n    // at this point, lvals[l] are 32 bit aligned for all threads.\n    for (; l + 7 < n2; l += 8 * numx) {\n      for (int k = 0; k < 8; k += 2) {\n        float2 curr = __half22float2(*((__half2*)(lvals + l + k)));\n        cuWelfordOnlineSum(curr.x, mu, sigma2, count);\n        cuWelfordOnlineSum(curr.y, mu, sigma2, count);\n      }\n    }\n    for (; l < n2; ++l) {\n      float curr = static_cast<float>(lvals[l]);\n      cuWelfordOnlineSum(curr, mu, sigma2, count);\n    }\n    // intra-warp reductions\n    for (int l = 0; l <= 4; ++l) {\n      int srcLaneB = (threadIdx.x + (1 << l)) & 31;\n      float muB = WARP_SHFL(mu, srcLaneB);\n      float countB = WARP_SHFL(count, srcLaneB);\n      float sigma2B = WARP_SHFL(sigma2, srcLaneB);\n      cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);\n    }\n    // threadIdx.x == 0 has correct values for each warp\n    // inter-warp reductions\n    if (blockDim.y > 1) {\n      float* ubuf = (float*)buf;\n      float* ibuf = (float*)(ubuf + blockDim.y);\n      for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {\n        // upper half of warps write to shared\n        if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2 * offset) {\n          const int wrt_y = threadIdx.y - offset;\n          ubuf[2 * wrt_y] = mu;\n          ubuf[2 * wrt_y + 1] = sigma2;\n          ibuf[wrt_y] = count;\n        }\n        __syncthreads();\n        // lower half merges\n        if (threadIdx.x == 0 && threadIdx.y < offset) {\n          float muB = ubuf[2 * threadIdx.y];\n          float sigma2B = ubuf[2 * threadIdx.y + 1];\n          float countB = ibuf[threadIdx.y];\n          cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);\n        }\n        __syncthreads();\n      }\n      // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values\n      if (threadIdx.x == 0 && threadIdx.y == 0) {\n        ubuf[0] = mu;\n        ubuf[1] = sigma2;\n      }\n      __syncthreads();\n      mu = ubuf[0];\n      sigma2 = ubuf[1] / float(n2);\n      // don't care about final value of count, we know count == n2\n    } else {\n      mu = WARP_SHFL(mu, 0);\n      sigma2 = WARP_SHFL(sigma2 / float(n2), 0);\n    }\n  }\n}\n\ntemplate <typename U>\n__device__ U rsqrt(U v) {\n  return U(1) / sqrt(v);\n}\ntemplate <>\n__device__ float rsqrt(float v) {\n  return rsqrtf(v);\n}\ntemplate <>\n__device__ double rsqrt(double v) {\n  return rsqrt(v);\n}\n\n// This is the un-specialized struct.  Note that we prevent instantiation of\n// this struct by putting an undefined symbol in the function body so it won't\n// compile.\n//  template <typename T>\n//  struct SharedMemory\n//  {\n//      // Ensure that we won't compile any un-specialized types\n//      __device__ T *getPointer()\n//      {\n//          extern __device__ void error(void);\n//          error();\n//          return NULL;\n//      }\n//  };\n// https://github.com/NVIDIA/apex/issues/246\ntemplate <typename T>\nstruct SharedMemory;\ntemplate <>\nstruct SharedMemory<float> {\n  __device__ float* getPointer() {\n    extern __shared__ float s_float[];\n    return s_float;\n  }\n};\n\ntemplate <>\nstruct SharedMemory<double> {\n  __device__ double* getPointer() {\n    extern __shared__ double s_double[];\n    return s_double;\n  }\n};\n\ntemplate <typename T, typename U>\n__global__ void cuApplyLayerNorm(T* __restrict__ output_vals, U* __restrict__ mean, U* __restrict__ invvar,\n                                 const T* __restrict__ vals, const int n1, const int n2, const U epsilon,\n                                 const T* __restrict__ gamma, const T* __restrict__ beta) {\n  // Assumptions:\n  // 1) blockDim.x == warpSize\n  // 2) Tensors are contiguous\n  //\n  for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {\n    SharedMemory<U> shared;\n    U* buf = shared.getPointer();\n    U mu, sigma2;\n    cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf);\n    const T* lvals = vals + i1 * n2;\n    T* ovals = output_vals + i1 * n2;\n    U c_invvar = rsqrt(sigma2 + epsilon);\n    const int numx = blockDim.x * blockDim.y;\n    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;\n    if (gamma != NULL && beta != NULL) {\n      for (int i = thrx; i < n2; i += numx) {\n        U curr = static_cast<U>(lvals[i]);\n        ovals[i] = gamma[i] * static_cast<T>(c_invvar * (curr - mu)) + beta[i];\n      }\n    } else {\n      for (int i = thrx; i < n2; i += numx) {\n        U curr = static_cast<U>(lvals[i]);\n        ovals[i] = static_cast<T>(c_invvar * (curr - mu));\n      }\n    }\n    if (threadIdx.x == 0 && threadIdx.y == 0) {\n      mean[i1] = mu;\n      invvar[i1] = c_invvar;\n    }\n  }\n}\n\ntemplate <typename T, typename U>\n__device__ void cuLoadWriteStridedInputs(const int i1_block, const int thr_load_row_off, const int thr_load_col_off,\n                                         const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2,\n                                         const T* input, const T* dout, const int i1_end, const int n2,\n                                         const U* __restrict__ mean, const U* __restrict__ invvar) {\n  int i1 = i1_block + thr_load_row_off;\n  if (i1 < i1_end) {\n    U curr_mean = mean[i1];\n    U curr_invvar = invvar[i1];\n    for (int k = 0; k < blockDim.y; ++k) {\n      int i2 = i2_off + k;\n      int load_idx = i1 * n2 + i2;\n      int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;\n      if (i2 < n2) {\n        U curr_input = static_cast<U>(input[load_idx]);\n        U curr_dout = static_cast<U>(dout[load_idx]);\n        warp_buf1[write_idx] = curr_dout;\n        warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;\n      } else {\n        warp_buf1[write_idx] = U(0);\n        warp_buf2[write_idx] = U(0);\n      }\n    }\n  } else {\n    for (int k = 0; k < blockDim.y; ++k) {\n      int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;\n      warp_buf1[write_idx] = U(0);\n      warp_buf2[write_idx] = U(0);\n    }\n  }\n}\n\ntemplate <typename T, typename U>\n__device__ void cuLoadAddStridedInputs(const int i1_block, const int thr_load_row_off, const int thr_load_col_off,\n                                       const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2,\n                                       const T* input, const T* dout, const int i1_end, const int n2,\n                                       const U* __restrict__ mean, const U* __restrict__ invvar) {\n  int i1 = i1_block + thr_load_row_off;\n  if (i1 < i1_end) {\n    U curr_mean = mean[i1];\n    U curr_invvar = invvar[i1];\n    for (int k = 0; k < blockDim.y; ++k) {\n      int i2 = i2_off + k;\n      int load_idx = i1 * n2 + i2;\n      int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;\n      if (i2 < n2) {\n        U curr_input = static_cast<U>(input[load_idx]);\n        U curr_dout = static_cast<U>(dout[load_idx]);\n        warp_buf1[write_idx] += curr_dout;\n        warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;\n      }\n    }\n  }\n}\n\ntemplate <typename T, typename U>\n__global__ void cuComputePartGradGammaBeta(const T* __restrict__ dout, const T* __restrict__ input, const int n1,\n                                           const int n2, const U* __restrict__ mean, const U* __restrict__ invvar,\n                                           U epsilon, U* part_grad_gamma, U* part_grad_beta) {\n  const int numsegs_n1 = (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y);\n  const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;\n  const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y;\n  const int i1_beg_plus_one = (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y;\n  const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;\n  const int row_stride = blockDim.x + 1;\n  const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1);\n  const int thr_load_row_off = (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y;\n  const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;\n  SharedMemory<U> shared;\n  U* buf = shared.getPointer();  // buf has at least blockDim.x * blockDim.y *\n                                 // blockDim.y + (blockDim.y -\n                                 // 1)*(blockDim.x/blockDim.y) elements\n  U* warp_buf1 = (U*)buf;\n  U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;\n  // compute partial sums from strided inputs\n  // do this to increase number of loads in flight\n  cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2, input,\n                           dout, i1_end, n2, mean, invvar);\n  for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; i1_block += blockDim.y * blockDim.y) {\n    cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride, warp_buf1, warp_buf2,\n                           input, dout, i1_end, n2, mean, invvar);\n  }\n  __syncthreads();\n  // inter-warp reductions\n  // sum within each warp\n  U acc1 = U(0);\n  U acc2 = U(0);\n  for (int k = 0; k < blockDim.y; ++k) {\n    int row1 = threadIdx.y + k * blockDim.y;\n    int idx1 = row1 * row_stride + threadIdx.x;\n    acc1 += warp_buf1[idx1];\n    acc2 += warp_buf2[idx1];\n  }\n  warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;\n  warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;\n  __syncthreads();\n  // sum all warps\n  for (int offset = blockDim.y / 2; offset > 1; offset /= 2) {\n    if (threadIdx.y < offset) {\n      int row1 = threadIdx.y;\n      int row2 = threadIdx.y + offset;\n      int idx1 = row1 * row_stride + threadIdx.x;\n      int idx2 = row2 * row_stride + threadIdx.x;\n      warp_buf1[idx1] += warp_buf1[idx2];\n      warp_buf2[idx1] += warp_buf2[idx2];\n    }\n    __syncthreads();\n  }\n  int i2 = blockIdx.x * blockDim.x + threadIdx.x;\n  if (threadIdx.y == 0 && i2 < n2) {\n    int row1 = threadIdx.y;\n    int row2 = threadIdx.y + 1;\n    int idx1 = row1 * row_stride + threadIdx.x;\n    int idx2 = row2 * row_stride + threadIdx.x;\n    part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];\n    part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];\n  }\n}\n\ntemplate <typename T, typename U>\n__global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, const U* part_grad_beta, const int part_size,\n                                       const int n1, const int n2, T* grad_gamma, T* grad_beta) {\n  // sum partial gradients for gamma and beta\n  SharedMemory<U> shared;\n  U* buf = shared.getPointer();\n  int i2 = blockIdx.x * blockDim.x + threadIdx.x;\n  if (i2 < n2) {\n    // each warp does sequential reductions until reduced part_size is num_warps\n    int num_warp_reductions = part_size / blockDim.y;\n    U sum_gamma = U(0);\n    U sum_beta = U(0);\n    const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;\n    const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;\n    for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {\n      sum_gamma += part_grad_gamma_ptr[warp_offset * n2];\n      sum_beta += part_grad_beta_ptr[warp_offset * n2];\n    }\n    // inter-warp reductions\n    const int nbsize3 = blockDim.x * blockDim.y / 2;\n    for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {\n      // top half write to shared memory\n      if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {\n        const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;\n        buf[write_idx] = sum_gamma;\n        buf[write_idx + nbsize3] = sum_beta;\n      }\n      __syncthreads();\n      // bottom half sums\n      if (threadIdx.y < offset) {\n        const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;\n        sum_gamma += buf[read_idx];\n        sum_beta += buf[read_idx + nbsize3];\n      }\n      __syncthreads();\n    }\n    // write out fully summed gradients\n    if (threadIdx.y == 0) {\n      grad_gamma[i2] = sum_gamma;\n      grad_beta[i2] = sum_beta;\n    }\n  }\n}\n\ntemplate <typename T, typename U>\n__global__ void cuComputeGradInput(const T* __restrict__ dout, const T* __restrict__ dout_resid,\n                                   const T* __restrict__ input, const int n1, const int n2, const U* __restrict__ mean,\n                                   const U* __restrict__ invvar, U epsilon, const T* gamma, T* grad_input) {\n  for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {\n    U sum_loss1 = U(0);\n    U sum_loss2 = U(0);\n    const U c_mean = mean[i1];\n    const U c_invvar = invvar[i1];\n    const T* k_input = input + i1 * n2;\n    const T* k_dout = dout + i1 * n2;\n    const T* k_dout_resid = dout_resid + i1 * n2;\n    const int numx = blockDim.x * blockDim.y;\n    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;\n    if (gamma != NULL) {\n      int l = 4 * thrx;\n      for (; l + 3 < n2; l += 4 * numx) {\n        for (int k = 0; k < 4; ++k) {\n          const U c_h = static_cast<U>(k_input[l + k]);\n          const U c_loss = static_cast<U>(k_dout[l + k]);\n          sum_loss1 += c_loss * static_cast<U>(gamma[l + k]);\n          sum_loss2 += c_loss * static_cast<U>(gamma[l + k]) * (c_h - c_mean) * c_invvar;\n        }\n      }\n      for (; l < n2; ++l) {\n        const U c_h = static_cast<U>(k_input[l]);\n        const U c_loss = static_cast<U>(k_dout[l]);\n        sum_loss1 += c_loss * static_cast<U>(gamma[l]);\n        sum_loss2 += c_loss * static_cast<U>(gamma[l]) * (c_h - c_mean) * c_invvar;\n      }\n    } else {\n      int l = 4 * thrx;\n      for (; l + 3 < n2; l += 4 * numx) {\n        for (int k = 0; k < 4; ++k) {\n          const U c_h = static_cast<U>(k_input[l + k]);\n          const U c_loss = static_cast<U>(k_dout[l + k]);\n          sum_loss1 += c_loss;\n          sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;\n        }\n      }\n      for (; l < n2; ++l) {\n        const U c_h = static_cast<U>(k_input[l]);\n        const U c_loss = static_cast<U>(k_dout[l]);\n        sum_loss1 += c_loss;\n        sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;\n      }\n    }\n    // intra-warp reductions\n    for (int mask = blockDim.x / 2; mask > 0; mask /= 2) {\n      sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);\n      sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);\n    }\n    // inter-warp reductions\n    if (blockDim.y > 1) {\n      SharedMemory<U> shared;\n      U* buf = shared.getPointer();\n      for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {\n        // upper half of warps write to shared\n        if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {\n          const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;\n          buf[2 * wrt_i] = sum_loss1;\n          buf[2 * wrt_i + 1] = sum_loss2;\n        }\n        __syncthreads();\n        // lower half merges\n        if (threadIdx.y < offset) {\n          const int read_i = threadIdx.y * blockDim.x + threadIdx.x;\n          sum_loss1 += buf[2 * read_i];\n          sum_loss2 += buf[2 * read_i + 1];\n        }\n        __syncthreads();\n      }\n      if (threadIdx.y == 0) {\n        buf[2 * threadIdx.x] = sum_loss1;\n        buf[2 * threadIdx.x + 1] = sum_loss2;\n      }\n      __syncthreads();\n      if (threadIdx.y != 0) {\n        sum_loss1 = buf[2 * threadIdx.x];\n        sum_loss2 = buf[2 * threadIdx.x + 1];\n      }\n    }\n    // all threads now have the two sums over l\n    U fH = (U)n2;\n    U term1 = (U(1) / fH) * c_invvar;\n    T* k_grad_input = grad_input + i1 * n2;\n    if (gamma != NULL) {\n      for (int l = thrx; l < n2; l += numx) {\n        const U c_h = static_cast<U>(k_input[l]);\n        const U c_loss = static_cast<U>(k_dout[l]);\n        const T c_resid = static_cast<T>(k_dout_resid[l]);\n        U f_grad_input = fH * c_loss * static_cast<U>(gamma[l]);\n        f_grad_input -= sum_loss1;\n        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;\n        f_grad_input *= term1;\n        k_grad_input[l] = static_cast<T>(f_grad_input) + c_resid;\n      }\n    } else {\n      for (int l = thrx; l < n2; l += numx) {\n        const U c_h = static_cast<U>(k_input[l]);\n        const U c_loss = static_cast<U>(k_dout[l]);\n        const T c_resid = static_cast<T>(k_dout_resid[l]);\n        U f_grad_input = fH * c_loss;\n        f_grad_input -= sum_loss1;\n        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;\n        f_grad_input *= term1;\n        k_grad_input[l] = static_cast<T>(f_grad_input) + c_resid;\n      }\n    }\n  }\n}\n\ntemplate <typename T, typename U>\nvoid HostApplyLayerNorm(T* output, U* mean, U* invvar, const T* input, int n1, int n2, double epsilon, const T* gamma,\n                        const T* beta) {\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const dim3 threads(32, 4, 1);\n  const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];\n  const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);\n  int nshared = threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0;\n  cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta);\n}\n\ntemplate <typename T, typename U>\nvoid HostLayerNormGradient(const T* dout, const T* dout_resid, const U* mean, const U* invvar, const at::Tensor& input,\n                           int n1, int n2, const T* gamma, const T* beta, double epsilon, T* grad_input, T* grad_gamma,\n                           T* grad_beta) {\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n  if (gamma != NULL && beta != NULL) {\n    // compute grad_gamma(j) and grad_beta(j)\n    const int part_size = 16;\n    const dim3 threads2(32, 4, 1);\n    const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1);\n    const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);\n    const int nshared2_b = threads2.x * threads2.y * sizeof(U);\n    const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;\n    at::Tensor part_grad_gamma = at::empty(\n        {part_size, n2}, input.options().dtype(input.scalar_type() == at::ScalarType::Half ? at::ScalarType::Float\n                                                                                           : input.scalar_type()));\n    at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);\n    cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(\n        dout, static_cast<T*>(input.data_ptr()), n1, n2, mean, invvar, U(epsilon),\n        static_cast<U*>(part_grad_gamma.data_ptr()), static_cast<U*>(part_grad_beta.data_ptr()));\n\n    const dim3 threads3(32, 8, 1);\n    const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1);\n    const int nshared3 = threads3.x * threads3.y * sizeof(U);\n    cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(static_cast<U*>(part_grad_gamma.data_ptr()),\n                                                                    static_cast<U*>(part_grad_beta.data_ptr()),\n                                                                    part_size, n1, n2, grad_gamma, grad_beta);\n  }\n\n  // compute grad_input\n  const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];\n  const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);\n  const dim3 threads1(32, 4, 1);\n  int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0;\n  cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(dout, dout_resid, static_cast<T*>(input.data_ptr()), n1,\n                                                             n2, mean, invvar, U(epsilon), gamma, grad_input);\n}\n}  // namespace\n"
  },
  {
    "path": "apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include <cuda_runtime.h>\n#include <math.h>\n#include <torch/extension.h>\n\n#include <iostream>\n#include <vector>\n\n#include \"dropout.cuh\"\n#include \"softmax.cuh\"\n\nnamespace multihead_attn {\nnamespace fused_softmax {\nnamespace mask_softmax_dropout {\n\nstd::vector<torch::Tensor> fwd_cuda(bool is_training, int heads, torch::Tensor const& input, const uint8_t* pad_mask,\n                                    float dropout_prob) {\n  const int attn_batches = input.size(0);\n  const int sequences = attn_batches / heads;\n  const int q_seq_len = input.size(1);\n  const int k_seq_len = q_seq_len;\n  const int dropout_elems = attn_batches * q_seq_len * k_seq_len;\n\n  // There is no reason to use more than one stream as every kernel is\n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated\n  // by ATen library code)\n  auto act_options = input.options().requires_grad(false);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);\n  torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);\n  torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* input_ptr = static_cast<void*>(input.data_ptr());\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n\n  // Padded Softmax\n  bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(reinterpret_cast<half*>(softmax_results_ptr),\n                                                          reinterpret_cast<const half*>(input_ptr), k_seq_len,\n                                                          k_seq_len, attn_batches * q_seq_len);\n  } else {\n    softmax_success = dispatch_masked_softmax<half, half, float>(\n        reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<const half*>(input_ptr), pad_mask, k_seq_len,\n        k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences);\n  }\n\n  if (is_training) {\n    // use at:: function so that C++ version generates the same random mask as\n    // python version\n    auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f - dropout_prob);\n    dropout_results = std::get<0>(dropout_tuple);\n    dropout_mask = std::get<1>(dropout_tuple);\n  }\n\n  // Matmul2\n\n  return {dropout_results, dropout_mask, softmax_results};\n}\n\ntorch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results,\n                       torch::Tensor const& dropout_mask, const uint8_t* padding_mask, float dropout_prob) {\n  const int attn_batches = output_grads.size(0);\n  const int q_seq_len = output_grads.size(1);\n  const int k_seq_len = q_seq_len;\n  const int dropout_elems = attn_batches * q_seq_len * k_seq_len;\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // Output Tensor Allocations\n  //  torch::Tensor input_grads         = torch::empty_like(output_grads);\n\n  // Apply Dropout Mask and Scale by Dropout Probability\n  // Softmax Grad\n  if (padding_mask == nullptr) {\n    dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(\n        static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),\n        reinterpret_cast<half const*>(softmax_results.data_ptr()), static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n        1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, attn_batches * q_seq_len, stream);\n  } else {\n    dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float, false>(\n        static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),\n        reinterpret_cast<half const*>(softmax_results.data_ptr()), static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n        static_cast<uint8_t const*>(padding_mask), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,\n        attn_batches * q_seq_len, heads, stream);\n  }\n  // backward pass is completely in-place\n  return output_grads;\n}\n}  // namespace mask_softmax_dropout\n}  // namespace fused_softmax\n}  // namespace multihead_attn\n"
  },
  {
    "path": "apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp",
    "content": "#include <cuda_fp16.h>\n#include <torch/extension.h>\n\n#include <vector>\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) \\\n  CHECK_CUDA(x);       \\\n  CHECK_CONTIGUOUS(x)\n\nnamespace multihead_attn {\nnamespace fused_softmax {\nnamespace additive_mask_softmax_dropout {\n\nstd::vector<torch::Tensor> fwd_cuda(bool is_training, int heads, torch::Tensor const& input, const half* pad_mask,\n                                    float dropout_prob);\n\ntorch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results,\n                       torch::Tensor const& dropout_mask, float dropout_prob);\n\nstd::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads, torch::Tensor const& input,\n                               torch::Tensor const& pad_mask, float dropout_prob) {\n  TORCH_CHECK(input.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(input.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  if (use_mask) {\n    TORCH_CHECK(pad_mask.dim() == 2, \"expected 2D tensor\");\n    TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Half, \"Only BYTE is supported\");\n  }\n\n  return fwd_cuda(is_training, heads, input, use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,\n                  dropout_prob);\n}\n\ntorch::Tensor bwd(bool use_mask, int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results,\n                  torch::Tensor const& dropout_mask, float dropout_prob) {\n  TORCH_CHECK(output_grads.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(softmax_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(dropout_mask.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  //  TORCH_CHECK(dropout_mask.scalar_type()      == at::ScalarType::Byte,\n  //  \"Only BYTE is supported\");\n\n  return bwd_cuda(heads, output_grads, softmax_results, dropout_mask, dropout_prob);\n}\n\n}  // namespace additive_mask_softmax_dropout\nnamespace mask_softmax_dropout {\n\nstd::vector<torch::Tensor> fwd_cuda(bool is_training, int heads, torch::Tensor const& input, const uint8_t* pad_mask,\n                                    float dropout_prob);\n\ntorch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results,\n                       torch::Tensor const& dropout_mask, const uint8_t* padding_mask, float dropout_prob);\n\nstd::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads, torch::Tensor const& input,\n                               torch::Tensor const& pad_mask, float dropout_prob) {\n  TORCH_CHECK(input.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(input.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n\n  if (use_mask) {\n    TORCH_CHECK(pad_mask.dim() == 2, \"expected 2D tensor\");\n    TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, \"Only BYTE is supported\");\n  }\n\n  return fwd_cuda(is_training, heads, input, use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,\n                  dropout_prob);\n}\n\ntorch::Tensor bwd(bool use_mask, int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results,\n                  torch::Tensor const& dropout_mask, torch::Tensor const& padding_mask, float dropout_prob) {\n  TORCH_CHECK(output_grads.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(softmax_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(dropout_mask.dim() == 3, \"expected 3D tensor\");\n\n  TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  //  TORCH_CHECK(dropout_mask.scalar_type()      == at::ScalarType::Byte,\n  //  \"Only BYTE is supported\");\n\n  return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,\n                  use_mask ? static_cast<const uint8_t*>(padding_mask.data_ptr()) : nullptr, dropout_prob);\n}\n\n}  // end namespace mask_softmax_dropout\n}  // end namespace fused_softmax\n\nnamespace encdec {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs_q,\n                                    torch::Tensor const& inputs_kv, torch::Tensor const& input_weights_q,\n                                    torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights,\n                                    const uint8_t* pad_mask, float dropout_prob);\nstd::vector<torch::Tensor> bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results,\n                                    torch::Tensor const& dropout_results, torch::Tensor const& softmax_results,\n                                    torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results,\n                                    torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv,\n                                    torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv,\n                                    torch::Tensor const& output_weights, torch::Tensor const& dropout_mask,\n                                    float dropout_prob);\n\nstd::vector<torch::Tensor> fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,\n                               torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv,\n                               torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights, torch::Tensor const& pad_mask, float dropout_prob) {\n  TORCH_CHECK(inputs_q.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(inputs_kv.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(input_weights_q.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(input_weights_kv.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(output_weights.dim() == 2, \"expected 2D tensor\");\n\n  TORCH_CHECK(inputs_q.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(inputs_kv.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_weights_q.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_weights_kv.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n\n  if (use_mask) {\n    TORCH_CHECK(pad_mask.dim() == 2, \"expected 2D tensor\");\n    TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, \"Only BYTE is supported\");\n  }\n\n  return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv, input_weights_q, input_weights_kv,\n                  output_weights, use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr, dropout_prob);\n}\n\nstd::vector<torch::Tensor> bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results, torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results,\n                               torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv,\n                               torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights, torch::Tensor const& dropout_mask,\n                               float dropout_prob) {\n  TORCH_CHECK(output_grads.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(matmul2_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(dropout_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(softmax_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(input_lin_q_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(input_lin_kv_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(inputs_q.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(inputs_kv.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(input_weights_q.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(input_weights_kv.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(output_weights.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(dropout_mask.dim() == 3, \"expected 3D tensor\");\n\n  TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_lin_q_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_lin_kv_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(inputs_q.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(inputs_kv.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_weights_q.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_weights_kv.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, \"Only BYTE is supported\");\n\n  return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, softmax_results, input_lin_q_results,\n                  input_lin_kv_results, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights,\n                  dropout_mask, dropout_prob);\n}\n\n}  // end namespace cublas_gemmex\n}  // end namespace encdec\n\nnamespace encdec_norm_add {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs_q,\n                                    torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights,\n                                    torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q,\n                                    torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights,\n                                    const uint8_t* pad_mask, float dropout_prob);\n\nstd::vector<torch::Tensor> bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results,\n                                    torch::Tensor const& dropout_results, torch::Tensor const& softmax_results,\n                                    torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results,\n                                    torch::Tensor const& lyr_nrm_results, torch::Tensor const& lyr_nrm_mean,\n                                    torch::Tensor const& lyr_nrm_invvar, torch::Tensor const& inputs_q,\n                                    torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights,\n                                    torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q,\n                                    torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights,\n                                    torch::Tensor const& dropout_mask, torch::Tensor const& dropout_add_mask,\n                                    float dropout_prob);\n\nstd::vector<torch::Tensor> fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,\n                               torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv,\n                               torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const& lyr_nrm_beta_weights,\n                               torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv,\n                               torch::Tensor const& output_weights, torch::Tensor const& pad_mask, float dropout_prob) {\n  TORCH_CHECK(inputs_q.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(inputs_kv.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(lyr_nrm_gamma_weights.dim() == 1, \"expected 1D tensor\");\n  TORCH_CHECK(lyr_nrm_beta_weights.dim() == 1, \"expected 1D tensor\");\n  TORCH_CHECK(input_weights_q.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(input_weights_kv.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(output_weights.dim() == 2, \"expected 2D tensor\");\n\n  TORCH_CHECK(inputs_q.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(inputs_kv.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(lyr_nrm_gamma_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(lyr_nrm_beta_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_weights_q.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_weights_kv.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n\n  if (use_mask) {\n    TORCH_CHECK(pad_mask.dim() == 2, \"expected 2D tensor\");\n    TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, \"Only BYTE is supported\");\n  }\n\n  return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights,\n                  input_weights_q, input_weights_kv, output_weights,\n                  use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr, dropout_prob);\n}\n\nstd::vector<torch::Tensor> bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results, torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results,\n                               torch::Tensor const& lyr_nrm_results, torch::Tensor const& lyr_nrm_mean,\n                               torch::Tensor const& lyr_nrm_invvar, torch::Tensor const& inputs_q,\n                               torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights,\n                               torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q,\n                               torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask, torch::Tensor const& dropout_add_mask,\n                               float dropout_prob) {\n  TORCH_CHECK(output_grads.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(matmul2_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(dropout_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(softmax_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(input_lin_q_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(input_lin_kv_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(lyr_nrm_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(lyr_nrm_mean.dim() == 1, \"expected 1D tensor\");\n  TORCH_CHECK(lyr_nrm_invvar.dim() == 1, \"expected 1D tensor\");\n  TORCH_CHECK(inputs_q.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(inputs_kv.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(lyr_nrm_gamma_weights.dim() == 1, \"expected 1D tensor\");\n  TORCH_CHECK(lyr_nrm_beta_weights.dim() == 1, \"expected 1D tensor\");\n  TORCH_CHECK(input_weights_q.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(input_weights_kv.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(output_weights.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(dropout_mask.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(dropout_add_mask.dim() == 3, \"expected 3D tensor\");\n\n  TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_lin_q_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_lin_kv_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(lyr_nrm_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(lyr_nrm_mean.scalar_type() == at::ScalarType::Float, \"Only FLOAT is supported\");\n  TORCH_CHECK(lyr_nrm_invvar.scalar_type() == at::ScalarType::Float, \"Only FLOAT is supported\");\n  TORCH_CHECK(inputs_q.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(inputs_kv.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(lyr_nrm_gamma_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(lyr_nrm_beta_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_weights_q.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_weights_kv.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, \"Only BYTE is supported\");\n  TORCH_CHECK(dropout_add_mask.scalar_type() == at::ScalarType::Byte, \"Only BYTE is supported\");\n\n  return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, softmax_results, input_lin_q_results,\n                  input_lin_kv_results, lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs_q, inputs_kv,\n                  lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights,\n                  dropout_mask, dropout_add_mask, dropout_prob);\n}\n\n}  // end namespace cublas_gemmex\n}  // end namespace encdec_norm_add\n\nnamespace self {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs,\n                                    torch::Tensor const& input_weights, torch::Tensor const& output_weights,\n                                    const uint8_t* pad_mask, float dropout_prob);\n\nstd::vector<torch::Tensor> bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results,\n                                    torch::Tensor const& dropout_results, torch::Tensor const& softmax_results,\n                                    torch::Tensor const& input_lin_results, torch::Tensor const& inputs,\n                                    torch::Tensor const& input_weights, torch::Tensor const& output_weights,\n                                    torch::Tensor const& dropout_mask, float dropout_prob);\n\nstd::vector<torch::Tensor> fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,\n                               torch::Tensor const& inputs, torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights, torch::Tensor const& pad_mask, float dropout_prob) {\n  TORCH_CHECK(inputs.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(input_weights.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(output_weights.dim() == 2, \"expected 2D tensor\");\n\n  TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n\n  if (use_mask) {\n    TORCH_CHECK(pad_mask.dim() == 2, \"expected 2D tensor\");\n    TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, \"Only BYTE is supported\");\n  }\n\n  return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights, output_weights,\n                  use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr, dropout_prob);\n}\n\nstd::vector<torch::Tensor> bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results, torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results, torch::Tensor const& inputs,\n                               torch::Tensor const& input_weights, torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask, float dropout_prob) {\n  TORCH_CHECK(output_grads.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(matmul2_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(dropout_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(softmax_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(input_lin_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(inputs.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(input_weights.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(output_weights.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(dropout_mask.dim() == 3, \"expected 3D tensor\");\n\n  TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_lin_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, \"Only BYTE is supported\");\n\n  return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, softmax_results, input_lin_results, inputs,\n                  input_weights, output_weights, dropout_mask, dropout_prob);\n}\n\n}  // end namespace cublas_gemmex\n}  // end namespace self\nnamespace self_bias {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs,\n                                    torch::Tensor const& input_weights, torch::Tensor const& output_weights,\n                                    torch::Tensor const& input_biases, torch::Tensor const& output_biases,\n                                    const uint8_t* pad_mask, float dropout_prob);\n\nstd::vector<torch::Tensor> bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results,\n                                    torch::Tensor const& dropout_results, torch::Tensor const& softmax_results,\n                                    torch::Tensor const& input_lin_results, torch::Tensor const& inputs,\n                                    torch::Tensor const& input_weights, torch::Tensor const& output_weights,\n                                    // torch::Tensor const& input_biases,\n                                    // torch::Tensor const& output_biases,\n                                    torch::Tensor const& dropout_mask, float dropout_prob);\n\nstd::vector<torch::Tensor> fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,\n                               torch::Tensor const& inputs, torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights, torch::Tensor const& input_biases,\n                               torch::Tensor const& output_biases, torch::Tensor const& pad_mask, float dropout_prob) {\n  TORCH_CHECK(inputs.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(input_weights.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(output_weights.dim() == 2, \"expected 2D tensor\");\n\n  TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n\n  if (use_mask) {\n    TORCH_CHECK(pad_mask.dim() == 2, \"expected 2D tensor\");\n    TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, \"Only BYTE is supported\");\n  }\n\n  return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights, output_weights, input_biases, output_biases,\n                  use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr, dropout_prob);\n}\n\nstd::vector<torch::Tensor> bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results, torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results, torch::Tensor const& inputs,\n                               torch::Tensor const& input_weights, torch::Tensor const& output_weights,\n                               torch::Tensor const& dropout_mask, float dropout_prob) {\n  TORCH_CHECK(output_grads.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(matmul2_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(dropout_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(softmax_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(input_lin_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(inputs.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(input_weights.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(output_weights.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(dropout_mask.dim() == 3, \"expected 3D tensor\");\n\n  TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_lin_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, \"Only BYTE is supported\");\n\n  return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, softmax_results, input_lin_results, inputs,\n                  input_weights, output_weights, dropout_mask, dropout_prob);\n}\n\n}  // end namespace cublas_gemmex\n}  // namespace self_bias\nnamespace self_bias_additive_mask {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs,\n                                    torch::Tensor const& input_weights, torch::Tensor const& output_weights,\n                                    torch::Tensor const& input_biases, torch::Tensor const& output_biases,\n                                    const half* pad_mask, float dropout_prob);\n\nstd::vector<torch::Tensor> bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results,\n                                    torch::Tensor const& dropout_results,\n                                    // torch::Tensor const& softmax_results,\n                                    torch::Tensor const& bmm1_results, torch::Tensor const& pad_mask,\n                                    torch::Tensor const& input_lin_results, torch::Tensor const& inputs,\n                                    torch::Tensor const& input_weights, torch::Tensor const& output_weights,\n                                    // torch::Tensor const& input_biases,\n                                    // torch::Tensor const& output_biases,\n                                    torch::Tensor const& dropout_mask, float dropout_prob);\n\nstd::vector<torch::Tensor> fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,\n                               torch::Tensor const& inputs, torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights, torch::Tensor const& input_biases,\n                               torch::Tensor const& output_biases, torch::Tensor const& pad_mask, float dropout_prob) {\n  TORCH_CHECK(inputs.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(input_weights.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(output_weights.dim() == 2, \"expected 2D tensor\");\n\n  TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(use_mask, \"no mask is not supported\");\n\n  if (use_mask) {\n    TORCH_CHECK(pad_mask.dim() == 2, \"expected 2D tensor\");\n    TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Half, \"Only Half is supported\");\n  }\n\n  return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights, output_weights, input_biases, output_biases,\n                  use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr, dropout_prob);\n}\n\nstd::vector<torch::Tensor> bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results, torch::Tensor const& bmm1_results,\n                               torch::Tensor const& pad_mask, torch::Tensor const& input_lin_results,\n                               torch::Tensor const& inputs, torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights, torch::Tensor const& dropout_mask,\n                               float dropout_prob) {\n  TORCH_CHECK(output_grads.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(matmul2_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(dropout_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(input_lin_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(inputs.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(input_weights.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(output_weights.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(dropout_mask.dim() == 3, \"expected 3D tensor\");\n\n  TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_lin_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, \"Only BYTE is supported\");\n\n  return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, bmm1_results, pad_mask, input_lin_results,\n                  inputs, input_weights, output_weights, dropout_mask, dropout_prob);\n}\n\n}  // end namespace cublas_gemmex\n}  // namespace self_bias_additive_mask\n\nnamespace self_norm_add {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs,\n                                    torch::Tensor const& lyr_nrm_gamma_weights,\n                                    torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights,\n                                    torch::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob);\n\nstd::vector<torch::Tensor> bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results,\n                                    torch::Tensor const& dropout_results, torch::Tensor const& softmax_results,\n                                    torch::Tensor const& input_lin_results, torch::Tensor const& lyr_nrm_results,\n                                    torch::Tensor const& lyr_nrm_mean, torch::Tensor const& lyr_nrm_invvar,\n                                    torch::Tensor const& inputs, torch::Tensor const& lyr_nrm_gamma_weights,\n                                    torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights,\n                                    torch::Tensor const& output_weights, torch::Tensor const& dropout_mask,\n                                    torch::Tensor const& dropout_add_mask, float dropout_prob);\n\nstd::vector<torch::Tensor> fwd(bool use_mask, bool use_time_mask, bool is_training, int heads,\n                               torch::Tensor const& inputs, torch::Tensor const& lyr_nrm_gamma_weights,\n                               torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights, torch::Tensor const& pad_mask, float dropout_prob) {\n  TORCH_CHECK(inputs.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(lyr_nrm_gamma_weights.dim() == 1, \"expected 1D tensor\");\n  TORCH_CHECK(lyr_nrm_beta_weights.dim() == 1, \"expected 1D tensor\");\n  TORCH_CHECK(input_weights.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(output_weights.dim() == 2, \"expected 2D tensor\");\n\n  TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(lyr_nrm_gamma_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(lyr_nrm_beta_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n\n  if (use_mask) {\n    TORCH_CHECK(pad_mask.dim() == 2, \"expected 2D tensor\");\n    TORCH_CHECK(pad_mask.scalar_type() == at::ScalarType::Byte, \"Only BYTE is supported\");\n  }\n\n  return fwd_cuda(use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights,\n                  output_weights, use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr, dropout_prob);\n}\n\nstd::vector<torch::Tensor> bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results,\n                               torch::Tensor const& dropout_results, torch::Tensor const& softmax_results,\n                               torch::Tensor const& input_lin_results, torch::Tensor const& lyr_nrm_results,\n                               torch::Tensor const& lyr_nrm_mean, torch::Tensor const& lyr_nrm_invvar,\n                               torch::Tensor const& inputs, torch::Tensor const& lyr_nrm_gamma_weights,\n                               torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights,\n                               torch::Tensor const& output_weights, torch::Tensor const& dropout_mask,\n                               torch::Tensor const& dropout_add_mask, float dropout_prob) {\n  TORCH_CHECK(output_grads.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(matmul2_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(dropout_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(softmax_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(input_lin_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(lyr_nrm_results.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(lyr_nrm_mean.dim() == 1, \"expected 1D tensor\");\n  TORCH_CHECK(lyr_nrm_invvar.dim() == 1, \"expected 1D tensor\");\n  TORCH_CHECK(inputs.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(lyr_nrm_gamma_weights.dim() == 1, \"expected 1D tensor\");\n  TORCH_CHECK(lyr_nrm_beta_weights.dim() == 1, \"expected 1D tensor\");\n  TORCH_CHECK(input_weights.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(output_weights.dim() == 2, \"expected 2D tensor\");\n  TORCH_CHECK(dropout_mask.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(dropout_add_mask.dim() == 3, \"expected 3D tensor\");\n\n  TORCH_CHECK(output_grads.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(matmul2_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(dropout_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(softmax_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_lin_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(lyr_nrm_results.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(lyr_nrm_mean.scalar_type() == at::ScalarType::Float, \"Only FLOAT is supported\");\n  TORCH_CHECK(lyr_nrm_invvar.scalar_type() == at::ScalarType::Float, \"Only FLOAT is supported\");\n  TORCH_CHECK(inputs.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(lyr_nrm_gamma_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(lyr_nrm_beta_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(input_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(output_weights.scalar_type() == at::ScalarType::Half, \"Only HALF is supported\");\n  TORCH_CHECK(dropout_mask.scalar_type() == at::ScalarType::Byte, \"Only BYTE is supported\");\n  TORCH_CHECK(dropout_add_mask.scalar_type() == at::ScalarType::Byte, \"Only BYTE is supported\");\n\n  return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, softmax_results, input_lin_results,\n                  lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights,\n                  input_weights, output_weights, dropout_mask, dropout_add_mask, dropout_prob);\n}\n\n}  // end namespace cublas_gemmex\n}  // end namespace self_norm_add\n}  // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"additive_mask_softmax_dropout_forward\", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd,\n        \"Self Multihead Attention masked softmax dropout -- Forward.\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"additive_mask_softmax_dropout_backward\", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd,\n        \"Self Multihead Attention masked softmax dropout -- Backward.\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"mask_softmax_dropout_forward\", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd,\n        \"Self Multihead Attention masked softmax dropout -- Forward.\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"mask_softmax_dropout_backward\", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd,\n        \"Self Multihead Attention masked softmax dropout -- Backward.\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"encdec_multihead_attn_forward\", &multihead_attn::encdec::cublas_gemmex::fwd,\n        \"Encdec Multihead Attention Forward.\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"encdec_multihead_attn_backward\", &multihead_attn::encdec::cublas_gemmex::bwd,\n        \"Encdec Multihead Attention Backward.\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"encdec_multihead_attn_norm_add_forward\", &multihead_attn::encdec_norm_add::cublas_gemmex::fwd,\n        \"Encdec Multihead Attention Plus Layer Norm and Residual Add Forward.\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"encdec_multihead_attn_norm_add_backward\", &multihead_attn::encdec_norm_add::cublas_gemmex::bwd,\n        \"Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"self_attn_forward\", &multihead_attn::self::cublas_gemmex::fwd, \"Self Multihead Attention Forward.\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"self_attn_backward\", &multihead_attn::self::cublas_gemmex::bwd, \"Self Multihead Attention Backward.\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"self_attn_bias_forward\", &multihead_attn::self_bias::cublas_gemmex::fwd,\n        \"Self Multihead Attention with Bias -- Forward.\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"self_attn_bias_backward\", &multihead_attn::self_bias::cublas_gemmex::bwd,\n        \"Self Multihead Attention with Bias -- Backward.\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"self_attn_bias_additive_mask_forward\", &multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd,\n        \"Self Multihead Attention with Bias -- Forward.\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"self_attn_bias_additive_mask_backward\", &multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd,\n        \"Self Multihead Attention with Bias -- Backward.\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"self_attn_norm_add_forward\", &multihead_attn::self_norm_add::cublas_gemmex::fwd,\n        \"Self Multihead Attention Plus Layer Norm and Residual Add Forward.\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"self_attn_norm_add_backward\", &multihead_attn::self_norm_add::cublas_gemmex::bwd,\n        \"Self Multihead Attention Plus Layer Norm and Residual Add Backward.\",\n        py::call_guard<py::gil_scoped_release>());\n}\n\n#undef CHECK_CUDA\n#undef CHECK_CONTIGUOUS\n#undef CHECK_INPUT\n"
  },
  {
    "path": "apex/contrib/csrc/multihead_attn/philox.cuh",
    "content": "#pragma once\n// Philox CUDA.\n\nnamespace {\n\nclass Philox {\n public:\n  __device__ inline Philox(unsigned long long seed, unsigned long long subsequence, unsigned long long offset)\n      : STATE(0) {\n    // key.x = (unsigned int)seed;\n    // key.y = (unsigned int)(seed >> 32);\n    // counter = make_uint4(0, 0, 0, 0);\n    // counter.z = (unsigned int)(subsequence);\n    // counter.w = (unsigned int)(subsequence >> 32);\n    // STATE = 0;\n    // incr_n(offset / 4);\n\n    key = reinterpret_cast<const uint2&>(seed);\n    ull2* tmp = reinterpret_cast<ull2*>(&counter);\n    tmp->x = offset / 4;\n    tmp->y = subsequence;\n  }\n  __device__ inline uint4 operator()() {\n    if (STATE == 0) {\n      uint4 counter_ = counter;\n      uint2 key_ = key;\n      // 7-round philox\n      for (int i = 0; i < 6; i++) {\n        counter_ = single_round(counter_, key_);\n        key_.x += (kPhilox10A);\n        key_.y += (kPhilox10B);\n      }\n      output = single_round(counter_, key_);\n      incr();\n    }\n    // return a float4 directly\n    // unsigned long ret;\n    // switch(STATE) {\n    //  case 0: ret = output.x; break;\n    //  case 1: ret = output.y; break;\n    //  case 2: ret = output.z; break;\n    //  case 3: ret = output.w; break;\n    //}\n    // STATE = (STATE + 1) % 4;\n    return output;\n  }\n\n private:\n  struct ull2 {\n    uint64_t x;\n    uint64_t y;\n  };\n  uint4 counter;\n  uint4 output;\n  uint2 key;\n  unsigned int STATE;\n  __device__ inline void incr_n(unsigned long long n) {\n    unsigned int nlo = (unsigned int)(n);\n    unsigned int nhi = (unsigned int)(n >> 32);\n    counter.x += nlo;\n    if (counter.x < nlo) nhi++;\n    counter.y += nhi;\n    if (nhi <= counter.y) return;\n    if (++counter.z) return;\n    ++counter.w;\n  }\n\n  __device__ uint4 incr128(uint4 ctr) {\n    uint4 res;\n    asm(\"add.cc.u32      %0, %4, %8;\\n\\t\"\n        \"addc.cc.u32     %1, %5, %9;\\n\\t\"\n        \"addc.cc.u32     %2, %6, %10;\\n\\t\"\n        \"addc.u32        %3, %7, %11;\\n\\t\"\n        : \"=r\"(res.x), \"=r\"(res.y), \"=r\"(res.z), \"=r\"(res.w)\n        : \"r\"(ctr.x), \"r\"(ctr.y), \"r\"(ctr.z), \"r\"(ctr.w), \"n\"(1), \"n\"(0), \"n\"(0), \"n\"(0));\n    return res;\n  }\n\n  __device__ inline void incr() { counter = incr128(counter); }\n  __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, unsigned int* result_high) {\n    *result_high = __umulhi(a, b);\n    return a * b;\n  }\n  __device__ uint2 mulhilo32_v2(unsigned int a, unsigned int b) {\n    uint2* res;\n    unsigned long long tmp;\n    asm(\"mul.wide.u32      %0, %1, %2;\\n\\t\" : \"=l\"(tmp) : \"r\"(a), \"r\"(b));\n    res = (uint2*)(&tmp);\n    return *res;\n  }\n  __device__ inline uint4 single_round(uint4 ctr, uint2 key) {\n    // unsigned int hi0;\n    // unsigned int hi1;\n    // unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);\n    // unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);\n    // uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};\n    uint2 res0 = mulhilo32_v2(kPhiloxSA, ctr.x);\n    uint2 res1 = mulhilo32_v2(kPhiloxSB, ctr.z);\n    uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};\n    return ret;\n  }\n  static const unsigned long kPhilox10A = 0x9E3779B9;\n  static const unsigned long kPhilox10B = 0xBB67AE85;\n  static const unsigned long kPhiloxSA = 0xD2511F53;\n  static const unsigned long kPhiloxSB = 0xCD9E8D57;\n};\n// Inverse of 2^32.\nconstexpr float M_RAN_INVM32 = 2.3283064e-10f;\n__device__ __inline__ float4 uniform4(uint4 x) {\n  return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32, x.w * M_RAN_INVM32);\n}\n\n}  // namespace\n"
  },
  {
    "path": "apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include <cuda_runtime.h>\n#include <math.h>\n#include <torch/extension.h>\n\n#include <iostream>\n#include <vector>\n\n#include \"dropout.cuh\"\n#include \"softmax.cuh\"\n#include \"strided_batched_gemm.cuh\"\n\nnamespace multihead_attn {\nnamespace self_bias_additive_mask {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs,\n                                    torch::Tensor const& input_weights, torch::Tensor const& output_weights,\n                                    torch::Tensor const& input_biases, torch::Tensor const& output_biases,\n                                    const half* pad_mask, float dropout_prob) {\n  const int embed_dim = inputs.size(2);\n  const int sequences = inputs.size(1);\n  const int q_seq_len = inputs.size(0);\n  const int k_seq_len = q_seq_len;\n  const int batches = sequences * q_seq_len;\n  const int head_dim = embed_dim / heads;\n  const int output_lin_dim = 3 * embed_dim;\n  const int attn_batches = heads * sequences;\n  const int lead_dim = attn_batches * 3 * head_dim;\n  const int batch_stride = 3 * head_dim;\n  [[maybe_unused]] const int dropout_elems = attn_batches * q_seq_len * k_seq_len;\n  const float alpha = 1.0;\n  const float beta_zero = 0.0;\n  const float beta_one = 1.0;\n  const float scale = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // There is no reason to use more than one stream as every kernel is\n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated\n  // by ATen library code)\n  auto act_options = inputs.options().requires_grad(false);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);\n  torch::Tensor bmm1_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);\n  torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);\n  torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);\n  torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);\n  torch::Tensor outputs = torch::empty_like(inputs, act_options);\n\n  // Input Linear Results Pointers to Q, K, and V of interviewed activations\n  void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr());\n  void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);\n  void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2 * head_dim);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* bmm1_results_ptr = static_cast<void*>(bmm1_results.data_ptr());\n  void* dropout_results_ptr = static_cast<void*>(dropout_results.data_ptr());\n\n  char a_layout_t{'t'};\n  char a_layout_n{'n'};\n  char b_layout_n{'n'};\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  // Input Linear Fwd\n  input_lin_results.copy_(input_biases);\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, static_cast<const void*>(&alpha),\n      static_cast<const void*>(input_weights.data_ptr()), CUDA_R_16F, embed_dim,\n      static_cast<const void*>(inputs.data_ptr()), CUDA_R_16F, embed_dim, static_cast<const void*>(&beta_one),\n      q_lin_results_ptr, CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)\n  gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale,\n                        static_cast<const half*>(k_lin_results_ptr), lead_dim, batch_stride,\n                        static_cast<const half*>(q_lin_results_ptr), lead_dim, batch_stride, beta_zero,\n                        static_cast<half*>(bmm1_results_ptr), k_seq_len, k_seq_len * q_seq_len, attn_batches);\n  // Padded Softmax\n  [[maybe_unused]] bool softmax_success = false;\n  if (is_training) {\n    softmax_success = dispatch_additive_masked_softmax_dropout<half, half, float>(\n        reinterpret_cast<half*>(dropout_results_ptr),\n        (is_training) ? reinterpret_cast<uint8_t*>(dropout_mask.data_ptr<uint8_t>()) : nullptr,\n        reinterpret_cast<const half*>(bmm1_results_ptr), pad_mask, attn_batches * q_seq_len * q_seq_len, k_seq_len,\n        k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences, 1.0f - dropout_prob, stream);\n  } else {\n    softmax_success = dispatch_additive_masked_softmax<half, half, float>(\n        reinterpret_cast<half*>(dropout_results_ptr),  // this is actually softmax results, but\n                                                       // making it consistent for the next function\n        reinterpret_cast<const half*>(bmm1_results_ptr), pad_mask, k_seq_len, k_seq_len, attn_batches * q_seq_len,\n        attn_batches * q_seq_len / sequences);\n  }\n\n  // Matmul2\n  gemm_switch_fp32accum(\n      a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, static_cast<const half*>(v_lin_results_ptr),\n      lead_dim, batch_stride, static_cast<const half*>(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len,\n      beta_zero, static_cast<half*>(matmul2_results.data_ptr()), head_dim * attn_batches, head_dim, attn_batches);\n\n  outputs.copy_(output_biases);\n\n  // Output Linear\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim, static_cast<const void*>(&alpha),\n      static_cast<const void*>(output_weights.data_ptr()), CUDA_R_16F, embed_dim,\n      static_cast<const void*>(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast<const void*>(&beta_one),\n      static_cast<void*>(outputs.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F,\n      // CUBLAS_GEMM_ALGO1_TENSOR_OP));\n      CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {input_lin_results, bmm1_results, dropout_results, dropout_mask, matmul2_results, outputs};\n}\n\nstd::vector<torch::Tensor> bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results,\n                                    torch::Tensor const& dropout_results, torch::Tensor const& bmm1_results,\n                                    torch::Tensor const& pad_mask, torch::Tensor const& input_lin_results,\n                                    torch::Tensor const& inputs, torch::Tensor const& input_weights,\n                                    torch::Tensor const& output_weights, torch::Tensor const& dropout_mask,\n                                    float dropout_prob) {\n  const int embed_dim = inputs.size(2);\n  const int sequences = inputs.size(1);\n  const int q_seq_len = inputs.size(0);\n  const int k_seq_len = q_seq_len;\n  const int batches = sequences * q_seq_len;\n  const int head_dim = embed_dim / heads;\n  const int output_lin_dim = 3 * embed_dim;\n  const int attn_batches = heads * sequences;\n  const int lead_dim = attn_batches * 3 * head_dim;\n  const int batch_stride = 3 * head_dim;\n  [[maybe_unused]] const int dropout_elems = attn_batches * q_seq_len * k_seq_len;\n  const float alpha = 1.0;\n  const float beta = 0.0;\n  const float scale = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // Output Tensor Allocations\n  torch::Tensor input_grads = torch::empty_like(inputs);\n  torch::Tensor input_weight_grads = torch::empty_like(input_weights);\n  torch::Tensor output_weight_grads = torch::empty_like(output_weights);\n  // Intermediate Tensor Allocations\n  at::Tensor output_lin_grads = torch::empty_like(matmul2_results);\n  at::Tensor matmul2_grads = torch::empty_like(dropout_results);\n  at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);\n\n  auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());\n  auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;\n  auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2 * head_dim;\n\n  auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());\n  auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;\n  auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2 * head_dim;\n\n  char a_layout_n{'n'};\n  char a_layout_t{'t'};\n  char b_layout_n{'n'};\n  char b_layout_t{'t'};\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n\n  // Output Linear Dgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim,\n                                    static_cast<const void*>(&alpha),\n                                    static_cast<const void*>(output_weights.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(output_grads.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(&beta), static_cast<void*>(output_lin_grads.data_ptr()),\n                                    CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  // Output Linear Wgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches,\n                                    static_cast<const void*>(&alpha),\n                                    static_cast<const void*>(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(output_grads.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(&beta), static_cast<void*>(output_weight_grads.data_ptr()),\n                                    CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  auto output_bias_grads = output_grads.view({-1, embed_dim}).sum(0, false);\n  // MatMul2 Dgrad1\n  gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha,\n                        static_cast<const half*>(v_lin_results_ptr), lead_dim, batch_stride,\n                        static_cast<const half*>(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, beta,\n                        static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, attn_batches);\n\n  // Matmul2 Dgrad2\n  gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, alpha,\n                        static_cast<const half*>(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim,\n                        static_cast<const half*>(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta,\n                        v_lin_grads_ptr, lead_dim, batch_stride, attn_batches);\n\n  // Apply Dropout Mask and Scale by Dropout Probability\n  // Softmax Grad\n  dispatch_masked_scale_softmax_backward_recompute<half, half, float, false>(\n      static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half* const>(matmul2_grads.data_ptr()),\n      reinterpret_cast<half const*>(bmm1_results.data_ptr()), reinterpret_cast<half const*>(pad_mask.data_ptr()),\n      static_cast<uint8_t const*>(dropout_mask.data_ptr()), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,\n      attn_batches * q_seq_len / sequences, attn_batches * q_seq_len, stream);\n\n  // Matmul1 Dgrad1\n  gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, scale, k_lin_results_ptr, lead_dim,\n                        batch_stride, static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len,\n                        beta, q_lin_grads_ptr, lead_dim, batch_stride, attn_batches);\n\n  // Matmul1 Dgrad2\n  gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, scale, q_lin_results_ptr, lead_dim,\n                        batch_stride, static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len,\n                        beta, k_lin_grads_ptr, lead_dim, batch_stride, attn_batches);\n  // Input Linear Dgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim,\n                                    static_cast<const void*>(&alpha),\n                                    static_cast<const void*>(input_weights.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(input_lin_output_grads.data_ptr()),\n                                    // static_cast<const void*>(q_lin_grads_ptr),\n                                    CUDA_R_16F, output_lin_dim, static_cast<const void*>(&beta),\n                                    static_cast<void*>(input_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F,\n                                    // CUBLAS_GEMM_ALGO10_TENSOR_OP));\n                                    CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Input Linear Wgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches, static_cast<const void*>(&alpha),\n      static_cast<const void*>(inputs.data_ptr()), CUDA_R_16F, embed_dim, static_cast<const void*>(q_lin_grads_ptr),\n      CUDA_R_16F, output_lin_dim, static_cast<const void*>(&beta), static_cast<void*>(input_weight_grads.data_ptr()),\n      CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {input_grads, input_weight_grads, output_weight_grads, input_bias_grads, output_bias_grads};\n}\n\n}  // end namespace cublas_gemmex\n}  // namespace self_bias_additive_mask\n}  // end namespace multihead_attn\n"
  },
  {
    "path": "apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include <cuda_runtime.h>\n#include <math.h>\n#include <torch/extension.h>\n\n#include <iostream>\n#include <vector>\n\n#include \"dropout.cuh\"\n#include \"softmax.cuh\"\n#include \"strided_batched_gemm.cuh\"\n\nnamespace multihead_attn {\nnamespace self_bias {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs,\n                                    torch::Tensor const& input_weights, torch::Tensor const& output_weights,\n                                    torch::Tensor const& input_biases, torch::Tensor const& output_biases,\n                                    const uint8_t* pad_mask, float dropout_prob) {\n  const int embed_dim = inputs.size(2);\n  const int sequences = inputs.size(1);\n  const int q_seq_len = inputs.size(0);\n  const int k_seq_len = q_seq_len;\n  const int batches = sequences * q_seq_len;\n  const int head_dim = embed_dim / heads;\n  const int output_lin_dim = 3 * embed_dim;\n  const int attn_batches = heads * sequences;\n  const int lead_dim = attn_batches * 3 * head_dim;\n  const int batch_stride = 3 * head_dim;\n  [[maybe_unused]] const int dropout_elems = attn_batches * q_seq_len * k_seq_len;\n  const float alpha = 1.0;\n  const float beta_zero = 0.0;\n  const float beta_one = 1.0;\n  const float scale = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // There is no reason to use more than one stream as every kernel is\n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated\n  // by ATen library code)\n  auto act_options = inputs.options().requires_grad(false);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);\n  torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);\n  torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);\n  torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);\n  torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);\n  torch::Tensor outputs = torch::empty_like(inputs, act_options);\n\n  // Input Linear Results Pointers to Q, K, and V of interviewed activations\n  void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr());\n  void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);\n  void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2 * head_dim);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n\n  char a_layout_t{'t'};\n  char a_layout_n{'n'};\n  char b_layout_n{'n'};\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  // Input Linear Fwd\n  input_lin_results.copy_(input_biases);\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, static_cast<const void*>(&alpha),\n      static_cast<const void*>(input_weights.data_ptr()), CUDA_R_16F, embed_dim,\n      static_cast<const void*>(inputs.data_ptr()), CUDA_R_16F, embed_dim, static_cast<const void*>(&beta_one),\n      q_lin_results_ptr, CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)\n  gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale,\n                        static_cast<const half*>(k_lin_results_ptr), lead_dim, batch_stride,\n                        static_cast<const half*>(q_lin_results_ptr), lead_dim, batch_stride, beta_zero,\n                        static_cast<half*>(softmax_results_ptr), k_seq_len, k_seq_len * q_seq_len, attn_batches);\n  // Padded Softmax\n  [[maybe_unused]] bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(reinterpret_cast<half*>(softmax_results_ptr),\n                                                          reinterpret_cast<const half*>(softmax_results_ptr), k_seq_len,\n                                                          k_seq_len, attn_batches * q_seq_len);\n  } else {\n    if (use_time_mask) {\n      softmax_success = dispatch_time_masked_softmax<half, half, float>(\n          reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<const half*>(softmax_results_ptr), pad_mask,\n          k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);\n    } else {\n      softmax_success = dispatch_masked_softmax<half, half, float>(\n          reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<const half*>(softmax_results_ptr), pad_mask,\n          k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences);\n    }\n  }\n\n  if (is_training) {\n    // use at:: function so that C++ version generates the same random mask as\n    // python version\n    auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f - dropout_prob);\n    dropout_results = std::get<0>(dropout_tuple);\n    dropout_mask = std::get<1>(dropout_tuple);\n  }\n\n  // Matmul2\n  gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha,\n                        static_cast<const half*>(v_lin_results_ptr), lead_dim, batch_stride,\n                        (is_training) ? static_cast<const half*>(dropout_results.data_ptr())\n                                      : static_cast<const half*>(softmax_results.data_ptr()),\n                        k_seq_len, k_seq_len * q_seq_len, beta_zero, static_cast<half*>(matmul2_results.data_ptr()),\n                        head_dim * attn_batches, head_dim, attn_batches);\n\n  outputs.copy_(output_biases);\n\n  // Output Linear\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim, static_cast<const void*>(&alpha),\n      static_cast<const void*>(output_weights.data_ptr()), CUDA_R_16F, embed_dim,\n      static_cast<const void*>(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast<const void*>(&beta_one),\n      static_cast<void*>(outputs.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F,\n      // CUBLAS_GEMM_ALGO1_TENSOR_OP));\n      CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {input_lin_results, softmax_results, dropout_results, dropout_mask, matmul2_results, outputs};\n}\n\nstd::vector<torch::Tensor> bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results,\n                                    torch::Tensor const& dropout_results, torch::Tensor const& softmax_results,\n                                    torch::Tensor const& input_lin_results, torch::Tensor const& inputs,\n                                    torch::Tensor const& input_weights, torch::Tensor const& output_weights,\n                                    torch::Tensor const& dropout_mask, float dropout_prob) {\n  const int embed_dim = inputs.size(2);\n  const int sequences = inputs.size(1);\n  const int q_seq_len = inputs.size(0);\n  const int k_seq_len = q_seq_len;\n  const int batches = sequences * q_seq_len;\n  const int head_dim = embed_dim / heads;\n  const int output_lin_dim = 3 * embed_dim;\n  const int attn_batches = heads * sequences;\n  const int lead_dim = attn_batches * 3 * head_dim;\n  const int batch_stride = 3 * head_dim;\n  [[maybe_unused]] const int dropout_elems = attn_batches * q_seq_len * k_seq_len;\n  const float alpha = 1.0;\n  const float beta = 0.0;\n  const float scale = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // Output Tensor Allocations\n  torch::Tensor input_grads = torch::empty_like(inputs);\n  torch::Tensor input_weight_grads = torch::empty_like(input_weights);\n  torch::Tensor output_weight_grads = torch::empty_like(output_weights);\n  // Intermediate Tensor Allocations\n  at::Tensor output_lin_grads = torch::empty_like(matmul2_results);\n  at::Tensor matmul2_grads = torch::empty_like(dropout_results);\n  at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);\n\n  auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());\n  auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;\n  auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2 * head_dim;\n\n  auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());\n  auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;\n  auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2 * head_dim;\n\n  char a_layout_n{'n'};\n  char a_layout_t{'t'};\n  char b_layout_n{'n'};\n  char b_layout_t{'t'};\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n\n  // Output Linear Dgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim,\n                                    static_cast<const void*>(&alpha),\n                                    static_cast<const void*>(output_weights.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(output_grads.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(&beta), static_cast<void*>(output_lin_grads.data_ptr()),\n                                    CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  // Output Linear Wgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches,\n                                    static_cast<const void*>(&alpha),\n                                    static_cast<const void*>(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(output_grads.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(&beta), static_cast<void*>(output_weight_grads.data_ptr()),\n                                    CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  auto output_bias_grads = output_grads.view({-1, embed_dim}).sum(0, false);\n  // MatMul2 Dgrad1\n  gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha,\n                        static_cast<const half*>(v_lin_results_ptr), lead_dim, batch_stride,\n                        static_cast<const half*>(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, beta,\n                        static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, attn_batches);\n\n  // Matmul2 Dgrad2\n  gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, alpha,\n                        static_cast<const half*>(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim,\n                        static_cast<const half*>(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta,\n                        v_lin_grads_ptr, lead_dim, batch_stride, attn_batches);\n\n  // Apply Dropout Mask and Scale by Dropout Probability\n  // Softmax Grad\n  dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(\n      static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),\n      reinterpret_cast<half const*>(softmax_results.data_ptr()), static_cast<uint8_t const*>(dropout_mask.data_ptr()),\n      1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, attn_batches * q_seq_len, stream);\n\n  // Matmul1 Dgrad1\n  gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, scale, k_lin_results_ptr, lead_dim,\n                        batch_stride, static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len,\n                        beta, q_lin_grads_ptr, lead_dim, batch_stride, attn_batches);\n\n  // Matmul1 Dgrad2\n  gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, scale, q_lin_results_ptr, lead_dim,\n                        batch_stride, static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len,\n                        beta, k_lin_grads_ptr, lead_dim, batch_stride, attn_batches);\n  // Input Linear Dgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim,\n                                    static_cast<const void*>(&alpha),\n                                    static_cast<const void*>(input_weights.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(input_lin_output_grads.data_ptr()),\n                                    // static_cast<const void*>(q_lin_grads_ptr),\n                                    CUDA_R_16F, output_lin_dim, static_cast<const void*>(&beta),\n                                    static_cast<void*>(input_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F,\n                                    // CUBLAS_GEMM_ALGO10_TENSOR_OP));\n                                    CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Input Linear Wgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches, static_cast<const void*>(&alpha),\n      static_cast<const void*>(inputs.data_ptr()), CUDA_R_16F, embed_dim, static_cast<const void*>(q_lin_grads_ptr),\n      CUDA_R_16F, output_lin_dim, static_cast<const void*>(&beta), static_cast<void*>(input_weight_grads.data_ptr()),\n      CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {input_grads, input_weight_grads, output_weight_grads, input_bias_grads, output_bias_grads};\n}\n\n}  // end namespace cublas_gemmex\n}  // namespace self_bias\n}  // end namespace multihead_attn\n"
  },
  {
    "path": "apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include <cuda_runtime.h>\n#include <math.h>\n#include <torch/extension.h>\n\n#include <iostream>\n#include <vector>\n\n#include \"dropout.cuh\"\n#include \"softmax.cuh\"\n#include \"strided_batched_gemm.cuh\"\n\nnamespace multihead_attn {\nnamespace self {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs,\n                                    torch::Tensor const& input_weights, torch::Tensor const& output_weights,\n                                    const uint8_t* pad_mask, float dropout_prob) {\n  const int embed_dim = inputs.size(2);\n  const int sequences = inputs.size(1);\n  const int q_seq_len = inputs.size(0);\n  const int k_seq_len = q_seq_len;\n  const int batches = sequences * q_seq_len;\n  const int head_dim = embed_dim / heads;\n  const int output_lin_dim = 3 * embed_dim;\n  const int attn_batches = heads * sequences;\n  const int lead_dim = attn_batches * 3 * head_dim;\n  const int batch_stride = 3 * head_dim;\n  const int dropout_elems = attn_batches * q_seq_len * k_seq_len;\n  const float alpha = 1.0;\n  const float beta = 0.0;\n  const float scale = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // There is no reason to use more than one stream as every kernel is\n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated\n  // by ATen library code)\n  auto act_options = inputs.options().requires_grad(false);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);\n  torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);\n  torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);\n  torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);\n  torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);\n  torch::Tensor outputs = torch::empty_like(inputs, act_options);\n\n  // Input Linear Results Pointers to Q, K, and V of interviewed activations\n  void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr());\n  void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);\n  void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2 * head_dim);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n\n  char a_layout_t{'t'};\n  char a_layout_n{'n'};\n  char b_layout_n{'n'};\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  // Input Linear Fwd\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, static_cast<const void*>(&alpha),\n      static_cast<const void*>(input_weights.data_ptr()), CUDA_R_16F, embed_dim,\n      static_cast<const void*>(inputs.data_ptr()), CUDA_R_16F, embed_dim, static_cast<const void*>(&beta),\n      q_lin_results_ptr, CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)\n  gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale,\n                        static_cast<const half*>(k_lin_results_ptr), lead_dim, batch_stride,\n                        static_cast<const half*>(q_lin_results_ptr), lead_dim, batch_stride, beta,\n                        static_cast<half*>(softmax_results_ptr), k_seq_len, k_seq_len * q_seq_len, attn_batches);\n\n  // Padded Softmax\n  bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(reinterpret_cast<half*>(softmax_results_ptr),\n                                                          reinterpret_cast<const half*>(softmax_results_ptr), k_seq_len,\n                                                          k_seq_len, attn_batches * q_seq_len);\n  } else {\n    if (use_time_mask) {\n      softmax_success = dispatch_time_masked_softmax<half, half, float>(\n          reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<const half*>(softmax_results_ptr), pad_mask,\n          k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);\n    } else {\n      softmax_success = dispatch_masked_softmax<half, half, float>(\n          reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<const half*>(softmax_results_ptr), pad_mask,\n          k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences);\n    }\n  }\n  assert(softmax_success);\n\n  if (is_training) {\n    apex_fused_dropout_cuda<at::Half, float, uint32_t>(\n        static_cast<at::Half const*>(softmax_results.data_ptr()), static_cast<at::Half*>(dropout_results.data_ptr()),\n        static_cast<uint8_t*>(dropout_mask.data_ptr()), dropout_elems, (1.0f - dropout_prob));\n  }\n\n  // Matmul2\n  gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha,\n                        static_cast<const half*>(v_lin_results_ptr), lead_dim, batch_stride,\n                        (is_training) ? static_cast<const half*>(dropout_results.data_ptr())\n                                      : static_cast<const half*>(softmax_results.data_ptr()),\n                        k_seq_len, k_seq_len * q_seq_len, beta, static_cast<half*>(matmul2_results.data_ptr()),\n                        head_dim * attn_batches, head_dim, attn_batches);\n\n  // Output Linear\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim, static_cast<const void*>(&alpha),\n      static_cast<const void*>(output_weights.data_ptr()), CUDA_R_16F, embed_dim,\n      static_cast<const void*>(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast<const void*>(&beta),\n      static_cast<void*>(outputs.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {input_lin_results, softmax_results, dropout_results, dropout_mask, matmul2_results, outputs};\n}\n\nstd::vector<torch::Tensor> bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results,\n                                    torch::Tensor const& dropout_results, torch::Tensor const& softmax_results,\n                                    torch::Tensor const& input_lin_results, torch::Tensor const& inputs,\n                                    torch::Tensor const& input_weights, torch::Tensor const& output_weights,\n                                    torch::Tensor const& dropout_mask, float dropout_prob) {\n  const int embed_dim = inputs.size(2);\n  const int sequences = inputs.size(1);\n  const int q_seq_len = inputs.size(0);\n  const int k_seq_len = q_seq_len;\n  const int batches = sequences * q_seq_len;\n  const int head_dim = embed_dim / heads;\n  const int output_lin_dim = 3 * embed_dim;\n  const int attn_batches = heads * sequences;\n  const int lead_dim = attn_batches * 3 * head_dim;\n  const int batch_stride = 3 * head_dim;\n  const int dropout_elems = attn_batches * q_seq_len * k_seq_len;\n  const float alpha = 1.0;\n  const float beta = 0.0;\n  const float scale = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // Output Tensor Allocations\n  torch::Tensor input_grads = torch::empty_like(inputs);\n  torch::Tensor input_weight_grads = torch::empty_like(input_weights);\n  torch::Tensor output_weight_grads = torch::empty_like(output_weights);\n  // Intermediate Tensor Allocations\n  at::Tensor output_lin_grads = torch::empty_like(matmul2_results);\n  at::Tensor matmul2_grads = torch::empty_like(dropout_results);\n  at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);\n\n  auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());\n  auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;\n  auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2 * head_dim;\n\n  auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());\n  auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;\n  auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2 * head_dim;\n\n  char a_layout_n{'n'};\n  char a_layout_t{'t'};\n  char b_layout_n{'n'};\n  char b_layout_t{'t'};\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n\n  // Output Linear Dgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim,\n                                    static_cast<const void*>(&alpha),\n                                    static_cast<const void*>(output_weights.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(output_grads.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(&beta), static_cast<void*>(output_lin_grads.data_ptr()),\n                                    CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Output Linear Wgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches,\n                                    static_cast<const void*>(&alpha),\n                                    static_cast<const void*>(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(output_grads.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(&beta), static_cast<void*>(output_weight_grads.data_ptr()),\n                                    CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul2 Dgrad1\n  gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha,\n                        static_cast<const half*>(v_lin_results_ptr), lead_dim, batch_stride,\n                        static_cast<const half*>(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, beta,\n                        static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, attn_batches);\n\n  // Matmul2 Dgrad2\n  gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, alpha,\n                        static_cast<const half*>(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim,\n                        static_cast<const half*>(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta,\n                        v_lin_grads_ptr, lead_dim, batch_stride, attn_batches);\n\n  // Apply Dropout Mask and Scale by Dropout Probability\n  apex_masked_scale_cuda<at::Half, float, uint32_t>(\n      static_cast<at::Half const*>(matmul2_grads.data_ptr()), static_cast<at::Half*>(matmul2_grads.data_ptr()),\n      static_cast<uint8_t const*>(dropout_mask.data_ptr()), dropout_elems, (1.0 / (1.0 - dropout_prob)));\n\n  // Softmax Grad\n  bool softmax_success = false;\n  softmax_success = dispatch_softmax_backward<half, half, float>(\n      static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),\n      reinterpret_cast<half const*>(softmax_results.data_ptr()), k_seq_len, k_seq_len, attn_batches * q_seq_len);\n  assert(softmax_success);\n\n  // Matmul1 Dgrad1\n  gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, scale, k_lin_results_ptr, lead_dim,\n                        batch_stride, static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len,\n                        beta, q_lin_grads_ptr, lead_dim, batch_stride, attn_batches);\n\n  // Matmul1 Dgrad2\n  gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, scale, q_lin_results_ptr, lead_dim,\n                        batch_stride, static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len,\n                        beta, k_lin_grads_ptr, lead_dim, batch_stride, attn_batches);\n\n  // Input Linear Dgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim, static_cast<const void*>(&alpha),\n      static_cast<const void*>(input_weights.data_ptr()), CUDA_R_16F, embed_dim,\n      static_cast<const void*>(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim, static_cast<const void*>(&beta),\n      static_cast<void*>(input_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Input Linear Wgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches, static_cast<const void*>(&alpha),\n      static_cast<const void*>(inputs.data_ptr()), CUDA_R_16F, embed_dim, static_cast<const void*>(q_lin_grads_ptr),\n      CUDA_R_16F, output_lin_dim, static_cast<const void*>(&beta), static_cast<void*>(input_weight_grads.data_ptr()),\n      CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {input_grads, input_weight_grads, output_weight_grads};\n}\n\n}  // end namespace cublas_gemmex\n}  // end namespace self\n}  // end namespace multihead_attn\n"
  },
  {
    "path": "apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include <cuda_runtime.h>\n#include <math.h>\n#include <torch/extension.h>\n\n#include <iostream>\n#include <vector>\n\n#include \"dropout.cuh\"\n#include \"layer_norm.cuh\"\n#include \"softmax.cuh\"\n#include \"strided_batched_gemm.cuh\"\n\nnamespace multihead_attn {\nnamespace self_norm_add {\nnamespace cublas_gemmex {\n\nstd::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs,\n                                    torch::Tensor const& lyr_nrm_gamma_weights,\n                                    torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights,\n                                    torch::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob) {\n  const int embed_dim = inputs.size(2);\n  const int sequences = inputs.size(1);\n  const int q_seq_len = inputs.size(0);\n  const int k_seq_len = q_seq_len;\n  const int batches = sequences * q_seq_len;\n  const int total_tokens = batches * embed_dim;\n  const int head_dim = embed_dim / heads;\n  const int output_lin_dim = 3 * embed_dim;\n  const int attn_batches = heads * sequences;\n  const int lead_dim = attn_batches * 3 * head_dim;\n  const int batch_stride = 3 * head_dim;\n  const int dropout_elems = attn_batches * q_seq_len * k_seq_len;\n  const float alpha = 1.0;\n  const float beta = 0.0;\n  const float scale = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // There is no reason to use more than one stream as every kernel is\n  // sequentially dependent\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // 3 Intermediate Results + Output (Note: dropout intermediates are generated\n  // by ATen library code)\n  auto act_options = inputs.options().requires_grad(false);\n  auto lyr_nrm_options = act_options.dtype(torch::kFloat32);\n  auto mask_options = act_options.dtype(torch::kUInt8);\n\n  torch::Tensor lyr_nrm_mean = torch::empty({batches}, lyr_nrm_options);\n  torch::Tensor lyr_nrm_invvar = torch::empty({batches}, lyr_nrm_options);\n  torch::Tensor lyr_nrm_results = torch::empty_like(inputs, act_options);\n\n  torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);\n  torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);\n  torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);\n  torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);\n  torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);\n  torch::Tensor output_lin_results = torch::empty_like(inputs, act_options);\n  torch::Tensor dropout_add_mask = torch::empty_like(inputs, mask_options);\n  torch::Tensor outputs = torch::empty_like(inputs, act_options);\n\n  // Input Linear Results Pointers to Q, K, and V of interviewed activations\n  void* q_lin_results_ptr = static_cast<void*>(input_lin_results.data_ptr());\n  void* k_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + head_dim);\n  void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2 * head_dim);\n\n  // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n\n  char a_layout_t{'t'};\n  char a_layout_n{'n'};\n  char b_layout_n{'n'};\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n  // Layer Norm\n  HostApplyLayerNorm<at::Half, float>(\n      static_cast<at::Half*>(lyr_nrm_results.data_ptr()), static_cast<float*>(lyr_nrm_mean.data_ptr()),\n      static_cast<float*>(lyr_nrm_invvar.data_ptr()), static_cast<const at::Half*>(inputs.data_ptr()),\n      static_cast<int>(batches),    // n1\n      static_cast<int>(embed_dim),  // n2\n      1.0e-5, static_cast<const at::Half*>(lyr_nrm_gamma_weights.data_ptr()),\n      static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));\n\n  // Input Linear Fwd\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_T, CUBLAS_OP_N, output_lin_dim, batches, embed_dim, static_cast<const void*>(&alpha),\n      static_cast<const void*>(input_weights.data_ptr()), CUDA_R_16F, embed_dim,\n      // static_cast<const void*>(inputs.data_ptr()),\n      static_cast<const void*>(lyr_nrm_results.data_ptr()), CUDA_R_16F, embed_dim, static_cast<const void*>(&beta),\n      q_lin_results_ptr, CUDA_R_16F, output_lin_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)\n  gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale,\n                        static_cast<const half*>(k_lin_results_ptr), lead_dim, batch_stride,\n                        static_cast<const half*>(q_lin_results_ptr), lead_dim, batch_stride, beta,\n                        static_cast<half*>(softmax_results_ptr), k_seq_len, k_seq_len * q_seq_len, attn_batches);\n\n  // Padded Softmax\n  bool softmax_success = false;\n  if (pad_mask == nullptr) {\n    softmax_success = dispatch_softmax<half, half, float>(reinterpret_cast<half*>(softmax_results_ptr),\n                                                          reinterpret_cast<const half*>(softmax_results_ptr), k_seq_len,\n                                                          k_seq_len, attn_batches * q_seq_len);\n  } else {\n    if (use_time_mask) {\n      softmax_success = dispatch_time_masked_softmax<half, half, float>(\n          reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<const half*>(softmax_results_ptr), pad_mask,\n          k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len);\n    } else {\n      softmax_success = dispatch_masked_softmax<half, half, float>(\n          reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<const half*>(softmax_results_ptr), pad_mask,\n          k_seq_len, k_seq_len, attn_batches * q_seq_len, attn_batches * q_seq_len / sequences);\n    }\n  }\n  assert(softmax_success);\n\n  if (is_training) {\n    apex_fused_dropout_cuda<at::Half, float, uint32_t>(\n        static_cast<at::Half const*>(softmax_results.data_ptr()), static_cast<at::Half*>(dropout_results.data_ptr()),\n        static_cast<uint8_t*>(dropout_mask.data_ptr()), dropout_elems, (1.0f - dropout_prob));\n  }\n\n  // Matmul2\n  gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha,\n                        static_cast<const half*>(v_lin_results_ptr), lead_dim, batch_stride,\n                        (is_training) ? static_cast<const half*>(dropout_results.data_ptr())\n                                      : static_cast<const half*>(softmax_results.data_ptr()),\n                        // static_cast<const half*>(dropout_results.data_ptr()),\n                        k_seq_len, k_seq_len * q_seq_len, beta, static_cast<half*>(matmul2_results.data_ptr()),\n                        head_dim * attn_batches, head_dim, attn_batches);\n\n  // Output Linear\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_T, CUBLAS_OP_N, embed_dim, batches, embed_dim,\n                                    static_cast<const void*>(&alpha),\n                                    static_cast<const void*>(output_weights.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(&beta), static_cast<void*>(output_lin_results.data_ptr()),\n                                    CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // End-of-block Dropout-Add\n  if (is_training) {\n    apex_dropout_add_cuda<at::Half, float, uint32_t>(\n        static_cast<at::Half const*>(output_lin_results.data_ptr()), static_cast<at::Half const*>(inputs.data_ptr()),\n        static_cast<at::Half*>(outputs.data_ptr()), static_cast<uint8_t*>(dropout_add_mask.data_ptr()), total_tokens,\n        (1.0f - dropout_prob));\n  } else {\n    apex_add_cuda<at::Half, float, uint32_t>(static_cast<at::Half const*>(output_lin_results.data_ptr()),\n                                             static_cast<at::Half const*>(inputs.data_ptr()),\n                                             static_cast<at::Half*>(outputs.data_ptr()), total_tokens);\n  }\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar,  input_lin_results, softmax_results,\n          dropout_results, dropout_mask, matmul2_results, dropout_add_mask,  outputs};\n}\n\nstd::vector<torch::Tensor> bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results,\n                                    torch::Tensor const& dropout_results, torch::Tensor const& softmax_results,\n                                    torch::Tensor const& input_lin_results, torch::Tensor const& lyr_nrm_results,\n                                    torch::Tensor const& lyr_nrm_mean, torch::Tensor const& lyr_nrm_invvar,\n                                    torch::Tensor const& inputs, torch::Tensor const& lyr_nrm_gamma_weights,\n                                    torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights,\n                                    torch::Tensor const& output_weights, torch::Tensor const& dropout_mask,\n                                    torch::Tensor const& dropout_add_mask, float dropout_prob) {\n  const int embed_dim = inputs.size(2);\n  const int sequences = inputs.size(1);\n  const int q_seq_len = inputs.size(0);\n  const int k_seq_len = q_seq_len;\n  const int batches = sequences * q_seq_len;\n  const int total_tokens = batches * embed_dim;\n  const int head_dim = embed_dim / heads;\n  const int output_lin_dim = 3 * embed_dim;\n  const int attn_batches = heads * sequences;\n  const int lead_dim = attn_batches * 3 * head_dim;\n  const int batch_stride = 3 * head_dim;\n  const int dropout_elems = attn_batches * q_seq_len * k_seq_len;\n  const float alpha = 1.0;\n  const float beta = 0.0;\n  const float scale = 1.0 / sqrt(static_cast<float>(head_dim));\n\n  // TODO: Streams can be used in Backprop but I haven't added more than one\n  // in my first attempt to create the code\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n\n  // Output Tensor Allocations\n  torch::Tensor input_grads = torch::empty_like(inputs);\n  torch::Tensor lyr_nrm_gamma_grads = torch::empty_like(lyr_nrm_gamma_weights);\n  torch::Tensor lyr_nrm_beta_grads = torch::empty_like(lyr_nrm_beta_weights);\n  torch::Tensor input_weight_grads = torch::empty_like(input_weights);\n  torch::Tensor output_weight_grads = torch::empty_like(output_weights);\n  // Intermediate Tensor Allocations\n  torch::Tensor dropout_add_grads = torch::empty_like(output_grads);\n  torch::Tensor output_lin_grads = torch::empty_like(matmul2_results);\n  torch::Tensor matmul2_grads = torch::empty_like(dropout_results);\n  torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);\n  torch::Tensor input_lin_grads = torch::empty_like(inputs);\n\n  auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());\n  auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;\n  auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2 * head_dim;\n\n  auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());\n  auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;\n  auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2 * head_dim;\n\n  char a_layout_n{'n'};\n  char a_layout_t{'t'};\n  char b_layout_n{'n'};\n  char b_layout_t{'t'};\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));\n\n  // Dropout Add Backward\n  apex_masked_scale_cuda<at::Half, float, uint32_t>(\n      static_cast<at::Half const*>(output_grads.data_ptr()), static_cast<at::Half*>(dropout_add_grads.data_ptr()),\n      static_cast<uint8_t const*>(dropout_add_mask.data_ptr()), total_tokens, (1.0 / (1.0 - dropout_prob)));\n\n  // Output Linear Dgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, embed_dim,\n                                    static_cast<const void*>(&alpha),\n                                    static_cast<const void*>(output_weights.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(dropout_add_grads.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(&beta), static_cast<void*>(output_lin_grads.data_ptr()),\n                                    CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Output Linear Wgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, embed_dim, batches,\n                                    static_cast<const void*>(&alpha),\n                                    static_cast<const void*>(matmul2_results.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(dropout_add_grads.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(&beta), static_cast<void*>(output_weight_grads.data_ptr()),\n                                    CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // MatMul2 Dgrad1\n  gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha,\n                        static_cast<const half*>(v_lin_results_ptr), lead_dim, batch_stride,\n                        static_cast<const half*>(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, beta,\n                        static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, attn_batches);\n\n  // Matmul2 Dgrad2\n  gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, alpha,\n                        static_cast<const half*>(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim,\n                        static_cast<const half*>(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta,\n                        v_lin_grads_ptr, lead_dim, batch_stride, attn_batches);\n\n  // Apply Dropout Mask and Scale by Dropout Probability\n  apex_masked_scale_cuda<at::Half, float, uint32_t>(\n      static_cast<at::Half const*>(matmul2_grads.data_ptr()), static_cast<at::Half*>(matmul2_grads.data_ptr()),\n      static_cast<uint8_t const*>(dropout_mask.data_ptr()), dropout_elems, (1.0 / (1.0 - dropout_prob)));\n\n  // Softmax Grad\n  bool softmax_success = false;\n  softmax_success = dispatch_softmax_backward<half, half, float>(\n      static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),\n      reinterpret_cast<half const*>(softmax_results.data_ptr()), k_seq_len, k_seq_len, attn_batches * q_seq_len);\n  assert(softmax_success);\n\n  // Matmul1 Dgrad1\n  gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, scale, k_lin_results_ptr, lead_dim,\n                        batch_stride, static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len,\n                        beta, q_lin_grads_ptr, lead_dim, batch_stride, attn_batches);\n\n  // Matmul1 Dgrad2\n  gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, scale, q_lin_results_ptr, lead_dim,\n                        batch_stride, static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len,\n                        beta, k_lin_grads_ptr, lead_dim, batch_stride, attn_batches);\n\n  // Input Linear Dgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(\n      handle, CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, batches, output_lin_dim, static_cast<const void*>(&alpha),\n      static_cast<const void*>(input_weights.data_ptr()), CUDA_R_16F, embed_dim,\n      static_cast<const void*>(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim, static_cast<const void*>(&beta),\n      // static_cast<void*>(input_grads.data_ptr()),\n      static_cast<void*>(input_lin_grads.data_ptr()), CUDA_R_16F, embed_dim, CUDA_R_32F,\n      // CUBLAS_GEMM_ALGO10_TENSOR_OP));\n      CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Input Linear Wgrad\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, embed_dim, output_lin_dim, batches,\n                                    static_cast<const void*>(&alpha),\n                                    // static_cast<const void*>(inputs.data_ptr()),\n                                    static_cast<const void*>(lyr_nrm_results.data_ptr()), CUDA_R_16F, embed_dim,\n                                    static_cast<const void*>(q_lin_grads_ptr), CUDA_R_16F, output_lin_dim,\n                                    static_cast<const void*>(&beta), static_cast<void*>(input_weight_grads.data_ptr()),\n                                    CUDA_R_16F, embed_dim, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n\n  // Fused Layer Norm Bwd with Residual Add\n  HostLayerNormGradient<half, float>(\n      static_cast<const half*>(input_lin_grads.data_ptr()), static_cast<const half*>(output_grads.data_ptr()),\n      static_cast<const float*>(lyr_nrm_mean.data_ptr()), static_cast<const float*>(lyr_nrm_invvar.data_ptr()), inputs,\n      static_cast<int>(batches),    // n1\n      static_cast<int>(embed_dim),  // n2\n      static_cast<const half*>(lyr_nrm_gamma_weights.data_ptr()),\n      static_cast<const half*>(lyr_nrm_beta_weights.data_ptr()), 1.0e-5, static_cast<half*>(input_grads.data_ptr()),\n      static_cast<half*>(lyr_nrm_gamma_grads.data_ptr()), static_cast<half*>(lyr_nrm_beta_grads.data_ptr()));\n\n  TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));\n\n  return {input_grads, lyr_nrm_gamma_grads, lyr_nrm_beta_grads, input_weight_grads, output_weight_grads};\n}\n\n}  // end namespace cublas_gemmex\n}  // end namespace self_norm_add\n}  // end namespace multihead_attn\n"
  },
  {
    "path": "apex/contrib/csrc/multihead_attn/softmax.cuh",
    "content": "#pragma once\n#include <curand_kernel.h>\n\n#include <ATen/cuda/CUDAGraphsUtils.cuh>\n\n#include \"philox.cuh\"\n\n#ifdef OLD_GENERATOR_PATH\n#include <ATen/CUDAGeneratorImpl.h>\n#else\n#include <ATen/cuda/CUDAGeneratorImpl.h>\n#endif\n\n#include <assert.h>\n#include <cuda_fp16.h>\n#include <stdint.h>\n\n#include <cfloat>\n#include <cmath>\n#include <limits>\n\nnamespace {\ntemplate <typename Datatype, int ELEMENTS_PER_LDG>\n__device__ __inline__ void copy_vector(Datatype* dst, const Datatype* src);\n\ntemplate <typename Datatype, int ELEMENTS_PER_LDG>\n__device__ __inline__ void apply_mask(Datatype* dst, Datatype value, const uint8_t* src);\n\ntemplate <typename Datatype, int ELEMENTS_PER_LDG>\n__device__ __inline__ void apply_additive_mask(Datatype* dst, const Datatype* additive_mask);\n\ntemplate <>\n__device__ __inline__ void copy_vector<__half, 1>(__half* dst, const __half* src) {\n  *dst = *src;\n}\n\ntemplate <>\n__device__ __inline__ void copy_vector<float, 1>(float* dst, const float* src) {\n  *dst = *src;\n}\n\ntemplate <>\n__device__ __inline__ void copy_vector<__half, 4>(__half* dst, const __half* src) {\n  *((float2*)dst) = *((float2*)src);\n}\ntemplate <>\n__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t* dst, const uint8_t* src) {\n  *dst = *src;\n}\n\ntemplate <>\n__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t* dst, const uint8_t* src) {\n  *((half2*)dst) = *((half2*)src);\n}\n\ntemplate <>\n__device__ __inline__ void apply_mask<__half, 1>(__half* dst, __half value, const uint8_t* src) {\n  if (*src == 1) {\n    *dst = value;\n  }\n}\n\ntemplate <>\n__device__ __inline__ void apply_additive_mask<__half, 1>(__half* dst, const __half* additive_mask) {\n  *dst += *additive_mask;\n}\n\ntemplate <>\n__device__ __inline__ void apply_additive_mask<__half, 4>(__half* dst, const __half* additive_mask) {\n  *dst += *additive_mask;\n  *(dst + 1) += *(additive_mask + 1);\n  *(dst + 2) += *(additive_mask + 2);\n  *(dst + 3) += *(additive_mask + 3);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// Warp Softmax forward\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate\n// over all data. WARP_SIZE number of elements working on a single batch, has to\n// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32,\n          int ELEMENTS_PER_LDG_STG = 1>\n__global__ void softmax_warp_forward(input_t* dst, const output_t* src, int batch_size, int stride, int element_count) {\n  assert(ELEMENTS_PER_LDG_STG == 1);\n\n  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n\n  // batch_size might not be a multiple of WARP_BATCH. Check how\n  // many batches have to computed within this WARP.\n  int local_batches = batch_size - first_batch;\n  if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;\n\n  // there might be multiple batches per warp. compute the index within the\n  // batch\n  int local_idx = threadIdx.x;\n\n  src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n  dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n\n  // load data from global memory\n  input_t elements_input[WARP_BATCH][WARP_ITERATIONS];\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    int batch_element_count = (i >= local_batches) ? 0 : element_count;\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n#pragma unroll\n      for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n        elements_input[i][it + element] = -std::numeric_limits<float>::infinity();\n      }\n\n      if (element_index < batch_element_count) {\n        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + i * element_count + it * WARP_SIZE);\n      }\n    }\n  }\n\n  // convert input_t to acc_t\n  acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      elements[i][it] = elements_input[i][it];\n    }\n  }\n\n  constexpr uint32_t FULL_MASK = 0xffffffff;\n\n  // compute local max_value\n\n  // take the max_value of the first element to avoid one max call\n  acc_t max_value[WARP_BATCH];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    max_value[i] = elements[i][0];\n  }\n\n#pragma unroll\n  for (int it = 1; it < WARP_ITERATIONS; ++it) {\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n    }\n  }\n\n// reduction max_value\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n    float val[WARP_BATCH];\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n    }\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n    }\n  }\n\n  // compute local sum\n  acc_t sum[WARP_BATCH]{0.0f};\n\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      // elements[i][it] = expf(elements[i][it] - max_value[i]);\n      elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n      sum[i] += elements[i][it];\n    }\n  }\n\n// reduction sum\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n    }\n  }\n\n// store result\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    if (i >= local_batches) break;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n      if (element_index < element_count) {\n        // dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];\n        output_t out[ELEMENTS_PER_LDG_STG];\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          out[element] = elements[i][it + element] / sum[i];\n        }\n        copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);\n      } else {\n        break;\n      }\n    }\n  }\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate\n// over all data. WARP_SIZE number of elements working on a single batch, has to\n// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t>\nusing softmax_forward_func = void (*)(input_t* dst, const output_t* src, int batch_size, int stride, int element_count);\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_softmax_kernel(int log2_elements, int& warp_size, int& batches_per_warp,\n                         softmax_forward_func<input_t, output_t>& kernel) {\n  // determine size of a warp\n  const int next_power_of_two = 1 << log2_elements;\n  warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n\n  // determine how many batches a warp should process.\n  batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n\n  switch (log2_elements) {\n    case 0:  // 1\n      kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 1, 1>;\n      break;\n    case 1:  // 2\n      kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 2, 1>;\n      break;\n    case 2:  // 4\n      kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 4, 1>;\n      break;\n    case 3:  // 8\n      kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 8, 1>;\n      break;\n    case 4:  // 16\n      kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 16, 1>;\n      break;\n    case 5:  // 32\n      kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 32, 1>;\n      break;\n    case 6:  // 64\n      kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 2, 32, 1>;\n      break;\n    case 7:  // 128\n      kernel = &softmax_warp_forward<input_t, output_t, acc_t, 2, 4, 32, 1>;\n      break;\n    case 8:  // 256\n      kernel = &softmax_warp_forward<input_t, output_t, acc_t, 1, 8, 32, 1>;\n      break;\n    case 9:  // 512\n      kernel = &softmax_warp_forward<input_t, output_t, acc_t, 1, 16, 32, 1>;\n      break;\n    case 10:  // 1024\n      kernel = &softmax_warp_forward<input_t, output_t, acc_t, 1, 32, 32, 1>;\n      break;\n    default:\n      return false;\n  }\n  return true;\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nbool dispatch_softmax(output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride,\n                      int batch_count) {\n  if (softmax_elements == 0) {\n    return true;\n  } else if (softmax_elements <= 1024) {\n    // compute function index. there's a function for each power of two size up\n    // to 1024.\n    int log2_elements = 0;\n    while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n\n    softmax_forward_func<input_t, output_t> kernel;\n    int warp_size, batches_per_warp;\n    if (!warp_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n      return false;\n    }\n\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n\n    // compute warps per block.\n    int warps_per_block = (threads_per_block / warp_size);\n\n    // compute launch size\n    int batches_per_block = warps_per_block * batches_per_warp;\n    int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n    dim3 threads(warp_size, warps_per_block, 1);\n\n    // launch\n    kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, batch_count, softmax_elements_stride,\n                                                                     softmax_elements);\n    return true;\n  }\n  return false;\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE,\n          int ELEMENTS_PER_LDG_STG>\n__global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t* dst, uint8_t* dropout_mask,\n                                                                  const input_t* src, const input_t* pad_mask,\n                                                                  int batch_size, int stride, int element_count,\n                                                                  int pad_batch_stride, at::PhiloxCudaState philox_args,\n                                                                  float p) {\n  assert(ELEMENTS_PER_LDG_STG == 4);\n  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n  int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x;\n  acc_t pinv = acc_t(1) / p;\n  // batch_size might not be a multiple of WARP_BATCH. Check how\n  // many batches have to computed within this WARP.\n  int local_batches = batch_size - first_batch;\n  if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;\n\n  // there might be multiple batches per warp. compute the index within the\n  // batch\n  int local_idx = threadIdx.x;\n  // vectorize if element_count is multiple of 4, else don't vectorize\n  input_t elements_input[WARP_BATCH][WARP_ITERATIONS];\n\n  int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n  src += thread_offset;\n  dst += thread_offset;\n  dropout_mask += thread_offset;\n\n  // load data from global memory\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    int batch_element_count = (i >= local_batches) ? 0 : element_count;\n    int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    const half* curr_mask = pad_mask + pad_thread_offset;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n#pragma unroll\n      for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n        // masking_value is a large negative value\n        elements_input[i][it + element] = -10000;\n      }\n\n      if (element_index < batch_element_count) {\n        int itr_jmp = it * WARP_SIZE;\n        int itr_idx = i * element_count + itr_jmp;\n        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + itr_idx);\n        apply_additive_mask<input_t, ELEMENTS_PER_LDG_STG>(\n            &elements_input[i][it],\n            curr_mask + itr_jmp);  //(__half)-std::numeric_limits<float>::infinity()\n      }\n    }\n  }\n  // convert input_t to acc_t\n  acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      elements[i][it] = elements_input[i][it];\n    }\n  }\n\n  constexpr uint32_t FULL_MASK = 0xffffffff;\n\n  // compute local max_value\n\n  // take the max_value of the first element to avoid one max call\n  acc_t max_value[WARP_BATCH];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    max_value[i] = elements[i][0];\n  }\n\n#pragma unroll\n  for (int it = 1; it < WARP_ITERATIONS; ++it) {\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n    }\n  }\n\n// reduction max_value\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n    float val[WARP_BATCH];\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n    }\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n    }\n  }\n\n  // compute local sum\n  acc_t sum[WARP_BATCH]{0.0f};\n\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n      sum[i] += elements[i][it];\n    }\n  }\n\n// reduction sum\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n    }\n  }\n  auto seeds = at::cuda::philox::unpack(philox_args);\n  Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds));\n  uint8_t rands[WARP_BATCH][WARP_ITERATIONS];\n  float4 rand_num;\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    if (i >= local_batches) break;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n      if (element_index < element_count) {\n        rand_num = uniform4(ph());\n        rands[i][it] = (rand_num.x <= p) > 0.5;\n        rands[i][it + 1] = (rand_num.y <= p) > 0.5;\n        rands[i][it + 2] = (rand_num.z <= p) > 0.5;\n        rands[i][it + 3] = (rand_num.w <= p) > 0.5;\n        copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(dropout_mask + i * element_count + it * WARP_SIZE, &rands[i][it]);\n      }\n    }\n  }\n\n// store result\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    if (i >= local_batches) break;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n      if (element_index < element_count) {\n        output_t out[ELEMENTS_PER_LDG_STG];\n#pragma unroll\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          out[element] = rands[i][it + element] * (pinv * (elements[i][it + element] / sum[i]));\n        }\n        copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);\n\n      } else {\n        break;\n      }\n    }\n  }\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE,\n          int ELEMENTS_PER_LDG_STG>\n__global__ void additive_masked_softmax_dropout_warp_forward(output_t* dst, uint8_t* dropout_mask, const input_t* src,\n                                                             const input_t* pad_mask, int batch_size, int stride,\n                                                             int element_count, int pad_batch_stride,\n                                                             at::PhiloxCudaState philox_args, float p) {\n  assert(ELEMENTS_PER_LDG_STG == 1);\n  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n  int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x;\n  acc_t pinv = acc_t(1) / p;\n  // batch_size might not be a multiple of WARP_BATCH. Check how\n  // many batches have to computed within this WARP.\n  int local_batches = batch_size - first_batch;\n  if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;\n\n  // there might be multiple batches per warp. compute the index within the\n  // batch\n  int local_idx = threadIdx.x;\n  // vectorize if element_count is multiple of 4, else don't vectorize\n  input_t elements_input[WARP_BATCH][WARP_ITERATIONS];\n\n  int thread_offset = first_batch * stride + local_idx;\n  src += thread_offset;\n  dst += thread_offset;\n  dropout_mask += thread_offset;\n\n  // load data from global memory\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    int batch_element_count = (i >= local_batches) ? 0 : element_count;\n    int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + local_idx;\n    const half* curr_mask = pad_mask + pad_thread_offset;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += 1) {\n      int element_index = local_idx + it * WARP_SIZE;\n#pragma unroll\n      for (int element = 0; element < 1; ++element) {\n        // masking_value is a large negative value\n        elements_input[i][it + element] = -10000;\n      }\n\n      if (element_index < batch_element_count) {\n        int itr_jmp = it * WARP_SIZE;\n        int itr_idx = i * element_count + itr_jmp;\n        copy_vector<input_t, 1>(&elements_input[i][it], src + itr_idx);\n        apply_additive_mask<input_t, 1>(&elements_input[i][it], curr_mask + itr_jmp);\n      }\n    }\n  }\n  // convert input_t to acc_t\n  acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      elements[i][it] = elements_input[i][it];\n    }\n  }\n\n  constexpr uint32_t FULL_MASK = 0xffffffff;\n\n  // compute local max_value\n\n  // take the max_value of the first element to avoid one max call\n  acc_t max_value[WARP_BATCH];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    max_value[i] = elements[i][0];\n  }\n\n#pragma unroll\n  for (int it = 1; it < WARP_ITERATIONS; ++it) {\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n    }\n  }\n\n// reduction max_value\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n    float val[WARP_BATCH];\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n    }\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n    }\n  }\n\n  // compute local sum\n  acc_t sum[WARP_BATCH]{0.0f};\n\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n      sum[i] += elements[i][it];\n    }\n  }\n\n// reduction sum\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n    }\n  }\n  curandStatePhilox4_32_10_t state;\n  auto seeds = at::cuda::philox::unpack(philox_args);\n  curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state);\n\n// store result\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    if (i >= local_batches) break;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += 1) {\n      int element_index = local_idx + it * WARP_SIZE;\n      if (element_index < element_count) {\n        output_t out[1];\n        acc_t softmax_out[1];\n        uint8_t dropout_mask_temp[1];\n        // generate a vector of random numbers here\n        float rand = curand_uniform(&state);\n        float* rand_ptr = (float*)(&rand);\n#pragma unroll\n        for (int element = 0; element < 1; ++element) {\n          softmax_out[element] = (elements[i][it + element] / sum[i]);\n          rand_ptr[element] = rand_ptr[element] <= p;\n          out[element] = rand_ptr[element] * pinv * softmax_out[element];\n          dropout_mask_temp[element] = rand_ptr[element] > 0.5;  // just to distinguish 0.0f and 1.0f\n        }\n        copy_vector<output_t, 1>(dst + i * element_count + it * WARP_SIZE, out);\n        copy_vector<uint8_t, 1>(dropout_mask + i * element_count + it * WARP_SIZE, dropout_mask_temp);\n\n      } else {\n        break;\n      }\n    }\n  }\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate\n// over all data. WARP_SIZE number of elements working on a single batch, has to\n// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t, typename acc_t>\nusing additive_masked_softmax_dropout_forward_func = void (*)(output_t* dst, uint8_t* dropout_mask, const input_t* src,\n                                                              const input_t* pad_mask, int batch_size, int stride,\n                                                              int element_count, int pad_batch_stride,\n                                                              at::PhiloxCudaState philox_args, float p);\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_additive_masked_softmax_dropout_kernel(\n    int element_count, int log2_elements, int& warp_size, int& batches_per_warp,\n    additive_masked_softmax_dropout_forward_func<input_t, output_t, acc_t>& kernel) {\n  // determine size of a warp\n  const int next_power_of_two = 1 << log2_elements;\n  warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n\n  // determine how many batches a warp should process.\n  batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n  bool flag_vec4 = (element_count % 4 == 0);\n  switch (log2_elements) {\n    case 0:  // 1\n      kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2, 1, 1, 1>;\n      break;\n    case 1:  // 2\n      kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2, 1, 2, 1>;\n      break;\n    case 2:  // 4\n      kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2, 1, 4, 1>;\n      break;\n    case 3:  // 8\n      kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2, 1, 8, 1>;\n      break;\n    case 4:  // 16\n      kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2, 1, 16, 1>;\n      break;\n    case 5:  // 32\n      kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2, 1, 32, 1>;\n      break;\n    case 6:  // 64\n      kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2, 2, 32, 1>;\n      break;\n    case 7:  // 128\n      if (flag_vec4)\n        kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 2, 4, 32, 4>;\n      else\n        kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2, 4, 32, 1>;\n      break;\n    case 8:  // 256\n      if (flag_vec4)\n        kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1, 8, 32, 4>;\n      else\n        kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1, 8, 32, 1>;\n      break;\n    case 9:  // 512\n      if (flag_vec4)\n        kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1, 16, 32, 4>;\n      else\n        kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1, 16, 32, 1>;\n      break;\n    case 10:  // 1024\n      if (flag_vec4)\n        kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1, 32, 32, 4>;\n      else\n        kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1, 32, 32, 1>;\n      break;\n    case 11:  // 2048\n      if (flag_vec4)\n        kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1, 64, 32, 4>;\n      else\n        kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 1, 64, 32, 1>;\n      break;\n    default:\n      return false;\n  }\n  return true;\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nbool dispatch_additive_masked_softmax_dropout(output_t* dst, uint8_t* dropout_mask, const input_t* src,\n                                              const input_t* pad_mask, int totalElements, int softmax_elements,\n                                              int softmax_elements_stride, int batch_count, int pad_batch_stride,\n                                              float p,\n                                              cudaStream_t streamid)  // p is the probability to keep, not drop\n{\n  if (softmax_elements == 0) {\n    return true;\n  } else if (softmax_elements <= 2048) {\n    // compute function index. there's a function for each power of two size up\n    // to 1024.\n    int log2_elements = 0;\n    while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n\n    additive_masked_softmax_dropout_forward_func<input_t, output_t, acc_t> kernel;\n    int warp_size, batches_per_warp;\n    if (!warp_additive_masked_softmax_dropout_kernel<input_t, output_t, acc_t>(softmax_elements, log2_elements,\n                                                                               warp_size, batches_per_warp, kernel)) {\n      return false;\n    }\n\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n    // compute warps per block.\n    int warps_per_block = (threads_per_block / warp_size);\n    int batches_per_block = warps_per_block * batches_per_warp;\n    int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n    c10::optional<at::Generator> gen_;\n    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());\n    int64_t counter_offset = (totalElements / (blocks * threads_per_block) + 1);\n    at::PhiloxCudaState rng_engine_inputs;\n    {\n      std::lock_guard<std::mutex> lock(gen->mutex_);\n      rng_engine_inputs = gen->philox_cuda_state(counter_offset);\n    }\n\n    // compute launch size\n    dim3 threads(warp_size, warps_per_block, 1);\n\n    // launch\n    kernel<<<blocks, threads, 0, streamid>>>(dst, dropout_mask, src, pad_mask, batch_count, softmax_elements_stride,\n                                             softmax_elements, pad_batch_stride, rng_engine_inputs, p);\n    return true;\n  }\n  return false;\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate\n// over all data. WARP_SIZE number of elements working on a single batch, has to\n// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32,\n          int ELEMENTS_PER_LDG_STG = 1>\n__global__ void additive_masked_softmax_warp_forward(input_t* dst, const output_t* src, const input_t* pad_mask,\n                                                     int batch_size, int stride, int element_count,\n                                                     int pad_batch_stride) {\n  assert(ELEMENTS_PER_LDG_STG == 1);\n\n  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n\n  // batch_size might not be a multiple of WARP_BATCH. Check how\n  // many batches have to computed within this WARP.\n  int local_batches = batch_size - first_batch;\n  if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;\n\n  // there might be multiple batches per warp. compute the index within the\n  // batch\n  int local_idx = threadIdx.x;\n\n  int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n  src += thread_offset;\n  dst += thread_offset;\n\n  // load data from global memory\n  input_t elements_input[WARP_BATCH][WARP_ITERATIONS];\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    int batch_element_count = (i >= local_batches) ? 0 : element_count;\n    int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    const half* curr_mask = pad_mask + pad_thread_offset;\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n#pragma unroll\n      for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n        // masking_value is a large negative value\n        elements_input[i][it + element] = -10000;\n      }\n\n      if (element_index < batch_element_count) {\n        int itr_jmp = it * WARP_SIZE;\n        int itr_idx = i * element_count + itr_jmp;\n        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + itr_idx);\n        // apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it],\n        //                                          (__half)-std::numeric_limits<float>::infinity(),\n        //                                          curr_mask + itr_jmp);\n        elements_input[i][it] += *(curr_mask + itr_jmp);\n      }\n    }\n  }\n\n  // convert input_t to acc_t\n  acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      elements[i][it] = elements_input[i][it];\n    }\n  }\n\n  constexpr uint32_t FULL_MASK = 0xffffffff;\n\n  // compute local max_value\n\n  // take the max_value of the first element to avoid one max call\n  acc_t max_value[WARP_BATCH];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    max_value[i] = elements[i][0];\n  }\n\n#pragma unroll\n  for (int it = 1; it < WARP_ITERATIONS; ++it) {\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n    }\n  }\n\n// reduction max_value\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n    float val[WARP_BATCH];\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n    }\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n    }\n  }\n\n  // compute local sum\n  acc_t sum[WARP_BATCH]{0.0f};\n\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      // elements[i][it] = expf(elements[i][it] - max_value[i]);\n      elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n      sum[i] += elements[i][it];\n    }\n  }\n\n// reduction sum\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n    }\n  }\n\n// store result\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    if (i >= local_batches) break;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n      if (element_index < element_count) {\n        // dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];\n        output_t out[ELEMENTS_PER_LDG_STG];\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          out[element] = elements[i][it + element] / sum[i];\n        }\n        copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);\n      } else {\n        break;\n      }\n    }\n  }\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate\n// over all data. WARP_SIZE number of elements working on a single batch, has to\n// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t>\nusing additive_masked_softmax_forward_func = void (*)(input_t* dst, const output_t* src, const half* pad_mask,\n                                                      int batch_size, int stride, int element_count,\n                                                      int pad_batch_stride);\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_additive_masked_softmax_kernel(int log2_elements, int& warp_size, int& batches_per_warp,\n                                         additive_masked_softmax_forward_func<input_t, output_t>& kernel) {\n  // determine size of a warp\n  const int next_power_of_two = 1 << log2_elements;\n  warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n\n  // determine how many batches a warp should process.\n  batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n\n  switch (log2_elements) {\n    case 0:  // 1\n      kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 1, 1>;\n      break;\n    case 1:  // 2\n      kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 2, 1>;\n      break;\n    case 2:  // 4\n      kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 4, 1>;\n      break;\n    case 3:  // 8\n      kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 8, 1>;\n      break;\n    case 4:  // 16\n      kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 16, 1>;\n      break;\n    case 5:  // 32\n      kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 32, 1>;\n      break;\n    case 6:  // 64\n      kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 2, 32, 1>;\n      break;\n    case 7:  // 128\n      kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 4, 32, 1>;\n      break;\n    case 8:  // 256\n      kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 8, 32, 1>;\n      break;\n    case 9:  // 512\n      kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 16, 32, 1>;\n      break;\n    case 10:  // 1024\n      kernel = &additive_masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 32, 32, 1>;\n      break;\n    default:\n      return false;\n  }\n  return true;\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nbool dispatch_additive_masked_softmax(output_t* dst, const input_t* src, const input_t* pad_mask, int softmax_elements,\n                                      int softmax_elements_stride, int batch_count, int pad_batch_stride) {\n  if (softmax_elements == 0) {\n    return true;\n  } else if (softmax_elements <= 1024) {\n    // compute function index. there's a function for each power of two size up\n    // to 1024.\n    int log2_elements = 0;\n    while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n\n    additive_masked_softmax_forward_func<input_t, output_t> kernel;\n    int warp_size, batches_per_warp;\n    if (!warp_additive_masked_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp,\n                                                                       kernel)) {\n      return false;\n    }\n\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n\n    // compute warps per block.\n    int warps_per_block = (threads_per_block / warp_size);\n\n    // compute launch size\n    int batches_per_block = warps_per_block * batches_per_warp;\n    int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n    dim3 threads(warp_size, warps_per_block, 1);\n\n    // launch\n    kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(\n        dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);\n    return true;\n  }\n  return false;\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nbool dispatch_additive_masked_softmax_stream(output_t* dst, const input_t* src, const input_t* pad_mask,\n                                             int softmax_elements, int softmax_elements_stride, int batch_count,\n                                             int pad_batch_stride, cudaStream_t streamid) {\n  if (softmax_elements == 0) {\n    return true;\n  } else if (softmax_elements <= 1024) {\n    // compute function index. there's a function for each power of two size up\n    // to 1024.\n    int log2_elements = 0;\n    while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n    additive_masked_softmax_forward_func<input_t, output_t> kernel;\n    int warp_size, batches_per_warp;\n    if (!warp_additive_masked_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp,\n                                                                       kernel)) {\n      return false;\n    }\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n    // compute warps per block.\n    int warps_per_block = (threads_per_block / warp_size);\n    // compute launch size\n    int batches_per_block = warps_per_block * batches_per_warp;\n    int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n    dim3 threads(warp_size, warps_per_block, 1);\n    // launch\n    kernel<<<blocks, threads, 0, streamid>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements,\n                                             pad_batch_stride);\n    return true;\n  }\n  return false;\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate\n// over all data. WARP_SIZE number of elements working on a single batch, has to\n// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32,\n          int ELEMENTS_PER_LDG_STG = 1>\n__global__ void masked_softmax_warp_forward(input_t* dst, const output_t* src, const uint8_t* pad_mask, int batch_size,\n                                            int stride, int element_count, int pad_batch_stride) {\n  assert(ELEMENTS_PER_LDG_STG == 1);\n\n  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n\n  // batch_size might not be a multiple of WARP_BATCH. Check how\n  // many batches have to computed within this WARP.\n  int local_batches = batch_size - first_batch;\n  if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;\n\n  // there might be multiple batches per warp. compute the index within the\n  // batch\n  int local_idx = threadIdx.x;\n\n  int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n  src += thread_offset;\n  dst += thread_offset;\n\n  // load data from global memory\n  input_t elements_input[WARP_BATCH][WARP_ITERATIONS];\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    int batch_element_count = (i >= local_batches) ? 0 : element_count;\n    int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    const uint8_t* curr_mask = pad_mask + pad_thread_offset;\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n#pragma unroll\n      for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n        elements_input[i][it + element] = -std::numeric_limits<float>::infinity();\n      }\n\n      if (element_index < batch_element_count) {\n        int itr_jmp = it * WARP_SIZE;\n        int itr_idx = i * element_count + itr_jmp;\n        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + itr_idx);\n        apply_mask<input_t, ELEMENTS_PER_LDG_STG>(\n            &elements_input[i][it], __float2half(-std::numeric_limits<float>::infinity()), curr_mask + itr_jmp);\n      }\n    }\n  }\n\n  // convert input_t to acc_t\n  acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      elements[i][it] = elements_input[i][it];\n    }\n  }\n\n  constexpr uint32_t FULL_MASK = 0xffffffff;\n\n  // compute local max_value\n\n  // take the max_value of the first element to avoid one max call\n  acc_t max_value[WARP_BATCH];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    max_value[i] = elements[i][0];\n  }\n\n#pragma unroll\n  for (int it = 1; it < WARP_ITERATIONS; ++it) {\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n    }\n  }\n\n// reduction max_value\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n    float val[WARP_BATCH];\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n    }\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n    }\n  }\n\n  // compute local sum\n  acc_t sum[WARP_BATCH]{0.0f};\n\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      // elements[i][it] = expf(elements[i][it] - max_value[i]);\n      elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n      sum[i] += elements[i][it];\n    }\n  }\n\n// reduction sum\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n    }\n  }\n\n// store result\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    if (i >= local_batches) break;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n      if (element_index < element_count) {\n        // dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];\n        output_t out[ELEMENTS_PER_LDG_STG];\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          out[element] = elements[i][it + element] / sum[i];\n        }\n        copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);\n      } else {\n        break;\n      }\n    }\n  }\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate\n// over all data. WARP_SIZE number of elements working on a single batch, has to\n// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t>\nusing masked_softmax_forward_func = void (*)(input_t* dst, const output_t* src, const uint8_t* pad_mask, int batch_size,\n                                             int stride, int element_count, int pad_batch_stride);\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_masked_softmax_kernel(int log2_elements, int& warp_size, int& batches_per_warp,\n                                masked_softmax_forward_func<input_t, output_t>& kernel) {\n  // determine size of a warp\n  const int next_power_of_two = 1 << log2_elements;\n  warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n\n  // determine how many batches a warp should process.\n  batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n\n  switch (log2_elements) {\n    case 0:  // 1\n      kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 1, 1>;\n      break;\n    case 1:  // 2\n      kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 2, 1>;\n      break;\n    case 2:  // 4\n      kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 4, 1>;\n      break;\n    case 3:  // 8\n      kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 8, 1>;\n      break;\n    case 4:  // 16\n      kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 16, 1>;\n      break;\n    case 5:  // 32\n      kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 32, 1>;\n      break;\n    case 6:  // 64\n      kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 2, 32, 1>;\n      break;\n    case 7:  // 128\n      kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 4, 32, 1>;\n      break;\n    case 8:  // 256\n      kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 8, 32, 1>;\n      break;\n    case 9:  // 512\n      kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 16, 32, 1>;\n      break;\n    case 10:  // 1024\n      kernel = &masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 32, 32, 1>;\n      break;\n    default:\n      return false;\n  }\n  return true;\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nbool dispatch_masked_softmax(output_t* dst, const input_t* src, const uint8_t* pad_mask, int softmax_elements,\n                             int softmax_elements_stride, int batch_count, int pad_batch_stride) {\n  if (softmax_elements == 0) {\n    return true;\n  } else if (softmax_elements <= 1024) {\n    // compute function index. there's a function for each power of two size up\n    // to 1024.\n    int log2_elements = 0;\n    while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n\n    masked_softmax_forward_func<input_t, output_t> kernel;\n    int warp_size, batches_per_warp;\n    if (!warp_masked_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n      return false;\n    }\n\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n\n    // compute warps per block.\n    int warps_per_block = (threads_per_block / warp_size);\n\n    // compute launch size\n    int batches_per_block = warps_per_block * batches_per_warp;\n    int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n    dim3 threads(warp_size, warps_per_block, 1);\n\n    // launch\n    kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(\n        dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);\n    return true;\n  }\n  return false;\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate\n// over all data. WARP_SIZE number of elements working on a single batch, has to\n// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32,\n          int ELEMENTS_PER_LDG_STG = 1>\n__global__ void time_masked_softmax_warp_forward(input_t* dst, const output_t* src, const uint8_t* pad_mask,\n                                                 int batch_size, int stride, int element_count, int mod_seq_len) {\n  assert(ELEMENTS_PER_LDG_STG == 1);\n\n  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n\n  // batch_size might not be a multiple of WARP_BATCH. Check how\n  // many batches have to computed within this WARP.\n  int local_batches = batch_size - first_batch;\n  if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;\n\n  // there might be multiple batches per warp. compute the index within the\n  // batch\n  int local_idx = threadIdx.x;\n\n  int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n  src += thread_offset;\n  dst += thread_offset;\n\n  // load data from global memory\n  input_t elements_input[WARP_BATCH][WARP_ITERATIONS];\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    int batch_element_count = (i >= local_batches) ? 0 : element_count;\n    int pad_thread_offset = ((first_batch + i) % mod_seq_len) * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    const uint8_t* curr_mask = pad_mask + pad_thread_offset;\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n#pragma unroll\n      for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n        elements_input[i][it + element] = -std::numeric_limits<float>::infinity();\n      }\n\n      if (element_index < batch_element_count) {\n        int itr_jmp = it * WARP_SIZE;\n        int itr_idx = i * element_count + itr_jmp;\n        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], src + itr_idx);\n        apply_mask<input_t, ELEMENTS_PER_LDG_STG>(\n            &elements_input[i][it], __float2half(-std::numeric_limits<float>::infinity()), curr_mask + itr_jmp);\n      }\n    }\n  }\n\n  // convert input_t to acc_t\n  acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      elements[i][it] = elements_input[i][it];\n    }\n  }\n\n  constexpr uint32_t FULL_MASK = 0xffffffff;\n\n  // compute local max_value\n\n  // take the max_value of the first element to avoid one max call\n  acc_t max_value[WARP_BATCH];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    max_value[i] = elements[i][0];\n  }\n\n#pragma unroll\n  for (int it = 1; it < WARP_ITERATIONS; ++it) {\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n    }\n  }\n\n// reduction max_value\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n    float val[WARP_BATCH];\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n    }\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n    }\n  }\n\n  // compute local sum\n  acc_t sum[WARP_BATCH]{0.0f};\n\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      // elements[i][it] = expf(elements[i][it] - max_value[i]);\n      elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n      sum[i] += elements[i][it];\n    }\n  }\n\n// reduction sum\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n    }\n  }\n\n// store result\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    if (i >= local_batches) break;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n      if (element_index < element_count) {\n        // dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];\n        output_t out[ELEMENTS_PER_LDG_STG];\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          out[element] = elements[i][it + element] / sum[i];\n        }\n        copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);\n      } else {\n        break;\n      }\n    }\n  }\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate\n// over all data. WARP_SIZE number of elements working on a single batch, has to\n// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t>\nusing time_masked_softmax_forward_func = void (*)(input_t* dst, const output_t* src, const uint8_t* pad_mask,\n                                                  int batch_size, int stride, int element_count, int mod_seq_len);\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_time_masked_softmax_kernel(int log2_elements, int& warp_size, int& batches_per_warp,\n                                     time_masked_softmax_forward_func<input_t, output_t>& kernel) {\n  // determine size of a warp\n  const int next_power_of_two = 1 << log2_elements;\n  warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n\n  // determine how many batches a warp should process.\n  batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n\n  switch (log2_elements) {\n    case 0:  // 1\n      kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 1, 1>;\n      break;\n    case 1:  // 2\n      kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 2, 1>;\n      break;\n    case 2:  // 4\n      kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 4, 1>;\n      break;\n    case 3:  // 8\n      kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 8, 1>;\n      break;\n    case 4:  // 16\n      kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 16, 1>;\n      break;\n    case 5:  // 32\n      kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 1, 32, 1>;\n      break;\n    case 6:  // 64\n      kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 2, 32, 1>;\n      break;\n    case 7:  // 128\n      kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 2, 4, 32, 1>;\n      break;\n    case 8:  // 256\n      kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 8, 32, 1>;\n      break;\n    case 9:  // 512\n      kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 16, 32, 1>;\n      break;\n    case 10:  // 1024\n      kernel = &time_masked_softmax_warp_forward<input_t, output_t, acc_t, 1, 32, 32, 1>;\n      break;\n    default:\n      return false;\n  }\n  return true;\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nbool dispatch_time_masked_softmax(output_t* dst, const input_t* src, const uint8_t* pad_mask, int softmax_elements,\n                                  int softmax_elements_stride, int batch_count, int mod_seq_len) {\n  if (softmax_elements == 0) {\n    return true;\n  } else if (softmax_elements <= 1024) {\n    // compute function index. there's a function for each power of two size up\n    // to 1024.\n    int log2_elements = 0;\n    while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n\n    time_masked_softmax_forward_func<input_t, output_t> kernel;\n    int warp_size, batches_per_warp;\n    if (!warp_time_masked_softmax_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp,\n                                                                   kernel)) {\n      return false;\n    }\n\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n\n    // compute warps per block.\n    int warps_per_block = (threads_per_block / warp_size);\n\n    // compute launch size\n    int batches_per_block = warps_per_block * batches_per_warp;\n    int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n    dim3 threads(warp_size, warps_per_block, 1);\n\n    // launch\n    kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(\n        dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, mod_seq_len);\n    return true;\n  }\n  return false;\n}\n\nint log2_ceil_native(int value) {\n  int log2_value = 0;\n  while ((1 << log2_value) < value) ++log2_value;\n  return log2_value;\n}\n\ntemplate <typename T>\n__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,\n                                                  unsigned int mask = 0xffffffff) {\n#if CUDA_VERSION >= 9000\n  return __shfl_xor_sync(mask, value, laneMask, width);\n#else\n  return __shfl_xor(value, laneMask, width);\n#endif\n}\n\ntemplate <typename acc_t, int WARP_BATCH, int WARP_SIZE>\n__device__ __forceinline__ void warp_reduce_sum(acc_t* sum) {\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);\n      sum[i] = sum[i] + b;\n    }\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// Warp softmax backward functions as fused variants of\n// at::softmax_backward_data function\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// softmax backward data function is taken from native pytorch, elementwise mul\n// is fused in the epolog, as well as masking and scaling for fusing dropout\n\ntemplate <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>\n__global__ void masked_scale_softmax_warp_backward_masked_dgrad(output_t* gradInput, const input_t* grad,\n                                                                const input_t* output, const uint8_t* mask,\n                                                                const uint8_t* pad_mask, acc_t scale, int batch_size,\n                                                                int stride, int element_count, int heads) {\n  // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and\n  // warp_size of method warp_softmax_backward_kernel.\n  constexpr int next_power_of_two = 1 << log2_elements;\n  constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n  constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;\n  constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;\n\n  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n\n  // batch_size might not be a multiple of WARP_BATCH. Check how\n  // many batches have to computed within this WARP.\n  int local_batches = batch_size - first_batch;\n  if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;\n\n  // there might be multiple batches per warp. compute the index within the\n  // batch\n  int local_idx = threadIdx.x % WARP_SIZE;\n\n  // the first element to process by the current thread\n  int thread_offset = first_batch * stride + local_idx;\n  grad += thread_offset;\n  output += thread_offset;\n  gradInput += thread_offset;\n  mask += thread_offset;\n\n  // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified\n  // to one loop, but I think doing so would obfuscate the logic of the\n  // algorithm, thus I chose to keep the nested loops. This should have no\n  // impact on performance because the loops are unrolled anyway.\n\n  // load data from global memory\n  acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];\n  acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    int batch_element_count = (i >= local_batches) ? 0 : element_count;\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      int element_index = local_idx + it * WARP_SIZE;\n      if (element_index < batch_element_count) {\n        grad_reg[i][it] = (input_t)((acc_t)mask[i * element_count + it * WARP_SIZE] *\n                                    (acc_t)grad[i * element_count + it * WARP_SIZE] * (acc_t)scale) *\n                          output[i * element_count + it * WARP_SIZE];\n        output_reg[i][it] = output[i * element_count + it * WARP_SIZE];\n      } else {\n        grad_reg[i][it] = acc_t(0);\n        output_reg[i][it] = acc_t(0);\n      }\n    }\n  }\n\n  acc_t sum[WARP_BATCH];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    sum[i] = grad_reg[i][0];\n#pragma unroll\n    for (int it = 1; it < WARP_ITERATIONS; ++it) {\n      sum[i] += grad_reg[i][it];\n    }\n  }\n  warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(sum);\n\n// store result\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    if (i >= local_batches) break;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      int element_index = local_idx + it * WARP_SIZE;\n      if (element_index < element_count) {\n        // compute gradients\n        int total_ind = thread_offset + i * element_count + it * WARP_SIZE;\n        int pad_mask_ind =\n            element_count * (total_ind / (heads * element_count * element_count)) + total_ind % element_count;\n        uint8_t pad_mask_element = 1 - pad_mask[pad_mask_ind];\n        if (pad_mask_element == 0)\n          gradInput[i * element_count + it * WARP_SIZE] = 0;\n        else {\n          if (is_log_softmax) {\n            gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);\n          } else {\n            gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]);\n          }\n        }\n      }\n    }\n  }\n}\ntemplate <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nvoid dispatch_masked_scale_softmax_backward_masked_out(output_t* grad_input, const input_t* grad, const input_t* output,\n                                                       const uint8_t* mask, const uint8_t* pad_mask, acc_t scale,\n                                                       int softmax_elements, int softmax_elements_stride,\n                                                       int batch_count, int heads) {\n  TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024);\n  if (softmax_elements == 0) {\n    return;\n  } else {\n    int log2_elements = log2_ceil_native(softmax_elements);\n    const int next_power_of_two = 1 << log2_elements;\n\n    // This value must match the WARP_SIZE constexpr value computed inside\n    // softmax_warp_backward.\n    int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n\n    // This value must match the WARP_BATCH constexpr value computed inside\n    // softmax_warp_backward.\n    int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n\n    int warps_per_block = (threads_per_block / warp_size);\n    int batches_per_block = warps_per_block * batches_per_warp;\n    int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n    dim3 threads(warp_size, warps_per_block, 1);\n    // Launch code would be more elegant if C++ supported FOR CONSTEXPR\n    switch (log2_elements) {\n      case 0:  // 1\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 0, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale,\n                                                                       batch_count, softmax_elements_stride,\n                                                                       softmax_elements, heads);\n        break;\n      case 1:  // 2\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 1, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale,\n                                                                       batch_count, softmax_elements_stride,\n                                                                       softmax_elements, heads);\n        break;\n      case 2:  // 4\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 2, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale,\n                                                                       batch_count, softmax_elements_stride,\n                                                                       softmax_elements, heads);\n        break;\n      case 3:  // 8\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 3, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale,\n                                                                       batch_count, softmax_elements_stride,\n                                                                       softmax_elements, heads);\n        break;\n      case 4:  // 16\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 4, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale,\n                                                                       batch_count, softmax_elements_stride,\n                                                                       softmax_elements, heads);\n        break;\n      case 5:  // 32\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 5, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale,\n                                                                       batch_count, softmax_elements_stride,\n                                                                       softmax_elements, heads);\n        break;\n      case 6:  // 64\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 6, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale,\n                                                                       batch_count, softmax_elements_stride,\n                                                                       softmax_elements, heads);\n        break;\n      case 7:  // 128\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 7, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale,\n                                                                       batch_count, softmax_elements_stride,\n                                                                       softmax_elements, heads);\n        break;\n      case 8:  // 256\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 8, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale,\n                                                                       batch_count, softmax_elements_stride,\n                                                                       softmax_elements, heads);\n        break;\n      case 9:  // 512\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 9, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale,\n                                                                       batch_count, softmax_elements_stride,\n                                                                       softmax_elements, heads);\n        break;\n      case 10:  // 1024\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 10, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, mask, pad_mask, scale,\n                                                                       batch_count, softmax_elements_stride,\n                                                                       softmax_elements, heads);\n        break;\n      default:\n        break;\n    }\n  }\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nvoid dispatch_masked_scale_softmax_backward_masked_out_stream(output_t* grad_input, const input_t* grad,\n                                                              const input_t* output, const uint8_t* mask,\n                                                              const uint8_t* pad_mask, acc_t scale,\n                                                              int softmax_elements, int softmax_elements_stride,\n                                                              int batch_count, int heads, cudaStream_t streamid) {\n  TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024);\n  if (softmax_elements == 0) {\n    return;\n  } else {\n    int log2_elements = log2_ceil_native(softmax_elements);\n    const int next_power_of_two = 1 << log2_elements;\n    // This value must match the WARP_SIZE constexpr value computed inside\n    // softmax_warp_backward.\n    int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n    // This value must match the WARP_BATCH constexpr value computed inside\n    // softmax_warp_backward.\n    int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n    int warps_per_block = (threads_per_block / warp_size);\n    int batches_per_block = warps_per_block * batches_per_warp;\n    int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n    dim3 threads(warp_size, warps_per_block, 1);\n    // Launch code would be more elegant if C++ supported FOR CONSTEXPR\n    switch (log2_elements) {\n      case 0:  // 1\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 0, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements, heads);\n        break;\n      case 1:  // 2\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 1, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements, heads);\n        break;\n      case 2:  // 4\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 2, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements, heads);\n        break;\n      case 3:  // 8\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 3, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements, heads);\n        break;\n      case 4:  // 16\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 4, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements, heads);\n        break;\n      case 5:  // 32\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 5, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements, heads);\n        break;\n      case 6:  // 64\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 6, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements, heads);\n        break;\n      case 7:  // 128\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 7, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements, heads);\n        break;\n      case 8:  // 256\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 8, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements, heads);\n        break;\n      case 9:  // 512\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 9, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements, heads);\n        break;\n      case 10:  // 1024\n        masked_scale_softmax_warp_backward_masked_dgrad<input_t, output_t, acc_t, 10, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, pad_mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements, heads);\n        break;\n      default:\n        break;\n    }\n  }\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>\n__global__ void masked_scale_softmax_warp_backward(output_t* gradInput, const input_t* grad, const input_t* output,\n                                                   const uint8_t* mask, acc_t scale, int batch_size, int stride,\n                                                   int element_count) {\n  // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and\n  // warp_size of method warp_softmax_backward_kernel.\n  constexpr int next_power_of_two = 1 << log2_elements;\n  constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n  constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;\n  constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;\n\n  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n\n  // batch_size might not be a multiple of WARP_BATCH. Check how\n  // many batches have to computed within this WARP.\n  int local_batches = batch_size - first_batch;\n  if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;\n\n  // there might be multiple batches per warp. compute the index within the\n  // batch\n  int local_idx = threadIdx.x % WARP_SIZE;\n\n  // the first element to process by the current thread\n  int thread_offset = first_batch * stride + local_idx;\n  grad += thread_offset;\n  output += thread_offset;\n  gradInput += thread_offset;\n  mask += thread_offset;\n\n  // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified\n  // to one loop, but I think doing so would obfuscate the logic of the\n  // algorithm, thus I chose to keep the nested loops. This should have no\n  // impact on performance because the loops are unrolled anyway.\n\n  // load data from global memory\n  acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];\n  acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    int batch_element_count = (i >= local_batches) ? 0 : element_count;\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      int element_index = local_idx + it * WARP_SIZE;\n      if (element_index < batch_element_count) {\n        grad_reg[i][it] = (input_t)((acc_t)mask[i * element_count + it * WARP_SIZE] *\n                                    (acc_t)grad[i * element_count + it * WARP_SIZE] * (acc_t)scale) *\n                          output[i * element_count + it * WARP_SIZE];\n        output_reg[i][it] = output[i * element_count + it * WARP_SIZE];\n      } else {\n        grad_reg[i][it] = acc_t(0);\n        output_reg[i][it] = acc_t(0);\n      }\n    }\n  }\n\n  acc_t sum[WARP_BATCH];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    sum[i] = grad_reg[i][0];\n#pragma unroll\n    for (int it = 1; it < WARP_ITERATIONS; ++it) {\n      sum[i] += grad_reg[i][it];\n    }\n  }\n  warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(sum);\n\n// store result\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    if (i >= local_batches) break;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      int element_index = local_idx + it * WARP_SIZE;\n      if (element_index < element_count) {\n        // compute gradients\n        if (is_log_softmax) {\n          gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);\n        } else {\n          gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]);\n        }\n      }\n    }\n  }\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32,\n          int ELEMENTS_PER_LDG_STG, bool is_log_softmax>\n__global__ void masked_scale_softmax_warp_backward_recompute(output_t* gradInput, const input_t* grad,\n                                                             const input_t* softmax_input, const input_t* pad_mask,\n                                                             const uint8_t* mask, acc_t scale, int batch_size,\n                                                             int stride, int pad_batch_stride, int element_count) {\n  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n\n  // batch_size might not be a multiple of WARP_BATCH. Check how\n  // many batches have to computed within this WARP.\n  int local_batches = batch_size - first_batch;\n  if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;\n\n  // there might be multiple batches per warp. compute the index within the\n  // batch\n  int local_idx = threadIdx.x % WARP_SIZE;\n  // vectorize if a row length is multiple of 4\n  int flag_vec4 = element_count & 3 == 0;\n  acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];\n  input_t elements_input[WARP_BATCH][WARP_ITERATIONS];\n\n  // the first element to process by the current thread\n  int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n\n  grad += thread_offset;\n  softmax_input += thread_offset;\n  gradInput += thread_offset;\n  mask += thread_offset;\n\n  // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified\n  // to one loop, but I think doing so would obfuscate the logic of the\n  // algorithm, thus I chose to keep the nested loops. This should have no\n  // impact on performance because the loops are unrolled anyway.\n\n  // load data from global memory\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    int batch_element_count = (i >= local_batches) ? 0 : element_count;\n    int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    const input_t* curr_mask = pad_mask + pad_thread_offset;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n\n#pragma unroll\n      for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n        // masking_value is a large negative value\n        elements_input[i][it + element] = -10000;\n        grad_reg[i][it + element] = acc_t(0);\n      }\n\n      if (element_index < batch_element_count) {\n        int itr_jmp = it * WARP_SIZE;\n        int itr_idx = i * element_count + itr_jmp;\n        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&elements_input[i][it], softmax_input + itr_idx);\n        apply_additive_mask<input_t, ELEMENTS_PER_LDG_STG>(\n            &elements_input[i][it],\n            curr_mask + itr_jmp);  //(__half)-std::numeric_limits<float>::infinity()\n        uint8_t mask_temp[ELEMENTS_PER_LDG_STG];\n        input_t grad_temp[ELEMENTS_PER_LDG_STG];\n        copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(&mask_temp[0], mask + itr_idx);\n        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&grad_temp[0], grad + itr_idx);\n#pragma unroll\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          grad_reg[i][it + element] = ((acc_t)mask_temp[element] * (acc_t)grad_temp[element] * (acc_t)scale);\n        }\n      }\n    }\n  }\n  // load data from global memory\n\n  // convert input_t to acc_t\n  // TODO : remove this, input is already acc_t type in register\n  acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      elements[i][it] = elements_input[i][it];\n    }\n  }\n\n  constexpr uint32_t FULL_MASK = 0xffffffff;\n\n  // compute local max_value\n\n  // take the max_value of the first element to avoid one max call\n  acc_t max_value[WARP_BATCH];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    max_value[i] = elements[i][0];\n  }\n\n#pragma unroll\n  for (int it = 1; it < WARP_ITERATIONS; ++it) {\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n    }\n  }\n\n// reduction max_value\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n    float val[WARP_BATCH];\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      val[i] = __shfl_xor_sync(FULL_MASK, max_value[i], offset, WARP_SIZE);\n    }\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];\n    }\n  }\n\n  // compute local sum\n  acc_t sum[WARP_BATCH]{0.0f};\n\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      // elements[i][it] = expf(elements[i][it] - max_value[i]);\n      elements[i][it] = std::exp(elements[i][it] - max_value[i]);\n      sum[i] += elements[i][it];\n    }\n  }\n\n// reduction sum\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n    }\n  }\n\n// store result\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    if (i >= local_batches) break;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it++) {\n      elements[i][it] = elements[i][it] / sum[i];\n      grad_reg[i][it] = grad_reg[i][it] * elements[i][it];\n    }\n  }\n\n  acc_t grad_sum[WARP_BATCH];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    grad_sum[i] = grad_reg[i][0];\n#pragma unroll\n    for (int it = 1; it < WARP_ITERATIONS; ++it) {\n      grad_sum[i] += grad_reg[i][it];\n    }\n  }\n  warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(grad_sum);\n\n// store result\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    if (i >= local_batches) break;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n      if (element_index < element_count) {\n        // compute gradients\n        output_t grad_input_reg[ELEMENTS_PER_LDG_STG];\n#pragma unroll\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; element++) {\n          if (is_log_softmax) {\n            grad_input_reg[element] = (grad_reg[i][it + element] - std::exp(elements[i][it + element]) * grad_sum[i]);\n          } else {\n            grad_input_reg[element] = (grad_reg[i][it + element] - elements[i][it + element] * grad_sum[i]);\n          }\n        }\n        copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, grad_input_reg);\n      }\n    }\n  }\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nusing masked_scale_softmax_warp_backward_recompute_func = void (*)(output_t* gradInput, const input_t* grad,\n                                                                   const input_t* softmax_input,\n                                                                   const input_t* pad_mask, const uint8_t* mask,\n                                                                   acc_t scale, int batch_size, int stride,\n                                                                   int pad_batch_stride, int element_count);\n\ntemplate <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nbool masked_scale_softmax_warp_backward_recompute_kernel(\n    int element_count, int log2_elements, int& warp_size, int& batches_per_warp,\n    masked_scale_softmax_warp_backward_recompute_func<input_t, output_t, acc_t, is_log_softmax>& kernel) {\n  // determine size of a warp\n  const int next_power_of_two = 1 << log2_elements;\n  warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n\n  // determine how many batches a warp should process.\n  batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n  bool flag_vec4 = (element_count % 4 == 0);\n  switch (log2_elements) {\n    case 0:  // 1\n      kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2, 1, 1, 1, is_log_softmax>;\n      break;\n    case 1:  // 2\n      kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2, 1, 2, 1, is_log_softmax>;\n      break;\n    case 2:  // 4\n      kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2, 1, 4, 1, is_log_softmax>;\n      break;\n    case 3:  // 8\n      kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2, 1, 8, 1, is_log_softmax>;\n      break;\n    case 4:  // 16\n      kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2, 1, 16, 1, is_log_softmax>;\n      break;\n    case 5:  // 32\n      kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2, 1, 32, 1, is_log_softmax>;\n      break;\n    case 6:  // 64\n      kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2, 2, 32, 1, is_log_softmax>;\n      break;\n    case 7:  // 128\n      kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 2, 4, 32, 1, is_log_softmax>;\n      break;\n    case 8:  // 256\n      if (flag_vec4)\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1, 8, 32, 4, is_log_softmax>;\n      else\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1, 8, 32, 1, is_log_softmax>;\n      break;\n    case 9:  // 512\n      if (flag_vec4)\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1, 16, 32, 4, is_log_softmax>;\n      else\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1, 16, 32, 1, is_log_softmax>;\n      break;\n    case 10:  // 1024\n      if (flag_vec4)\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1, 32, 32, 4, is_log_softmax>;\n      else\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1, 32, 32, 1, is_log_softmax>;\n      break;\n    case 11:  // 2048\n      if (flag_vec4)\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1, 64, 32, 4, is_log_softmax>;\n      else\n        kernel = &masked_scale_softmax_warp_backward_recompute<input_t, output_t, acc_t, 1, 64, 32, 1, is_log_softmax>;\n      break;\n    default:\n      return false;\n  }\n  return true;\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nbool dispatch_masked_scale_softmax_backward_recompute(output_t* grad_input, const input_t* grad,\n                                                      const input_t* softmax_input, const input_t* pad_mask,\n                                                      const uint8_t* mask, acc_t scale, int softmax_elements,\n                                                      int softmax_elements_stride, int pad_batch_stride,\n                                                      int batch_count, cudaStream_t streamid) {\n  if (softmax_elements == 0) {\n    return true;\n  } else if (softmax_elements <= 2048) {\n    // compute function index. there's a function for each power of two size up\n    // to 1024.\n    int log2_elements = 0;\n    while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n\n    masked_scale_softmax_warp_backward_recompute_func<input_t, output_t, acc_t, is_log_softmax> kernel;\n    int warp_size, batches_per_warp;\n    if (!masked_scale_softmax_warp_backward_recompute_kernel<input_t, output_t, acc_t, is_log_softmax>(\n            softmax_elements, log2_elements, warp_size, batches_per_warp, kernel)) {\n      return false;\n    }\n\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n    // compute warps per block.\n    int warps_per_block = (threads_per_block / warp_size);\n    int batches_per_block = warps_per_block * batches_per_warp;\n    int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n\n    // compute launch size\n    dim3 threads(warp_size, warps_per_block, 1);\n\n    // launch\n    kernel<<<blocks, threads, 0, streamid>>>(grad_input, grad, softmax_input, pad_mask, mask, scale, batch_count,\n                                             softmax_elements_stride, pad_batch_stride, softmax_elements);\n    return true;\n  }\n  return false;\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nvoid dispatch_masked_scale_softmax_backward_stream(output_t* grad_input, const input_t* grad, const input_t* output,\n                                                   const uint8_t* mask, acc_t scale, int softmax_elements,\n                                                   int softmax_elements_stride, int batch_count,\n                                                   cudaStream_t streamid) {\n  TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024);\n  if (softmax_elements == 0) {\n    return;\n  } else {\n    int log2_elements = log2_ceil_native(softmax_elements);\n    const int next_power_of_two = 1 << log2_elements;\n    // This value must match the WARP_SIZE constexpr value computed inside\n    // softmax_warp_backward.\n    int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n    // This value must match the WARP_BATCH constexpr value computed inside\n    // softmax_warp_backward.\n    int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n    int warps_per_block = (threads_per_block / warp_size);\n    int batches_per_block = warps_per_block * batches_per_warp;\n    int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n    dim3 threads(warp_size, warps_per_block, 1);\n    // Launch code would be more elegant if C++ supported FOR CONSTEXPR\n    switch (log2_elements) {\n      case 0:  // 1\n        masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 0, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements);\n        break;\n      case 1:  // 2\n        masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 1, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements);\n        break;\n      case 2:  // 4\n        masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 2, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements);\n        break;\n      case 3:  // 8\n        masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 3, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements);\n        break;\n      case 4:  // 16\n        masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 4, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements);\n        break;\n      case 5:  // 32\n        masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 5, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements);\n        break;\n      case 6:  // 64\n        masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 6, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements);\n        break;\n      case 7:  // 128\n        masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 7, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements);\n        break;\n      case 8:  // 256\n        masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 8, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements);\n        break;\n      case 9:  // 512\n        masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 9, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements);\n        break;\n      case 10:  // 1024\n        masked_scale_softmax_warp_backward<input_t, output_t, acc_t, 10, is_log_softmax>\n            <<<blocks, threads, 0, streamid>>>(grad_input, grad, output, mask, scale, batch_count,\n                                               softmax_elements_stride, softmax_elements);\n        break;\n      default:\n        break;\n    }\n  }\n}\n\n// elementwise multiplication called in at::softmax_backward_data is fused\n// inside softmax dgrad kernel as a result of fusion, intermediate\n// multiplication result is stored in fp32 in registers, instead of fp16\ntemplate <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>\n__global__ void softmax_warp_backward_fused_native(output_t* gradInput, const input_t* grad, const input_t* output,\n                                                   int batch_size, int stride, int element_count) {\n  // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and\n  // warp_size of method warp_softmax_backward_kernel.\n  constexpr int next_power_of_two = 1 << log2_elements;\n  constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n  constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;\n  constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;\n\n  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n\n  // batch_size might not be a multiple of WARP_BATCH. Check how\n  // many batches have to computed within this WARP.\n  int local_batches = batch_size - first_batch;\n  if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;\n\n  // there might be multiple batches per warp. compute the index within the\n  // batch\n  int local_idx = threadIdx.x % WARP_SIZE;\n\n  // the first element to process by the current thread\n  int thread_offset = first_batch * stride + local_idx;\n  grad += thread_offset;\n  output += thread_offset;\n  gradInput += thread_offset;\n\n  // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified\n  // to one loop, but I think doing so would obfuscate the logic of the\n  // algorithm, thus I chose to keep the nested loops. This should have no\n  // impact on performance because the loops are unrolled anyway.\n\n  // load data from global memory\n  acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];\n  acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    int batch_element_count = (i >= local_batches) ? 0 : element_count;\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      int element_index = local_idx + it * WARP_SIZE;\n      if (element_index < batch_element_count) {\n        grad_reg[i][it] = grad[i * element_count + it * WARP_SIZE] * output[i * element_count + it * WARP_SIZE];\n        output_reg[i][it] = output[i * element_count + it * WARP_SIZE];\n      } else {\n        grad_reg[i][it] = acc_t(0);\n        output_reg[i][it] = acc_t(0);\n      }\n    }\n  }\n\n  acc_t sum[WARP_BATCH];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    sum[i] = grad_reg[i][0];  //* output_reg[i][0];\n#pragma unroll\n    for (int it = 1; it < WARP_ITERATIONS; ++it) {\n      sum[i] += grad_reg[i][it];  // * output_reg[i][it];\n    }\n  }\n  warp_reduce_sum<acc_t, WARP_BATCH, WARP_SIZE>(sum);\n\n// store result\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    if (i >= local_batches) break;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      int element_index = local_idx + it * WARP_SIZE;\n      if (element_index < element_count) {\n        // compute gradients\n        if (is_log_softmax) {\n          gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);\n        } else {\n          gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]);\n        }\n      }\n    }\n  }\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>\nvoid dispatch_softmax_backward_fused_native(output_t* grad_input, const input_t* grad, const input_t* output,\n                                            int softmax_elements, int softmax_elements_stride, int batch_count) {\n  TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024);\n  if (softmax_elements == 0) {\n    return;\n  } else {\n    int log2_elements = log2_ceil_native(softmax_elements);\n    const int next_power_of_two = 1 << log2_elements;\n\n    // This value must match the WARP_SIZE constexpr value computed inside\n    // softmax_warp_backward.\n    int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n\n    // This value must match the WARP_BATCH constexpr value computed inside\n    // softmax_warp_backward.\n    int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n\n    int warps_per_block = (threads_per_block / warp_size);\n    int batches_per_block = warps_per_block * batches_per_warp;\n    int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n    dim3 threads(warp_size, warps_per_block, 1);\n    // Launch code would be more elegant if C++ supported FOR CONSTEXPR\n    switch (log2_elements) {\n      case 0:  // 1\n        softmax_warp_backward_fused_native<input_t, output_t, acc_t, 0, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 1:  // 2\n        softmax_warp_backward_fused_native<input_t, output_t, acc_t, 1, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 2:  // 4\n        softmax_warp_backward_fused_native<input_t, output_t, acc_t, 2, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 3:  // 8\n        softmax_warp_backward_fused_native<input_t, output_t, acc_t, 3, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 4:  // 16\n        softmax_warp_backward_fused_native<input_t, output_t, acc_t, 4, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 5:  // 32\n        softmax_warp_backward_fused_native<input_t, output_t, acc_t, 5, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 6:  // 64\n        softmax_warp_backward_fused_native<input_t, output_t, acc_t, 6, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 7:  // 128\n        softmax_warp_backward_fused_native<input_t, output_t, acc_t, 7, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 8:  // 256\n        softmax_warp_backward_fused_native<input_t, output_t, acc_t, 8, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 9:  // 512\n        softmax_warp_backward_fused_native<input_t, output_t, acc_t, 9, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 10:  // 1024\n        softmax_warp_backward_fused_native<input_t, output_t, acc_t, 10, is_log_softmax>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      default:\n        break;\n    }\n  }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n// Warp softmax backward\n////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32,\n          int ELEMENTS_PER_LDG_STG = 1>\n__global__ void softmax_warp_backward(__half* gradInput, const __half* grad, const __half* output, int batch_size,\n                                      int stride, int element_count) {\n  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n\n  // batch_size might not be a multiple of WARP_BATCH. Check how\n  // many batches have to computed within this WARP.\n  int local_batches = batch_size - first_batch;\n  if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;\n\n  // there might be multiple batches per warp. compute the index within the\n  // batch\n  int local_idx = threadIdx.x;\n\n  // the first element to process by the current thread\n  int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n  grad += thread_offset;\n  output += thread_offset;\n  gradInput += thread_offset;\n\n  // load data from global memory\n  input_t grad_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};\n  input_t output_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    int batch_element_count = (i >= local_batches) ? 0 : element_count;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n      if (element_index < batch_element_count) {\n        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE);\n        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&output_reg_input[i][it],\n                                                   output + i * element_count + it * WARP_SIZE);\n      }\n    }\n  }\n\n  // convert half to floating point\n  acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];\n  acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      grad_reg[i][it] = grad_reg_input[i][it];\n      output_reg[i][it] = output_reg_input[i][it];\n    }\n  }\n\n  // compute thread local sum\n  acc_t sum[WARP_BATCH] = {0};\n#pragma unroll\n  for (int it = 0; it < WARP_ITERATIONS; ++it) {\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      sum[i] += grad_reg[i][it] * output_reg[i][it];\n    }\n  }\n\n  // reduction sum\n  constexpr uint32_t FULL_MASK = 0xffffffff;\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n    }\n  }\n\n// store result\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    if (i >= local_batches) break;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n      if (element_index < element_count) {\n        // compute gradients\n        output_t out[ELEMENTS_PER_LDG_STG];\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          out[element] = (output_reg[i][it + element] * (grad_reg[i][it + element] - sum[i]));\n        }\n        // store them in global memory\n        copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);\n      }\n    }\n  }\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate\n// over all data. WARP_SIZE number of elements working on a single batch, has to\n// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t>\nusing softmax_backward_func = void (*)(output_t* gradInput, const input_t* grad, const input_t* output, int batch_size,\n                                       int stride, int element_count);\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_softmax_backward_kernel(int log2_elements, int& warp_size, int& batches_per_warp,\n                                  softmax_backward_func<input_t, output_t>& kernel) {\n  // determine size of a warp\n  const int next_power_of_two = 1 << log2_elements;\n  warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n\n  // determine how many batches a warp should process.\n  batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n\n  switch (log2_elements) {\n    case 0:  // 1\n      kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 1, 1>;\n      break;\n    case 1:  // 2\n      kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 2, 1>;\n      break;\n    case 2:  // 4\n      kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 4, 1>;\n      break;\n    case 3:  // 8\n      kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 8, 1>;\n      break;\n    case 4:  // 16\n      kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 16, 1>;\n      break;\n    case 5:  // 32\n      kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 32, 1>;\n      break;\n    case 6:  // 64\n      kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 2, 32, 1>;\n      break;\n    case 7:  // 128\n      kernel = &softmax_warp_backward<input_t, output_t, acc_t, 2, 4, 32, 1>;\n      break;\n    case 8:  // 256\n      kernel = &softmax_warp_backward<input_t, output_t, acc_t, 1, 8, 32, 1>;\n      break;\n    case 9:  // 512\n      kernel = &softmax_warp_backward<input_t, output_t, acc_t, 1, 16, 32, 1>;\n      break;\n    case 10:  // 1024\n      kernel = &softmax_warp_backward<input_t, output_t, acc_t, 1, 32, 32, 1>;\n      break;\n    default:\n      return false;\n  }\n  return true;\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nbool dispatch_softmax_backward(output_t* grad_input, const input_t* grad, const input_t* output, int softmax_elements,\n                               int softmax_elements_stride, int batch_count) {\n  if (softmax_elements == 0) {\n    return true;\n  } else if (softmax_elements <= 1024) {\n    // compute function index. there's a function for each power of two size up\n    // to 1024.\n    int log2_elements = 0;\n    while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n\n    softmax_backward_func<input_t, output_t> kernel;\n    int warp_size, batches_per_warp;\n    if (!warp_softmax_backward_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n      return false;\n    }\n\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n\n    // compute warps per block.\n    int warps_per_block = (threads_per_block / warp_size);\n\n    // compute launch size\n    int batches_per_block = warps_per_block * batches_per_warp;\n    int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n    dim3 threads(warp_size, warps_per_block, 1);\n\n    // launch\n    kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count,\n                                                                     softmax_elements_stride, softmax_elements);\n    return true;\n  }\n  return false;\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nbool dispatch_softmax_backward_stream(output_t* grad_input, const input_t* grad, const input_t* output,\n                                      int softmax_elements, int softmax_elements_stride, int batch_count,\n                                      cudaStream_t streamid) {\n  if (softmax_elements == 0) {\n    return true;\n  } else if (softmax_elements <= 1024) {\n    // compute function index. there's a function for each power of two size up\n    // to 1024.\n    int log2_elements = 0;\n    while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n    softmax_backward_func<input_t, output_t> kernel;\n    int warp_size, batches_per_warp;\n    if (!warp_softmax_backward_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp, kernel)) {\n      return false;\n    }\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n    // compute warps per block.\n    int warps_per_block = (threads_per_block / warp_size);\n    // compute launch size\n    int batches_per_block = warps_per_block * batches_per_warp;\n    int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n    dim3 threads(warp_size, warps_per_block, 1);\n    // launch\n    kernel<<<blocks, threads, 0, streamid>>>(grad_input, grad, output, batch_count, softmax_elements_stride,\n                                             softmax_elements);\n    return true;\n  }\n  return false;\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE = 32,\n          int ELEMENTS_PER_LDG_STG = 1>\n__global__ void masked_softmax_warp_backward(__half* gradInput, const __half* grad, const __half* output,\n                                             const uint8_t* pad_mask, int batch_size, int stride, int element_count,\n                                             int pad_batch_stride) {\n  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n\n  // batch_size might not be a multiple of WARP_BATCH. Check how\n  // many batches have to computed within this WARP.\n  int local_batches = batch_size - first_batch;\n  if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;\n\n  // there might be multiple batches per warp. compute the index within the\n  // batch\n  int local_idx = threadIdx.x;\n\n  // the first element to process by the current thread\n  int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n  grad += thread_offset;\n  output += thread_offset;\n  gradInput += thread_offset;\n\n  // load data from global memory\n  input_t grad_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};\n  input_t output_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f};\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    int batch_element_count = (i >= local_batches) ? 0 : element_count;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n      if (element_index < batch_element_count) {\n        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE);\n        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(&output_reg_input[i][it],\n                                                   output + i * element_count + it * WARP_SIZE);\n      }\n    }\n  }\n\n  // convert half to floating point\n  acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];\n  acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      grad_reg[i][it] = grad_reg_input[i][it];\n      output_reg[i][it] = output_reg_input[i][it];\n    }\n  }\n\n  // compute thread local sum\n  acc_t sum[WARP_BATCH] = {0};\n#pragma unroll\n  for (int it = 0; it < WARP_ITERATIONS; ++it) {\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      sum[i] += grad_reg[i][it] * output_reg[i][it];\n    }\n  }\n\n  // reduction sum\n  constexpr uint32_t FULL_MASK = 0xffffffff;\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);\n    }\n  }\n\n// store result\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    if (i >= local_batches) break;\n    int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + ELEMENTS_PER_LDG_STG * local_idx;\n    const uint8_t* curr_mask = pad_mask + pad_thread_offset;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n      if (element_index < element_count) {\n        // compute gradients\n        output_t out[ELEMENTS_PER_LDG_STG];\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          out[element] = (output_reg[i][it + element] * (grad_reg[i][it + element] - sum[i]));\n        }\n        // store them in global memory\n        int itr_jmp = it * WARP_SIZE;\n        int itr_idx = i * element_count + itr_jmp;\n        // It is kind of unfortunate this has to be here to zero something out\n        // that is close to zero in the first place\n        apply_mask<input_t, ELEMENTS_PER_LDG_STG>(&out[0], 0.0, curr_mask + itr_jmp);\n        copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + itr_idx, out);\n      }\n    }\n  }\n}\n\n// WARP_BATCH number of batches.\n// WARP_ITERATOINS The number of iterations required for one warp to iterate\n// over all data. WARP_SIZE number of elements working on a single batch, has to\n// be a power of two. ELEMENTS_PER_LDG_STG has to be 1.\ntemplate <typename input_t, typename output_t>\nusing masked_softmax_backward_func = void (*)(output_t* gradInput, const input_t* grad, const input_t* output,\n                                              const uint8_t* pad_mask, int batch_size, int stride, int element_count,\n                                              int pad_batch_stride);\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nbool warp_masked_softmax_backward_kernel(int log2_elements, int& warp_size, int& batches_per_warp,\n                                         masked_softmax_backward_func<input_t, output_t>& kernel) {\n  // determine size of a warp\n  const int next_power_of_two = 1 << log2_elements;\n  warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;\n\n  // determine how many batches a warp should process.\n  batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n\n  switch (log2_elements) {\n    case 0:  // 1\n      kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 1, 1>;\n      break;\n    case 1:  // 2\n      kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 2, 1>;\n      break;\n    case 2:  // 4\n      kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 4, 1>;\n      break;\n    case 3:  // 8\n      kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 8, 1>;\n      break;\n    case 4:  // 16\n      kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 16, 1>;\n      break;\n    case 5:  // 32\n      kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 1, 32, 1>;\n      break;\n    case 6:  // 64\n      kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 2, 32, 1>;\n      break;\n    case 7:  // 128\n      kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 2, 4, 32, 1>;\n      break;\n    case 8:  // 256\n      kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 1, 8, 32, 1>;\n      break;\n    case 9:  // 512\n      kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 1, 16, 32, 1>;\n      break;\n    case 10:  // 1024\n      kernel = &masked_softmax_warp_backward<input_t, output_t, acc_t, 1, 32, 32, 1>;\n      break;\n    default:\n      return false;\n  }\n  return true;\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nbool dispatch_masked_softmax_backward(output_t* grad_input, const input_t* grad, const input_t* output,\n                                      const uint8_t* pad_mask, int softmax_elements, int softmax_elements_stride,\n                                      int batch_count, int pad_batch_stride) {\n  if (softmax_elements == 0) {\n    return true;\n  } else if (softmax_elements <= 1024) {\n    // compute function index. there's a function for each power of two size up\n    // to 1024.\n    int log2_elements = 0;\n    while ((1 << log2_elements) < softmax_elements) ++log2_elements;\n\n    masked_softmax_backward_func<input_t, output_t> kernel;\n    int warp_size, batches_per_warp;\n    if (!warp_masked_softmax_backward_kernel<input_t, output_t, acc_t>(log2_elements, warp_size, batches_per_warp,\n                                                                       kernel)) {\n      return false;\n    }\n\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n\n    // compute warps per block.\n    int warps_per_block = (threads_per_block / warp_size);\n\n    // compute launch size\n    int batches_per_block = warps_per_block * batches_per_warp;\n    int blocks = (batch_count + batches_per_block - 1) / batches_per_block;\n    dim3 threads(warp_size, warps_per_block, 1);\n\n    // launch\n    kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(\n        grad_input, grad, output, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);\n    return true;\n  }\n  return false;\n}\n}  // namespace\n"
  },
  {
    "path": "apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh",
    "content": "#pragma once\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include <cuda_runtime.h>\n\n#include <iostream>\n#include <vector>\n\n// #include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n#include <cutlass/cutlass.h>\n#include <cutlass/fast_math.h>\n#include <cutlass/gemm/device/gemm_batched.h>\n#include <cutlass/gemm/gemm.h>\n#include <cutlass/layout/matrix.h>\n#include <cutlass/matrix_coord.h>\n#include <cutlass/pitch_linear_coord.h>\n\nnamespace {\ncublasOperation_t convertTransToCublasOperation(char trans) {\n  if (trans == 't')\n    return CUBLAS_OP_T;\n  else if (trans == 'n')\n    return CUBLAS_OP_N;\n  else if (trans == 'c')\n    return CUBLAS_OP_C;\n  else {\n    TORCH_CHECK(false, \"trans must be one of: t, n, c\");\n    return CUBLAS_OP_T;\n  }\n}\n\nvoid CublasStridedBatchedGemm(char transa, char transb, long m, long n, long k, float alpha, const half* a, long lda,\n                              long strideA, const half* b, long ldb, long strideB, float beta, half* c, long ldc,\n                              long strideC, long batchCount, cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP) {\n  cublasOperation_t opa = convertTransToCublasOperation(transa);\n  cublasOperation_t opb = convertTransToCublasOperation(transb);\n\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();\n  cublasSetStream(handle, stream);\n  float fAlpha = alpha;\n  float fBeta = beta;\n  TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(\n      handle, opa, opb, (int)m, (int)n, (int)k, (void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA, b, CUDA_R_16F,\n      (int)ldb, strideB, (void*)&fBeta, c, CUDA_R_16F, (int)ldc, strideC, (int)batchCount, CUDA_R_32F, algo));\n}\n\n}  // namespace\n\n// TODO(mkozuki): Make use of the int template parameters or discard them.\ntemplate <typename LayoutA, typename LayoutB, int SRC_A, int SRC_B, int DST_C>\nvoid CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k, float alpha, const half* a, long lda,\n                           long long int batch_stride_A, const half* b, long ldb, long long int batch_stride_B,\n                           float beta, half* c, long ldc, long long int batch_stride_C, long batch_count) {\n  using Gemm = cutlass::gemm::device::GemmBatched<\n      /* Element type of A matrix */ half, /* Layout of A matrix */ LayoutA,\n      /* Element type of B matrix */ half, /* Layout of B matrix */ LayoutB,\n      /* Element type of C matrix */ half, /* Layout of C matrix */ cutlass::layout::ColumnMajor,\n      /* Element Accumulator*/ float>;\n  Gemm gemm_op;\n  cutlass::Status status = gemm_op({{static_cast<int>(m), static_cast<int>(n), static_cast<int>(k)},\n                                    {a, lda},\n                                    batch_stride_A,\n                                    {b, ldb},\n                                    batch_stride_B,\n                                    {c, ldc},\n                                    batch_stride_C,\n                                    {c, ldc},\n                                    batch_stride_C,\n                                    {alpha, beta},\n                                    static_cast<int>(batch_count)},\n                                   nullptr, stream);\n  C10_CUDA_CHECK(status != cutlass::Status::kSuccess ? cudaErrorUnknown : cudaSuccess);\n}\n\nnamespace {\nvoid gemm_switch_fp32accum(char transa, char transb, long m, long n, long k, float alpha, const half* a, long lda,\n                           long strideA, const half* b, long ldb, long strideB, float beta, half* c, long ldc,\n                           long strideC, long batchCount) {\n  auto stream = c10::cuda::getCurrentCUDAStream();\n  // printf(\"GEMM   -> %c%c M: %i N: %i K: %i Alpha: %f Beta: %f\\n\", (transa ==\n  // 't' ? 'T' : 'N'), (transb =='t' ? 'T' : 'N'), m, n, k, alpha, beta);\n  if ((transa == 't') && (transb == 'n')) {\n    if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {\n      CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC,\n                               batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);\n    } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 8, 8, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 8, 8, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 8, 4, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 8, 4, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 8, 4, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 8, 2, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 8, 2, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 8, 2, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 4, 8, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 4, 8, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 4, 8, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 4, 4, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 4, 4, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 4, 4, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 4, 2, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 4, 2, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 4, 2, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 2, 8, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 2, 8, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 2, 8, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 2, 4, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 2, 4, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 2, 4, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 2, 2, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 2, 2, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, 2, 2, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else {\n      CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC,\n                               batchCount);\n    }\n  } else if ((transa == 'n') && (transb == 'n')) {\n    if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {\n      CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC,\n                               batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);\n    } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 8, 8, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 8, 8, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 8, 4, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 8, 4, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 8, 4, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 8, 2, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 8, 2, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 8, 2, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 4, 8, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 4, 8, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 4, 8, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 4, 4, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 4, 4, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 4, 4, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 4, 2, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 4, 2, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 4, 2, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 2, 8, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 2, 8, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 2, 8, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 2, 4, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 2, 4, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 2, 4, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 2, 2, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 2, 2, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor, 2, 2, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else {\n      CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC,\n                               batchCount);\n    }\n  } else if ((transa == 'n') && (transb == 't')) {\n    if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {\n      CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC,\n                               batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);\n    } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 8, 8, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 8, 8, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 8, 4, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 8, 4, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 8, 4, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 8, 2, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 8, 2, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 8, 2, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 4, 8, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 4, 8, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 4, 8, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 4, 4, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 4, 4, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 4, 2, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 4, 2, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 4, 2, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 2, 8, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 2, 8, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 2, 8, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 2, 4, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 2, 4, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 2, 4, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 2, 2, 8>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 2, 2, 4>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) {\n      CutlassGemm_FP32Accum<cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, 2, 2, 2>(\n          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);\n    } else {\n      CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC,\n                               batchCount);\n    }\n  } else {\n    TORCH_CHECK(false, \"TransA and TransB are invalid\");\n  }\n}\n\nvoid adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t* lda, int64_t* ldb,\n                    int64_t* ldc) {\n  int transa_ = ((transa == 't') || (transa == 'T'));\n  int transb_ = ((transb == 't') || (transb == 'T'));\n\n  // Note: leading dimensions generally are checked that they are > 0 and at\n  // least as big the result requires (even if the value won't be used).\n  if (n <= 1) *ldc = std::max<int64_t>(m, 1);\n\n  if (transa_) {\n    if (m <= 1) *lda = std::max<int64_t>(k, 1);\n  } else {\n    if (k <= 1) *lda = std::max<int64_t>(m, 1);\n  }\n\n  if (transb_) {\n    if (k <= 1) *ldb = std::max<int64_t>(n, 1);\n  } else {\n    if (n <= 1) *ldb = std::max<int64_t>(k, 1);\n  }\n}\n\nvoid HgemmStridedBatched(char transa, char transb, long m, long n, long k, float alpha, const half* a, long lda,\n                         long strideA, const half* b, long ldb, long strideB, float beta, half* c, long ldc,\n                         long strideC, long batchCount) {\n  if ((m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) ||\n      (batchCount >= INT_MAX))\n\n  {\n    TORCH_CHECK(false,\n                \"Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, \"\n                \"batchCount\"\n                \"with the bound [val] <= %d\",\n                INT_MAX);\n  }\n\n  adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);\n\n  gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC,\n                        batchCount);\n}\n\n}  // namespace\n"
  },
  {
    "path": "apex/contrib/csrc/nccl_allocator/NCCLAllocator.cpp",
    "content": "#include <c10/cuda/CUDACachingAllocator.h>\n#include <c10/util/Exception.h>\n#include <nccl.h>\n#include <torch/csrc/cuda/CUDAPluggableAllocator.h>\n#include <torch/extension.h>\n\n#define NCCL_CHECK(cmd)                                                                                     \\\n  do {                                                                                                      \\\n    ncclResult_t result = cmd;                                                                              \\\n    if (result != ncclSuccess) {                                                                            \\\n      std::string err = \"NCCL error in: \" + std::string(__FILE__) + \":\" + std::to_string(__LINE__) + \", \" + \\\n                        std::string(ncclGetErrorString(result));                                            \\\n      TORCH_CHECK(false, err);                                                                              \\\n    }                                                                                                       \\\n  } while (0)\n\nvoid* nccl_alloc_plug(size_t size, int device, void* stream) {\n  void* ptr;\n  NCCL_CHECK(ncclMemAlloc(&ptr, size));\n  return ptr;\n}\n\nvoid nccl_free_plug(void* ptr, std::size_t size, int device, void* stream) { NCCL_CHECK(ncclMemFree(ptr)); }\n\nstd::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> nccl_allocator;\n\nvoid maybe_init() {\n  if (!nccl_allocator) {\n    nccl_allocator =\n        std::make_shared<torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator>(nccl_alloc_plug, nccl_free_plug);\n  }\n}\n\nstd::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> get_nccl_allocator() {\n  maybe_init();\n  return nccl_allocator;\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"get_nccl_allocator\", []() { return get_nccl_allocator(); });\n};\n"
  },
  {
    "path": "apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp",
    "content": "/**\n * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \"nccl_p2p_cuda.cuh\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"get_unique_nccl_id\", &apex::contrib::nccl_p2p::get_unique_nccl_id, \"get_unique_nccl_id\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"init_nccl_comm\", &apex::contrib::nccl_p2p::init_nccl_comm, \"init_nccl_comm\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"left_right_halo_exchange_inplace\", &apex::contrib::nccl_p2p::left_right_halo_exchange_inplace,\n        \"left_right_halo_exchange_inplace\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"left_right_halo_exchange\", &apex::contrib::nccl_p2p::left_right_halo_exchange, \"left_right_halo_exchange\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"add_delay\", &apex::contrib::nccl_p2p::add_delay, \"add_delay\", py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu",
    "content": "#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDACachingAllocator.h>\n#include <torch/extension.h>\n\n#include <cassert>\n#include <cstdio>\n#include <ctime>\n#include <list>\n\n#include \"nccl.h\"\n\n/*\n * This file implements a crude but effective mechanism for copying data between tenors owned by different ranks\n * on the same machine using cudaMemcpyAsync peer-to-peer transfers.\n */\n\nnamespace {\n\n__global__ void AddDelay_kernel(const int delay, int* counter) {\n  if (blockIdx.x == 0 && threadIdx.x == 0) {\n    // waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.\n    int new_counter = 0;\n    double elapsed = 0;\n    clock_t start = clock();\n    do {\n      clock_t now = clock();\n      elapsed = (double)(now - start) * 1e9 / CLOCKS_PER_SEC;\n      ++new_counter;\n    } while (elapsed < (double)delay);\n    *counter = new_counter;\n  }\n}\n\nclass NcclCommWrapper {\n private:\n  ncclComm_t comm;\n  int rank, world_size;\n\n  ncclDataType_t get_nccl_type(at::Tensor input) {\n    switch (input.scalar_type()) {\n      case at::ScalarType::Half:\n        return ncclFloat16;\n      case at::ScalarType::Float:\n        return ncclFloat32;\n      case at::ScalarType::Double:\n        return ncclFloat64;\n      case at::ScalarType::Byte:\n        return ncclUint8;\n      case at::ScalarType::Char:\n        return ncclInt8;\n      case at::ScalarType::Int:\n        return ncclInt32;\n      case at::ScalarType::Long:\n        return ncclInt64;\n      case at::ScalarType::BFloat16:\n        return ncclBfloat16;\n      default:\n        assert(false);\n    }\n  }\n\n public:\n  NcclCommWrapper() {\n    memset(&comm, 0, sizeof(ncclComm_t));\n    rank = 0;\n    world_size = 0;\n  }\n  NcclCommWrapper(ncclUniqueId id, int my_rank, int num_ranks) {\n    ncclCommInitRank(&comm, num_ranks, id, my_rank);\n    rank = my_rank;\n    world_size = num_ranks;\n  }\n\n  ~NcclCommWrapper() {\n    printf(\"ncclCommDestroy()\\n\");\n    ncclCommDestroy(comm);\n  }\n\n  void left_right_halo_exchange_inplace(int left_rank, int right_rank, at::Tensor left_output_halo,\n                                        at::Tensor right_output_halo, at::Tensor left_input_halo,\n                                        at::Tensor right_input_halo) {\n    auto stream = at::cuda::getCurrentCUDAStream();\n    ncclGroupStart();\n    ncclDataType_t ncclType = get_nccl_type(left_output_halo);\n    bool left_zero = (left_rank < 0);\n    bool right_zero = (right_rank < 0);\n    size_t left_n = torch::numel(left_output_halo);\n    size_t right_n = torch::numel(right_output_halo);\n    assert(left_n > 0 && left_n == right_n);\n    if (left_zero) {\n      left_input_halo.zero_();\n    } else {\n      AT_DISPATCH_ALL_TYPES_AND3(\n          at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, left_output_halo.scalar_type(),\n          \"left_halo_exch\", [&]() {\n            // send left (to my_rank - 1)\n            ncclSend(left_output_halo.data_ptr<scalar_t>(), left_n, ncclType, left_rank, comm, stream);\n            // receive left (from my_rank - 1)\n            ncclRecv(left_input_halo.data_ptr<scalar_t>(), right_n, ncclType, left_rank, comm, stream);\n          });\n    }\n    if (right_zero) {\n      right_input_halo.zero_();\n    } else {\n      AT_DISPATCH_ALL_TYPES_AND3(\n          at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, right_output_halo.scalar_type(),\n          \"right_halo_exch\", [&]() {\n            // send right (to my_rank + 1 )\n            ncclSend(right_output_halo.data_ptr<scalar_t>(), right_n, ncclType, right_rank, comm, stream);\n            // receive right (from my_rank + 1)\n            ncclRecv(right_input_halo.data_ptr<scalar_t>(), left_n, ncclType, right_rank, comm, stream);\n          });\n    }\n    ncclGroupEnd();\n  }\n\n  std::vector<at::Tensor> left_right_halo_exchange(int left_rank, int right_rank, at::Tensor left_output_halo,\n                                                   at::Tensor right_output_halo) {\n    // after halo exchange:\n    // left_output_halo of rank+1 ends up in right_input_halo of rank\n    // right_output_halo of rank-1 ends up in left_input_halo of rank\n    auto right_input_halo = torch::empty_like(left_output_halo);\n    auto left_input_halo = torch::empty_like(right_output_halo);\n    left_right_halo_exchange_inplace(left_rank, right_rank, left_output_halo, right_output_halo, left_input_halo,\n                                     right_input_halo);\n    return {left_input_halo, right_input_halo};\n  }\n};\n\nclass ManagedObjects {\n public:\n  ManagedObjects() {}\n  ~ManagedObjects() {\n    for (auto it = _nccl_comms.begin(); it != _nccl_comms.end(); ++it) {\n      delete *it;\n    }\n  }\n\n  int add_comm(NcclCommWrapper* comm) {\n    int handle = _nccl_comms.size();\n    _nccl_comms.push_back(comm);\n    return handle;\n  }\n\n  NcclCommWrapper& get_comm(int handle) {\n    assert(handle >= 0 && handle < _nccl_comms.size());\n    return *_nccl_comms[handle];\n  }\n\n private:\n  std::vector<NcclCommWrapper*> _nccl_comms;\n};\nclass ManagedObjects mo;\n\n}  // end anonymous namespace\n\nnamespace apex {\nnamespace contrib {\nnamespace nccl_p2p {\n\nat::Tensor get_unique_nccl_id(int n) {\n  ncclUniqueId id;\n  ncclGetUniqueId(&id);\n  auto id_tensor = torch::empty({n, (int)sizeof(ncclUniqueId)},\n                                torch::dtype(torch::kUInt8).device(torch::kCPU).requires_grad(false));\n  auto id_ptr = id_tensor.data_ptr<uint8_t>();\n  size_t offset = 0;\n  for (int i = 0; i < n; ++i) {\n    ncclUniqueId id;\n    ncclGetUniqueId(&id);\n    memcpy(id_ptr + offset, &id, sizeof(ncclUniqueId));\n    offset += sizeof(ncclUniqueId);\n  }\n  return id_tensor;\n}\n\nint init_nccl_comm(at::Tensor unique_nccl_id, int my_rank, int num_ranks) {\n  ncclUniqueId id;\n  auto unique_nccl_id_ptr = unique_nccl_id.data_ptr<uint8_t>();\n  memcpy(&id, unique_nccl_id_ptr, sizeof(ncclUniqueId));\n  NcclCommWrapper* comm = new NcclCommWrapper(id, my_rank, num_ranks);\n  int handle = mo.add_comm(comm);\n  comm = 0L;\n  return handle;\n}\n\nvoid left_right_halo_exchange_inplace(int handle, int left_rank, int right_rank, at::Tensor left_output_halo,\n                                      at::Tensor right_output_halo, at::Tensor left_input_halo,\n                                      at::Tensor right_input_halo) {\n  class NcclCommWrapper& communicator = mo.get_comm(handle);\n  return communicator.left_right_halo_exchange_inplace(left_rank, right_rank, left_output_halo, right_output_halo,\n                                                       left_input_halo, right_input_halo);\n}\n\nstd::vector<at::Tensor> left_right_halo_exchange(int handle, int left_rank, int right_rank, at::Tensor left_output_halo,\n                                                 at::Tensor right_output_halo) {\n  class NcclCommWrapper& communicator = mo.get_comm(handle);\n  return communicator.left_right_halo_exchange(left_rank, right_rank, left_output_halo, right_output_halo);\n}\n\nvoid add_delay(int delay) {\n  auto stream = at::cuda::getCurrentCUDAStream();\n  auto t = torch::empty({1}, torch::dtype(torch::kInt32).device(torch::kCUDA));\n  AddDelay_kernel<<<1, 1, 0, stream>>>(delay, t.data_ptr<int>());\n}\n\n}  // namespace nccl_p2p\n}  // namespace contrib\n}  // namespace apex\n"
  },
  {
    "path": "apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh",
    "content": "/**\n * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n#include <torch/extension.h>\n#ifndef _nccl_p2p_h_\n#define _nccl_p2p_h_\n\nnamespace apex {\nnamespace contrib {\nnamespace nccl_p2p {\nat::Tensor get_unique_nccl_id(int n);\nint init_nccl_comm(at::Tensor unique_nccl_id, int my_rank, int num_ranks);\nvoid left_right_halo_exchange_inplace(int handle, int left_rank, int right_rank, at::Tensor left_output_halo,\n                                      at::Tensor right_output_halo, at::Tensor left_input_halo,\n                                      at::Tensor right_input_halo);\nstd::vector<at::Tensor> left_right_halo_exchange(int handle, int left_rank, int right_rank, at::Tensor left_output_halo,\n                                                 at::Tensor right_output_halo);\nvoid add_delay(int delay);\n}  // namespace nccl_p2p\n}  // namespace contrib\n}  // namespace apex\n#endif\n"
  },
  {
    "path": "apex/contrib/csrc/nccl_p2p/nccl_version.cpp",
    "content": "// Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.\n// This file is used to check the version of NCCL detected.\n#include <torch/extension.h>\n\n#include <tuple>\n\nstd::tuple<int, int> get_nccl_version();\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def(\"get_nccl_version\", &get_nccl_version); }\n"
  },
  {
    "path": "apex/contrib/csrc/nccl_p2p/nccl_version_check.cu",
    "content": "// Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.\n\n// This file is used to check the version of NCCL detected.\n#include <nccl.h>\n\n#include <tuple>\n\nstd::tuple<int, int> get_nccl_version() { return {int(NCCL_MAJOR), int(NCCL_MINOR)}; }\n"
  },
  {
    "path": "apex/contrib/csrc/optimizers/fused_adam_cuda.cpp",
    "content": "#include <torch/extension.h>\n\n// CUDA forward declaration\nvoid fused_strided_check_finite(at::Tensor& overflow_flag, at::Tensor& p_copy, int stride, int clear_overflow_first);\n\nvoid fused_adam_cuda(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, at::Tensor& v, at::Tensor& g, float lr,\n                     float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction,\n                     float decay);\nvoid fused_reversible_adam_cuda(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, at::Tensor& v, at::Tensor& g,\n                                float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode,\n                                int bias_correction, float decay);\nvoid fused_maybe_adam_undo_cuda(at::Tensor& overflow_flag, at::Tensor& p, at::Tensor& m, at::Tensor& v, at::Tensor& g,\n                                float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode,\n                                int bias_correction, float decay);\n\nvoid fused_adam_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                        float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode,\n                        int bias_correction, float decay);\n\nvoid maybe_cast_cuda(at::Tensor& overflow_flag, at::Tensor& p_in, at::Tensor& p_out);\nvoid maybe_cast_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists);\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) \\\n  CHECK_CUDA(x);       \\\n  CHECK_CONTIGUOUS(x)\n\n// C++ interface\nvoid strided_check_finite(at::Tensor& overflow_flag, at::Tensor& p_copy, int stride, int clear_overflow_first) {\n  CHECK_INPUT(p_copy);\n  fused_strided_check_finite(overflow_flag, p_copy, stride, clear_overflow_first);\n}\nvoid adam(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, at::Tensor& v, at::Tensor& g, float lr, float beta1,\n          float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {\n  CHECK_INPUT(p);\n  if (p_copy.numel() > 0) CHECK_INPUT(p_copy);\n  CHECK_INPUT(m);\n  CHECK_INPUT(v);\n  CHECK_INPUT(g);\n  int64_t num_elem = p.numel();\n  TORCH_CHECK(m.numel() == num_elem, \"number of elements in m and p tensors should be equal\");\n  TORCH_CHECK(v.numel() == num_elem, \"number of elements in v and p tensors should be equal\");\n  TORCH_CHECK(g.numel() == num_elem, \"number of elements in g and p tensors should be equal\");\n  TORCH_CHECK(p_copy.numel() == num_elem || p_copy.numel() == 0,\n              \"number of elements in p_copy and p tensors should be equal, or p_copy should be empty\");\n\n  fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);\n}\nvoid reversible_adam(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, at::Tensor& v, at::Tensor& g, float lr,\n                     float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction,\n                     float decay) {\n  CHECK_INPUT(p);\n  if (p_copy.numel() > 0) CHECK_INPUT(p_copy);\n  CHECK_INPUT(m);\n  CHECK_INPUT(v);\n  CHECK_INPUT(g);\n  int64_t num_elem = p.numel();\n  TORCH_CHECK(m.numel() == num_elem, \"number of elements in m and p tensors should be equal\");\n  TORCH_CHECK(v.numel() == num_elem, \"number of elements in v and p tensors should be equal\");\n  TORCH_CHECK(g.numel() == num_elem, \"number of elements in g and p tensors should be equal\");\n  TORCH_CHECK(p_copy.numel() == num_elem || p_copy.numel() == 0,\n              \"number of elements in p_copy and p tensors should be equal, or p_copy should be empty\");\n\n  fused_reversible_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);\n}\nvoid maybe_adam_undo(at::Tensor& overflow_flag, at::Tensor& p, at::Tensor& m, at::Tensor& v, at::Tensor& g, float lr,\n                     float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction,\n                     float decay) {\n  CHECK_INPUT(p);\n  CHECK_INPUT(m);\n  CHECK_INPUT(v);\n  CHECK_INPUT(g);\n  int64_t num_elem = p.numel();\n  TORCH_CHECK(m.numel() == num_elem, \"number of elements in m and p tensors should be equal\");\n  TORCH_CHECK(v.numel() == num_elem, \"number of elements in v and p tensors should be equal\");\n  TORCH_CHECK(g.numel() == num_elem, \"number of elements in g and p tensors should be equal\");\n\n  fused_maybe_adam_undo_cuda(overflow_flag, p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction,\n                             decay);\n}\nvoid maybe_cast(at::Tensor& overflow_flag, at::Tensor& p_in, at::Tensor& p_out) {\n  CHECK_INPUT(p_in);\n  CHECK_INPUT(p_out);\n  int64_t num_elem = p_in.numel();\n  TORCH_CHECK(p_out.numel() == num_elem, \"number of elements in p_in and p_out should be equal\");\n\n  maybe_cast_cuda(overflow_flag, p_in, p_out);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"strided_check_finite\", &strided_check_finite, \"Strided finite check.\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"adam\", &adam, \"Adam optimized CUDA implementation.\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"reversible_adam\", &reversible_adam, \"Reversible Adam optimized CUDA implementation.\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"adam_mt\", &fused_adam_cuda_mt, \"Multi tensor Adam optimized CUDA implementation.\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"maybe_adam_undo\", &maybe_adam_undo, \"Undo function for Adam optimized CUDA implementation.\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"maybe_cast\", &maybe_cast, \"Unpack byte tensor containing e5m2 floats.\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"maybe_cast_mt\", &maybe_cast_cuda_mt, \"Unpack byte tensor containing e5m2 floats.\",\n        py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu",
    "content": "#include <cuda.h>\n#include <cuda_runtime.h>\n#include <stdio.h>\n\n#include <cmath>\n\n#include \"ATen/ATen.h\"\n#include \"ATen/TensorUtils.h\"\n#include \"ATen/cuda/CUDAContext.h\"\n#include \"ATen/cuda/detail/IndexUtils.cuh\"\n// #include \"ATen/Type.h\"\n#include \"ATen/AccumulateType.h\"\n#include \"multi_tensor_apply.cuh\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntemplate <typename T>\n__device__ __forceinline__ bool is_aligned(T* p) {\n  return ((uint64_t)p) % (ILP * sizeof(T)) == 0;\n}\n\ntemplate <typename T>\n__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset) {\n  typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;\n  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];\n}\n\n#include \"type_shim.h\"\n\ntypedef enum {\n  ADAM_MODE_0 = 0,  // eps under square root\n  ADAM_MODE_1 = 1   // eps outside square root\n} adamMode_t;\n\ntemplate <typename T, typename GRAD_T>\n__global__ void adam_cuda_kernel(T* __restrict__ p,\n                                 GRAD_T* __restrict__ p_copy,  // For mixed precision training, pass NULL if not needed\n                                 T* __restrict__ m, T* __restrict__ v, const GRAD_T* __restrict__ g, const float b1,\n                                 const float b2, const float eps, const float grad_scale, const float step_size,\n                                 const size_t tsize, adamMode_t mode, const float decay) {\n  // Assuming 2D grids and 2D blocks\n  const int blockId = gridDim.x * blockIdx.y + blockIdx.x;\n  const int threadsPerBlock = blockDim.x * blockDim.y;\n  const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;\n  const int i = (blockId * threadsPerBlock + threadIdInBlock);\n  const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;\n\n  for (int j = i; j < tsize; j += totThreads) {\n    T scaled_grad = g[j] / grad_scale;\n    m[j] = b1 * m[j] + (1 - b1) * scaled_grad;\n    v[j] = b2 * v[j] + (1 - b2) * scaled_grad * scaled_grad;\n    float denom;\n    if (mode == ADAM_MODE_0)\n      denom = sqrtf(v[j] + eps);\n    else  // Mode 1\n      denom = sqrtf(v[j]) + eps;\n    float update = (m[j] / denom) + (decay * p[j]);\n    p[j] = p[j] - (step_size * update);\n    if (p_copy != NULL) p_copy[j] = (GRAD_T)p[j];\n  }\n}\n\ntemplate <int DEPTH, typename T, typename GRAD_T>\nstruct AdamFunctor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<DEPTH>& tl,\n                                             const float b1, const float b2, const float eps, const float grad_scale,\n                                             const float step_size, adamMode_t mode, const float decay) {\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    T* p = (T*)tl.addresses[0][tensor_loc];\n    p += chunk_idx * chunk_size;\n    T* m = (T*)tl.addresses[1][tensor_loc];\n    m += chunk_idx * chunk_size;\n    T* v = (T*)tl.addresses[2][tensor_loc];\n    v += chunk_idx * chunk_size;\n    GRAD_T* g = (GRAD_T*)tl.addresses[3][tensor_loc];\n    g += chunk_idx * chunk_size;\n    GRAD_T* p_copy = NULL;\n    if (DEPTH == 5) {\n      p_copy = (GRAD_T*)tl.addresses[4][tensor_loc];\n      p_copy += chunk_idx * chunk_size;\n    }\n\n    n -= chunk_idx * chunk_size;\n\n    T incoming_p[ILP];\n    T incoming_m[ILP];\n    T incoming_v[ILP];\n    T incoming_g[ILP];\n\n    // to make things simple, we put aligned case in a different code path\n    if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(p) && is_aligned(m) && is_aligned(v) && is_aligned(g) &&\n        is_aligned(p_copy)) {\n      for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {\n        // load\n        GRAD_T tmp_g[ILP];\n        load_store(incoming_p, p, 0, i_start);\n        load_store(incoming_m, m, 0, i_start);\n        load_store(incoming_v, v, 0, i_start);\n        load_store(tmp_g, g, 0, i_start);\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          incoming_g[ii] = static_cast<T>(tmp_g[ii]);\n          T scaled_grad = incoming_g[ii] / grad_scale;\n          incoming_m[ii] = b1 * incoming_m[ii] + (1 - b1) * scaled_grad;\n          incoming_v[ii] = b2 * incoming_v[ii] + (1 - b2) * scaled_grad * scaled_grad;\n          float denom;\n          if (mode == ADAM_MODE_0)\n            denom = sqrtf(incoming_v[ii] + eps);\n          else  // Mode 1\n            denom = sqrtf(incoming_v[ii]) + eps;\n          float update = (incoming_m[ii] / denom) + (decay * incoming_p[ii]);\n          incoming_p[ii] = incoming_p[ii] - (step_size * update);\n          if (DEPTH == 5) tmp_g[ii] = static_cast<GRAD_T>(incoming_p[ii]);\n        }\n        load_store(p, incoming_p, i_start, 0);\n        load_store(m, incoming_m, i_start, 0);\n        load_store(v, incoming_v, i_start, 0);\n        if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0);\n      }\n    } else {\n      for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          incoming_p[ii] = 0;\n          incoming_m[ii] = 0;\n          incoming_v[ii] = 0;\n          incoming_g[ii] = 0;\n\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            incoming_p[ii] = p[i];\n            incoming_m[ii] = m[i];\n            incoming_v[ii] = v[i];\n            incoming_g[ii] = static_cast<T>(g[i]);\n          }\n        }\n\n        // note for clarification to future michael:\n        // From a pure memory dependency perspective, there's likely no point unrolling\n        // the write loop, since writes just fire off once their LDGs arrive.\n        // Put another way, the STGs are dependent on the LDGs, but not on each other.\n        // There is still compute ILP benefit from unrolling the loop though.\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int j = i_start + threadIdx.x + ii * blockDim.x;\n\n          if (j < n && j < chunk_size) {\n            T scaled_grad = incoming_g[ii] / grad_scale;\n            m[j] = b1 * incoming_m[ii] + (1 - b1) * scaled_grad;\n            v[j] = b2 * incoming_v[ii] + (1 - b2) * scaled_grad * scaled_grad;\n            float denom;\n            if (mode == ADAM_MODE_0)\n              denom = sqrtf(v[j] + eps);\n            else  // Mode 1\n              denom = sqrtf(v[j]) + eps;\n            float update = (m[j] / denom) + (decay * incoming_p[ii]);\n            p[j] = incoming_p[ii] - (step_size * update);\n            if (DEPTH == 5) p_copy[j] = (GRAD_T)p[j];\n          }\n        }\n      }\n    }\n  }\n};\n\nvoid fused_adam_cuda(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, at::Tensor& v, at::Tensor& g, float lr,\n                     float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction,\n                     float decay) {\n  //        using namespace at;\n\n  // Get tensor size\n  int tsize = p.numel();\n  // Determine #threads and #blocks\n  const int threadsPerBlock = 512;\n  const dim3 blocks((tsize + threadsPerBlock - 1) / threadsPerBlock);\n  TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p), \"parameter tensor is too large to be indexed with int32\");\n  // Constants\n  float step_size = 0;\n  if (bias_correction == 1) {\n    const float bias_correction1 = 1 - std::pow(beta1, step);\n    const float bias_correction2 = 1 - std::pow(beta2, step);\n    step_size = lr * std::sqrt(bias_correction2) / bias_correction1;\n  } else {\n    step_size = lr;\n  }\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  if (g.scalar_type() == at::ScalarType::Half) {\n    // all other values should be fp32 for half gradients\n    TORCH_CHECK(p.scalar_type() == at::ScalarType::Float, \"expected parameter to be of float type\");\n    // dispatch is done on the gradient type\n    using namespace at;  // prevents \"toString is undefined\" errors\n    DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, \"adam_cuda_kernel\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n                            adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks, threadsPerBlock, 0, stream>>>(\n                                p.data_ptr<accscalar_t>(), p_copy.numel() ? p_copy.data_ptr<scalar_t_0>() : NULL,\n                                m.data_ptr<accscalar_t>(), v.data_ptr<accscalar_t>(), g.data_ptr<scalar_t_0>(), beta1,\n                                beta2, eps, grad_scale, step_size, tsize, (adamMode_t)mode, decay););\n  } else {\n    using namespace at;\n    DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, \"adam_cuda_kernel\",\n                              adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks, threadsPerBlock, 0, stream>>>(\n                                  p.data_ptr<scalar_t_0>(),\n                                  NULL,  // don't output p_copy for fp32, it's wasted write\n                                  m.data_ptr<scalar_t_0>(), v.data_ptr<scalar_t_0>(), g.data_ptr<scalar_t_0>(), beta1,\n                                  beta2, eps, grad_scale, step_size, tsize, (adamMode_t)mode, decay););\n  }\n  C10_CUDA_CHECK(cudaGetLastError());\n}\n\nvoid fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag,\n                        std::vector<std::vector<at::Tensor>> tensor_lists,  // p, m, v, g, p_copy\n                        float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode,\n                        int bias_correction, float decay) {\n  // Constants\n  float step_size = 0;\n  if (bias_correction == 1) {\n    const float bias_correction1 = 1 - std::pow(beta1, step);\n    const float bias_correction2 = 1 - std::pow(beta2, step);\n    step_size = lr * std::sqrt(bias_correction2) / bias_correction1;\n  } else {\n    step_size = lr;\n  }\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  size_t tl_sz = tensor_lists.size();\n  TORCH_CHECK(tl_sz == 4 || tl_sz == 5, \"expected tensor lists of size 4 or 5\");\n\n  if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half) {\n    // alher values should be fp32 for half gradients\n    TORCH_CHECK(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, \"expected parameter to be of float type\");\n    // dich is done on the gradient type\n    if (tl_sz == 5) {\n      DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, \"adam_cuda_mt_kernel\",\n                              using accscalar_t = at::acc_type<scalar_t_0, true>;\n                              multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,\n                                                    AdamFunctor<5, accscalar_t, scalar_t_0>(), beta1, beta2, eps,\n                                                    grad_scale, step_size, (adamMode_t)mode, decay););\n    } else {\n      DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, \"adam_cuda_mt_kernel\",\n                              using accscalar_t = at::acc_type<scalar_t_0, true>;\n                              multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,\n                                                    AdamFunctor<4, accscalar_t, scalar_t_0>(), beta1, beta2, eps,\n                                                    grad_scale, step_size, (adamMode_t)mode, decay););\n    }\n  } else {\n    if (tl_sz == 5) {\n      DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, \"adam_cuda_mt_kernel\",\n                                multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,\n                                                      AdamFunctor<5, scalar_t_0, scalar_t_0>(), beta1, beta2, eps,\n                                                      grad_scale, step_size, (adamMode_t)mode, decay););\n    } else {\n      DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, \"adam_cuda_mt_kernel\",\n                                multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,\n                                                      AdamFunctor<4, scalar_t_0, scalar_t_0>(), beta1, beta2, eps,\n                                                      grad_scale, step_size, (adamMode_t)mode, decay););\n    }\n  }\n  C10_CUDA_CHECK(cudaGetLastError());\n}\n\ntemplate <typename FROM_T, typename TO_T>\n__device__ void convert(const FROM_T vi, TO_T& vo) {\n  vo = static_cast<TO_T>(vi);\n}\n\ntemplate <>\n__device__ void convert(const float vi, uint8_t& vo) {\n  union S {\n    float as_float;\n    int as_int;\n  };\n  S s;\n  s.as_float = vi;\n  s.as_int = s.as_int & 0xFF800000;\n  union T {\n    at::Half as_half;\n    uint8_t as_byte[2];\n  };\n  T t;\n  t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);\n  vo = t.as_byte[1];\n}\n\ntemplate <>\n__device__ void convert(const uint8_t vi, float& vo) {\n  union T {\n    at::Half as_half;\n    uint8_t as_byte[2];\n  };\n  T t;\n  t.as_byte[0] = 0;\n  t.as_byte[1] = vi;\n  vo = static_cast<float>(t.as_half);\n}\n\ntemplate <>\n__device__ void convert(const at::Half vi, uint8_t& vo) {\n  union S {\n    float as_float;\n    int as_int;\n  };\n  S s;\n  s.as_float = static_cast<float>(vi);\n  s.as_int = s.as_int & 0xFF800000;\n  union T {\n    at::Half as_half;\n    uint8_t as_byte[2];\n  };\n  T t;\n  t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);\n  vo = t.as_byte[1];\n}\n\ntemplate <>\n__device__ void convert(const uint8_t vi, at::Half& vo) {\n  union T {\n    at::Half as_half;\n    uint8_t as_byte[2];\n  };\n  T t;\n  t.as_byte[0] = 0;\n  t.as_byte[1] = vi;\n  vo = t.as_half;\n}\n\ntemplate <typename GRAD_T>\n__global__ void strided_check_finite_cuda_kernel(volatile int* noop_gmem, GRAD_T* __restrict__ p_copy,\n                                                 const size_t tsize, int stride, int clear_overflow_first) {\n  // Assuming 2D grids and 2D blocks\n  const int blockId = gridDim.x * blockIdx.y + blockIdx.x;\n  const int threadsPerBlock = blockDim.x * blockDim.y;\n  const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;\n  const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride;\n  const int totThreads = gridDim.x * gridDim.y * threadsPerBlock * stride;\n\n  if (clear_overflow_first) {\n    if (i == 0) {\n      *noop_gmem = 0;\n    }\n    __syncthreads();\n  }\n\n  for (int j = i; j < tsize; j += totThreads) {\n    GRAD_T pi = p_copy[j];\n    if (!isfinite(pi)) {\n      *noop_gmem = 1;\n    }\n  }\n}\ntemplate <>\n__global__ void strided_check_finite_cuda_kernel(volatile int* noop_gmem, uint8_t* __restrict__ p_copy,\n                                                 const size_t tsize, int stride, int clear_overflow_first) {\n  // Assuming 2D grids and 2D blocks\n  const int blockId = gridDim.x * blockIdx.y + blockIdx.x;\n  const int threadsPerBlock = blockDim.x * blockDim.y;\n  const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;\n  const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride;\n  const int totThreads = gridDim.x * gridDim.y * threadsPerBlock * stride;\n\n  if (clear_overflow_first) {\n    if (i == 0) {\n      *noop_gmem = 0;\n    }\n    __syncthreads();\n  }\n\n  for (int j = i; j < tsize; j += totThreads) {\n    at::Half pi;\n    convert(p_copy[j], pi);\n    if (!isfinite(pi)) {\n      *noop_gmem = 1;\n    }\n  }\n}\n\ntemplate <typename FROM_T, typename TO_T>\n__global__ void maybe_cast_kernel(volatile int* overflow_flag, const FROM_T* p_in, TO_T* p_out, const size_t tsize) {\n  if (overflow_flag && *overflow_flag != 0) return;\n\n  // Assuming 2D grids and 2D blocks\n  const int blockId = gridDim.x * blockIdx.y + blockIdx.x;\n  const int threadsPerBlock = blockDim.x * blockDim.y;\n  const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;\n  const int i = (blockId * threadsPerBlock + threadIdInBlock);\n  const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;\n\n  FROM_T pi[ILP];\n  TO_T po[ILP];\n\n  for (int j_start = 0; j_start < tsize; j_start += totThreads * ILP) {\n#pragma unroll\n    for (int ii = 0; ii < ILP; ii++) {\n      pi[ii] = 0;\n\n      int j = j_start + i + totThreads * ii;\n      if (j < tsize) {\n        pi[ii] = p_in[j];\n      }\n    }\n\n#pragma unroll\n    for (int ii = 0; ii < ILP; ii++) {\n      convert(pi[ii], po[ii]);\n    }\n\n#pragma unroll\n    for (int ii = 0; ii < ILP; ii++) {\n      int j = j_start + i + totThreads * ii;\n      if (j < tsize) {\n        p_out[j] = po[ii];\n      }\n    }\n  }\n}\n\ntemplate <typename T, typename GRAD_T, typename REDU_T>\n__global__ void reversible_adam_cuda_kernel(\n    T* __restrict__ p,\n    REDU_T* __restrict__ p_copy,  // For mixed precision training, pass NULL if not needed\n    T* __restrict__ m, T* __restrict__ v, const GRAD_T* __restrict__ g, const float b1, const float b2, const float eps,\n    const float grad_scale, const float step_size, const size_t tsize, adamMode_t mode, const float decay) {\n  // Assuming 2D grids and 2D blocks\n  const int blockId = gridDim.x * blockIdx.y + blockIdx.x;\n  const int threadsPerBlock = blockDim.x * blockDim.y;\n  const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;\n  const int i = (blockId * threadsPerBlock + threadIdInBlock);\n  const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;\n\n  T mi[ILP];\n  T vi[ILP];\n  T pi[ILP];\n  T gi[ILP];\n\n  bool overflow = false;\n  for (int j_start = 0; j_start < tsize; j_start += totThreads * ILP) {\n#pragma unroll\n    for (int ii = 0; ii < ILP; ii++) {\n      mi[ii] = T(0);\n      vi[ii] = T(0);\n      pi[ii] = T(0);\n      gi[ii] = GRAD_T(0);\n\n      int j = j_start + i + totThreads * ii;\n      if (j < tsize) {\n        pi[ii] = p[j];\n        mi[ii] = m[j];\n        vi[ii] = v[j];\n        gi[ii] = static_cast<T>(g[j]);\n      }\n    }\n\n#pragma unroll\n    for (int ii = 0; ii < ILP; ii++) {\n      T scaled_grad = gi[ii] / grad_scale;\n      if (isfinite(scaled_grad)) {\n        mi[ii] = b1 * mi[ii] + (1 - b1) * scaled_grad;\n        vi[ii] = b2 * vi[ii] + (1 - b2) * scaled_grad * scaled_grad;\n        float denom;\n        if (mode == ADAM_MODE_0)\n          denom = sqrtf(vi[ii] + eps);\n        else  // Mode 1\n          denom = sqrtf(vi[ii]) + eps;\n        float update = (mi[ii] / denom) + (decay * pi[ii]);\n        pi[ii] = pi[ii] - (step_size * update);\n      } else {\n        overflow = true;\n      }\n    }\n\n#pragma unroll\n    for (int ii = 0; ii < ILP; ii++) {\n      int j = j_start + i + totThreads * ii;\n      if (j < tsize) {\n        m[j] = mi[ii];\n        v[j] = vi[ii];\n        p[j] = pi[ii];\n        if (p_copy != NULL) {\n          convert(pi[ii], p_copy[j]);\n        }\n      }\n    }\n  }\n\n  if (p_copy != NULL) {\n    __syncthreads();\n    if (overflow) {\n      convert(float(INFINITY), p_copy[0]);\n    }\n  }\n}\n\ntemplate <typename T, typename GRAD_T>\n__global__ void maybe_adam_undo_cuda_kernel(volatile int* overflow_flag, T* __restrict__ p, T* __restrict__ m,\n                                            T* __restrict__ v, const GRAD_T* __restrict__ g, const float b1,\n                                            const float b2, const float eps, const float grad_scale,\n                                            const float step_size, const size_t tsize, adamMode_t mode,\n                                            const float decay) {\n  // NB! Skip undo kernel when overflow flag is NOT set\n  if (overflow_flag && *overflow_flag == 0) return;\n\n  // Assuming 2D grids and 2D blocks\n  const int blockId = gridDim.x * blockIdx.y + blockIdx.x;\n  const int threadsPerBlock = blockDim.x * blockDim.y;\n  const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;\n  const int i = (blockId * threadsPerBlock + threadIdInBlock);\n  const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;\n\n  T mi[ILP];\n  T vi[ILP];\n  T pi[ILP];\n  T gi[ILP];\n\n  for (int j_start = 0; j_start < tsize; j_start += totThreads * ILP) {\n#pragma unroll\n    for (int ii = 0; ii < ILP; ii++) {\n      mi[ii] = T(0);\n      vi[ii] = T(0);\n      pi[ii] = T(0);\n      gi[ii] = GRAD_T(0);\n\n      int j = j_start + i * ILP;\n      if (j < tsize) {\n        pi[ii] = p[j];\n        mi[ii] = m[j];\n        vi[ii] = v[j];\n        gi[ii] = static_cast<T>(g[j]);\n      }\n    }\n\n#pragma unroll\n    for (int ii = 0; ii < ILP; ii++) {\n      T scaled_grad = gi[ii] / grad_scale;\n      if (isfinite(scaled_grad)) {\n        float denom;\n        if (mode == ADAM_MODE_0)\n          denom = sqrtf(vi[ii] + eps);\n        else  // Mode 1\n          denom = sqrtf(vi[ii]) + eps;\n        pi[ii] = (pi[ii] + step_size * (mi[ii] / denom)) / (1.0f - step_size * decay);\n        mi[ii] = (mi[ii] - (1 - b1) * scaled_grad) / b1;\n        vi[ii] = (vi[ii] - (1 - b2) * scaled_grad * scaled_grad) / b2;\n        // Make sure round off errors don't create (small) negative value.\n        // This can happen if we have to revert the very first step.\n        vi[ii] = vi[ii] >= 0.0f ? vi[ii] : 0.0f;\n      }\n    }\n\n#pragma unroll\n    for (int ii = 0; ii < ILP; ii++) {\n      int j = j_start + i * ILP;\n      if (j < tsize) {\n        m[j] = mi[ii];\n        v[j] = vi[ii];\n        p[j] = pi[ii];\n      }\n    }\n  }\n}\n\ntemplate <int DEPTH, typename FROM_T, typename TO_T>\nstruct MaybeCastFunctor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* overflow_flag,\n                                             TensorListMetadata<DEPTH>& tl) {\n    if (overflow_flag && *overflow_flag != 0) return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    FROM_T* p_in = (FROM_T*)tl.addresses[0][tensor_loc];\n    p_in += chunk_idx * chunk_size;\n    TO_T* p_out = (TO_T*)tl.addresses[1][tensor_loc];\n    p_out += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n    int dim = chunk_size < n ? chunk_size : n;\n\n    FROM_T pi[ILP];\n    TO_T po[ILP];\n\n    for (int j_start = 0; j_start < dim; j_start += blockDim.x * ILP) {\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        pi[ii] = FROM_T(0);\n        int j = j_start + threadIdx.x + ii * blockDim.x;\n        if (j < dim) {\n          pi[ii] = p_in[j];\n        }\n      }\n\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        convert(pi[ii], po[ii]);\n      }\n\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int j = j_start + threadIdx.x + ii * blockDim.x;\n        if (j < dim) {\n          p_out[j] = po[ii];\n        }\n      }\n    }\n  }\n};\n\nvoid fused_strided_check_finite(at::Tensor& overflow_flag, at::Tensor& p_copy, int stride, int clear_overflow_first) {\n  // Get tensor size\n  int tsize = p_copy.numel();\n  int niter = (tsize + stride - 1) / stride;\n\n  // Determine #threads and #blocks\n  const int threadsPerBlock = 512;\n  // In order to avoid race condition, blocks must be 1 when clear_overflow_first flag is set.\n  const dim3 blocks(clear_overflow_first ? 1 : (niter + threadsPerBlock - 1) / threadsPerBlock);\n  TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p_copy), \"parameter tensor is too large to be indexed with int32\");\n\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n  using namespace at;  // prevents \"toString is undefined\" errors\n  DISPATCH_FLOAT_HALF_AND_BYTE(\n      p_copy.scalar_type(), 0, \"check_finite_cuda_kernel\",\n      strided_check_finite_cuda_kernel<scalar_t_0><<<blocks, threadsPerBlock, 0, stream>>>(\n          overflow_flag.data_ptr<int>(), p_copy.data_ptr<scalar_t_0>(), tsize, stride, clear_overflow_first););\n  C10_CUDA_CHECK(cudaGetLastError());\n}\n\nvoid fused_reversible_adam_cuda(at::Tensor& p, at::Tensor& p_copy, at::Tensor& m, at::Tensor& v, at::Tensor& g,\n                                float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode,\n                                int bias_correction, float decay) {\n  //      using namespace at;\n\n  // Get tensor size\n  int tsize = p.numel();\n  // Determine #threads and #blocks\n  const int threadsPerBlock = 512;\n  const dim3 blocks((tsize + threadsPerBlock - 1) / threadsPerBlock);\n  TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p), \"parameter tensor is too large to be indexed with int32\");\n  // Constants\n  float step_size = 0;\n  if (bias_correction == 1) {\n    const float bias_correction1 = 1 - std::pow(beta1, step);\n    const float bias_correction2 = 1 - std::pow(beta2, step);\n    step_size = lr * std::sqrt(bias_correction2) / bias_correction1;\n  } else {\n    step_size = lr;\n  }\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  if (g.scalar_type() == at::ScalarType::Half) {\n    // all other values should be fp32 for half gradients\n    TORCH_CHECK(p.scalar_type() == at::ScalarType::Float, \"expected parameter to be of float type\");\n    // dispatch is done on the gradient type\n    using namespace at;  // prevents \"toString is undefined\" errors\n    if (p_copy.numel() == 0 || p_copy.scalar_type() == g.scalar_type()) {\n      DISPATCH_FLOAT_AND_HALF(\n          g.scalar_type(), 0, \"adam_cuda_kernel\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n          reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, scalar_t_0><<<blocks, threadsPerBlock, 0, stream>>>(\n              p.data_ptr<accscalar_t>(), p_copy.numel() ? p_copy.data_ptr<scalar_t_0>() : NULL,\n              m.data_ptr<accscalar_t>(), v.data_ptr<accscalar_t>(), g.data_ptr<scalar_t_0>(), beta1, beta2, eps,\n              grad_scale, step_size, tsize, (adamMode_t)mode, decay););\n    } else {\n      TORCH_CHECK(p_copy.scalar_type() == at::ScalarType::Byte, \"expected parameter to be of byte type\");\n      DISPATCH_FLOAT_AND_HALF(\n          g.scalar_type(), 0, \"adam_cuda_e5m2_kernel\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n          reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, uint8_t><<<blocks, threadsPerBlock, 0, stream>>>(\n              p.data_ptr<accscalar_t>(), p_copy.data_ptr<uint8_t>(), m.data_ptr<accscalar_t>(),\n              v.data_ptr<accscalar_t>(), g.data_ptr<scalar_t_0>(), beta1, beta2, eps, grad_scale, step_size, tsize,\n              (adamMode_t)mode, decay););\n    }\n  } else {\n    using namespace at;\n    DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, \"adam_cuda_kernel\",\n                              reversible_adam_cuda_kernel<scalar_t_0, scalar_t_0, scalar_t_0>\n                              <<<blocks, threadsPerBlock, 0, stream>>>(\n                                  p.data_ptr<scalar_t_0>(),\n                                  NULL,  // don't output p_copy for fp32, it's wasted write\n                                  m.data_ptr<scalar_t_0>(), v.data_ptr<scalar_t_0>(), g.data_ptr<scalar_t_0>(), beta1,\n                                  beta2, eps, grad_scale, step_size, tsize, (adamMode_t)mode, decay););\n  }\n  C10_CUDA_CHECK(cudaGetLastError());\n}\n\nvoid maybe_cast_cuda(at::Tensor& overflow_flag, at::Tensor& p_in, at::Tensor& p_out) {\n  // Get tensor size\n  int tsize = p_in.numel();\n  TORCH_CHECK(tsize == p_out.numel(), \"p_in.numel() must equal p_out.numel()\");\n  // Determine #threads and #blocks\n  const int threadsPerBlock = 512;\n  const dim3 blocks((tsize + threadsPerBlock - 1) / threadsPerBlock);\n  TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p_in), \"parameter tensor is too large to be indexed with int32\");\n  // Constants\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n  DISPATCH_FLOAT_HALF_AND_BYTE(p_in.scalar_type(), 0,\n                               \"maybe_cast_cuda\" DISPATCH_FLOAT_HALF_AND_BYTE(\n                                   p_out.scalar_type(), 1, \"maybe_cast_cuda\",\n                                   maybe_cast_kernel<scalar_t_0, scalar_t_1><<<blocks, threadsPerBlock, 0, stream>>>(\n                                       overflow_flag.numel() ? overflow_flag.data_ptr<int>() : NULL,\n                                       p_in.data_ptr<scalar_t_0>(), p_out.data_ptr<scalar_t_1>(), tsize);))\n  C10_CUDA_CHECK(cudaGetLastError());\n}\n\nvoid maybe_cast_cuda_mt(int chunk_size, at::Tensor overflow_flag,\n                        std::vector<std::vector<at::Tensor>> tensor_lists)  // p_in, p_out\n{\n  // Constants\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  size_t tl_sz = tensor_lists.size();\n  TORCH_CHECK(tl_sz == 2, \"expected tensor lists of size 2\");\n\n  DISPATCH_FLOAT_HALF_AND_BYTE(\n      tensor_lists[0][0].scalar_type(), 0, \"maybe_cast_cuda_mt_kernel\",\n      DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[1][0].scalar_type(), 1, \"maybe_cast_cuda_mt_kernel\",\n                                   multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, overflow_flag, tensor_lists,\n                                                         MaybeCastFunctor<2, scalar_t_0, scalar_t_1>());))\n  C10_CUDA_CHECK(cudaGetLastError());\n}\n\nvoid fused_maybe_adam_undo_cuda(at::Tensor& overflow_flag, at::Tensor& p, at::Tensor& m, at::Tensor& v, at::Tensor& g,\n                                float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode,\n                                int bias_correction, float decay) {\n  // Get tensor size\n  int tsize = p.numel();\n  // Determine #threads and #blocks\n  const int threadsPerBlock = 512;\n  const dim3 blocks((tsize + threadsPerBlock - 1) / threadsPerBlock);\n  TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(p), \"parameter tensor is too large to be indexed with int32\");\n  // Constants\n  float step_size = 0;\n  if (bias_correction == 1) {\n    const float bias_correction1 = 1 - std::pow(beta1, step);\n    const float bias_correction2 = 1 - std::pow(beta2, step);\n    step_size = lr * std::sqrt(bias_correction2) / bias_correction1;\n  } else {\n    step_size = lr;\n  }\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  if (g.scalar_type() == at::ScalarType::Half) {\n    // all other values should be fp32 for half gradients\n    TORCH_CHECK(p.scalar_type() == at::ScalarType::Float, \"expected parameter to be of float type\");\n    // dispatch is done on the gradient type\n    using namespace at;  // prevents \"toString is undefined\" errors\n    DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, \"adam_cuda_kernel\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n                            maybe_adam_undo_cuda_kernel<accscalar_t, scalar_t_0>\n                            <<<blocks, threadsPerBlock, 0, stream>>>(\n                                overflow_flag.numel() ? overflow_flag.data_ptr<int>() : NULL, p.data_ptr<accscalar_t>(),\n                                m.data_ptr<accscalar_t>(), v.data_ptr<accscalar_t>(), g.data_ptr<scalar_t_0>(), beta1,\n                                beta2, eps, grad_scale, step_size, tsize, (adamMode_t)mode, decay););\n  } else {\n    using namespace at;\n    DISPATCH_DOUBLE_AND_FLOAT(\n        g.scalar_type(), 0, \"adam_cuda_kernel\",\n        maybe_adam_undo_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks, threadsPerBlock, 0, stream>>>(\n            overflow_flag.numel() ? overflow_flag.data_ptr<int>() : NULL, p.data_ptr<scalar_t_0>(),\n            m.data_ptr<scalar_t_0>(), v.data_ptr<scalar_t_0>(), g.data_ptr<scalar_t_0>(), beta1, beta2, eps, grad_scale,\n            step_size, tsize, (adamMode_t)mode, decay););\n  }\n  C10_CUDA_CHECK(cudaGetLastError());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp",
    "content": "#include <torch/extension.h>\n\nvoid multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                            const float lr, const float beta1, const float beta2, const float epsilon, const int step,\n                            const int bias_correction, const float weight_decay, const int grad_averaging,\n                            const int mode, const float global_grad_norm, const float max_grad_norm);\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"lamb\", &multi_tensor_lamb_cuda, \"Computes and apply update for LAMB optimizer\",\n        py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n\n#include \"multi_tensor_apply.cuh\"\n#include \"type_shim.h\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntypedef enum {\n  MOMENT_MODE_0 = 0,  // L2 regularization mode\n  MOMENT_MODE_1 = 1   // Decoupled weight decay mode\n} adamMode_t;\n\nstd::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,\n                                                            std::vector<std::vector<at::Tensor>> tensor_lists,\n                                                            at::optional<bool> per_tensor_python);\n\nusing MATH_T = float;\n\ntemplate <typename T>\nstruct LAMBStage1Functor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<4>& tl,\n                                             const float beta1, const float beta2, const float beta3,\n                                             const float beta1_correction, const float beta2_correction,\n                                             const float epsilon, adamMode_t mode, const float decay,\n                                             const float global_grad_norm, const float max_global_grad_norm) {\n    // I'd like this kernel to propagate infs/nans.\n    // if(*noop_gmem == 1)\n    //   return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    float clipped_global_grad_norm =\n        global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f;\n\n    T* g = (T*)tl.addresses[0][tensor_loc];\n    g += chunk_idx * chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx * chunk_size;\n\n    T* m = (T*)tl.addresses[2][tensor_loc];\n    m += chunk_idx * chunk_size;\n\n    T* v = (T*)tl.addresses[3][tensor_loc];\n    v += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    // see note in multi_tensor_scale_kernel.cu\n    for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n      MATH_T r_g[ILP];\n      MATH_T r_p[ILP];\n      MATH_T r_m[ILP];\n      MATH_T r_v[ILP];\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          r_g[ii] = g[i];\n          // special ?optimization? for lamb stage 1\n          if (decay == 0) {\n            r_p[ii] = MATH_T(0);\n          } else {\n            r_p[ii] = p[i];\n          }\n          r_m[ii] = m[i];\n          r_v[ii] = v[i];\n        } else {\n          r_g[ii] = MATH_T(0);\n          r_p[ii] = MATH_T(0);\n          r_m[ii] = MATH_T(0);\n          r_v[ii] = MATH_T(0);\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        if (mode == MOMENT_MODE_0) {\n          MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;\n          // L2 on scaled grad\n          scaled_grad = scaled_grad + decay * r_p[ii];\n          r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n          r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;\n          MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n          MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n          MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n          r_p[ii] = next_m_unbiased / denom;\n        } else {\n          MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;\n          r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n          r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;\n          MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n          MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n          MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n          r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]);\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          g[i] = r_p[ii];\n          m[i] = r_m[ii];\n          v[i] = r_v[ii];\n        }\n      }\n    }\n  }\n};\n\n// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.\n// It computes new parameter value.\ntemplate <typename T>\nstruct LAMBStage2Functor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<2>& tl,\n                                             const float* per_tensor_param_norm, const float* per_tensor_update_norm,\n                                             const float learning_rate, const float decay) {\n    // I'd like this kernel to propagate infs/nans.\n    // if(*noop_gmem == 1)\n    //   return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int tensor_num = tl.start_tensor_this_launch + tensor_loc;\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    MATH_T ratio = learning_rate;\n    // apply adaptive learning rate to parameters with non-zero weight decay\n    if (decay != 0.0) {\n      float param_norm = per_tensor_param_norm[tensor_num];\n      float update_norm = per_tensor_update_norm[tensor_num];\n      ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;\n    }\n\n    T* update = (T*)tl.addresses[0][tensor_loc];\n    update += chunk_idx * chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n      MATH_T r_p[ILP];\n      MATH_T r_update[ILP];\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          r_p[ii] = p[i];\n          r_update[ii] = update[i];\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        r_p[ii] = r_p[ii] - (ratio * r_update[ii]);\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          p[i] = r_p[ii];\n        }\n      }\n    }\n  }\n};\n\nvoid multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                            const float lr, const float beta1, const float beta2, const float epsilon, const int step,\n                            const int bias_correction, const float weight_decay, const int grad_averaging,\n                            const int mode, const float global_grad_norm, const float max_grad_norm) {\n  using namespace at;\n  // Master weight and 32bit momentum(potentially changing) is not handled by this\n  // So we assume every tensor are all in the same type\n\n  // Handle bias correction mode\n  float bias_correction1 = 1.0f, bias_correction2 = 1.0f;\n  if (bias_correction == 1) {\n    bias_correction1 = 1 - std::pow(beta1, step);\n    bias_correction2 = 1 - std::pow(beta2, step);\n  }\n\n  // Handle grad averaging mode\n  float beta3 = 1.0f;\n  if (grad_averaging == 1) beta3 = 1 - beta1;\n\n  std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin() + 1);\n  std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin() + 1, tensor_lists.begin() + 2);\n\n  // Compute per tensor param norm\n  auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);\n\n  // We now in-place modify grad to store update before compute its norm\n  // Generally this is not a issue since people modify grad in step() method all the time\n  // We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code\n  DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, \"lamb_stage_1\",\n                          multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,\n                                                LAMBStage1Functor<scalar_t_0>(), beta1, beta2,\n                                                beta3,  // 1-beta1 or 1 depends on averaging mode\n                                                bias_correction1, bias_correction2, epsilon, (adamMode_t)mode,\n                                                weight_decay, global_grad_norm, max_grad_norm);)\n\n  // Compute update norms\n  auto update_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true);\n\n  std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin() + 2);\n\n  DISPATCH_FLOAT_AND_HALF(\n      tensor_lists[0][0].scalar_type(), 0, \"lamb_stage_2\",\n      multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, grad_param_list, LAMBStage2Functor<scalar_t_0>(),\n                            std::get<1>(param_norm_tuple).data_ptr<float>(),\n                            std::get<1>(update_norm_tuple).data_ptr<float>(), lr, weight_decay);)\n\n  AT_CUDA_CHECK(cudaGetLastError());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp",
    "content": "#include <torch/extension.h>\n\nvoid multi_tensor_fused_adam_cuda(int chunk_size, at::Tensor noop_flag,\n                                  std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor grad_scale, float lr,\n                                  float beta1, float beta2, float eps, int step, int mode, int bias_correction,\n                                  float weight_decay);\n\nvoid multi_tensor_fused_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,\n                                             std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor grad_scale,\n                                             at::Tensor lr, float beta1, float beta2, float eps, at::Tensor step,\n                                             int mode, int bias_correction, float weight_decay);\n\nvoid multi_tensor_fused_adam_with_param_remainders_cuda(int chunk_size, at::Tensor noop_flag,\n                                                        std::vector<std::vector<at::Tensor>> tensor_lists,\n                                                        at::Tensor grad_scale, float lr, float beta1, float beta2,\n                                                        float eps, int step, int mode, int bias_correction,\n                                                        float weight_decay);\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"multi_tensor_fused_adam\", &multi_tensor_fused_adam_cuda,\n        \"CUDA kernels for multi-tensor Adam, \"\n        \"with param copy\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"multi_tensor_fused_adam_capturable\", &multi_tensor_fused_adam_capturable_cuda,\n        \"CUDA kernels for multi-tensor Adam, \"\n        \"with param copy, capturable for CUDA graph\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"multi_tensor_fused_adam_with_param_remainders\", &multi_tensor_fused_adam_with_param_remainders_cuda,\n        \"CUDA kernel for multi-tensor Adam, \"\n        \"with stored param remainders and param copy\",\n        py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n\n#include <cmath>\n\n#include \"multi_tensor_apply.cuh\"\n#include \"type_shim.h\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntemplate <typename T>\n__device__ __forceinline__ bool is_aligned(const T* p) {\n  return ((uint64_t)p) % (ILP * sizeof(T)) == 0;\n}\n\ntemplate <typename T>\n__device__ __forceinline__ void load_store(T* dst, const T* src, int dst_offset = 0, int src_offset = 0) {\n  typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;\n  ((LT*)dst)[dst_offset] = ((const LT*)src)[src_offset];\n}\n\n// (1-t)*x + t*y\n// Note: Named _lerp to avoid ambiguity with std::lerp under C++20.\n__device__ __forceinline__ float _lerp(float t, float x, float y) {\n  // See https://developer.nvidia.com/blog/lerp-faster-cuda/\n  return fma(t, y, fma(-t, x, x));\n}\n\ntypedef enum {\n  ADAM_MODE_0 = 0,  // L2 regularization mode\n  ADAM_MODE_1 = 1   // Decoupled weight decay mode(AdamW)\n} adamMode_t;\n\n/* Multi-tensor Adam\n *\n * Updates params in-place and outputs a copy with a desired datatype.\n */\ntemplate <typename T, typename GRAD_T, typename PARAM_OUT_T>\nstruct DistAdamFunctor {\n  // Vectorized local compute\n  __device__ __forceinline__ static void local_step(T p[ILP], T m[ILP], T v[ILP], const GRAD_T g[ILP],\n                                                    const float grad_scale, const float beta1, const float beta2,\n                                                    const float beta1_correction, const float beta2_correction,\n                                                    const float eps, const float lr, adamMode_t mode,\n                                                    const float weight_decay) {\n    if (mode == ADAM_MODE_0) {  // L2\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        float scaled_grad = (g[ii] * grad_scale) + (weight_decay * p[ii]);\n        float next_m = _lerp(beta1, scaled_grad, m[ii]);\n        float next_v = _lerp(beta2, scaled_grad * scaled_grad, v[ii]);\n        float next_m_unbiased = next_m / beta1_correction;\n        float next_v_unbiased = next_v / beta2_correction;\n        float denom = sqrtf(next_v_unbiased) + eps;\n        float update = next_m_unbiased / denom;\n        m[ii] = next_m;\n        v[ii] = next_v;\n        p[ii] -= lr * update;\n      }\n    } else {  // weight decay\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        float scaled_grad = g[ii] * grad_scale;\n        float next_m = _lerp(beta1, scaled_grad, m[ii]);\n        float next_v = _lerp(beta2, scaled_grad * scaled_grad, v[ii]);\n        float next_m_unbiased = next_m / beta1_correction;\n        float next_v_unbiased = next_v / beta2_correction;\n        float denom = sqrtf(next_v_unbiased) + eps;\n        float update = (next_m_unbiased / denom) + (weight_decay * p[ii]);\n        m[ii] = next_m;\n        v[ii] = next_v;\n        p[ii] -= lr * update;\n      }\n    }\n  }\n\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<5>& tl,\n                                             const float* grad_scale_ptr, const float beta1, const float beta2,\n                                             const float beta1_correction, const float beta2_correction,\n                                             const float eps, const float lr, adamMode_t mode,\n                                             const float weight_decay) const {\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    const float grad_scale = *grad_scale_ptr;\n\n    T* p_in = (T*)tl.addresses[0][tensor_loc];\n    p_in += chunk_idx * chunk_size;\n    T* m = (T*)tl.addresses[1][tensor_loc];\n    m += chunk_idx * chunk_size;\n    T* v = (T*)tl.addresses[2][tensor_loc];\n    v += chunk_idx * chunk_size;\n    const GRAD_T* g = (GRAD_T*)tl.addresses[3][tensor_loc];\n    g += chunk_idx * chunk_size;\n    PARAM_OUT_T* p_out = (PARAM_OUT_T*)tl.addresses[4][tensor_loc];\n    p_out += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n    n = chunk_size < n ? chunk_size : n;\n\n    const bool aligned =\n        (n % ILP == 0 && is_aligned(p_in) && is_aligned(m) && is_aligned(v) && is_aligned(g) && is_aligned(p_out));\n\n    for (int i_start = threadIdx.x * ILP; i_start < n; i_start += blockDim.x * ILP) {\n      T local_p[ILP];\n      T local_m[ILP];\n      T local_v[ILP];\n      GRAD_T local_g[ILP];\n      PARAM_OUT_T local_p_out[ILP];\n\n      // Load\n      if (aligned) {\n        load_store(local_p, p_in + i_start);\n        load_store(local_m, m + i_start);\n        load_store(local_v, v + i_start);\n        load_store(local_g, g + i_start);\n      } else {\n#pragma unroll\n        for (int ii = 0, i = i_start; ii < ILP; ii++, i++) {\n          if (i < n) {\n            local_p[ii] = p_in[i];\n            local_m[ii] = m[i];\n            local_v[ii] = v[i];\n            local_g[ii] = g[i];\n          } else {\n            local_p[ii] = 0;\n            local_m[ii] = 0;\n            local_v[ii] = 0;\n            local_g[ii] = 0;\n          }\n        }\n      }\n\n      // Local compute\n      local_step(local_p, local_m, local_v, local_g, grad_scale, beta1, beta2, beta1_correction, beta2_correction, eps,\n                 lr, mode, weight_decay);\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        local_p_out[ii] = static_cast<PARAM_OUT_T>(local_p[ii]);\n      }\n\n      // Store\n      if (aligned) {\n        load_store(p_in + i_start, local_p);\n        load_store(m + i_start, local_m);\n        load_store(v + i_start, local_v);\n        load_store(p_out + i_start, local_p_out);\n      } else {\n#pragma unroll\n        for (int ii = 0, i = i_start; ii < ILP; ii++, i++) {\n          if (i < n) {\n            p_in[i] = local_p[ii];\n            m[i] = local_m[ii];\n            v[i] = local_v[ii];\n            p_out[i] = local_p_out[ii];\n          }\n        }\n      }\n    }\n  }\n};\n\n/* Multi-tensor Adam with CUDA Graph Support\n *\n * Updates params in-place and outputs a copy with a desired datatype.\n */\ntemplate <typename T, typename GRAD_T, typename PARAM_OUT_T>\nstruct DistAdamCapturableFunctor {\n  // Vectorized local compute\n  __device__ __forceinline__ static void local_step(T p[ILP], T m[ILP], T v[ILP], const GRAD_T g[ILP],\n                                                    const float grad_scale, const float beta1, const float beta2,\n                                                    const float beta1_correction, const float beta2_correction,\n                                                    const float eps, const float lr, adamMode_t mode,\n                                                    const float weight_decay) {\n    if (mode == ADAM_MODE_0) {  // L2\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        float scaled_grad = (g[ii] * grad_scale) + (weight_decay * p[ii]);\n        float next_m = _lerp(beta1, scaled_grad, m[ii]);\n        float next_v = _lerp(beta2, scaled_grad * scaled_grad, v[ii]);\n        float next_m_unbiased = next_m / beta1_correction;\n        float next_v_unbiased = next_v / beta2_correction;\n        float denom = sqrtf(next_v_unbiased) + eps;\n        float update = next_m_unbiased / denom;\n        m[ii] = next_m;\n        v[ii] = next_v;\n        p[ii] -= lr * update;\n      }\n    } else {  // weight decay\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        float scaled_grad = g[ii] * grad_scale;\n        float next_m = _lerp(beta1, scaled_grad, m[ii]);\n        float next_v = _lerp(beta2, scaled_grad * scaled_grad, v[ii]);\n        float next_m_unbiased = next_m / beta1_correction;\n        float next_v_unbiased = next_v / beta2_correction;\n        float denom = sqrtf(next_v_unbiased) + eps;\n        float update = (next_m_unbiased / denom) + (weight_decay * p[ii]);\n        m[ii] = next_m;\n        v[ii] = next_v;\n        p[ii] -= lr * update;\n      }\n    }\n  }\n\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<5>& tl,\n                                             const float* grad_scale_ptr, const float beta1, const float beta2,\n                                             const int* step, const int bias_correction, const float eps,\n                                             const float* lr, adamMode_t mode, const float weight_decay) const {\n    assert(noop_gmem);\n    assert(grad_scale_ptr);\n    assert(step);\n    assert(lr);\n\n    if (*noop_gmem == 1) return;\n\n    float beta1_correction = 1.0f, beta2_correction = 1.0f;\n    if (bias_correction == 1) {\n      beta1_correction = 1 - pow(beta1, *step);\n      beta2_correction = 1 - pow(beta2, *step);\n    }\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    const float grad_scale = *grad_scale_ptr;\n\n    T* p_in = (T*)tl.addresses[0][tensor_loc];\n    p_in += chunk_idx * chunk_size;\n    T* m = (T*)tl.addresses[1][tensor_loc];\n    m += chunk_idx * chunk_size;\n    T* v = (T*)tl.addresses[2][tensor_loc];\n    v += chunk_idx * chunk_size;\n    const GRAD_T* g = (GRAD_T*)tl.addresses[3][tensor_loc];\n    g += chunk_idx * chunk_size;\n    PARAM_OUT_T* p_out = (PARAM_OUT_T*)tl.addresses[4][tensor_loc];\n    p_out += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n    n = chunk_size < n ? chunk_size : n;\n\n    const bool aligned =\n        (n % ILP == 0 && is_aligned(p_in) && is_aligned(m) && is_aligned(v) && is_aligned(g) && is_aligned(p_out));\n\n    for (int i_start = threadIdx.x * ILP; i_start < n; i_start += blockDim.x * ILP) {\n      T local_p[ILP];\n      T local_m[ILP];\n      T local_v[ILP];\n      GRAD_T local_g[ILP];\n      PARAM_OUT_T local_p_out[ILP];\n\n      // Load\n      if (aligned) {\n        load_store(local_p, p_in + i_start);\n        load_store(local_m, m + i_start);\n        load_store(local_v, v + i_start);\n        load_store(local_g, g + i_start);\n      } else {\n#pragma unroll\n        for (int ii = 0, i = i_start; ii < ILP; ii++, i++) {\n          if (i < n) {\n            local_p[ii] = p_in[i];\n            local_m[ii] = m[i];\n            local_v[ii] = v[i];\n            local_g[ii] = g[i];\n          } else {\n            local_p[ii] = 0;\n            local_m[ii] = 0;\n            local_v[ii] = 0;\n            local_g[ii] = 0;\n          }\n        }\n      }\n\n      // Local compute\n      local_step(local_p, local_m, local_v, local_g, grad_scale, beta1, beta2, beta1_correction, beta2_correction, eps,\n                 *lr, mode, weight_decay);\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        local_p_out[ii] = static_cast<PARAM_OUT_T>(local_p[ii]);\n      }\n\n      // Store\n      if (aligned) {\n        load_store(p_in + i_start, local_p);\n        load_store(m + i_start, local_m);\n        load_store(v + i_start, local_v);\n        load_store(p_out + i_start, local_p_out);\n      } else {\n#pragma unroll\n        for (int ii = 0, i = i_start; ii < ILP; ii++, i++) {\n          if (i < n) {\n            p_in[i] = local_p[ii];\n            m[i] = local_m[ii];\n            v[i] = local_v[ii];\n            p_out[i] = local_p_out[ii];\n          }\n        }\n      }\n    }\n  }\n};\n\n/* Functor for multi-tensor Adam with implicit main params\n *\n * If params are BF16 and optimizer state is FP32, it is not necessary\n * to store FP32 main params. Instead, store 16-bit param remainder\n * and combine with BF16 param to reconstruct the FP32 main param.\n */\ntemplate <typename GRAD_T>\nstruct DistAdamWithParamRemaindersFunctor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<6>& tl,\n                                             const float* grad_scale_ptr, const float beta1, const float beta2,\n                                             const float beta1_correction, const float beta2_correction,\n                                             const float eps, const float lr, adamMode_t mode,\n                                             const float weight_decay) const {\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    const float grad_scale = *grad_scale_ptr;\n\n    int16_t* p_in = (int16_t*)tl.addresses[0][tensor_loc];\n    p_in += chunk_idx * chunk_size;\n    int16_t* p_rem = (int16_t*)tl.addresses[1][tensor_loc];\n    p_rem += chunk_idx * chunk_size;\n    float* m = (float*)tl.addresses[2][tensor_loc];\n    m += chunk_idx * chunk_size;\n    float* v = (float*)tl.addresses[3][tensor_loc];\n    v += chunk_idx * chunk_size;\n    const GRAD_T* g = (GRAD_T*)tl.addresses[4][tensor_loc];\n    g += chunk_idx * chunk_size;\n    int16_t* p_out = (int16_t*)tl.addresses[5][tensor_loc];\n    p_out += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n    n = chunk_size < n ? chunk_size : n;\n\n    const bool aligned = (n % ILP == 0 && is_aligned(p_in) && is_aligned(p_rem) && is_aligned(m) && is_aligned(v) &&\n                          is_aligned(g) && is_aligned(p_out));\n\n    for (int i_start = threadIdx.x * ILP; i_start < n; i_start += blockDim.x * ILP) {\n      union fp32_or_int162 {\n        float fp32;\n        int16_t int16[2];\n      };\n      fp32_or_int162 local_p[ILP];\n      int16_t local_p_bf16[ILP];\n      int16_t local_p_rem[ILP];\n      float local_m[ILP];\n      float local_v[ILP];\n      GRAD_T local_g[ILP];\n\n      // Load\n      if (aligned) {\n        load_store(local_p_bf16, p_in + i_start);\n        load_store(local_p_rem, p_rem + i_start);\n        load_store(local_m, m + i_start);\n        load_store(local_v, v + i_start);\n        load_store(local_g, g + i_start);\n      } else {\n#pragma unroll\n        for (int ii = 0, i = i_start; ii < ILP; ii++, i++) {\n          if (i < n) {\n            local_p_bf16[ii] = p_in[i];\n            local_p_rem[ii] = p_rem[i];\n            local_m[ii] = m[i];\n            local_v[ii] = v[i];\n            local_g[ii] = g[i];\n          } else {\n            local_p_bf16[ii] = 0;\n            local_p_rem[ii] = 0;\n            local_m[ii] = 0;\n            local_v[ii] = 0;\n            local_g[ii] = 0;\n          }\n        }\n      }\n\n      // Reconstruct FP32 params\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        if (local_p_rem[ii] < 0) local_p_bf16[ii]--;  // Undo rounding\n        local_p[ii].int16[1] = local_p_bf16[ii];\n        local_p[ii].int16[0] = local_p_rem[ii];\n      }\n\n      // Local compute\n      using LocalFunctor = DistAdamFunctor<float, GRAD_T, void>;\n      LocalFunctor::local_step(reinterpret_cast<float*>(local_p), local_m, local_v, local_g, grad_scale, beta1, beta2,\n                               beta1_correction, beta2_correction, eps, lr, mode, weight_decay);\n\n      // Split into BF16 params (rounded-to-nearest) and remainders\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        local_p_bf16[ii] = local_p[ii].int16[1];\n        local_p_rem[ii] = local_p[ii].int16[0];\n        if (local_p_rem[ii] < 0) local_p_bf16[ii]++;  // Round up\n      }\n\n      // Store\n      if (aligned) {\n        load_store(p_rem + i_start, local_p_rem);\n        load_store(m + i_start, local_m);\n        load_store(v + i_start, local_v);\n        load_store(p_out + i_start, local_p_bf16);\n      } else {\n#pragma unroll\n        for (int ii = 0, i = i_start; ii < ILP; ii++, i++) {\n          if (i < n) {\n            p_rem[i] = local_p_rem[ii];\n            m[i] = local_m[ii];\n            v[i] = local_v[ii];\n            p_out[i] = local_p_bf16[ii];\n          }\n        }\n      }\n    }\n  }\n};\n\nvoid multi_tensor_fused_adam_cuda(int chunk_size, at::Tensor noop_flag,\n                                  std::vector<std::vector<at::Tensor>> tensor_lists,  // p_in, m, v, g, p_out\n                                  at::Tensor grad_scale, float lr, float beta1, float beta2, float eps, int step,\n                                  int mode, int bias_correction, float weight_decay) {\n  using namespace at;\n\n  // Expect p_in, m, v, g, p_out\n  size_t tl_sz = tensor_lists.size();\n  TORCH_CHECK(tl_sz == 5, \"expected tensor lists of size 5\");\n  const auto p_in_type = tensor_lists[0][0].scalar_type();\n  const auto g_type = tensor_lists[3][0].scalar_type();\n  const auto p_out_type = tensor_lists[4][0].scalar_type();\n\n  float beta1_correction = 1.0f, beta2_correction = 1.0f;\n  if (bias_correction == 1) {\n    beta1_correction = 1 - std::pow(beta1, step);\n    beta2_correction = 1 - std::pow(beta2, step);\n  }\n\n  DISPATCH_FLOAT_HALF_AND_BFLOAT(\n      p_in_type, 0, \"dist_adam_cuda_kernel\",\n      DISPATCH_FLOAT_HALF_AND_BFLOAT(\n          g_type, 1, \"dist_adam_cuda_kernel\",\n          DISPATCH_FLOAT_HALF_AND_BFLOAT(\n              p_out_type, 2, \"dist_adam_cuda_kernel\",\n              multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,\n                                    DistAdamFunctor<scalar_t_0, scalar_t_1, scalar_t_2>(), grad_scale.data_ptr<float>(),\n                                    beta1, beta2, beta1_correction, beta2_correction, eps, lr, (adamMode_t)mode,\n                                    weight_decay);)));\n  C10_CUDA_CHECK(cudaGetLastError());\n}\n\nvoid multi_tensor_fused_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,\n                                             std::vector<std::vector<at::Tensor>> tensor_lists,  // p_in, m, v, g, p_out\n                                             at::Tensor grad_scale, at::Tensor lr, float beta1, float beta2, float eps,\n                                             at::Tensor step, int mode, int bias_correction, float weight_decay) {\n  using namespace at;\n\n  // Expect p_in, m, v, g, p_out\n  size_t tl_sz = tensor_lists.size();\n  TORCH_CHECK(tl_sz == 5, \"expected tensor lists of size 5\");\n  const auto p_in_type = tensor_lists[0][0].scalar_type();\n  const auto g_type = tensor_lists[3][0].scalar_type();\n  const auto p_out_type = tensor_lists[4][0].scalar_type();\n\n  DISPATCH_FLOAT_HALF_AND_BFLOAT(\n      p_in_type, 0, \"dist_adam_capturable_cuda_kernel\",\n      DISPATCH_FLOAT_HALF_AND_BFLOAT(\n          g_type, 1, \"dist_adam_capturable_cuda_kernel\",\n          DISPATCH_FLOAT_HALF_AND_BFLOAT(\n              p_out_type, 2, \"dist_adam_capturable_cuda_kernel\",\n              multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,\n                                    DistAdamCapturableFunctor<scalar_t_0, scalar_t_1, scalar_t_2>(),\n                                    grad_scale.data_ptr<float>(), beta1, beta2, step.data_ptr<int>(), bias_correction,\n                                    eps, lr.data_ptr<float>(), (adamMode_t)mode, weight_decay);)));\n  C10_CUDA_CHECK(cudaGetLastError());\n}\n\nvoid multi_tensor_fused_adam_with_param_remainders_cuda(\n    int chunk_size, at::Tensor noop_flag,\n    std::vector<std::vector<at::Tensor>> tensor_lists,  // p_in, p_rem, m, v, g, p_out\n    at::Tensor grad_scale, float lr, float beta1, float beta2, float eps, int step, int mode, int bias_correction,\n    float weight_decay) {\n  using namespace at;\n\n  // Expect p_in, p_rem, m, v, g, p_out\n  size_t tl_sz = tensor_lists.size();\n  TORCH_CHECK(tl_sz == 6, \"expected tensor lists of size 6\");\n  const auto g_type = tensor_lists[4][0].scalar_type();\n\n  float beta1_correction = 1.0f, beta2_correction = 1.0f;\n  if (bias_correction == 1) {\n    beta1_correction = 1 - std::pow(beta1, step);\n    beta2_correction = 1 - std::pow(beta2, step);\n  }\n\n  DISPATCH_FLOAT_HALF_AND_BFLOAT(\n      g_type, 0, \"dist_adam_with_param_remainders_cuda_kernel\",\n      multi_tensor_apply<6>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,\n                            DistAdamWithParamRemaindersFunctor<scalar_t_0>(), grad_scale.data_ptr<float>(), beta1,\n                            beta2, beta1_correction, beta2_correction, eps, lr, (adamMode_t)mode, weight_decay););\n  C10_CUDA_CHECK(cudaGetLastError());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp",
    "content": "#include <torch/extension.h>\n\nvoid multi_tensor_lamb_compute_update_term_cuda(int chunk_size, at::Tensor noop_flag,\n                                                std::vector<std::vector<at::Tensor>> tensor_lists,\n                                                at::Tensor per_tensor_beta1, at::Tensor per_tensor_beta2,\n                                                at::Tensor per_tensor_beta3, at::Tensor per_tensor_bias_correction,\n                                                at::Tensor step, at::Tensor per_tensor_epsilon, const int mode,\n                                                at::Tensor per_tensor_decay, at::Tensor global_scale,\n                                                at::Tensor global_grad_norm, const float max_grad_norm);\n\nvoid multi_tensor_lamb_update_weights_cuda(int chunk_size, at::Tensor noop_flag,\n                                           std::vector<std::vector<at::Tensor>> tensor_lists,\n                                           at::Tensor per_tensor_param_norm, at::Tensor per_tensor_update_norm,\n                                           at::Tensor update_norm_offset, at::Tensor learning_rate,\n                                           at::Tensor per_tensor_decay, at::Tensor global_grad_norm, bool use_nvlamb);\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"multi_tensor_lamb_compute_update_term\", &multi_tensor_lamb_compute_update_term_cuda,\n        \"Computes update term for LAMB optimizer\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"multi_tensor_lamb_update_weights\", &multi_tensor_lamb_update_weights_cuda,\n        \"Applies update term for LAMB optimizer\", py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n\n#include \"multi_tensor_apply.cuh\"\n#include \"type_shim.h\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntemplate <typename T>\n__device__ __forceinline__ bool is_aligned(T* p) {\n  return ((uint64_t)p) % (ILP * sizeof(T)) == 0;\n}\n\ntemplate <typename T>\n__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset) {\n  typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;\n  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];\n}\n\ntemplate <typename FROM_T, typename TO_T>\n__device__ void convert(const FROM_T vi, TO_T& vo) {\n  vo = static_cast<TO_T>(vi);\n}\n\ntemplate <>\n__device__ void convert(const float vi, uint8_t& vo) {\n  union S {\n    float as_float;\n    int as_int;\n  };\n  S s;\n  s.as_float = vi;\n  s.as_int = s.as_int & 0xFF800000;\n  union T {\n    at::Half as_half;\n    uint8_t as_byte[2];\n  };\n  T t;\n  t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);\n  vo = t.as_byte[1];\n}\n\ntemplate <>\n__device__ void convert(const uint8_t vi, float& vo) {\n  union T {\n    at::Half as_half;\n    uint8_t as_byte[2];\n  };\n  T t;\n  t.as_byte[0] = 0;\n  t.as_byte[1] = vi;\n  vo = static_cast<float>(t.as_half);\n}\n\ntemplate <>\n__device__ void convert(const at::Half vi, uint8_t& vo) {\n  union S {\n    float as_float;\n    int as_int;\n  };\n  S s;\n  s.as_float = static_cast<float>(vi);\n  s.as_int = s.as_int & 0xFF800000;\n  union T {\n    at::Half as_half;\n    uint8_t as_byte[2];\n  };\n  T t;\n  t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);\n  vo = t.as_byte[1];\n}\n\ntemplate <>\n__device__ void convert(const uint8_t vi, at::Half& vo) {\n  union T {\n    at::Half as_half;\n    uint8_t as_byte[2];\n  };\n  T t;\n  t.as_byte[0] = 0;\n  t.as_byte[1] = vi;\n  vo = t.as_half;\n}\n\ntypedef enum {\n  MOMENT_MODE_0 = 0,  // L2 regularization mode\n  MOMENT_MODE_1 = 1   // Decoupled weight decay mode\n} adamMode_t;\n\ntemplate <typename T, typename GRAD_T, typename MATH_T>\nstruct DistOptLAMBStage1Functor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<5>& tl,\n                                             const MATH_T* per_tensor_beta1, const MATH_T* per_tensor_beta2,\n                                             const MATH_T* per_tensor_beta3, const int* per_tensor_bias_correction,\n                                             const int* step, const MATH_T* per_tensor_epsilon, adamMode_t mode,\n                                             const MATH_T* per_tensor_decay, const MATH_T* global_scale,\n                                             const MATH_T* global_grad_norm, const float max_grad_norm) {\n    // I'd like this kernel to propagate infs/nans.\n    if (*noop_gmem == 1) return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int tensor_num = tl.start_tensor_this_launch + tensor_loc;\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    float combined_scale = *global_scale;\n    if (max_grad_norm > 0) {\n      combined_scale = max_grad_norm / (*global_grad_norm / *global_scale + 1e-6);\n      combined_scale = *global_scale / std::min((float)1.0, combined_scale);\n    }\n\n    MATH_T beta1 = per_tensor_beta1[tensor_num];\n    MATH_T beta2 = per_tensor_beta2[tensor_num];\n    MATH_T beta3 = 1 - beta1;\n    MATH_T beta1_correction, beta2_correction;\n    if (per_tensor_bias_correction[tensor_num] == 1) {\n      beta1_correction = 1 - pow(beta1, *step);\n      beta2_correction = 1 - pow(beta2, *step);\n    } else {\n      beta1_correction = (MATH_T)1.0;\n      beta2_correction = (MATH_T)1.0;\n    }\n    MATH_T epsilon = per_tensor_epsilon[tensor_num];\n    MATH_T decay = per_tensor_decay[tensor_num];\n\n    GRAD_T* g = (GRAD_T*)tl.addresses[0][tensor_loc];\n    g += chunk_idx * chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx * chunk_size;\n\n    T* m = (T*)tl.addresses[2][tensor_loc];\n    m += chunk_idx * chunk_size;\n\n    T* v = (T*)tl.addresses[3][tensor_loc];\n    v += chunk_idx * chunk_size;\n\n    MATH_T* u = (MATH_T*)tl.addresses[4][tensor_loc];\n    u += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    MATH_T r_g[ILP];\n    MATH_T r_p[ILP];\n    MATH_T r_m[ILP];\n    MATH_T r_v[ILP];\n    // to make things simple, we put aligned case in a different code path\n    if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(g) && is_aligned(p) && is_aligned(m) && is_aligned(v)) {\n      GRAD_T l_g[ILP];\n      T l_p[ILP];\n      T l_m[ILP];\n      T l_v[ILP];\n      for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {\n        // load\n        load_store(l_g, g, 0, i_start);\n        if (decay != 0) load_store(l_p, p, 0, i_start);\n        load_store(l_m, m, 0, i_start);\n        load_store(l_v, v, 0, i_start);\n        // unpack\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          r_g[ii] = l_g[ii];\n          if (decay == 0) {\n            r_p[ii] = MATH_T(0);\n          } else {\n            r_p[ii] = l_p[ii];\n          }\n          r_m[ii] = l_m[ii];\n          r_v[ii] = l_v[ii];\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          if (mode == MOMENT_MODE_0) {\n            MATH_T scaled_grad = r_g[ii] / combined_scale;\n            // L2 on scaled grad\n            scaled_grad = scaled_grad + decay * r_p[ii];\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = next_m_unbiased / denom;\n          } else {\n            MATH_T scaled_grad = r_g[ii] / combined_scale;\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]);\n          }\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          l_m[ii] = r_m[ii];\n          l_v[ii] = r_v[ii];\n        }\n        // store\n        load_store(u, r_p, i_start, 0);\n        load_store(m, l_m, i_start, 0);\n        load_store(v, l_v, i_start, 0);\n      }\n    } else {\n      // see note in multi_tensor_scale_kernel.cu\n      for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n        MATH_T r_g[ILP];\n        MATH_T r_p[ILP];\n        MATH_T r_m[ILP];\n        MATH_T r_v[ILP];\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            r_g[ii] = g[i];\n            // special ?optimization? for lamb stage 1\n            if (decay == 0) {\n              r_p[ii] = MATH_T(0);\n            } else {\n              r_p[ii] = p[i];\n            }\n            r_m[ii] = m[i];\n            r_v[ii] = v[i];\n          } else {\n            r_g[ii] = MATH_T(0);\n            r_p[ii] = MATH_T(0);\n            r_m[ii] = MATH_T(0);\n            r_v[ii] = MATH_T(0);\n          }\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          if (mode == MOMENT_MODE_0) {\n            MATH_T scaled_grad = r_g[ii] / combined_scale;\n            // L2 on scaled grad\n            scaled_grad = scaled_grad + decay * r_p[ii];\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = next_m_unbiased / denom;\n          } else {\n            MATH_T scaled_grad = r_g[ii] / combined_scale;\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]);\n          }\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            u[i] = r_p[ii];\n            m[i] = r_m[ii];\n            v[i] = r_v[ii];\n          }\n        }\n      }\n    }\n  }\n};\n\n// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.\n// It computes new parameter value.\ntemplate <typename T, typename GRAD_T, typename MATH_T>\nstruct DistOptLAMBStage2Functor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<3>& tl,\n                                             const MATH_T* per_tensor_param_norm, const MATH_T* per_tensor_update_norm,\n                                             const long* update_norm_offset, const MATH_T* learning_rate,\n                                             const MATH_T* per_tensor_decay, const MATH_T* global_grad_norm,\n                                             bool use_nvlamb) {\n    // I'd like this kernel to propagate infs/nans.\n    if (*noop_gmem == 1) return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int tensor_num = tl.start_tensor_this_launch + tensor_loc;\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    MATH_T decay = per_tensor_decay[tensor_num];\n\n    MATH_T ratio = *learning_rate;\n    // nvlamb: apply adaptive learning rate to all parameters\n    // otherwise, only apply to those with non-zero weight decay\n    if (use_nvlamb || (decay != (MATH_T)0.0)) {\n      MATH_T param_norm = per_tensor_param_norm[tensor_num];\n      MATH_T update_norm = per_tensor_update_norm[update_norm_offset[tensor_num]];\n      ratio =\n          (update_norm != 0.0 && param_norm != 0.0) ? (*learning_rate) * (param_norm / update_norm) : (*learning_rate);\n    }\n\n    MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc];\n    update += chunk_idx * chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx * chunk_size;\n\n    GRAD_T* p_copy = (GRAD_T*)tl.addresses[2][tensor_loc];\n    p_copy += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    // to make things simple, we put aligned case in a different code path\n    if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(p) && is_aligned(update)) {\n      T r_p[ILP];\n      MATH_T r_update[ILP];\n      GRAD_T r_p_copy[ILP];\n      for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {\n        // load\n        load_store(r_p, p, 0, i_start);\n        load_store(r_update, update, 0, i_start);\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * r_update[ii]);\n          convert(r_p[ii], r_p_copy[ii]);\n        }\n        load_store(p, r_p, i_start, 0);\n        load_store(p_copy, r_p_copy, i_start, 0);\n      }\n    } else {\n      for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n        MATH_T r_p[ILP];\n        MATH_T r_update[ILP];\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            r_p[ii] = p[i];\n            r_update[ii] = update[i];\n          }\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          r_p[ii] = r_p[ii] - (ratio * r_update[ii]);\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            p[i] = r_p[ii];\n            convert(r_p[ii], p_copy[i]);\n          }\n        }\n      }\n    }\n  }\n};\n\nvoid multi_tensor_lamb_compute_update_term_cuda(int chunk_size, at::Tensor noop_flag,\n                                                std::vector<std::vector<at::Tensor>> tensor_lists,\n                                                at::Tensor per_tensor_beta1, at::Tensor per_tensor_beta2,\n                                                at::Tensor per_tensor_beta3, at::Tensor per_tensor_bias_correction,\n                                                at::Tensor step, at::Tensor per_tensor_epsilon, const int mode,\n                                                at::Tensor per_tensor_decay, at::Tensor global_scale,\n                                                at::Tensor global_grad_norm, const float max_grad_norm) {\n  using namespace at;\n\n  DISPATCH_FLOAT_AND_HALF(\n      tensor_lists[1][0].scalar_type(), 0, \"lamb_stage_1\",\n      DISPATCH_FLOAT_AND_HALF(\n          tensor_lists[0][0].scalar_type(), 1, \"lamb_stage_1\",\n          DISPATCH_FLOAT_AND_HALF(\n              tensor_lists[4][0].scalar_type(), 2, \"lamb_stage_1\",\n              multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,\n                                    DistOptLAMBStage1Functor<scalar_t_0, scalar_t_1, scalar_t_2>(),\n                                    per_tensor_beta1.data_ptr<scalar_t_2>(), per_tensor_beta2.data_ptr<scalar_t_2>(),\n                                    per_tensor_beta3.data_ptr<scalar_t_2>(), per_tensor_bias_correction.data_ptr<int>(),\n                                    step.data_ptr<int>(), per_tensor_epsilon.data_ptr<scalar_t_2>(), (adamMode_t)mode,\n                                    per_tensor_decay.data_ptr<scalar_t_2>(), global_scale.data_ptr<scalar_t_2>(),\n                                    global_grad_norm.data_ptr<scalar_t_2>(), max_grad_norm);)))\n\n  AT_CUDA_CHECK(cudaGetLastError());\n}\n\nvoid multi_tensor_lamb_update_weights_cuda(int chunk_size, at::Tensor noop_flag,\n                                           std::vector<std::vector<at::Tensor>> tensor_lists,\n                                           at::Tensor per_tensor_param_norm, at::Tensor per_tensor_update_norm,\n                                           at::Tensor update_norm_offset, at::Tensor learning_rate,\n                                           at::Tensor per_tensor_decay, at::Tensor global_grad_norm, bool use_nvlamb) {\n  using namespace at;\n\n  DISPATCH_FLOAT_AND_HALF(\n      tensor_lists[1][0].scalar_type(), 0, \"lamb_stage_2\",\n      DISPATCH_FLOAT_HALF_AND_BYTE(\n          tensor_lists[2][0].scalar_type(), 1, \"lamb_stage_2\",\n          DISPATCH_FLOAT_AND_HALF(\n              tensor_lists[0][0].scalar_type(), 2, \"lamb_stage_2\",\n              multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,\n                                    DistOptLAMBStage2Functor<scalar_t_0, scalar_t_1, scalar_t_2>(),\n                                    per_tensor_param_norm.data_ptr<scalar_t_2>(),\n                                    per_tensor_update_norm.data_ptr<scalar_t_2>(), update_norm_offset.data_ptr<long>(),\n                                    learning_rate.data_ptr<scalar_t_2>(), per_tensor_decay.data_ptr<scalar_t_2>(),\n                                    global_grad_norm.data_ptr<scalar_t_2>(), use_nvlamb);)))\n\n  AT_CUDA_CHECK(cudaGetLastError());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/peer_memory/peer_memory.cpp",
    "content": "/**\n * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \"peer_memory_cuda.cuh\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"allocate_raw\", &apex::contrib::peer_memory::allocate_raw, \"allocate_raw\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"free_raw\", &apex::contrib::peer_memory::free_raw, \"free_raw\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"zero\", &apex::contrib::peer_memory::zero, \"zero\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"get_raw_ipc_address\", &apex::contrib::peer_memory::get_raw_ipc_address, \"get_raw_ipc_address\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"get_raw_peers\", &apex::contrib::peer_memory::get_raw_peers, \"get_raw_peers\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"blob_view_half\", &apex::contrib::peer_memory::blob_view_half, \"blob_view_half\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"blob_view_float\", &apex::contrib::peer_memory::blob_view_float, \"blob_view_float\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"blob_view_int\", &apex::contrib::peer_memory::blob_view_int, \"blob_view_int\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"push_pull_halos_1d\", &apex::contrib::peer_memory::push_pull_halos_1d, \"push_pull_halos_1d\",\n        py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/peer_memory/peer_memory_cuda.cu",
    "content": "#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDACachingAllocator.h>\n#include <cuda_runtime_api.h>\n#include <torch/extension.h>\n\n#include <cassert>\n#include <cstdio>\n#include <list>\n\n#include \"nccl.h\"\n\n#define CUDACHECK(cmd)                                                                                \\\n  do {                                                                                                \\\n    cudaError_t err = cmd;                                                                            \\\n    if (err != cudaSuccess) {                                                                         \\\n      char hostname[1024];                                                                            \\\n      gethostname(hostname, 1024);                                                                    \\\n      printf(\"%s: CUDA failure %s:%d '%s'\\n\", hostname, __FILE__, __LINE__, cudaGetErrorString(err)); \\\n    }                                                                                                 \\\n  } while (0)\n\nnamespace {\n\nconstexpr int THREADS_PER_CTA = 128;\n\n/* Basic deleter function for from_blob function.\nvoid deleter(void* ptr)\n{\n    printf(\"deleter(ptr=%p)\\n\",ptr);\n    cudaFree(ptr);\n}\n*/\n\ntemplate <class T>\nat::Tensor blob_view(T* raw_ptr, std::vector<int64_t> shape, const at::TensorOptions& options, bool channels_last) {\n  size_t size = 1;\n  std::vector<int64_t> strides(shape.size());\n  if (channels_last) {\n    assert(shape.size() == 4);\n    strides[0] = shape[1] * shape[2] * shape[3];\n    strides[1] = 1;\n    strides[2] = shape[1] * shape[3];\n    strides[3] = shape[1];\n  } else {\n    int idx = strides.size();\n    for (auto it = shape.rbegin(); it != shape.rend(); ++it) {\n      strides[--idx] = size;\n      size *= *it;\n    }\n  }\n  size *= sizeof(T);\n  // TODO: Implement dynamic reuse of pooled peer memory.\n  // We provide no deleter function because all peer memory allocations are static in this implementation.\n  return torch::from_blob((void*)raw_ptr, shape, strides, 0L, options);\n}\n\nvoid tensor_shape(at::Tensor t, bool explicit_nhwc, int& N, int& C, int& H, int& W) {\n  if (t.dim() == 3) {\n    N = 1;\n    if (explicit_nhwc) {\n      C = t.size(2);\n      H = t.size(0);\n      W = t.size(1);\n    } else {\n      C = t.size(0);\n      H = t.size(1);\n      W = t.size(2);\n    }\n  } else if (t.dim() == 4) {\n    if (explicit_nhwc) {\n      N = t.size(0);\n      C = t.size(3);\n      H = t.size(1);\n      W = t.size(2);\n    } else {\n      N = t.size(0);\n      C = t.size(1);\n      H = t.size(2);\n      W = t.size(3);\n    }\n  } else {\n    printf(\"%s;%d - t.dim() must be either 3 or 4 (was %d)\\n\", __FILE__, __LINE__, int(t.dim()));\n    assert(t.dim() == 3 || t.dim() == 4);\n  }\n}\n\nvoid tensor_strides(at::Tensor t, bool explicit_nhwc, int& stride_N, int& stride_C, int& stride_H, int& stride_W) {\n  if (t.dim() == 3) {\n    if (explicit_nhwc) {\n      stride_C = t.stride(2);\n      stride_H = t.stride(0);\n      stride_W = t.stride(1);\n    } else {\n      stride_C = t.stride(0);\n      stride_H = t.stride(1);\n      stride_W = t.stride(2);\n    }\n    stride_N = t.size(0) * t.size(1) * t.size(2);\n  } else if (t.dim() == 4) {\n    if (explicit_nhwc) {\n      stride_N = t.stride(0);\n      stride_C = t.stride(3);\n      stride_H = t.stride(1);\n      stride_W = t.stride(2);\n    } else {\n      stride_N = t.stride(0);\n      stride_C = t.stride(1);\n      stride_H = t.stride(2);\n      stride_W = t.stride(3);\n    }\n  } else {\n    printf(\"%s;%d - t.dim() must be either 3 or 4 (was %d)\\n\", __FILE__, __LINE__, t.dim());\n    assert(t.dim() == 3 || t.dim() == 4);\n  }\n}\n\ntemplate <class T>\ninline __device__ void __zero(T* dst) {\n  *dst = T(0);\n}\n\ninline __device__ void __zero(int2* dst) { *dst = {0, 0}; }\n\ntemplate <class T, bool contiguous>\ninline __device__ void zero_tensor(const int dim0, const int dim1, const int dim2, T* __restrict__ data,\n                                   const int data_stride0, const int data_stride1, const int data_stride2,\n                                   const int thread_id, const int block_id, const int num_blocks) {\n  const int global_id = thread_id + block_id * THREADS_PER_CTA;\n  const int num_threads = num_blocks * THREADS_PER_CTA;\n  const int count = dim0 * dim1 * dim2;\n  for (int i = global_id; i < count; i += num_threads) {\n    int offset;\n    if (contiguous) {\n      offset = i;\n    } else {\n      const int j2 = i % dim2;\n      const int k = i / dim2;\n      const int j1 = k % dim1;\n      const int j0 = k / dim1;\n      offset = j0 * data_stride0 + j1 * data_stride1 + j2 * data_stride2;\n    }\n    __zero(data + offset);\n  }\n}\n\ntemplate <class T, bool contiguous>\ninline __device__ void push_pull_tensor(const int dim0, const int dim1, const int dim2, const T* __restrict__ data_in,\n                                        const int data_in_stride0, const int data_in_stride1, const int data_in_stride2,\n                                        T* __restrict__ data_out, const int data_out_stride0,\n                                        const int data_out_stride1, const int data_out_stride2, int4* local_peer,\n                                        int4* remote_peer, const int thread_id, const int block_id,\n                                        const int num_blocks) {\n  // 128b=16B NVLink flit\n  // Note: Use last 4B as a semaphore\n  static_assert(sizeof(T) <= 12);\n  union Flit {\n    T payload;\n    uint uints[4];\n  };\n  // Communication bit indicates whether flit has been received from\n  // a remote GPU\n  constexpr uint communication_mask = 1 << 0;\n  // Status bit is used to choose the active peer buffer in an\n  // alternating double buffer scheme. We use buffer 1 if the bits\n  // match, use buffer 2 if the bits differ, and invert the bit\n  // after finishing with a buffer.\n  constexpr uint status_mask = 1 << 1;\n\n  // Split peer memory into two sets of buffers\n  // Note: Each block owns a THREADS_PER_CTA*2*16B chunk of peer\n  // memory\n  const int peer_offset1 = block_id * THREADS_PER_CTA * 2 + thread_id;\n  const int peer_offset2 = peer_offset1 + THREADS_PER_CTA;\n  volatile int* local_peer1 = reinterpret_cast<volatile int*>(local_peer + peer_offset1);\n  volatile int* local_peer2 = reinterpret_cast<volatile int*>(local_peer + peer_offset2);\n  volatile int* remote_peer1 = reinterpret_cast<volatile int*>(remote_peer + peer_offset1);\n  volatile int* remote_peer2 = reinterpret_cast<volatile int*>(remote_peer + peer_offset2);\n\n  // Iterate through tensor entries\n  const int num_threads = num_blocks * THREADS_PER_CTA;\n  const int count = dim0 * dim1 * dim2;\n  for (int i0 = block_id * THREADS_PER_CTA; i0 < count; i0 += num_threads) {\n    const int i = i0 + thread_id;\n    const bool has_data = i < count;\n\n    // Calculate buffer positions\n    int data_in_offset, data_out_offset;\n    if (contiguous) {\n      data_in_offset = i;\n      data_out_offset = i;\n    } else {\n      const int j2 = i % dim2;\n      const int k = i / dim2;\n      const int j1 = k % dim1;\n      const int j0 = k / dim1;\n      data_in_offset = j0 * data_in_stride0 + j1 * data_in_stride1 + j2 * data_in_stride2;\n      data_out_offset = j0 * data_out_stride0 + j1 * data_out_stride1 + j2 * data_out_stride2;\n    }\n\n    // Determine which peer memory buffer to use\n    // Note: The status bit is not affected by asynchronous\n    // communication from the remote GPU.\n    Flit local_message1, local_message2;\n    asm volatile(\"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];\"\n                 : \"=r\"(local_message1.uints[0]), \"=r\"(local_message1.uints[1]), \"=r\"(local_message1.uints[2]),\n                   \"=r\"(local_message1.uints[3])\n                 : \"l\"(local_peer1)\n                 : \"memory\");\n    asm volatile(\"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];\"\n                 : \"=r\"(local_message2.uints[0]), \"=r\"(local_message2.uints[1]), \"=r\"(local_message2.uints[2]),\n                   \"=r\"(local_message2.uints[3])\n                 : \"l\"(local_peer2)\n                 : \"memory\");\n    const uint status1 = local_message1.uints[3] & status_mask;\n    const uint status2 = local_message2.uints[3] & status_mask;\n    const bool peer1_is_active = (status1 ^ status2) == 0;\n    volatile int* ox = peer1_is_active ? remote_peer1 : remote_peer2;\n    volatile int* ix = peer1_is_active ? local_peer1 : local_peer2;\n    const uint status = peer1_is_active ? status1 : status2;\n    Flit recv_message = peer1_is_active ? local_message1 : local_message2;\n\n    // Send flit to remote GPU\n    // Note: Set communication bit and keep status bit\n    Flit send_message;\n    if (has_data) {\n      send_message.payload = data_in[data_in_offset];\n    }\n    send_message.uints[3] = communication_mask | status;\n    asm volatile(\"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};\" ::\"l\"(ox), \"r\"(send_message.uints[0]),\n                 \"r\"(send_message.uints[1]), \"r\"(send_message.uints[2]), \"r\"(send_message.uints[3])\n                 : \"memory\");\n\n    // Recieve flit from peer\n    while ((recv_message.uints[3] & communication_mask) == 0) {\n      asm volatile(\"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];\"\n                   : \"=r\"(recv_message.uints[0]), \"=r\"(recv_message.uints[1]), \"=r\"(recv_message.uints[2]),\n                     \"=r\"(recv_message.uints[3])\n                   : \"l\"(ix)\n                   : \"memory\");\n    }\n    if (has_data) {\n      data_out[data_out_offset] = recv_message.payload;\n    }\n\n    // Reset semaphore\n    // Note: Clear communication bit and invert status bit\n    uint flag = ~status & status_mask;\n    asm volatile(\"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};\" ::\"l\"(ix), \"n\"(0), \"n\"(0), \"n\"(0), \"r\"(flag)\n                 : \"memory\");\n    if (i0 + num_threads < count) {\n      __threadfence_system();\n    }\n  }\n}\n\ntemplate <class T, bool contiguous, bool top_zero, bool btm_zero>\n#if __CUDA_ARCH__ >= 700\n__launch_bounds__(THREADS_PER_CTA)\n#endif\n    __global__ void push_pull_halos_1d_kernel(\n        // top halo,\n        T* toh, int toh_stride0, int toh_stride1, int toh_stride2,        // top output halo (local)\n        const T* tih, int tih_stride0, int tih_stride1, int tih_stride2,  // top input halo (local)\n        int4* tox,                                                        // top output transfer buffer (remote peer)\n        int4* tix,                                                        // top input transfer buffer (local peer)\n        // btm halo\n        T* boh, int boh_stride0, int boh_stride1, int boh_stride2,        // btm output halo (local)\n        const T* bih, int bih_stride0, int bih_stride1, int bih_stride2,  // btm input halo (local)\n        int4* box,                                                        // btm output transfer buffer (remote peer)\n        int4* bix,                                                        // btm input transfer buffer (local peer)\n        // dimensions\n        int dim0, int dim1, int dim2,\n        bool top_first  // whether to launch communicate top halo first\n    ) {\n  const int num_blocks_side = gridDim.x / 2;\n  const int block_id_side = (blockIdx.x < num_blocks_side ? blockIdx.x : blockIdx.x - num_blocks_side);\n  const bool in_top_block = top_first == (blockIdx.x < num_blocks_side);\n  if (in_top_block) {\n    if (top_zero) {\n      zero_tensor<T, contiguous>(dim0, dim1, dim2, toh, toh_stride0, toh_stride1, toh_stride2, threadIdx.x,\n                                 block_id_side, num_blocks_side);\n    } else {\n      push_pull_tensor<T, contiguous>(dim0, dim1, dim2, tih, tih_stride0, tih_stride1, tih_stride2, toh, toh_stride0,\n                                      toh_stride1, toh_stride2, tix, tox, threadIdx.x, block_id_side, num_blocks_side);\n    }\n  } else {\n    if (btm_zero) {\n      zero_tensor<T, contiguous>(dim0, dim1, dim2, boh, boh_stride0, boh_stride1, boh_stride2, threadIdx.x,\n                                 block_id_side, num_blocks_side);\n    } else {\n      push_pull_tensor<T, contiguous>(dim0, dim1, dim2, bih, bih_stride0, bih_stride1, bih_stride2, boh, boh_stride0,\n                                      boh_stride1, boh_stride2, bix, box, threadIdx.x, block_id_side, num_blocks_side);\n    }\n  }\n}\n\n__global__ void delay_kernel(int delay_nanoseconds, int* counter) {\n  if (blockIdx.x == 0 && threadIdx.x == 0) {\n    // waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.\n    int new_counter = 0;\n    double elapsed = 0;\n    clock_t start = clock();\n    do {\n      clock_t now = clock();\n      elapsed = (double)(now - start) * 1e9 / CLOCKS_PER_SEC;\n      ++new_counter;\n    } while (elapsed < (double)delay_nanoseconds);\n    *counter = new_counter;\n  }\n}\n\n}  // namespace\n\nnamespace apex {\nnamespace contrib {\nnamespace peer_memory {\n\nint64_t allocate_raw(int64_t size) {\n  float* ptr = 0L;\n  cudaMalloc(&ptr, size);\n  cudaMemset(ptr, 0, size);\n  return (int64_t)ptr;\n}\n\nvoid free_raw(int64_t raw) { cudaFree((void*)raw); }\n\nvoid zero(int64_t raw, int64_t size) { cudaMemset((void*)raw, 0, size); }\n\nat::Tensor get_raw_ipc_address(int64_t raw) {\n  cudaIpcMemHandle_t mem_handle;\n  CUDACHECK(cudaIpcGetMemHandle(&mem_handle, (void*)raw));\n  const int n = sizeof(cudaIpcMemHandle_t);\n  auto address_tensor = torch::empty({n}, torch::dtype(torch::kUInt8));\n  auto address_tensor_p = address_tensor.data_ptr<uint8_t>();\n  memcpy(address_tensor_p, (uint8_t*)&mem_handle, n);\n  return address_tensor;\n}\n\nstd::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw) {\n  int peer_group_size = ipc_addresses.size(0);\n  std::vector<int64_t> results(peer_group_size);\n  for (int i = 0; i < peer_group_size; ++i) {\n    if (i != peer_rank) {\n      cudaIpcMemHandle_t mem_handle;\n      memcpy(&mem_handle, ipc_addresses.index({i}).data_ptr<uint8_t>(), sizeof(cudaIpcMemHandle_t));\n      void* p = 0L;\n      CUDACHECK(cudaIpcOpenMemHandle((void**)&p, mem_handle, cudaIpcMemLazyEnablePeerAccess));\n      results[i] = (int64_t)p;\n    } else {\n      results[i] = (int64_t)raw;\n    }\n  }\n  return results;\n}\n\nat::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape, bool channels_last) {\n  return blob_view<at::Half>((at::Half*)raw, shape, torch::dtype(torch::kFloat16).device(torch::kCUDA), channels_last);\n}\n\nat::Tensor blob_view_float(int64_t raw, std::vector<int64_t> shape, bool channels_last) {\n  return blob_view<float>((float*)raw, shape, torch::dtype(torch::kFloat32).device(torch::kCUDA), channels_last);\n}\n\nat::Tensor blob_view_int(int64_t raw, std::vector<int64_t> shape, bool channels_last) {\n  return blob_view<int>((int*)raw, shape, torch::dtype(torch::kInt32).device(torch::kCUDA), channels_last);\n}\n\nvoid push_pull_halos_1d(\n    bool diagnostics, bool explicit_nhwc,\n    int numSM,                    // number of SMs to use (zero corresponds to all SMs)\n    int rank,                     // rank in spatial parallel group\n    bool top_zero,                // if top halo should be zeroed\n    at::Tensor top_in_halo,       // top input halo buffer (in local device memory, sent to top neighbor)\n    at::Tensor top_in_transfer,   // top input transfer buffer (in local peer memory)\n    at::Tensor top_out_transfer,  // top output transfer buffer (in top neighbor peer memory)\n    at::Tensor top_out_halo,      // top output halo buffer (in local device memory, received from top neighbor)\n    bool btm_zero,                // if btm halo should be zeroed\n    at::Tensor btm_in_halo,       // btm input halo buffer (in local device memory, sent to btm neighbor)\n    at::Tensor btm_in_transfer,   // btm input transfer buffer (in local peer memory)\n    at::Tensor btm_out_transfer,  // btm output transfer buffer (in btm neighbor peer memory)\n    at::Tensor btm_out_halo       // btm output halo buffer (in local device memory, received from btm neighbor)\n) {\n  // basic checks of inputs\n  TORCH_CHECK(!(top_zero && btm_zero));\n  TORCH_CHECK(top_in_halo.is_cuda());\n  TORCH_CHECK(top_out_transfer.is_cuda());\n  TORCH_CHECK(top_in_transfer.is_cuda());\n  TORCH_CHECK(top_out_halo.is_cuda());\n  TORCH_CHECK(btm_in_halo.is_cuda());\n  TORCH_CHECK(btm_out_transfer.is_cuda());\n  TORCH_CHECK(btm_in_transfer.is_cuda());\n  TORCH_CHECK(btm_out_halo.is_cuda());\n\n  // tensor shapes\n  int tih_N, tih_C, tih_H, tih_W;\n  tensor_shape(top_in_halo, explicit_nhwc, tih_N, tih_C, tih_H, tih_W);\n  int toh_N, toh_C, toh_H, toh_W;\n  tensor_shape(top_out_halo, explicit_nhwc, toh_N, toh_C, toh_H, toh_W);\n  int bih_N, bih_C, bih_H, bih_W;\n  tensor_shape(btm_in_halo, explicit_nhwc, bih_N, bih_C, bih_H, bih_W);\n  int boh_N, boh_C, boh_H, boh_W;\n  tensor_shape(btm_out_halo, explicit_nhwc, boh_N, boh_C, boh_H, boh_W);\n  TORCH_CHECK(toh_N == tih_N && tih_N == boh_N && boh_N == bih_N && toh_C == tih_C && tih_C == boh_C &&\n              boh_C == bih_C && toh_H == tih_H && tih_H == boh_H && boh_H == bih_H && toh_W == tih_W &&\n              tih_W == boh_W && boh_W == bih_W);\n  int NN = toh_N, NC = toh_C, NH = toh_H, NW = toh_W;\n  if (diagnostics) {\n    printf(\"rank %d: NN=%d, NC=%d, NH=%d, NW=%d\\n\", rank, NN, NC, NH, NW);\n  }\n  TORCH_CHECK(NN == 1);\n\n  // tensor strides\n  int tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W;\n  tensor_strides(top_in_halo, explicit_nhwc, tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W);\n  int toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W;\n  tensor_strides(top_out_halo, explicit_nhwc, toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W);\n  int bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W;\n  tensor_strides(btm_in_halo, explicit_nhwc, bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W);\n  int boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W;\n  tensor_strides(btm_out_halo, explicit_nhwc, boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W);\n  if (diagnostics) {\n    printf(\"rank %d: tih_stride :: N=%d, C=%d, H=%d, W=%d\\n\", rank, tih_stride_N, tih_stride_C, tih_stride_H,\n           tih_stride_W);\n    printf(\"rank %d: toh_stride :: N=%d, C=%d, H=%d, W=%d\\n\", rank, toh_stride_N, toh_stride_C, toh_stride_H,\n           toh_stride_W);\n    printf(\"rank %d: bih_stride :: N=%d, C=%d, H=%d, W=%d\\n\", rank, bih_stride_N, bih_stride_C, bih_stride_H,\n           bih_stride_W);\n    printf(\"rank %d: boh_stride :: N=%d, C=%d, H=%d, W=%d\\n\", rank, boh_stride_N, boh_stride_C, boh_stride_H,\n           boh_stride_W);\n  }\n\n  // determine if nhwc\n  bool is_nhwc = (toh_stride_C == 1);\n  if (diagnostics) {\n    printf(\"rank %d: is_nhwc = %s\\n\", rank, is_nhwc ? \"true\" : \"false\");\n  }\n\n  // determine if contiguous\n  bool contiguous = true;\n  if ((NN - 1) * toh_stride_N + (NC - 1) * toh_stride_C + (NH - 1) * toh_stride_H + (NW - 1) * toh_stride_W !=\n      NN * NC * NH * NW - 1) {\n    contiguous = false;\n  }\n  if ((NN - 1) * boh_stride_N + (NC - 1) * boh_stride_C + (NH - 1) * boh_stride_H + (NW - 1) * boh_stride_W !=\n      NN * NC * NH * NW - 1) {\n    contiguous = false;\n  }\n  if (!top_zero) {\n    if (toh_stride_N != tih_stride_N || toh_stride_C != tih_stride_C || toh_stride_H != tih_stride_H ||\n        toh_stride_W != tih_stride_W) {\n      contiguous = false;\n    }\n  }\n  if (!btm_zero) {\n    if (boh_stride_N != bih_stride_N || boh_stride_C != bih_stride_C || boh_stride_H != bih_stride_H ||\n        boh_stride_W != bih_stride_W) {\n      contiguous = false;\n    }\n  }\n  if (diagnostics) {\n    printf(\"rank %d: contiguous = %s\\n\", rank, contiguous ? \"true\" : \"false\");\n  }\n\n  // determine whether to communicate top halo first\n  bool top_first = rank % 2 != 0;\n  if (diagnostics) {\n    printf(\"rank %d: top_first = %s\\n\", rank, top_first ? \"true\" : \"false\");\n  }\n\n  // peer memory buffers\n  int tox_size = top_out_transfer.numel() * top_out_transfer.element_size();\n  int tix_size = top_in_transfer.numel() * top_in_transfer.element_size();\n  int box_size = btm_out_transfer.numel() * btm_out_transfer.element_size();\n  int bix_size = btm_in_transfer.numel() * btm_in_transfer.element_size();\n  if (!top_zero) {\n    TORCH_CHECK(top_out_transfer.is_contiguous());\n    TORCH_CHECK(top_in_transfer.is_contiguous());\n    TORCH_CHECK(tox_size == tix_size);\n  }\n  if (!btm_zero) {\n    TORCH_CHECK(btm_out_transfer.is_contiguous());\n    TORCH_CHECK(btm_in_transfer.is_contiguous());\n    TORCH_CHECK(box_size == bix_size);\n  }\n\n  // figure out launch parameters\n  int device;\n  cudaGetDevice(&device);\n  cudaDeviceProp prop;\n  cudaGetDeviceProperties(&prop, device);\n  if (numSM <= 0 || numSM > prop.multiProcessorCount) {\n    numSM = prop.multiProcessorCount;\n  }\n  auto current_stream = at::cuda::getCurrentCUDAStream();\n  dim3 block(THREADS_PER_CTA, 1, 1);\n\n  // helper macros to launch templated kernel\n#define LAUNCH_PUSH_PULL_HALO_KERNEL_BASE(T, CONTIGUOUS, TOP_ZERO, BTM_ZERO, KERNEL_ARGS, NUM_ELEMENTS)           \\\n  do {                                                                                                            \\\n    /* kernel configuration */                                                                                    \\\n    int numBlocksPerSm;                                                                                           \\\n    cudaOccupancyMaxActiveBlocksPerMultiprocessor(                                                                \\\n        &numBlocksPerSm, push_pull_halos_1d_kernel<T, CONTIGUOUS, TOP_ZERO, BTM_ZERO>, THREADS_PER_CTA, 0);       \\\n    dim3 grid(numSM * numBlocksPerSm, 1, 1);                                                                      \\\n    if (grid.x % 2 != 0) {                                                                                        \\\n      /* require even number of blocks (half for top, half for bottom) */                                         \\\n      grid.x -= 1;                                                                                                \\\n    }                                                                                                             \\\n    if ((grid.x / 2) * THREADS_PER_CTA > NUM_ELEMENTS) {                                                          \\\n      /* only need enough blocks to cover top and bottom halo elements */                                         \\\n      grid.x = 2 * ((NUM_ELEMENTS + THREADS_PER_CTA - 1) / THREADS_PER_CTA);                                      \\\n    }                                                                                                             \\\n    if (!TOP_ZERO) {                                                                                              \\\n      /* require 2*128b=32B peer memory per thread */                                                             \\\n      if ((grid.x / 2) * THREADS_PER_CTA * 32 > tox_size) {                                                       \\\n        grid.x = 2 * (tox_size / (THREADS_PER_CTA * 32));                                                         \\\n      }                                                                                                           \\\n    }                                                                                                             \\\n    if (!BTM_ZERO) {                                                                                              \\\n      /* require 2*128b=32B peer memory per thread */                                                             \\\n      if ((grid.x / 2) * THREADS_PER_CTA * 32 > box_size) {                                                       \\\n        grid.x = 2 * (box_size / (THREADS_PER_CTA * 32));                                                         \\\n      }                                                                                                           \\\n    }                                                                                                             \\\n    TORCH_CHECK(grid.x >= 2);                                                                                     \\\n                                                                                                                  \\\n    /* launch kernel */                                                                                           \\\n    cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<T, CONTIGUOUS, TOP_ZERO, BTM_ZERO>, grid, block, \\\n                                KERNEL_ARGS, 0, current_stream);                                                  \\\n  } while (false)\n#define LAUNCH_PUSH_PULL_HALO_KERNEL(T, CONTIGUOUS, KERNEL_ARGS, NUM_ELEMENTS)                   \\\n  do {                                                                                           \\\n    if (top_zero) {                                                                              \\\n      LAUNCH_PUSH_PULL_HALO_KERNEL_BASE(T, CONTIGUOUS, true, false, KERNEL_ARGS, NUM_ELEMENTS);  \\\n    } else if (btm_zero) {                                                                       \\\n      LAUNCH_PUSH_PULL_HALO_KERNEL_BASE(T, CONTIGUOUS, false, true, KERNEL_ARGS, NUM_ELEMENTS);  \\\n    } else {                                                                                     \\\n      LAUNCH_PUSH_PULL_HALO_KERNEL_BASE(T, CONTIGUOUS, false, false, KERNEL_ARGS, NUM_ELEMENTS); \\\n    }                                                                                            \\\n  } while (false)\n\n  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, top_out_halo.scalar_type(), \"push_pull_halos_1d_kernel\", [&] {\n    if (diagnostics) {\n      printf(\"rank %d: size(scalar_t) = %ld\\n\", rank, sizeof(scalar_t));\n    }\n    scalar_t* toh_p = top_out_halo.data_ptr<scalar_t>();\n    scalar_t* tih_p = top_in_halo.data_ptr<scalar_t>();\n    int4* tox_p = reinterpret_cast<int4*>(top_out_transfer.data_ptr<scalar_t>());\n    int4* tix_p = reinterpret_cast<int4*>(top_in_transfer.data_ptr<scalar_t>());\n    scalar_t* boh_p = btm_out_halo.data_ptr<scalar_t>();\n    scalar_t* bih_p = btm_in_halo.data_ptr<scalar_t>();\n    int4* box_p = reinterpret_cast<int4*>(btm_out_transfer.data_ptr<scalar_t>());\n    int4* bix_p = reinterpret_cast<int4*>(btm_in_transfer.data_ptr<scalar_t>());\n    if (diagnostics) printf(\"rank %d: choosing halo exchange kernel\\n\", rank);\n\n    // do int2 vector loads if channel count permits\n    if (contiguous && (NN * NH * NW * NC * sizeof(scalar_t)) % sizeof(int2) == 0) {\n      // can do contiguous int2 transfers\n      if (diagnostics) {\n      }\n      toh_stride_N = toh_stride_H = toh_stride_W = toh_stride_C = 1;\n      tih_stride_N = tih_stride_H = tih_stride_W = tih_stride_C = 1;\n      boh_stride_N = boh_stride_H = boh_stride_W = boh_stride_C = 1;\n      bih_stride_N = bih_stride_H = bih_stride_W = bih_stride_C = 1;\n      NC = (NN * NH * NW * NC * sizeof(scalar_t)) / sizeof(int2);\n      NN = NH = NW = 1;\n      if (diagnostics) {\n        printf(\"rank %d: launching contiguous int2 halo exchange kernel\\n\", rank);\n        printf(\"rank %d: NC=%d, NH=%d, NW=%d\\n\", rank, NC, NH, NW);\n      }\n      void* kernel_args[] = {(int2**)&toh_p,\n                             &toh_stride_H,\n                             &toh_stride_W,\n                             &toh_stride_C,\n                             (int2**)&tih_p,\n                             &tih_stride_H,\n                             &tih_stride_W,\n                             &tih_stride_C,\n                             &tox_p,\n                             &tix_p,\n                             (int2**)&boh_p,\n                             &boh_stride_H,\n                             &boh_stride_W,\n                             &boh_stride_C,\n                             (int2**)&bih_p,\n                             &bih_stride_H,\n                             &bih_stride_W,\n                             &bih_stride_C,\n                             &box_p,\n                             &bix_p,\n                             &NH,\n                             &NW,\n                             &NC,\n                             &top_first};\n      int num_elem = NN * NH * NW * NC;\n      LAUNCH_PUSH_PULL_HALO_KERNEL(int2, true, kernel_args, num_elem);\n    } else if (is_nhwc && (NC * sizeof(scalar_t)) % sizeof(int2) == 0) {\n      // can do strided int2 transfers\n      int divisor = sizeof(int2) / sizeof(scalar_t);\n      if (diagnostics) {\n        printf(\"rank %d: launching strided int2 halo exchange kernel\\n\", rank);\n      }\n      toh_stride_N /= divisor;\n      toh_stride_H /= divisor;\n      toh_stride_W /= divisor;\n      tih_stride_N /= divisor;\n      tih_stride_H /= divisor;\n      tih_stride_W /= divisor;\n      boh_stride_N /= divisor;\n      boh_stride_H /= divisor;\n      boh_stride_W /= divisor;\n      bih_stride_N /= divisor;\n      bih_stride_H /= divisor;\n      bih_stride_W /= divisor;\n      NC /= divisor;\n      if (diagnostics) {\n        printf(\"rank %d: divisor=%d\\n\", rank, divisor);\n        printf(\"rank %d: tih_stride :: N=%d, C=%d, H=%d, W=%d\\n\", rank, tih_stride_N, tih_stride_C, tih_stride_H,\n               tih_stride_W);\n        printf(\"rank %d: toh_stride :: N=%d, C=%d, H=%d, W=%d\\n\", rank, toh_stride_N, toh_stride_C, toh_stride_H,\n               toh_stride_W);\n        printf(\"rank %d: bih_stride :: N=%d, C=%d, H=%d, W=%d\\n\", rank, bih_stride_N, bih_stride_C, bih_stride_H,\n               bih_stride_W);\n        printf(\"rank %d: boh_stride :: N=%d, C=%d, H=%d, W=%d\\n\", rank, boh_stride_N, boh_stride_C, boh_stride_H,\n               boh_stride_W);\n        printf(\"rank %d: NC=%d, NH=%d, NW=%d\\n\", rank, NC, NH, NW);\n      }\n      void* kernel_args[] = {(int2**)&toh_p,\n                             &toh_stride_H,\n                             &toh_stride_W,\n                             &toh_stride_C,\n                             (int2**)&tih_p,\n                             &tih_stride_H,\n                             &tih_stride_W,\n                             &tih_stride_C,\n                             &tox_p,\n                             &tix_p,\n                             (int2**)&boh_p,\n                             &boh_stride_H,\n                             &boh_stride_W,\n                             &boh_stride_C,\n                             (int2**)&bih_p,\n                             &bih_stride_H,\n                             &bih_stride_W,\n                             &bih_stride_C,\n                             &box_p,\n                             &bix_p,\n                             &NH,\n                             &NW,\n                             &NC,\n                             &top_first};\n      int num_elem = NH * NW * NC;\n      LAUNCH_PUSH_PULL_HALO_KERNEL(int2, false, kernel_args, num_elem);\n    } else {\n      // cannot do int2 transfers\n      if (diagnostics) {\n        printf(\"rank %d: launching non-int2 halo exchange kernel\\n\", rank);\n      }\n      int num_elem = NC * NH * NW;\n      if (is_nhwc) {\n        void* kernel_args[] = {&toh_p,        &toh_stride_H, &toh_stride_W, &toh_stride_C, &tih_p,        &tih_stride_H,\n                               &tih_stride_W, &tih_stride_C, &tox_p,        &tix_p,        &boh_p,        &boh_stride_H,\n                               &boh_stride_W, &boh_stride_C, &bih_p,        &bih_stride_H, &bih_stride_W, &bih_stride_C,\n                               &box_p,        &bix_p,        &NH,           &NW,           &NC,           &top_first};\n        LAUNCH_PUSH_PULL_HALO_KERNEL(scalar_t, false, kernel_args, num_elem);\n      } else {\n        void* kernel_args[] = {&toh_p,        &toh_stride_C, &toh_stride_H, &toh_stride_W, &tih_p,        &tih_stride_C,\n                               &tih_stride_H, &tih_stride_W, &tox_p,        &tix_p,        &boh_p,        &boh_stride_C,\n                               &boh_stride_H, &boh_stride_W, &bih_p,        &bih_stride_C, &bih_stride_H, &bih_stride_W,\n                               &box_p,        &bix_p,        &NC,           &NH,           &NW,           &top_first};\n        LAUNCH_PUSH_PULL_HALO_KERNEL(scalar_t, false, kernel_args, num_elem);\n      }\n    }\n  });\n\n#undef LAUNCH_PUSH_PULL_HALO_KERNEL_BASE\n#undef LAUNCH_PUSH_PULL_HALO_KERNEL\n}\n\n}  // namespace peer_memory\n}  // namespace contrib\n}  // namespace apex\n"
  },
  {
    "path": "apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh",
    "content": "/**\n * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n#include <torch/extension.h>\n#ifndef _peer_memory_h_\n#define _peer_memory_h_\n\nnamespace apex {\nnamespace contrib {\nnamespace peer_memory {\nint64_t allocate_raw(int64_t size);\nvoid free_raw(int64_t raw);\nvoid zero(int64_t raw, int64_t size);\nat::Tensor get_raw_ipc_address(int64_t raw);\nstd::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw);\nat::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape, bool channels_last);\nat::Tensor blob_view_float(int64_t raw, std::vector<int64_t> shape, bool channels_last);\nat::Tensor blob_view_int(int64_t raw, std::vector<int64_t> shape, bool channels_last);\nvoid push_pull_halos_1d(\n    bool diagnostics, bool explicit_nhwc,\n    int numSM,                    // number of SMs to use\n    int peer_rank,                // rank in spatial parallel group\n    bool top_zero,                // if top halo should be zeroed\n    at::Tensor top_out_halo,      // top output halo buffer (in local device memory, received from top neighbor)\n    at::Tensor top_inp_transfer,  // top input transfer buffer (in local peer memory)\n    at::Tensor top_out_transfer,  // top output transfer buffer (in top neighbor peer memory)\n    at::Tensor top_inp_halo,      // top input halo buffer (in local device memory, sent to top neighbor)\n    bool btm_zero,                // if btm halo should be zeroed\n    at::Tensor btm_out_halo,      // btm output halo buffer (in local device memory, received from btm neighbor)\n    at::Tensor btm_inp_transfer,  // btm input transfer buffer (in local peer memory)\n    at::Tensor btm_out_transfer,  // btm output transfer buffer (in btm neighbor peer memory)\n    at::Tensor btm_inp_halo       // btm input halo buffer (in local device memory, sent to btm neighbor)\n);\n}  // namespace peer_memory\n}  // namespace contrib\n}  // namespace apex\n#endif\n"
  },
  {
    "path": "apex/contrib/csrc/transducer/transducer_joint.cpp",
    "content": "#include <ATen/Functions.h>\r\n#include <torch/extension.h>\r\n\r\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x \" must be a CUDA tensor\")\r\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\r\n#define CHECK_INPUT(x) \\\r\n  CHECK_CUDA(x);       \\\r\n  CHECK_CONTIGUOUS(x)\r\n\r\nstd::vector<torch::Tensor> transducer_joint_cuda_forward(torch::Tensor f, torch::Tensor g, torch::Tensor fLen,\r\n                                                         torch::Tensor gLen, torch::Tensor batchOffset,\r\n                                                         int64_t packedBatch, int opt, bool packOutput, bool relu,\r\n                                                         bool dropout, float dropoutProb, int tileSize);\r\n\r\nstd::vector<torch::Tensor> transducer_joint_cuda_backward(std::vector<torch::Tensor> in, torch::Tensor fLen,\r\n                                                          torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen,\r\n                                                          int maxGLen, bool packOutput, float scale);\r\n\r\nstd::vector<torch::Tensor> transducer_joint_forward(torch::Tensor f, torch::Tensor g, torch::Tensor fLen,\r\n                                                    torch::Tensor gLen, torch::Tensor batchOffset, int64_t packedBatch,\r\n                                                    int opt, bool packOutput, bool relu, bool dropout,\r\n                                                    float dropoutProb, int tileSize) {\r\n  CHECK_INPUT(f);\r\n  CHECK_INPUT(g);\r\n  CHECK_INPUT(fLen);\r\n  CHECK_INPUT(gLen);\r\n  if (packOutput) CHECK_INPUT(batchOffset);\r\n  return transducer_joint_cuda_forward(f, g, fLen, gLen, batchOffset, packedBatch, opt, packOutput, relu, dropout,\r\n                                       dropoutProb, tileSize);\r\n}\r\n\r\nstd::vector<torch::Tensor> transducer_joint_backward(std::vector<torch::Tensor> in, torch::Tensor fLen,\r\n                                                     torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen,\r\n                                                     int maxGLen, bool packOutput, float scale) {\r\n  for (auto t : in) {\r\n    CHECK_INPUT(t);\r\n  }\r\n  CHECK_INPUT(fLen);\r\n  CHECK_INPUT(gLen);\r\n  if (packOutput) CHECK_INPUT(batchOffset);\r\n  return transducer_joint_cuda_backward(in, fLen, gLen, batchOffset, maxFLen, maxGLen, packOutput, scale);\r\n}\r\n\r\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\r\n  m.def(\"forward\", &transducer_joint_forward, \"transducer joint forward (CUDA)\",\r\n        py::call_guard<py::gil_scoped_release>());\r\n  m.def(\"backward\", &transducer_joint_backward, \"transducer joint backward (CUDA)\",\r\n        py::call_guard<py::gil_scoped_release>());\r\n}\r\n"
  },
  {
    "path": "apex/contrib/csrc/transducer/transducer_joint_kernel.cu",
    "content": "#include <ATen/AccumulateType.h>\r\n#include <cuda.h>\r\n#include <cuda_runtime.h>\r\n#include <curand_kernel.h>\r\n#include <torch/extension.h>\r\n\r\n#ifdef OLD_GENERATOR_PATH\r\n#include <ATen/CUDAGeneratorImpl.h>\r\n#else\r\n#include <ATen/cuda/CUDAGeneratorImpl.h>\r\n#endif\r\n\r\n#include <ATen/cuda/CUDAContext.h>\r\n#include <c10/macros/Macros.h>\r\n\r\n#include <ATen/cuda/CUDAGraphsUtils.cuh>\r\n\r\n#include \"philox.cuh\"\r\n\r\n// Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width.\r\n// width should be a power of 2 and should be less than warpSize.\r\ntemplate <typename scalar_t>\r\n__device__ __forceinline__ scalar_t warpReduce(scalar_t x, int width = C10_WARP_SIZE) {\r\n  for (unsigned offset = width / 2; offset > 0; offset /= 2) {\r\n    x += __shfl_down_sync(0xffffffff, x, offset, width);\r\n  }\r\n  return x;\r\n}\r\n\r\ninline int largestPowerOfTwo(int x) {\r\n  int y = 1;\r\n  while (y <= x) y <<= 1;\r\n  return y >> 1;\r\n}\r\n\r\n/*\r\nFigure out vectorization type for masks.\r\nSimilar to how PyTorch figures out acc_t here:\r\naten/src/ATen/AccumulateType.h\r\n*/\r\ntemplate <int V>\r\nstruct MaskVecType {};\r\n\r\ntemplate <>\r\nstruct MaskVecType<1> {\r\n  using type = uint8_t;\r\n};\r\ntemplate <>\r\nstruct MaskVecType<2> {\r\n  using type = uint16_t;\r\n};\r\ntemplate <>\r\nstruct MaskVecType<4> {\r\n  using type = uint32_t;\r\n};\r\n\r\ntemplate <int V>\r\nusing mvec_type = typename MaskVecType<V>::type;\r\n\r\n// Helper class to calculate pointer offset that can be shared by different flavors of kernels.\r\n// For fwd, batch offset and stride are different for packing and non-packing mode.\r\nstruct OffsetCalFwd {\r\n  __device__ __forceinline__ OffsetCalFwd(int64_t batch, const int64_t* batchOffset, int64_t maxFLen, int64_t maxGLen,\r\n                                          int64_t gLen, int64_t hiddenSize, bool packOutput)\r\n      : batch(batch),\r\n        batchOffset(batchOffset),\r\n        maxFLen(maxFLen),\r\n        maxGLen(maxGLen),\r\n        gLen(gLen),\r\n        hiddenSize(hiddenSize),\r\n        packOutput(packOutput) {}\r\n\r\n  int64_t batch;\r\n  const int64_t* batchOffset;\r\n  int64_t maxFLen;\r\n  int64_t maxGLen;\r\n  int64_t gLen;\r\n  int64_t hiddenSize;\r\n  bool packOutput;\r\n\r\n  __device__ __forceinline__ int64_t getBatchOffset() {\r\n    return packOutput ? ((batch == 0) ? 0 : batchOffset[batch - 1]) * hiddenSize\r\n                      : batch * maxFLen * maxGLen * hiddenSize;\r\n  }\r\n\r\n  __device__ __forceinline__ int64_t getStrideF() { return packOutput ? gLen * hiddenSize : maxGLen * hiddenSize; }\r\n};\r\n\r\n// Helper class to calculate pointer offset that can be shared by different flavors of kernels\r\n// For bwd, batch offset and stride are different for packing and non-packing mode.\r\n// The reducion is done for two input tensors. Therefore, generating two sets of offsets\r\n// according to bwdFasterDim can lead to a unified implementation in the actual kernel.\r\nstruct OffsetCalBwd {\r\n  __device__ __forceinline__ OffsetCalBwd(int64_t batch, const int64_t* batchOffset, const int* fLen, const int* gLen,\r\n                                          int64_t maxFLen, int64_t maxGLen, int64_t hiddenSize, bool packOutput,\r\n                                          bool bwdFasterDim)\r\n      : batch(batch),\r\n        batchOffset(batchOffset),\r\n        maxFLen(maxFLen),\r\n        maxGLen(maxGLen),\r\n        fLen(fLen),\r\n        gLen(gLen),\r\n        hiddenSize(hiddenSize),\r\n        packOutput(packOutput),\r\n        bwdFasterDim(bwdFasterDim) {}\r\n\r\n  int64_t batch;\r\n  const int64_t* batchOffset;\r\n  const int* fLen;\r\n  const int* gLen;\r\n  int64_t maxFLen;\r\n  int64_t maxGLen;\r\n  int64_t hiddenSize;\r\n  bool packOutput;\r\n  bool bwdFasterDim;  // whether doing bwd on the faster moving dimension\r\n\r\n  __device__ __forceinline__ int64_t getBatchOffset() {\r\n    return packOutput ? ((batch == 0) ? 0 : batchOffset[batch - 1]) * hiddenSize\r\n                      : batch * maxFLen * maxGLen * hiddenSize;\r\n  }\r\n\r\n  __device__ __forceinline__ int64_t getMaxXLen() { return bwdFasterDim ? maxGLen : maxFLen; }\r\n\r\n  __device__ __forceinline__ auto getMyXLen() -> decltype(gLen[batch]) {\r\n    return bwdFasterDim ? gLen[batch] : fLen[batch];\r\n  }\r\n\r\n  __device__ __forceinline__ auto getMyYLen() -> decltype(gLen[batch]) {\r\n    return bwdFasterDim ? fLen[batch] : gLen[batch];\r\n  }\r\n\r\n  __device__ __forceinline__ int64_t getStrideX() {\r\n    return bwdFasterDim ? hiddenSize : ((packOutput ? gLen[batch] : maxGLen) * hiddenSize);\r\n  }\r\n\r\n  __device__ __forceinline__ int64_t getStrideY() {\r\n    return bwdFasterDim ? ((packOutput ? gLen[batch] : maxGLen) * hiddenSize) : hiddenSize;\r\n  }\r\n};\r\n\r\n// Vanila transducer joint forward kernel\r\n// Detail of this joint function can be found in:\r\n// [1] Sequence Transduction with Recurrent Neural Networks.\r\n\r\n// f is a tensor of shape [batch, T, H]\r\n// g is a tensor of shape [batch, U, H]\r\n// the transducer joint does\r\n// sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1)\r\n// The resultant tensor is of shape [batch, T, U, H]\r\n// Each thread block is working on one \"batch\" of data in the output tensor, [batch, t, u, :]\r\n\r\n// This joint function can optionally pack the output where the output tensor with a shape of\r\n// [B, T, U, H] is packed into [B_packed, H].\r\n// Don't-care region (t > fLen) or (u > gLen) is removed.\r\n// To enable packing, the starting offset for each batch need to be specified with batchOffset.\r\ntemplate <typename scalar_t, class OffsetCal>\r\n__global__ void transducer_joint_forward(const scalar_t* f, const scalar_t* g, const int* fLen, const int* gLen,\r\n                                         const int64_t* batchOffset, int64_t maxFLen, int64_t maxGLen,\r\n                                         int64_t hiddenSize, bool packOutput, scalar_t* sum) {\r\n  const int batch = blockIdx.z;\r\n  const int t = blockIdx.y;\r\n  const int u = blockIdx.x;\r\n  const auto myFLen = fLen[batch];\r\n  const auto myGLen = gLen[batch];\r\n\r\n  OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput);\r\n  const auto myBatchOffset = offsetCal.getBatchOffset();\r\n  const auto strideF = offsetCal.getStrideF();\r\n  scalar_t const* myF = f + batch * maxFLen * hiddenSize + t * hiddenSize;\r\n  scalar_t const* myG = g + batch * maxGLen * hiddenSize + u * hiddenSize;\r\n  scalar_t* mySum = sum + myBatchOffset + t * strideF + u * hiddenSize;\r\n\r\n  if (t < myFLen and u < myGLen) {\r\n#pragma unroll\r\n    for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x) {\r\n      if (h < hiddenSize) {\r\n        mySum[h] = myF[h] + myG[h];\r\n      }\r\n    }\r\n  } else if (packOutput == false and t < maxFLen and u < maxGLen) {\r\n// Need to write finite data to don't-care region because we instantiate the result tensor\r\n// with torch::empty for performance reasons. Even though it is don't-care region, the\r\n// contents need to be finite, otherwise could lead to NaN in WGRAD.\r\n// In packing mode, this write is no longer necessary as we remove the don't-care region\r\n// from the output.\r\n// Picking -1 (over 0) here for ease of testing.\r\n#pragma unroll\r\n    for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x) {\r\n      if (h < hiddenSize) {\r\n        mySum[h] = -1;\r\n      }\r\n    }\r\n  }\r\n}\r\n\r\n/*\r\nTiled version of the joint forward kernel\r\nDetail of this joint function can be found in:\r\n[1] Sequence Transduction with Recurrent Neural Networks.\r\n\r\nf is a tensor of shape [batch, T, H]\r\ng is a tensor of shape [batch, U, H]\r\nthe transducer joint does\r\nsum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1)\r\nThe resultant tensor is of shape [batch, T, U, H]\r\nEach thread is working on a tile of the shape of tileF x tileG in the result tensor.\r\nThe input for the tile is first loaded in the register and is reused tileG and tileF times.\r\n\r\nThis joint function can optionally pack the output where the output tensor with a shape of\r\n[B, T, U, H] is packed into [B_packed, H].\r\nDon't-care region (t > fLen) or (u > gLen) is removed.\r\nTo enable packing, the starting offset for each batch need to be specified with batchOffset.\r\n\r\nOptionally this joint function performs ReLU and/or dropout on the joint output, which is\r\ncontrolled by arguments relu and dropout, respectively. philoxArgs is argument used for generating\r\npseudorandom number. When at least one of operations in ReLU and dropout is activated, the joint\r\nfunction is a masked operation, which is controlled by the template argument masked. In this case,\r\nmasks are saved to backward.\r\n*/\r\ntemplate <typename scalar_t, int tileF, int tileG, int U, class OffsetCal, bool masked>\r\n__global__ void transducer_joint_tiled_forward(const scalar_t* f, const scalar_t* g, const int* fLen, const int* gLen,\r\n                                               const int64_t* batchOffset, int64_t maxFLen, int64_t maxGLen,\r\n                                               int64_t hiddenSize, int64_t hiddenPerBlock, bool packOutput, bool relu,\r\n                                               bool dropout, float p, at::PhiloxCudaState philoxArgs, scalar_t* sum,\r\n                                               uint8_t* mask) {\r\n  static_assert(U == 4, \"U has to be 4, as random numbers are generated in batch of 4\");\r\n\r\n  const int batch = blockIdx.z;\r\n  const int t = blockIdx.y * tileF;\r\n  const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock;\r\n  const int u = blockIdx.x / hiddenBlock * tileG;\r\n  const int hOffset = (blockIdx.x % hiddenBlock) * hiddenPerBlock;\r\n  const int h = threadIdx.x;\r\n  const auto myFLen = fLen[batch];\r\n  const auto myGLen = gLen[batch];\r\n\r\n  OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput);\r\n  const auto myBatchOffset = offsetCal.getBatchOffset();\r\n  const auto strideF = offsetCal.getStrideF();\r\n\r\n  scalar_t const* myF = f + batch * maxFLen * hiddenSize + t * hiddenSize + hOffset;\r\n  scalar_t const* myG = g + batch * maxGLen * hiddenSize + u * hiddenSize + hOffset;\r\n  scalar_t* mySum = sum + myBatchOffset + t * strideF + u * hiddenSize + hOffset;\r\n  uint8_t* myMask = mask + myBatchOffset + t * strideF + u * hiddenSize + hOffset;\r\n\r\n  // The following code is only needed for dropout. We try to bypass them as much as possible.\r\n  auto seeds = masked ? at::cuda::philox::unpack(philoxArgs)\r\n                      : std::make_tuple(static_cast<uint64_t>(0), static_cast<uint64_t>(0));\r\n  uint64_t tid =\r\n      masked ? (static_cast<uint64_t>(blockIdx.z) * gridDim.y * gridDim.x + blockIdx.y * gridDim.x + blockIdx.x) *\r\n                       blockDim.x +\r\n                   threadIdx.x\r\n             : 0;\r\n  Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds));\r\n  scalar_t scale = masked ? ((p == 0) ? 0 : 1 / p) : 0;\r\n  bool dropoutMask[U];\r\n\r\n  if (t < myFLen and u < myGLen and hOffset + h < hiddenSize) {\r\n    // register buffers for tiled input reuse\r\n    scalar_t fBuffer[tileF], gBuffer[tileG];\r\n    for (int i = 0; i < tileF; ++i) {\r\n      if (t + i < myFLen) fBuffer[i] = myF[i * hiddenSize + h];\r\n    }\r\n    for (int j = 0; j < tileG; ++j) {\r\n      if (u + j < myGLen) gBuffer[j] = myG[j * hiddenSize + h];\r\n    }\r\n#pragma unroll\r\n    for (int i = 0; i < tileF; ++i) {\r\n      if (t + i < myFLen) {\r\n#pragma unroll\r\n        for (int j = 0; j < tileG; ++j) {\r\n          int idx = i * tileG + j;\r\n          if (masked and dropout and idx % U == 0) {\r\n            // For performance, generate 4 random numbers in one shot\r\n            // auto rand4 = curand_uniform4(&state);\r\n            auto rand4 = uniform4(ph());\r\n            dropoutMask[0] = rand4.x < p;\r\n            dropoutMask[1] = rand4.y < p;\r\n            dropoutMask[2] = rand4.z < p;\r\n            dropoutMask[3] = rand4.w < p;\r\n          }\r\n\r\n          if (u + j < myGLen) {\r\n            scalar_t out = fBuffer[i] + gBuffer[j];\r\n            if (masked) {\r\n              // Apply ReLU here when relu is True\r\n              bool localMask = relu ? (out > 0) : 1;\r\n              localMask = dropout ? localMask & dropoutMask[idx % U] : localMask;\r\n              out = dropout ? out * localMask * scale : out * localMask;\r\n              myMask[i * strideF + j * hiddenSize + h] = static_cast<uint8_t>(localMask);\r\n            }\r\n            mySum[i * strideF + j * hiddenSize + h] = out;\r\n          } else if (packOutput == false and u + j < maxGLen)\r\n            mySum[i * strideF + j * hiddenSize + h] = -1;\r\n        }\r\n      } else if (packOutput == false and t + i < maxFLen) {\r\n// Again need to write finite data to don't-care region\r\n#pragma unroll\r\n        for (int j = 0; j < tileG; ++j) {\r\n          if (u + j < maxGLen) mySum[i * strideF + j * hiddenSize + h] = -1;\r\n        }\r\n      }\r\n    }\r\n  } else if (packOutput == false and t < maxFLen and u < maxGLen and hOffset + h < hiddenSize) {\r\n// Only need to ensure the finity in normal mode\r\n#pragma unroll\r\n    for (int i = 0; i < tileF; ++i) {\r\n      if (t + i < maxFLen) {\r\n#pragma unroll\r\n        for (int j = 0; j < tileG; ++j) {\r\n          if (u + j < maxGLen) mySum[i * strideF + j * hiddenSize + h] = -1;\r\n        }\r\n      }\r\n    }\r\n  }\r\n}\r\n\r\n/*\r\nBwd operation (reduction) on one input tensor. Since the operation performed for the two input\r\ntensors are exactly the same, only one kernel is needed, and the different indexing offsets\r\nand strides are handled by OffsetCalBwd.\r\n\r\nWhen packing is enabled in the fwd op, unpacking is needed to restore the gradients in a\r\nnon-packed form.\r\n\r\nWhen ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,\r\nand mask contains the mask information.\r\n*/\r\ntemplate <typename scalar_t, typename acc_t, class OffsetCal, bool masked>\r\n__device__ void transducer_joint_single_backward(const scalar_t* grad, const uint8_t* mask, const int* fLen,\r\n                                                 const int* gLen, const int64_t* batchOffset, int64_t maxFLen,\r\n                                                 int64_t maxGLen, int64_t hiddenSize, bool packOutput,\r\n                                                 bool bwdFasterDim,  // whether bwd on the faster moving dimension (u)\r\n                                                 float scale, scalar_t* inGrad, int yBlockOffset = 0) {\r\n  const int batch = blockIdx.z;\r\n  // For the second input tensor, this offset need to be subtracted because the first yBlockOffset\r\n  // sets of thread blocks are for the first input tensor.\r\n  const int x = blockIdx.y - yBlockOffset;\r\n  const int hOffset = blockIdx.x * C10_WARP_SIZE;\r\n  const int wid = threadIdx.y;\r\n  const int lid = threadIdx.x;\r\n  const int numWarp = blockDim.y;\r\n  extern __shared__ char smem8[];\r\n  auto smem = reinterpret_cast<acc_t*>(smem8);\r\n\r\n  OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, bwdFasterDim);\r\n  const auto maxXLen = offsetCal.getMaxXLen();\r\n  const auto myXLen = offsetCal.getMyXLen();\r\n  const auto myYLen = offsetCal.getMyYLen();\r\n  scalar_t* myInGrad = inGrad + batch * maxXLen * hiddenSize + x * hiddenSize + hOffset;\r\n\r\n  if (x < myXLen) {\r\n    const auto myBatchOffset = offsetCal.getBatchOffset();\r\n    const auto strideX = offsetCal.getStrideX();\r\n    const auto strideY = offsetCal.getStrideY();\r\n    const scalar_t* myGrad = grad + myBatchOffset + x * strideX + hOffset;\r\n    const uint8_t* myMask = masked ? mask + myBatchOffset + x * strideX + hOffset : nullptr;\r\n\r\n    // Each warp reduces numYPerWarp \"y\" first\r\n    acc_t warpSum = 0;\r\n    auto numYPerWarp = (myYLen + numWarp - 1) / numWarp;\r\n#pragma unroll\r\n    for (int warpY = 0; warpY < numYPerWarp; ++warpY) {\r\n      auto y = wid * numYPerWarp + warpY;\r\n      if (y < myYLen and (hOffset + lid) < hiddenSize)\r\n        if (masked)\r\n          warpSum += static_cast<acc_t>(myGrad[y * strideY + lid]) * myMask[y * strideY + lid] * scale;\r\n        else\r\n          warpSum += myGrad[y * strideY + lid];\r\n    }\r\n\r\n    // transpose partial sum in SMEM and reduce further using warpReduce\r\n    smem[lid * numWarp + wid] = warpSum;\r\n    __syncthreads();\r\n    auto sum = smem[wid * C10_WARP_SIZE + lid];\r\n    sum = warpReduce(sum, numWarp);\r\n\r\n    // a a b b c c d d\r\n    // a a b b c c d d\r\n    // a a b b c c d d\r\n    // a a b b c c d d\r\n    // example of 4 warps (a, b, c, d) with 8 threads per warp\r\n    // Each warp need 8 / 4 = 2 threads to write the results.\r\n    if (hOffset + wid * C10_WARP_SIZE / numWarp + lid / numWarp < hiddenSize) {\r\n      if (lid % numWarp == 0) {\r\n        myInGrad[wid * C10_WARP_SIZE / numWarp + lid / numWarp] = sum;\r\n      }\r\n    }\r\n  } else if (wid == 0 and hOffset + lid < hiddenSize) {\r\n    // Need to ensure the grad is zero for don't care region\r\n    myInGrad[lid] = 0;\r\n  }\r\n}\r\n\r\n/*\r\nActual bwd (reduction) kernel get launched.\r\nCall transducer_joint_single_backward twice on two input tensors.\r\nThe two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op\r\nuses the rest.\r\nWhen ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,\r\nand mask contains the mask information.\r\n*/\r\ntemplate <typename scalar_t, typename acc_t, class OffsetCal, bool masked>\r\n__global__ void transducer_joint_combined_backward(const scalar_t* grad, const uint8_t* mask, const int* fLen,\r\n                                                   const int* gLen, const int64_t* batchOffset, int64_t maxFLen,\r\n                                                   int64_t maxGLen, int64_t hiddenSize, bool packOutput, float scale,\r\n                                                   scalar_t* fGrad, scalar_t* gGrad) {\r\n  if (blockIdx.y < maxFLen) {\r\n    transducer_joint_single_backward<scalar_t, acc_t, OffsetCal, masked>(\r\n        grad, mask, fLen, gLen, batchOffset, maxFLen, maxGLen, hiddenSize, packOutput, false, scale, fGrad);\r\n  } else {\r\n    transducer_joint_single_backward<scalar_t, acc_t, OffsetCal, masked>(\r\n        grad, mask, fLen, gLen, batchOffset, maxFLen, maxGLen, hiddenSize, packOutput, true, scale, gGrad, maxFLen);\r\n  }\r\n}\r\n\r\n/*\r\nVectorized version of transducer_joint_single_backward\r\nDoing exact same operation as transducer_joint_single_backward except the load and store are\r\nvectorized.\r\nWhen packing is enabled in the fwd op, unpacking is needed to restore the gradients in a\r\nnon-packed form.\r\nWhen ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,\r\nand mask contains the mask information.\r\n*/\r\ntemplate <typename scalar_t, typename acc_t, typename vec_t, int V, class OffsetCal, bool masked>\r\n__device__ void transducer_joint_single_vec_backward(const scalar_t* grad, const uint8_t* mask, const int* fLen,\r\n                                                     const int* gLen, const int64_t* batchOffset, int64_t maxFLen,\r\n                                                     int64_t maxGLen, int64_t hiddenSize, bool packOutput,\r\n                                                     bool bwdFasterDim, float scale, scalar_t* inGrad,\r\n                                                     int yBlockOffset = 0) {\r\n  const int batch = blockIdx.z;\r\n  const int x = blockIdx.y - yBlockOffset;\r\n  const int hOffset = blockIdx.x * C10_WARP_SIZE * V;\r\n  const int wid = threadIdx.y;\r\n  const int lid = threadIdx.x;\r\n  const int numWarp = blockDim.y;\r\n\r\n  // Figure out the vectorization type for mask\r\n  using mvec_t = mvec_type<V>;\r\n\r\n  OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, bwdFasterDim);\r\n  const auto maxXLen = offsetCal.getMaxXLen();\r\n  const auto myXLen = offsetCal.getMyXLen();\r\n  const auto myYLen = offsetCal.getMyYLen();\r\n  scalar_t* myInGrad = inGrad + batch * maxXLen * hiddenSize + x * hiddenSize + hOffset;\r\n  extern __shared__ char smem8[];\r\n  auto smem = reinterpret_cast<acc_t*>(smem8);\r\n\r\n  acc_t warpSum[V];\r\n  scalar_t inBuffer[V];\r\n  uint8_t maskBuffer[V];\r\n  scalar_t outBuffer[V];\r\n  auto myInGradVec = reinterpret_cast<vec_t*>(myInGrad);\r\n  auto outBufferVec = reinterpret_cast<vec_t*>(outBuffer);\r\n\r\n  if (x < myXLen) {\r\n    const auto myBatchOffset = offsetCal.getBatchOffset();\r\n    const auto strideX = offsetCal.getStrideX();\r\n    const auto strideY = offsetCal.getStrideY();\r\n    const scalar_t* myGrad = grad + myBatchOffset + x * strideX + hOffset;\r\n    const uint8_t* myMask = masked ? mask + myBatchOffset + x * strideX + hOffset : nullptr;\r\n\r\n    for (int i = 0; i < V; ++i) warpSum[i] = 0;\r\n\r\n    // Each warp reduces numYPerWarp \"y\" first\r\n    auto numYPerWarp = (myYLen + numWarp - 1) / numWarp;\r\n    for (int warpY = 0; warpY < numYPerWarp; ++warpY) {\r\n      auto y = wid * numYPerWarp + warpY;\r\n      auto myGradVec = reinterpret_cast<vec_t const*>(myGrad + y * strideY);\r\n      auto myMaskVec = masked ? reinterpret_cast<mvec_t const*>(myMask + y * strideY) : nullptr;\r\n      auto inBufferVec = reinterpret_cast<vec_t*>(inBuffer);\r\n      auto maskBufferVec = reinterpret_cast<mvec_t*>(maskBuffer);\r\n      if (hOffset + lid * V < hiddenSize and y < myYLen) {\r\n        *inBufferVec = myGradVec[lid];  // vectorized load\r\n        if (masked) {\r\n          *maskBufferVec = myMaskVec[lid];\r\n#pragma unroll\r\n          for (int i = 0; i < V; ++i) warpSum[i] += static_cast<acc_t>(inBuffer[i]) * maskBuffer[i] * scale;\r\n        } else {\r\n#pragma unroll\r\n          for (int i = 0; i < V; ++i) warpSum[i] += inBuffer[i];\r\n        }\r\n      }\r\n    }\r\n\r\n    // transpose partial sum in SMEM and reduce further using warpReduce\r\n    for (int i = 0; i < V; ++i) {\r\n      smem[lid * numWarp + wid] = warpSum[i];\r\n      __syncthreads();\r\n      auto sum = smem[wid * C10_WARP_SIZE + lid];\r\n\r\n      if (hOffset + (wid * C10_WARP_SIZE / numWarp) * V < hiddenSize) {\r\n        sum = warpReduce(sum, numWarp);\r\n        if (lid % numWarp == 0) {\r\n          outBuffer[i] = sum;\r\n        }\r\n      }\r\n      __syncthreads();\r\n    }\r\n\r\n    // a a b b c c d d\r\n    // a a b b c c d d\r\n    // a a b b c c d d\r\n    // a a b b c c d d\r\n    // example of 4 warps (a, b, c, d) with 8 threads per warp\r\n    // Each warp need 8 / 4 = 2 threads to write the results.\r\n    if (lid % numWarp == 0 and hOffset + (wid * C10_WARP_SIZE / numWarp + lid / numWarp) * V < hiddenSize)\r\n      myInGradVec[wid * C10_WARP_SIZE / numWarp + lid / numWarp] = *outBufferVec;\r\n  } else if (wid == 0 and hOffset + lid * V < hiddenSize) {\r\n    // Need to ensure the grad is zero for don't care region\r\n    myInGradVec[lid] = 0;\r\n  }\r\n}\r\n\r\n/*\r\nVecotrized version of transducer_joint_combined_backward\r\nCall transducer_joint_single_vec_backward twice on two input tensors.\r\nThe two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op\r\nuses the rest.\r\nWhen ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,\r\nand mask contains the mask information.\r\n*/\r\ntemplate <typename scalar_t, typename acc_t, typename vec_t, int V, class OffsetCal, bool masked>\r\n__global__ void transducer_joint_combined_vec_backward(const scalar_t* grad, const uint8_t* mask, const int* fLen,\r\n                                                       const int* gLen, const int64_t* batchOffset, int64_t maxFLen,\r\n                                                       int64_t maxGLen, int64_t hiddenSize, bool packOutput,\r\n                                                       float scale, scalar_t* fGrad, scalar_t* gGrad) {\r\n  if (blockIdx.y < maxFLen) {\r\n    transducer_joint_single_vec_backward<scalar_t, acc_t, vec_t, V, OffsetCal, masked>(\r\n        grad, mask, fLen, gLen, batchOffset, maxFLen, maxGLen, hiddenSize, packOutput, false, scale, fGrad);\r\n  } else {\r\n    transducer_joint_single_vec_backward<scalar_t, acc_t, vec_t, V, OffsetCal, masked>(\r\n        grad, mask, fLen, gLen, batchOffset, maxFLen, maxGLen, hiddenSize, packOutput, true, scale, gGrad, maxFLen);\r\n  }\r\n}\r\n\r\nstd::vector<torch::Tensor> transducer_joint_cuda_forward(torch::Tensor f, torch::Tensor g, torch::Tensor fLen,\r\n                                                         torch::Tensor gLen, torch::Tensor batchOffset,\r\n                                                         int64_t packedBatch, int opt, bool packOutput, bool relu,\r\n                                                         bool dropout, float dropoutProb, int tileSize) {\r\n  auto tensorOpt = f.options();\r\n  auto dtype = f.scalar_type();\r\n  const auto batchSize = f.size(0);\r\n  const auto maxFLen = f.size(1);\r\n  const auto maxGLen = g.size(1);\r\n  const auto hiddenSize = f.size(2);\r\n  bool masked = dropout or relu;\r\n\r\n  int64_t* batchOffsetPtr = nullptr;\r\n  torch::Tensor sum, mask;\r\n  auto maskOpt = tensorOpt.dtype(torch::kUInt8);\r\n  if (!packOutput) {\r\n    sum = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt);\r\n    batchOffsetPtr = nullptr;\r\n    if (masked) mask = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, maskOpt);\r\n  } else {\r\n    sum = torch::empty({packedBatch, hiddenSize}, tensorOpt);\r\n    batchOffsetPtr = batchOffset.data_ptr<int64_t>();\r\n    if (masked) mask = torch::empty({packedBatch, hiddenSize}, maskOpt);\r\n  }\r\n  uint8_t* maskPtr = masked ? mask.data_ptr<uint8_t>() : nullptr;\r\n\r\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\r\n\r\n  TORCH_CHECK(opt == 0 or opt == 1, \"Got an invalid optimization level \", opt);\r\n  // Simple heuristics\r\n  const int numThread =\r\n      std::min(128, (static_cast<int>(hiddenSize) + C10_WARP_SIZE - 1) / C10_WARP_SIZE * C10_WARP_SIZE);\r\n\r\n  if (opt == 0) {\r\n    // vanilla kernel\r\n    const int threads = numThread;\r\n    const dim3 blocks(maxGLen, maxFLen, batchSize);\r\n\r\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\r\n        dtype, \"transducer_joint_forward\", ([&] {\r\n          transducer_joint_forward<scalar_t, OffsetCalFwd><<<blocks, threads, 0, stream>>>(\r\n              f.data_ptr<scalar_t>(), g.data_ptr<scalar_t>(), fLen.data_ptr<int>(), gLen.data_ptr<int>(),\r\n              batchOffsetPtr, maxFLen, maxGLen, hiddenSize, packOutput, sum.data_ptr<scalar_t>());\r\n        }));\r\n  }\r\n  if (opt == 1) {\r\n    // tiled version. For simplicity, assume tileF == tileG, even though the kernel can\r\n    // support more general cases.\r\n    const int threads = numThread;\r\n    const int hiddenPerBlock = numThread;\r\n    const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock;\r\n    const dim3 blocks((maxGLen + tileSize - 1) / tileSize * hiddenBlock, (maxFLen + tileSize - 1) / tileSize,\r\n                      batchSize);\r\n\r\n    TORCH_CHECK(tileSize == 1 or tileSize == 2 or tileSize == 4, \"Expected tileSize to be in [1, 2, 4], but got \",\r\n                tileSize);\r\n\r\n    at::PhiloxCudaState rng_engine_inputs;\r\n    if (masked) {\r\n      // set up PRG when the input is masked. rng_engine_inputs will be used as a space filler\r\n      // for non-masked calls.\r\n      // Therefore no need to initialize.\r\n      c10::optional<at::Generator> gen_;\r\n      auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());\r\n      // counterOffset records how many cuRAND calls each thread makes. For a tiled kernel,\r\n      // each thread processes tileF * tileG output elements.\r\n      int64_t counterOffset = tileSize * tileSize;\r\n      {\r\n        std::lock_guard<std::mutex> lock(gen->mutex_);\r\n        rng_engine_inputs = gen->philox_cuda_state(counterOffset);\r\n      }\r\n    }\r\n\r\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\r\n        dtype, \"transducer_joint_forward\", ([&] {\r\n          void (*kernel)(const scalar_t*, const scalar_t*, const int*, const int*, const int64_t*, int64_t, int64_t,\r\n                         int64_t, int64_t, bool, bool, bool, float, at::PhiloxCudaState, scalar_t*, uint8_t*);\r\n          if (masked) {\r\n            switch (tileSize) {\r\n              case 2:\r\n                kernel = &transducer_joint_tiled_forward<scalar_t, 2, 2, 4, OffsetCalFwd, true>;\r\n                break;\r\n              case 4:\r\n                kernel = &transducer_joint_tiled_forward<scalar_t, 4, 4, 4, OffsetCalFwd, true>;\r\n                break;\r\n            }\r\n          } else {\r\n            switch (tileSize) {\r\n              case 1:\r\n                kernel = &transducer_joint_tiled_forward<scalar_t, 1, 1, 4, OffsetCalFwd, false>;\r\n                break;\r\n              case 2:\r\n                kernel = &transducer_joint_tiled_forward<scalar_t, 2, 2, 4, OffsetCalFwd, false>;\r\n                break;\r\n              case 4:\r\n                kernel = &transducer_joint_tiled_forward<scalar_t, 4, 4, 4, OffsetCalFwd, false>;\r\n                break;\r\n            }\r\n          }\r\n\r\n          kernel<<<blocks, threads, 0, stream>>>(f.data_ptr<scalar_t>(), g.data_ptr<scalar_t>(), fLen.data_ptr<int>(),\r\n                                                 gLen.data_ptr<int>(), batchOffsetPtr, maxFLen, maxGLen, hiddenSize,\r\n                                                 hiddenPerBlock, packOutput, relu, dropout, 1.0f - dropoutProb,\r\n                                                 rng_engine_inputs, sum.data_ptr<scalar_t>(), maskPtr);\r\n        }));\r\n  }\r\n\r\n  C10_CUDA_CHECK(cudaGetLastError());\r\n  if (masked)\r\n    return {sum, mask};\r\n  else\r\n    return {sum};\r\n}\r\n\r\nstd::vector<torch::Tensor> transducer_joint_cuda_backward(std::vector<torch::Tensor> in, torch::Tensor fLen,\r\n                                                          torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen,\r\n                                                          int maxGLen, bool packOutput, float scale) {\r\n  auto grad = in[0];\r\n  bool masked = (in.size() == 2);\r\n  uint8_t* maskPtr = masked ? in[1].data_ptr<uint8_t>() : nullptr;\r\n\r\n  auto tensorOpt = grad.options();\r\n  auto dtype = grad.scalar_type();\r\n  const int batchSize = fLen.size(0);\r\n  const int hiddenSize = grad.size(-1);\r\n\r\n  const auto deviceProperties = at::cuda::getCurrentDeviceProperties();\r\n  const int maxNumWarp = deviceProperties->maxThreadsPerBlock / C10_WARP_SIZE;\r\n\r\n  torch::Tensor fGrad = torch::empty({batchSize, maxFLen, hiddenSize}, tensorOpt);\r\n  torch::Tensor gGrad = torch::empty({batchSize, maxGLen, hiddenSize}, tensorOpt);\r\n\r\n  int64_t* batchOffsetPtr = (!packOutput) ? nullptr : batchOffset.data_ptr<int64_t>();\r\n\r\n  // The number \"y\" I would like each thread to work on\r\n  const int workPerThread = 32;\r\n  // Since the bwd for f and g have the same thread block size, we need to use the max of the two.\r\n  int numWarp = largestPowerOfTwo((std::max(maxFLen, maxGLen) + workPerThread - 1) / workPerThread);\r\n  // Would like to have at least 2 warps\r\n  numWarp = std::max(2, numWarp);\r\n  // cap on the maximum number of warps allowed\r\n  numWarp = std::min(maxNumWarp, numWarp);\r\n\r\n  // Need smem for transposing the partial sum. The partial sum is in a matrix of the shape\r\n  // numWarp x warpSize\r\n  const int smemSize = numWarp * C10_WARP_SIZE;\r\n  const dim3 threads(C10_WARP_SIZE, numWarp, 1);\r\n\r\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(\r\n      dtype, \"transducer_joint_cuda_backward_kernel\", ([&] {\r\n        auto gradPtr = grad.data_ptr<scalar_t>();\r\n        auto fLenPtr = fLen.data_ptr<int>();\r\n        auto gLenPtr = gLen.data_ptr<int>();\r\n        auto fGradPtr = fGrad.data_ptr<scalar_t>();\r\n        auto gGradPtr = gGrad.data_ptr<scalar_t>();\r\n\r\n        // resolve the acc_t type\r\n        using acc_t = at::acc_type<scalar_t, true>;\r\n        using vec_t = uint64_t;\r\n\r\n        constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t);\r\n        constexpr int vecAlignment = std::alignment_of<vec_t>::value;\r\n\r\n        // if all input and output tensors meet the alignment requirement\r\n        bool memAlign = (reinterpret_cast<uint64_t>(gradPtr) % vecAlignment == 0) and\r\n                        (reinterpret_cast<uint64_t>(fGradPtr) % vecAlignment == 0) and\r\n                        (reinterpret_cast<uint64_t>(gGradPtr) % vecAlignment == 0);\r\n\r\n        if (vectFactor > 1 and hiddenSize % vectFactor == 0 and memAlign) {\r\n          // If vectorization helps and the alignment requirement is met, use the vectorized\r\n          // kernel. For simplicity, hiddenSize needs to be a multiple vecFactor.\r\n          const dim3 blocks((hiddenSize + C10_WARP_SIZE * vectFactor - 1) / (C10_WARP_SIZE * vectFactor),\r\n                            maxFLen + maxGLen, batchSize);\r\n          if (masked) {\r\n            transducer_joint_combined_vec_backward<scalar_t, acc_t, vec_t, vectFactor, OffsetCalBwd, true>\r\n                <<<blocks, threads, smemSize * sizeof(acc_t)>>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr,\r\n                                                                maxFLen, maxGLen, hiddenSize, packOutput, scale,\r\n                                                                fGradPtr, gGradPtr);\r\n          } else {\r\n            transducer_joint_combined_vec_backward<scalar_t, acc_t, vec_t, vectFactor, OffsetCalBwd, false>\r\n                <<<blocks, threads, smemSize * sizeof(acc_t)>>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr,\r\n                                                                maxFLen, maxGLen, hiddenSize, packOutput, scale,\r\n                                                                fGradPtr, gGradPtr);\r\n          }\r\n        } else {\r\n          const dim3 blocks((hiddenSize + C10_WARP_SIZE - 1) / C10_WARP_SIZE, maxFLen + maxGLen, batchSize);\r\n          if (masked) {\r\n            transducer_joint_combined_backward<scalar_t, acc_t, OffsetCalBwd, true>\r\n                <<<blocks, threads, smemSize * sizeof(acc_t)>>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr,\r\n                                                                maxFLen, maxGLen, hiddenSize, packOutput, scale,\r\n                                                                fGradPtr, gGradPtr);\r\n          } else {\r\n            transducer_joint_combined_backward<scalar_t, acc_t, OffsetCalBwd, false>\r\n                <<<blocks, threads, smemSize * sizeof(acc_t)>>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr,\r\n                                                                maxFLen, maxGLen, hiddenSize, packOutput, scale,\r\n                                                                fGradPtr, gGradPtr);\r\n          }\r\n        }\r\n      }));\r\n\r\n  return {fGrad, gGrad};\r\n}\r\n"
  },
  {
    "path": "apex/contrib/csrc/transducer/transducer_loss.cpp",
    "content": "#include <torch/extension.h>\n\n#include <vector>\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) \\\n  CHECK_CUDA(x);       \\\n  CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> transducer_loss_cuda_forward(torch::Tensor x, torch::Tensor label, torch::Tensor audLen,\n                                                        torch::Tensor txtLen, torch::Tensor batchOffset, int maxFLen,\n                                                        int blankIdx, int opt, bool packedInput);\n\ntorch::Tensor transducer_loss_cuda_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha,\n                                            torch::Tensor beta, torch::Tensor audLen, torch::Tensor txtLen,\n                                            torch::Tensor label, torch::Tensor batchOffset, int maxFLen, int blankIdx,\n                                            int opt, bool fuseSoftmaxBackward, bool packedInput);\n\nstd::vector<torch::Tensor> transducer_loss_forward(torch::Tensor x, torch::Tensor label, torch::Tensor fLen,\n                                                   torch::Tensor yLen, torch::Tensor batchOffset, int maxFLen,\n                                                   int blankIdx, int opt, bool packedInput) {\n  CHECK_INPUT(x);\n  CHECK_INPUT(label);\n  CHECK_INPUT(fLen);\n  CHECK_INPUT(yLen);\n  if (packedInput) CHECK_INPUT(batchOffset);\n  return transducer_loss_cuda_forward(x, label, fLen, yLen, batchOffset, maxFLen, blankIdx, opt, packedInput);\n}\n\ntorch::Tensor transducer_loss_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha, torch::Tensor beta,\n                                       torch::Tensor fLen, torch::Tensor yLen, torch::Tensor label,\n                                       torch::Tensor batchOffset, int maxFLen, int blankIdx, int opt,\n                                       bool fuseSoftmaxBackward, bool packedInput) {\n  CHECK_INPUT(x);\n  CHECK_INPUT(label);\n  CHECK_INPUT(lossGrad);\n  CHECK_INPUT(alpha);\n  CHECK_INPUT(beta);\n  CHECK_INPUT(fLen);\n  CHECK_INPUT(yLen);\n  if (packedInput) CHECK_INPUT(batchOffset);\n\n  return transducer_loss_cuda_backward(x, lossGrad, alpha, beta, fLen, yLen, label, batchOffset, maxFLen, blankIdx, opt,\n                                       fuseSoftmaxBackward, packedInput);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &transducer_loss_forward, \"transducer loss forward (CUDA)\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward\", &transducer_loss_backward, \"transducer loss backward (CUDA)\",\n        py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "apex/contrib/csrc/transducer/transducer_loss_kernel.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <torch/extension.h>\n\n#include <vector>\n\ntemplate <typename scalar_t>\n__device__ __forceinline__ scalar_t logSumExp(scalar_t a, scalar_t b) {\n  // standard log-sum-exp trick is used here to provide better numerical stability\n  return (a >= b) ? a + std::log1p(exp(b - a)) : b + std::log1p(exp(a - b));\n}\n\n// Vanilla transducer loss function (i.e. forward-backward algorithm)\n// Detail of this loss function can be found in:\n// [1] Sequence Transduction with Recurrent Neural Networks.\n\n// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted\n// into log scale by the preceding log_softmax layer\n// Diagonal wavefront advancing usually used in dynamic programming is leveraged here.\n// alpha and beta are of acc_t type, as they are essentially accumulators.\n\n// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into\n// [B_packed, H].\n// Don't-care region (t > audLen) or (u > txtLen) is removed.\n// To support the packed input, the starting offsets for each batch need to be specified with\n// batchOffset.\ntemplate <typename scalar_t, typename acc_t>\n__global__ void transducer_loss_forward(const scalar_t* x, const int* label, const int* audLen, const int* txtLen,\n                                        const int64_t* batchOffset,\n                                        int64_t dictSize,  // 64-bit indexing for data tensor\n                                        int64_t blankIdx, int64_t maxFLen, int64_t maxGLen, bool packedInput,\n                                        acc_t* alpha, acc_t* beta, scalar_t* loss) {\n  const int batch = blockIdx.y;\n  const int tid = threadIdx.x;\n  const auto myFLen = audLen[batch];\n  // Note that start of the sentence is added as 1 here\n  const auto myGLen = txtLen[batch] + 1;\n  const auto myLabel = label + batch * (maxGLen - 1);\n  const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen;\n  const int64_t myStrideT = packedInput ? myGLen : maxGLen;\n  const scalar_t* myX = x + myBatchOffset * dictSize;\n  int u = tid;\n\n  if (blockIdx.x == 0) {\n    // alpha path\n    acc_t* myAlpha = alpha + batch * maxFLen * maxGLen;\n    if (u == 0) myAlpha[0] = 0;\n    __syncthreads();\n\n    for (int64_t step = 1; step < myFLen + myGLen - 1; ++step) {\n      // Move along the diagonal wavefront to leverage available parallelism\n      for (u = tid; u < myGLen; u += blockDim.x) {\n        int64_t t = step - u;\n        if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) {\n          // Eq(16) in [1]\n          if (u == 0) {\n            // alpha(t, u) = alpha(t-1, u) * null(t-1, u)\n            myAlpha[t * maxGLen + u] = myAlpha[(t - 1) * maxGLen] + myX[((t - 1) * myStrideT) * dictSize + blankIdx];\n          } else if (t == 0) {\n            // alpha(t, u-1) = alpha(t, u-1) * y(t, u-1)\n            myAlpha[u] = myAlpha[u - 1] + myX[(u - 1) * dictSize + myLabel[u - 1]];\n          } else {\n            // alpha(t, u) = alpha(t-1, u) * null(t-1, u) + alpha(t, u-1) * y(t, u-1)\n            acc_t current = myAlpha[(t - 1) * maxGLen + u] + myX[((t - 1) * myStrideT + u) * dictSize + blankIdx];\n            acc_t next = myAlpha[t * maxGLen + u - 1] + myX[(t * myStrideT + u - 1) * dictSize + myLabel[u - 1]];\n            myAlpha[t * maxGLen + u] = logSumExp(next, current);\n          }\n        }\n      }\n      __syncthreads();\n    }\n  } else if (blockIdx.x == 1) {\n    // beta path\n    acc_t* myBeta = beta + batch * maxFLen * maxGLen;\n    if (u == 0) {\n      myBeta[(myFLen - 1) * maxGLen + myGLen - 1] = myX[((myFLen - 1) * myStrideT + myGLen - 1) * dictSize + blankIdx];\n    }\n    __syncthreads();\n\n    for (int64_t step = myFLen + myGLen - 3; step >= 0; --step) {\n      for (u = tid; u < myGLen; u += blockDim.x) {\n        int64_t t = step - u;\n        if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) {\n          // Eq(18) in [1]\n          if (u == myGLen - 1) {\n            // beta(t, u) = beta(t+1, u) * null(t, u)\n            myBeta[t * maxGLen + u] = myBeta[(t + 1) * maxGLen + u] + myX[(t * myStrideT + u) * dictSize + blankIdx];\n          } else if (t == myFLen - 1) {\n            // beta(t, u) = beta(t, u+1) * y(t, u)\n            myBeta[t * maxGLen + u] = myBeta[t * maxGLen + u + 1] + myX[(t * myStrideT + u) * dictSize + myLabel[u]];\n          } else {\n            // beta(t, u) = beta(t+1, u)*null(t, u) + beta(t, u+1)*y(t, u)\n            acc_t current = myBeta[(t + 1) * maxGLen + u] + myX[(t * myStrideT + u) * dictSize + blankIdx];\n            acc_t next = myBeta[t * maxGLen + u + 1] + myX[(t * myStrideT + u) * dictSize + myLabel[u]];\n            myBeta[t * maxGLen + u] = logSumExp(next, current);\n          }\n        }\n      }\n      __syncthreads();\n    }\n    if (tid == 0) loss[batch] = -myBeta[0];\n  }\n}\n\n// transudcer loss function (i.e. forward-backward algorithm) with batch loading optimization.\n// Compared to the vanilla version, there are two optimizations:\n// 1. load x in batch through loop unrolling to reduce the latency.\n// 2. Use registers and shared memory to hold alpha and beta values passed from one step the next.\n// For simplicity, this kernel currently only supports U <= maxThread, which should be the common\n// case. For cases where U > maxThread, the vanilla kernel is used as a fallback option.\n\n// Detail of this loss function can be found in:\n// [1] Sequence Transduction with Recurrent Neural Networks.\n// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted\n// into log scale by the preceding log_softmax layer\n// Diagonal wavefront advancing usually used in dynamic programming is leveraged here.\n// alpha and beta are of acc_t type, as they are essentially accumulators.\n\n// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into\n// [B_packed, H].\n// Don't-care region (t > audLen) or (u > txtLen) is removed.\n// To support the packed input, the starting offsets for each batch need to be specified with\n// batchOffset.\ntemplate <typename scalar_t, typename acc_t, int batchLdSize>\n__global__ void transducer_loss_batch_load_forward(const scalar_t* x, const int* label, const int* audLen,\n                                                   const int* txtLen, const int64_t* batchOffset, int64_t dictSize,\n                                                   int64_t blankIdx, int64_t maxFLen, int64_t maxGLen, bool packedInput,\n                                                   acc_t* alpha, acc_t* beta, scalar_t* loss) {\n  const int batch = blockIdx.y;\n  int u = threadIdx.x;\n  const auto myFLen = audLen[batch];\n  const auto myGLen = txtLen[batch] + 1;\n  const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen;\n  const int64_t myStrideT = packedInput ? myGLen : maxGLen;\n  const scalar_t* myX = x + myBatchOffset * dictSize;\n  scalar_t next[batchLdSize], current[batchLdSize];\n  extern __shared__ char smem8[];\n  auto smem = reinterpret_cast<acc_t*>(smem8);\n\n  if (blockIdx.x == 0) {\n    // alpha path\n    acc_t* myAlpha = alpha + batch * maxFLen * maxGLen;\n    // two SMEM regions for double buffering read and write data to avoid data race\n    acc_t* const sharedAlpha[2] = {smem, smem + maxGLen};\n\n    sharedAlpha[0][u] = 0;\n    __syncthreads();\n\n    if (u == 0) myAlpha[0] = 0;\n\n    auto myAlphaLabel = (u == 0) ? 0 : label[batch * (maxGLen - 1) + u - 1];\n    // register used to pass value to the next step for the same thread\n    acc_t prvStepAlpha = 0;\n    for (int64_t step = 1; step < myFLen + myGLen - 1 + batchLdSize; step += batchLdSize) {\n// Move along the diagonal wavefront to leverage available parallelism\n// Batch loading X through loop unrolling\n#pragma unroll\n      for (int i = 0; i < batchLdSize; ++i) {\n        if (step + i < myFLen + myGLen - 1) {\n          // index computing\n          int64_t t = step + i - u;\n          int64_t currentId = ((t - 1) * myStrideT + u) * dictSize + blankIdx;\n          int64_t nextId = (t * myStrideT + u - 1) * dictSize + myAlphaLabel;\n          // main loading loop\n          if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) {\n            if (u == 0) {\n              current[i] = myX[currentId];\n            } else if (t == 0) {\n              next[i] = myX[nextId];\n            } else {\n              current[i] = myX[currentId];\n              next[i] = myX[nextId];\n            }\n          }\n        }\n      }\n      // main computing loop\n      for (int i = 0; i < batchLdSize; ++i) {\n        // swap the pointer for double buffering\n        auto sharedAlphaRd = sharedAlpha[(step + i - 1) % 2];\n        auto sharedAlphaWr = sharedAlpha[(step + i) % 2];\n        if (step + i < myFLen + myGLen - 1) {\n          int64_t t = step + i - u;\n          if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) {\n            // Eq(16) in [1]\n            if (u == 0)\n              prvStepAlpha = prvStepAlpha + current[i];\n            else if (t == 0)\n              prvStepAlpha = sharedAlphaRd[u - 1] + next[i];\n            else\n              prvStepAlpha = logSumExp(prvStepAlpha + current[i], sharedAlphaRd[u - 1] + next[i]);\n            sharedAlphaWr[u] = prvStepAlpha;\n            myAlpha[t * maxGLen + u] = prvStepAlpha;\n          }\n        }\n        __syncthreads();\n      }\n    }\n  } else if (blockIdx.x == 1) {\n    // beta path\n    acc_t* myBeta = beta + batch * maxFLen * maxGLen;\n    // two SMEM regions for double buffering read and write data to avoid data race\n    acc_t* const sharedBeta[2] = {smem, smem + maxGLen};\n    sharedBeta[0][u] = myX[((myFLen - 1) * myStrideT + myGLen - 1) * dictSize + blankIdx];\n    __syncthreads();\n\n    auto myBetaLabel = (u == maxGLen - 1) ? 0 : label[batch * (maxGLen - 1) + u];\n    // register used to pass value to the next step for the same thread\n    acc_t prvStepBeta = myX[((myFLen - 1) * myStrideT + myGLen - 1) * dictSize + blankIdx];\n    if (u == 0) myBeta[(myFLen - 1) * maxGLen + myGLen - 1] = prvStepBeta;\n\n    for (int64_t step = 1; step < myFLen + myGLen - 1; step += batchLdSize) {\n// Move along the diagonal wavefront to leverage available parallelism\n// Batch loading X\n#pragma unroll\n      for (int i = 0; i < batchLdSize; ++i) {\n        if (step + i < myFLen + myGLen - 1) {\n          // index computing\n          int64_t t = myFLen + myGLen - (step + i) - 2 - u;\n          int64_t currentId = (t * myStrideT + u) * dictSize + blankIdx;\n          int64_t nextId = (t * myStrideT + u) * dictSize + myBetaLabel;\n          // main loading loop\n          if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) {\n            if (u == myGLen - 1) {\n              current[i] = myX[currentId];\n            } else if (t == myFLen - 1) {\n              next[i] = myX[nextId];\n            } else {\n              current[i] = myX[currentId];\n              next[i] = myX[nextId];\n            }\n          }\n        }\n      }\n      // main computing loop\n      for (int i = 0; i < batchLdSize; ++i) {\n        // swap the pointer for double buffering\n        auto sharedBetaRd = sharedBeta[(step + i - 1) % 2];\n        auto sharedBetaWr = sharedBeta[(step + i) % 2];\n        if (step + i < myFLen + myGLen - 1) {\n          int64_t t = myFLen + myGLen - (step + i) - 2 - u;\n          if (t >= 0 and t < myFLen and u >= 0 and u < myGLen) {\n            // Eq(18) in [1]\n            if (u == myGLen - 1)\n              prvStepBeta = prvStepBeta + current[i];\n            else if (t == myFLen - 1)\n              prvStepBeta = sharedBetaRd[u + 1] + next[i];\n            else\n              prvStepBeta = logSumExp(prvStepBeta + current[i], sharedBetaRd[u + 1] + next[i]);\n            sharedBetaWr[u] = prvStepBeta;\n            myBeta[t * maxGLen + u] = prvStepBeta;\n          }\n        }\n        __syncthreads();\n      }\n    }\n    if (u == 0) loss[batch] = -prvStepBeta;\n  }\n}\n\n// Vanilla transudcer loss backward operation.\n// Detail of this loss function can be found in:\n// [1] Sequence Transduction with Recurrent Neural Networks.\n// For this backward kernel, bwd op for the preceding softmax is assumed to be handled elsewhere,\n// hence only Eq(20) in [1] is implemented in this kernel.\n\n// Each thread block works on [batch, t, :, :] of data. Each thread works on a specific u at a time\n// Since only gradients for the correct token and null token need to be updated, gradients at other\n// locations are initialized to 0.\n\n// To support the packed input, the starting offsets for each batch need to be specified with\n// batchOffset.\ntemplate <typename scalar_t, typename acc_t>\n__global__ void transducer_loss_backward(const scalar_t* x, const scalar_t* lossGrad, const int* audLen,\n                                         const int* txtLen, const int* label, const acc_t* alpha, const acc_t* beta,\n                                         const int64_t* batchOffset, int64_t dictSize, int64_t blankIdx,\n                                         int64_t maxFLen, int64_t maxGLen, bool packedInput, scalar_t* xGrad) {\n  const int tid = threadIdx.x;\n  const int t = blockIdx.x;\n  const int batch = blockIdx.y;\n  const int64_t myFLen = audLen[batch];\n  const int64_t myGLen = txtLen[batch] + 1;\n  const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen;\n  const int64_t myStrideT = packedInput ? myGLen : maxGLen;\n  auto myX = x + (myBatchOffset + t * myStrideT) * dictSize;\n  auto myAlpha = alpha + batch * maxFLen * maxGLen;\n  auto myBeta = beta + batch * maxFLen * maxGLen;\n  auto myXGrad = xGrad + (myBatchOffset + t * myStrideT) * dictSize;\n  auto myLabel = label + batch * (maxGLen - 1);\n\n  int64_t u = tid;\n  while (t < myFLen and u < myGLen) {\n    // Do the update\n    // loss = -ln(Pr(y*|x))\n    acc_t grad = std::log(lossGrad[batch]) + myAlpha[t * maxGLen + u] - myBeta[0];\n    if (u != myGLen - 1)\n      myXGrad[u * dictSize + myLabel[u]] =\n          -std::exp(grad + myBeta[t * maxGLen + u + 1] + myX[u * dictSize + myLabel[u]]);\n    if (t == myFLen - 1 and u == myGLen - 1)\n      myXGrad[u * dictSize + blankIdx] = -std::exp(grad + myX[u * dictSize + blankIdx]);\n    else if (t != myFLen - 1)\n      myXGrad[u * dictSize + blankIdx] = -std::exp(grad + myBeta[(t + 1) * maxGLen + u] + myX[u * dictSize + blankIdx]);\n\n    u += blockDim.x;\n  }\n}\n\n// Fused transudcer loss backward operation.\n// Detail of this loss function can be found in:\n// [1] Sequence Transduction with Recurrent Neural Networks.\n// The bwd op of the preceding softmax layer is fused in this kernel.\n// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time\n\n// To support the packed input, the starting offsets for each batch need to be specified with\n// batchOffset.\ntemplate <typename scalar_t, typename acc_t>\n__global__ void transducer_loss_fused_backward(const scalar_t* x, const scalar_t* lossGrad, const int* audLen,\n                                               const int* txtLen, const int* label, const acc_t* alpha,\n                                               const acc_t* beta, const int64_t* batchOffset, int64_t dictSize,\n                                               int64_t blankIdx, int64_t maxFLen, int64_t maxGLen, bool packedInput,\n                                               scalar_t* xGrad) {\n  const int tid = threadIdx.x;\n  const int u = blockIdx.x;\n  const int t = blockIdx.y;\n  const int batch = blockIdx.z;\n  const int64_t myFLen = audLen[batch];\n  const int64_t myGLen = txtLen[batch] + 1;\n  const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen;\n  const int64_t myStrideT = packedInput ? myGLen : maxGLen;\n\n  __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared;\n  auto myXGrad = xGrad + (myBatchOffset + t * myStrideT + u) * dictSize;\n\n  if (t < myFLen and u < myGLen) {\n    auto myX = x + (myBatchOffset + t * myStrideT + u) * dictSize;\n    auto myAlpha = alpha + batch * maxFLen * maxGLen;\n    auto myBeta = beta + batch * maxFLen * maxGLen;\n    auto myLabel = label + batch * (maxGLen - 1);\n\n    // load and store shared variables in SMEM\n    if (tid == 0) {\n      commonFactor = std::log(lossGrad[batch]) + myAlpha[t * maxGLen + u] - myBeta[0];\n      myBetaTU = myBeta[t * maxGLen + u];\n      myBetaTUp1 = myBeta[t * maxGLen + u + 1];\n      myBetaTp1U = myBeta[(t + 1) * maxGLen + u];\n      myLabelShared = myLabel[u];\n    }\n\n    __syncthreads();\n\n    for (int64_t h = tid; h < dictSize; h += blockDim.x) {\n      // Do the update\n      acc_t grad = commonFactor + myX[h];  // loss = -ln(Pr(y*|x))\n      acc_t myGrad = std::exp(grad + myBetaTU);\n      if (u != myGLen - 1 and h == myLabelShared) {\n        myGrad -= std::exp(grad + myBetaTUp1);\n      } else if (h == blankIdx) {\n        if (t == myFLen - 1 and u == myGLen - 1)\n          myGrad -= std::exp(grad);\n        else if (t != myFLen - 1)\n          myGrad -= std::exp(grad + myBetaTp1U);\n      }\n      myXGrad[h] = myGrad;\n    }\n  } else if (!packedInput) {\n    // In non-pack mode, need to make sure the gradients for don't-care regions are zero.\n    for (int64_t h = tid; h < dictSize; h += blockDim.x) {\n      myXGrad[h] = 0;\n    }\n  }\n}\n\n// Vectorized version of fused transudcer loss backward operation.\n// Detail of this loss function can be found in:\n// [1] Sequence Transduction with Recurrent Neural Networks.\n// The bwd op of the preceding softmax layer is fused in this kernel.\n// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time\n\n// To support the packed input, the starting offsets for each batch need to be specified with\n// batchOffset.\ntemplate <typename scalar_t, typename acc_t, typename vec_t, int V>\n__global__ void transducer_loss_fused_vec_backward(const scalar_t* x, const scalar_t* lossGrad, const int* audLen,\n                                                   const int* txtLen, const int* label, const acc_t* alpha,\n                                                   const acc_t* beta, const int64_t* batchOffset, int64_t dictSize,\n                                                   int64_t blankIdx, int64_t maxFLen, int64_t maxGLen, bool packedInput,\n                                                   scalar_t* xGrad) {\n  const int tid = threadIdx.x;\n  const int u = blockIdx.x;\n  const int t = blockIdx.y;\n  const int batch = blockIdx.z;\n  const int64_t myFLen = audLen[batch];\n  const int64_t myGLen = txtLen[batch] + 1;\n  const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch - 1]) : batch * maxFLen * maxGLen;\n  const int64_t myStrideT = packedInput ? myGLen : maxGLen;\n\n  __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared;\n  auto myXGrad = xGrad + (myBatchOffset + t * myStrideT + u) * dictSize;\n  auto myX = x + (myBatchOffset + t * myStrideT + u) * dictSize;\n  auto myAlpha = alpha + batch * maxFLen * maxGLen;\n  auto myBeta = beta + batch * maxFLen * maxGLen;\n  auto myLabel = label + batch * (maxGLen - 1);\n\n  // Variabels for vectorization\n  scalar_t myXBuffer[V], myXGradBuffer[V];\n  auto myXVec = reinterpret_cast<vec_t const*>(myX);\n  auto myXGradVec = reinterpret_cast<vec_t*>(myXGrad);\n  auto myXBufferVec = reinterpret_cast<vec_t*>(myXBuffer);\n  auto myXGradBufferVec = reinterpret_cast<vec_t*>(myXGradBuffer);\n  if (t < myFLen and u < myGLen) {\n    // load and store shared variables in SMEM\n    if (tid == 0) {\n      commonFactor = std::log(lossGrad[batch]) + myAlpha[t * maxGLen + u] - myBeta[0];\n      myBetaTU = myBeta[t * maxGLen + u];\n      if (t != myFLen - 1) myBetaTp1U = myBeta[(t + 1) * maxGLen + u];\n      if (u != myGLen - 1) {\n        myBetaTUp1 = myBeta[t * maxGLen + u + 1];\n        myLabelShared = myLabel[u];\n      }\n    }\n\n    __syncthreads();\n\n#pragma unroll\n    for (int64_t h0 = tid * V; h0 < dictSize; h0 += blockDim.x * V) {\n      // Load myX in a vector form\n      *myXBufferVec = myXVec[h0 / V];\n// Do the update for a vector of input\n#pragma unroll\n      for (int i = 0; i < V; ++i) {\n        auto h = h0 + i;\n        acc_t grad = commonFactor + myXBuffer[i];  // loss = -ln(Pr(y*|x))\n        acc_t myGrad = std::exp(grad + myBetaTU);\n        if (u != myGLen - 1 and h == myLabelShared) {\n          myGrad -= std::exp(grad + myBetaTUp1);\n        } else if (h == blankIdx) {\n          if (t == myFLen - 1 and u == myGLen - 1)\n            myGrad -= std::exp(grad);\n          else if (t != myFLen - 1)\n            myGrad -= std::exp(grad + myBetaTp1U);\n        }\n        myXGradBuffer[i] = myGrad;\n      }\n\n      // Store myXGrad in a vector form\n      myXGradVec[h0 / V] = *myXGradBufferVec;\n    }\n  } else if (!packedInput) {\n    // In non-pack mode, need to make sure the gradients for don't-care regions are zero.\n    for (int64_t h0 = tid * V; h0 < dictSize; h0 += blockDim.x * V) {\n      myXGradVec[h0 / V] = 0;\n    }\n  }\n}\n\nstd::vector<torch::Tensor> transducer_loss_cuda_forward(torch::Tensor x, torch::Tensor label, torch::Tensor audLen,\n                                                        torch::Tensor txtLen, torch::Tensor batchOffset, int maxFLen,\n                                                        int blankIdx, int opt, bool packedInput) {\n  auto scalarType = x.scalar_type();\n  auto tensorOpt = x.options();\n  const int batchSize = label.size(0);\n  const int maxGLen = label.size(1) + 1;\n  const int dictSize = x.size(-1);\n\n  TORCH_CHECK(blankIdx >= 0 and blankIdx < dictSize, \"Expected blank index to be in the range of 0 to \", dictSize - 1,\n              \", but got \", blankIdx);\n  TORCH_CHECK(opt == -1 or opt == 0 or opt == 1, \"Got an invalid optimization level \", opt);\n\n  // The data type of alpha and beta will be resolved at dispatch time,\n  // hence defined here and assigned later\n  torch::Tensor alpha;\n  torch::Tensor beta;\n  torch::Tensor loss = torch::empty({batchSize}, tensorOpt);\n  const auto deviceProperties = at::cuda::getCurrentDeviceProperties();\n  const auto maxThreadPerBlock = deviceProperties->maxThreadsPerBlock;\n  const auto maxSmemPerBlock = deviceProperties->sharedMemPerBlock;\n  const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr<int64_t>() : nullptr;\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n      scalarType, \"transducer_loss_cuda_forward\", ([&] {\n        // resolve accumulation type\n        using acc_t = at::acc_type<scalar_t, true>;\n        auto accType = c10::CppTypeToScalarType<acc_t>::value;\n        auto accTensorOpt = tensorOpt.dtype(accType);\n        alpha = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt);\n        beta = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt);\n\n        // decide what kernel to launch based on the problem size\n        // if the required SMEM size or number threads exceeds the limit, fall back to the vanilla\n        // kernel.\n        const auto smemSize = 2 * maxGLen * sizeof(acc_t);\n        const auto optFallBack = (maxGLen > maxThreadPerBlock or smemSize > maxSmemPerBlock) ? 0\n                                 : (opt == -1)                                               ? 1\n                                                                                             : opt;\n        const int threads = std::min(maxThreadPerBlock, maxGLen);\n        const dim3 blocks(2, batchSize, 1);\n\n        if (optFallBack == 0)\n          transducer_loss_forward<<<blocks, threads, 0, stream>>>(\n              x.data_ptr<scalar_t>(), label.data_ptr<int>(), audLen.data_ptr<int>(), txtLen.data_ptr<int>(),\n              batchOffsetPtr, dictSize, blankIdx, maxFLen, maxGLen, packedInput, alpha.data_ptr<acc_t>(),\n              beta.data_ptr<acc_t>(), loss.data_ptr<scalar_t>());\n        else if (optFallBack == 1)\n          transducer_loss_batch_load_forward<scalar_t, acc_t, 4><<<blocks, threads, smemSize, stream>>>(\n              x.data_ptr<scalar_t>(), label.data_ptr<int>(), audLen.data_ptr<int>(), txtLen.data_ptr<int>(),\n              batchOffsetPtr, dictSize, blankIdx, maxFLen, maxGLen, packedInput, alpha.data_ptr<acc_t>(),\n              beta.data_ptr<acc_t>(), loss.data_ptr<scalar_t>());\n      }));\n  C10_CUDA_CHECK(cudaGetLastError());\n\n  return {alpha, beta, loss};\n}\n\ntorch::Tensor transducer_loss_cuda_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha,\n                                            torch::Tensor beta, torch::Tensor audLen, torch::Tensor txtLen,\n                                            torch::Tensor label, torch::Tensor batchOffset, int maxFLen, int blankIdx,\n                                            int opt, bool fuseSoftmaxBackward, bool packedInput) {\n  auto dtype = x.scalar_type();\n  torch::Tensor xGrad;\n  const int batchSize = label.size(0);\n  const int maxGLen = label.size(1) + 1;\n  const int dictSize = x.size(-1);\n  const auto deviceProperties = at::cuda::getCurrentDeviceProperties();\n  const int maxThreadPerBlock = deviceProperties->maxThreadsPerBlock;\n  const int warpSize = deviceProperties->warpSize;\n  const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr<int64_t>() : nullptr;\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  if (fuseSoftmaxBackward) {\n    // alloc empty tensors for performance, hence need to ensure zeros are writtern to\n    // don't-care region in the kernel.\n    xGrad = torch::empty_like(x);\n\n    // Would like each thread to work on 4 hidden units\n    const int workPerThread = 4;\n    // Don't want to have more than 128 threads per thread block\n    const int maxThreadPerElmt = std::min(128, maxThreadPerBlock);\n    const int threads = std::min(maxThreadPerElmt, std::max(warpSize, (dictSize + workPerThread - 1) / workPerThread));\n    const dim3 blocks(maxGLen, maxFLen, batchSize);\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n        dtype, \"transducer_loss_cuda_backward\", ([&] {\n          using vec_t = uint64_t;\n          using acc_t = at::acc_type<scalar_t, true>;\n          constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t);\n          constexpr int vecAlignment = std::alignment_of<vec_t>::value;\n          // if all input and output tensors meet the alignment requirement\n          bool memAlign = reinterpret_cast<uint64_t>(x.data_ptr<scalar_t>()) % vecAlignment == 0 and\n                          reinterpret_cast<uint64_t>(xGrad.data_ptr<scalar_t>()) % vecAlignment == 0;\n\n          if (vectFactor > 1 and dictSize % vectFactor == 0 and memAlign) {\n            transducer_loss_fused_vec_backward<scalar_t, acc_t, vec_t, vectFactor><<<blocks, threads, 0, stream>>>(\n                x.data_ptr<scalar_t>(), lossGrad.data_ptr<scalar_t>(), audLen.data_ptr<int>(), txtLen.data_ptr<int>(),\n                label.data_ptr<int>(), alpha.data_ptr<acc_t>(), beta.data_ptr<acc_t>(), batchOffsetPtr, dictSize,\n                blankIdx, maxFLen, maxGLen, packedInput, xGrad.data_ptr<scalar_t>());\n          } else {\n            transducer_loss_fused_backward<<<blocks, threads, 0, stream>>>(\n                x.data_ptr<scalar_t>(), lossGrad.data_ptr<scalar_t>(), audLen.data_ptr<int>(), txtLen.data_ptr<int>(),\n                label.data_ptr<int>(), alpha.data_ptr<acc_t>(), beta.data_ptr<acc_t>(), batchOffsetPtr, dictSize,\n                blankIdx, maxFLen, maxGLen, packedInput, xGrad.data_ptr<scalar_t>());\n          }\n        }));\n  } else {\n    // for non-fused kernel, the gradients need to be writtern are very sparse, hence initialize\n    // the tensor with all zeros.\n    xGrad = torch::zeros_like(x);\n    // don't launch more threads than needed.\n    const int threads = std::min(maxThreadPerBlock, maxGLen);\n    const dim3 blocks(maxFLen, batchSize);\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, \"transducer_loss_cuda_backward\", ([&] {\n                                          using acc_t = at::acc_type<scalar_t, true>;\n                                          transducer_loss_backward<<<blocks, threads, 0, stream>>>(\n                                              x.data_ptr<scalar_t>(), lossGrad.data_ptr<scalar_t>(),\n                                              audLen.data_ptr<int>(), txtLen.data_ptr<int>(), label.data_ptr<int>(),\n                                              alpha.data_ptr<acc_t>(), beta.data_ptr<acc_t>(), batchOffsetPtr, dictSize,\n                                              blankIdx, maxFLen, maxGLen, packedInput, xGrad.data_ptr<scalar_t>());\n                                        }));\n  }\n  C10_CUDA_CHECK(cudaGetLastError());\n\n  return xGrad;\n}\n"
  },
  {
    "path": "apex/contrib/csrc/xentropy/interface.cpp",
    "content": "#include <torch/extension.h>\n\n#include <string>\n\n// CUDA forward declarations\n\nstd::vector<at::Tensor> softmax_xentropy_cuda(const at::Tensor& input, const at::Tensor& labels, const float smoothing,\n                                              const bool half_to_float);\n\nat::Tensor softmax_xentropy_backward_cuda(const at::Tensor& grad_loss, const at::Tensor& logits,\n                                          const at::Tensor& max_log_sum_exp, const at::Tensor& labels,\n                                          const float smoothing);\n\n// C++ interface\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) \\\n  CHECK_CUDA(x);       \\\n  CHECK_CONTIGUOUS(x)\n\nstd::vector<at::Tensor> softmax_xentropy_forward(const at::Tensor& input, const at::Tensor& labels,\n                                                 const float smoothing, const bool half_to_float) {\n  CHECK_CUDA(input);\n  CHECK_INPUT(labels);\n\n  return softmax_xentropy_cuda(input, labels, smoothing, half_to_float);\n}\n\nat::Tensor softmax_xentropy_backward(const at::Tensor& grad_loss, const at::Tensor& logits,\n                                     const at::Tensor& max_log_sum_exp, const at::Tensor& labels,\n                                     const float smoothing) {\n  CHECK_CUDA(grad_loss);\n  CHECK_CUDA(logits);\n  CHECK_INPUT(max_log_sum_exp);\n  CHECK_INPUT(labels);\n\n  return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &softmax_xentropy_forward, \"Softmax cross entropy loss with label smoothing forward (CUDA)\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward\", &softmax_xentropy_backward, \"Softmax cross entropy loss with label smoothing backward (CUDA)\",\n        py::call_guard<py::gil_scoped_release>());\n  // ref: https://pybind11.readthedocs.io/en/stable/basics.html#exporting-variables\n  py::object version = py::cast(\n#ifdef XENTROPY_VER\n      XENTROPY_VER\n#else\n      std::string{}\n#endif\n  );\n  m.attr(\"__version__\") = version;\n}\n"
  },
  {
    "path": "apex/contrib/csrc/xentropy/xentropy_kernel.cu",
    "content": "/**\n * From PyTorch:\n *\n * Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)\n * Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)\n * Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)\n * Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)\n * Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)\n * Copyright (c) 2011-2013 NYU                      (Clement Farabet)\n * Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)\n * Copyright (c) 2006      Idiap Research Institute (Samy Bengio)\n * Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)\n *\n * From Caffe2:\n *\n * Copyright (c) 2016-present, Facebook Inc. All rights reserved.\n *\n * All contributions by Facebook:\n * Copyright (c) 2016 Facebook Inc.\n *\n * All contributions by Google:\n * Copyright (c) 2015 Google Inc.\n * All rights reserved.\n *\n * All contributions by Yangqing Jia:\n * Copyright (c) 2015 Yangqing Jia\n * All rights reserved.\n *\n * All contributions from Caffe:\n * Copyright(c) 2013, 2014, 2015, the respective contributors\n * All rights reserved.\n *\n * All other contributions:\n * Copyright(c) 2015, 2016 the respective contributors\n * All rights reserved.\n *\n * Caffe2 uses a copyright model similar to Caffe: each contributor holds\n * copyright over their contributions to Caffe2. The project versioning records\n * all such contribution and copyright details. If a contributor wants to further\n * mark their specific copyright on a particular contribution, they should\n * indicate their copyright solely in the commit message of the change when it is\n * committed.\n *\n * All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright\n *    notice, this list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright\n *    notice, this list of conditions and the following disclaimer in the\n *    documentation and/or other materials provided with the distribution.\n *\n * 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America\n *    and IDIAP Research Institute nor the names of its contributors may be\n *    used to endorse or promote products derived from this software without\n *    specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE\n * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n * POSSIBILITY OF SUCH DAMAGE.\n */\n#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#include <ATen/cuda/NumericLimits.cuh>\n\n#include \"type_shim.h\"\n\n#define ALIGN_BYTES 16\n\nusing Tensor = at::Tensor;\nusing TensorList = at::TensorList;\nusing ScalarType = at::ScalarType;\nusing at::acc_type;\n\ntemplate <typename T, typename AccumT, typename OutT>\nstruct LogSoftMaxForwardEpilogue {\n  __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum)\n      : logsum(max_input + std::log(sum)) {}\n\n  __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp) : logsum(max_log_sum_exp) {}\n\n  __device__ __forceinline__ OutT operator()(T input) const { return static_cast<OutT>(input - logsum); }\n\n  const AccumT logsum;\n};\n\ntemplate <typename T, typename AccumT, typename OutT>\nstruct LogSoftMaxBackwardEpilogue {\n  __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum) : sum(sum) {}\n\n  __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const {\n    return static_cast<T>(gradOutput - std::exp(static_cast<AccumT>(output)) * sum);\n  }\n\n  const AccumT sum;\n};\n\nconst int max_threads = 1024;\n\ninline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {\n  uint64_t block_size = 1;\n  uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads));\n  while (block_size < (max_block_size / 2)) block_size *= 2;\n  // Launch at least a single warp - the kernel assumes that.\n  block_size = std::max(block_size, static_cast<uint64_t>(32));\n  return dim3(block_size);\n}\n\ntemplate <typename T>\nstruct Add {\n  __device__ __forceinline__ T operator()(T a, T b) const { return a + b; }\n};\n\ntemplate <typename T>\nstruct Max {\n  __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; }\n};\n\n////////////////////////////////////////////////////////////////////////////////\n// Regular kernel (fast when dim_size is large; requires inner_size == 1)\n////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename T, typename AccumT>\nstruct MaxFloat {\n  __device__ __forceinline__ AccumT operator()(AccumT max, T v) const { return ::max(max, (AccumT)v); }\n};\n\ntemplate <typename T, typename AccumT>\nstruct AddFloat {\n  __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { return sum + v; }\n};\n\ntemplate <typename T, typename AccumT>\nstruct SumExpFloat {\n  __device__ __forceinline__ SumExpFloat(AccumT v) : max_k(v) {}\n\n  __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { return sum + std::exp(v - max_k); }\n\n  const AccumT max_k;\n};\n\ntemplate <template <typename> class Reduction, typename AccumT>\n__device__ __forceinline__ AccumT blockReduce(AccumT* smem, AccumT val, const Reduction<AccumT>& r, AccumT defaultVal) {\n  // To avoid RaW races from chaining blockReduce calls together, we need a sync here\n  __syncthreads();\n\n  smem[threadIdx.x] = val;\n\n  __syncthreads();\n\n  AccumT warpVal = defaultVal;\n\n  // First warp will perform per-warp reductions for the remaining warps\n  uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;\n  if (threadIdx.x < 32) {\n    int lane = threadIdx.x % 32;\n    if (lane < blockDim.x / 32) {\n#pragma unroll\n      for (int i = 0; i < 32; ++i) {\n        warpVal = r(warpVal, smem[lane * 32 + i]);\n      }\n      __syncwarp(mask);\n      smem[lane] = warpVal;\n    }\n  }\n\n  __syncthreads();\n\n  // First thread will perform a reduction of the above per-warp reductions\n  AccumT blockVal = defaultVal;\n\n  if (threadIdx.x == 0) {\n    for (int i = 0; i < blockDim.x / 32; ++i) {\n      blockVal = r(blockVal, smem[i]);\n    }\n    smem[0] = blockVal;\n  }\n\n  // Sync and broadcast\n  __syncthreads();\n  return smem[0];\n}\n\ntemplate <template <typename> class Reduction1, template <typename> class Reduction2, typename AccumT>\n__device__ __forceinline__ void blockReduce(AccumT* smem, AccumT* reducVal1, AccumT val1, const Reduction1<AccumT>& r1,\n                                            AccumT defaultVal1, AccumT* reducVal2, AccumT val2,\n                                            const Reduction2<AccumT>& r2, AccumT defaultVal2) {\n  // To avoid RaW races from chaining blockReduce calls together, we need a sync here\n  __syncthreads();\n\n  smem[threadIdx.x] = val1;\n  smem[blockDim.x + threadIdx.x] = val2;\n\n  __syncthreads();\n\n  AccumT warpVal1 = defaultVal1;\n  AccumT warpVal2 = defaultVal2;\n\n  // First warp will perform per-warp reductions for the remaining warps\n  uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;\n  if (threadIdx.x < 32) {\n    int lane = threadIdx.x % 32;\n    if (lane < blockDim.x / 32) {\n#pragma unroll\n      for (int i = 0; i < 32; ++i) {\n        warpVal1 = r1(warpVal1, smem[lane * 32 + i]);\n        warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]);\n      }\n      __syncwarp(mask);\n      smem[lane] = warpVal1;\n      smem[lane + blockDim.x] = warpVal2;\n    }\n  }\n\n  __syncthreads();\n\n  // First thread will perform a reduction of the above per-warp reductions\n  AccumT blockVal1 = defaultVal1;\n  AccumT blockVal2 = defaultVal2;\n\n  if (threadIdx.x == 0) {\n    for (int i = 0; i < blockDim.x / 32; ++i) {\n      blockVal1 = r1(blockVal1, smem[i]);\n      blockVal2 = r2(blockVal2, smem[i + blockDim.x]);\n    }\n    smem[0] = blockVal1;\n    smem[blockDim.x] = blockVal2;\n  }\n\n  // Sync and broadcast\n  __syncthreads();\n  *reducVal1 = smem[0];\n  *reducVal2 = smem[blockDim.x];\n  __syncthreads();\n}\n\ntemplate <template <typename, typename> class Reduction, int ILP, typename T, typename AccumT>\n__device__ __forceinline__ AccumT ilpReduce(int shift, T* data, int size, const Reduction<T, AccumT>& r,\n                                            AccumT defaultVal) {\n  typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LoadT;\n  AccumT threadVal = defaultVal;\n  int offset = threadIdx.x;\n\n  // shift and do 1\n  if (shift > 0) {\n    data -= shift;\n    size += shift;\n    if (threadIdx.x >= shift) {\n      threadVal = r(threadVal, data[offset]);\n    }\n    size -= blockDim.x;\n    data += blockDim.x;\n  }\n  int last = size % (ILP * blockDim.x);\n\n  T v[ILP];\n  LoadT* value = reinterpret_cast<LoadT*>(&v);\n\n  for (; offset * ILP < (size - last); offset += blockDim.x) {\n    *value = reinterpret_cast<LoadT*>(data)[offset];\n\n    for (int j = 0; j < ILP; ++j) {\n      threadVal = r(threadVal, v[j]);\n    }\n  }\n\n  offset = size - last + threadIdx.x;\n  // Epilogue\n  for (; offset < size; offset += blockDim.x) threadVal = r(threadVal, data[offset]);\n\n  return threadVal;\n}\n\ntemplate <template <typename, typename> class Reduction1, template <typename, typename> class Reduction2, int ILP,\n          typename T, typename AccumT>\n__device__ __forceinline__ void ilpReduce(int shift, T* data, int size, AccumT* reducVal1,\n                                          const Reduction1<T, AccumT>& r1, AccumT defaultVal1, AccumT* reducVal2,\n                                          const Reduction2<T, AccumT>& r2, AccumT defaultVal2) {\n  typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LoadT;\n\n  AccumT threadVal1 = defaultVal1;\n  AccumT threadVal2 = defaultVal2;\n  int offset = threadIdx.x;\n\n  // shift and do 1\n  if (shift > 0) {\n    data -= shift;\n    size += shift;\n    if (threadIdx.x >= shift) {\n      threadVal1 = r1(threadVal1, data[offset]);\n      threadVal2 = r2(threadVal2, data[offset]);\n    }\n    size -= blockDim.x;\n    data += blockDim.x;\n  }\n  int last = size % (ILP * blockDim.x);\n\n  T v[ILP];\n  LoadT* value = reinterpret_cast<LoadT*>(&v);\n\n  for (; offset * ILP < (size - last); offset += blockDim.x) {\n    *value = reinterpret_cast<LoadT*>(data)[offset];\n\n    for (int j = 0; j < ILP; ++j) {\n      threadVal1 = r1(threadVal1, v[j]);\n      threadVal2 = r2(threadVal2, v[j]);\n    }\n  }\n\n  offset = size - last + threadIdx.x;\n  // Epilogue\n  for (; offset < size; offset += blockDim.x) {\n    threadVal1 = r1(threadVal1, data[offset]);\n    threadVal2 = r2(threadVal2, data[offset]);\n  }\n\n  *reducVal1 = threadVal1;\n  *reducVal2 = threadVal2;\n}\n\ntemplate <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t,\n          template <typename, typename, typename> class Epilogue>\n__global__ void cunn_SoftMaxXEntropyForward(accscalar_t* losses, outscalar_t* max_log_sum_exp, scalar_t* input,\n                                            int64_t* labels, int64_t classes, const float smoothing) {\n  extern __shared__ unsigned char smem[];\n  auto sdata = reinterpret_cast<accscalar_t*>(smem);\n  // forward pointers to batch[blockIdx.x]\n  // each block handles a sample in the mini-batch\n  input += blockIdx.x * classes;\n  // output += blockIdx.x * classes;\n  const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t);\n\n  int64_t label = labels[blockIdx.x];\n\n  // find the max and sum\n  accscalar_t threadMax, threadSum, max_k, sum_k;\n  ilpReduce<MaxFloat, AddFloat, ILP, scalar_t, accscalar_t>(\n      shift, input, classes, &threadMax, MaxFloat<scalar_t, accscalar_t>(), -at::numeric_limits<accscalar_t>::max(),\n      &threadSum, AddFloat<scalar_t, accscalar_t>(), static_cast<accscalar_t>(0));\n\n  blockReduce<Max, Add, accscalar_t>(sdata, &max_k, threadMax, Max<accscalar_t>(),\n                                     -at::numeric_limits<accscalar_t>::max(), &sum_k, threadSum, Add<accscalar_t>(),\n                                     static_cast<accscalar_t>(0));\n\n  accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(\n      shift, input, classes, SumExpFloat<scalar_t, accscalar_t>(max_k), static_cast<accscalar_t>(0));\n  accscalar_t sumAll = blockReduce<Add, accscalar_t>(sdata, threadExp, Add<accscalar_t>(), static_cast<accscalar_t>(0));\n\n  Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);\n\n  // calculate per element loss with label smoothing\n  // reserve max + log_sum_exp for bprop\n  if (threadIdx.x == 0) {\n    accscalar_t log_prob = epilogue(static_cast<accscalar_t>(input[label]));\n    losses[blockIdx.x] = (max_k + std::log(sumAll) - sum_k / classes) * smoothing - log_prob * (1 - smoothing);\n    max_log_sum_exp[blockIdx.x] = max_k + std::log(sumAll);\n  }\n}\n\ntemplate <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>\n__device__ __forceinline__ void apply(scalar_t* gradInput, scalar_t* logits, outscalar_t* max_log_sum_exp,\n                                      outscalar_t* gradOutput, int64_t* labels, const float smoothing, int classes) {\n  accscalar_t smooth_positives = 1.0 - smoothing;\n  accscalar_t smooth_negatives = smoothing / classes;\n  accscalar_t tmpGradOutput = gradOutput[blockIdx.x];\n  int64_t label = labels[blockIdx.x];\n  accscalar_t coeff = max_log_sum_exp[blockIdx.x];\n\n  int offset = threadIdx.x;\n  int last = classes % (ILP * blockDim.x);\n\n  for (; offset < classes - last; offset += blockDim.x * ILP) {\n    accscalar_t tmpLogits[ILP];\n\n#pragma unroll\n    for (int j = 0; j < ILP; ++j) {\n      tmpLogits[j] = static_cast<accscalar_t>(logits[offset + j * blockDim.x]);\n    }\n\n#pragma unroll\n    for (int j = 0; j < ILP; ++j)\n      gradInput[offset + j * blockDim.x] =\n          tmpGradOutput *\n          (std::exp(tmpLogits[j] - coeff) -\n           static_cast<accscalar_t>((offset + j * blockDim.x == label) ? 1 : 0) * smooth_positives - smooth_negatives);\n  }\n\n  for (; offset < classes; offset += blockDim.x)\n    gradInput[offset] =\n        tmpGradOutput * (std::exp(static_cast<accscalar_t>(logits[offset]) - coeff) -\n                         static_cast<accscalar_t>((offset == label) ? 1 : 0) * smooth_positives - smooth_negatives);\n}\n\ntemplate <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>\n__device__ __forceinline__ void aligned_apply(int shift, scalar_t* gradInput, scalar_t* logits,\n                                              outscalar_t* max_log_sum_exp, outscalar_t* gradOutput, int64_t* labels,\n                                              const float smoothing, int classes) {\n  accscalar_t smooth_positives = 1.0 - smoothing;\n  accscalar_t smooth_negatives = smoothing / classes;\n  accscalar_t tmpGradOutput = gradOutput[blockIdx.x];\n  int64_t label = labels[blockIdx.x];\n  accscalar_t coeff = max_log_sum_exp[blockIdx.x];\n\n  int offset = threadIdx.x;\n\n  // shift and do 1\n  if (shift > 0) {\n    logits -= shift;\n    gradInput -= shift;\n    classes += shift;\n    if (threadIdx.x >= shift) {\n      gradInput[offset] =\n          tmpGradOutput *\n          (std::exp(static_cast<accscalar_t>(logits[offset]) - coeff) -\n           static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) * smooth_positives - smooth_negatives);\n    }\n    classes -= blockDim.x;\n    gradInput += blockDim.x;\n    logits += blockDim.x;\n    shift -= blockDim.x;\n  }\n\n  int last = classes % (ILP * blockDim.x);\n\n  typedef typename std::aligned_storage<ILP * sizeof(scalar_t), ILP * alignof(scalar_t)>::type LoadT;\n  // input\n  scalar_t v[ILP];\n  LoadT* value = reinterpret_cast<LoadT*>(&v);\n  // output\n  scalar_t r[ILP];\n  LoadT* result = reinterpret_cast<LoadT*>(&r);\n\n  for (; offset * ILP < (classes - last); offset += blockDim.x) {\n    *value = reinterpret_cast<LoadT*>(logits)[offset];\n\n#pragma unroll\n    for (int j = 0; j < ILP; ++j) {\n      r[j] =\n          tmpGradOutput * (std::exp(static_cast<accscalar_t>(v[j]) - coeff) -\n                           static_cast<accscalar_t>(((ILP * offset + j - shift) == label) ? 1 : 0) * smooth_positives -\n                           smooth_negatives);\n    }\n    reinterpret_cast<LoadT*>(gradInput)[offset] = *result;\n  }\n\n  offset = classes - last + threadIdx.x;\n  for (; offset < classes; offset += blockDim.x)\n    gradInput[offset] =\n        tmpGradOutput *\n        (std::exp(static_cast<accscalar_t>(logits[offset]) - coeff) -\n         static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) * smooth_positives - smooth_negatives);\n}\n\ntemplate <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t,\n          template <typename, typename, typename> class Epilogue>\n__global__ void cunn_SoftMaxXEntropyBackward(scalar_t* gradInput, scalar_t* logits, outscalar_t* max_log_sum_exp,\n                                             outscalar_t* gradOutput, int64_t* labels, const float smoothing,\n                                             int classes) {\n  gradInput += blockIdx.x * classes;\n  logits += blockIdx.x * classes;\n\n  // Do vectorized load/store when input/output have same alignment\n  const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t);\n  const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t);\n  if (shift == shift_) {\n    aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput,\n                                                           labels, smoothing, classes);\n  } else {\n    apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing,\n                                                   classes);\n  }\n}\n\ntemplate <template <typename, typename, typename> class Epilogue>\nstd::vector<Tensor> host_softmax_xentropy(const Tensor& input_, const Tensor& labels_, const float smoothing,\n                                          const bool half_to_float) {\n  if (half_to_float)\n    TORCH_CHECK(input_.scalar_type() == ScalarType::Half, \"conversion is supported for Half type only\");\n  TORCH_CHECK(labels_.scalar_type() == ScalarType::Long, \"Label type should be CUDA Long\");\n\n  auto input = input_.contiguous();\n  Tensor max_log_sum_exp =\n      at::empty_like(labels_, half_to_float ? input.options().dtype(ScalarType::Float) : input.options());\n  Tensor losses = at::empty_like(labels_, input_.options().dtype(ScalarType::Float));\n\n  static_assert(\n      std::is_same<acc_type<at::Half, true>, float>::value || std::is_same<acc_type<at::Half, true>, double>::value,\n      \"accscalar_t for half should be float or double\");\n  TORCH_CHECK(input.dim() == 2, \"Currently only 2 dim input supported\");\n  TORCH_CHECK(labels_.dim() == 1, \"Labels should be 1 dimensional\");\n  TORCH_CHECK(input.size(0) == labels_.size(0), \"Input and label should have same number of examples\");\n  TORCH_CHECK(input.numel() > 0, \"Number of classes in input should not be 0\");\n\n  const int64_t dim = 1;\n  int64_t outer_size = 1;\n  int64_t dim_size = input.size(dim);\n  int64_t inner_size = 1;\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n  for (int64_t i = 0; i < dim; ++i) outer_size *= input.size(i);\n  for (int64_t i = dim + 1; i < input.dim(); ++i) inner_size *= input.size(i);\n  // This kernel spawns a block per each element in the batch.\n  // XXX: it assumes that inner_size == 1\n  TORCH_CHECK(inner_size == 1, \"Currently only inner size 1 supported\");\n\n  dim3 grid(outer_size);\n\n  using namespace at;\n  DISPATCH_FLOAT_HALF_AND_BFLOAT(\n      input.scalar_type(), 0, \"host_softmax_xentropy\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n      const int ILP = sizeof(float4) / sizeof(scalar_t_0); dim3 block = SoftMax_getBlockSize(ILP, dim_size);\n      if (!half_to_float) {\n        cunn_SoftMaxXEntropyForward<ILP, scalar_t_0, accscalar_t, scalar_t_0, Epilogue>\n            <<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(\n                losses.data_ptr<accscalar_t>(), max_log_sum_exp.data_ptr<scalar_t_0>(), input.data_ptr<scalar_t_0>(),\n                labels_.data_ptr<int64_t>(), dim_size, smoothing);\n      } else {\n        cunn_SoftMaxXEntropyForward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>\n            <<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(\n                losses.data_ptr<accscalar_t>(), max_log_sum_exp.data_ptr<accscalar_t>(), input.data_ptr<scalar_t_0>(),\n                labels_.data_ptr<int64_t>(), dim_size, smoothing);\n      });\n\n  C10_CUDA_CHECK(cudaGetLastError());\n\n  std::vector<at::Tensor> ret = {losses, max_log_sum_exp};\n  return ret;\n}\n\ntemplate <template <typename, typename, typename> class Epilogue>\nTensor host_softmax_xentropy_backward(const at::Tensor& grad_loss, const at::Tensor& logits_,\n                                      const at::Tensor& max_log_sum_exp, const at::Tensor& labels,\n                                      const float smoothing, bool half_to_float) {\n  const int64_t dim = 1;\n  Tensor gI = at::empty_like(logits_);\n  if (grad_loss.numel() == 0) {\n    return gI;\n  }\n\n  auto grad = grad_loss.contiguous();\n  auto logits = logits_.contiguous();\n\n  static_assert(\n      std::is_same<acc_type<at::Half, true>, float>::value || std::is_same<acc_type<at::Half, true>, double>::value,\n      \"accscalar_t for half should be float or double\");\n  if (grad.dim() == 0) grad = grad.view(1);\n\n  TORCH_CHECK(logits_.dim() == 2, \"Currently only 2 dim input supported\");\n  TORCH_CHECK(labels.dim() == 1, \"Labels should be 1 dimensional\");\n  TORCH_CHECK(logits_.numel() > 0, \"Number of classes in input should not be 0\");\n  TORCH_CHECK(logits_.size(0) == labels.size(0), \"Input and label should have same number of examples\");\n  TORCH_CHECK(labels.size(0) == grad.size(0), \"Label and loss should have same number of examples\");\n\n  int64_t outer_size = 1;\n  int64_t dim_size = logits.size(dim);\n  int64_t inner_size = 1;\n  for (int64_t i = 0; i < dim; ++i) outer_size *= logits.size(i);\n  for (int64_t i = dim + 1; i < logits.dim(); ++i) inner_size *= logits.size(i);\n  // See descriptions of kernels above.\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n  TORCH_CHECK(inner_size == 1, \"Currently only inner size 1 supported\");\n\n  dim3 grid(outer_size);\n\n  DISPATCH_FLOAT_HALF_AND_BFLOAT(\n      gI.scalar_type(), 0, \"host_softmax_xentropy_backward\", using accscalar_t = acc_type<scalar_t_0, true>;\n      const int ILP = sizeof(float4) / sizeof(scalar_t_0); dim3 block = SoftMax_getBlockSize(ILP, dim_size);\n      if (!half_to_float) {\n        cunn_SoftMaxXEntropyBackward<ILP, scalar_t_0, accscalar_t, scalar_t_0, Epilogue>\n            <<<grid, block, block.x * sizeof(accscalar_t), stream>>>(\n                gI.data_ptr<scalar_t_0>(), logits.data_ptr<scalar_t_0>(), max_log_sum_exp.data_ptr<scalar_t_0>(),\n                grad.data_ptr<scalar_t_0>(), labels.data_ptr<int64_t>(), smoothing, dim_size);\n      } else {\n        cunn_SoftMaxXEntropyBackward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>\n            <<<grid, block, block.x * sizeof(accscalar_t), stream>>>(\n                gI.data_ptr<scalar_t_0>(), logits.data_ptr<scalar_t_0>(), max_log_sum_exp.data_ptr<accscalar_t>(),\n                grad.data_ptr<accscalar_t>(), labels.data_ptr<int64_t>(), smoothing, dim_size);\n      });\n\n  C10_CUDA_CHECK(cudaGetLastError());\n  return gI;\n}\n\nstd::vector<Tensor> softmax_xentropy_cuda(const Tensor& input, const Tensor& labels, const float smoothing,\n                                          const bool half_to_float) {\n  return host_softmax_xentropy<LogSoftMaxForwardEpilogue>(input, labels, smoothing, half_to_float);\n}\n\nat::Tensor softmax_xentropy_backward_cuda(const at::Tensor& grad_loss, const at::Tensor& logits,\n                                          const at::Tensor& max_log_sum_exp, const at::Tensor& labels,\n                                          const float smoothing) {\n  bool half_to_float = grad_loss.scalar_type() != logits.scalar_type();\n  if (half_to_float) {\n    TORCH_CHECK((grad_loss.scalar_type() == ScalarType::Float && logits.scalar_type() == ScalarType::Half),\n                \"expected input and grad types to match, or input to be at::Half and grad to be at::Float\");\n  }\n  return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels,\n                                                                    smoothing, half_to_float);\n}\n"
  },
  {
    "path": "apex/contrib/cudnn_gbn/__init__.py",
    "content": "from .batch_norm import GroupBatchNorm2d\n"
  },
  {
    "path": "apex/contrib/cudnn_gbn/batch_norm.py",
    "content": "import torch\nfrom torch.nn.modules.batchnorm import _BatchNorm\nfrom torch.nn import functional as F\nfrom torch import Tensor\nimport peer_memory_cuda as pm\nimport cudnn_gbn_lib\nfrom torch.cuda.amp import custom_fwd, custom_bwd\n\n\nclass _GroupBatchNorm2d(torch.autograd.Function):\n    @staticmethod\n    @custom_fwd\n    def forward(\n        ctx,\n        input,\n        weight,\n        bias,\n        running_mean,\n        running_variance,\n        minibatch_mean,\n        minibatch_inv_var,\n        momentum,\n        eps,\n        group_size,\n        group_rank,\n        fwd_buffers,\n        bwd_buffers,\n    ):\n        ctx.save_for_backward(input, weight, minibatch_mean, minibatch_inv_var)\n        ctx.eps = eps\n        ctx.bn_group = group_size\n        ctx.rank_id = group_rank\n        ctx.peer_buffers = bwd_buffers\n        return cudnn_gbn_lib.forward(\n            input,\n            weight,\n            bias,\n            running_mean,\n            running_variance,\n            minibatch_mean,\n            minibatch_inv_var,\n            momentum,\n            eps,\n            group_size,\n            group_rank,\n            fwd_buffers,\n        )\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, grad_output):\n        x, scale, minibatch_mean, minibatch_inv_var = ctx.saved_variables\n        eps = ctx.eps\n        bn_group = ctx.bn_group\n        rank_id = ctx.rank_id\n        peer_buffers = ctx.peer_buffers\n        dx, dscale, dbias = cudnn_gbn_lib.backward(\n            x,\n            grad_output,\n            scale,\n            minibatch_mean,\n            minibatch_inv_var,\n            eps,\n            bn_group,\n            rank_id,\n            peer_buffers,\n        )\n        return (\n            dx,\n            dscale,\n            dbias,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n        )\n\n\nclass GroupBatchNorm2d(_BatchNorm):\n    \"\"\"\n    synchronized batch normalization module extented from ``torch.nn.BatchNormNd``\n    with the added stats reduction across multiple processes.\n\n    When running in training mode, the layer reduces stats across process groups\n    to increase the effective batchsize for normalization layer. This is useful\n    in applications where batch size is small on a given process that would\n    diminish converged accuracy of the model.\n\n    When running in evaluation mode, the layer falls back to\n    ``torch.nn.functional.batch_norm``.\n\n    Args:\n        num_features: :math:`C` from an expected input of size\n            :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Can be set to ``None`` for cumulative moving average\n            (i.e. simple average). Default: 0.1\n        affine: a boolean value that when set to ``True``, this module has\n            learnable affine parameters. Default: ``True``\n        track_running_stats: a boolean value that when set to ``True``, this\n            module tracks the running mean and variance, and when set to ``False``,\n            this module does not track such statistics and always uses batch\n            statistics in both training and eval modes. Default: ``True``\n\n    Example::\n\n        >>> sbn = apex.contrib.GroupBatchNorm2d(100).cuda()\n        >>> inp = torch.randn(10, 100, 14, 14).cuda()\n        >>> out = sbn(inp)\n        >>> inp = torch.randn(3, 100, 20).cuda()\n        >>> out = sbn(inp)\n    \"\"\"\n\n    def __init__(\n        self,\n        num_features,\n        group_size,\n        eps=1e-5,\n        momentum=0.1,\n        affine=True,\n        track_running_stats=True,\n    ):\n        super(GroupBatchNorm2d, self).__init__(\n            num_features,\n            eps=eps,\n            momentum=momentum,\n            affine=affine,\n            track_running_stats=track_running_stats,\n        )\n        self.group_size = group_size\n        rank = torch.distributed.get_rank()\n        self.group_id = rank // group_size\n        self.group_rank = rank % group_size\n        self.fwd_peer_buffers = self.get_peer_buffers(num_features)\n        self.bwd_peer_buffers = self.get_peer_buffers(num_features)\n        self.minibatch_mean = torch.cuda.FloatTensor(num_features)\n        self.minibatch_inv_var = torch.cuda.FloatTensor(num_features)\n\n    def get_peer_buffers(self, num_features):\n        # group_size * 2 (low-latency algo) * 2 (mean+var) * channels * 4 (float32)\n        peer_size = self.group_size * 4 * num_features * 4\n        raw = pm.allocate_raw(peer_size)\n        # exchange peer pointers with nccl\n        world_size = torch.distributed.get_world_size()\n        raw_ipc = pm.get_raw_ipc_address(raw).cuda()\n        raw_ipcs = [torch.empty_like(raw_ipc) for _ in range(world_size)]\n        torch.distributed.all_gather(raw_ipcs, raw_ipc)\n        group_ipcs = [\n            raw_ipcs[x]\n            for x in range(\n                self.group_id * self.group_size,\n                (self.group_id * self.group_size) + self.group_size,\n            )\n        ]\n        peer_raw_ipcs = torch.stack(group_ipcs).cpu()\n        return pm.get_raw_peers(peer_raw_ipcs, self.group_rank, raw)\n\n    def _check_input_dim(self, input):\n        if input.dim() != 4:\n            raise ValueError(\"expected 4D input (got {}D input)\".format(input.dim()))\n\n    def _check_input_channels(self, input):\n        if input.size(1) % 8 != 0:\n            raise ValueError(\"GroupBatchNorm2d number of input channels should be a multiple of 8\")\n\n    def forward(self, input: Tensor) -> Tensor:\n        # currently only GPU input is supported\n        if not input.is_cuda:\n            raise ValueError(\"GroupBatchNorm2d expected input tensor to be on GPU\")\n        if not input.is_contiguous(memory_format=torch.channels_last):\n            raise ValueError(\n                \"GroupBatchNorm2d expected input tensor to be in channels last memory format\"\n            )\n        if torch.is_autocast_enabled():\n            input = input.to(torch.get_autocast_gpu_dtype())\n        if input.dtype != torch.float16:\n            raise ValueError(\"GroupBatchNorm2d expected input tensor in float16\")\n        self._check_input_dim(input)\n        self._check_input_channels(input)\n\n        if not self.training:\n            # fall back to pytorch implementation for inference\n            return F.batch_norm(\n                input,\n                self.running_mean,\n                self.running_var,\n                self.weight,\n                self.bias,\n                False,\n                self.momentum,\n                self.eps,\n            )\n\n        return _GroupBatchNorm2d.apply(\n            input,\n            self.weight,\n            self.bias,\n            self.running_mean,\n            self.running_var,\n            self.minibatch_mean,\n            self.minibatch_inv_var,\n            self.momentum,\n            self.eps,\n            self.group_size,\n            self.group_rank,\n            self.fwd_peer_buffers,\n            self.bwd_peer_buffers,\n        )\n"
  },
  {
    "path": "apex/contrib/examples/gpu_direct_storage/benchmark_load.py",
    "content": "import timeit\nimport torch\nimport apex.contrib.gpu_direct_storage as gds\n\ndef run_benchmark_torch_load():\n    sizes = [2 ** i for i in range(16, 28)]\n    for size in sizes:\n        torch.cuda.empty_cache()\n        s = torch.cuda.Stream()\n        x = torch.empty(size, device = \"cuda\")\n        y = torch.linspace(0, 1, size, device = \"cuda\")\n        torch.save(y, f\"{size}.data\")\n\n        # warmup\n        torch.cuda.synchronize()\n        for _ in range(10):\n            x = torch.load(f\"{size}.data\")\n\n        torch.cuda.synchronize()\n        start_time = timeit.default_timer()\n        for _ in range(10):\n            x = torch.load(f\"{size}.data\")\n        torch.cuda.synchronize()\n        end_time = timeit.default_timer()\n        print(f\"torch.load: size = {size}, {end_time - start_time}\")\n        assert(torch.allclose(x, y))\n\ndef run_benchmark(func):\n    sizes = [2 ** i for i in range(16, 28)]\n    for size in sizes:\n        torch.cuda.empty_cache()\n        s = torch.cuda.Stream()\n        x = torch.empty(size, device = \"cuda\")\n        y = torch.linspace(0, 1, size, device = \"cuda\")\n\n        with gds.GDSFile(f\"{size}.data\", \"w\") as f:\n            f.save_data(y)\n\n        # warmup\n        torch.cuda.synchronize()\n        for _ in range(10):\n            func(x, f\"{size}.data\")\n\n        torch.cuda.synchronize()\n        start_time = timeit.default_timer()\n        for _ in range(10):\n            func(x, f\"{size}.data\")\n        torch.cuda.synchronize()\n        end_time = timeit.default_timer()\n        print(f\"{func.__name__}: size = {size}, {end_time - start_time}\")\n        assert(torch.allclose(x, y))\n\ndef load_data_yes_gds(tensor, filename):\n    with gds.GDSFile(filename, \"r\") as f:\n        f.load_data(tensor)\n\ndef load_data_no_gds(tensor, filename):\n    with gds.GDSFile(filename, \"rn\") as f:\n        f.load_data_no_gds(tensor)\n\nif __name__ == '__main__':\n    run_benchmark_torch_load()\n    run_benchmark(load_data_yes_gds)\n    run_benchmark(load_data_no_gds)\n"
  },
  {
    "path": "apex/contrib/examples/gpu_direct_storage/benchmark_save.py",
    "content": "import os\nimport timeit\nimport torch\nimport apex.contrib.gpu_direct_storage as gds\n\ndef run_benchmark(func):\n    sizes = [2 ** i for i in range(16, 28)]\n    for size in sizes:\n        torch.cuda.empty_cache()\n        s = torch.cuda.Stream()\n        x = torch.linspace(0, 1, size, device = \"cuda\")\n\n        # warmup\n        torch.cuda.synchronize()\n        for _ in range(10):\n            func(x, f\"{size}.data\")\n            os.remove(f\"{size}.data\")\n\n        torch.cuda.synchronize()\n        start_time = timeit.default_timer()\n        for _ in range(10):\n            func(x, f\"{size}.data\")\n            os.remove(f\"{size}.data\")\n        torch.cuda.synchronize()\n        end_time = timeit.default_timer()\n        print(f\"{func.__name__}: size = {size}, {end_time - start_time}\")\n\ndef save_data_yes_gds(tensor, filename):\n    with gds.GDSFile(filename, \"w\") as f:\n        f.save_data(tensor)\n\ndef save_data_no_gds(tensor, filename):\n    with gds.GDSFile(filename, \"wn\") as f:\n        f.save_data_no_gds(tensor)\n\nif __name__ == '__main__':\n    run_benchmark(torch.save)\n    run_benchmark(save_data_yes_gds)\n    run_benchmark(save_data_no_gds)\n"
  },
  {
    "path": "apex/contrib/examples/gpu_direct_storage/example_load.py",
    "content": "import torch\nimport apex.contrib.gpu_direct_storage as gds\n\nfor size in [128, 1024, 8192]:\n    x = torch.empty(size, device = \"cuda\")\n    with gds.GDSFile(f\"{size}.data\", \"r\") as f:\n        f.load_data(x)\n    xx = torch.linspace(0, 1, size, device = \"cuda\")\n    assert(torch.allclose(x, xx))\n"
  },
  {
    "path": "apex/contrib/examples/gpu_direct_storage/example_save.py",
    "content": "import torch\nimport apex.contrib.gpu_direct_storage as gds\n\nfor size in [128, 1024, 8192]:\n    x = torch.linspace(0, 1, size, device = \"cuda\")\n    with gds.GDSFile(f\"{size}.data\", \"w\") as f:\n        f.save_data(x)\n"
  },
  {
    "path": "apex/contrib/examples/multihead_attn/func_test_multihead_attn.py",
    "content": "import torch\nimport argparse\n\nfrom apex.contrib.multihead_attn import SelfMultiheadAttn\nfrom apex.contrib.multihead_attn import EncdecMultiheadAttn\n\nparser = argparse.ArgumentParser(description='Multihead Attention Standalone Test')\nparser.add_argument('--seq-length', default=64, type=int, help='Sequence Length of Input')\nparser.add_argument('--num-seqs-start', default=5, type=int, help='Start Range of Number of Sequences')\nparser.add_argument('--num-seqs-stop', default=80, type=int, help='Stop Range of Number of Sequences')\nparser.add_argument('--num-seqs-inc', default=5, type=int, help='Range Increment of Number of Sequences')\nparser.add_argument('--trials', default=20, type=int, help='Number of Trials to Execute')\nparser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard')\nparser.add_argument('--layers', default=18, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')\nparser.add_argument('--seed-start', default=1, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')\nparser.add_argument('--seed-end', default=100, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')\nparser.add_argument('--hidden-dim', default=1024, type=int, help='Multihead Attention hidden dimension')\nparser.add_argument('--heads', default=16, type=int, help='Number of Multihead Attention heads')\nparser.add_argument('--encdec-attn', action='store_true', help='Use Encoder-Decoder Attention instead of Self Attention.')\nparser.add_argument('--norm-add', action='store_true', help='Include Layer Norm and Dropout-Add in Multihead Attention block.')\nparser.add_argument('--ref', action='store_true', help='Reference implementation in python pytorch.')\nparser.add_argument('--native', action='store_true', help='torch.nn.MultitheadAttention Version.')\nparser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.')\nparser.add_argument('--eval', action='store_true', help='Inference only, no backward pass.')\n\nargs = parser.parse_args()\nassert args.seq_length % 64 == 0, \"Sequence Length should be a multiple of 64!\"\n\nif not torch.cuda.is_available():\n    raise NotImplementedError('Running on CPU is not supported')\ntorch.cuda.set_device(0)\n\ndropout_prob = 0.1\n\nfor seed in range(args.seed_start, args.seed_end+1) :\n    torch.manual_seed(seed)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed_all(seed)\n    ref_layer = None\n    if args.encdec_attn :\n        ref_layer = EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='default')\n    else :\n        ref_layer = SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='default')\n    ref_layer.cuda()\n    ref_layer.half()\n    ref_layer.reset_parameters()\n\n    ref_inputs    = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n    ref_inputs_kv = None\n    if args.encdec_attn :\n        ref_inputs_kv    = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n    ref_grads         = torch.randn_like(ref_inputs)\n\n    ref_outputs,_ = ref_layer.forward(ref_inputs,\n                                      ref_inputs_kv,\n                                      ref_inputs_kv,\n                                      key_padding_mask=None,\n                                      need_weights=False,\n                                      attn_mask=None,\n                                      is_training=(not args.eval))\n\n    ref_outputs.backward(ref_grads)\n\n    torch.manual_seed(seed)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed_all(seed)\n    tst_layer = None\n    if args.encdec_attn :\n        tst_layer = EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='fast')\n    else:\n        tst_layer = SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='fast')\n    tst_layer.cuda()\n    tst_layer.half()\n    tst_layer.reset_parameters()\n\n    tst_inputs    = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n    tst_inputs_kv = None\n    if args.encdec_attn :\n        tst_inputs_kv    = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n\n    assert torch.equal(ref_inputs,tst_inputs), \"ERROR: Inputs are different!\"\n\n    tst_grads         = torch.randn_like(tst_inputs)\n\n    tst_outputs,_ = tst_layer.forward(tst_inputs,\n                                      tst_inputs_kv,\n                                      tst_inputs_kv,\n                                      key_padding_mask=None,\n                                      need_weights=False,\n                                      attn_mask=None,\n                                      is_training=(not args.eval))\n\n    tst_outputs.backward(tst_grads)\n\n    fwd_close = torch.equal(ref_outputs, tst_outputs)\n    bwd_close = torch.equal(ref_inputs.grad, tst_inputs.grad)\n\n    diff_fwd = ref_outputs - tst_outputs\n    diff_cnt_fwd = diff_fwd.ne(0.0).sum()\n    diff_accum_fwd = diff_fwd.abs().sum()\n\n    diff_bwd = ref_inputs.grad - tst_inputs.grad\n    diff_cnt_bwd = diff_bwd.ne(0.0).sum()\n    diff_accum_bwd = diff_bwd.abs().sum()\n\n    print(\">>> Seed: \", seed, fwd_close, diff_cnt_fwd.item(), diff_accum_fwd.item(), bwd_close, diff_cnt_bwd.item(), diff_accum_bwd.item())\n"
  },
  {
    "path": "apex/contrib/examples/multihead_attn/perf_test_multihead_attn.py",
    "content": "import torch\nimport argparse\n\nfrom apex.contrib.multihead_attn import SelfMultiheadAttn\nfrom apex.contrib.multihead_attn import EncdecMultiheadAttn\n\nparser = argparse.ArgumentParser(description='Multihead Attention Standalone Test')\nparser.add_argument('--seq-length', default=64, type=int, help='Sequence Length of Input')\nparser.add_argument('--num-seqs-start', default=10, type=int, help='Start Range of Number of Sequences')\nparser.add_argument('--num-seqs-stop', default=120, type=int, help='Stop Range of Number of Sequences')\nparser.add_argument('--num-seqs-inc', default=5, type=int, help='Range Increment of Number of Sequences')\nparser.add_argument('--trials', default=20, type=int, help='Number of Trials to Execute')\nparser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard')\nparser.add_argument('--layers', default=18, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')\nparser.add_argument('--hidden-dim', default=1024, type=int, help='Multihead Attention hidden dimension')\nparser.add_argument('--heads', default=16, type=int, help='Number of Multihead Attention heads')\nparser.add_argument('--encdec-attn', action='store_true', help='Use Encoder-Decoder Attention instead of Self Attention.')\nparser.add_argument('--norm-add', action='store_true', help='Include Layer Norm and Dropout-Add in Multihead Attention block.')\nparser.add_argument('--ref', action='store_true', help='Reference implementation in python pytorch.')\nparser.add_argument('--native', action='store_true', help='torch.nn.MultitheadAttention Version.')\nparser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.')\nparser.add_argument('--biases', action='store_true', help='Execute multihead attention with Linear Biases.')\n\nargs = parser.parse_args()\n\nif not torch.cuda.is_available():\n    raise NotImplementedError('Running on CPU is not supported')\ntorch.cuda.set_device(0)\n\ntorch.manual_seed(111)\nif torch.cuda.is_available():\n    torch.cuda.manual_seed_all(111)\n\nattn_layers = []\nfor idx in range(0, args.layers) :\n    if args.encdec_attn :\n        if args.ref :\n            attn_layers.append(EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=False, impl='default'))\n        else :\n            attn_layers.append(EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=args.norm_add, impl='fast'))\n    else :\n        if args.native :\n            attn_layers.append(torch.nn.MultiheadAttention(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases))\n        elif args.ref :\n            attn_layers.append(SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=args.norm_add, impl='default'))\n        else :\n            attn_layers.append(SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=args.norm_add, impl='fast'))\n    attn_layers[idx].cuda()\n    attn_layers[idx].half()\n    if not args.native :\n        attn_layers[idx].reset_parameters()\n\nstart_evt_fwd = []\nstart_evt_bwd = []\nstop_evt_bwd  = []\nfor recorded_trial in range(0, args.trials) :\n    start_evt_fwd.append(torch.cuda.Event(enable_timing=True))\n    start_evt_bwd.append(torch.cuda.Event(enable_timing=True))\n    stop_evt_bwd.append(torch.cuda.Event(enable_timing=True))\n\nfor sequences in range(args.num_seqs_start, args.num_seqs_stop + args.num_seqs_inc, args.num_seqs_inc) :\n    inputs        = torch.randn(args.seq_length, sequences, args.hidden_dim, dtype=torch.float16, device=torch.device(\"cuda\")).requires_grad_(True)\n    grads         = torch.randn_like(inputs)\n   \n    for trial in range(0, args.trials + args.warmup_trials) :\n        layer_inputs  = inputs\n        evt_idx       = trial - args.warmup_trials\n    \n        if evt_idx >= 0 :\n            start_evt_fwd[evt_idx].record()\n    \n        for lyr_idx in range(0, args.layers) :\n            if args.native :\n                outputs,_ = attn_layers[lyr_idx].forward(layer_inputs, \n                                                         layer_inputs, \n                                                         layer_inputs, \n                                                         key_padding_mask=None, \n                                                         need_weights=False, \n                                                         attn_mask=None)\n            else :\n                outputs,_ = attn_layers[lyr_idx].forward(layer_inputs, \n                                                         layer_inputs, \n                                                         layer_inputs,\n                                                         key_padding_mask=None, \n                                                         need_weights=False, \n                                                         attn_mask=None,\n                                                         is_training=True)\n            layer_inputs = outputs\n    \n        if evt_idx >= 0 :\n            start_evt_bwd[evt_idx].record()\n\n        if not args.fwd :\n            layer_inputs.backward(grads)\n    \n        if evt_idx >= 0 :\n            stop_evt_bwd[evt_idx].record()\n   \n    torch.cuda.synchronize()\n    elapsed_time_fwd = 0.0\n    elapsed_time_bwd = 0.0\n    for evt_idx in range(0, args.trials) :\n        elapsed_time_fwd += start_evt_fwd[evt_idx].elapsed_time(start_evt_bwd[evt_idx])\n        elapsed_time_bwd += start_evt_bwd[evt_idx].elapsed_time(stop_evt_bwd[evt_idx])\n   \n    print(\"[ {} Attn {} ]Total Tokens: {:4d} Sequences: {:3d} Sequence Length: {:3d} Fwd Time / Layer: {:.3f} ms Bwd Time / Layer: {:.3f} ms\".format(\n        'Encdec' if args.encdec_attn else 'Self',              \\\n        'Norm&Add' if args.norm_add else '',                   \\\n        sequences*args.seq_length,                             \\\n        sequences,                                             \\\n        args.seq_length,                                       \\\n        elapsed_time_fwd / ( args.trials * args.layers ),      \\\n        elapsed_time_bwd / ( args.trials * args.layers )))\n\n"
  },
  {
    "path": "apex/contrib/examples/nccl_allocator/allreduce.py",
    "content": "import os\nimport torch\nimport torch.distributed as dist\nimport apex.contrib.nccl_allocator as nccl_allocator\n\nassert os.getenv(\"WORLD_SIZE\") is not None, \"Please use: torchrun --nproc-per-node=8 allreduce.py\"\n\nrank = int(os.getenv(\"RANK\"))\nlocal_rank = int(os.getenv(\"LOCAL_RANK\"))\nworld_size = int(os.getenv(\"WORLD_SIZE\"))\n\nnccl_allocator.init()\n\ntorch.cuda.set_device(local_rank)\ndist.init_process_group(backend=\"nccl\")\npool = nccl_allocator.create_nccl_mem_pool()\nwith nccl_allocator.nccl_mem(pool):\n    a = torch.ones(1024 * 1024 * 2, device=\"cuda\")\ndist.all_reduce(a)\n\ntorch.cuda.synchronize()\n\n"
  },
  {
    "path": "apex/contrib/examples/nccl_allocator/cache.py",
    "content": "import torch\nimport apex.contrib.nccl_allocator as nccl_allocator\nfrom pynvml.smi import nvidia_smi\n\ndef set_device(dev):\n    import ctypes\n    handle = ctypes.CDLL(\"libcudart.so\")\n    result = handle.cudaSetDevice(ctypes.c_int(dev))\n    assert result == 0\n\ndef print_used_mem(string, nvsmi, device_id = 0):\n    print(f\"{string}:\", nvsmi.DeviceQuery('memory.used')['gpu'][device_id])\n\nnccl_allocator.init()\nnrep = 6\nnccl_mem = []\n\nset_device(0)\nnvsmi = nvidia_smi.getInstance()\n\nprint_used_mem(\"\", nvsmi)\n\npool = nccl_allocator.create_nccl_mem_pool()\nwith nccl_allocator.nccl_mem(pool):\n    for i in range(nrep):\n      out = torch.randn(1024 * 1024 * 100).cuda() # >= 400 MB\n      nccl_mem.append(out)\n\nprint_used_mem(\"after nccl alloc (+>=2400)\", nvsmi) # + 2400+ MB\n\ncudart_mem = []\nfor i in range(nrep):\n  out = torch.randn(1024 * 1024 * 50 ).cuda() # == 200 MB\n  cudart_mem.append(out)\n\nprint_used_mem(\"after cudart alloc (+1200)\", nvsmi)\n\ndel cudart_mem\ntorch.cuda.empty_cache()\ntorch.cuda.empty_cache()\nprint_used_mem(\"release cudart mem (-1200)\", nvsmi) # - 1200 MB\n\ndel nccl_mem\nnccl_mem2 = []\nwith nccl_allocator.nccl_mem(pool):\n    for i in range(nrep):\n      out = torch.randn(1024 * 1024 * 100).cuda() # >= 400 MB\n      nccl_mem2.append(out)\nprint_used_mem(\"reuse nccl cache (same)\", nvsmi) # + 0 MB\ndel nccl_mem2\ntorch.cuda.empty_cache()\nprint_used_mem(\"release nccl_mem (-2400)\", nvsmi) # - 2400 MB\n\ntorch.cuda.empty_cache()\n"
  },
  {
    "path": "apex/contrib/examples/nccl_allocator/change_cuda_allocator.py",
    "content": "import torch\nimport apex.contrib.nccl_allocator as nccl_allocator\n\nnccl_allocator.init()\nnrep = 6\npool = nccl_allocator.create_nccl_mem_pool()\nwith nccl_allocator.nccl_mem(pool):\n    for i in range(nrep):\n      out = torch.randn(1024).cuda()\n\nfor i in range(nrep):\n  out = torch.randn(1024).cuda()\n\ntorch.cuda.empty_cache()\ntorch.cuda.empty_cache()\n"
  },
  {
    "path": "apex/contrib/examples/nccl_allocator/toy_ddp.py",
    "content": "import os\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel as DDP\n\nimport apex.contrib.nccl_allocator as nccl_allocator\n\nassert os.getenv(\"WORLD_SIZE\") is not None, \"Please use: torchrun --nproc-per-node=8 toy_ddp.py\"\n\nclass ToyModel(nn.Module):\n    def __init__(self):\n        super(ToyModel, self).__init__()\n        self.net1 = nn.Linear(10, 10)\n        self.relu = nn.ReLU()\n        self.net2 = nn.Linear(10, 5)\n\n    def forward(self, x):\n        return self.net2(self.relu(self.net1(x)))\n\n\nrank = int(os.getenv(\"RANK\"))\nlocal_rank = int(os.getenv(\"LOCAL_RANK\"))\nworld_size = int(os.getenv(\"WORLD_SIZE\"))\n\nnccl_allocator.init()\n\ntorch.cuda.set_device(local_rank)\ndist.init_process_group(backend=\"nccl\")\n\ndevice = torch.device(\"cuda\", local_rank)\nmodel = ToyModel().to(device)\nddp_model = DDP(model, device_ids=[rank])\nloss_fn = nn.MSELoss()\noptimizer = optim.SGD(ddp_model.parameters(), lr=0.001)\n\ndata_ptrs = []\npool = nccl_allocator.create_nccl_mem_pool()\nwith nccl_allocator.nccl_mem(pool):\n    for param in ddp_model.parameters():\n        param.grad = torch.empty_like(param)\n        data_ptrs.append(param.grad.data_ptr())\n\nfor _ in range(10):\n    optimizer.zero_grad(set_to_none=False)\n    outputs = ddp_model(torch.randn(20, 10))\n    labels = torch.randn(20, 5).to(rank)\n    loss_fn(outputs, labels).backward()\n    optimizer.step()\n\nfor data_ptr, param in zip(data_ptrs, ddp_model.parameters()):\n    assert(data_ptr == param.grad.data_ptr())\ndist.destroy_process_group()\n"
  },
  {
    "path": "apex/contrib/fmha/__init__.py",
    "content": "from .fmha import FMHAFun\n"
  },
  {
    "path": "apex/contrib/fmha/fmha.py",
    "content": "###############################################################################\n# Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#     * Redistributions of source code must retain the above copyright\n#       notice, this list of conditions and the following disclaimer.\n#     * Redistributions in binary form must reproduce the above copyright\n#       notice, this list of conditions and the following disclaimer in the\n#       documentation and/or other materials provided with the distribution.\n#     * Neither the name of the NVIDIA CORPORATION nor the\n#       names of its contributors may be used to endorse or promote products\n#       derived from this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n#\n###############################################################################\n\n\nimport torch\nimport fmhalib as mha\n\n\nclass FMHAFun(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training, zero_tensors):\n        from apex import deprecated_warning\n\n        deprecated_warning(\n            \"`apex.contrib.fmha` is deprecated and will be removed in July 2026. \"\n            \"We encourage you to migrate to PyTorch native MultiheadAttention\"\n            \"The documentation is available in https://docs.pytorch.org/docs/main/generated/torch.nn.MultiheadAttention.html\"\n        )\n\n        batch_size = cu_seqlens.numel() - 1\n        if batch_size < 4:\n            max_s = 512\n            context, S_dmask = mha.fwd_nl(\n                qkv, cu_seqlens, p_dropout, max_s, is_training, True, zero_tensors, None\n            )\n        else:\n            context, S_dmask = mha.fwd(\n                qkv,\n                cu_seqlens,\n                p_dropout,\n                max_s,\n                is_training,\n                False,\n                zero_tensors,\n                None,\n            )\n        ctx.save_for_backward(qkv, S_dmask)\n        ctx.cu_seqlens = cu_seqlens\n        ctx.p_dropout = p_dropout\n        ctx.max_s = max_s\n        ctx.zero_tensors = zero_tensors\n        return context\n\n    @staticmethod\n    def backward(ctx, dout):\n        qkv, S_dmask = ctx.saved_tensors\n        batch_size = ctx.cu_seqlens.numel() - 1\n        if batch_size < 4:\n            dqkv, dp, _ = mha.bwd_nl(\n                dout,\n                qkv,\n                S_dmask,\n                ctx.cu_seqlens,\n                ctx.p_dropout,\n                ctx.max_s,\n                ctx.zero_tensors,\n            )\n        else:\n            dqkv, dp = mha.bwd(\n                dout,\n                qkv,\n                S_dmask,\n                ctx.cu_seqlens,\n                ctx.p_dropout,\n                ctx.max_s,\n                ctx.zero_tensors,\n            )\n\n        return dqkv, None, None, None, None, None\n\n\nclass FMHA(torch.nn.Module):\n    def __init__(self, config):\n        super(FMHA, self).__init__()\n\n        self.p_dropout = config.attention_probs_dropout_prob\n        self.h = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.d = self.hidden_size // self.h\n        assert self.d * self.h == self.hidden_size, \"Invalid hidden size/num_heads\"\n\n    def forward(self, qkv, cu_seqlens, max_s, is_training=True, zero_tensors=False):\n        ctx = FMHAFun.apply(\n            qkv.view(-1, 3, self.h, self.d),\n            cu_seqlens,\n            self.p_dropout,\n            max_s,\n            is_training,\n            zero_tensors,\n        )\n\n        return ctx.view(-1, self.hidden_size)\n"
  },
  {
    "path": "apex/contrib/focal_loss/__init__.py",
    "content": "try:\n    import torch\n    import focal_loss_cuda\n    from .focal_loss import focal_loss\n\n    del torch\n    del focal_loss_cuda\n    del focal_loss\nexcept ImportError:\n    print(\"apex was installed without --focal_loss flag, apex.contrib.focal_loss is not available\")\n"
  },
  {
    "path": "apex/contrib/focal_loss/focal_loss.py",
    "content": "import torch\n\nimport focal_loss_cuda\n\n\nclass FocalLoss(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        cls_output,\n        cls_targets_at_level,\n        num_positives_sum,\n        num_real_classes,\n        alpha,\n        gamma,\n        label_smoothing=0.0,\n    ):\n        loss, partial_grad = focal_loss_cuda.forward(\n            cls_output,\n            cls_targets_at_level,\n            num_positives_sum,\n            num_real_classes,\n            alpha,\n            gamma,\n            label_smoothing,\n        )\n\n        ctx.save_for_backward(partial_grad, num_positives_sum)\n        return loss\n\n    @staticmethod\n    def backward(ctx, grad_loss):\n        partial_grad, num_positives_sum = ctx.saved_tensors\n\n        # The backward kernel is actually in-place to save memory space,\n        # partial_grad and grad_input are the same tensor.\n        grad_input = focal_loss_cuda.backward(grad_loss, partial_grad, num_positives_sum)\n\n        return grad_input, None, None, None, None, None, None\n\n\ndef focal_loss(\n    cls_output: torch.Tensor,\n    cls_targets_at_level: torch.Tensor,\n    num_positive_sum: torch.Tensor,\n    num_real_classes: int,\n    alpha: float,\n    gamma: float,\n    label_smoothing: float = 0.0,\n) -> torch.Tensor:\n    \"\"\"Fused focal loss function.\"\"\"\n    return FocalLoss.apply(\n        cls_output,\n        cls_targets_at_level,\n        num_positive_sum,\n        num_real_classes,\n        alpha,\n        gamma,\n        label_smoothing,\n    )\n"
  },
  {
    "path": "apex/contrib/gpu_direct_storage/README.md",
    "content": "# APEX GPUDirect Storage\n\nThis module aims to add a PyTorch extension for [GPUDirect Storage](https://developer.nvidia.com/blog/gpudirect-storage/) (GDS) support through utilizing the [cuFile](https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html) library.\n\n# Build command\n```\npip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--gpu_direct_storage\" ./\n```\n\nAlternatively:\n```\npython setup.py install --gpu_direct_storage\n```\n\nCheck installation:\n```\npython -c \"import torch; import apex.contrib.gpu_direct_storage\"\n```\n"
  },
  {
    "path": "apex/contrib/gpu_direct_storage/__init__.py",
    "content": "from _apex_gpu_direct_storage import _GDSFile\nfrom contextlib import contextmanager\n\n\n@contextmanager\ndef GDSFile(filename, mode):\n    assert type(filename) == str\n    assert type(mode) == str\n    try:\n        from apex import deprecated_warning\n\n        deprecated_warning(\n            \"`gpu_direct_storage.GDSFile` is deprecated and will be removed in September 2025. \"\n            \"We encourage you to use `torch.cuda.gds` module of PyTorch as a replacement. \"\n            \"Its documentation is available at https://docs.pytorch.org/docs/stable/cuda.html#gpudirect-storage-prototype\"\n        )\n        file_handle = _GDSFile(filename, mode)\n        yield file_handle\n    finally:\n        file_handle.close()\n        del file_handle\n"
  },
  {
    "path": "apex/contrib/group_norm/__init__.py",
    "content": "from .group_norm import *\n"
  },
  {
    "path": "apex/contrib/group_norm/group_norm.py",
    "content": "#!/usr/bin/env python\n# coding: utf-8\n\n#\n# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: BSD-3-Clause\n#\n\nimport functools\nimport os\nimport torch\nimport torch.nn.init as init\nimport group_norm_cuda\nimport group_norm_v2_cuda\n\nfrom torch import Tensor\nfrom torch.nn.parameter import Parameter\n\n__all__ = [\"GroupNorm\"]\n\n\ndef one_time_warning(msg: str):\n    if not hasattr(one_time_warning, \"has_been_called\"):\n        one_time_warning.has_been_called = True\n        print(f\"\\033[93m{msg}\\033[0m\")  # hightlight with yellow color\n\n\n@functools.cache\ndef get_cc_and_sm_count(device_index: int):\n    props = torch.cuda.get_device_properties(device_index)\n    CC = (props.major, props.minor)\n    SM_COUNT = props.multi_processor_count\n    return CC, SM_COUNT\n\n\n# pytorch group norm requires same input type\ndef torch_group_norm(x, g, w, b, eps, act=\"\"):\n    xdtype, wdtype = x.dtype, w.dtype\n    if xdtype != wdtype:\n        x = x.to(dtype=wdtype)\n    y = torch.nn.functional.group_norm(x, g, w, b, eps)\n    if act in [\"silu\", \"swish\"]:\n        y = torch.nn.functional.silu(y)\n    if xdtype != wdtype and y.dtype != xdtype:\n        y = y.to(dtype=xdtype)\n    return y\n\n\n@torch.library.custom_op(\"apex::group_norm_nhwc_fprop\", mutates_args=())\ndef group_norm_nhwc_fprop(\n    x: torch.Tensor,\n    G: int,\n    weight: torch.Tensor,\n    bias: torch.Tensor,\n    eps: float,\n    act: str | None = None,\n    passes: int = 1,\n    use_group_norm_v2: bool = False,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    # sanity check\n    act = act.lower() if act else act\n    assert x.is_contiguous(memory_format=torch.channels_last), \"Only support NHWC layout.\"\n    assert weight.numel() == x.shape[1], \"Unexpected parameter count.\"\n    assert bias.numel() == x.shape[1], \"Unexpected parameter count.\"\n    assert x.shape[1] % G == 0, \"C % G != 0.\"\n    assert act in [None, \"\", \"silu\", \"swish\"], \"Unsupported activation.\"\n    assert passes in [1, 2], \"Invalid number of passes for algorithm.\"\n\n    with_swish = act in (\"silu\", \"swish\")\n    sm_margin = int(os.environ.get(\"APEX_GROUP_NORM_FPROP_SM_MARGIN\", \"0\"))\n\n    # enqueue fprop kernel\n    if use_group_norm_v2:\n        sums = torch.empty(x.shape[0] * G * 2, device=x.device)\n        y = group_norm_v2_cuda.gn(\n            x, weight, bias, eps, with_swish, G, mean_var_out=sums, sm_margin=sm_margin\n        )\n    else:\n        if sm_margin:\n            raise NotImplementedError(\"sm_margin is not supported for GroupNorm v1\")\n        y, sums = group_norm_cuda.forward(x, G, weight, bias, eps, passes, with_swish)\n    return y, sums\n\n\n@group_norm_nhwc_fprop.register_fake\ndef fake_group_norm_nhwc_fprop(\n    x, G, weight, bias, eps, act=None, passes=1, use_group_norm_v2=False\n):\n    # sanity check\n    act = act.lower() if act else act\n    assert x.is_contiguous(memory_format=torch.channels_last), \"Only support NHWC layout.\"\n    assert weight.numel() == x.shape[1], \"Unexpected parameter count.\"\n    assert bias.numel() == x.shape[1], \"Unexpected parameter count.\"\n    assert x.shape[1] % G == 0, \"C % G != 0.\"\n    assert act in [None, \"\", \"silu\", \"swish\"], \"Unsupported activation.\"\n    assert passes in [1, 2], \"Invalid number of passes for algorithm.\"\n\n    y = torch.empty_like(x)\n    sums = torch.empty(2 * x.shape[0] * G, device=\"cuda\", dtype=torch.float32)\n    return y, sums\n\n\n@torch.library.custom_op(\"apex::group_norm_nhwc_bprop\", mutates_args=())\ndef group_norm_nhwc_bprop(\n    grad_output: torch.Tensor,\n    sums: torch.Tensor,\n    x: torch.Tensor,\n    G: int,\n    weight: torch.Tensor,\n    bias: torch.Tensor,\n    eps: float,\n    act: str | None = None,\n    passes: int = 1,\n    use_group_norm_v2: bool = False,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    # sanity check\n    if not grad_output.is_contiguous(memory_format=torch.channels_last):\n        one_time_warning(\n            \"Warning: GroupNorm NHWC expects NHWC grad_output but it's not, \"\n            \"thus a memory format change is introduced. \"\n            \"This may come from the TorchInductor rule that tangents must be \"\n            \"contiguous. Try to avoid graph break around NHWC tensors \"\n            \"can fix this issue. (Future warning will be suppressed.)\"\n        )\n        grad_output = grad_output.contiguous(memory_format=torch.channels_last)\n\n    act = act.lower() if act else act\n    with_swish = act in [\"silu\", \"swish\"]\n    sm_margin = int(os.environ.get(\"APEX_GROUP_NORM_BPROP_SM_MARGIN\", \"0\"))\n\n    if use_group_norm_v2:\n        dx, dw, db = group_norm_v2_cuda.gn_bwd(\n            grad_output, x, weight, bias, sums, eps, with_swish, G, sm_margin=sm_margin\n        )\n    else:\n        if sm_margin:\n            raise NotImplementedError(\"sm_margin is not supported for GroupNorm v1\")\n        dx, dw, db = group_norm_cuda.backward(\n            grad_output, sums, x, G, weight, bias, eps, passes, with_swish\n        )\n    return dx, dw, db\n\n\n@group_norm_nhwc_bprop.register_fake\ndef fake_group_norm_nhwc_bprop(\n    grad_output,\n    sums,\n    x,\n    G,\n    weight,\n    bias,\n    eps,\n    act=None,\n    passes=1,\n    use_group_norm_v2=False,\n):\n    dx = torch.empty_like(x)\n    dw = torch.empty_like(weight)\n    db = torch.empty_like(bias)\n    return dx, dw, db\n\n\ndef backward(ctx, grad_output, grad_sums):\n    # retrive saved info\n    x, w, b, sums = ctx.saved_tensors\n    G = ctx.G\n    eps = ctx.eps\n    passes = ctx.passes\n    act = ctx.act\n    use_group_norm_v2 = ctx.use_group_norm_v2\n\n    dx, dw, db = group_norm_nhwc_bprop(\n        grad_output, sums, x, G, w, b, eps, act, passes, use_group_norm_v2\n    )\n    return dx, None, dw, db, None, None, None, None\n\n\ndef setup_context(ctx, inputs, output):\n    x, G, weight, bias, eps, act, passes, use_group_norm_v2 = inputs\n    y, sums = output\n    # save for backward\n    ctx.save_for_backward(x, weight, bias, sums)\n    ctx.G = G\n    ctx.eps = eps\n    ctx.passes = passes\n    ctx.act = act\n    ctx.use_group_norm_v2 = use_group_norm_v2\n\n\ngroup_norm_nhwc_fprop.register_autograd(backward, setup_context=setup_context)\n\n\ndef cuda_group_norm_nhwc_one_pass(x, G, weight, bias, eps, act=None):\n    y, _ = group_norm_nhwc_fprop(x, G, weight, bias, eps, act, passes=1)\n    return y\n\n\ndef cuda_group_norm_nhwc_two_pass(x, G, weight, bias, eps, act=None):\n    y, _ = group_norm_nhwc_fprop(x, G, weight, bias, eps, act, passes=2)\n    return y\n\n\ndef cuda_group_norm_v2_nhwc(x, G, weight, bias, eps, act=None):\n    y, _ = group_norm_nhwc_fprop(x, G, weight, bias, eps, act, use_group_norm_v2=True)\n    return y\n\n\n# We do not direct inherit from torch.nn.GroupNorm since several fusers don't\n# support inheritance. Extends:\n# https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/normalization.py\nclass GroupNorm(torch.nn.Module):\n    \"\"\"Optimized GroupNorm for NHWC layout with optional Swish/SiLU fusion.\n\n    There are two version of CUDA kernels under the hood: one pass and two\n    passes. This operator contains a simple heuristic to choose algorithm.\n\n    Limitations:\n\n    * Designed for 32 groups, also tested with 16 groups, some other number\n      of groups can also work but not guaranteed;\n    * Supported number of channels C are:\n\n        128, 256, 320, 384, 448, 512, 640, 768, 896, 960, 1024, 1280, 1344,\n        1536, 1792, 1920, 2048, 2240, 2560, 2688, 3072, 3136, 3584, 4096.\n\n      One pass algorithm supports only channels mentioned above. Two pass\n      algorithm might automatically support some other channels as well.\n    * N/H/W do not have lower (except >0) and upper bound limitations;\n\n    All the unsupported cases will be forwarded to PyTorch implementation.\n    \"\"\"\n\n    __constants__ = [\n        \"num_groups\",\n        \"num_channels\",\n        \"eps\",\n        \"affine\",\n        \"act\",\n        \"SUPPORTED_CHANNELS\",\n        \"SUPPORTED_GROUPS\",\n    ]\n    num_groups: int\n    num_channels: int\n    eps: float\n    affine: bool\n    act: str | None\n    SUPPORTED_CHANNELS = frozenset(\n        [\n            128,\n            256,\n            320,\n            384,\n            448,\n            512,\n            640,\n            768,\n            896,\n            960,\n            1024,\n            1280,\n            1344,\n            1536,\n            1792,\n            1920,\n            2048,\n            2240,\n            2560,\n            2688,\n            3072,\n            3136,\n            3584,\n            4096,\n        ]\n    )\n    SUPPORTED_GROUPS = frozenset([16, 32])\n    SUPPORTED_DTYPES = frozenset(\n        [\n            # (input dtype, parameter dtype)\n            (torch.float32, torch.float32),\n            (torch.float32, torch.float16),\n            (torch.float32, torch.bfloat16),\n            (torch.float16, torch.float16),\n            (torch.float16, torch.bfloat16),\n            (torch.float16, torch.float32),\n            (torch.bfloat16, torch.bfloat16),\n            (torch.bfloat16, torch.float16),\n            (torch.bfloat16, torch.float32),\n        ]\n    )\n    GN_V2_SUPPORTED_CHANNELS = frozenset(\n        [\n            # (HW, C)\n            (8 * 8, 1280),\n            (8 * 8, 2560),\n            (16 * 16, 640),\n            (16 * 16, 1280),\n            (16 * 16, 1920),\n            (16 * 16, 2560),\n            (32 * 32, 320),\n            (32 * 32, 640),\n            (32 * 32, 960),\n            (32 * 32, 1280),\n            (32 * 32, 1920),\n            (64 * 64, 320),\n            (64 * 64, 640),\n            (64 * 64, 960),\n        ]\n    )\n    GN_V2_SUPPORTED_DTYPES = frozenset(\n        [\n            # (input dtype, parameter dtype)\n            (torch.float16, torch.float16),\n            (torch.bfloat16, torch.bfloat16),\n        ]\n    )\n    GN_V2_SUPPORTED_GROUPS_SWISH = frozenset(\n        [\n            # (num_groups, with_swish)\n            (16, True),\n            (32, False),\n        ]\n    )\n    GN_V2_SUPPORTED_LOWER_BOUND_SM_COUNT = {\n        (10, 0): 148,\n    }\n\n    def __init__(\n        self,\n        num_groups: int,\n        num_channels: int,\n        eps: float = 1e-5,\n        affine: bool = True,\n        device=None,\n        dtype=None,\n        act=None,\n    ) -> None:\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        if num_channels % num_groups != 0:\n            raise ValueError(\"num_channels must be divisible by num_groups\")\n\n        self.num_groups = num_groups\n        self.num_channels = num_channels\n        self.eps = eps\n        self.affine = affine\n        self.act = act.lower() if act else act\n        if self.affine:\n            self.weight = Parameter(torch.empty(num_channels, **factory_kwargs))\n            self.bias = Parameter(torch.empty(num_channels, **factory_kwargs))\n        else:\n            self.register_parameter(\"weight\", None)\n            self.register_parameter(\"bias\", None)\n\n        self.reset_parameters()\n        sm = torch.cuda.get_device_capability(device)\n        self.sm = sm[0] * 10 + sm[1]\n\n    def reset_parameters(self) -> None:\n        if self.affine:\n            init.ones_(self.weight)\n            init.zeros_(self.bias)\n\n    def _check_legality(self, input: Tensor) -> bool:\n        is_nhwc = input.is_contiguous(memory_format=torch.channels_last)\n        is_legal_groups = self.num_groups in self.SUPPORTED_GROUPS\n        is_legal_channels = self.num_channels in self.SUPPORTED_CHANNELS\n        is_input_half_or_float_or_bf16 = input.dtype in [\n            torch.float16,\n            torch.bfloat16,\n            torch.float32,\n        ]\n        is_supported_dtype_combination = (\n            not self.affine or (input.dtype, self.weight.dtype) in self.SUPPORTED_DTYPES\n        )\n        is_legal_act = self.act in [None, \"\", \"silu\", \"swish\"]\n\n        if (\n            is_nhwc\n            and is_input_half_or_float_or_bf16\n            and is_supported_dtype_combination\n            and is_legal_act\n            and self.affine\n            and is_legal_groups\n            and is_legal_channels\n        ):\n            return True\n        else:\n            return False\n\n    def _check_v2_legality(self, input: Tensor) -> bool:\n        is_legal_channels = (\n            input.shape[2] * input.shape[3],\n            self.num_channels,\n        ) in self.GN_V2_SUPPORTED_CHANNELS\n        is_supported_groups_swish_combination = (\n            self.num_groups,\n            self.act in [\"silu\", \"swish\"],\n        ) in self.GN_V2_SUPPORTED_GROUPS_SWISH\n        is_supported_dtype_combination = (\n            self.affine and (input.dtype, self.weight.dtype) in self.GN_V2_SUPPORTED_DTYPES\n        )\n        cc, sm_count = get_cc_and_sm_count(input.device.index)\n        is_supported_sm_count = (\n            cc in self.GN_V2_SUPPORTED_LOWER_BOUND_SM_COUNT\n            and sm_count >= self.GN_V2_SUPPORTED_LOWER_BOUND_SM_COUNT[cc]\n        )\n\n        if (\n            is_legal_channels\n            and is_supported_groups_swish_combination\n            and is_supported_dtype_combination\n            and is_supported_sm_count\n        ):\n            return True\n        else:\n            return False\n\n    def forward(self, input: Tensor) -> Tensor:\n        can_use_nhwc_group_norm = self._check_legality(input)\n\n        if can_use_nhwc_group_norm:\n            channels = input.shape[1]\n            hw = 1\n            for i in range(2, len(input.shape)):\n                hw *= input.shape[i]\n            max_hw_one_pass = 1024 if self.sm >= 80 else 256\n            if (hw >= 512 and channels in (3136, 3584, 4096)) or hw > max_hw_one_pass:\n                passes = 2\n            else:\n                passes = 1\n            use_group_norm_v2 = self._check_v2_legality(input)\n            y, _ = group_norm_nhwc_fprop(\n                input,\n                self.num_groups,\n                self.weight,\n                self.bias,\n                self.eps,\n                self.act,\n                passes,\n                use_group_norm_v2,\n            )\n            return y\n        else:\n            return torch_group_norm(\n                input, self.num_groups, self.weight, self.bias, self.eps, self.act\n            )\n\n    def extra_repr(self) -> str:\n        if self.act:\n            return \"{num_groups}, {num_channels}, eps={eps}, affine={affine}, act={act}\".format(\n                **self.__dict__\n            )\n        else:\n            return \"{num_groups}, {num_channels}, eps={eps}, affine={affine}\".format(\n                **self.__dict__\n            )\n"
  },
  {
    "path": "apex/contrib/groupbn/__init__.py",
    "content": "try:\n    import torch\n    import bnp\n    from .batch_norm import BatchNorm2d_NHWC\n\n    del torch\n    del bnp\n    del batch_norm\nexcept ImportError:\n    print(\"apex was installed without --bnp flag, contrib.groupbn is not available\")\n"
  },
  {
    "path": "apex/contrib/groupbn/batch_norm.py",
    "content": "import torch\nimport numpy as np\nfrom torch.nn.modules.batchnorm import _BatchNorm\n\nimport bnp\n\n\nclass bn_NHWC_impl(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        x,\n        s,\n        b,\n        rm,\n        riv,\n        mini_m,\n        mini_riv,\n        ret_cta,\n        mom,\n        epsilon,\n        fuse_relu,\n        is_train,\n        bn_group,\n        my_data,\n        pair_data,\n        magic,\n        pair_data2,\n        pair_data3,\n        fwd_occup,\n        fwd_grid_x,\n        bwd_occup,\n        bwd_grid_x,\n        multi_stream,\n    ):\n        if is_train:\n            ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv)\n            ctx.epsilon = epsilon\n            ctx.momentum = mom\n            ctx.ret_cta = ret_cta\n            ctx.fuse_relu = fuse_relu\n            ctx.my_data = my_data\n            ctx.pair_data = pair_data\n            ctx.magic = magic\n            ctx.pair_data2 = pair_data2\n            ctx.pair_data3 = pair_data3\n            ctx.bn_group = bn_group\n            ctx.bwd_occup = bwd_occup\n            ctx.bwd_grid_x = bwd_grid_x\n            ctx.multi_stream = multi_stream\n\n            res = bnp.bn_fwd_nhwc(\n                x,\n                s,\n                b,\n                rm,\n                riv,\n                mini_m,\n                mini_riv,\n                ret_cta,\n                mom,\n                epsilon,\n                fuse_relu,\n                my_data,\n                pair_data,\n                pair_data2,\n                pair_data3,\n                bn_group,\n                magic,\n                fwd_occup,\n                fwd_grid_x,\n                multi_stream,\n            )\n            return res\n        else:\n            return bnp.bn_fwd_eval_nhwc(\n                x, s, b, rm, riv, ret_cta, bn_group, mom, epsilon, fuse_relu\n            )\n\n    @staticmethod\n    def backward(ctx, grad_y):\n        x, s, b, rm, riv, mini_m, mini_riv = ctx.saved_variables\n        epsilon = ctx.epsilon\n        mom = ctx.momentum\n        ret_cta = ctx.ret_cta\n        fuse_relu = ctx.fuse_relu\n        my_data = ctx.my_data\n        pair_data = ctx.pair_data\n        magic = ctx.magic\n        pair_data2 = ctx.pair_data2\n        pair_data3 = ctx.pair_data3\n        bn_group = ctx.bn_group\n        bwd_occup = ctx.bwd_occup\n        bwd_grid_x = ctx.bwd_grid_x\n        multi_stream = ctx.multi_stream\n\n        dx, dscale, dbias = bnp.bn_bwd_nhwc(\n            x,\n            grad_y,\n            s,\n            b,\n            rm,\n            riv,\n            mini_m,\n            mini_riv,\n            ret_cta,\n            mom,\n            epsilon,\n            fuse_relu,\n            my_data,\n            pair_data,\n            pair_data2,\n            pair_data3,\n            bn_group,\n            magic,\n            bwd_occup,\n            bwd_grid_x,\n            multi_stream,\n        )\n\n        return (\n            dx,\n            dscale,\n            dbias,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n        )\n\n\nclass bn_addrelu_NHWC_impl(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        x,\n        z,\n        s,\n        b,\n        rm,\n        riv,\n        mini_m,\n        mini_riv,\n        grid_dim_y,\n        ret_cta,\n        mom,\n        epsilon,\n        is_train,\n        bn_group,\n        my_data,\n        pair_data,\n        magic,\n        pair_data2,\n        pair_data3,\n        fwd_occup,\n        fwd_grid_x,\n        bwd_occup,\n        bwd_grid_x,\n        multi_stream,\n    ):\n        if is_train:\n            bitmask = torch.cuda.IntTensor(((x.numel() + 31) // 32) * 2 * grid_dim_y)\n            ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv, bitmask)\n            ctx.epsilon = epsilon\n            ctx.momentum = mom\n            ctx.ret_cta = ret_cta\n            ctx.my_data = my_data\n            ctx.pair_data = pair_data\n            ctx.magic = magic\n            ctx.pair_data2 = pair_data2\n            ctx.pair_data3 = pair_data3\n            ctx.bn_group = bn_group\n            ctx.bwd_occup = bwd_occup\n            ctx.bwd_grid_x = bwd_grid_x\n            ctx.multi_stream = multi_stream\n\n            res = bnp.bn_addrelu_fwd_nhwc(\n                x,\n                z,\n                s,\n                b,\n                rm,\n                riv,\n                mini_m,\n                mini_riv,\n                bitmask,\n                ret_cta,\n                mom,\n                epsilon,\n                my_data,\n                pair_data,\n                pair_data2,\n                pair_data3,\n                bn_group,\n                magic,\n                fwd_occup,\n                fwd_grid_x,\n                multi_stream,\n            )\n            return res\n        else:\n            return bnp.bn_addrelu_fwd_eval_nhwc(\n                x, z, s, b, rm, riv, ret_cta, bn_group, mom, epsilon\n            )\n\n    @staticmethod\n    def backward(ctx, grad_y):\n        x, s, b, rm, riv, mini_m, mini_riv, bitmask = ctx.saved_variables\n        epsilon = ctx.epsilon\n        mom = ctx.momentum\n        ret_cta = ctx.ret_cta\n        my_data = ctx.my_data\n        pair_data = ctx.pair_data\n        magic = ctx.magic\n        pair_data2 = ctx.pair_data2\n        pair_data3 = ctx.pair_data3\n        bn_group = ctx.bn_group\n        bwd_occup = ctx.bwd_occup\n        bwd_grid_x = ctx.bwd_grid_x\n        multi_stream = ctx.multi_stream\n\n        dx, dz, dscale, dbias = bnp.bn_addrelu_bwd_nhwc(\n            x,\n            grad_y,\n            s,\n            b,\n            rm,\n            riv,\n            mini_m,\n            mini_riv,\n            bitmask,\n            ret_cta,\n            mom,\n            epsilon,\n            my_data,\n            pair_data,\n            pair_data2,\n            pair_data3,\n            bn_group,\n            magic,\n            bwd_occup,\n            bwd_grid_x,\n            multi_stream,\n        )\n\n        return (\n            dx,\n            dz,\n            dscale,\n            dbias,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n        )\n\n\nclass BatchNorm2d_NHWC(_BatchNorm):\n    # if using BatchNorm2d_NHWC simultaneously with multiple streams set multi_stream to True\n    def __init__(\n        self,\n        num_features,\n        fuse_relu=False,\n        bn_group=1,\n        max_cta_per_sm=2,\n        cta_launch_margin=12,\n        multi_stream=False,\n    ):\n        super(BatchNorm2d_NHWC, self).__init__(num_features)\n\n        self.fuse_relu = fuse_relu\n        self.multi_stream = multi_stream\n\n        self.minibatch_mean = torch.cuda.FloatTensor(num_features)\n        self.minibatch_riv = torch.cuda.FloatTensor(num_features)\n\n        # defaut to distributed bn disabled\n        self.bn_group = bn_group\n        self.max_cta_per_sm = max_cta_per_sm  # used only in training fwd and bwd\n        self.cta_launch_margin = cta_launch_margin  # used only in training fwd and bwd\n        self.my_data = None\n        self.pair_data = None\n        self.pair_data2 = None\n        self.pair_data3 = None\n        self.local_rank = 0\n        self.magic = torch.IntTensor([0])\n\n        # calculate cta per sm occupancies\n        assert max_cta_per_sm > 0  # won't be able to do much with 0 CTAs :)\n        self.fwd_occupancy = min(bnp.bn_fwd_nhwc_occupancy(), max_cta_per_sm)\n        self.bwd_occupancy = min(bnp.bn_bwd_nhwc_occupancy(), max_cta_per_sm)\n        self.addrelu_fwd_occupancy = min(bnp.bn_addrelu_fwd_nhwc_occupancy(), max_cta_per_sm)\n        self.addrelu_bwd_occupancy = min(bnp.bn_addrelu_bwd_nhwc_occupancy(), max_cta_per_sm)\n\n        # calculate grid dimentions based on occupancy numbers\n        mp_count = torch.cuda.get_device_properties(None).multi_processor_count\n        self.fwd_grid_dim_x = max(mp_count * self.fwd_occupancy - cta_launch_margin, 1)\n        self.bwd_grid_dim_x = max(mp_count * self.bwd_occupancy - cta_launch_margin, 1)\n        self.addrelu_fwd_grid_dim_x = max(\n            mp_count * self.addrelu_fwd_occupancy - cta_launch_margin, 1\n        )\n        self.addrelu_bwd_grid_dim_x = max(\n            mp_count * self.addrelu_bwd_occupancy - cta_launch_margin, 1\n        )\n        self.grid_dim_y = (num_features + 63) // 64\n\n        # allocate scratch space used by implementation\n        # TODO: scratch space that is not supposed to be exposed at user code. We only need one time initialization, the\n        # same buffer could be reused in future iterations. Currently we exposed it here instead of requesting new\n        # buffer from cache allocator to avoid unnecessary initialization at future iterations.\n        self.ret_cta = torch.cuda.ByteTensor(8192).fill_(0)\n\n        # FIXME: turn pair handles into an array\n        if bn_group > 1:\n            local_rank = torch.distributed.get_rank()\n            world_size = torch.distributed.get_world_size()\n            assert world_size >= bn_group\n            assert world_size % bn_group == 0\n\n            bn_sync_steps = 1\n            if bn_group == 4:\n                bn_sync_steps = 2\n            if bn_group == 8:\n                bn_sync_steps = 3\n\n            self.ipc_buffer = torch.cuda.ByteTensor(bnp.get_buffer_size(bn_sync_steps))\n            self.my_data = bnp.get_data_ptr(self.ipc_buffer)\n            # we are walking on very thin ice here by utilizing internal `_share_cuda_()`\n            self.storage = self.ipc_buffer.storage()\n            self.share_cuda = self.storage._share_cuda_()\n            internal_cuda_mem = self.share_cuda\n            # internal_cuda_mem[1]: ipc_mem_handle\n            my_handle = torch.cuda.ByteTensor(np.frombuffer(internal_cuda_mem[1], dtype=np.uint8))\n            # internal_cuda_mem[3]: offset\n            my_offset = torch.cuda.IntTensor([internal_cuda_mem[3]])\n\n            handles_all = torch.empty(\n                world_size,\n                my_handle.size(0),\n                dtype=my_handle.dtype,\n                device=my_handle.device,\n            )\n            handles_l = list(handles_all.unbind(0))\n            torch.distributed.all_gather(handles_l, my_handle)\n\n            offsets_all = torch.empty(\n                world_size,\n                my_offset.size(0),\n                dtype=my_offset.dtype,\n                device=my_offset.device,\n            )\n            offsets_l = list(offsets_all.unbind(0))\n            torch.distributed.all_gather(offsets_l, my_offset)\n\n            # whom do I actually care about? that would be local_rank XOR 1\n            self.pair_handle = handles_l[local_rank ^ 1].cpu().contiguous()\n            pair_offset = offsets_l[local_rank ^ 1].cpu()\n            self.pair_data = bnp.get_remote_data_ptr(self.pair_handle, pair_offset)\n\n            if bn_group > 2:\n                self.pair_handle2 = handles_l[local_rank ^ 2].cpu().contiguous()\n                pair_offset2 = offsets_l[local_rank ^ 2].cpu()\n                self.pair_data2 = bnp.get_remote_data_ptr(self.pair_handle2, pair_offset2)\n\n            if bn_group > 4:\n                self.pair_handle3 = handles_l[local_rank ^ 4].cpu().contiguous()\n                pair_offset3 = offsets_l[local_rank ^ 4].cpu()\n                self.pair_data3 = bnp.get_remote_data_ptr(self.pair_handle3, pair_offset3)\n\n            # FIXME: get magic value into C code and eliminate from here\n            self.magic = torch.IntTensor([2])\n            self.local_rank = local_rank\n\n    def forward(self, x, z=None):\n        if z is not None:\n            assert self.fuse_relu == True\n            return bn_addrelu_NHWC_impl.apply(\n                x,\n                z,\n                self.weight,\n                self.bias,\n                self.running_mean,\n                self.running_var,\n                self.minibatch_mean,\n                self.minibatch_riv,\n                self.grid_dim_y,\n                self.ret_cta,\n                self.momentum,\n                self.eps,\n                self.training,\n                self.bn_group,\n                self.my_data,\n                self.pair_data,\n                (self.magic),\n                self.pair_data2,\n                self.pair_data3,\n                self.addrelu_fwd_occupancy,\n                self.addrelu_fwd_grid_dim_x,\n                self.addrelu_bwd_occupancy,\n                self.addrelu_bwd_grid_dim_x,\n                self.multi_stream,\n            )\n        else:\n            return bn_NHWC_impl.apply(\n                x,\n                self.weight,\n                self.bias,\n                self.running_mean,\n                self.running_var,\n                self.minibatch_mean,\n                self.minibatch_riv,\n                self.ret_cta,\n                self.momentum,\n                self.eps,\n                self.fuse_relu,\n                self.training,\n                self.bn_group,\n                self.my_data,\n                self.pair_data,\n                (self.magic),\n                self.pair_data2,\n                self.pair_data3,\n                self.fwd_occupancy,\n                self.fwd_grid_dim_x,\n                self.bwd_occupancy,\n                self.bwd_grid_dim_x,\n                self.multi_stream,\n            )\n\n    def __del__(self):\n        if self.bn_group > 1:\n            bnp.close_remote_data(self.pair_handle)\n            if self.bn_group > 2:\n                bnp.close_remote_data(self.pair_handle2)\n                if self.bn_group > 4:\n                    bnp.close_remote_data(self.pair_handle3)\n"
  },
  {
    "path": "apex/contrib/index_mul_2d/__init__.py",
    "content": "from .index_mul_2d import index_mul_2d\n"
  },
  {
    "path": "apex/contrib/index_mul_2d/index_mul_2d.py",
    "content": "import torch\n\nimport fused_index_mul_2d\n\n\nclass IndexMul2d_(torch.autograd.Function):\n    \"\"\"\n    Currently only support index in dimension 0 with a 2-dimension tensor.\n    The shape of indexed in1 must be same with in2. Now this kernel does not support broadcast.\n    The datatype must be float32 or float16.\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, in1: torch.Tensor, in2: torch.Tensor, idx1: torch.Tensor) -> torch.Tensor:\n        assert in2.size(0) == idx1.size(0)\n        if (in1.dtype != torch.float32 and in1.dtype != torch.half) or in2.dtype != in1.dtype:\n            raise RuntimeError(\n                \"input1'dtype and input2's dtype must be fp32 or fp16. And input type must be same\"\n            )\n        if in1.dim() != 2 or in2.dim() != 2:\n            raise RuntimeError(\"in1 and in2 must be 2-dimension tensor.\")\n        if idx1.dim() != 1:\n            raise RuntimeError(\"idx1 must be 1-dimension tensor.\")\n\n        if not in1.is_contiguous():\n            in1 = in1.contiguous()\n        if not in2.is_contiguous():\n            in2 = in2.contiguous()\n        if not idx1.is_contiguous():\n            idx1 = idx1.contiguous()\n\n        assert in1.is_contiguous()\n        assert in2.is_contiguous()\n        assert idx1.is_contiguous()\n\n        out = torch.empty_like(in2)\n\n        if in1.dtype == torch.float32:\n            fused_index_mul_2d.float_forward(out, in1, in2, idx1)\n        elif in1.dtype == torch.half:\n            fused_index_mul_2d.half_forward(out, in1, in2, idx1)\n\n        ctx.for_backwards = (in1, in2, idx1)\n        return out\n\n    @staticmethod\n    def backward(ctx, grad_out):\n        in1, in2, idx1 = ctx.for_backwards\n\n        grad_in1, grad_in2 = index_mul_2d_backward(in1, in2, idx1, grad_out)\n\n        return grad_in1, grad_in2, None\n\n\nclass IndexMul2dBackward_(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        in1: torch.Tensor,\n        in2: torch.Tensor,\n        idx1: torch.Tensor,\n        grad_out: torch.Tensor,\n    ) -> torch.Tensor:\n        if not in1.is_contiguous():\n            in1 = in1.contiguous()\n        if not in2.is_contiguous():\n            in2 = in2.contiguous()\n        if not idx1.is_contiguous():\n            idx1 = idx1.contiguous()\n        if not grad_out.is_contiguous():\n            grad_out = grad_out.contiguous()\n\n        assert in1.is_contiguous()\n        assert in2.is_contiguous()\n        assert idx1.is_contiguous()\n        assert grad_out.is_contiguous()\n\n        grad_in1 = torch.zeros_like(in1)\n        grad_in2 = torch.empty_like(in2)\n\n        if in1.dtype == torch.float32:\n            fused_index_mul_2d.float_backward(grad_in1, grad_in2, grad_out, in1, in2, idx1)\n        elif in1.dtype == torch.half:\n            fused_index_mul_2d.half_backward(grad_in1, grad_in2, grad_out, in1, in2, idx1)\n\n        ctx.for_backwards = (in1, in2, idx1, grad_out)\n        return grad_in1, grad_in2\n\n    @staticmethod\n    def backward(ctx, grad_grad_in1, grad_grad_in2):\n        if not grad_grad_in1.is_contiguous():\n            grad_grad_in1 = grad_grad_in1.contiguous()\n        if not grad_grad_in2.is_contiguous():\n            grad_grad_in2 = grad_grad_in2.contiguous()\n\n        assert grad_grad_in1.is_contiguous()\n        assert grad_grad_in2.is_contiguous()\n\n        in1, in2, idx1, grad_out = ctx.for_backwards\n\n        grad_in1 = torch.zeros_like(in1)\n        grad_in2 = torch.empty_like(in2)\n        grad_grad_out = torch.empty_like(grad_out)\n\n        if in1.dtype == torch.float32:\n            fused_index_mul_2d.float_backward_backward(\n                grad_grad_out,\n                grad_in1,\n                grad_in2,\n                grad_out,\n                grad_grad_in1,\n                grad_grad_in2,\n                in1,\n                in2,\n                idx1,\n            )\n        elif in1.dtype == torch.half:\n            fused_index_mul_2d.half_backward_backward(\n                grad_grad_out,\n                grad_in1,\n                grad_in2,\n                grad_out,\n                grad_grad_in1,\n                grad_grad_in2,\n                in1,\n                in2,\n                idx1,\n            )\n\n        return grad_in1, grad_in2, None, grad_grad_out\n\n\nindex_mul_2d = IndexMul2d_.apply\nindex_mul_2d_backward = IndexMul2dBackward_.apply\n"
  },
  {
    "path": "apex/contrib/layer_norm/__init__.py",
    "content": "from .layer_norm import FastLayerNorm\n"
  },
  {
    "path": "apex/contrib/layer_norm/layer_norm.py",
    "content": "import torch\nfrom torch.nn import init\n\nfrom apex._autocast_utils import _cast_if_autocast_enabled\nimport fast_layer_norm\n\n\nclass FastLayerNormFN(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, x, gamma, beta, epsilon, memory_efficient=False):\n        ctx.x_shape = x.shape\n        ctx.memory_efficient = memory_efficient\n\n        x = x.contiguous()\n        gamma = gamma.contiguous()\n        beta = beta.contiguous()\n        hidden_size = gamma.numel()\n        xmat = x.view((-1, hidden_size))\n        ymat, mu, rsigma = fast_layer_norm.ln_fwd(xmat, gamma, beta, epsilon)\n        if ctx.memory_efficient:\n            ctx.save_for_backward(ymat, gamma, None, rsigma, beta)\n        else:\n            ctx.save_for_backward(xmat, gamma, mu, rsigma, None)\n        return ymat.view(x.shape)\n\n    @staticmethod\n    def backward(ctx, dy):\n        # assert dy.is_contiguous()\n        dy = dy.contiguous()  # this happens!\n        x_or_y_mat, gamma, mu, rsigma, beta = ctx.saved_tensors\n        dymat = dy.view(x_or_y_mat.shape)\n        dxmat, dgamma, dbeta, _, _ = fast_layer_norm.ln_bwd(\n            dymat, x_or_y_mat, mu, rsigma, gamma, beta, ctx.memory_efficient\n        )\n        dx = dxmat.view(ctx.x_shape)\n        return dx, dgamma, dbeta, None, None\n\n\ndef _fast_layer_norm(x, weight, bias, epsilon, memory_efficient):\n    args = _cast_if_autocast_enabled(x, weight, bias, epsilon, memory_efficient)\n    with torch.amp.autocast(\"cuda\", enabled=False):\n        return FastLayerNormFN.apply(*args)\n\n\nclass FastLayerNorm(torch.nn.Module):\n    def __init__(self, hidden_size, eps=1e-5, memory_efficient=False):\n        super().__init__()\n        self.epsilon = eps\n        self.memory_efficient = memory_efficient\n        self.weight = torch.nn.Parameter(torch.empty(hidden_size))\n        self.bias = torch.nn.Parameter(torch.empty(hidden_size))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        init.ones_(self.weight)\n        init.zeros_(self.bias)\n\n    def forward(self, x):\n        return _fast_layer_norm(x, self.weight, self.bias, self.epsilon, self.memory_efficient)\n"
  },
  {
    "path": "apex/contrib/multihead_attn/README.md",
    "content": "# Fast Multihead Attention \n\nThis implementation has two main features :\n* A C++ implementation to avoid the CPU overheads of Pytorch found with smaller batch sizes.\n* The removal of all copies and transposes found in standard implementations of Multihead Attention.\n\n|                                            | Python Version | C++ Version |\n| :----------------------------------------- | :------------: | :---------: |\n| Layer Norm and Residual Add Variant        | X              | X           |\n| Includes Linear Biases                     | X              |             |\n| Reduces CPU Overheads                      |                | X           |\n| Fuses masking with Softmax                 |                | X           |\n| Removes Transposes and Copies              | X              | X           |\n| Includes Self and Encoder/Decoder Variants | X              | X           |\n\n## How to Instantiate\n\n`SelfMultiheadAttn(` _hidden dim_, _heads_, _dropout=prob_, _bias=bool_, _include_norm_add=bool_, _impl='fast'_ `)`\n`EncdecMultiheadAttn(` _hidden dim_, _heads_, _dropout=prob_, _bias=bool_, _include_norm_add=bool_, _impl='fast'_ `)`\n\n `impl` has two options:\n * `fast` uses C++ Version\n * `default` uses Python Version\n\n## Instructions to build on Linux\n\n```\n$ git clone https://github.com/NVIDIA/apex\n$ cd apex\n$ pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" --global-option=\"--fast_multihead_attn\" ./\n```\n## Try Performance Tests Yourself!\nPerf test script is found here!\n```\ncd contrib/examples/multihead_attn\n```\n#### Fast Multihead Attention\n```\npython perf_test_multihead_attn.py --ref\n```\n#### Fast Multihead Attention with C++ Implementation\n```\npython perf_test_multihead_attn.py\n```\n#### Compare with `torch.nn.MultiheadAttn`\n```\npython perf_test_multihead_attn.py --native\n```\n#### Test your own range!\n```\npython perf_test_multihead_attn.py --seq-length 64 --num-seqs-start 10 --num-seqs-stop 120 --num-seqs-inc 5\n```\n\n## Performance Comparisons\n\n* Performance was measured with 64 token sequence lengths on an NVIDIA TitanV card.\n* Time is measured across multiple layers to simulate an in model scenario.\n\n![Multihead Attention Forward](MHA_fwd.png)\n![Multihead Attention Backward](MHA_bwd.png)\n"
  },
  {
    "path": "apex/contrib/multihead_attn/__init__.py",
    "content": "from .self_multihead_attn import SelfMultiheadAttn\nfrom .encdec_multihead_attn import EncdecMultiheadAttn\nfrom .mask_softmax_dropout_func import fast_mask_softmax_dropout_func\n"
  },
  {
    "path": "apex/contrib/multihead_attn/encdec_multihead_attn.py",
    "content": "import math\n\nimport torch\nfrom torch import nn\nfrom torch.nn import Parameter\nimport torch.nn.functional as F\n\nfrom .encdec_multihead_attn_func import encdec_attn_func\nfrom .fast_encdec_multihead_attn_func import fast_encdec_attn_func\nfrom .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func\nfrom apex.normalization.fused_layer_norm import FusedLayerNorm\n\n\n@torch.jit.script\ndef jit_dropout_add(x, residual, prob, is_training):\n    # type: (Tensor, Tensor, float, bool) -> Tensor\n    out = F.dropout(x, p=prob, training=True)\n    out = residual + out\n    return out\n\n\nclass EncdecMultiheadAttn(nn.Module):\n    \"\"\"Multi-headed attention.\n\n    See \"Attention Is All You Need\" for more details.\n    \"\"\"\n\n    def __init__(\n        self,\n        embed_dim,\n        num_heads,\n        dropout=0.0,\n        bias=False,\n        include_norm_add=False,\n        impl=\"fast\",\n    ):\n        super().__init__()\n        self.embed_dim = embed_dim\n        self.num_heads = num_heads\n        self.dropout = dropout\n        self.head_dim = embed_dim // num_heads\n        assert self.head_dim * num_heads == self.embed_dim, (\n            \"embed_dim must be divisible by num_heads\"\n        )\n        self.bias = bias\n        self.include_norm_add = include_norm_add\n        self.impl = impl\n        self.scaling = self.head_dim**-0.5\n\n        self.in_proj_weight_q = Parameter(torch.empty(embed_dim, embed_dim))\n        self.in_proj_weight_kv = Parameter(torch.empty(2 * embed_dim, embed_dim))\n        self.out_proj_weight = Parameter(torch.empty(embed_dim, embed_dim))\n        if self.bias:\n            assert impl != \"fast\", \"ERROR! The Fast implementation does not support biases!\"\n            self.in_proj_bias_q = Parameter(torch.empty(embed_dim))\n            self.in_proj_bias_kv = Parameter(torch.empty(2 * embed_dim))\n            self.out_proj_bias = Parameter(torch.empty(embed_dim))\n        else:\n            self.register_parameter(\"in_proj_bias_q\", None)\n            self.register_parameter(\"in_proj_bias_kv\", None)\n            self.in_proj_bias_q = None\n            self.in_proj_bias_kv = None\n            self.out_proj_bias = None\n        if self.include_norm_add:\n            if impl == \"fast\":\n                self.lyr_nrm_gamma_weights = Parameter(torch.empty(embed_dim))\n                self.lyr_nrm_beta_weights = Parameter(torch.empty(embed_dim))\n                self.lyr_nrm = None\n            else:\n                self.register_parameter(\"lyr_norm_gamma_weights\", None)\n                self.register_parameter(\"lyr_norm_beta_weights\", None)\n                self.lyr_nrm_gamma_weights = None\n                self.lyr_nrm_beta_weights = None\n                self.lyr_nrm = FusedLayerNorm(embed_dim)\n        self.reset_parameters()\n\n        if self.include_norm_add:\n            if impl == \"fast\":\n                self.attn_func = fast_encdec_attn_norm_add_func\n            elif impl == \"default\":\n                self.attn_func = encdec_attn_func\n            else:\n                assert False, \"Unsupported impl: {} !\".format(impl)\n        else:\n            if impl == \"fast\":\n                self.attn_func = fast_encdec_attn_func\n            elif impl == \"default\":\n                self.attn_func = encdec_attn_func\n            else:\n                assert False, \"Unsupported impl: {} !\".format(impl)\n\n    def reset_parameters(self):\n        nn.init.xavier_uniform_(self.in_proj_weight_q)\n        # in_proj_weight_kv has shape [2 * hidden, hidden] but it should be\n        # initialized like a [hidden, hidden] matrix.\n        # sqrt(6 / (hidden + hidden)) / sqrt(6 / (2 * hidden + hidden)) = sqrt(1.5)\n        # therefore xavier_uniform gain should be set to sqrt(1.5).\n        nn.init.xavier_uniform_(self.in_proj_weight_kv, gain=math.sqrt(1.5))\n        nn.init.xavier_uniform_(self.out_proj_weight)\n        if self.bias:\n            nn.init.constant_(self.in_proj_bias_q, 0.0)\n            nn.init.constant_(self.in_proj_bias_kv, 0.0)\n            nn.init.constant_(self.out_proj_bias, 0.0)\n        if self.include_norm_add:\n            if self.impl == \"fast\":\n                nn.init.ones_(self.lyr_nrm_gamma_weights)\n                nn.init.zeros_(self.lyr_nrm_beta_weights)\n            else:\n                self.lyr_nrm.reset_parameters()\n\n    def forward(\n        self,\n        query,\n        key,\n        value,\n        key_padding_mask=None,\n        need_weights=False,\n        attn_mask=None,\n        is_training=True,\n    ):\n        \"\"\"Input shape: Time x Batch x Channel\n\n        Self-attention can be implemented by passing in the same arguments for\n        query, key and value. Future timesteps can be masked with the\n        `mask_future_timesteps` argument. Padding elements can be excluded from\n        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:\n        batch x src_len, where padding elements are indicated by 1s.\n        \"\"\"\n\n        if key_padding_mask is not None:\n            assert attn_mask is None, (\n                \"ERROR attn_mask and key_padding_mask should not be both defined!\"\n            )\n            mask = key_padding_mask\n        elif attn_mask is not None:\n            mask = attn_mask\n        else:\n            mask = None\n\n        if self.include_norm_add:\n            if self.impl == \"fast\":\n                outputs = self.attn_func(\n                    attn_mask is not None,\n                    is_training,\n                    self.num_heads,\n                    query,\n                    key,\n                    self.lyr_nrm_gamma_weights,\n                    self.lyr_nrm_beta_weights,\n                    self.in_proj_weight_q,\n                    self.in_proj_weight_kv,\n                    self.out_proj_weight,\n                    mask,\n                    self.dropout,\n                )\n            else:\n                lyr_nrm_results = self.lyr_nrm(query)\n                outputs = self.attn_func(\n                    attn_mask is not None,\n                    is_training,\n                    self.num_heads,\n                    self.scaling,\n                    lyr_nrm_results,\n                    key,\n                    self.in_proj_weight_q,\n                    self.in_proj_weight_kv,\n                    self.out_proj_weight,\n                    self.in_proj_bias_q,\n                    self.in_proj_bias_kv,\n                    self.out_proj_bias,\n                    mask,\n                    self.dropout,\n                )\n                if is_training:\n                    outputs = jit_dropout_add(outputs, query, self.dropout, is_training)\n                else:\n                    outputs = outputs + query\n        else:\n            if self.impl == \"fast\":\n                outputs = self.attn_func(\n                    attn_mask is not None,\n                    is_training,\n                    self.num_heads,\n                    query,\n                    key,\n                    self.in_proj_weight_q,\n                    self.in_proj_weight_kv,\n                    self.out_proj_weight,\n                    mask,\n                    self.dropout,\n                )\n            else:\n                outputs = self.attn_func(\n                    attn_mask is not None,\n                    is_training,\n                    self.num_heads,\n                    self.scaling,\n                    query,\n                    key,\n                    self.in_proj_weight_q,\n                    self.in_proj_weight_kv,\n                    self.out_proj_weight,\n                    self.in_proj_bias_q,\n                    self.in_proj_bias_kv,\n                    self.out_proj_bias,\n                    mask,\n                    self.dropout,\n                )\n\n        return outputs, None\n"
  },
  {
    "path": "apex/contrib/multihead_attn/encdec_multihead_attn_func.py",
    "content": "import torch\nimport torch.nn.functional as F\n\n\nclass EncdecAttnFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        use_time_mask,\n        is_training,\n        heads,\n        scale,\n        inputs_q,\n        inputs_kv,\n        input_weights_q,\n        input_weights_kv,\n        output_weights,\n        input_biases_q,\n        input_biases_kv,\n        output_biases,\n        mask,\n        dropout_prob,\n    ):\n        from apex import deprecated_warning\n\n        deprecated_warning(\n            \"`apex.contrib.multihead_attn` is deprecated and will be removed in July 2026. \"\n            \"We encourage you to migrate to PyTorch native MultiheadAttention\"\n            \"The documentation is available in https://docs.pytorch.org/docs/main/generated/torch.nn.MultiheadAttention.html\"\n        )\n\n        use_biases_t = torch.tensor([input_biases_q is not None])\n        heads_t = torch.tensor([heads])\n        scale_t = torch.tensor([scale])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor = torch.tensor([])\n        head_dim = inputs_q.size(2) // heads\n\n        # Input Linear GEMM Q\n        # input1: (activations) [seql_q, seqs, embed_dim(1024)]\n        # input2: (weights)     [embed_dim (1024), embed_dim (1024)] (transpose [0,1])\n        # output:               [seql_q, seqs, embed_dim]\n        # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)\n        if use_biases_t[0]:\n            input_lin_q_results = torch.addmm(\n                input_biases_q,\n                inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),\n                input_weights_q.transpose(0, 1),\n                beta=1.0,\n                alpha=1.0,\n            )\n        else:\n            input_lin_q_results = torch.mm(\n                inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),\n                input_weights_q.transpose(0, 1),\n            )\n        input_lin_q_results = input_lin_q_results.view(\n            inputs_q.size(0), inputs_q.size(1), input_weights_q.size(0)\n        )\n        # Input Linear GEMM KV\n        # input1: (activations) [seql_k, seqs, embed_dim(1024)]\n        # input2: (weights)     [embed_dim*2 (2048), embed_dim (1024)] (transpose [0,1])\n        # output:               [seql_k, seqs, embed_dim*2]\n        # GEMM: ( (seql_k*seqs) x embed_dim ) x ( embed_dim x embed_dim*2 ) = (seql_k*seqs x embed_dim*2)\n        if use_biases_t[0]:\n            input_lin_kv_results = torch.addmm(\n                input_biases_kv,\n                inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)),\n                input_weights_kv.transpose(0, 1),\n                beta=1.0,\n                alpha=1.0,\n            )\n        else:\n            input_lin_kv_results = torch.mm(\n                inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)),\n                input_weights_kv.transpose(0, 1),\n            )\n        input_lin_kv_results = input_lin_kv_results.view(\n            inputs_kv.size(0), inputs_kv.size(1), input_weights_kv.size(0)\n        )\n\n        # Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)\n        # Sequences and heads are combined to make the batch of the Batched GEMM\n        # input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]\n        # input_lin_kv_results: [seql_k, batches=seqs*heads, 2, head_dim]\n        queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1) * heads, head_dim)\n        input_lin_kv_results = input_lin_kv_results.view(\n            inputs_kv.size(0), inputs_kv.size(1) * heads, 2, head_dim\n        )\n        keys = input_lin_kv_results[:, :, 0, :]\n        values = input_lin_kv_results[:, :, 1, :]\n\n        # Matmul1 Batched GEMMs\n        # The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification\n        # baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of\n        # a separate elementwise operation.\n        # Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1)\n        # Input2: (Keys)    [seql_k, seqs*heads, head_dim] transpose(0,1)\n        # output:           [seqs*heads, seql_q, seql_k]\n        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )\n        matmul1_results = torch.empty(\n            (queries.size(1), queries.size(0), keys.size(0)),\n            dtype=queries.dtype,\n            device=torch.device(\"cuda\"),\n        )\n        matmul1_results = torch.baddbmm(\n            matmul1_results,\n            queries.transpose(0, 1),\n            keys.transpose(0, 1).transpose(1, 2),\n            out=matmul1_results,\n            beta=0.0,\n            alpha=scale_t[0],\n        )\n\n        if mask is not None:\n            # Self Attention Time Mask\n            if use_time_mask:\n                assert len(mask.size()) == 2, \"Timing mask is not 2D!\"\n                assert mask.size(0) == mask.size(1), \"Sequence length should match!\"\n                mask = mask.to(torch.bool)\n                matmul1_results = matmul1_results.masked_fill_(mask, float(\"-inf\"))\n            # Key Padding Mask\n            else:\n                batches, seql_q, seql_k = matmul1_results.size()\n                seqs = int(batches / heads)\n                matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)\n                mask = mask.to(torch.bool)\n                matmul1_results = matmul1_results.masked_fill_(\n                    mask.unsqueeze(1).unsqueeze(2), float(\"-inf\")\n                )\n                matmul1_results = matmul1_results.view(seqs * heads, seql_q, seql_k)\n\n        softmax_results = F.softmax(matmul1_results, dim=-1)\n\n        # Dropout - is not executed for inference\n        if is_training:\n            dropout_results, dropout_mask = torch._fused_dropout(\n                softmax_results, p=(1.0 - dropout_prob_t[0])\n            )\n        else:\n            dropout_results = softmax_results\n            dropout_mask = null_tensor\n\n        # Matmul2 Batched GEMMs\n        # The output tensor specification is needed here to specify the non-standard output.\n        # Given that pytorch cannot currently perform autograd with an output tensor specified,\n        # this requires a backward pass specified.\n        # Input1: from_softmax [seqs*heads, seql_q, seql_k]\n        # Input2: (values)     [seql_v, seqs*heads, head_dim] transpose(0,1)\n        # Output:              [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = (seql_q x head_dim)\n        matmul2_results = torch.empty(\n            (dropout_results.size(1), dropout_results.size(0), values.size(2)),\n            dtype=dropout_results.dtype,\n            device=torch.device(\"cuda\"),\n        ).transpose(1, 0)\n        matmul2_results = torch.bmm(dropout_results, values.transpose(0, 1), out=matmul2_results)\n        matmul2_results = (\n            matmul2_results.transpose(0, 1)\n            .contiguous()\n            .view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))\n        )\n\n        # Output Linear GEMM\n        # Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]\n        # Input2: (weights)     [ embed_dim, embed_dim ] transpose(0,1)\n        # Output:               [ seql_q, seqs, embed_dim ]\n        # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )\n        if use_biases_t[0]:\n            outputs = torch.addmm(\n                output_biases,\n                matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),\n                output_weights.transpose(0, 1),\n                beta=1.0,\n                alpha=1.0,\n            )\n        else:\n            outputs = torch.mm(\n                matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),\n                output_weights.transpose(0, 1),\n            )\n        outputs = outputs.view(inputs_q.size(0), inputs_q.size(1), output_weights.size(0))\n\n        ctx.save_for_backward(\n            use_biases_t,\n            heads_t,\n            scale_t,\n            matmul2_results,\n            dropout_results,\n            softmax_results,\n            input_lin_q_results,\n            input_lin_kv_results,\n            inputs_q,\n            inputs_kv,\n            input_weights_q,\n            input_weights_kv,\n            output_weights,\n            dropout_mask,\n            dropout_prob_t,\n        )\n\n        return outputs.detach()\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        (\n            use_biases_t,\n            heads_t,\n            scale_t,\n            matmul2_results,\n            dropout_results,\n            softmax_results,\n            input_lin_q_results,\n            input_lin_kv_results,\n            inputs_q,\n            inputs_kv,\n            input_weights_q,\n            input_weights_kv,\n            output_weights,\n            dropout_mask,\n            dropout_prob_t,\n        ) = ctx.saved_tensors\n\n        head_dim = inputs_q.size(2) // heads_t[0]\n\n        # Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)\n        # Sequences and heads are combined to make the batch of the Batched GEMM\n        # input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]\n        # input_lin_kv_results: [seql_k, batches=seqs*heads, 2, head_dim]\n        queries = input_lin_q_results.view(\n            inputs_q.size(0), inputs_q.size(1) * heads_t[0], head_dim\n        )\n        input_lin_kv_results = input_lin_kv_results.view(\n            inputs_kv.size(0), inputs_kv.size(1) * heads_t[0], 2, head_dim\n        )\n        keys = input_lin_kv_results[:, :, 0, :]\n        values = input_lin_kv_results[:, :, 1, :]\n\n        # Slice out k,v from one big set of gradients entering the input linear's bprop  (should only impact meta data, no copies!)\n        # The gradients are identical in size to the Input Linear outputs.\n        # The tensor is declared before hand to properly slice out query, key, and value grads.\n        input_lin_kv_results_grads = torch.empty_like(input_lin_kv_results)\n        queries_grads = torch.empty_like(queries)\n        keys_grads = input_lin_kv_results_grads[:, :, 0, :]\n        values_grads = input_lin_kv_results_grads[:, :, 1, :]\n\n        # Output Linear GEMM - DGRAD\n        # Input1: (data grads)  [seql_q, seqs, embed_dim=heads*head_dim]\n        # Input2: (weights)     [ embed_dim, embed_dim ]\n        # Output:               [ seql_q, seqs, embed_dim ]\n        # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )\n        output_lin_grads = torch.mm(\n            output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)),\n            output_weights,\n        )\n        output_lin_grads = output_lin_grads.view(\n            output_grads.size(0), output_grads.size(1), output_weights.size(1)\n        )\n        # Output Linear GEMM - WGRAD\n        # Input1: (data grads)  [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1)\n        # Input2: (activations) [seql_q*seqs, embed_dim ]\n        # Output:               [ seql_q, seqs, embed_dim ]\n        # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )\n        output_weight_grads = torch.mm(\n            output_grads.view(\n                output_grads.size(0) * output_grads.size(1), output_grads.size(2)\n            ).transpose(0, 1),\n            matmul2_results.view(\n                matmul2_results.size(0) * matmul2_results.size(1),\n                matmul2_results.size(2),\n            ),\n        )\n        output_lin_grads = output_lin_grads.view(\n            output_grads.size(0), output_grads.size(1) * heads_t[0], head_dim\n        ).transpose(0, 1)\n\n        if use_biases_t[0]:\n            output_bias_grads = torch.sum(\n                output_grads.view(\n                    output_grads.size(0) * output_grads.size(1), output_grads.size(2)\n                ),\n                0,\n            )\n        else:\n            output_bias_grads = None\n\n        # Matmul2 - DGRAD1\n        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)\n        # Output:               [seqs*heads, seql_q, seql_k]\n        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )\n        matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0, 1).transpose(1, 2))\n        # Matmul2 - DGRAD2\n        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)\n        # Output:               [seqs*heads, seql_q, seql_k]\n        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )\n        values_grads = torch.bmm(\n            dropout_results.transpose(1, 2),\n            output_lin_grads,\n            out=values_grads.transpose(0, 1),\n        )\n\n        # Mask and Scaling for Dropout (not a publically documented op)\n        dropout_grads = torch._masked_scale(\n            matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0])\n        )\n\n        # Softmax Grad (not a publically documented op)\n        softmax_grads = torch._softmax_backward_data(\n            dropout_grads, softmax_results, -1, softmax_results.dtype\n        )\n\n        # Matmul1 - DGRAD1\n        # Input1: (data grads)  [seqs*heads, seql_q, seql_k]\n        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)\n        # Output:               [seqs*heads, seql_q, head_dim] transpose(0,1)\n        # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )\n        queries_grads = torch.baddbmm(\n            queries_grads.transpose(0, 1),\n            softmax_grads,\n            keys.transpose(0, 1),\n            out=queries_grads.transpose(0, 1),\n            beta=0.0,\n            alpha=scale_t[0],\n        )\n        # Matmul1 - DGRAD2\n        # Input1: (data grads)  [seqs*heads, seql_q, seql_k] transpose(1,2)\n        # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # Output:               [seqs*heads, seql_k, head_dim] transpose(0,1)\n        # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )\n        keys_grads = torch.baddbmm(\n            keys_grads.transpose(0, 1),\n            softmax_grads.transpose(1, 2),\n            queries.transpose(0, 1),\n            out=keys_grads.transpose(0, 1),\n            beta=0.0,\n            alpha=scale_t[0],\n        )\n\n        # Input Q Linear GEMM - DGRAD\n        # input1: (data grads) [seql_q, seqs, embed_dim(1024)]\n        # input2: (weights)    [embed_dim (1024), embed_dim (1024)]\n        # output:              [seql_q, seqs, embed_dim]\n        # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)\n        queries_grads = queries_grads.transpose(0, 1).view(\n            inputs_q.size(0) * inputs_q.size(1), heads_t[0] * head_dim\n        )\n        input_q_grads = torch.mm(queries_grads, input_weights_q)\n        input_q_grads = input_q_grads.view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))\n        # Input KV Linear GEMM - DGRAD\n        # input1: (data grads) [seql_k, seqs, 2*embed_dim(2048)]\n        # input2: (weights)    [embed_dim*2 (2048), embed_dim (1024)]\n        # output:              [seql_k, seqs, embed_dim]\n        # GEMM: ( (seql_k*seqs) x 2*embed_dim ) x ( 2*embed_dim x embed_dim ) = (seql_k*seqs x embed_dim)\n        input_lin_kv_results_grads = input_lin_kv_results_grads.view(\n            inputs_kv.size(0) * inputs_kv.size(1), heads_t[0] * 2 * head_dim\n        )\n        input_kv_grads = torch.mm(input_lin_kv_results_grads, input_weights_kv)\n        input_kv_grads = input_kv_grads.view(\n            inputs_kv.size(0), inputs_kv.size(1), inputs_kv.size(2)\n        )\n        # Input Q Linear GEMM - WGRAD\n        # input1: (data grads)  [seql_q*seqs, embed_dim(1024)]\n        # input2: (activations) [seql_q*seqs, embed_dim(1024)]\n        # output:               [embed_dim, embed_dim]\n        # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (embed_dim x embed_dim)\n        input_weight_q_grads = torch.mm(\n            queries_grads.transpose(0, 1),\n            inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),\n        )\n        # Input KV Linear GEMM - WGRAD\n        # input1: (data grads)  [seql_k*seqs, 2*embed_dim(2048)]\n        # input2: (activations) [seql_k*seqs, embed_dim(1024)]\n        # output:               [2*embed_dim, embed_dim]\n        # GEMM: ( 2*embed_dim x seql_k*seqs ) x ( seql_k*seqs x embed_dim ) = (2*embed_dim x embed_dim)\n        input_weight_kv_grads = torch.mm(\n            input_lin_kv_results_grads.transpose(0, 1),\n            inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)),\n        )\n\n        if use_biases_t[0]:\n            input_bias_grads_q = torch.sum(queries_grads, 0)\n            input_bias_grads_kv = torch.sum(input_lin_kv_results_grads, 0)\n        else:\n            input_bias_grads_q = None\n            input_bias_grads_kv = None\n\n        return (\n            None,\n            None,\n            None,\n            None,\n            input_q_grads,\n            input_kv_grads,\n            input_weight_q_grads,\n            input_weight_kv_grads,\n            output_weight_grads,\n            input_bias_grads_q,\n            input_bias_grads_kv,\n            output_bias_grads,\n            None,\n            None,\n        )\n\n\nencdec_attn_func = EncdecAttnFunc.apply\n"
  },
  {
    "path": "apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py",
    "content": "import torch\n\nimport fast_multihead_attn\n\n\nclass FastEncdecAttnFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        use_time_mask,\n        is_training,\n        heads,\n        inputs_q,\n        inputs_kv,\n        input_weights_q,\n        input_weights_kv,\n        output_weights,\n        pad_mask,\n        dropout_prob,\n    ):\n        from apex import deprecated_warning\n\n        deprecated_warning(\n            \"`apex.contrib.multihead_attn` is deprecated and will be removed in July 2026. \"\n            \"We encourage you to migrate to PyTorch native MultiheadAttention\"\n            \"The documentation is available in https://docs.pytorch.org/docs/main/generated/torch.nn.MultiheadAttention.html\"\n        )\n\n        heads_t = torch.tensor([heads])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor = torch.tensor([])\n        use_mask = pad_mask is not None\n\n        (\n            input_lin_q_results,\n            input_lin_kv_results,\n            softmax_results,\n            dropout_results,\n            dropout_mask,\n            matmul2_results,\n            outputs,\n        ) = fast_multihead_attn.encdec_multihead_attn_forward(\n            use_mask,\n            use_time_mask,\n            is_training,\n            heads,\n            inputs_q,\n            inputs_kv,\n            input_weights_q,\n            input_weights_kv,\n            output_weights,\n            pad_mask if use_mask else null_tensor,\n            dropout_prob,\n        )\n\n        ctx.save_for_backward(\n            heads_t,\n            matmul2_results,\n            dropout_results,\n            softmax_results,\n            input_lin_q_results,\n            input_lin_kv_results,\n            inputs_q,\n            inputs_kv,\n            input_weights_q,\n            input_weights_kv,\n            output_weights,\n            dropout_mask,\n            dropout_prob_t,\n        )\n\n        return outputs.detach()\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        (\n            heads_t,\n            matmul2_results,\n            dropout_results,\n            softmax_results,\n            input_lin_q_results,\n            input_lin_kv_results,\n            inputs_q,\n            inputs_kv,\n            input_weights_q,\n            input_weights_kv,\n            output_weights,\n            dropout_mask,\n            dropout_prob_t,\n        ) = ctx.saved_tensors\n\n        (\n            input_q_grads,\n            input_kv_grads,\n            input_weight_q_grads,\n            input_weight_kv_grads,\n            output_weight_grads,\n        ) = fast_multihead_attn.encdec_multihead_attn_backward(\n            heads_t[0],\n            output_grads,\n            matmul2_results,\n            dropout_results,\n            softmax_results,\n            input_lin_q_results,\n            input_lin_kv_results,\n            inputs_q,\n            inputs_kv,\n            input_weights_q,\n            input_weights_kv,\n            output_weights,\n            dropout_mask,\n            dropout_prob_t[0],\n        )\n\n        return (\n            None,\n            None,\n            None,\n            input_q_grads,\n            input_kv_grads,\n            input_weight_q_grads,\n            input_weight_kv_grads,\n            output_weight_grads,\n            None,\n            None,\n        )\n\n\nfast_encdec_attn_func = FastEncdecAttnFunc.apply\n"
  },
  {
    "path": "apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py",
    "content": "# Copyright (c) 2017-present, Facebook, Inc.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the LICENSE file in\n# the root directory of this source tree. An additional grant of patent rights\n# can be found in the PATENTS file in the same directory.\n\nimport torch\n\nimport fast_multihead_attn\n\n\nclass FastEncdecAttnNormAddFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        use_time_mask,\n        is_training,\n        heads,\n        inputs_q,\n        inputs_kv,\n        lyr_nrm_gamma_weights,\n        lyr_nrm_beta_weights,\n        input_weights_q,\n        input_weights_kv,\n        output_weights,\n        pad_mask,\n        dropout_prob,\n    ):\n        from apex import deprecated_warning\n\n        deprecated_warning(\n            \"`apex.contrib.multihead_attn` is deprecated and will be removed in July 2026. \"\n            \"We encourage you to migrate to PyTorch native MultiheadAttention\"\n            \"The documentation is available in https://docs.pytorch.org/docs/main/generated/torch.nn.MultiheadAttention.html\"\n        )\n\n        heads_t = torch.tensor([heads])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor = torch.tensor([])\n        use_mask = pad_mask is not None\n\n        (\n            lyr_nrm_results,\n            lyr_nrm_mean,\n            lyr_nrm_invvar,\n            input_lin_q_results,\n            input_lin_kv_results,\n            softmax_results,\n            dropout_results,\n            dropout_mask,\n            matmul2_results,\n            dropout_add_mask,\n            outputs,\n        ) = fast_multihead_attn.encdec_multihead_attn_norm_add_forward(\n            use_mask,\n            use_time_mask,\n            is_training,\n            heads,\n            inputs_q,\n            inputs_kv,\n            lyr_nrm_gamma_weights,\n            lyr_nrm_beta_weights,\n            input_weights_q,\n            input_weights_kv,\n            output_weights,\n            pad_mask if use_mask else null_tensor,\n            dropout_prob,\n        )\n\n        ctx.save_for_backward(\n            heads_t,\n            matmul2_results,\n            dropout_results,\n            softmax_results,\n            input_lin_q_results,\n            input_lin_kv_results,\n            lyr_nrm_results,\n            lyr_nrm_mean,\n            lyr_nrm_invvar,\n            inputs_q,\n            inputs_kv,\n            lyr_nrm_gamma_weights,\n            lyr_nrm_beta_weights,\n            input_weights_q,\n            input_weights_kv,\n            output_weights,\n            dropout_mask,\n            dropout_add_mask,\n            dropout_prob_t,\n        )\n\n        return outputs.detach()\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        (\n            heads_t,\n            matmul2_results,\n            dropout_results,\n            softmax_results,\n            input_lin_q_results,\n            input_lin_kv_results,\n            lyr_nrm_results,\n            lyr_nrm_mean,\n            lyr_nrm_invvar,\n            inputs_q,\n            inputs_kv,\n            lyr_nrm_gamma_weights,\n            lyr_nrm_beta_weights,\n            input_weights_q,\n            input_weights_kv,\n            output_weights,\n            dropout_mask,\n            dropout_add_mask,\n            dropout_prob_t,\n        ) = ctx.saved_tensors\n\n        (\n            input_q_grads,\n            input_kv_grads,\n            lyr_nrm_gamma_grads,\n            lyr_nrm_beta_grads,\n            input_weight_q_grads,\n            input_weight_kv_grads,\n            output_weight_grads,\n        ) = fast_multihead_attn.encdec_multihead_attn_norm_add_backward(\n            heads_t[0],\n            output_grads,\n            matmul2_results,\n            dropout_results,\n            softmax_results,\n            input_lin_q_results,\n            input_lin_kv_results,\n            lyr_nrm_results,\n            lyr_nrm_mean,\n            lyr_nrm_invvar,\n            inputs_q,\n            inputs_kv,\n            lyr_nrm_gamma_weights,\n            lyr_nrm_beta_weights,\n            input_weights_q,\n            input_weights_kv,\n            output_weights,\n            dropout_mask,\n            dropout_add_mask,\n            dropout_prob_t[0],\n        )\n\n        # import pdb; pdb.set_trace()\n        return (\n            None,\n            None,\n            None,\n            input_q_grads,\n            input_kv_grads,\n            lyr_nrm_gamma_grads,\n            lyr_nrm_beta_grads,\n            input_weight_q_grads,\n            input_weight_kv_grads,\n            output_weight_grads,\n            None,\n            None,\n        )\n\n\nfast_encdec_attn_norm_add_func = FastEncdecAttnNormAddFunc.apply\n"
  },
  {
    "path": "apex/contrib/multihead_attn/fast_self_multihead_attn_func.py",
    "content": "import torch\n\nimport fast_multihead_attn\n\n\nclass FastSelfAttnFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        use_time_mask,\n        is_training,\n        heads,\n        inputs,\n        input_weights,\n        output_weights,\n        input_biases,\n        output_biases,\n        pad_mask,\n        mask_additive,\n        dropout_prob,\n    ):\n        from apex import deprecated_warning\n\n        deprecated_warning(\n            \"`apex.contrib.multihead_attn` is deprecated and will be removed in July 2026. \"\n            \"We encourage you to migrate to PyTorch native MultiheadAttention\"\n            \"The documentation is available in https://docs.pytorch.org/docs/main/generated/torch.nn.MultiheadAttention.html\"\n        )\n\n        use_biases_t = torch.tensor([input_biases is not None])\n        heads_t = torch.tensor([heads])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor = torch.tensor([])\n        use_mask = pad_mask is not None\n        mask_additive_t = torch.tensor([mask_additive])\n\n        if use_biases_t[0]:\n            if not mask_additive:\n                (\n                    input_lin_results,\n                    softmax_results,\n                    dropout_results,\n                    dropout_mask,\n                    matmul2_results,\n                    outputs,\n                ) = fast_multihead_attn.self_attn_bias_forward(\n                    use_mask,\n                    use_time_mask,\n                    is_training,\n                    heads,\n                    inputs,\n                    input_weights,\n                    output_weights,\n                    input_biases,\n                    output_biases,\n                    pad_mask if use_mask else null_tensor,\n                    dropout_prob,\n                )\n                # fast_self_multihead_attn_bias.forward()                           \\\n                ctx.save_for_backward(\n                    use_biases_t,\n                    heads_t,\n                    matmul2_results,\n                    dropout_results,\n                    softmax_results,\n                    null_tensor,\n                    null_tensor,\n                    mask_additive_t,\n                    input_lin_results,\n                    inputs,\n                    input_weights,\n                    output_weights,\n                    dropout_mask,\n                    dropout_prob_t,\n                )\n\n            else:\n                (\n                    input_lin_results,\n                    bmm1_results,\n                    dropout_results,\n                    dropout_mask,\n                    matmul2_results,\n                    outputs,\n                ) = fast_multihead_attn.self_attn_bias_additive_mask_forward(\n                    use_mask,\n                    use_time_mask,\n                    is_training,\n                    heads,\n                    inputs,\n                    input_weights,\n                    output_weights,\n                    input_biases,\n                    output_biases,\n                    pad_mask if use_mask else null_tensor,\n                    dropout_prob,\n                )\n                # fast_self_multihead_attn_bias_additive_mask.forward(                           \\\n                ctx.save_for_backward(\n                    use_biases_t,\n                    heads_t,\n                    matmul2_results,\n                    dropout_results,\n                    null_tensor,\n                    bmm1_results,\n                    pad_mask,\n                    mask_additive_t,\n                    input_lin_results,\n                    inputs,\n                    input_weights,\n                    output_weights,\n                    dropout_mask,\n                    dropout_prob_t,\n                )\n\n        else:\n            (\n                input_lin_results,\n                softmax_results,\n                dropout_results,\n                dropout_mask,\n                matmul2_results,\n                outputs,\n            ) = fast_multihead_attn.self_attn_forward(\n                use_mask,\n                use_time_mask,\n                is_training,\n                heads,\n                inputs,\n                input_weights,\n                output_weights,\n                pad_mask if use_mask else null_tensor,\n                dropout_prob,\n            )\n            # fast_self_multihead_attn.forward(                           \\\n            ctx.save_for_backward(\n                use_biases_t,\n                heads_t,\n                matmul2_results,\n                dropout_results,\n                softmax_results,\n                null_tensor,\n                null_tensor,\n                mask_additive_t,\n                input_lin_results,\n                inputs,\n                input_weights,\n                output_weights,\n                dropout_mask,\n                dropout_prob_t,\n            )\n        return outputs.detach()\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        (\n            use_biases_t,\n            heads_t,\n            matmul2_results,\n            dropout_results,\n            softmax_results,\n            bmm1_results,\n            pad_mask,\n            mask_additive_t,\n            input_lin_results,\n            inputs,\n            input_weights,\n            output_weights,\n            dropout_mask,\n            dropout_prob_t,\n        ) = ctx.saved_tensors\n\n        if use_biases_t[0]:\n            if not mask_additive_t[0]:\n                (\n                    input_grads,\n                    input_weight_grads,\n                    output_weight_grads,\n                    input_bias_grads,\n                    output_bias_grads,\n                ) = fast_multihead_attn.self_attn_bias_backward(\n                    heads_t[0],\n                    output_grads,\n                    matmul2_results,\n                    dropout_results,\n                    softmax_results,\n                    input_lin_results,\n                    inputs,\n                    input_weights,\n                    output_weights,\n                    dropout_mask,\n                    dropout_prob_t[0],\n                )\n                # fast_self_multihead_attn_bias.backward(                          \\\n\n            else:\n                (\n                    input_grads,\n                    input_weight_grads,\n                    output_weight_grads,\n                    input_bias_grads,\n                    output_bias_grads,\n                ) = fast_multihead_attn.self_attn_bias_additive_mask_backward(\n                    heads_t[0],\n                    output_grads,\n                    matmul2_results,\n                    dropout_results,\n                    bmm1_results,\n                    pad_mask,\n                    input_lin_results,\n                    inputs,\n                    input_weights,\n                    output_weights,\n                    dropout_mask,\n                    dropout_prob_t[0],\n                )\n                # fast_self_multihead_attn_bias_additive_mask.backward(                          \\\n\n        else:\n            input_bias_grads = None\n            output_bias_grads = None\n            input_grads, input_weight_grads, output_weight_grads = (\n                fast_multihead_attn.self_attn_backward(\n                    heads_t[0],\n                    output_grads,\n                    matmul2_results,\n                    dropout_results,\n                    softmax_results,\n                    input_lin_results,\n                    inputs,\n                    input_weights,\n                    output_weights,\n                    dropout_mask,\n                    dropout_prob_t[0],\n                )\n            )\n            # fast_self_multihead_attn.backward(                          \\\n        return (\n            None,\n            None,\n            None,\n            input_grads,\n            input_weight_grads,\n            output_weight_grads,\n            input_bias_grads,\n            output_bias_grads,\n            None,\n            None,\n            None,\n        )\n\n\nfast_self_attn_func = FastSelfAttnFunc.apply\n"
  },
  {
    "path": "apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py",
    "content": "import torch\n\nimport fast_multihead_attn\n\n\nclass FastSelfAttnNormAddFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        use_time_mask,\n        is_training,\n        heads,\n        inputs,\n        lyr_nrm_gamma_weights,\n        lyr_nrm_beta_weights,\n        input_weights,\n        output_weights,\n        pad_mask,\n        dropout_prob,\n    ):\n        from apex import deprecated_warning\n\n        deprecated_warning(\n            \"`apex.contrib.multihead_attn` is deprecated and will be removed in July 2026. \"\n            \"We encourage you to migrate to PyTorch native MultiheadAttention\"\n            \"The documentation is available in https://docs.pytorch.org/docs/main/generated/torch.nn.MultiheadAttention.html\"\n        )\n\n        heads_t = torch.tensor([heads])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor = torch.tensor([])\n        use_mask = pad_mask is not None\n\n        (\n            lyr_nrm_results,\n            lyr_nrm_mean,\n            lyr_nrm_invvar,\n            input_lin_results,\n            softmax_results,\n            dropout_results,\n            dropout_mask,\n            matmul2_results,\n            dropout_add_mask,\n            outputs,\n        ) = fast_multihead_attn.self_attn_norm_add_forward(\n            use_mask,\n            use_time_mask,\n            is_training,\n            heads,\n            inputs,\n            lyr_nrm_gamma_weights,\n            lyr_nrm_beta_weights,\n            input_weights,\n            output_weights,\n            pad_mask if use_mask else null_tensor,\n            dropout_prob,\n        )\n        # fast_self_multihead_attn_norm_add.forward(                 \\\n\n        ctx.save_for_backward(\n            heads_t,\n            matmul2_results,\n            dropout_results,\n            softmax_results,\n            input_lin_results,\n            lyr_nrm_results,\n            lyr_nrm_mean,\n            lyr_nrm_invvar,\n            inputs,\n            lyr_nrm_gamma_weights,\n            lyr_nrm_beta_weights,\n            input_weights,\n            output_weights,\n            dropout_mask,\n            dropout_add_mask,\n            dropout_prob_t,\n        )\n\n        return outputs.detach()\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        (\n            heads_t,\n            matmul2_results,\n            dropout_results,\n            softmax_results,\n            input_lin_results,\n            lyr_nrm_results,\n            lyr_nrm_mean,\n            lyr_nrm_invvar,\n            inputs,\n            lyr_nrm_gamma_weights,\n            lyr_nrm_beta_weights,\n            input_weights,\n            output_weights,\n            dropout_mask,\n            dropout_add_mask,\n            dropout_prob_t,\n        ) = ctx.saved_tensors\n\n        (\n            input_grads,\n            lyr_nrm_gamma_grads,\n            lyr_nrm_beta_grads,\n            input_weight_grads,\n            output_weight_grads,\n        ) = fast_multihead_attn.self_attn_norm_add_backward(\n            heads_t[0],\n            output_grads,\n            matmul2_results,\n            dropout_results,\n            softmax_results,\n            input_lin_results,\n            lyr_nrm_results,\n            lyr_nrm_mean,\n            lyr_nrm_invvar,\n            inputs,\n            lyr_nrm_gamma_weights,\n            lyr_nrm_beta_weights,\n            input_weights,\n            output_weights,\n            dropout_mask,\n            dropout_add_mask,\n            dropout_prob_t[0],\n        )\n        # fast_self_multihead_attn_norm_add.backward(                 \\\n\n        return (\n            None,\n            None,\n            None,\n            input_grads,\n            lyr_nrm_gamma_grads,\n            lyr_nrm_beta_grads,\n            input_weight_grads,\n            output_weight_grads,\n            None,\n            None,\n        )\n\n\nfast_self_attn_norm_add_func = FastSelfAttnNormAddFunc.apply\n"
  },
  {
    "path": "apex/contrib/multihead_attn/mask_softmax_dropout_func.py",
    "content": "import torch\n\nimport fast_multihead_attn\n\n\nclass MaskSoftmaxDropout(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, is_training, heads, inputs, pad_mask, mask_additive, dropout_prob):\n        from apex import deprecated_warning\n\n        deprecated_warning(\n            \"`apex.contrib.multihead_attn` is deprecated and will be removed in July 2026. \"\n            \"We encourage you to migrate to PyTorch native MultiheadAttention\"\n            \"The documentation is available in https://docs.pytorch.org/docs/main/generated/torch.nn.MultiheadAttention.html\"\n        )\n\n        heads_t = torch.tensor([heads])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor = torch.tensor([])\n        use_mask = pad_mask is not None\n        use_mask_t = torch.tensor([use_mask])\n        mask_additive_t = torch.tensor([mask_additive])\n\n        if mask_additive:\n            dropout_results, dropout_mask, softmax_results = (\n                fast_multihead_attn.additive_mask_softmax_dropout_forward(\n                    use_mask,\n                    is_training,\n                    heads,\n                    inputs,\n                    pad_mask if use_mask else null_tensor,\n                    dropout_prob,\n                )\n            )\n            # fast_additive_mask_softmax_dropout.forward(                           \\\n        else:\n            dropout_results, dropout_mask, softmax_results = (\n                fast_multihead_attn.mask_softmax_dropout_forward(\n                    use_mask,\n                    is_training,\n                    heads,\n                    inputs,\n                    pad_mask if use_mask else null_tensor,\n                    dropout_prob,\n                )\n            )\n            # fast_mask_softmax_dropout.forward(                           \\\n\n        ctx.save_for_backward(\n            use_mask_t,\n            heads_t,\n            softmax_results,\n            dropout_mask,\n            pad_mask if use_mask else null_tensor,\n            mask_additive_t,\n            dropout_prob_t,\n        )\n\n        return dropout_results.detach()\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        (\n            use_mask_t,\n            heads_t,\n            softmax_results,\n            dropout_mask,\n            pad_mask,\n            mask_additive_t,\n            dropout_prob_t,\n        ) = ctx.saved_tensors\n\n        if mask_additive_t[0]:\n            input_grads = fast_multihead_attn.additive_mask_softmax_dropout_backward(\n                use_mask_t[0],\n                heads_t[0],\n                output_grads,\n                softmax_results,\n                dropout_mask,\n                dropout_prob_t[0],\n            )\n            # fast_additive_mask_softmax_dropout.backward(                          \\\n        else:\n            input_grads = fast_multihead_attn.mask_softmax_dropout_backward(\n                use_mask_t[0],\n                heads_t[0],\n                output_grads,\n                softmax_results,\n                dropout_mask,\n                pad_mask,\n                dropout_prob_t[0],\n            )\n            # fast_mask_softmax_dropout.backward(                          \\\n        return None, None, input_grads, None, None, None\n\n\nfast_mask_softmax_dropout_func = MaskSoftmaxDropout.apply\n"
  },
  {
    "path": "apex/contrib/multihead_attn/self_multihead_attn.py",
    "content": "import math\n\nimport torch\nfrom torch import nn\nfrom torch.nn import Parameter\nimport torch.nn.functional as F\n\nfrom .self_multihead_attn_func import self_attn_func\nfrom .fast_self_multihead_attn_func import fast_self_attn_func\nfrom .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func\nfrom apex.normalization.fused_layer_norm import FusedLayerNorm\n\n\n@torch.jit.script\ndef jit_dropout_add(x, residual, prob, is_training):\n    # type: (Tensor, Tensor, float, bool) -> Tensor\n    out = F.dropout(x, p=prob, training=True)\n    out = residual + out\n    return out\n\n\nclass SelfMultiheadAttn(nn.Module):\n    \"\"\"Multi-headed attention.\n\n    See \"Attention Is All You Need\" for more details.\n    \"\"\"\n\n    def __init__(\n        self,\n        embed_dim,\n        num_heads,\n        dropout=0.0,\n        bias=False,\n        include_norm_add=False,\n        impl=\"fast\",\n        separate_qkv_params=False,\n        mask_additive=False,\n    ):\n        super().__init__()\n        self.embed_dim = embed_dim\n        self.num_heads = num_heads\n        self.dropout = dropout\n        self.head_dim = embed_dim // num_heads\n        assert self.head_dim * num_heads == self.embed_dim, (\n            \"embed_dim must be divisible by num_heads\"\n        )\n        self.bias = bias\n        self.include_norm_add = include_norm_add\n        self.impl = impl\n        self.scaling = self.head_dim**-0.5\n        self.separate_qkv_params = separate_qkv_params\n        self.mask_additive = mask_additive\n        if mask_additive:\n            assert self.include_norm_add == False, \"additive mask not supported with layer norm\"\n            assert impl == \"default\" or (impl == \"fast\" and bias), (\n                \"additive mask not supported for fast mode without bias\"\n            )\n        if separate_qkv_params:\n            self.q_weight = Parameter(torch.empty(embed_dim, embed_dim))\n            self.k_weight = Parameter(torch.empty(embed_dim, embed_dim))\n            self.v_weight = Parameter(torch.empty(embed_dim, embed_dim))\n        else:\n            self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))\n        self.out_proj_weight = Parameter(torch.empty(embed_dim, embed_dim))\n        if self.bias:\n            if separate_qkv_params:\n                self.q_bias = Parameter(torch.empty(embed_dim))\n                self.k_bias = Parameter(torch.empty(embed_dim))\n                self.v_bias = Parameter(torch.empty(embed_dim))\n            else:\n                self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))\n            self.out_proj_bias = Parameter(torch.empty(embed_dim))\n        else:\n            if separate_qkv_params:\n                self.register_parameter(\"q_bias\", None)\n                self.register_parameter(\"k_bias\", None)\n                self.register_parameter(\"v_bias\", None)\n                self.q_bias = None\n                self.k_bias = None\n                self.v_bias = None\n            else:\n                self.register_parameter(\"in_proj_bias\", None)\n                self.in_proj_bias = None\n            self.register_parameter(\"out_proj_bias\", None)\n            self.out_proj_bias = None\n        if self.include_norm_add:\n            if impl == \"fast\":\n                self.lyr_nrm_gamma_weights = Parameter(torch.empty(embed_dim))\n                self.lyr_nrm_beta_weights = Parameter(torch.empty(embed_dim))\n                self.lyr_nrm = None\n            else:\n                self.register_parameter(\"lyr_norm_gamma_weights\", None)\n                self.register_parameter(\"lyr_norm_beta_weights\", None)\n                self.lyr_nrm_gamma_weights = None\n                self.lyr_nrm_beta_weights = None\n                self.lyr_nrm = FusedLayerNorm(embed_dim)\n        self.reset_parameters()\n\n        if self.include_norm_add:\n            if impl == \"fast\":\n                self.attn_func = fast_self_attn_norm_add_func\n            elif impl == \"default\":\n                self.attn_func = self_attn_func\n            else:\n                assert False, \"Unsupported impl: {} !\".format(impl)\n        else:\n            if impl == \"fast\":\n                self.attn_func = fast_self_attn_func\n            elif impl == \"default\":\n                self.attn_func = self_attn_func\n            else:\n                assert False, \"Unsupported impl: {} !\".format(impl)\n\n    def reset_parameters(self):\n        if self.separate_qkv_params:\n            nn.init.xavier_uniform_(self.q_weight)\n            nn.init.xavier_uniform_(self.k_weight)\n            nn.init.xavier_uniform_(self.v_weight)\n        else:\n            # in_proj_weight has shape [3 * hidden, hidden] but it should be\n            # initialized like a [hidden, hidden] matrix.\n            # sqrt(6 / (hidden + hidden)) / sqrt(6 / (3 * hidden + hidden)) = sqrt(2)\n            # therefore xavier_uniform gain should be set to sqrt(2).\n            nn.init.xavier_uniform_(self.in_proj_weight, gain=math.sqrt(2))\n        nn.init.xavier_uniform_(self.out_proj_weight)\n        if self.bias:\n            if self.separate_qkv_params:\n                nn.init.constant_(self.q_bias, 0.0)\n                nn.init.constant_(self.k_bias, 0.0)\n                nn.init.constant_(self.v_bias, 0.0)\n            else:\n                nn.init.constant_(self.in_proj_bias, 0.0)\n            nn.init.constant_(self.out_proj_bias, 0.0)\n        if self.include_norm_add:\n            if self.impl == \"fast\":\n                nn.init.ones_(self.lyr_nrm_gamma_weights)\n                nn.init.zeros_(self.lyr_nrm_beta_weights)\n            else:\n                self.lyr_nrm.reset_parameters()\n\n    def forward(\n        self,\n        query,\n        key,\n        value,\n        key_padding_mask=None,\n        need_weights=False,\n        attn_mask=None,\n        is_training=True,\n    ):\n        \"\"\"Input shape: Time x Batch x Channel\n\n        Self-attention can be implemented by passing in the same arguments for\n        query, key and value. Future timesteps can be masked with the\n        `mask_future_timesteps` argument. Padding elements can be excluded from\n        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:\n        batch x src_len, where padding elements are indicated by 1s.\n        \"\"\"\n        if self.separate_qkv_params:\n            input_weights = (\n                torch.cat(\n                    [\n                        self.q_weight.view(self.num_heads, 1, self.head_dim, self.embed_dim),\n                        self.k_weight.view(self.num_heads, 1, self.head_dim, self.embed_dim),\n                        self.v_weight.view(self.num_heads, 1, self.head_dim, self.embed_dim),\n                    ],\n                    dim=1,\n                )\n                .reshape(3 * self.embed_dim, self.embed_dim)\n                .contiguous()\n            )\n        else:\n            input_weights = self.in_proj_weight\n        if self.bias:\n            if self.separate_qkv_params:\n                input_bias = (\n                    torch.cat(\n                        [\n                            self.q_bias.view(self.num_heads, 1, self.head_dim),\n                            self.k_bias.view(self.num_heads, 1, self.head_dim),\n                            self.v_bias.view(self.num_heads, 1, self.head_dim),\n                        ],\n                        dim=1,\n                    )\n                    .reshape(3 * self.embed_dim)\n                    .contiguous()\n                )\n            else:\n                input_bias = self.in_proj_bias\n        else:\n            input_bias = None\n        if key_padding_mask is not None:\n            assert attn_mask is None, (\n                \"ERROR attn_mask and key_padding_mask should not be both defined!\"\n            )\n            mask = key_padding_mask\n        elif attn_mask is not None:\n            assert self.mask_additive == False, \"additive mask not supported for time mask\"\n            mask = attn_mask\n        else:\n            mask = None\n\n        if self.include_norm_add:\n            if self.impl == \"fast\":\n                outputs = self.attn_func(\n                    attn_mask is not None,\n                    is_training,\n                    self.num_heads,\n                    query,\n                    self.lyr_nrm_gamma_weights,\n                    self.lyr_nrm_beta_weights,\n                    input_weights,\n                    self.out_proj_weight,\n                    mask,\n                    self.dropout,\n                )\n            else:\n                lyr_nrm_results = self.lyr_nrm(query)\n                outputs = self.attn_func(\n                    attn_mask is not None,\n                    is_training,\n                    self.num_heads,\n                    self.scaling,\n                    lyr_nrm_results,\n                    input_weights,\n                    self.out_proj_weight,\n                    input_bias,\n                    self.out_proj_bias,\n                    mask,\n                    self.mask_additive,\n                    self.dropout,\n                )\n                if is_training:\n                    outputs = jit_dropout_add(outputs, query, self.dropout, is_training)\n                else:\n                    outputs = outputs + query\n        else:\n            if self.impl == \"fast\":\n                outputs = self.attn_func(\n                    attn_mask is not None,\n                    is_training,\n                    self.num_heads,\n                    query,\n                    input_weights,\n                    self.out_proj_weight,\n                    input_bias,\n                    self.out_proj_bias,\n                    mask,\n                    self.mask_additive,\n                    self.dropout,\n                )\n            else:\n                outputs = self.attn_func(\n                    attn_mask is not None,\n                    is_training,\n                    self.num_heads,\n                    self.scaling,\n                    query,\n                    input_weights,\n                    self.out_proj_weight,\n                    input_bias,\n                    self.out_proj_bias,\n                    mask,\n                    self.mask_additive,\n                    self.dropout,\n                )\n\n        return outputs, None\n"
  },
  {
    "path": "apex/contrib/multihead_attn/self_multihead_attn_func.py",
    "content": "import torch\nimport torch.nn.functional as F\n\n\nclass SelfAttnFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        use_time_mask,\n        is_training,\n        heads,\n        scale,\n        inputs,\n        input_weights,\n        output_weights,\n        input_biases,\n        output_biases,\n        mask,\n        is_additive_mask,\n        dropout_prob,\n    ):\n        from apex import deprecated_warning\n\n        deprecated_warning(\n            \"`apex.contrib.multihead_attn` is deprecated and will be removed in July 2026. \"\n            \"We encourage you to migrate to PyTorch native MultiheadAttention\"\n            \"The documentation is available in https://docs.pytorch.org/docs/main/generated/torch.nn.MultiheadAttention.html\"\n        )\n\n        use_biases_t = torch.tensor([input_biases is not None])\n        heads_t = torch.tensor([heads])\n        scale_t = torch.tensor([scale])\n        dropout_prob_t = torch.tensor([dropout_prob])\n        null_tensor = torch.tensor([])\n        head_dim = inputs.size(2) // heads\n\n        # Input Linear GEMM\n        # input1: (activations) [seql_q, seqs, embed_dim(1024)]\n        # input2: (weights)     [embed_dim*3 (3072), embed_dim (1024)] (transpose [0,1])\n        # output:               [seql_q, seqs, embed_dim*3]\n        # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim*3 ) = (seql_q*seqs x embed_dim*3)\n        if use_biases_t[0]:\n            input_lin_results = torch.addmm(\n                input_biases,\n                inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)),\n                input_weights.transpose(0, 1),\n                beta=1.0,\n                alpha=1.0,\n            )\n        else:\n            input_lin_results = torch.mm(\n                inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)),\n                input_weights.transpose(0, 1),\n            )\n        input_lin_results = input_lin_results.view(\n            inputs.size(0), inputs.size(1), input_weights.size(0)\n        )\n\n        # Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)\n        # Sequences and heads are combined to make the batch of the Batched GEMM\n        # input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]\n        # input_lin_results: [seql_q, batches=seqs*heads, 3, head_dim]\n        input_lin_results = input_lin_results.view(\n            inputs.size(0), inputs.size(1) * heads, 3, head_dim\n        )\n        queries = input_lin_results[:, :, 0, :]\n        keys = input_lin_results[:, :, 1, :]\n        values = input_lin_results[:, :, 2, :]\n\n        # Matmul1 Batched GEMMs\n        # The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification\n        # baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of\n        # a separate elementwise operation.\n        # Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1)\n        # Input2: (Keys)    [seql_k, seqs*heads, head_dim] transpose(0,1)\n        # output:           [seqs*heads, seql_q, seql_k]\n        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )\n        matmul1_results = torch.empty(\n            (queries.size(1), queries.size(0), keys.size(0)),\n            dtype=queries.dtype,\n            device=torch.device(\"cuda\"),\n        )\n        matmul1_results = torch.baddbmm(\n            matmul1_results,\n            queries.transpose(0, 1),\n            keys.transpose(0, 1).transpose(1, 2),\n            out=matmul1_results,\n            beta=0.0,\n            alpha=scale_t[0],\n        )\n\n        if mask is not None:\n            # Self Attention Time Mask\n            if use_time_mask:\n                assert len(mask.size()) == 2, \"Timing mask is not 2D!\"\n                assert mask.size(0) == mask.size(1), \"Sequence length should match!\"\n                mask = mask.to(torch.bool)\n                matmul1_results = matmul1_results.masked_fill_(mask, float(\"-inf\"))\n            # Key Padding Mask\n            else:\n                batches, seql_q, seql_k = matmul1_results.size()\n                seqs = int(batches / heads)\n                matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)\n                if is_additive_mask:\n                    matmul1_results = matmul1_results + mask.unsqueeze(1).unsqueeze(2)\n                else:\n                    mask = mask.to(torch.bool)\n                    matmul1_results = matmul1_results.masked_fill_(\n                        mask.unsqueeze(1).unsqueeze(2), float(\"-inf\")\n                    )\n                matmul1_results = matmul1_results.view(seqs * heads, seql_q, seql_k)\n\n        softmax_results = F.softmax(matmul1_results, dim=-1)\n\n        # Dropout - is not executed for inference\n        if is_training:\n            dropout_results, dropout_mask = torch._fused_dropout(\n                softmax_results, p=(1.0 - dropout_prob_t[0])\n            )\n        else:\n            dropout_results = softmax_results\n            dropout_mask = null_tensor\n\n        # Matmul2 Batched GEMMs\n        # The output tensor specification is needed here to specify the non-standard output.\n        # Given that pytorch cannot currently perform autograd with an output tensor specified,\n        # this requires a backward pass specified.\n        # Input1: from_softmax [seqs*heads, seql_q, seql_k]\n        # Input2: (values)     [seql_v, seqs*heads, head_dim] transpose(0,1)\n        # Output:              [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = (seql_q x head_dim)\n        matmul2_results = torch.empty(\n            (dropout_results.size(1), dropout_results.size(0), values.size(2)),\n            dtype=dropout_results.dtype,\n            device=torch.device(\"cuda\"),\n        ).transpose(1, 0)\n        matmul2_results = torch.bmm(dropout_results, values.transpose(0, 1), out=matmul2_results)\n        matmul2_results = (\n            matmul2_results.transpose(0, 1)\n            .contiguous()\n            .view(inputs.size(0), inputs.size(1), inputs.size(2))\n        )\n\n        # Output Linear GEMM\n        # Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]\n        # Input2: (weights)     [ embed_dim, embed_dim ] transpose(0,1)\n        # Output:               [ seql_q, seqs, embed_dim ]\n        # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )\n        if use_biases_t[0]:\n            outputs = torch.addmm(\n                output_biases,\n                matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)),\n                output_weights.transpose(0, 1),\n                beta=1.0,\n                alpha=1.0,\n            )\n        else:\n            outputs = torch.mm(\n                matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)),\n                output_weights.transpose(0, 1),\n            )\n        outputs = outputs.view(inputs.size(0), inputs.size(1), output_weights.size(0))\n\n        ctx.save_for_backward(\n            use_biases_t,\n            heads_t,\n            scale_t,\n            matmul2_results,\n            dropout_results,\n            softmax_results,\n            input_lin_results,\n            inputs,\n            input_weights,\n            output_weights,\n            dropout_mask,\n            dropout_prob_t,\n        )\n\n        return outputs.detach()\n\n    @staticmethod\n    def backward(ctx, output_grads):\n        (\n            use_biases_t,\n            heads_t,\n            scale_t,\n            matmul2_results,\n            dropout_results,\n            softmax_results,\n            input_lin_results,\n            inputs,\n            input_weights,\n            output_weights,\n            dropout_mask,\n            dropout_prob_t,\n        ) = ctx.saved_tensors\n\n        head_dim = inputs.size(2) // heads_t[0]\n\n        # Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)\n        # Sequences and heads are combined to make the batch of the Batched GEMM\n        # input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]\n        # input_lin_results: [seql_q, batches=seqs*heads, 3, head_dim]\n        input_lin_results = input_lin_results.view(\n            inputs.size(0), inputs.size(1) * heads_t[0], 3, head_dim\n        )\n        queries = input_lin_results[:, :, 0, :]\n        keys = input_lin_results[:, :, 1, :]\n        values = input_lin_results[:, :, 2, :]\n\n        # Slice out q,k,v from one big set of gradients entering the input linear's bprop  (should only impact meta data, no copies!)\n        # The gradients are identical in size to the Input Linear outputs.\n        # The tensor is declared before hand to properly slice out query, key, and value grads.\n        input_lin_results_grads = torch.empty_like(input_lin_results)\n        queries_grads = input_lin_results_grads[:, :, 0, :]\n        keys_grads = input_lin_results_grads[:, :, 1, :]\n        values_grads = input_lin_results_grads[:, :, 2, :]\n\n        # Output Linear GEMM - DGRAD\n        # Input1: (data grads)  [seql_q, seqs, embed_dim=heads*head_dim]\n        # Input2: (weights)     [ embed_dim, embed_dim ]\n        # Output:               [ seql_q, seqs, embed_dim ]\n        # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )\n        output_lin_grads = torch.mm(\n            output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)),\n            output_weights,\n        )\n        output_lin_grads = output_lin_grads.view(\n            output_grads.size(0), output_grads.size(1), output_weights.size(1)\n        )\n        # Output Linear GEMM - WGRAD\n        # Input1: (data grads)  [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1)\n        # Input2: (activations) [seql_q*seqs, embed_dim ]\n        # Output:               [ seql_q, seqs, embed_dim ]\n        # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )\n        output_weight_grads = torch.mm(\n            output_grads.view(\n                output_grads.size(0) * output_grads.size(1), output_grads.size(2)\n            ).transpose(0, 1),\n            matmul2_results.view(\n                matmul2_results.size(0) * matmul2_results.size(1),\n                matmul2_results.size(2),\n            ),\n        )\n        output_lin_grads = output_lin_grads.view(\n            inputs.size(0), inputs.size(1) * heads_t[0], head_dim\n        ).transpose(0, 1)\n\n        if use_biases_t[0]:\n            output_bias_grads = torch.sum(\n                output_grads.view(\n                    output_grads.size(0) * output_grads.size(1), output_grads.size(2)\n                ),\n                0,\n            )\n        else:\n            output_bias_grads = None\n\n        # Matmul2 - DGRAD1\n        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)\n        # Output:               [seqs*heads, seql_q, seql_k]\n        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )\n        matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0, 1).transpose(1, 2))\n        # Matmul2 - DGRAD2\n        # Input1: (data grads)  [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)\n        # Output:               [seqs*heads, seql_q, seql_k]\n        # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )\n        values_grads = torch.bmm(\n            dropout_results.transpose(1, 2),\n            output_lin_grads,\n            out=values_grads.transpose(0, 1),\n        )\n\n        # Mask and Scaling for Dropout (not a publically documented op)\n        dropout_grads = torch._masked_scale(\n            matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0])\n        )\n\n        # Softmax Grad (not a publically documented op)\n        softmax_grads = torch._softmax_backward_data(\n            dropout_grads, softmax_results, -1, softmax_results.dtype\n        )\n\n        # Matmul1 - DGRAD1\n        # Input1: (data grads)  [seqs*heads, seql_q, seql_k]\n        # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)\n        # Output:               [seqs*heads, seql_q, head_dim] transpose(0,1)\n        # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )\n        queries_grads = torch.baddbmm(\n            queries_grads.transpose(0, 1),\n            softmax_grads,\n            keys.transpose(0, 1),\n            out=queries_grads.transpose(0, 1),\n            beta=0.0,\n            alpha=scale_t[0],\n        )\n        # Matmul1 - DGRAD2\n        # Input1: (data grads)  [seqs*heads, seql_q, seql_k] transpose(1,2)\n        # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)\n        # Output:               [seqs*heads, seql_k, head_dim] transpose(0,1)\n        # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )\n        keys_grads = torch.baddbmm(\n            keys_grads.transpose(0, 1),\n            softmax_grads.transpose(1, 2),\n            queries.transpose(0, 1),\n            out=keys_grads.transpose(0, 1),\n            beta=0.0,\n            alpha=scale_t[0],\n        )\n\n        # Input Linear GEMM - DGRAD\n        # input1: (data grads) [seql_q, seqs, 3*embed_dim(3072)]\n        # input2: (weights)    [embed_dim*3 (3072), embed_dim (1024)]\n        # output:              [seql_q, seqs, embed_dim]\n        # GEMM: ( (seql_q*seqs) x 3*embed_dim ) x ( 3*embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)\n        input_lin_results_grads = input_lin_results_grads.view(\n            inputs.size(0) * inputs.size(1), heads_t[0] * 3 * head_dim\n        )\n        input_grads = torch.mm(input_lin_results_grads, input_weights)\n        input_grads = input_grads.view(inputs.size(0), inputs.size(1), inputs.size(2))\n        # Input Linear GEMM - WGRAD\n        # input1: (data grads)  [seql_q*seqs, 3*embed_dim(3072)]\n        # input2: (activations) [seql_q*seqs, embed_dim(1024)]\n        # output:               [3*embed_dim, embed_dim]\n        # GEMM: ( 3*embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (3*embed_dim x embed_dim)\n        input_weight_grads = torch.mm(\n            input_lin_results_grads.transpose(0, 1),\n            inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)),\n        )\n\n        if use_biases_t[0]:\n            input_bias_grads = torch.sum(input_lin_results_grads, 0)\n        else:\n            input_bias_grads = None\n\n        return (\n            None,\n            None,\n            None,\n            None,\n            input_grads,\n            input_weight_grads,\n            output_weight_grads,\n            input_bias_grads,\n            output_bias_grads,\n            None,\n            None,\n            None,\n        )\n\n\nself_attn_func = SelfAttnFunc.apply\n"
  },
  {
    "path": "apex/contrib/nccl_allocator/README.md",
    "content": "## General information\n\n`nccl_allocator` is a module that enables `ncclMemAlloc`[^1] to be used within PyTorch for faster NCCL NVLS collective communications.\nIt is mainly based on `CUDAPluggableAllocator`.\nThe context manager `nccl_allocator.nccl_mem(enabled=True)` is used as a switch between `cudaMalloc` and `ncclMemAlloc` (if `enabled=True` it will use `cudaMalloc`).\n\n[^1]: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html\n\n### Example usage:\n\nHere is a minimalistic example:\n\n```\nimport os\nimport torch\nimport torch.distributed as dist\nimport apex.contrib.nccl_allocator as nccl_allocator\n\nrank = int(os.getenv(\"RANK\"))\nlocal_rank = int(os.getenv(\"LOCAL_RANK\"))\nworld_size = int(os.getenv(\"WORLD_SIZE\"))\n\nnccl_allocator.init()\n\ntorch.cuda.set_device(local_rank)\ndist.init_process_group(backend=\"nccl\")\n\nwith nccl_allocator.nccl_mem():\n\ta = torch.ones(1024 * 1024 * 2, device=\"cuda\")\ndist.all_reduce(a)\n\ntorch.cuda.synchronize()\n```\n\nPlease visit `apex/contrib/examples/nccl_allocator` for more examples.\n\n\n### IMPORTANT\n\nThere are several strict requirements:\n- PyTorch must include PR [#112850](https://github.com/pytorch/pytorch/pull/112850)\n- NCCL v2.19.4 and newer\n- NCCL NVLS requires CUDA Driver 530 and newer (tested on 535)\n\n"
  },
  {
    "path": "apex/contrib/nccl_allocator/__init__.py",
    "content": "from .nccl_allocator import *\n"
  },
  {
    "path": "apex/contrib/nccl_allocator/nccl_allocator.py",
    "content": "import os\nimport torch\nimport _apex_nccl_allocator\n\nfrom contextlib import nullcontext\n\n\n__all__ = [\"init\", \"nccl_mem\", \"create_nccl_mem_pool\"]\n\n\ndef get_func_args(func):\n    import inspect\n\n    sig = inspect.signature(func)\n    return [arg.name for arg in sig.parameters.values()]\n\n\ndef create_nccl_mem_pool(symmetric: bool | None = None) -> torch.cuda.MemPool:\n    _allocator = _apex_nccl_allocator.get_nccl_allocator()\n    if symmetric is None:\n        _pool = torch.cuda.MemPool(_allocator)\n    else:\n        if \"symmetric\" in get_func_args(torch.cuda.MemPool):\n            _pool = torch.cuda.MemPool(_allocator, symmetric=symmetric)\n        elif \"symm_mem\" in get_func_args(torch.cuda.MemPool):\n            # This path handles argument name divergence between\n            # nvidia pytorch and the official pytorch.\n            _pool = torch.cuda.MemPool(_allocator, symm_mem=symmetric)\n        else:\n            raise ValueError(\n                \"symmetric setting with torch.cuda.MemPool requires higher PyTorch version\"\n            )\n    return _pool\n\n\ndef init() -> None:\n    os.environ[\"NCCL_NVLS_ENABLE\"] = \"1\"\n    os.environ[\"TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK\"] = \"0\"\n\n\nclass nccl_mem:\n    def __init__(self, pool, enabled=True, device=None, group=None):\n        self.device = None\n        self.group = None\n        self.mem_context = None\n        self.pool = pool\n\n        if enabled:\n            if device is None:\n                self.device = torch.device(\"cuda\", torch.cuda.current_device())\n            elif isinstance(device, int):\n                self.device = torch.device(\"cuda\", device)\n            elif isinstance(device, str):\n                assert \"cuda\" in device, \"only cuda devices are supported\"\n                self.device = torch.device(device)\n\n            if group is None:\n                self.group = torch.distributed.distributed_c10d._get_default_group()\n            else:\n                self.group = group\n\n            self.mem_context = torch.cuda.use_mem_pool(self.pool)\n        else:\n            self.mem_context = nullcontext()\n\n    def __enter__(self):\n        self.mem_context.__enter__()\n        if self.group is not None:\n            backend = self.group._get_backend(self.device)\n            try:\n                backend.deregister_mem_pool(self.pool)\n            except RuntimeError:\n                pass\n\n    def __exit__(self, *args):\n        if self.group is not None:\n            backend = self.group._get_backend(self.device)\n            try:\n                backend.register_mem_pool(self.pool)\n            except RuntimeError:\n                pass\n        self.mem_context.__exit__(*args)\n"
  },
  {
    "path": "apex/contrib/openfold_triton/README.md",
    "content": "# OpenFold triton kernels\n\nThis subpackage is a collection of Triton kernels written specifically for the OpenFold model architecture initial training mode.\n\nTo use this subpackage, you must install additional dependencies:\n\n```bash\npip install einops\n```\n\nThe following sections list all main features and show how to use them.\n\n## Multi-Head Attention\n\n```python\nimport apex.contrib.openfold_triton.mha as mha\nfrom apex.contrib.openfold_triton import AttnBiasJIT, AttnNoBiasJIT, AttnTri, CanSchTriMHA\n\n# Integration with Attention module:\nclass SelfAttentionWithGate(nn.Module):\n    # ...\n\n    def _attention_forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        mask: torch.Tensor,\n        bias: Optional[torch.Tensor],\n    ) -> torch.Tensor:\n        if self.chunk_size is None:\n            if mha.is_enabled() and CanSchTriMHA(\n                list(query.shape),\n                bias is not None,\n                inf=self.inf,\n                training=self.training,\n            ):\n                if mask is not None:\n                    mask = mask.contiguous()\n                if bias is not None:\n                    bias = bias.contiguous()\n                return AttnTri(\n                    query, key, value, mask, bias, self.inf, torch.is_grad_enabled()\n                )\n            elif mha.is_enabled() and bias is not None and self.training:\n                return AttnBiasJIT(query, key, value, mask, bias, self.inf)\n            elif mha.is_enabled() and bias is None and self.training:\n                return AttnNoBiasJIT(query, key, value, mask, self.inf)\n\n# Switch on/off MHA dynamically at runtime via:\nmha.enable()\nmha.disable()\n\n```\n\n## LayerNorm\n\n```python\nfrom apex.contrib.openfold_triton import LayerNormSmallShapeOptImpl\n\n# Integration with LayerNorm module:\nclass LayerNorm(nn.Module):\n    # ...\n\n    def _should_use_triton_kernels(self, x: torch.Tensor) -> bool:\n        ln_triton_shapes = (\n            (256, 128),\n            (256, 256),\n        )\n        ln_triton_dim = 4\n        return (\n            self.training\n            and x.dim() == ln_triton_dim\n            and x.shape[-2:] in ln_triton_shapes\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        if self._should_use_triton_kernels(x):\n            return LayerNormSmallShapeOptImpl.apply(\n                x, self.normalized_shape, self.weight, self.bias, self.eps\n            )\n        else:\n            return F.layer_norm(\n                x, self.normalized_shape, self.weight, self.bias, self.eps\n            )\n\n# To load auto tuned cache:\nfrom apex.contrib.openfold_triton._layer_norm_config_ampere import _auto_tuned_config_ampere\nfrom apex.contrib.openfold_triton._layer_norm_config_hopper import _auto_tuned_config_hopper\nfrom apex.contrib.openfold_triton import _tuneable_triton_kernels\n\ndef load_triton_auto_tuned_cache(dap_size: int, arch_type: str) -> None:\n    auto_tuned_config = {\n        \"hopper\": _auto_tuned_config_hopper,\n        \"ampere\": _auto_tuned_config_ampere,\n    }[arch_type]\n    config_for_current_dap = auto_tuned_config[dap_size]\n    for func_name, cache in config_for_current_dap.items():\n        _tuneable_triton_kernels[func_name].cache = cache\n\nload_triton_auto_tuned_cache(\n    dap_size=4,  # supported values: 0, 1, 2, 4, 8\n    arch_type=\"hopper\",\n)\n\n```\n\n## FusedAdamSWA\n\n```python\nfrom apex.contrib.openfold_triton.fused_adam_swa import FusedAdamSWA\n\nfused_optimizer = FusedAdamSWA.from_optim(\n    adam_optimizer=adam_optimizer,  # standard pytorch optimizer\n    fp32_params=fp32_params,  # FP32 used in weight update\n    bf16_params=bf16_params,  # BF16 used in forward, backward, reduction\n    swa_params=swa_params,  # SWA used for evaluation\n    swa_decay_rate=swa_decay_rate,  # for example: 0.9, 0.99, 0.999\n)\n\nfused_optimizer.step()  # fused optimizer step: casting BF16/FP32 + param updates + SWA\n\n```\n"
  },
  {
    "path": "apex/contrib/openfold_triton/__init__.py",
    "content": "# © 2023 NVIDIA CORPORATION & AFFILIATES\n\nimport json\nimport warnings\nfrom collections import OrderedDict\nfrom copy import deepcopy\nfrom io import BytesIO\nfrom typing import BinaryIO, Union\n\nimport torch\nfrom triton.runtime.autotuner import Autotuner, Config, Heuristics\nfrom triton.runtime.jit import JITFunction\n\nfrom apex.contrib.openfold_triton._layer_norm_backward_kernels import (\n    _layer_norm_backward_dw_db_partial,\n    _layer_norm_backward_dw_db_partial_strided,\n    _layer_norm_backward_dx,\n    _layer_norm_backward_dx_strided,\n)\nfrom apex.contrib.openfold_triton._layer_norm_forward_kernels import (\n    _layer_norm_forward,\n    _layer_norm_forward_strided,\n)\nfrom apex.contrib.openfold_triton.layer_norm import LayerNormSmallShapeOptImpl\nfrom apex.contrib.openfold_triton.mha import (\n    AttnBiasJIT,\n    AttnNoBiasJIT,\n    AttnTri,\n    CanSchTriMHA,\n)\n\n__all__ = (\n    \"LayerNormSmallShapeOptImpl\",\n    \"sync_triton_auto_tune_cache_across_gpus\",\n    \"CanSchTriMHA\",\n    \"AttnTri\",\n    \"AttnBiasJIT\",\n    \"AttnNoBiasJIT\",\n)\n\n\ndef _get_tuneable_triton_func_name(f: Union[Autotuner, Heuristics, JITFunction]) -> str:\n    if isinstance(f, JITFunction):\n        return f.__name__\n    else:\n        return _get_tuneable_triton_func_name(f.fn)\n\n\n_tuneable_triton_kernels = OrderedDict(\n    (_get_tuneable_triton_func_name(func), func)\n    for func in (\n        _layer_norm_backward_dw_db_partial,\n        _layer_norm_backward_dw_db_partial_strided,\n        _layer_norm_backward_dx,\n        _layer_norm_backward_dx_strided,\n        _layer_norm_forward,\n        _layer_norm_forward_strided,\n    )\n)\n\n\ndef _save_triton_auto_tune_cache(strict: bool = True, verbose: bool = False) -> BytesIO:\n    caches = OrderedDict()\n    for func_name, func in _tuneable_triton_kernels.items():\n        if len(func.cache) < 1:\n            msg = f\"Triton JIT kernel {func_name} didn't have tuning cache\"\n            if strict:\n                raise ValueError(msg)\n            else:\n                warnings.warn(msg)\n        else:\n            caches[func_name] = [\n                (keys, vals.all_kwargs())\n                for keys, vals in zip(func.cache.keys(), func.cache.values())\n            ]\n    f = BytesIO(json.dumps(caches).encode(\"utf-8\"))\n    if verbose:\n        print(f\"Triton kernel auto-tuning caches written to {f}\")\n    return f\n\n\ndef _load_triton_auto_tune_cache(f: BinaryIO, strict: bool = True, verbose: bool = False) -> None:\n    caches = json.load(f)\n    if strict:\n        loaded_func_name = set(caches.keys())\n        tuneable_func_name = set(_tuneable_triton_kernels.keys())\n        if loaded_func_name != tuneable_func_name:\n            raise ValueError(\n                f\"Tuneable Triton kernels don't match with provided auto-tuning cache file {f}\\n\"\n                f\"Missing kernel caches: {tuneable_func_name - loaded_func_name}\\n\"\n                f\"Unexpected kernel caches: {loaded_func_name - tuneable_func_name}\"\n            )\n    for func_name, func_cache in caches.items():\n        if func_name not in _tuneable_triton_kernels:\n            raise ValueError(f\"{func_name} from {f} doesn't match any tuneable Triton kernels\")\n        for key, val in func_cache:\n            _tuneable_triton_kernels[func_name].cache[tuple(key)] = Config(val)\n    if verbose:\n        print(f\"Triton kernel auto-tuning caches loaded from {f}\")\n\n\ndef sync_triton_auto_tune_cache_across_gpus(strict: bool = True, verbose: bool = False) -> None:\n    if not torch.distributed.is_initialized():\n        return\n    if torch.distributed.get_rank() == 0:\n        print(\"Broadcasting Triton auto-tuning cache from rank 0 to other ranks...\")\n        cache = _save_triton_auto_tune_cache(strict=strict, verbose=verbose)\n        cache.seek(0)\n        cache_list = [\n            cache,\n        ]\n    else:\n        print(\n            f\"Rank {torch.distributed.get_rank()} is waiting for Triton auto-tuning cache from rank 0...\"\n        )\n        cache_list = [\n            None,\n        ]\n    torch.distributed.broadcast_object_list(cache_list)\n    _load_triton_auto_tune_cache(cache_list[0], strict=strict, verbose=verbose)\n    print(\"Succeed!\")\n"
  },
  {
    "path": "apex/contrib/openfold_triton/_layer_norm_backward_kernels.py",
    "content": "# © 2023 NVIDIA CORPORATION & AFFILIATES\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom triton import Config\n\n# %% Constants for efficient memory access.\nCACHE_SECTOR_SIZE = 32 * 8\nBF16_LOAD_SIZE = int(CACHE_SECTOR_SIZE / torch.finfo(torch.bfloat16).bits)\nPARTIAL_REDUCE_MIN = 32\n\n\n# %% Separated backward kernels for contiguous inputs. We choose to not fusing them because dX and\n# d{W, b} reduce along different directions.\n@triton.autotune(\n    configs=[\n        Config({\"M_BLOCK\": 1}, num_warps=1),\n        Config({\"M_BLOCK\": 2}, num_warps=1),\n        Config({\"M_BLOCK\": 4}, num_warps=2),\n        Config({\"M_BLOCK\": 8}, num_warps=4),\n        Config({\"M_BLOCK\": 16}, num_warps=8),\n        Config({\"M_BLOCK\": 32}, num_warps=8),\n        Config({\"M_BLOCK\": 64}, num_warps=8),\n    ],\n    key=[\"M\", \"N\"],\n)\n@triton.heuristics(\n    values={\n        \"N_BLOCK\": lambda kwargs: triton.next_power_of_2(kwargs[\"N\"]),\n    },\n)\n@triton.jit\ndef _layer_norm_backward_dx(\n    dy_ptr,\n    x_ptr,\n    w_ptr,\n    x_invstd_ptr,\n    x_mean_ptr,\n    dx_ptr,\n    M: tl.constexpr,\n    N: tl.constexpr,\n    M_BLOCK: tl.constexpr,\n    N_BLOCK: tl.constexpr,\n):\n    m_idx = (tl.program_id(0) * M_BLOCK + tl.arange(0, M_BLOCK))[:, None]\n    m_mask = m_idx < M\n    n_idx = tl.arange(0, N_BLOCK)[None, :]\n    n_mask = n_idx < N\n    mask = m_mask & n_mask\n    x = tl.load(x_ptr + N * m_idx + n_idx, mask, other=0).to(tl.float32)\n    x_mean = tl.load(x_mean_ptr + m_idx, m_mask, other=0).to(tl.float32)\n    x_invstd = tl.load(x_invstd_ptr + m_idx, m_mask, other=0).to(tl.float32)\n    x_hat = (x - x_mean) * x_invstd\n    dy = tl.load(dy_ptr + N * m_idx + n_idx, mask, other=0).to(tl.float32)\n    w = tl.load(w_ptr + n_idx, n_mask, other=0).to(tl.float32)\n    c1 = tl.sum(x_hat * dy * w, axis=1) / N\n    c2 = tl.sum(dy * w, axis=1) / N\n    dx = x_invstd * (dy * w - c1[:, None] * x_hat - c2[:, None])\n    tl.store(dx_ptr + N * m_idx + n_idx, dx, mask)\n\n\n@triton.autotune(\n    configs=[\n        Config(\n            {\"N_BLOCK\": BF16_LOAD_SIZE, \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN},\n            num_warps=2,\n        ),\n        Config(\n            {\"N_BLOCK\": BF16_LOAD_SIZE, \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN * 2},\n            num_warps=4,\n        ),\n        Config(\n            {\"N_BLOCK\": BF16_LOAD_SIZE, \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN * 4},\n            num_warps=8,\n        ),\n        Config(\n            {\"N_BLOCK\": BF16_LOAD_SIZE, \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN * 8},\n            num_warps=8,\n        ),\n        Config(\n            {\"N_BLOCK\": BF16_LOAD_SIZE, \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN * 16},\n            num_warps=8,\n        ),\n        Config(\n            {\"N_BLOCK\": BF16_LOAD_SIZE * 2, \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN},\n            num_warps=4,\n        ),\n        Config(\n            {\"N_BLOCK\": BF16_LOAD_SIZE * 2, \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN * 2},\n            num_warps=8,\n        ),\n        Config(\n            {\"N_BLOCK\": BF16_LOAD_SIZE * 2, \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN * 4},\n            num_warps=8,\n        ),\n        Config(\n            {\"N_BLOCK\": BF16_LOAD_SIZE * 2, \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN * 8},\n            num_warps=8,\n        ),\n        Config(\n            {\n                \"N_BLOCK\": BF16_LOAD_SIZE * 2,\n                \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN * 16,\n            },\n            num_warps=8,\n        ),\n    ],\n    key=[\"M\", \"N\"],\n)\n@triton.jit\ndef _layer_norm_backward_dw_db_partial(\n    dy_ptr,\n    x_ptr,\n    x_invstd_ptr,\n    x_mean_ptr,\n    dw_partial_buf_ptr,\n    db_partial_buf_ptr,\n    M: tl.constexpr,\n    N: tl.constexpr,\n    BUF_N_STRIDE: tl.constexpr,\n    N_BLOCK: tl.constexpr,\n    M_PARTIAL_REDUCE: tl.constexpr,\n):\n    m_idx = (tl.program_id(0) * M_PARTIAL_REDUCE + tl.arange(0, M_PARTIAL_REDUCE))[:, None]\n    m_mask = m_idx < M\n    n_idx = tl.program_id(1) * N_BLOCK + tl.arange(0, N_BLOCK)\n    n_mask = n_idx < N\n    idx = N * m_idx + n_idx[None, :]\n    mask = m_mask & n_mask[None, :]\n    x = tl.load(x_ptr + idx, mask, other=0).to(tl.float32)\n    x_mean = tl.load(x_mean_ptr + m_idx, m_mask, other=0).to(tl.float32)\n    x_invstd = tl.load(x_invstd_ptr + m_idx, m_mask, other=0).to(tl.float32)\n    x_hat = (x - x_mean) * x_invstd\n    dy = tl.load(dy_ptr + idx, mask, other=0).to(tl.float32)\n    dw_partial = tl.sum(dy * x_hat, axis=0)\n    db_partial = tl.sum(dy, axis=0)\n    tl.store(dw_partial_buf_ptr + BUF_N_STRIDE * n_idx + tl.program_id(0), dw_partial, n_mask)\n    tl.store(db_partial_buf_ptr + BUF_N_STRIDE * n_idx + tl.program_id(0), db_partial, n_mask)\n\n\n# %% Backward kernels for noncontiguous inputs. Using similar strided access logic as in forward.\n@triton.autotune(\n    configs=[\n        Config({\"M_BLOCK\": 1}, num_warps=1),\n        Config({\"M_BLOCK\": 2}, num_warps=1),\n        Config({\"M_BLOCK\": 4}, num_warps=2),\n        Config({\"M_BLOCK\": 8}, num_warps=4),\n        Config({\"M_BLOCK\": 16}, num_warps=8),\n        Config({\"M_BLOCK\": 32}, num_warps=8),\n        Config({\"M_BLOCK\": 64}, num_warps=8),\n    ],\n    key=[\"M\", \"N\"],\n)\n@triton.heuristics(\n    values={\n        \"N_BLOCK\": lambda kwargs: triton.next_power_of_2(kwargs[\"N\"]),\n    },\n)\n@triton.jit\ndef _layer_norm_backward_dx_strided(\n    dy_ptr,\n    x_ptr,\n    w_ptr,\n    x_invstd_ptr,\n    x_mean_ptr,\n    dx_ptr,\n    M: tl.constexpr,\n    N: tl.constexpr,\n    M_BLOCK: tl.constexpr,\n    N_BLOCK: tl.constexpr,\n    D0: tl.constexpr,\n    D1: tl.constexpr,\n    D2: tl.constexpr,\n    D3: tl.constexpr,\n    S0: tl.constexpr,\n    S1: tl.constexpr,\n    S2: tl.constexpr,\n    S3: tl.constexpr,\n):\n    m_logic_idx = tl.program_id(0) * M_BLOCK + tl.arange(0, M_BLOCK)\n    m_mask = m_logic_idx < M\n    m_logic_idx_0 = m_logic_idx // (D1 * D2) % D0\n    m_logic_idx_1 = m_logic_idx // D2 % D1\n    m_logic_idx_2 = m_logic_idx % D2\n    m_idx = m_logic_idx_0 * S0 + m_logic_idx_1 * S1 + m_logic_idx_2 * S2\n    n_logic_idx = tl.arange(0, N_BLOCK)\n    n_mask = n_logic_idx < N\n    n_idx = n_logic_idx * S3\n    mask = m_mask[:, None] & n_mask[None, :]\n    x_idx = m_idx[:, None] + n_idx[None, :]\n    x = tl.load(x_ptr + x_idx, mask, other=0).to(tl.float32)\n    x_mean = tl.load(x_mean_ptr + m_logic_idx, m_mask, other=0).to(tl.float32)[:, None]\n    x_invstd = tl.load(x_invstd_ptr + m_logic_idx, m_mask, other=0).to(tl.float32)[:, None]\n    x_hat = (x - x_mean) * x_invstd\n    dy_idx = N * m_logic_idx[:, None] + n_logic_idx[None, :]\n    dy = tl.load(dy_ptr + dy_idx, mask, other=0).to(tl.float32)\n    w = tl.load(w_ptr + n_logic_idx, n_mask, other=0).to(tl.float32)[None, :]\n    c1 = tl.sum(x_hat * dy * w, axis=1) / N\n    c2 = tl.sum(dy * w, axis=1) / N\n    dx = x_invstd * (dy * w - c1[:, None] * x_hat - c2[:, None])\n    tl.store(dx_ptr + x_idx, dx, mask)\n\n\n@triton.autotune(\n    configs=[\n        Config(\n            {\"N_BLOCK\": BF16_LOAD_SIZE, \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN},\n            num_warps=2,\n        ),\n        Config(\n            {\"N_BLOCK\": BF16_LOAD_SIZE, \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN * 2},\n            num_warps=4,\n        ),\n        Config(\n            {\"N_BLOCK\": BF16_LOAD_SIZE, \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN * 4},\n            num_warps=8,\n        ),\n        Config(\n            {\"N_BLOCK\": BF16_LOAD_SIZE, \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN * 8},\n            num_warps=8,\n        ),\n        Config(\n            {\"N_BLOCK\": BF16_LOAD_SIZE, \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN * 16},\n            num_warps=8,\n        ),\n        Config(\n            {\"N_BLOCK\": BF16_LOAD_SIZE * 2, \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN},\n            num_warps=4,\n        ),\n        Config(\n            {\"N_BLOCK\": BF16_LOAD_SIZE * 2, \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN * 2},\n            num_warps=8,\n        ),\n        Config(\n            {\"N_BLOCK\": BF16_LOAD_SIZE * 2, \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN * 4},\n            num_warps=8,\n        ),\n        Config(\n            {\"N_BLOCK\": BF16_LOAD_SIZE * 2, \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN * 8},\n            num_warps=8,\n        ),\n        Config(\n            {\n                \"N_BLOCK\": BF16_LOAD_SIZE * 2,\n                \"M_PARTIAL_REDUCE\": PARTIAL_REDUCE_MIN * 16,\n            },\n            num_warps=8,\n        ),\n    ],\n    key=[\"M\", \"N\"],\n)\n@triton.jit\ndef _layer_norm_backward_dw_db_partial_strided(\n    dy_ptr,\n    x_ptr,\n    x_invstd_ptr,\n    x_mean_ptr,\n    dw_partial_buf_ptr,\n    db_partial_buf_ptr,\n    M: tl.constexpr,\n    N: tl.constexpr,\n    BUF_N_STRIDE: tl.constexpr,\n    N_BLOCK: tl.constexpr,\n    M_PARTIAL_REDUCE: tl.constexpr,\n    D0: tl.constexpr,\n    D1: tl.constexpr,\n    D2: tl.constexpr,\n    D3: tl.constexpr,\n    S0: tl.constexpr,\n    S1: tl.constexpr,\n    S2: tl.constexpr,\n    S3: tl.constexpr,\n):\n    m_logic_idx = tl.program_id(0) * M_PARTIAL_REDUCE + tl.arange(0, M_PARTIAL_REDUCE)\n    m_mask = m_logic_idx < M\n    m_logic_idx_0 = m_logic_idx // (D1 * D2) % D0\n    m_logic_idx_1 = m_logic_idx // D2 % D1\n    m_logic_idx_2 = m_logic_idx % D2\n    m_idx = m_logic_idx_0 * S0 + m_logic_idx_1 * S1 + m_logic_idx_2 * S2\n    n_logic_idx = tl.program_id(1) * N_BLOCK + tl.arange(0, N_BLOCK)\n    n_mask = n_logic_idx < N\n    n_idx = n_logic_idx * S3\n    mask = m_mask[:, None] & n_mask[None, :]\n    x_idx = m_idx[:, None] + n_idx[None, :]\n    x = tl.load(x_ptr + x_idx, mask, other=0).to(tl.float32)\n    x_mean = tl.load(x_mean_ptr + m_logic_idx, m_mask, other=0).to(tl.float32)[:, None]\n    x_invstd = tl.load(x_invstd_ptr + m_logic_idx, m_mask, other=0).to(tl.float32)[:, None]\n    x_hat = (x - x_mean) * x_invstd\n    dy_idx = N * m_logic_idx[:, None] + n_logic_idx[None, :]\n    dy = tl.load(dy_ptr + dy_idx, mask, other=0).to(tl.float32)\n    dw_partial = tl.sum(dy * x_hat, axis=0)\n    db_partial = tl.sum(dy, axis=0)\n    tl.store(\n        dw_partial_buf_ptr + BUF_N_STRIDE * n_logic_idx + tl.program_id(0),\n        dw_partial,\n        n_mask,\n    )\n    tl.store(\n        db_partial_buf_ptr + BUF_N_STRIDE * n_logic_idx + tl.program_id(0),\n        db_partial,\n        n_mask,\n    )\n\n\n# %% Reduce partial accumulator buffers along the row dimension. Straightforward.\n@triton.jit\ndef _layer_norm_backward_buf_reduce(\n    partial_buf_ptr,\n    output_ptr,\n    N: tl.constexpr,\n    M: tl.constexpr,\n    N_STRIDE: tl.constexpr,\n    M_STRIDE: tl.constexpr,\n):\n    idx = N_STRIDE * tl.program_id(0) + M_STRIDE * tl.arange(0, M)\n    mask = tl.program_id(0) < N\n    x = tl.sum(tl.load(partial_buf_ptr + idx, mask, other=0).to(tl.float32), axis=0)\n    tl.store(output_ptr + tl.program_id(0), x, mask)\n"
  },
  {
    "path": "apex/contrib/openfold_triton/_layer_norm_config_ampere.py",
    "content": "# © 2023 NVIDIA CORPORATION & AFFILIATES\n\nfrom triton import Config\n\n# Mapping schema: Dict[\n#   dap_size: int, Dict[\n#     function_name: str, Dict[\n#       input_shape: Tuple[int, int], config: triton.Config\n#     ]\n#   ]\n# ]\n_auto_tuned_config_ampere = {\n    0: {\n        \"_layer_norm_backward_dw_db_partial\": {\n            (65536, 128): Config(\n                {\"N_BLOCK\": 16, \"M_PARTIAL_REDUCE\": 512}, num_warps=8, num_stages=2\n            ),\n            (32768, 256): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 256}, num_warps=8, num_stages=2\n            ),\n        },\n        \"_layer_norm_backward_dw_db_partial_strided\": {\n            (65536, 128): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 512}, num_warps=8, num_stages=2\n            )\n        },\n        \"_layer_norm_backward_dx\": {\n            (65536, 128): Config({\"M_BLOCK\": 4}, num_warps=2, num_stages=2),\n            (32768, 256): Config({\"M_BLOCK\": 4}, num_warps=2, num_stages=2),\n        },\n        \"_layer_norm_backward_dx_strided\": {\n            (65536, 128): Config({\"M_BLOCK\": 2}, num_warps=1, num_stages=2)\n        },\n        \"_layer_norm_forward\": {\n            (65536, 128): Config({\"M_BLOCK\": 32}, num_warps=8, num_stages=2),\n            (32768, 256): Config({\"M_BLOCK\": 16}, num_warps=8, num_stages=2),\n        },\n        \"_layer_norm_forward_strided\": {\n            (65536, 128): Config({\"M_BLOCK\": 8}, num_warps=4, num_stages=2)\n        },\n    },\n    2: {\n        \"_layer_norm_backward_dw_db_partial\": {\n            (65536, 128): Config(\n                {\"N_BLOCK\": 16, \"M_PARTIAL_REDUCE\": 512}, num_warps=8, num_stages=2\n            ),\n            (32768, 128): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 256}, num_warps=8, num_stages=2\n            ),\n            (16384, 256): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 256}, num_warps=8, num_stages=2\n            ),\n        },\n        \"_layer_norm_backward_dw_db_partial_strided\": {\n            (32768, 128): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 512}, num_warps=8, num_stages=2\n            )\n        },\n        \"_layer_norm_backward_dx\": {\n            (65536, 128): Config({\"M_BLOCK\": 8}, num_warps=4, num_stages=2),\n            (32768, 128): Config({\"M_BLOCK\": 8}, num_warps=4, num_stages=2),\n            (16384, 256): Config({\"M_BLOCK\": 4}, num_warps=2, num_stages=2),\n        },\n        \"_layer_norm_backward_dx_strided\": {\n            (32768, 128): Config({\"M_BLOCK\": 2}, num_warps=1, num_stages=2)\n        },\n        \"_layer_norm_forward\": {\n            (65536, 128): Config({\"M_BLOCK\": 32}, num_warps=8, num_stages=2),\n            (32768, 128): Config({\"M_BLOCK\": 32}, num_warps=8, num_stages=2),\n            (16384, 256): Config({\"M_BLOCK\": 16}, num_warps=8, num_stages=2),\n        },\n        \"_layer_norm_forward_strided\": {\n            (32768, 128): Config({\"M_BLOCK\": 8}, num_warps=4, num_stages=2)\n        },\n    },\n    4: {\n        \"_layer_norm_backward_dw_db_partial\": {\n            (65536, 128): Config(\n                {\"N_BLOCK\": 16, \"M_PARTIAL_REDUCE\": 512}, num_warps=8, num_stages=2\n            ),\n            (16384, 128): Config(\n                {\"N_BLOCK\": 16, \"M_PARTIAL_REDUCE\": 256}, num_warps=8, num_stages=2\n            ),\n            (8192, 256): Config(\n                {\"N_BLOCK\": 16, \"M_PARTIAL_REDUCE\": 256}, num_warps=8, num_stages=2\n            ),\n        },\n        \"_layer_norm_backward_dw_db_partial_strided\": {\n            (16384, 128): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 512}, num_warps=8, num_stages=2\n            )\n        },\n        \"_layer_norm_backward_dx\": {\n            (65536, 128): Config({\"M_BLOCK\": 8}, num_warps=4, num_stages=2),\n            (16384, 128): Config({\"M_BLOCK\": 4}, num_warps=2, num_stages=2),\n            (8192, 256): Config({\"M_BLOCK\": 1}, num_warps=1, num_stages=2),\n        },\n        \"_layer_norm_backward_dx_strided\": {\n            (16384, 128): Config({\"M_BLOCK\": 2}, num_warps=1, num_stages=2)\n        },\n        \"_layer_norm_forward\": {\n            (65536, 128): Config({\"M_BLOCK\": 32}, num_warps=8, num_stages=2),\n            (16384, 128): Config({\"M_BLOCK\": 32}, num_warps=8, num_stages=2),\n            (8192, 256): Config({\"M_BLOCK\": 16}, num_warps=8, num_stages=2),\n        },\n        \"_layer_norm_forward_strided\": {\n            (16384, 128): Config({\"M_BLOCK\": 8}, num_warps=4, num_stages=2)\n        },\n    },\n    8: {\n        \"_layer_norm_backward_dw_db_partial\": {\n            (65536, 128): Config(\n                {\"N_BLOCK\": 16, \"M_PARTIAL_REDUCE\": 512}, num_warps=8, num_stages=2\n            ),\n            (8192, 128): Config(\n                {\"N_BLOCK\": 16, \"M_PARTIAL_REDUCE\": 256}, num_warps=8, num_stages=2\n            ),\n            (4096, 256): Config(\n                {\"N_BLOCK\": 16, \"M_PARTIAL_REDUCE\": 256}, num_warps=8, num_stages=2\n            ),\n        },\n        \"_layer_norm_backward_dw_db_partial_strided\": {\n            (8192, 128): Config({\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 512}, num_warps=8, num_stages=2)\n        },\n        \"_layer_norm_backward_dx\": {\n            (65536, 128): Config({\"M_BLOCK\": 8}, num_warps=4, num_stages=2),\n            (8192, 128): Config({\"M_BLOCK\": 2}, num_warps=1, num_stages=2),\n            (4096, 256): Config({\"M_BLOCK\": 1}, num_warps=1, num_stages=2),\n        },\n        \"_layer_norm_backward_dx_strided\": {\n            (8192, 128): Config({\"M_BLOCK\": 1}, num_warps=1, num_stages=2)\n        },\n        \"_layer_norm_forward\": {\n            (65536, 128): Config({\"M_BLOCK\": 32}, num_warps=8, num_stages=2),\n            (8192, 128): Config({\"M_BLOCK\": 16}, num_warps=8, num_stages=2),\n            (4096, 256): Config({\"M_BLOCK\": 8}, num_warps=4, num_stages=2),\n        },\n        \"_layer_norm_forward_strided\": {\n            (8192, 128): Config({\"M_BLOCK\": 8}, num_warps=4, num_stages=2)\n        },\n    },\n}\n\n_auto_tuned_config_ampere[1] = _auto_tuned_config_ampere[0]\n"
  },
  {
    "path": "apex/contrib/openfold_triton/_layer_norm_config_hopper.py",
    "content": "# © 2023 NVIDIA CORPORATION & AFFILIATES\n\nfrom triton import Config\n\n# Mapping schema: Dict[\n#   dap_size: int, Dict[\n#     function_name: str, Dict[\n#       input_shape: Tuple[int, int], config: triton.Config\n#     ]\n#   ]\n# ]\n_auto_tuned_config_hopper = {\n    0: {\n        \"_layer_norm_backward_dw_db_partial\": {\n            (65536, 128): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 256}, num_warps=8, num_stages=2\n            ),\n            (32768, 256): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 256}, num_warps=8, num_stages=2\n            ),\n        },\n        \"_layer_norm_backward_dw_db_partial_strided\": {\n            (65536, 128): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 512}, num_warps=8, num_stages=2\n            )\n        },\n        \"_layer_norm_backward_dx\": {\n            (65536, 128): Config({\"M_BLOCK\": 8}, num_warps=4, num_stages=2),\n            (32768, 256): Config({\"M_BLOCK\": 8}, num_warps=4, num_stages=2),\n        },\n        \"_layer_norm_backward_dx_strided\": {\n            (65536, 128): Config({\"M_BLOCK\": 2}, num_warps=1, num_stages=2)\n        },\n        \"_layer_norm_forward\": {\n            (65536, 128): Config({\"M_BLOCK\": 64}, num_warps=8, num_stages=2),\n            (32768, 256): Config({\"M_BLOCK\": 32}, num_warps=8, num_stages=2),\n        },\n        \"_layer_norm_forward_strided\": {\n            (65536, 128): Config({\"M_BLOCK\": 8}, num_warps=4, num_stages=2)\n        },\n    },\n    2: {\n        \"_layer_norm_backward_dw_db_partial\": {\n            (65536, 128): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 256}, num_warps=8, num_stages=2\n            ),\n            (32768, 128): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 256}, num_warps=8, num_stages=2\n            ),\n            (16384, 256): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 256}, num_warps=8, num_stages=2\n            ),\n        },\n        \"_layer_norm_backward_dw_db_partial_strided\": {\n            (32768, 128): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 512}, num_warps=8, num_stages=2\n            )\n        },\n        \"_layer_norm_backward_dx\": {\n            (65536, 128): Config({\"M_BLOCK\": 4}, num_warps=2, num_stages=2),\n            (32768, 128): Config({\"M_BLOCK\": 16}, num_warps=8, num_stages=2),\n            (16384, 256): Config({\"M_BLOCK\": 4}, num_warps=2, num_stages=2),\n        },\n        \"_layer_norm_backward_dx_strided\": {\n            (32768, 128): Config({\"M_BLOCK\": 2}, num_warps=1, num_stages=2)\n        },\n        \"_layer_norm_forward\": {\n            (65536, 128): Config({\"M_BLOCK\": 64}, num_warps=8, num_stages=2),\n            (32768, 128): Config({\"M_BLOCK\": 64}, num_warps=8, num_stages=2),\n            (16384, 256): Config({\"M_BLOCK\": 64}, num_warps=8, num_stages=2),\n        },\n        \"_layer_norm_forward_strided\": {\n            (32768, 128): Config({\"M_BLOCK\": 8}, num_warps=4, num_stages=2)\n        },\n    },\n    4: {\n        \"_layer_norm_backward_dw_db_partial\": {\n            (65536, 128): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 256}, num_warps=8, num_stages=2\n            ),\n            (16384, 128): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 256}, num_warps=8, num_stages=2\n            ),\n            (8192, 256): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 256}, num_warps=8, num_stages=2\n            ),\n        },\n        \"_layer_norm_backward_dw_db_partial_strided\": {\n            (16384, 128): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 512}, num_warps=8, num_stages=2\n            )\n        },\n        \"_layer_norm_backward_dx\": {\n            (65536, 128): Config({\"M_BLOCK\": 4}, num_warps=2, num_stages=2),\n            (16384, 128): Config({\"M_BLOCK\": 32}, num_warps=8, num_stages=2),\n            (8192, 256): Config({\"M_BLOCK\": 8}, num_warps=4, num_stages=2),\n        },\n        \"_layer_norm_backward_dx_strided\": {\n            (16384, 128): Config({\"M_BLOCK\": 2}, num_warps=1, num_stages=2)\n        },\n        \"_layer_norm_forward\": {\n            (65536, 128): Config({\"M_BLOCK\": 64}, num_warps=8, num_stages=2),\n            (16384, 128): Config({\"M_BLOCK\": 32}, num_warps=8, num_stages=2),\n            (8192, 256): Config({\"M_BLOCK\": 16}, num_warps=8, num_stages=2),\n        },\n        \"_layer_norm_forward_strided\": {\n            (16384, 128): Config({\"M_BLOCK\": 8}, num_warps=4, num_stages=2)\n        },\n    },\n    8: {\n        \"_layer_norm_backward_dw_db_partial\": {\n            (65536, 128): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 256}, num_warps=8, num_stages=2\n            ),\n            (8192, 128): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 256}, num_warps=8, num_stages=2\n            ),\n            (4096, 256): Config(\n                {\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 256}, num_warps=8, num_stages=2\n            ),\n        },\n        \"_layer_norm_backward_dw_db_partial_strided\": {\n            (8192, 128): Config({\"N_BLOCK\": 32, \"M_PARTIAL_REDUCE\": 256}, num_warps=8, num_stages=2)\n        },\n        \"_layer_norm_backward_dx\": {\n            (65536, 128): Config({\"M_BLOCK\": 8}, num_warps=4, num_stages=2),\n            (8192, 128): Config({\"M_BLOCK\": 32}, num_warps=8, num_stages=2),\n            (4096, 256): Config({\"M_BLOCK\": 2}, num_warps=1, num_stages=2),\n        },\n        \"_layer_norm_backward_dx_strided\": {\n            (8192, 128): Config({\"M_BLOCK\": 2}, num_warps=1, num_stages=2)\n        },\n        \"_layer_norm_forward\": {\n            (65536, 128): Config({\"M_BLOCK\": 64}, num_warps=8, num_stages=2),\n            (8192, 128): Config({\"M_BLOCK\": 32}, num_warps=8, num_stages=2),\n            (4096, 256): Config({\"M_BLOCK\": 8}, num_warps=4, num_stages=2),\n        },\n        \"_layer_norm_forward_strided\": {\n            (8192, 128): Config({\"M_BLOCK\": 8}, num_warps=4, num_stages=2)\n        },\n    },\n}\n\n_auto_tuned_config_hopper[1] = _auto_tuned_config_hopper[0]\n"
  },
  {
    "path": "apex/contrib/openfold_triton/_layer_norm_forward_kernels.py",
    "content": "# © 2023 NVIDIA CORPORATION & AFFILIATES\n\nfrom packaging.version import Version\n\nimport triton\nimport triton.language as tl\nfrom triton import Config\n\nif Version(\"2.0.0\") < Version(triton.__version__):\n    rsqrt = tl.math.rsqrt\nelse:\n    rsqrt = tl.libdevice.rsqrt\n\n\n# %% Forward kernel for contiguous inputs.\n@triton.autotune(\n    configs=[\n        Config({\"M_BLOCK\": 1}, num_warps=1),\n        Config({\"M_BLOCK\": 2}, num_warps=1),\n        Config({\"M_BLOCK\": 4}, num_warps=2),\n        Config({\"M_BLOCK\": 8}, num_warps=4),\n        Config({\"M_BLOCK\": 16}, num_warps=8),\n        Config({\"M_BLOCK\": 32}, num_warps=8),\n        Config({\"M_BLOCK\": 64}, num_warps=8),\n    ],\n    key=[\"M\", \"N\"],\n)\n@triton.heuristics(\n    values={\n        \"N_BLOCK\": lambda kwargs: triton.next_power_of_2(kwargs[\"N\"]),\n    },\n)\n@triton.jit\ndef _layer_norm_forward(\n    x_ptr,\n    w_ptr,\n    b_ptr,\n    eps,\n    x_invstd_ptr,\n    x_mean_ptr,\n    y_ptr,\n    M: tl.constexpr,\n    N: tl.constexpr,\n    M_BLOCK: tl.constexpr,\n    N_BLOCK: tl.constexpr,\n):\n    m_idx = tl.program_id(0) * M_BLOCK + tl.arange(0, M_BLOCK)\n    m_mask = m_idx < M\n    n_idx = tl.arange(0, N_BLOCK)\n    n_mask = n_idx < N\n    mask = m_mask[:, None] & n_mask[None, :]\n    x = tl.load(x_ptr + N * m_idx[:, None] + n_idx[None, :], mask, other=0).to(tl.float32)\n    x_mean = tl.sum(x, 1) / N\n    tl.store(x_mean_ptr + m_idx, x_mean, m_mask)\n    x_bar = x - x_mean[:, None]\n    x_var = tl.sum(x_bar * x_bar, 1) / N\n    x_invstd = rsqrt(x_var + eps)\n    tl.store(x_invstd_ptr + m_idx, x_invstd, m_mask)\n    x_hat = x_bar * x_invstd[:, None]\n    w = tl.load(w_ptr + n_idx, n_mask, other=0).to(tl.float32)[None, :]\n    b = tl.load(b_ptr + n_idx, n_mask, other=0).to(tl.float32)[None, :]\n    y = w * x_hat + b\n    tl.store(y_ptr + N * m_idx[:, None] + n_idx[None, :], y, mask)\n\n\n# %% Forward kernel for noncontiguous inputs. Using strided access to avoid extra memory overhead.\n@triton.autotune(\n    configs=[\n        Config({\"M_BLOCK\": 1}, num_warps=1),\n        Config({\"M_BLOCK\": 2}, num_warps=1),\n        Config({\"M_BLOCK\": 4}, num_warps=2),\n        Config({\"M_BLOCK\": 8}, num_warps=4),\n        Config({\"M_BLOCK\": 16}, num_warps=8),\n        Config({\"M_BLOCK\": 32}, num_warps=8),\n        Config({\"M_BLOCK\": 64}, num_warps=8),\n    ],\n    key=[\"M\", \"N\"],\n)\n@triton.heuristics(\n    values={\n        \"N_BLOCK\": lambda kwargs: triton.next_power_of_2(kwargs[\"N\"]),\n    },\n)\n@triton.jit\ndef _layer_norm_forward_strided(\n    x_ptr,\n    w_ptr,\n    b_ptr,\n    eps,\n    x_invstd_ptr,\n    x_mean_ptr,\n    y_ptr,\n    M: tl.constexpr,\n    N: tl.constexpr,\n    M_BLOCK: tl.constexpr,\n    N_BLOCK: tl.constexpr,\n    D0: tl.constexpr,\n    D1: tl.constexpr,\n    D2: tl.constexpr,\n    D3: tl.constexpr,\n    S0: tl.constexpr,\n    S1: tl.constexpr,\n    S2: tl.constexpr,\n    S3: tl.constexpr,\n):\n    m_logic_idx = tl.program_id(0) * M_BLOCK + tl.arange(0, M_BLOCK)\n    m_mask = m_logic_idx < M\n    m_logic_idx_0 = m_logic_idx // (D1 * D2) % D0\n    m_logic_idx_1 = m_logic_idx // D2 % D1\n    m_logic_idx_2 = m_logic_idx % D2\n    m_idx = m_logic_idx_0 * S0 + m_logic_idx_1 * S1 + m_logic_idx_2 * S2\n    n_logic_idx = tl.arange(0, N_BLOCK)\n    n_mask = n_logic_idx < N\n    n_idx = n_logic_idx * S3\n    mask = m_mask[:, None] & n_mask[None, :]\n    x_idx = m_idx[:, None] + n_idx[None, :]\n    x = tl.load(x_ptr + x_idx, mask, other=0).to(tl.float32)\n    x_mean = tl.sum(x, 1) / N\n    tl.store(x_mean_ptr + m_logic_idx, x_mean, m_mask)\n    x_bar = x - x_mean[:, None]\n    x_var = tl.sum(x_bar * x_bar, 1) / N\n    x_invstd = rsqrt(x_var + eps)\n    tl.store(x_invstd_ptr + m_logic_idx, x_invstd, m_mask)\n    x_hat = x_bar * x_invstd[:, None]\n    w = tl.load(w_ptr + n_logic_idx, n_mask, other=0).to(tl.float32)[None, :]\n    b = tl.load(b_ptr + n_logic_idx, n_mask, other=0).to(tl.float32)[None, :]\n    y = w * x_hat + b\n    tl.store(y_ptr + N * m_logic_idx[:, None] + n_logic_idx[None, :], y, mask)\n"
  },
  {
    "path": "apex/contrib/openfold_triton/_mha_kernel.py",
    "content": "# © 2023 NVIDIA CORPORATION & AFFILIATES\n\nimport triton\nimport triton.language as tl\n\n\ndef init_to_zero(name):\n    return lambda nargs: nargs[name].zero_()\n\n\ndef get_configs_fwd():\n    configs = []\n    for num_stages in [0, 1, 2, 3, 4]:\n        for block_m in [32, 64, 128]:\n            for block_n in [16, 32, 64, 128]:\n                if block_n > block_m:\n                    continue\n                for num_warps in [1, 2, 4, 8]:\n                    if 32 * num_warps * 32 > block_m * block_n:\n                        continue\n                    configs.append(\n                        triton.Config(\n                            {\"BLOCK_M\": block_m, \"BLOCK_N\": block_n},\n                            num_stages=num_stages,\n                            num_warps=num_warps,\n                        )\n                    )\n    return configs\n\n\n\"\"\"\n@triton.autotune(\n    configs=get_configs_fwd(), \n    key=['Z', 'H', 'N_CTX', 'H_DIM', 'IS_TRAINING'],\n)\n\"\"\"\n\n\n@triton.heuristics(\n    {\n        \"EVEN_M\": lambda args: args[\"N_CTX\"] % args[\"BLOCK_M\"] == 0,\n        \"EVEN_N\": lambda args: args[\"N_CTX\"] % args[\"BLOCK_N\"] == 0,\n        \"EVEN_HEADDIM\": lambda args: args[\"H_DIM\"] == args[\"BLOCK_DMODEL\"],\n    }\n)\n@triton.jit\ndef _attention_core(\n    Q,\n    K,\n    V,\n    Mask,\n    Bias,\n    sm_scale,\n    L,\n    M,\n    Out,\n    stride_qz,\n    stride_qh,\n    stride_qm,\n    stride_qk,\n    stride_kz,\n    stride_kh,\n    stride_kn,\n    stride_kk,\n    stride_vz,\n    stride_vh,\n    stride_vk,\n    stride_vn,\n    stride_oz,\n    stride_oh,\n    stride_om,\n    stride_on,\n    stride_bz,\n    stride_bh,\n    stride_bm,\n    stride_bn,\n    stride_mz,\n    stride_mh,\n    stride_mm,\n    stride_mn,\n    Z,\n    H,\n    N_CTX,\n    H_DIM,\n    BATCH,  # 256 8 128 32 1\n    inf: tl.constexpr,\n    IS_TRAINING: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    use_mask: tl.constexpr,\n    use_bias: tl.constexpr,\n    EVEN_M: tl.constexpr,\n    EVEN_N: tl.constexpr,\n    EVEN_HEADDIM: tl.constexpr,\n):\n    start_m = tl.program_id(0)\n    off_hz = tl.program_id(1)\n    off_b = off_hz // H\n    off_h = off_hz % H\n    # initialize offsets\n    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    offs_n = tl.arange(0, BLOCK_N)\n    offs_d = tl.arange(0, BLOCK_DMODEL)\n    off_q = (\n        off_b * stride_qz\n        + off_h * stride_qh\n        + offs_m[:, None] * stride_qm\n        + offs_d[None, :] * stride_qk\n    )\n    off_k = (\n        off_b * stride_kz\n        + off_h * stride_kh\n        + offs_n[None, :] * stride_kn\n        + offs_d[:, None] * stride_kk\n    )\n    off_v = (\n        off_b * stride_vz\n        + off_h * stride_vh\n        + offs_n[:, None] * stride_vk\n        + offs_d[None, :] * stride_vn\n    )\n    # Initialize pointers to Q, K, V\n    q_ptrs = Q + off_q\n    k_ptrs = K + off_k\n    v_ptrs = V + off_v\n\n    # Initialize pointers to bias, mask\n    if use_bias:\n        batch_2 = Z // BATCH\n        off_hz_bias = (off_hz // (batch_2 * H) * H) + (off_hz % H)\n        offs_base_bias = off_hz_bias * (N_CTX * N_CTX) + offs_m[:, None] * N_CTX + offs_n[None, :]\n        \"\"\"\n        off_b = off_hz // H\n        off_h = off_hz % H\n        bias_ptrs = Bias + off_b * stride_bz + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn)\n        \"\"\"\n\n    if use_mask:\n        # off_hz_mask = (off_hz // H)\n        # offs_base_mask = off_hz_mask * N_CTX\n        off_b = off_hz // H\n        off_h = off_hz % H\n        mask_ptrs = (\n            Mask\n            + off_b * stride_mz\n            + off_h * stride_mh\n            + (offs_m[:, None] * stride_mm + offs_n[None, :] * stride_mn)\n        )\n\n    # initialize pointer to m and l\n    m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n    l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)\n    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n    # load q: it will stay in SRAM throughout\n    if EVEN_M & EVEN_N:\n        if EVEN_HEADDIM:\n            q = tl.load(q_ptrs)\n        else:\n            q = tl.load(q_ptrs, mask=offs_d[None, :] < H_DIM, other=0.0)\n    else:\n        if EVEN_HEADDIM:\n            q = tl.load(q_ptrs, mask=offs_m[:, None] < N_CTX, other=0.0)\n        else:\n            q = tl.load(\n                q_ptrs,\n                mask=(offs_m[:, None] < N_CTX) & (offs_d[None, :] < H_DIM),\n                other=0.0,\n            )\n\n    # loop over k, v and update accumulator\n    #  (start_m + 1) * BLOCK_M\n    for start_n in range(0, N_CTX, BLOCK_N):\n        start_n = tl.multiple_of(start_n, BLOCK_N)\n        # -- compute qk ----\n        if EVEN_N & EVEN_M:  # If we just do \"if EVEN_N\", there seems to be some race condition\n            if EVEN_HEADDIM:\n                k = tl.load(k_ptrs)\n            else:\n                k = tl.load(k_ptrs, mask=offs_d[:, None] < H_DIM, other=0.0)\n        else:\n            if EVEN_HEADDIM:\n                k = tl.load(k_ptrs, mask=(start_n + offs_n)[None, :] < N_CTX, other=0.0)\n            else:\n                k = tl.load(\n                    k_ptrs,\n                    mask=((start_n + offs_n)[None, :] < N_CTX) & (offs_d[:, None] < H_DIM),\n                    other=0.0,\n                )\n\n        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n\n        # qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n        if use_bias:\n            qk += tl.dot(q * sm_scale.to(tl.bfloat16), k).to(tl.bfloat16)\n            qk += tl.where((start_n + offs_n)[None, :] < N_CTX, 0, -inf).to(tl.bfloat16)\n            if EVEN_M & EVEN_N:\n                bias_data = tl.load(Bias + offs_base_bias + start_n)\n            else:\n                bias_load_mask = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n                bias_load_mask = tl.where(offs_m[:, None] >= N_CTX, 1.0, bias_load_mask)\n                bias_load_mask = tl.where((start_n + offs_n)[None, :] >= N_CTX, 1.0, bias_load_mask)\n                bias_data = tl.load(\n                    Bias + offs_base_bias + start_n,\n                    mask=(bias_load_mask == 0.0),\n                    other=0.0,\n                )\n            qk = qk + bias_data\n        else:\n            qk += tl.dot(q, k)\n            qk += tl.where((start_n + offs_n)[None, :] < N_CTX, 0, -inf)\n\n        qk = qk.to(tl.bfloat16)\n\n        if use_mask:\n            if EVEN_M & EVEN_N:\n                mask_data = tl.load(mask_ptrs + start_n).to(tl.int32)\n            else:\n                mask_data = tl.load(\n                    mask_ptrs + start_n,\n                    mask=(offs_m[:, None] < N_CTX) & ((start_n + offs_n)[None, :] < N_CTX),\n                    other=0,\n                ).to(tl.int32)\n            qk += tl.where(mask_data == 0, -inf, 0.0)\n\n        if use_bias:\n            # compute new m\n            m_curr = tl.maximum(tl.max(qk, 1), m_prev)\n            # correct old l\n            l_prev *= tl.exp(m_prev - m_curr)\n            # attention weights\n            p = tl.exp(qk - m_curr[:, None])\n        else:\n            m_curr = tl.maximum(tl.max(qk, 1) * sm_scale, m_prev)\n            l_prev *= tl.exp(m_prev - m_curr)\n            p = tl.exp(qk * sm_scale - m_curr[:, None])\n\n        l_curr = tl.sum(p, 1) + l_prev\n        # rescale operands of matmuls\n        l_rcp = 1.0 / l_curr\n        p *= l_rcp[:, None]\n        acc *= (l_prev * l_rcp)[:, None]\n        # update acc\n        p = p.to(Q.dtype.element_ty)\n\n        if EVEN_N & EVEN_M:  # If we just do \"if EVEN_N\", there seems to be some race condition\n            if EVEN_HEADDIM:\n                v = tl.load(v_ptrs)\n            else:\n                v = tl.load(v_ptrs, mask=offs_d[None, :] < H_DIM, other=0.0)\n        else:\n            if EVEN_HEADDIM:\n                v = tl.load(v_ptrs, mask=(start_n + offs_n)[:, None] < N_CTX, other=0.0)\n            else:\n                v = tl.load(\n                    v_ptrs,\n                    mask=((start_n + offs_n)[:, None] < N_CTX) & (offs_d[None, :] < H_DIM),\n                    other=0.0,\n                )\n        acc += tl.dot(p, v)\n        # update m_i and l_i\n        l_prev = l_curr\n        m_prev = m_curr\n        # update pointers\n        k_ptrs += BLOCK_N * stride_kn\n        v_ptrs += BLOCK_N * stride_vk\n    # rematerialize offsets to save registers\n    start_m = tl.program_id(0)\n    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    # write back l and m\n    if IS_TRAINING:\n        l_ptrs = L + off_hz * N_CTX + offs_m\n        m_ptrs = M + off_hz * N_CTX + offs_m\n        tl.store(l_ptrs, l_prev)\n        tl.store(m_ptrs, m_prev)\n    # initialize pointers to output\n    offs_n = tl.arange(0, BLOCK_DMODEL)\n    off_o = (\n        off_b * stride_oz\n        + off_h * stride_oh\n        + offs_m[:, None] * stride_om\n        + offs_n[None, :] * stride_on\n    )\n    out_ptrs = Out + off_o\n    if EVEN_M:\n        if EVEN_HEADDIM:\n            tl.store(out_ptrs, acc.to(Q.dtype.element_ty))\n        else:\n            tl.store(out_ptrs, acc.to(Q.dtype.element_ty), mask=offs_n[None, :] < H_DIM)\n    else:\n        if EVEN_HEADDIM:\n            tl.store(out_ptrs, acc.to(Q.dtype.element_ty), mask=offs_m[:, None] < N_CTX)\n        else:\n            tl.store(\n                out_ptrs,\n                acc.to(Q.dtype.element_ty),\n                mask=(offs_m[:, None] < N_CTX) & (offs_n[None, :] < H_DIM),\n            )\n    # tl.store(out_ptrs, acc.to(Q.dtype.element_ty), mask=out_store_mask)\n\n\n@triton.jit\ndef _bwd_preprocess(\n    Out,\n    DO,\n    L,\n    NewDO,\n    Delta,\n    stride_ob,\n    stride_oh,\n    stride_om,\n    stride_ok,\n    stride_dob,\n    stride_doh,\n    stride_dom,\n    stride_dok,\n    BLOCK_M: tl.constexpr,\n    D_HEAD: tl.constexpr,\n):\n    off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n    off_n = tl.arange(0, D_HEAD)\n    # load\n    o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n    do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n    denom = tl.load(L + off_m).to(tl.float32)\n    # compute\n    do = do / denom[:, None]\n    delta = tl.sum(o * do, axis=1)\n    # write-back\n    tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)\n    tl.store(Delta + off_m, delta)\n\n\ndef get_configs_bwd():\n    configs = []\n    for num_stages in [0, 1, 2, 3, 4]:\n        for block_m in [32, 64, 128]:\n            for block_n in [16, 32, 64, 128]:\n                if block_n > block_m:\n                    continue\n                for num_warps in [1, 2, 4, 8]:\n                    if 32 * num_warps * 32 > block_m * block_n:\n                        continue\n                    configs.append(\n                        triton.Config(\n                            {\"BLOCK_M\": block_m, \"BLOCK_N\": block_n},\n                            num_stages=num_stages,\n                            num_warps=num_warps,\n                            pre_hook=init_to_zero(\"DQ\"),\n                        )\n                    )\n    return configs\n\n\n\"\"\"\n@triton.autotune(\n    configs=get_configs_bwd(),\n    key=['Z', 'H', 'N_CTX', 'H_DIM'],\n)\n\"\"\"\n\n\n@triton.heuristics(\n    {\n        \"EVEN_M\": lambda args: args[\"N_CTX\"] % args[\"BLOCK_M\"] == 0,\n        \"EVEN_N\": lambda args: args[\"N_CTX\"] % args[\"BLOCK_N\"] == 0,\n        \"EVEN_HEADDIM\": lambda args: args[\"H_DIM\"] == args[\"BLOCK_DMODEL\"],\n    }\n)\n@triton.jit\ndef _bwd_kernel(\n    Q,\n    K,\n    V,\n    Mask,\n    Bias,\n    sm_scale,\n    Out,\n    DO,\n    DQ,\n    DK,\n    DV,\n    DP,\n    L,\n    M,\n    D,\n    stride_qz,\n    stride_qh,\n    stride_qm,\n    stride_qk,\n    stride_kz,\n    stride_kh,\n    stride_kn,\n    stride_kk,\n    stride_vz,\n    stride_vh,\n    stride_vk,\n    stride_vn,\n    stride_mz,\n    stride_mh,\n    stride_mm,\n    stride_mn,\n    stride_bz,\n    stride_bh,\n    stride_bm,\n    stride_bn,\n    stride_dpz,\n    stride_dph,\n    stride_dpm,\n    stride_dpn,\n    stride_dob,\n    stride_doh,\n    stride_dom,\n    stride_dok,\n    stride_dqb,\n    stride_dqh,\n    stride_dqm,\n    stride_dqk,\n    stride_dkb,\n    stride_dkh,\n    stride_dkn,\n    stride_dkk,\n    stride_dvb,\n    stride_dvh,\n    stride_dvn,\n    stride_dvk,\n    Z,\n    H,\n    N_CTX,\n    H_DIM,\n    # num_block,\n    inf: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    use_mask: tl.constexpr,\n    use_bias: tl.constexpr,\n    EVEN_M: tl.constexpr,\n    EVEN_N: tl.constexpr,\n    EVEN_HEADDIM: tl.constexpr,\n    SEQUENCE_PARALLEL: tl.constexpr,\n):\n    off_hz = tl.program_id(0)\n    off_b = off_hz // H\n    off_h = off_hz % H\n\n    # offset pointers for batch/head\n    Q += off_b * stride_qz + off_h * stride_qh\n    K += off_b * stride_kz + off_h * stride_kh\n    V += off_b * stride_vz + off_h * stride_vh\n    DO += off_b * stride_dob + off_h * stride_doh\n    DQ += off_b * stride_dqb + off_h * stride_dqh\n    DK += off_b * stride_dkb + off_h * stride_dkh\n    DV += off_b * stride_dvb + off_h * stride_dvh\n    DP += off_b * stride_dpz + off_h * stride_dph\n\n    if use_bias:\n        Bias += off_b * stride_bz + off_h * stride_bh\n    if use_mask:\n        # offs_base_mask = off_b * N_CTX\n        Mask += off_b * stride_mz + off_h * stride_mh\n\n    num_block_n = tl.cdiv(N_CTX, BLOCK_N)\n    for start_n in range(0, num_block_n):\n        # lo = start_n * BLOCK_M\n        lo = 0\n        # initialize row/col offsets\n        offs_qm = lo + tl.arange(0, BLOCK_M)\n        offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)  # BLOCK_M\n        offs_m = tl.arange(0, BLOCK_M)  # BLOCK_N\n        offs_k = tl.arange(0, BLOCK_DMODEL)\n        # initialize pointers to value-like data\n        q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n        k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n        v_ptrs = V + (offs_n[:, None] * stride_vk + offs_k[None, :] * stride_vn)\n        do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_k[None, :] * stride_dok)\n        dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_k[None, :] * stride_dqk)\n        dp_ptrs = DP + (offs_qm[:, None] * stride_dpm + offs_n[None, :] * stride_dpn)\n        if use_bias:\n            b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :] * stride_bn)\n        if use_mask:\n            mask_ptrs = Mask + (offs_qm[:, None] * stride_mm + offs_n[None, :] * stride_mn)\n        # pointer to row-wise quantities in value-like data\n        D_ptrs = D + off_hz * N_CTX\n        m_ptrs = M + off_hz * N_CTX\n        # initialize dv amd dk\n        dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)  # BLOCK_M\n        dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)  # BLOCK_M\n        # k and v stay in SRAM throughout\n        if EVEN_N & EVEN_M:\n            if EVEN_HEADDIM:\n                k = tl.load(k_ptrs)\n                v = tl.load(v_ptrs)\n            else:\n                k = tl.load(k_ptrs, mask=offs_k[None, :] < H_DIM, other=0.0)\n                v = tl.load(v_ptrs, mask=offs_k[None, :] < H_DIM, other=0.0)\n        else:\n            if EVEN_HEADDIM:\n                k = tl.load(k_ptrs, mask=offs_n[:, None] < N_CTX, other=0.0)\n                v = tl.load(v_ptrs, mask=offs_n[:, None] < N_CTX, other=0.0)\n            else:\n                k = tl.load(\n                    k_ptrs,\n                    mask=(offs_n[:, None] < N_CTX) & (offs_k[None, :] < H_DIM),\n                    other=0.0,\n                )\n                v = tl.load(\n                    v_ptrs,\n                    mask=(offs_n[:, None] < N_CTX) & (offs_k[None, :] < H_DIM),\n                    other=0.0,\n                )\n        # loop over rows\n        num_block_m = tl.cdiv(N_CTX, BLOCK_M)\n        for start_m in range(lo, num_block_m * BLOCK_M, BLOCK_M):\n            start_m = tl.multiple_of(start_m, BLOCK_M)\n            offs_m_curr = start_m + offs_m\n            # load q, k, v, do on-chip\n            if EVEN_M & EVEN_HEADDIM:\n                q = tl.load(q_ptrs)\n            else:\n                if EVEN_HEADDIM:\n                    q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < N_CTX, other=0.0)\n                else:\n                    q = tl.load(\n                        q_ptrs,\n                        mask=(offs_m_curr[:, None] < N_CTX) & (offs_k[None, :] < H_DIM),\n                        other=0.0,\n                    )\n            # recompute p = softmax(qk, dim=-1).T\n            # NOTE: `do` is pre-divided by `l`; no normalization here\n            qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n            qk += tl.dot(q, tl.trans(k))\n\n            if use_bias:\n                tl.debug_barrier()  # Race condition otherwise\n                if EVEN_M & EVEN_N:\n                    bias = tl.load(b_ptrs).to(tl.float32)\n                else:\n                    bias = tl.load(\n                        b_ptrs,\n                        mask=(offs_m_curr[:, None] < N_CTX) & (offs_n[None, :] < N_CTX),\n                        other=0.0,\n                    ).to(tl.float32)\n                qk = qk * sm_scale + bias\n\n            if use_mask:\n                # tl.debug_barrier()  # Race condition otherwise\n                # qk = tl.where(offs_m_curr[:, None] >= N_CTX, float(\"-1e20\"), qk)\n                # qk = tl.where(offs_n[None, :] >= N_CTX, float(\"-1e20\"), qk)\n                # mask_data = tl.load(Mask + offs_base_mask + offs_n)\n                # qk = tl.where(mask_data[None, :] == 0., float(\"-1e20\"), qk)\n                if EVEN_M & EVEN_N:\n                    mask_data = tl.load(mask_ptrs).to(tl.float32)\n                else:\n                    mask_data = tl.load(\n                        mask_ptrs,\n                        mask=(offs_m_curr[:, None] < N_CTX) & (offs_n[None, :] < N_CTX),\n                        other=0.0,\n                    ).to(tl.float32)\n\n                qk += tl.where(mask_data == 0.0, -inf, 0.0)\n                # qk = tl.where(mask_data == 0., -inf, qk)\n\n            m = tl.load(m_ptrs + offs_m_curr)\n            if use_bias:\n                p = tl.exp(qk - m[:, None])\n            else:\n                p = tl.exp(qk * sm_scale - m[:, None])\n            # compute dv\n            if EVEN_M & EVEN_HEADDIM:\n                do = tl.load(do_ptrs)  # .to(tl.float32)\n            else:\n                do = tl.load(\n                    do_ptrs,\n                    mask=(offs_m_curr[:, None] < N_CTX) & (offs_k[None, :] < H_DIM),\n                    other=0.0,\n                )\n\n            dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)\n            # compute dp = dot(v, do)\n            Di = tl.load(D_ptrs + offs_m_curr)\n\n            dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]\n            dp += tl.dot(do, tl.trans(v))\n\n            # compute ds = p * (dp - delta[:, None])\n            ds = p * dp\n            if use_bias:\n                tl.store(dp_ptrs, ds)\n            ds = ds * sm_scale\n\n            # compute dk = dot(ds.T, q)\n            dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)\n\n            # compute dq\n            # can we remove .to(tl.float32)\n            if EVEN_M & EVEN_HEADDIM:  # Race condition if we just do EVEN_M\n                dq = tl.load(dq_ptrs).to(tl.float32)\n                dq += tl.dot(ds.to(Q.dtype.element_ty), k)\n                tl.store(dq_ptrs, dq)\n            else:\n                if EVEN_HEADDIM:\n                    dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < N_CTX, other=0.0).to(\n                        tl.float32\n                    )\n                    dq += tl.dot(ds.to(Q.dtype.element_ty), k)\n                    tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < N_CTX)\n                else:\n                    dq = tl.load(\n                        dq_ptrs,\n                        mask=(offs_m_curr[:, None] < N_CTX) & (offs_k[None, :] < H_DIM),\n                        other=0.0,\n                    ).to(tl.float32)\n                    dq += tl.dot(ds.to(Q.dtype.element_ty), k)\n                    tl.store(\n                        dq_ptrs,\n                        dq,\n                        mask=(offs_m_curr[:, None] < N_CTX) & (offs_k[None, :] < H_DIM),\n                    )\n            # increment pointers\n            dq_ptrs += BLOCK_M * stride_dqm\n            q_ptrs += BLOCK_M * stride_qm\n            do_ptrs += BLOCK_M * stride_dom\n\n            dp_ptrs += BLOCK_M * stride_dpm\n            if use_bias:\n                b_ptrs += BLOCK_M * stride_bm\n            if use_mask:\n                mask_ptrs += BLOCK_M * stride_mm\n        # write-back\n        dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_k[None, :] * stride_dvk)\n        dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk)\n\n        if EVEN_N & EVEN_M:\n            if EVEN_HEADDIM:\n                tl.store(dv_ptrs, dv)\n                tl.store(dk_ptrs, dk)\n            else:\n                tl.store(dv_ptrs, dv, mask=offs_k[None, :] < H_DIM)\n                tl.store(dk_ptrs, dk, mask=offs_k[None, :] < H_DIM)\n        else:\n            if EVEN_HEADDIM:\n                tl.store(dv_ptrs, dv, mask=offs_n[:, None] < N_CTX)\n                tl.store(dk_ptrs, dk, mask=offs_n[:, None] < N_CTX)\n            else:\n                tl.store(\n                    dv_ptrs,\n                    dv,\n                    mask=(offs_n[:, None] < N_CTX) & (offs_k[None, :] < H_DIM),\n                )\n                tl.store(\n                    dk_ptrs,\n                    dk,\n                    mask=(offs_n[:, None] < N_CTX) & (offs_k[None, :] < H_DIM),\n                )\n"
  },
  {
    "path": "apex/contrib/openfold_triton/fused_adam_swa.py",
    "content": "# © 2023 NVIDIA CORPORATION & AFFILIATES\n\nfrom __future__ import annotations\n\nfrom collections import defaultdict\nfrom enum import Enum, unique\nfrom itertools import chain\nfrom typing import Callable, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport triton\nimport triton.language as tl\nfrom torch.optim import Adam, Optimizer\n\n# The most common parameter size in open-fold.\nCHUNK_SIZE = torch.tensor(128, dtype=torch.int64)\n\n\n# Data type enumerates. tl.constexpr arg doesn't accept Triton data types.\n@unique\nclass _DTypeEnum(Enum):\n    FP16 = 0\n    BF16 = 1\n    FP32 = 2\n    FP64 = 3\n\n\n_TORCH2DTYPE = {\n    torch.float16: _DTypeEnum.FP16,\n    torch.bfloat16: _DTypeEnum.BF16,\n    torch.float32: _DTypeEnum.FP32,\n    torch.float64: _DTypeEnum.FP64,\n}\n\n\n_DTYPE2TRITON = {\n    _DTypeEnum.FP16: tl.float16,\n    _DTypeEnum.BF16: tl.bfloat16,\n    _DTypeEnum.FP32: tl.float32,\n    _DTypeEnum.FP64: tl.float64,\n}\n\n\n# Adam math impl enumerates. There're minor impl differences between Apex and official PyTorch.\n@unique\nclass AdamMathType(Enum):\n    ApexAdam = 0\n    ApexAdamW = 1\n    PyTorchAdam = 2\n\n\n@triton.jit\ndef _adam_math(\n    param,\n    grad,\n    moment,\n    velocity,\n    beta1,\n    beta2,\n    beta1_correction,\n    beta2_correction,\n    eps,\n    lr,\n    weight_decay,\n    adam_math_mode: tl.constexpr,\n):\n    if adam_math_mode == tl.constexpr(AdamMathType.ApexAdam.value):\n        grad += weight_decay * param\n        moment *= beta1\n        moment += (1.0 - beta1) * grad\n        velocity *= beta2\n        velocity += (1.0 - beta2) * grad * grad\n        update = (moment / beta1_correction) / (tl.math.sqrt(velocity / beta2_correction) + eps)\n        param -= lr * update\n    elif adam_math_mode == tl.constexpr(AdamMathType.ApexAdamW.value):\n        moment *= beta1\n        moment += (1.0 - beta1) * grad\n        velocity *= beta2\n        velocity += (1.0 - beta2) * grad * grad\n        update = (moment / beta1_correction) / (tl.math.sqrt(velocity / beta2_correction) + eps)\n        update += weight_decay * param\n        param -= lr * update\n    elif adam_math_mode == tl.constexpr(AdamMathType.PyTorchAdam.value):\n        grad += weight_decay * param\n        moment *= beta1\n        moment += (1.0 - beta1) * grad\n        velocity *= beta2\n        velocity += (1.0 - beta2) * grad * grad\n        # PyTorch computes step_size and denominator separately so it can use addcdiv later.\n        step_size = -lr / beta1_correction\n        beta2_correction_sqrt = tl.math.sqrt(beta2_correction)\n        denom = tl.math.sqrt(velocity) / beta2_correction_sqrt + eps\n        param += step_size * (moment / denom)\n    else:\n        raise ValueError(f\"Unknown Adam math mode: {adam_math_mode}\")\n    return param, moment, velocity\n\n\n# OpenFold model doesn't use buffers, so only update parameters.\n@triton.jit\ndef _swa_math(\n    param,\n    swa_param,\n    decay_rate,\n    n_averaged,\n):\n    if n_averaged == 0:\n        swa_param = param\n    else:\n        swa_param += (1.0 - decay_rate) * (param - swa_param)\n    return swa_param\n\n\n@triton.jit\ndef _multi_tensor_adam_swa(\n    state_param_ptr_per_chunk,\n    compute_param_ptr_per_chunk,\n    swa_param_ptr_per_chunk,\n    grad_ptr_per_chunk,\n    moment_ptr_per_chunk,\n    velocity_ptr_per_chunk,\n    chunk_local_idx_ptr,\n    chunk_numel_ptr,\n    grad_clip_scale_ptr,\n    lr,\n    beta1,\n    beta2,\n    eps,\n    weight_decay,\n    beta1_correction,\n    beta2_correction,\n    swa_decay_rate,\n    swa_n_averaged,\n    adam_math_mode: tl.constexpr,\n    MODEL_COMPUTE_DTYPE: tl.constexpr,\n    MODEL_STATE_DTYPE: tl.constexpr,\n    CHUNK_SIZE: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    chunk_idx = tl.program_id(0)\n    chunk_local_idx = tl.load(chunk_local_idx_ptr + chunk_idx)\n    chunk_numel = tl.load(chunk_numel_ptr + chunk_idx)\n\n    compute_dtype = _DTYPE2TRITON[MODEL_COMPUTE_DTYPE.value]\n    compute_pointer_type = tl.pointer_type(compute_dtype)\n    state_dtype = _DTYPE2TRITON[MODEL_STATE_DTYPE.value]\n    state_pointer_type = tl.pointer_type(state_dtype)\n\n    state_param_ptr = tl.load(state_param_ptr_per_chunk + chunk_idx).to(state_pointer_type)\n    swa_param_ptr = tl.load(swa_param_ptr_per_chunk + chunk_idx).to(state_pointer_type)\n    moment_ptr = tl.load(moment_ptr_per_chunk + chunk_idx).to(state_pointer_type)\n    velocity_ptr = tl.load(velocity_ptr_per_chunk + chunk_idx).to(state_pointer_type)\n    compute_param_ptr = tl.load(compute_param_ptr_per_chunk + chunk_idx).to(compute_pointer_type)\n    grad_ptr = tl.load(grad_ptr_per_chunk + chunk_idx).to(compute_pointer_type)\n    grad_clip_scale = tl.load(grad_clip_scale_ptr)\n\n    ptr_base_offset = chunk_local_idx * CHUNK_SIZE\n    state_param_ptr += ptr_base_offset\n    compute_param_ptr += ptr_base_offset\n    swa_param_ptr += ptr_base_offset\n    grad_ptr += ptr_base_offset\n    moment_ptr += ptr_base_offset\n    velocity_ptr += ptr_base_offset\n\n    for i in range(0, CHUNK_SIZE, BLOCK_SIZE):\n        idx = i + tl.arange(0, BLOCK_SIZE)\n        mask = idx < chunk_numel\n        # Gradient clip step.\n        grad = tl.load(grad_ptr + idx, mask).to(state_dtype)\n        grad *= grad_clip_scale\n        # Adam step.\n        param = tl.load(state_param_ptr + idx, mask)\n        moment = tl.load(moment_ptr + idx, mask)\n        velocity = tl.load(velocity_ptr + idx, mask)\n        param, moment, velocity = _adam_math(\n            param=param,\n            grad=grad,\n            moment=moment,\n            velocity=velocity,\n            beta1=beta1,\n            beta2=beta2,\n            beta1_correction=beta1_correction,\n            beta2_correction=beta2_correction,\n            eps=eps,\n            lr=lr,\n            weight_decay=weight_decay,\n            adam_math_mode=adam_math_mode,\n        )\n        # SWA step.\n        swa_param = tl.load(swa_param_ptr + idx, mask)\n        swa_param = _swa_math(\n            param=param,\n            swa_param=swa_param,\n            decay_rate=swa_decay_rate,\n            n_averaged=swa_n_averaged,\n        )\n        # Write results. BF16 and SWA parameters are updated as well.\n        tl.store(state_param_ptr + idx, param, mask)\n        tl.store(moment_ptr + idx, moment, mask)\n        tl.store(velocity_ptr + idx, velocity, mask)\n        tl.store(compute_param_ptr + idx, param, mask)\n        tl.store(swa_param_ptr + idx, swa_param, mask)\n\n\n# Note:\n# - Gradients are attached to BF16 tensors\n# - Assume all parameters are all updated at each step, i.e., they share the same step number\nclass FusedAdamSWA(Optimizer):\n    def __init__(\n        self,\n        params: List[nn.Parameter],\n        compute_params: List[nn.Parameter],\n        swa_params: List[nn.Parameter],\n        swa_decay_rate: float,\n        lr: float = 1e-3,\n        bias_correction: bool = True,\n        betas: Tuple[float, float] = (0.9, 0.999),\n        eps: float = 1e-8,\n        adam_math_mode: AdamMathType = AdamMathType.PyTorchAdam,\n        weight_decay: float = 0.0,\n        amsgrad: bool = False,\n        set_grad_none: bool = True,\n        capturable: bool = False,\n        master_weights: bool = False,\n    ):\n        if not isinstance(params, list):\n            params = list(params)\n        if not isinstance(compute_params, list):\n            compute_params = list(compute_params)\n        if not isinstance(swa_params, list):\n            swa_params = list(swa_params)\n        if not compute_params or not swa_params:\n            raise ValueError(\"FusedAdamSWA requires both BF16 and SWA parameters.\")\n        if not len(params) == len(compute_params) == len(swa_params):\n            raise ValueError(\n                \"FusedAdamSWA expects params, bf16_params, and swa_params to have same length\"\n            )\n        if not all(\n            p.shape == b.shape == s.shape for p, b, s in zip(params, compute_params, swa_params)\n        ):\n            raise ValueError(\n                \"FusedAdamSWA expects each state in params, bf16_params, abd swa_params to have same shape\"\n            )\n        if not all(p.dtype == s.dtype for p, s in zip(params, swa_params)):\n            raise ValueError(\"FusedAdamSWA expects all params and swa_params to have same dtype\")\n        if not all(p.is_contiguous() for p in chain(params, compute_params, swa_params)):\n            raise ValueError(\"FusedAdamSWA expects all input params to be contiguous\")\n        if amsgrad:\n            raise NotImplementedError(\"amsgrad is not supported by FusedAdamSWA\")\n        if capturable:\n            raise NotImplementedError(\"capturable is not supported by FusedAdamSWA\")\n        if master_weights:\n            raise NotImplementedError(\"master_weights is not supported by FusedAdamSWA\")\n        if not isinstance(adam_math_mode, AdamMathType):\n            raise ValueError(\n                f\"Unknown Adam math mode {adam_math_mode}, expect to be any of:\\n\"\n                f\"\\t- {AdamMathType.ApexAdam}: NVIDIA Apex Adam math;\\n\"\n                f\"\\t- {AdamMathType.ApexAdamW}: NVIDIA Apex Adam math with adam_w set to True;\\n\"\n                f\"\\t- {AdamMathType.PyTorchAdam}: The official PyTorch Adam math.\\n\"\n            )\n\n        defaults = dict(\n            lr=lr,\n            bias_correction=bias_correction,\n            betas=betas,\n            eps=eps,\n            weight_decay=weight_decay,\n        )\n        super().__init__(params, defaults)\n        self.adam_math_mode = adam_math_mode\n        self.set_grad_none = set_grad_none\n        self.compute_param_groups = [{\"params\": compute_params}]\n        self.swa_param_groups = [{\"params\": swa_params, \"n_averaged\": 0}]\n        self.swa_decay_rate = swa_decay_rate\n\n        # We assume that parameter and buffer pointers won't change throughout the training, only\n        # gradients could be re-allocated due to set_grad_none.\n        self._pointer_buffers_initialized = False\n\n    def _build_pointer_buffers(self):\n        # Loading checkpoint to optimizer re-allocates param and states, so pointer logic should be\n        # at the first step of training, where we assume all states are ready.\n        if not all(\n            len(pg) == 1\n            for pg in (\n                self.param_groups,\n                self.compute_param_groups,\n                self.swa_param_groups,\n            )\n        ):\n            raise RuntimeError(\"FusedAdamSWA does not support multiple param groups\")\n\n        # `bf16_params` contains both BF16 and FP32 data types, thus we have to group parameters\n        # and other states into different buffers and launch respective kernels.\n        params, compute_params, swa_params = (\n            self.param_groups[0][\"params\"],\n            self.compute_param_groups[0][\"params\"],\n            self.swa_param_groups[0][\"params\"],\n        )\n        self.pointer_buffer_groups = defaultdict(dict)\n        for i, p in enumerate(compute_params):\n            compute_dtype = p.dtype\n            state_dtype = params[i].dtype\n            self.pointer_buffer_groups[(compute_dtype, state_dtype)].setdefault(\"tensor_idx\", [])\n            self.pointer_buffer_groups[(compute_dtype, state_dtype)][\"tensor_idx\"].append(i)\n\n        for (_, state_dtype), buffer_group in self.pointer_buffer_groups.items():\n            # Select tensors by dtype.\n            t_idx = buffer_group[\"tensor_idx\"]\n            params_this_group = [params[i] for i in t_idx]\n            compute_params_this_group = [compute_params[i] for i in t_idx]\n            swa_params_this_group = [swa_params[i] for i in t_idx]\n\n            # Build parameter pointer buffers.\n            param_ptrs = torch.tensor([p.data_ptr() for p in params_this_group], dtype=torch.int64)\n            compute_param_ptrs = torch.tensor(\n                [b.data_ptr() for b in compute_params_this_group], dtype=torch.int64\n            )\n            swa_param_ptrs = torch.tensor(\n                [s.data_ptr() for s in swa_params_this_group], dtype=torch.int64\n            )\n\n            param_numels = torch.tensor([p.numel() for p in params_this_group], dtype=torch.int64)\n            chunks_per_param = param_numels.float().div_(CHUNK_SIZE).ceil_().long()\n            chunk_local_idx = torch.cat(\n                [torch.arange(chunks, dtype=torch.int64) for chunks in chunks_per_param]\n            )\n            chunk_numel = torch.minimum(\n                param_numels.repeat_interleave(chunks_per_param) - chunk_local_idx * CHUNK_SIZE,\n                CHUNK_SIZE,\n            )\n            param_ptr_per_chunk = torch.repeat_interleave(param_ptrs, chunks_per_param)\n            compute_param_ptr_per_chunk = torch.repeat_interleave(\n                compute_param_ptrs, chunks_per_param\n            )\n            swa_param_ptr_per_chunk = torch.repeat_interleave(swa_param_ptrs, chunks_per_param)\n\n            device = params_this_group[0].device\n            buffer_group[\"device\"] = device\n            buffer_group[\"chunks_per_param\"] = chunks_per_param\n            buffer_group[\"chunk_local_idx\"] = chunk_local_idx.to(device)\n            buffer_group[\"chunk_numel\"] = chunk_numel.to(device)\n            buffer_group[\"param_ptr_per_chunk\"] = param_ptr_per_chunk.to(device)\n            buffer_group[\"compute_param_ptr_per_chunk\"] = compute_param_ptr_per_chunk.to(device)\n            buffer_group[\"swa_param_ptr_per_chunk\"] = swa_param_ptr_per_chunk.to(device)\n            buffer_group[\"total_chunks\"] = chunks_per_param.sum().item()\n            buffer_group[\"default_grad_clip_scale\"] = torch.tensor(1.0, dtype=state_dtype).to(\n                device\n            )\n\n            # Build moment pointer buffers.\n            moment, velocity = [], []\n            for p in params_this_group:\n                state = self.state[p]\n                if \"exp_avg\" not in state or \"exp_avg_sq\" not in state:\n                    state[\"exp_avg\"] = torch.zeros_like(p.detach(), dtype=state_dtype)\n                    state[\"exp_avg_sq\"] = torch.zeros_like(p.detach(), dtype=state_dtype)\n                moment.append(state[\"exp_avg\"].data_ptr())\n                velocity.append(state[\"exp_avg_sq\"].data_ptr())\n            moment = torch.tensor(moment, dtype=torch.int64)\n            velocity = torch.tensor(velocity, dtype=torch.int64)\n            buffer_group[\"exp_avg_ptr_per_chunk\"] = torch.repeat_interleave(\n                moment, chunks_per_param\n            ).to(device)\n            buffer_group[\"exp_avg_sq_ptr_per_chunk\"] = torch.repeat_interleave(\n                velocity, chunks_per_param\n            ).to(device)\n\n        self._pointer_buffers_initialized = True\n\n    def step(\n        self,\n        closure: Optional[Callable[[], torch.Tensor]] = None,\n        grad_clip_scale: Optional[Union[torch.Tensor, float]] = None,\n    ):\n        if not self._pointer_buffers_initialized:\n            self._build_pointer_buffers()\n\n        loss = closure() if closure is not None else None\n\n        group = self.param_groups[0]\n        compute_group = self.compute_param_groups[0]\n        swa_group = self.swa_param_groups[0]\n        if \"step\" in group:\n            group[\"step\"] += 1\n        else:\n            group[\"step\"] = 1\n        (beta1, beta2), step = group[\"betas\"], group[\"step\"]\n        if group[\"bias_correction\"]:\n            beta1_correction = 1.0 - beta1**step\n            beta2_correction = 1.0 - beta2**step\n        else:\n            beta1_correction = beta2_correction = 1.0\n\n        grad_ptr = []\n        for p in compute_group[\"params\"]:\n            if p.grad is None:\n                continue\n            if p.grad.detach().is_sparse:\n                raise RuntimeError(\n                    \"FusedAdamSWA does not support sparse gradients, please consider SparseAdam instead\"\n                )\n            grad_ptr.append(p.grad.data_ptr())\n\n        for (\n            compute_dtype,\n            state_dtype,\n        ), buffer_group in self.pointer_buffer_groups.items():\n            device = buffer_group[\"device\"]\n            t_idx = buffer_group[\"tensor_idx\"]\n            grad_ptr_this_group = [grad_ptr[i] for i in t_idx]\n            grad_ptr_this_group = torch.tensor(grad_ptr_this_group, dtype=torch.int64)\n            grad_ptr_per_chunk = torch.repeat_interleave(\n                grad_ptr_this_group, buffer_group[\"chunks_per_param\"]\n            ).to(device, non_blocking=True)\n            if grad_clip_scale is None:\n                grad_clip_scale_this_group = buffer_group[\"default_grad_clip_scale\"]\n            elif not torch.is_tensor(grad_clip_scale):\n                grad_clip_scale_this_group = torch.tensor(grad_clip_scale).to(\n                    device, non_blocking=True\n                )\n            else:\n                grad_clip_scale_this_group = grad_clip_scale\n\n            grid = (buffer_group[\"total_chunks\"],)\n            _multi_tensor_adam_swa[grid](\n                state_param_ptr_per_chunk=buffer_group[\"param_ptr_per_chunk\"],\n                compute_param_ptr_per_chunk=buffer_group[\"compute_param_ptr_per_chunk\"],\n                swa_param_ptr_per_chunk=buffer_group[\"swa_param_ptr_per_chunk\"],\n                grad_ptr_per_chunk=grad_ptr_per_chunk,\n                moment_ptr_per_chunk=buffer_group[\"exp_avg_ptr_per_chunk\"],\n                velocity_ptr_per_chunk=buffer_group[\"exp_avg_sq_ptr_per_chunk\"],\n                chunk_local_idx_ptr=buffer_group[\"chunk_local_idx\"],\n                chunk_numel_ptr=buffer_group[\"chunk_numel\"],\n                grad_clip_scale_ptr=grad_clip_scale_this_group,\n                lr=group[\"lr\"],\n                beta1=beta1,\n                beta2=beta2,\n                eps=group[\"eps\"],\n                weight_decay=group[\"weight_decay\"],\n                beta1_correction=beta1_correction,\n                beta2_correction=beta2_correction,\n                swa_decay_rate=self.swa_decay_rate,\n                swa_n_averaged=swa_group[\"n_averaged\"],\n                adam_math_mode=self.adam_math_mode.value,\n                MODEL_COMPUTE_DTYPE=_TORCH2DTYPE[compute_dtype],\n                MODEL_STATE_DTYPE=_TORCH2DTYPE[state_dtype],\n                # TODO: Find optimal hyper-parameters.\n                CHUNK_SIZE=CHUNK_SIZE.item(),\n                BLOCK_SIZE=128,\n                num_warps=1,\n            )\n\n        swa_group[\"n_averaged\"] += 1\n\n        return loss\n\n    @classmethod\n    def from_optim(\n        cls,\n        adam_optimizer: Adam,\n        fp32_params: List[nn.Parameter],\n        bf16_params: List[nn.Parameter],\n        swa_params: List[nn.Parameter],\n        swa_decay_rate: float,\n    ) -> FusedAdamSWA:\n        assert len(adam_optimizer.param_groups) == 1\n        param_group = adam_optimizer.param_groups[0]\n        lr = param_group[\"lr\"]\n        betas = param_group[\"betas\"]\n        eps = param_group[\"eps\"]\n        weight_decay = param_group[\"weight_decay\"]\n        amsgrad = param_group[\"amsgrad\"]\n        fused_adam_swa_optimizer = cls(\n            params=fp32_params,\n            compute_params=bf16_params,\n            swa_params=swa_params,\n            swa_decay_rate=swa_decay_rate,\n            lr=lr,\n            betas=betas,\n            eps=eps,\n            weight_decay=weight_decay,\n            amsgrad=amsgrad,\n            adam_math_mode=AdamMathType.PyTorchAdam,\n        )\n        adam_state_dict = adam_optimizer.state_dict()\n        adam_state_dict[\"param_groups\"][0].setdefault(\"bias_correction\", True)\n        steps = [v[\"step\"] for v in adam_state_dict[\"state\"].values()]\n        if len(steps) == 0:  # Did not load optimizer checkpoint.\n            steps = [torch.tensor(1)]\n        elif not all(s == steps[0] for s in steps):\n            raise ValueError(\"FusedAdamSWA requires all parameters were updated by same steps!\")\n        step = int(steps[0].item())\n        adam_state_dict[\"param_groups\"][0].setdefault(\"step\", step)\n        fused_adam_swa_optimizer.load_state_dict(adam_state_dict)\n        return fused_adam_swa_optimizer\n"
  },
  {
    "path": "apex/contrib/openfold_triton/layer_norm.py",
    "content": "# © 2023 NVIDIA CORPORATION & AFFILIATES\n\nfrom math import prod\n\nimport torch\nimport triton\nfrom torch.autograd import Function\n\nfrom apex.contrib.openfold_triton._layer_norm_backward_kernels import (\n    PARTIAL_REDUCE_MIN,\n    _layer_norm_backward_buf_reduce,\n    _layer_norm_backward_dw_db_partial,\n    _layer_norm_backward_dw_db_partial_strided,\n    _layer_norm_backward_dx,\n    _layer_norm_backward_dx_strided,\n)\nfrom apex.contrib.openfold_triton._layer_norm_forward_kernels import (\n    _layer_norm_forward,\n    _layer_norm_forward_strided,\n)\n\n# TODO: Find a more elegant approach to cache tuned results.\n_M_BUFSIZE_CACHE = dict()\n\n\nclass LayerNormSmallShapeOptImpl(Function):\n    @staticmethod\n    def forward(ctx, inputs, normalized_shape, weight, bias, eps=1e-05):\n        if not inputs.is_contiguous() and normalized_shape != inputs.shape[-1:]:\n            raise ValueError(\n                f\"This implementation only support normalizing along the last dimension for \"\n                f\"noncontiguous inputs. I.e., we expect \"\n                f\"normalized_shape={tuple(inputs.shape[-1:])}, but got {normalized_shape} instead\"\n            )\n        if not inputs.is_contiguous() and inputs.dim() != 4:\n            raise ValueError(\n                f\"This implementation only supports 4-dim noncontiguous inputs, but got \"\n                f\"{inputs.dim()} instead\"\n            )\n\n        normalized_degree = len(normalized_shape)\n        layer_shape = inputs.shape[:-normalized_degree]\n        M, N = prod(layer_shape), prod(normalized_shape)\n\n        x_invstd = torch.empty(M, dtype=torch.float32, device=inputs.device)\n        x_mean = torch.empty(M, dtype=torch.float32, device=inputs.device)\n        y = torch.empty(inputs.shape, dtype=inputs.dtype, device=inputs.device)\n\n        grid = lambda kwargs: (triton.cdiv(kwargs[\"M\"], kwargs[\"M_BLOCK\"]),)\n        if inputs.is_contiguous():\n            _layer_norm_forward[grid](\n                x_ptr=inputs,\n                w_ptr=weight,\n                b_ptr=bias,\n                eps=eps,\n                x_invstd_ptr=x_invstd,\n                x_mean_ptr=x_mean,\n                y_ptr=y,\n                M=M,\n                N=N,\n            )\n        else:\n            D0, D1, D2, D3 = inputs.shape\n            S0, S1, S2, S3 = inputs.stride()\n            _layer_norm_forward_strided[grid](\n                x_ptr=inputs,\n                w_ptr=weight,\n                b_ptr=bias,\n                eps=eps,\n                x_invstd_ptr=x_invstd,\n                x_mean_ptr=x_mean,\n                y_ptr=y,\n                M=M,\n                N=N,\n                D0=D0,\n                D1=D1,\n                D2=D2,\n                D3=D3,\n                S0=S0,\n                S1=S1,\n                S2=S2,\n                S3=S3,\n            )\n\n        ctx.save_for_backward(inputs, weight, x_invstd, x_mean)\n        ctx.flatten_shape = M, N\n        return y\n\n    @staticmethod\n    def backward(ctx, d_y):\n        inputs, weight, x_invstd, x_mean = ctx.saved_tensors\n        M, N = ctx.flatten_shape\n        d_inputs = torch.empty_like(inputs)\n        d_weight = torch.empty_like(weight)\n        d_bias = torch.empty_like(weight)\n\n        # %% Separated kernels, similar to Inductor.\n        # 1. dX.\n        grid = lambda kwargs: (triton.cdiv(kwargs[\"M\"], kwargs[\"M_BLOCK\"]),)\n        if inputs.is_contiguous():\n            _layer_norm_backward_dx[grid](\n                dy_ptr=d_y,\n                x_ptr=inputs,\n                w_ptr=weight,\n                x_invstd_ptr=x_invstd,\n                x_mean_ptr=x_mean,\n                dx_ptr=d_inputs,\n                M=M,\n                N=N,\n            )\n        else:\n            D0, D1, D2, D3 = inputs.shape\n            S0, S1, S2, S3 = inputs.stride()\n            _layer_norm_backward_dx_strided[grid](\n                dy_ptr=d_y,\n                x_ptr=inputs,\n                w_ptr=weight,\n                x_invstd_ptr=x_invstd,\n                x_mean_ptr=x_mean,\n                dx_ptr=d_inputs,\n                M=M,\n                N=N,\n                D0=D0,\n                D1=D1,\n                D2=D2,\n                D3=D3,\n                S0=S0,\n                S1=S1,\n                S2=S2,\n                S3=S3,\n            )\n        # 2. dW and db.\n        key = (M, N, inputs.is_contiguous())\n        M_BUFSIZE = _M_BUFSIZE_CACHE.get(key, triton.cdiv(M, PARTIAL_REDUCE_MIN))\n        dw_partial_buf = torch.empty([N, M_BUFSIZE], dtype=torch.float32, device=d_y.device)\n        db_partial_buf = torch.empty([N, M_BUFSIZE], dtype=torch.float32, device=d_y.device)\n        grid = lambda kwargs: (\n            triton.cdiv(M, kwargs[\"M_PARTIAL_REDUCE\"]),\n            triton.cdiv(N, kwargs[\"N_BLOCK\"]),\n        )\n        if inputs.is_contiguous():\n            _layer_norm_backward_dw_db_partial[grid](\n                dy_ptr=d_y,\n                x_ptr=inputs,\n                x_invstd_ptr=x_invstd,\n                x_mean_ptr=x_mean,\n                dw_partial_buf_ptr=dw_partial_buf,\n                db_partial_buf_ptr=db_partial_buf,\n                M=M,\n                N=N,\n                BUF_N_STRIDE=M_BUFSIZE,\n            )\n            M_PARTIAL_REDUCE = _layer_norm_backward_dw_db_partial.best_config.kwargs[\n                \"M_PARTIAL_REDUCE\"\n            ]\n        else:\n            _layer_norm_backward_dw_db_partial_strided[grid](\n                dy_ptr=d_y,\n                x_ptr=inputs,\n                x_invstd_ptr=x_invstd,\n                x_mean_ptr=x_mean,\n                dw_partial_buf_ptr=dw_partial_buf,\n                db_partial_buf_ptr=db_partial_buf,\n                M=M,\n                N=N,\n                BUF_N_STRIDE=M_BUFSIZE,\n                D0=D0,\n                D1=D1,\n                D2=D2,\n                D3=D3,\n                S0=S0,\n                S1=S1,\n                S2=S2,\n                S3=S3,\n            )\n            M_PARTIAL_REDUCE = _layer_norm_backward_dw_db_partial_strided.best_config.kwargs[\n                \"M_PARTIAL_REDUCE\"\n            ]\n        # 2.1. Reduce partial buffers, which can be overlapped.\n        M_BUFSIZE = triton.cdiv(M, M_PARTIAL_REDUCE)\n        _M_BUFSIZE_CACHE[key] = M_BUFSIZE\n        grid = (triton.next_power_of_2(N),)\n        _layer_norm_backward_buf_reduce[grid](\n            partial_buf_ptr=dw_partial_buf,\n            output_ptr=d_weight,\n            N=N,\n            M=M_BUFSIZE,\n            N_STRIDE=dw_partial_buf.stride(0),\n            M_STRIDE=dw_partial_buf.stride(1),\n            num_warps=1,\n        )\n        _layer_norm_backward_buf_reduce[grid](\n            partial_buf_ptr=db_partial_buf,\n            output_ptr=d_bias,\n            N=N,\n            M=M_BUFSIZE,\n            N_STRIDE=db_partial_buf.stride(0),\n            M_STRIDE=db_partial_buf.stride(1),\n            num_warps=1,\n        )\n\n        return d_inputs, None, d_weight, d_bias, None\n"
  },
  {
    "path": "apex/contrib/openfold_triton/mha.py",
    "content": "# © 2023 NVIDIA CORPORATION & AFFILIATES\n\nimport math\nfrom typing import Optional\n\nimport torch\nimport triton\nfrom einops import rearrange\n\nfrom apex.contrib.openfold_triton._mha_kernel import (\n    _attention_core,\n    _bwd_kernel,\n    _bwd_preprocess,\n)\n\n# whether TRITON MHA is enabled or not\n_TRI_MHA_ENABLED = False\n\n\ndef is_enabled() -> Optional[bool]:\n    global _TRI_MHA_ENABLED\n    return _TRI_MHA_ENABLED\n\n\ndef enable() -> None:\n    global _TRI_MHA_ENABLED\n    _TRI_MHA_ENABLED = True\n\n\ndef disable() -> None:\n    global _TRI_MHA_ENABLED\n    _TRI_MHA_ENABLED = False\n\n\n# TODO: support q.shape [1, 1024, 8, 256, 8]\ndef CanSchTriMHA(in_shape, has_bias=True, inf=1e9, training=True):\n    if has_bias == False:  # skip bias is None\n        return False\n    if inf != 1e9:  # skip inf != 1e9\n        return False\n\n    lst_3d = in_shape[-3:]\n    skip_neg2_dim = in_shape[:3] + in_shape[-1:]\n    if not training and (\n        in_shape == [1, 538, 4, 538, 16]\n        or in_shape == [1, 585, 4, 585, 16]\n        or in_shape == [1, 538, 4, 538, 32]\n        or in_shape == [1, 585, 4, 585, 32]\n        or in_shape == [1, 128, 8, 585, 32]\n        or in_shape == [1, 128, 8, 538, 32]\n        or lst_3d == [8, 128, 32]\n        or skip_neg2_dim == [1, 1024, 8, 8]\n        or skip_neg2_dim == [1, 128, 4, 32]\n        or skip_neg2_dim == [1, 128, 8, 32]\n    ):  # eval\n        return False  # skip eval\n    if (\n        in_shape == [1, 256, 4, 256, 16]\n        or in_shape == [1, 128, 4, 256, 16]\n        or in_shape == [1, 64, 4, 256, 16]\n        or in_shape == [1, 32, 4, 256, 16]\n    ):  # 7.26%\n        return True\n    elif (\n        in_shape == [1, 128, 8, 256, 32]\n        or in_shape == [1, 64, 8, 256, 32]\n        or in_shape == [1, 32, 8, 256, 32]\n        or in_shape == [1, 16, 8, 256, 32]\n    ):  # 21.77%\n        return True\n    elif (\n        in_shape == [1, 256, 8, 128, 32]\n        or in_shape == [1, 128, 8, 128, 32]\n        or in_shape == [1, 64, 8, 128, 32]\n        or in_shape == [1, 32, 8, 128, 32]\n    ):  # 21.77% no bias\n        return True\n    elif (\n        in_shape == [1, 256, 4, 256, 32]\n        or in_shape == [1, 128, 4, 256, 32]\n        or in_shape == [1, 64, 4, 256, 32]\n        or in_shape == [1, 32, 4, 256, 32]\n    ):  # 47.17%\n        return True\n    else:  # not support\n        return False\n\n\n# tune hyper params for each workload\ndef schedule_triton_mha(in_shape, fwd=True):\n    # default\n    ret = [64, 32, 2, 3] if fwd else [128, 64, 8, 0]\n    if in_shape == [256, 4, 256, 16]:\n        ret = [64, 32, 2, 4] if fwd else [64, 64, 4, 0]\n    elif in_shape == [128, 4, 256, 16]:\n        ret = [64, 32, 2, 4] if fwd else [64, 64, 4, 0]\n    elif in_shape == [64, 4, 256, 16]:\n        ret = [64, 32, 2, 4] if fwd else [64, 64, 4, 0]\n    elif in_shape == [32, 4, 256, 16]:\n        ret = [64, 32, 2, 4] if fwd else [64, 64, 4, 0]\n    # [*, 8, 256, 32]\n    elif in_shape == [128, 8, 256, 32]:  # DAP1\n        ret = [64, 32, 2, 3] if fwd else [128, 64, 8, 1]\n    elif in_shape == [64, 8, 256, 32]:  # DAP2\n        ret = [64, 32, 2, 3] if fwd else [128, 64, 8, 1]\n    elif in_shape == [32, 8, 256, 32]:  # DAP4\n        ret = [64, 32, 2, 3] if fwd else [128, 64, 8, 1]\n    elif in_shape == [16, 8, 256, 32]:  # DAP8\n        ret = [64, 32, 2, 3] if fwd else [128, 64, 8, 1]\n    # [*, 8, 128, 32]\n    elif in_shape == [256, 8, 128, 32]:  # DAP1\n        ret = [64, 64, 4, 3] if fwd else [128, 64, 4, 1]\n    elif in_shape == [128, 8, 128, 32]:  # DAP2\n        ret = [128, 64, 4, 2] if fwd else [64, 64, 2, 0]\n    elif in_shape == [64, 8, 128, 32]:  # DAP4\n        ret = [128, 64, 4, 2] if fwd else [64, 64, 2, 0]\n    elif in_shape == [32, 8, 128, 32]:  # DAP8\n        ret = [128, 64, 4, 2] if fwd else [64, 64, 2, 0]\n    # [*, 4, 256, 32]\n    elif in_shape == [256, 4, 256, 32]:  # DAP1\n        ret = [64, 32, 2, 3] if fwd else [128, 64, 8, 0]\n    elif in_shape == [128, 4, 256, 32]:  # DAP2\n        ret = [64, 32, 2, 3] if fwd else [128, 64, 8, 1]\n    elif in_shape == [64, 4, 256, 32]:  # DAP4\n        ret = [64, 32, 2, 3] if fwd else [128, 64, 8, 1]\n    elif in_shape == [32, 4, 256, 32]:  # DAP8\n        ret = [64, 32, 2, 3] if fwd else [128, 64, 8, 0]\n    return ret[0], ret[1], ret[2], ret[3]\n\n\nclass FusedAttenionCoreFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, q, k, v, mask=None, bias=None, inf=1000000000.0, is_training=True):\n        q_ori_size = len(q.size())\n        if q_ori_size == 5:\n            q = rearrange(q, \"1 b2 h n d -> (1 b2) h n d\")\n            k = rearrange(k, \"1 b2 h n d -> (1 b2) h n d\")\n            v = rearrange(v, \"1 b2 h n d -> (1 b2) h n d\")\n        if bias is not None:\n            if len(bias.size()) == 5:\n                bias = rearrange(bias, \"1 b2 h n d -> (1 b2) h n d\")\n\n        if mask is not None and len(mask.size()) == 5:\n            mask = rearrange(mask, \"1 b 1 1 e -> b 1 1 e\")\n\n        batch = 1\n        sm_scale = 1.0 / math.sqrt(q.size(-1))\n        # q *= sm_scale\n        # shape constraints\n        Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n        assert Lq == Lk and Lk == Lv\n\n        if not is_training:\n            Lk = max(triton.next_power_of_2(Lk), 16)\n\n        assert Lk in {16, 32, 64, 128}\n\n        o = torch.empty_like(q)\n\n        Z, H, N_CTX, H_DIM = q.shape\n        grid = lambda META: (triton.cdiv(N_CTX, META[\"BLOCK_M\"]), Z * H)\n        l = torch.empty(\n            (q.shape[-4], q.shape[-3], q.shape[-2]),\n            device=q.device,\n            dtype=torch.float32,\n        )\n        m = torch.empty(\n            (q.shape[-4], q.shape[-3], q.shape[-2]),\n            device=q.device,\n            dtype=torch.float32,\n        )\n        # BLOCK_M, BLOCK_N, num_warps, num_stages  = 64, 64, 2, 3\n        BLOCK_M, BLOCK_N, num_warps, num_stages = schedule_triton_mha(list(q.shape), fwd=True)\n        if bias != None:\n            bias = bias.expand(Z, H, N_CTX, N_CTX)\n        bias_strides = (\n            (bias.stride(0), bias.stride(1), bias.stride(2), bias.stride(3))\n            if bias is not None\n            else (0, 0, 0, 0)\n        )\n        if mask != None:\n            mask = mask.expand(-1, q.shape[1], q.shape[2], -1)\n        mask_strides = (\n            (mask.stride(0), mask.stride(1), mask.stride(2), mask.stride(3))\n            if mask is not None\n            else (0, 0, 0, 0)\n        )\n\n        _attention_core[grid](\n            q,\n            k,\n            v,\n            mask,\n            bias,\n            sm_scale,\n            l,\n            m,\n            o,\n            q.stride(0),\n            q.stride(1),\n            q.stride(2),\n            q.stride(3),\n            k.stride(0),\n            k.stride(1),\n            k.stride(2),\n            k.stride(3),\n            v.stride(0),\n            v.stride(1),\n            v.stride(2),\n            v.stride(3),\n            o.stride(0),\n            o.stride(1),\n            o.stride(2),\n            o.stride(3),\n            *bias_strides,\n            *mask_strides,\n            q.shape[0],\n            q.shape[1],\n            q.shape[2],\n            q.shape[3],\n            batch,  # 256 8 128 1\n            inf=inf,\n            IS_TRAINING=is_training,\n            BLOCK_M=BLOCK_M,\n            BLOCK_N=BLOCK_N,\n            BLOCK_DMODEL=Lk,\n            use_mask=(mask != None),\n            use_bias=(bias != None),\n            num_warps=num_warps,\n            num_stages=num_stages,\n        )\n        o = o.contiguous()\n        # print(h.asm[\"ttgir\"])\n        if is_training:\n            ctx.save_for_backward(q, k, v, o, m, l, bias)\n            ctx.grid = grid\n            ctx.sm_scale = sm_scale\n            ctx.BLOCK_DMODEL = Lk\n            ctx.mask = mask\n            ctx.inf = inf\n        if q_ori_size == 5:\n            o = rearrange(o, \"a b c d -> 1 a b c d\")\n        return o\n\n    @staticmethod\n    def backward(ctx, do):\n        q, k, v, o, m, l, bias = ctx.saved_tensors\n        ori_do_size = len(do.size())\n        if ori_do_size == 5:\n            do = rearrange(do, \"1 a b c d -> a b c d\")\n        do = do.contiguous()\n        dq = torch.zeros_like(q, dtype=torch.float32)\n        dk = torch.empty_like(k)\n        dv = torch.empty_like(v)\n        # bias.dtype\n        Z, H, N_CTX, H_DIM = q.shape[-4], q.shape[-3], q.shape[-2], q.shape[-1]\n        dp = torch.zeros((Z, H, N_CTX, N_CTX), dtype=torch.float32, device=\"cuda\")\n\n        do_scaled = torch.empty_like(do)\n        delta = torch.empty_like(l)\n        mask = ctx.mask\n        inf = ctx.inf\n\n        BLOCK = 128\n        BLOCK_HEADDIM = max(triton.next_power_of_2(H_DIM), 16)\n        grid = (triton.cdiv(N_CTX, BLOCK) * Z * H, 1)\n        _bwd_preprocess[grid](\n            o,\n            do,\n            l,\n            do_scaled,\n            delta,\n            o.stride(0),\n            o.stride(1),\n            o.stride(2),\n            o.stride(3),\n            do.stride(0),\n            do.stride(1),\n            do.stride(2),\n            do.stride(3),\n            BLOCK_M=BLOCK,\n            D_HEAD=BLOCK_HEADDIM,\n        )\n\n        if bias is not None:\n            assert bias.dtype in [q.dtype, torch.float]\n            assert bias.is_cuda\n            assert bias.dim() == 4\n            assert bias.stride(-1) == 1\n            bias = bias.expand(Z, H, N_CTX, N_CTX)\n\n        # if mask is not None:\n        #    mask = mask.expand(Z, H, N_CTX, N_CTX)\n\n        bias_strides = (\n            (bias.stride(0), bias.stride(1), bias.stride(2), bias.stride(3))\n            if bias is not None\n            else (0, 0, 0, 0)\n        )\n        mask_strides = (\n            (mask.stride(0), mask.stride(1), mask.stride(2), mask.stride(3))\n            if mask is not None\n            else (0, 0, 0, 0)\n        )\n\n        # BLOCK_M, BLOCK_N = 128, 64\n        BLOCK_M, BLOCK_N, num_warps, num_stages = schedule_triton_mha(list(q.shape), fwd=False)\n        # grid = lambda META: (triton.cdiv(N_CTX, META[\"BLOCK_N\"]), Z * H)\n        # grid = lambda META: (Z * H, triton.cdiv(N_CTX, META[\"BLOCK_N\"]))\n        # grid = lambda META: (triton.cdiv(N_CTX, META[\"BLOCK_N\"]) if META[\"SEQUENCE_PARALLEL\"] else 1,\n        #            Z * H)\n        grid = lambda META: (Z * H,)\n        _bwd_kernel[grid](\n            q,\n            k,\n            v,\n            mask,\n            bias,\n            ctx.sm_scale,\n            o,\n            do_scaled,\n            dq,\n            dk,\n            dv,\n            dp,\n            l,\n            m,\n            delta,\n            q.stride(0),\n            q.stride(1),\n            q.stride(2),\n            q.stride(3),\n            k.stride(0),\n            k.stride(1),\n            k.stride(2),\n            k.stride(3),\n            v.stride(0),\n            v.stride(1),\n            v.stride(2),\n            v.stride(3),\n            *mask_strides,\n            *bias_strides,\n            dp.stride(0),\n            dp.stride(1),\n            dp.stride(2),\n            dp.stride(3),\n            do.stride(0),\n            do.stride(1),\n            do.stride(2),\n            do.stride(3),\n            dq.stride(0),\n            dq.stride(1),\n            dq.stride(2),\n            dq.stride(3),\n            dk.stride(0),\n            dk.stride(1),\n            dk.stride(2),\n            dk.stride(3),\n            dv.stride(0),\n            dv.stride(1),\n            dv.stride(2),\n            dv.stride(3),\n            q.shape[0],\n            q.shape[1],\n            q.shape[2],\n            q.shape[3],\n            # ctx.grid[0], # to delete\n            inf=inf,\n            BLOCK_M=BLOCK_M,\n            BLOCK_N=BLOCK_N,\n            BLOCK_DMODEL=ctx.BLOCK_DMODEL,\n            use_mask=(mask != None),\n            use_bias=(bias != None),\n            num_warps=num_warps,\n            num_stages=num_stages,\n            SEQUENCE_PARALLEL=False,\n        )\n        dB = None\n        if bias is not None:\n            dB = torch.sum(dp, dim=-4, keepdim=True)\n            if len(bias.size()) == 4:\n                dB = rearrange(dB, \"b2 h n d -> 1 b2 h n d\")\n        # print(h.asm[\"ttgir\"])\n\n        if ori_do_size == 5:\n            dq = rearrange(dq, \"b2 h n d -> 1 b2 h n d\")\n            dk = rearrange(dk, \"b2 h n d -> 1 b2 h n d\")\n            dv = rearrange(dv, \"b2 h n d -> 1 b2 h n d\")\n\n        return dq, dk, dv, None, dB, None, None\n\n\nAttnTri = FusedAttenionCoreFunc.apply\n\n\ndef _attention_bias(\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    mask: torch.Tensor,\n    bias: Optional[torch.Tensor],\n    inf: float,\n) -> torch.Tensor:\n    # query:  [*, num_heads, Q, c_hidden]\n    # key:    [*, num_heads, K, c_hidden]\n    # value:  [*, num_heads, V, c_hidden]\n    # mask:   Logit mask tensor broadcastable to [*, num_heads, Q, K]\n    # bias:   Optional logit bias tensor broadcastable to [*, num_heads, Q, K]\n    # inf:    Safe infinity value.\n    # assuming K == V\n\n    key = torch.swapdims(key, -2, -1)\n    # key: [*, num_heads, c_hidden, K]\n\n    scaling = 1.0 / math.sqrt(query.size(-1))\n    a = torch.matmul(query * scaling, key)\n    # a: [*, num_heads, Q, K]\n\n    a += (mask - 1.0) * inf\n    # a: [*, num_heads, Q, K]\n\n    a += bias\n    # a: [*, num_heads, Q, K]\n\n    a = torch.softmax(a, dim=-1)\n    # a: [*, num_heads, Q, K]\n\n    a = torch.matmul(a, value)\n    # a: [*, num_heads, Q, c_hidden]\n\n    return a\n\n\ndef _attention_no_bias(\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    mask: torch.Tensor,\n    inf: float,\n) -> torch.Tensor:\n    # query:  [*, num_heads, Q, c_hidden]\n    # key:    [*, num_heads, K, c_hidden]\n    # value:  [*, num_heads, V, c_hidden]\n    # mask:   Logit mask tensor broadcastable to [*, num_heads, Q, K]\n    # bias:   Optional logit bias tensor broadcastable to [*, num_heads, Q, K]\n    # inf:    Safe infinity value.\n    # assuming K == V\n\n    key = torch.swapdims(key, -2, -1)\n    # key: [*, num_heads, c_hidden, K]\n\n    scaling = 1.0 / math.sqrt(query.size(-1))\n    a = torch.matmul(query * scaling, key)\n    # a: [*, num_heads, Q, K]\n\n    a += (mask - 1.0) * inf\n    # a: [*, num_heads, Q, K]\n\n    a = torch.softmax(a, dim=-1)\n    # a: [*, num_heads, Q, K]\n\n    a = torch.matmul(a, value)\n    # a: [*, num_heads, Q, c_hidden]\n\n    return a\n\n\nAttnBiasJIT = torch.compile(_attention_bias)\nAttnNoBiasJIT = torch.compile(_attention_no_bias)\n"
  },
  {
    "path": "apex/contrib/optimizers/__init__.py",
    "content": "from .fp16_optimizer import FP16_Optimizer\nfrom .fused_adam import FusedAdam\nfrom .fused_lamb import FusedLAMB\n"
  },
  {
    "path": "apex/contrib/optimizers/distributed_fused_adam.py",
    "content": "import collections\nimport contextlib\nfrom dataclasses import dataclass\nimport enum\nimport inspect\nimport io\nimport itertools\nimport threading\nfrom typing import (\n    Any,\n    Callable,\n    Dict,\n    Iterable,\n    List,\n    Optional,\n    Set,\n    Tuple,\n    Union,\n)\nimport warnings\n\nimport torch\nfrom torch.distributed.distributed_c10d import _get_default_group\n\ntry:\n    import apex.contrib.nccl_allocator as nccl_allocator\nexcept ImportError:\n    nccl_allocator = None\n\nfrom apex.multi_tensor_apply import multi_tensor_applier\nimport amp_C\nimport distributed_adam_cuda\n\n# Fallback to private functions if using PyTorch <1.13.0\ntry:\n    from torch.distributed.distributed_c10d import get_global_rank\nexcept ImportError:\n    from torch.distributed.distributed_c10d import _get_global_rank\n\n    get_global_rank = _get_global_rank\ntry:\n    from torch.distributed.distributed_c10d import reduce_scatter_tensor\nexcept ImportError:\n    from torch.distributed.distributed_c10d import _reduce_scatter_base\n\n    reduce_scatter_tensor = _reduce_scatter_base\ntry:\n    from torch.distributed.distributed_c10d import all_gather_into_tensor\nexcept ImportError:\n    from torch.distributed.distributed_c10d import _all_gather_base\n\n    all_gather_into_tensor = _all_gather_base\n\n# Import context manager to coalesce NCCL calls\n# Note: Replace these backward compatibility shims once PyTorch\n# exposes a stable public API for coalescing communication.\nfrom torch.distributed.distributed_c10d import _coalescing_manager\n\nif \"device\" not in inspect.signature(_coalescing_manager).parameters:\n    # PyTorch <=1.13.1 does not have device arg\n    _coalescing_manager_no_device_arg = _coalescing_manager\n\n    @contextlib.contextmanager\n    def _coalescing_manager(group, device, reqs):\n        with _coalescing_manager_no_device_arg(group, reqs):\n            yield\n\n\nif \"reqs\" in inspect.signature(_coalescing_manager).parameters:\n    # PyTorch <=2.0.1 handles synchronization externally to coalescing\n    # manager\n    _coalescing_manager_with_reqs_arg = _coalescing_manager\n\n    class _CoalescingManager:\n        def __init__(self):\n            self.works: List[torch.distributed.Work] = []\n\n        def append(self, work: torch.distributed.Work) -> None:\n            if work:\n                self.works.append(work)\n\n        def wait(self) -> None:\n            for work in self.works:\n                work.wait()\n\n    @contextlib.contextmanager\n    def _coalescing_manager(\n        group: Optional[torch.distributed.ProcessGroup] = None,\n        device: Optional[torch.device] = None,\n        async_ops: bool = False,\n    ) -> contextlib.AbstractContextManager:\n        assert device is not None\n        cm = _CoalescingManager()\n        with _coalescing_manager_with_reqs_arg(\n            group,\n            device,\n            cm.works,\n        ):\n            yield cm\n        if not async_ops:\n            cm.wait()\n\n    def _coalescing_manager_append_work(\n        cm: _CoalescingManager,\n        work: torch.distributed.Work,\n    ) -> None:\n        \"\"\"Add asynchronous request to coalescing manager\"\"\"\n        cm.append(work)\n\nelse:\n    # PyTorch >2.0.1 handles synchronization within coalescing\n    # manager\n    def _coalescing_manager_append_work(\n        cm: torch.distributed._CoalescingManager,\n        work: torch.distributed.Work,\n    ) -> None:\n        \"\"\"Dummy function for backward compatibility\n\n        Coalescing manager already keeps track of asynchronous\n        communication.\n\n        \"\"\"\n        pass\n\n\n# Import optional CUDA kernels\n_FOUND_DEPRECATED_FUSED_ADAM: bool = False\ntry:\n    import fused_adam_cuda\n\n    _FOUND_DEPRECATED_FUSED_ADAM = True\nexcept ImportError:\n    warnings.warn(\n        \"Could not find recommended CUDA kernels when importing \"\n        \"`DistributedFusedAdam`. \"\n        \"For best performance, Apex should be installed with \"\n        \"`--deprecated_fused_adam`.\"\n    )\n\n\ndef _round_to_multiple(\n    number: int,\n    multiple: int,\n    round_up: bool = True,\n) -> int:\n    \"\"\"Assumes arguments are positive integers\"\"\"\n    return (number + multiple - 1 if round_up else number) // multiple * multiple\n\n\ndef _devices_match(device1: torch.device, device2: torch.device) -> bool:\n    \"\"\"Whether two PyTorch devices are equivalent\"\"\"\n    device1 = torch.device(device1)\n    device2 = torch.device(device2)\n    if device1.type != device2.type:\n        return False\n    if device1.type == \"cuda\":\n        index1 = device1.index\n        index2 = device2.index\n        if index1 is None:\n            index1 = torch.cuda.current_device()\n        if index2 is None:\n            index2 = torch.cuda.current_device()\n        if index1 != index2:\n            return False\n    return True\n\n\ndef _multi_tensor_copy(\n    buffers_in: List[torch.Tensor],\n    buffers_out: List[torch.Tensor],\n    dummy_overflow_buf: Optional[torch.Tensor] = None,\n) -> None:\n    \"\"\"Copy between corresponding buffers\n\n    Uses fused copy kernel if possible.\n    \"\"\"\n\n    # Group buffers by device and dtype\n    buffer_groups = collections.defaultdict(list)\n    for buf_in, buf_out in zip(buffers_in, buffers_out):\n        if buf_in.data_ptr() == buf_out.data_ptr() or buf_in.numel() == 0:\n            # Nothing to be done if input and output buffers are same\n            # or have no entries\n            continue\n        if buf_in.dtype == buf_out.dtype:\n            # Just copy bytes if dtypes are same\n            buf_in = buf_in.view(torch.uint8)\n            buf_out = buf_out.view(torch.uint8)\n        is_cuda = _devices_match(buf_in.device, \"cuda\") and _devices_match(buf_out.device, \"cuda\")\n        is_contiguous = buf_in.is_contiguous() and buf_out.is_contiguous()\n        key = (\n            buf_in.dtype,\n            buf_out.dtype,\n            is_cuda,\n            is_contiguous,\n        )\n        buffer_groups[key].append((buf_in, buf_out))\n\n    # Copy each group of buffers\n    for key, buffers in buffer_groups.items():\n        # Check if buffers support fused kernel\n        dtype_in, dtype_out, is_cuda, is_contiguous = key\n        supported_dtypes = (torch.float32, torch.float16)\n        use_fused_kernel = (dtype_in in supported_dtypes and dtype_out in supported_dtypes) or (\n            dtype_in == torch.uint8 and dtype_out == torch.uint8\n        )\n        use_fused_kernel = use_fused_kernel and is_cuda and is_contiguous\n\n        # Copy buffers\n        if use_fused_kernel and _FOUND_DEPRECATED_FUSED_ADAM:\n            if dummy_overflow_buf is None:\n                dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=\"cuda\")\n            multi_tensor_applier(\n                fused_adam_cuda.maybe_cast_mt,\n                dummy_overflow_buf,\n                list(zip(*buffers)),\n            )\n        else:\n            # Warning: dummy_overflow_buf was not set in such case\n            for buf_in, buf_out in buffers:\n                buf_out.copy_(buf_in)\n\n\n@contextlib.contextmanager\ndef _disable_pre_forward_hook(\n    param: torch.nn.Parameter,\n) -> contextlib.AbstractContextManager:\n    \"\"\"Prevent parameter from calling pre-forward hook\"\"\"\n    hook_is_enabled = getattr(\n        param,\n        \"_pre_forward_hook_is_enabled\",\n        False,\n    )\n    if hook_is_enabled:\n        param._pre_forward_hook_is_enabled = False\n    try:\n        yield\n    finally:\n        if hook_is_enabled:\n            param._pre_forward_hook_is_enabled = True\n\n\n@torch.no_grad()\ndef _bf16_rem_to_fp32(\n    bf16: torch.Tensor,\n    rem: torch.Tensor,\n    fp32: torch.Tensor,\n) -> None:\n    \"\"\"Pack BF16 tensor and 16-bit remainders into FP32 tensor\"\"\"\n\n    # Check inputs\n    assert bf16.size() == rem.size() == fp32.size(), (\n        \"Tensor dimensions do not match: \"\n        f\"bf16={list(bf16.size())}, \"\n        f\"rem={list(rem.size())}, \"\n        f\"fp32={list(fp32.size())}, \"\n    )\n    assert bf16.dtype is torch.bfloat16, f\"bf16 buffer has invalid dtype ({bf16.dtype})\"\n    assert rem.dtype is torch.int16, f\"rem buffer has invalid dtype ({rem.dtype})\"\n    assert fp32.dtype is torch.float32, f\"fp32 buffer has invalid dtype ({fp32.dtype})\"\n\n    # Undo bf16 rounding\n    bf16 = bf16.view(torch.int16) - torch.where(rem < 0, 1, 0)\n\n    # Pack bf16 and remainder into little-endian fp32\n    fp32 = fp32.unsqueeze(-1).view(torch.int16)\n    fp32 = torch.stack((rem, bf16), dim=-1, out=fp32)\n\n\nclass DistributedFusedAdam(torch.optim.Optimizer):\n    \"\"\"Adam optimizer with ZeRO algorithm.\n\n    Currently GPU-only. Requires Apex to be installed via\n    ``python setup.py install --cuda_ext --cpp_ext --distributed_adam --deprecated_fused_adam``.\n\n    This implements the ZeRO-2 algorithm, which distributes the\n    optimizer state and gradients between parallel processes. In\n    particular, the parameters are flattened, grouped into fixed-size\n    buckets, and the optimizer state for each bucket is sharded over\n    the parallel processes. Options are provided to overlap the\n    gradient synchronization with the backward pass compute.\n\n    Adam was proposed in `Adam: A Method for Stochastic\n    Optimization`_, AdamW in `Decoupled Weight Decay Regularization`_,\n    and ZeRO in `ZeRO: Memory Optimizations Toward Training Trillion\n    Parameter Models`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts\n            defining parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        bias_correction (bool, optional): apply correction factor to\n            moment estimates. (default: True)\n        betas (Tuple[float, float], optional): coefficients used for\n            computing running averages of gradient and its square.\n            (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        adam_w_mode (boolean, optional): Decouple weight decay\n            regularization (also known as AdamW algorithm) (default:\n            True)\n        weight_decay (float, optional): weight decay (L2 penalty)\n            (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad\n            variant of this algorithm from the paper\n            `On the Convergence of Adam and Beyond`_ (default: False).\n            This is not yet supported.\n        dtype (torch.dtype, optional): datatype for optimizer state\n            (default: torch.float32)\n        grad_sync_dtype (torch.dtype, optional): datatype for gradient\n            synchronization (default: same as dtype)\n        param_sync_dtype (torch.dtype, optional): datatype for\n            parameter synchronization (default: same as dtype)\n        device (torch.device, optional): device for optimizer state\n            (default: cuda). Currently only supports GPU with one GPU\n            per process.\n        process_group (torch.distributed.ProcessGroup, optional):\n            parallel processes participating in optimizer (default:\n            default group in torch.distributed). This group is\n            interpreted as a 2D grid with dimensions\n            distributed_size x redundant_size.\n        distributed_process_group (torch.distributed.ProcessGroup,\n            optional): parallel processes to distribute optimizer\n            state over (default: same as process_group)\n        redundant_process_group (torch.distributed.ProcessGroup,\n            optional): parallel processes to replicate optimizer state\n            over (default: group only containing calling process)\n        average_grad_sync (bool, optional): whether to use average\n            reduction for gradient synchronization rather than sum\n            (default: True)\n        overlap_grad_sync (boolean, optional): whether to overlap\n            gradient synchronization with backward pass compute\n            (default: True)\n        overlap_param_sync (boolean, optional): whether to overlap\n            parameter synchronization with forward pass compute\n            (default: False). This is an experimental feature.\n        bucket_cap_mb (float, optional): bucket size in megabytes\n            (default: 100)\n        pipeline_size (int, optional): number of buckets to process\n            simultaneously in optimizer step (default: 2)\n        contiguous_param_buffer (bool, optional): convert parameters\n            into views into large persistent buffers (default: False).\n            This enables some performance optimizations (e.g. avoiding\n            some memory copies), but may add memory overhead (e.g. if\n            the memory allocator can't reuse the original parameter\n            buffers).\n        contiguous_grad_buffer (bool, optional): allocate gradient\n            buckets out of a large persistent buffers (default:\n            False). This allows individual parameter gradients to be\n            accessed externally (see grad_buffer_view function). It\n            enables some performance optimizations (e.g. avoiding some\n            memory copies), but prevents some memory optimizations\n            (e.g. the memory allocator can't reuse buffers for\n            gradient buckets).\n        store_params (bool, optional): store a distributed copy of the\n            parameters as optimizer state (default: True). This may be\n            desirable if the optimizer dtype has higher precision than\n            the parameter dtype.\n        store_param_remainders (bool, optional): if model is BF16 and\n            optimizer is FP32, store bits required to reconstruct FP32\n            params (default: False). This is an experimental feature.\n        with_scaled_states (bool, optional): apply per-tensor scaling\n            factors to the optimizer state (default: False). As\n            discussed in `FP8-LM: Training FP8 Large Language\n            Models`_, this helps maintain a reasonable dynamic range\n            even when the state is in a low-precision datatype like\n            FP16.\n        nccl_ub (bool, optional): enable NCCL user buffers for zero-copy\n            (default: False). It allows the collectives to use only 1 SM\n            when IB SHARP is enabled in a one-rank-per-node communication\n            group. This will help speedup the gemms overlapped with data-\n            parallel communications.\n        capturable (bool, optional): whether to use the version of the\n            optimizer that can be used with CUDA Graphs. (default: False).\n\n    .. _Adam\\: A Method for Stochastic Optimization:\n        https://arxiv.org/abs/1412.6980\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101\n    .. _ZeRO\\: Memory Optimizations Toward Training Trillion Parameter Models:\n        https://arxiv.org/abs/1910.02054\n    .. _FP8-LM\\: Training FP8 Large Language Models:\n        https://arxiv.org/pdf/2310.18313v2.pdf\n\n    \"\"\"\n\n    @dataclass\n    class ParameterFragment:\n        \"\"\"Buffer ranges for a parameter fragment\n\n        Describes corresponding regions in parameter buffer and\n        parameter bucket.\n\n        \"\"\"\n\n        # Parameter group index\n        param_group_id: int\n        # Parameter index within parameter group\n        param_id: int\n        # Bucket index\n        bucket_id: int\n        # Range within flattened parameter buffer\n        param_range: Tuple[int, int]\n        # Range within bucket\n        bucket_range: Tuple[int, int]\n        # Whether fragment is in local shard of bucket\n        in_local_shard: bool\n        # Range within local shard\n        shard_range: Optional[Tuple[int, int]]\n        # Range of local fragment shard within bucket\n        shard_bucket_range: Optional[Tuple[int, int]]\n        # Range of local fragment shard within parameter\n        shard_param_range: Optional[Tuple[int, int]]\n\n    class StateBucket:\n        \"\"\"Optimizer state for a bucket\"\"\"\n\n        def __init__(\n            self,\n            bucket_size: int,\n            shard_size: int,\n            dtype: torch.dtype,\n            device: torch.device,\n            grad_sync_dtype: torch.dtype,\n            param_sync_dtype: torch.dtype,\n            contiguous_buffer_offset: int = 0,\n            store_params: bool = False,\n            store_param_remainders: bool = False,\n        ):\n            # Size of parameter bucket\n            self.bucket_size: int = bucket_size\n            # Size of local shard of parameter bucket\n            self.shard_size: int = shard_size\n            # Data type for state\n            self.dtype = dtype\n            # Data type for gradient synchronization\n            self.grad_sync_dtype = grad_sync_dtype\n            # Data type for parameter synchronization\n            self.param_sync_dtype = param_sync_dtype\n            # Size of the filled region in the bucket\n            self.filled_size: int = 0\n            # Is it able to continue filling\n            self.able_to_fill: bool = True\n            # Offset to bucket in contiguous buffers\n            self.contiguous_buffer_offset: int = contiguous_buffer_offset\n            # Buffer ranges corresponding to parameter fragments\n            self.fragments: List[ParameterFragment] = []\n            # Local shard of parameters\n            self.params_shard: Optional[torch.Tensor] = None\n            if store_params:\n                self.params_shard = torch.zeros(\n                    [shard_size],\n                    dtype=self.dtype,\n                    device=device,\n                )\n            # Local shard of parameter remainders\n            self.param_remainders_shard: Optional[torch.Tensor] = None\n            if store_param_remainders:\n                self.param_remainders_shard = torch.zeros(\n                    [shard_size],\n                    dtype=torch.int16,\n                    device=device,\n                )\n            # Local shard of first moment estimate\n            self.exp_avg_shard: torch.Tensor = torch.zeros(\n                [shard_size],\n                dtype=self.dtype,\n                device=device,\n            )\n            # Local shard of second moment estimate\n            self.exp_avg_sq_shard: torch.Tensor = torch.zeros(\n                [shard_size],\n                dtype=self.dtype,\n                device=device,\n            )\n\n        def dtypes(self) -> Tuple[torch.dtype, torch.dtype, torch.dtype]:\n            \"\"\"Datatypes for the bucket's compute and communication\"\"\"\n            return (\n                self.dtype,\n                self.grad_sync_dtype,\n                self.param_sync_dtype,\n            )\n\n    class GradientStatus(enum.Enum):\n        \"\"\"Status of gradients within a bucket\"\"\"\n\n        # Gradients are ready to use\n        READY = enum.auto()\n        # Bucket is partially filled with unreduced gradients\n        PARTIALLY_FILLED = enum.auto()\n        # Bucket is fully filled with unreduced gradients\n        FULLY_FILLED = enum.auto()\n        # Asynchronous reduction is in progress\n        SYNCING = enum.auto()\n\n    class GradientBucket:\n        \"\"\"Gradient buffers and state for a bucket\"\"\"\n\n        def __init__(self):\n            # Local shard of gradients\n            self.grads_shard: Optional[torch.Tensor] = None\n            # Local contribution to gradients\n            self.grads_bucket: Optional[torch.Tensor] = None\n            # Buffer for gradient reduce-scatter\n            self.sync_grads_shard: Optional[torch.Tensor] = None\n            # Status of gradients\n            self.status: GradientStatus = DistributedFusedAdam.GradientStatus.READY\n            # Params that have generated grads\n            self.grads_generated: Set[torch.nn.Parameter] = set()\n\n    class ParameterStatus(enum.Enum):\n        \"\"\"Status of parameters within a bucket\"\"\"\n\n        # Parameters are sharded between processes\n        SHARDED = enum.auto()\n        # Asynchronous communication is in progress\n        SYNCING = enum.auto()\n        # Parameters are ready to use\n        READY = enum.auto()\n\n    class ParameterBucket:\n        \"\"\"Parameter buffers and state for a bucket\"\"\"\n\n        def __init__(self):\n            # Local shard of parameters\n            self.params_shard: Optional[torch.Tensor] = None\n            # Gathered parameter values\n            self.params_bucket: Optional[torch.Tensor] = None\n            # Status of parameters\n            self.status: ParameterStatus = DistributedFusedAdam.ParameterStatus.SHARDED\n            # Params that have been updated\n            self.params_updated: Set[torch.nn.Parameter] = set()\n\n    # Enable custom logic for AMP grad scaling\n    _step_supports_amp_scaling: bool = True\n    _custom_amp_unscale_grads: bool = True\n\n    def __init__(\n        self,\n        params: Union[Iterable[torch.nn.Parameter], Iterable[dict]],\n        lr: float = 1e-3,\n        bias_correction: bool = True,\n        betas: Tuple[float, float] = (0.9, 0.999),\n        eps: float = 1e-8,\n        adam_w_mode: bool = True,\n        weight_decay: float = 0.0,\n        amsgrad: bool = False,\n        dtype: torch.dtype = torch.float32,\n        grad_sync_dtype: Optional[torch.dtype] = None,\n        param_sync_dtype: Optional[torch.dtype] = None,\n        device: Optional[torch.device] = \"cuda\",\n        process_group: Optional[torch.distributed.ProcessGroup] = None,\n        distributed_process_group: Optional[torch.distributed.ProcessGroup] = None,\n        redundant_process_group: Optional[torch.distributed.ProcessGroup] = None,\n        average_grad_sync: bool = True,\n        overlap_grad_sync: bool = True,\n        overlap_param_sync: bool = False,\n        bucket_cap_mb: float = 100.0,\n        pipeline_size: int = 2,\n        contiguous_param_buffer: bool = False,\n        contiguous_grad_buffer: bool = False,\n        store_params: bool = True,\n        store_param_remainders: bool = False,\n        with_scaled_states: bool = False,\n        nccl_ub: bool = False,\n        capturable: bool = False,\n    ):\n        if (with_scaled_states or store_param_remainders) and capturable:\n            raise Exception(\n                f\"{self.__class__.__name__} with scaled states \"\n                \"or storing param remainders doesn't support CUDA graph yet.\"\n            )\n\n        if capturable and not _FOUND_DEPRECATED_FUSED_ADAM:\n            raise Exception(\n                f\"Capturable {self.__class__.__name__} relies on \"\n                \"multi_tensor_copy to set dummy_overflow_buf to indicate \"\n                \"whether there's gradient Inf/NaN, build APEX with \"\n                \"`--deprecated_fused_adam` is essential.\"\n            )\n\n        # If capturable for CUDA graph\n        self.capturable: bool = capturable\n        # If the optimizer is capturable then LR should be a tensor (on GPU)\n        if capturable:\n            lr = torch.tensor(lr, dtype=torch.float32, device=device)\n\n        defaults = dict(\n            lr=lr,\n            bias_correction=bias_correction,\n            betas=betas,\n            eps=eps,\n            weight_decay=weight_decay,\n        )\n        super().__init__(params, defaults)\n\n        # Adam options\n        self.adam_w_mode: bool = adam_w_mode\n        self.amsgrad: bool = amsgrad\n        if amsgrad:\n            raise RuntimeError(\"DistributedFusedAdam does not support the AMSGrad variant.\")\n\n        # Datatype options\n        if grad_sync_dtype is None:\n            grad_sync_dtype = dtype\n        if param_sync_dtype is None:\n            param_sync_dtype = dtype\n        supported_dtypes = (torch.float32, torch.float16, torch.bfloat16)\n        if dtype not in supported_dtypes or grad_sync_dtype not in supported_dtypes:\n            raise ValueError(\n                \"Unsupported dtypes for DistributedFusedAdam \"\n                f\"(dtype={dtype}, \"\n                f\"grad_sync_dtype={grad_sync_dtype}, \"\n                f\"param_sync_dtype={param_sync_dtype}))\"\n            )\n        self.dtype: torch.dtype = dtype\n        self.grad_sync_dtype: torch.dtype = grad_sync_dtype\n        self.param_sync_dtype: torch.dtype = param_sync_dtype\n\n        # Device options\n        if not _devices_match(device, \"cuda\"):\n            raise RuntimeError(f\"Invalid device for DistributedFusedAdam (device={device})\")\n        self.device: torch.device = torch.device(\"cuda\", torch.cuda.current_device())\n\n        # Process groups\n        self.process_group: torch.distributed.ProcessGroup = (\n            _get_default_group() if process_group is None else process_group\n        )\n        self.distributed_process_group: torch.distributed.ProcessGroup = (\n            self.process_group if distributed_process_group is None else distributed_process_group\n        )\n        self.redundant_process_group: Optional[torch.distributed.ProcessGroup] = (\n            redundant_process_group\n        )\n        self.process_group_size: int = torch.distributed.get_world_size(self.process_group)\n        self.distributed_rank: int = torch.distributed.get_rank(self.distributed_process_group)\n        self.distributed_size: int = torch.distributed.get_world_size(\n            self.distributed_process_group\n        )\n        self.redundant_size: int = (\n            1\n            if self.redundant_process_group is None\n            else torch.distributed.get_world_size(self.redundant_process_group)\n        )\n        if self.process_group_size != self.distributed_size * self.redundant_size:\n            raise RuntimeError(\n                \"Invalid process group configuration \"\n                f\"(process group size = {self.process_group_size}, \"\n                f\"distributed process group size = {self.distributed_size}, \"\n                f\"redundant process group size = {self.redundant_size})\"\n            )\n        self.process_group_root: int = get_global_rank(self.process_group, 0)\n\n        # Use average reduction for grad sync\n        self.average_grad_sync: bool = average_grad_sync\n        # Copy param grads to bucket as soon as available\n        self.greedy_grad_copy: bool = True\n        # Synchronize grad buckets as soon as their grads are available\n        self.overlap_grad_sync: bool = overlap_grad_sync\n        # Try synchronizing param buckets just before param is needed\n        self.overlap_param_sync: bool = overlap_param_sync\n        # Number of buckets to synchronize at a time\n        self.pipeline_size: int = pipeline_size\n\n        # Store params or param remainders\n        if store_param_remainders:\n            if store_params:\n                raise RuntimeError(\n                    \"Attempted to construct DistributedFusedAdam \"\n                    \"with store_params=True and store_param_remainders=True\"\n                )\n            if self.dtype != torch.float32 or self.param_sync_dtype != torch.bfloat16:\n                raise RuntimeError(\n                    \"DistributedFusedAdam requires \"\n                    \"BF16 params and FP32 optimizer state \"\n                    \"when storing parameter remainders \"\n                    f\"(dtype={self.dtype}, \"\n                    f\"param_sync_dtype={self.param_sync_dtype}))\"\n                )\n        self.store_params: bool = store_params\n        self.store_param_remainders: bool = store_param_remainders\n\n        # Whether to scale optimizer state\n        self.with_scaled_states: bool = with_scaled_states\n        if self.with_scaled_states:\n            if not self.store_params:\n                raise RuntimeError(\n                    \"Attempted to construct DistributedFusedAdam \"\n                    \"with with_scaled_state=True and store_params=False\"\n                )\n            if self.store_param_remainders:\n                raise RuntimeError(\n                    \"Attempted to construct DistributedFusedAdam \"\n                    \"with with_scaled_state=True and store_params_remainders=True\"\n                )\n            if self.dtype not in (torch.float16, torch.bfloat16):\n                raise RuntimeError(\n                    \"Attempted to construct DistributedFusedAdam \"\n                    f\"with with_scaled_state=True and dtype={self.dtype} \"\n                    \"(only fp16 and bf16 are supported)\"\n                )\n            if self.param_sync_dtype == torch.float32:\n                # _local_step_with_scaled_states applies Adam kernel\n                # to fp32 workspace buffer and relies on\n                # _check_params_shard_dtypes to copy to param sync\n                # workspace buffer. However,\n                # _check_params_shard_dtypes does nothing if\n                # param_sync_dtype is fp32.\n                raise RuntimeError(\n                    \"Attempted to construct DistributedFusedAdam \"\n                    f\"with with_scaled_state=True and param_sync_dtype={self.param_sync_dtype}\"\n                )\n        # Scaling factors to apply to recover unscaled optimizer state\n        self._state_scales: dict = {}\n\n        # Determine bucket sizes\n        dtype_size = torch.finfo(self.grad_sync_dtype).bits // 8\n        self.alignment: int = 128 // dtype_size\n        self.bucket_cap_mb: float = bucket_cap_mb\n        bucket_size = 1024 * 1024 * bucket_cap_mb / dtype_size\n        shard_size = int(bucket_size / self.distributed_size)\n        shard_size = _round_to_multiple(shard_size, self.alignment, round_up=False)\n        shard_size = max(shard_size, self.alignment)\n        self.default_shard_size: int = shard_size\n\n        # Optimizer state\n        self.state[\"buckets\"]: List[StateBucket] = []\n        self.state[\"step\"]: torch.Tensor | int = (\n            torch.tensor([0], dtype=torch.int, device=self.device) if self.capturable else 0\n        )\n\n        # Gradient state\n        self._grads_buckets: Dict[int, GradientBucket] = collections.defaultdict(\n            self.GradientBucket\n        )\n        # Param state\n        self._params_buckets: Dict[int, ParameterBucket] = collections.OrderedDict()\n\n        # Whether to allocate contiguous buffers for parameters\n        self.contiguous_param_buffer: bool = contiguous_param_buffer\n        # Whether to allocate contiguous buffers for gradients\n        self.contiguous_grad_buffer: bool = contiguous_grad_buffer\n        # Whether to use NCCL User Buffer\n        self.nccl_ub: bool = nccl_ub\n        # Contiguous buffers for parameters\n        self._param_buffers: Dict[Tuple[torch.dtype, torch.dtype, torch.dtype], torch.Tensor] = {}\n        # Contiguous buffers for gradients\n        self._grad_buffers: Dict[Tuple[torch.dtype, torch.dtype, torch.dtype], torch.Tensor] = {}\n        # Output buffer for gradient shards, only required for NCCL user buffer\n        if self.nccl_ub:\n            if not nccl_allocator:\n                raise RuntimeError(\"NCCL allocator importing failed but nccl ub is still requested\")\n            elif not self.contiguous_grad_buffer:\n                raise RuntimeError(\"NCCL user buffers require contiguous grad buffers\")\n            else:\n                self._shard_grad_buffers: Dict[\n                    Tuple[torch.dtype, torch.dtype, torch.dtype], torch.Tensor\n                ] = {}\n\n        # Side streams for state dict communication\n        self._pipeline_streams: List[torch.cuda.Stream] = [\n            torch.cuda.Stream() for _ in range(self.pipeline_size)\n        ]\n        # Side streams for gradients and parameters communication\n        self._comm_streams: List[torch.cuda.Stream] = [\n            torch.cuda.Stream() for _ in range(self.pipeline_size)\n        ]\n        self._last_comm_stream_id: int = -1\n\n        # Scale by factor before optimizer step. Used for grad\n        # clipping and gradient scaler.\n        self._grad_scale: torch.Tensor = torch.full(\n            [], 1.0, dtype=torch.float32, device=self.device\n        )\n        # Norm of parameter gradients. Used for gradient clipping and\n        # gradient scaler.\n        self._grad_norm: Optional[torch.Tensor] = None\n\n        # Dummy flag for multi-tensor kernels\n        # Note: Apex multi-tensor kernels have a noop_flag argument\n        # that is intended to detect non-finite values. It shouldn't\n        # have any effect with the kernels used in the optimizer, but\n        # we still set it to zero out of an abundance of caution.\n        self._dummy_overflow_buf: torch.Tensor = torch.zeros(\n            [1], dtype=torch.int32, device=self.device\n        )\n\n        # Check if collectives have no_copy option\n        self._gather_no_copy: bool = (\n            \"no_copy\" in inspect.getfullargspec(torch.distributed.gather).args\n        )\n\n        # Make sure parameter values are same across processes\n        self._broadcast_params()\n\n        # Lock for callbacks\n        self._lock: threading.Lock = threading.Lock()\n        # Attach hooks for gradient synchronization\n        self._register_post_backward_hooks()\n        # Attach hooks for param synchronization\n        if self.overlap_param_sync:\n            self._register_pre_forward_hooks()\n\n        # Move LR to device\n        if capturable:\n            for idx, group in enumerate(self.param_groups):\n                if len(group[\"params\"]) == 0:\n                    continue\n                for item in [\"lr\"]:\n                    if torch.is_tensor(group[item]):\n                        self.param_groups[idx][item] = group[item].to(device=self.device)\n                    else:\n                        self.param_groups[idx][item] = torch.tensor(group[item], device=self.device)\n\n        # For better representation string\n        arg_names = inspect.getfullargspec(DistributedFusedAdam.__init__).args\n        arg_names.remove(\"self\")\n        arg_names.remove(\"params\")\n        for i, group in enumerate(self.param_groups):\n            for key in sorted(group.keys()):\n                if key in arg_names:\n                    arg_names.remove(key)\n        self.args_dict = {name: getattr(self, name) for name in arg_names}\n\n    def __repr__(self) -> str:\n        # Based on: https://github.com/pytorch/pytorch/blob/v2.3.0-rc12/torch/optim/optimizer.py#L315\n        format_string = self.__class__.__name__ + \" (\"\n        for i, group in enumerate(self.param_groups):\n            format_string += \"\\n\"\n            format_string += f\"Parameter Group {i}\\n\"\n            for key in sorted(group.keys()):\n                if key != \"params\":\n                    format_string += f\"    {key}: {group[key]}\\n\"\n\n        for key, val in self.args_dict.items():\n            if \"process_group\" in key and val:\n                format_string += f\"{key}: {hex(id(val))}, world size {val.size()}\\n\"\n            else:\n                format_string += f\"{key}: {val}\\n\"\n\n        format_string += \")\"\n        return format_string\n\n    @torch.no_grad()\n    def _broadcast_params(self) -> None:\n        \"\"\"Broadcast parameter values from root rank\"\"\"\n        process_group = self.process_group\n        with _coalescing_manager(process_group, self.device, async_ops=True) as cm:\n            for param_group in self.param_groups:\n                for param in param_group[\"params\"]:\n                    _coalescing_manager_append_work(\n                        cm,\n                        torch.distributed.broadcast(\n                            param,\n                            src=self.process_group_root,\n                            group=process_group,\n                            async_op=True,\n                        ),\n                    )\n        cm.wait()\n\n    def _make_post_backward_hook(\n        self,\n        param: torch.nn.Parameter,\n        param_group_id: int,\n        param_id: int,\n    ) -> Callable:\n        \"\"\"Create callback function to call after param generates grad\n\n        Lazily initialize parameter and try launching grad sync.\n\n        \"\"\"\n\n        def post_backward_hook(*unused) -> None:\n            if getattr(param, \"_pre_forward_hook_is_enabled\", False):\n                raise RuntimeError(\n                    \"A parameter called its post-backward hook \"\n                    \"before its pre-forward hook. \"\n                    \"Please manually interact with the parameter \"\n                    \"before the forward pass (e.g. by calling data_ptr) \"\n                    \"or run DistributedFusedAdam with overlap_param_sync=False.\"\n                )\n            with self._lock:\n                need_to_initialize = \"fragments\" not in self.state[param]\n                if need_to_initialize:\n                    self._init_param_state(param, param_group_id, param_id)\n                if self.greedy_grad_copy:\n                    self._grad_copy(param)\n                    if self.overlap_grad_sync:\n                        self._try_start_bucket_grad_sync(\n                            params=[param],\n                            ignore_last_bucket=need_to_initialize,\n                        )\n\n        return post_backward_hook\n\n    def _register_post_backward_hooks(self) -> None:\n        \"\"\"Attach hooks for gradient synchronization\"\"\"\n        self._grad_accs = []\n        for param_group_id, group in enumerate(self.param_groups):\n            for param_id, param in enumerate(group[\"params\"]):\n                if param.requires_grad:\n                    param_tmp = param.expand_as(param)\n                    grad_acc = param_tmp.grad_fn.next_functions[0][0]\n                    hook = self._make_post_backward_hook(\n                        param,\n                        param_group_id,\n                        param_id,\n                    )\n                    grad_acc.register_hook(hook)\n                    self._grad_accs.append(grad_acc)\n\n    def _make_pre_forward_hook(\n        self,\n        param: torch.nn.Parameter,\n        param_group_id: int,\n        param_id: int,\n    ) -> Callable:\n        \"\"\"Create callback function to call before param forward pass\n\n        Make sure param has been synchronized and try launching next\n        param sync.\n\n        \"\"\"\n\n        def pre_forward_hook(*unused) -> None:\n            with self._lock:\n                if \"fragments\" not in self.state[param]:\n                    return\n                self._param_copy(param)\n                if self.overlap_param_sync:\n                    self._try_start_bucket_param_sync()\n\n        return pre_forward_hook\n\n    def _register_pre_forward_hooks(self) -> None:\n        \"\"\"Attach hooks for parameter synchronization\n\n        If _pre_forward_hook_is_enabled is set in a parameter, then\n        the callback will be called the first time any of its\n        attributes are accessed. This is hackily done by\n        monkey-patching the parameter class, so proceed with caution.\n\n        \"\"\"\n        for param_group_id, group in enumerate(self.param_groups):\n            for param_id, param in enumerate(group[\"params\"]):\n                # Monkey-patch parameter class\n                cls = param.__class__\n                if not getattr(cls, \"_has_pre_forward_hook\", False):\n                    # Monkey-patch magic methods to call __getattribute__\n                    special_funcs = [\n                        \"__abs__\",\n                        \"__add__\",\n                        \"__and__\",\n                        \"__bool__\",\n                        \"__complex__\",\n                        \"__contains__\",\n                        \"__deepcopy__\",\n                        \"__delitem__\",\n                        \"__div__\",\n                        \"__eq__\",\n                        \"__float__\",\n                        \"__floordiv__\",\n                        \"__ge__\",\n                        \"__getitem__\",\n                        \"__gt__\",\n                        \"__iadd__\",\n                        \"__iand__\",\n                        \"__idiv__\",\n                        \"__ifloordiv__\",\n                        \"__ilshift__\",\n                        \"__imod__\",\n                        \"__imul__\",\n                        \"__index__\",\n                        \"__int__\",\n                        \"__invert__\",\n                        \"__ior__\",\n                        \"__ipow__\",\n                        \"__irshift__\",\n                        \"__isub__\",\n                        \"__iter__\",\n                        \"__itruediv__\",\n                        \"__ixor__\",\n                        \"__le__\",\n                        \"__len__\",\n                        \"__long__\",\n                        \"__lshift__\",\n                        \"__lt__\",\n                        \"__matmul__\",\n                        \"__mod__\",\n                        \"__mul__\",\n                        \"__neg__\",\n                        \"__nonzero__\",\n                        \"__or__\",\n                        \"__pos__\",\n                        \"__pow__\",\n                        \"__radd__\",\n                        \"__rand__\",\n                        \"__rdiv__\",\n                        \"__reduce__\",\n                        \"__reduce_ex__\",\n                        \"__reversed__\",\n                        \"__rfloordiv__\",\n                        \"__rlshift__\",\n                        \"__rmatmul__\",\n                        \"__rmod__\",\n                        \"__rmul__\",\n                        \"__ror__\",\n                        \"__rpow__\",\n                        \"__rrshift__\",\n                        \"__rshift__\",\n                        \"__rsub__\",\n                        \"__rtruediv__\",\n                        \"__rxor__\",\n                        \"__setitem__\",\n                        \"__sizeof__\",\n                        \"__sub__\",\n                        \"__truediv__\",\n                        \"__xor__\",\n                    ]\n                    for func_name in special_funcs:\n\n                        def make_augmented_func() -> Callable:\n                            base_func_name = f\"_base_{func_name}\"\n\n                            def augmented_func(self, *args, **kwargs):\n                                return getattr(self, base_func_name)(*args, **kwargs)\n\n                            return augmented_func\n\n                        setattr(cls, f\"_base_{func_name}\", getattr(cls, func_name))\n                        setattr(cls, func_name, make_augmented_func())\n\n                    # Monkey-patch __getattribute__ to call pre-forward hook\n                    def make_getattribute() -> Callable[[str], Any]:\n                        special_attrs = {\n                            \"_pre_forward_hook_is_enabled\",\n                            \"_pre_forward_hook\",\n                            \"__del__\",\n                            \"__delattr__\",\n                            \"__dir__\",\n                            \"__getattr__\",\n                            \"__getattribute__\",\n                            \"__hash__\",\n                            \"__init__\",\n                            \"__new__\",\n                            \"__setattr__\",\n                        }\n\n                        def getattribute_with_pre_forward_hook(self, name: str):\n                            \"\"\"Variant of __getattribute__ that can call pre-forward hook\"\"\"\n                            if name not in special_attrs:\n                                if getattr(self, \"_pre_forward_hook_is_enabled\", False):\n                                    self._pre_forward_hook_is_enabled = False\n                                    self._pre_forward_hook()\n                            return object.__getattribute__(self, name)\n\n                        return getattribute_with_pre_forward_hook\n\n                    cls.__getattribute__ = make_getattribute()\n                    cls._has_pre_forward_hook = True\n\n                # Register pre-forward callback\n                param._pre_forward_hook_is_enabled = False\n                param._pre_forward_hook = self._make_pre_forward_hook(\n                    param,\n                    param_group_id,\n                    param_id,\n                )\n\n    @torch.no_grad()\n    def init_param_buffer(self) -> None:\n        \"\"\"Allocate contiguous buffers for param buckets\n\n        This converts the parameters into views into contiguous\n        buffers. This enables some performance optimizations (e.g.\n        avoiding some memory copies), but may add memory overhead\n        (e.g. if the memory allocator can't reuse the original\n        parameter buffers). To minimize memory overhead, this buffer\n        should be initialized before the first training step.\n\n        \"\"\"\n\n        # Make sure all params are initialized\n        self.contiguous_param_buffer = True\n        self.init_params()\n\n        # Construct param buffers\n        buffer_sizes = collections.defaultdict(lambda: 0)\n        for bucket in self.state[\"buckets\"]:\n            dtypes = bucket.dtypes()\n            buffer_sizes[dtypes] = max(\n                bucket.contiguous_buffer_offset + bucket.bucket_size,\n                buffer_sizes[dtypes],\n            )\n        for dtypes, buffer_size in buffer_sizes.items():\n            _, _, param_sync_dtype = dtypes\n            self._param_buffers[dtypes] = torch.zeros(\n                [buffer_size],\n                dtype=param_sync_dtype,\n                device=self.device,\n            )\n\n        # Figure out corresponding positions in params and param buffer\n        params = list(self.parameters())\n        param_flat_views = []\n        param_buffer_views = []\n        for i, param in enumerate(params):\n            fragment = self.state[param][\"fragments\"][0]\n            bucket_id = fragment.bucket_id\n            bucket = self.state[\"buckets\"][bucket_id]\n            param_size = param.numel()\n            bucket_start, _ = fragment.bucket_range\n            buffer_offset = bucket.contiguous_buffer_offset\n            buffer_start = buffer_offset + bucket_start\n            buffer_end = buffer_start + param_size\n            param_buffer = self._param_buffers[bucket.dtypes()]\n            param_buffer_view = param_buffer[buffer_start:buffer_end].detach()\n            if not _devices_match(param_buffer_view.device, param.device):\n                raise RuntimeError(\n                    \"Attempted to change a parameter with device={param.device} \"\n                    f\"into a buffer view with device={param_buffer_view.device}\"\n                )\n            if param_buffer_view.dtype != param.dtype:\n                if (\n                    not torch.is_floating_point(param_buffer_view)\n                    and param_buffer_view.element_size() == param.element_size()\n                ):\n                    param_buffer_view = param_buffer_view.view(dtype=param.dtype)\n                else:\n                    raise RuntimeError(\n                        f\"Attempted to change a parameter with dtype={param.dtype} \"\n                        f\"into a buffer view with dtype={param_buffer_view.dtype}\"\n                    )\n            if param.is_contiguous(memory_format=torch.channels_last):\n                param = param.permute(0, 2, 3, 1)\n            param_flat_views.append(param.detach().view(-1))\n            param_buffer_views.append(param_buffer_view)\n\n        # Copy values into param buffer\n        _multi_tensor_copy(\n            param_flat_views,\n            param_buffer_views,\n            dummy_overflow_buf=self._dummy_overflow_buf,\n        )\n\n        # Make all params a view into the param buffer\n        for param, buffer_view in zip(params, param_buffer_views):\n            # Preserve memory format for param here, i.e. NHWC tensors\n            # `param.data.set_()` failed to change storage.\n            # `param.set_()` invalidates bprop hook.\n            param.data = buffer_view.as_strided(param.size(), param.stride())\n\n    def _init_grad_buffer(self) -> None:\n        \"\"\"Allocate contiguous buffer for grad buckets\"\"\"\n\n        # Make sure all params are initialized\n        self.contiguous_grad_buffer = True\n        self.init_params()\n\n        # Construct grad buffers\n        buffer_sizes = collections.defaultdict(lambda: 0)\n        for bucket in self.state[\"buckets\"]:\n            dtypes = bucket.dtypes()\n            buffer_sizes[dtypes] = max(\n                bucket.contiguous_buffer_offset + bucket.bucket_size,\n                buffer_sizes[dtypes],\n            )\n        for dtypes, buffer_size in buffer_sizes.items():\n            _, grad_sync_dtype, _ = dtypes\n            if not self.nccl_ub:\n                self._grad_buffers[dtypes] = torch.zeros(\n                    [buffer_size],\n                    dtype=grad_sync_dtype,\n                    device=self.device,\n                )\n            else:\n                pool = nccl_allocator.create_nccl_mem_pool()\n                with nccl_allocator.nccl_mem(pool):\n                    self._grad_buffers[dtypes] = torch.zeros(\n                        [buffer_size],\n                        dtype=grad_sync_dtype,\n                        device=self.device,\n                    )\n                shard_buffer_size = buffer_size // self.distributed_size\n                with nccl_allocator.nccl_mem(pool):\n                    self._shard_grad_buffers[dtypes] = torch.zeros(\n                        [shard_buffer_size],\n                        dtype=grad_sync_dtype,\n                        device=self.device,\n                    )\n\n    def parameters(self) -> Iterable[torch.nn.Parameter]:\n        \"\"\"Returns an iterator over optimizer parameters\"\"\"\n        return itertools.chain.from_iterable(group[\"params\"] for group in self.param_groups)\n\n    def parameter(\n        self,\n        *args: Union[int, ParameterFragment],\n    ) -> torch.nn.Parameter:\n        \"\"\"Get optimizer parameter\n\n        Can either accept two ints or one\n        DistributedFusedAdam.ParameterFragment.\n\n        Arguments:\n            param_group_id (int): Parameter group index\n            param_id (int): Parameter index within parameter group\n\n        \"\"\"\n        if len(args) == 2 and isinstance(args[0], int) and isinstance(args[1], int):\n            param_group_id = args[0]\n            param_id = args[1]\n        elif len(args) == 1 and isinstance(args[0], self.ParameterFragment):\n            fragment = args[0]\n            param_group_id = fragment.param_group_id\n            param_id = fragment.param_id\n        else:\n            raise TypeError(\n                \"Expected input types are \"\n                \"[int, int] or [DistributedFusedAdam.ParameterFragment], \"\n                f\"but found {[type(arg).__name__ for arg in args]}\"\n            )\n        return self.param_groups[param_group_id][\"params\"][param_id]\n\n    def init_params(\n        self,\n        params: Optional[Iterable[torch.nn.Parameter]] = None,\n        dtype: Optional[torch.dtype] = None,\n        grad_sync_dtype: Optional[torch.dtype] = None,\n        param_sync_dtype: Optional[torch.dtype] = None,\n    ) -> None:\n        \"\"\"Initialize optimizer state for parameters\n\n        Ignores parameters that have already been initialized.\n\n        Arguments:\n            params (iterable, optional): parameters to initialize\n                (default: all parameters)\n\n        \"\"\"\n\n        # Default cases\n        if params is None:\n            params = self.parameters()\n        elif isinstance(params, torch.Tensor):\n            params = [params]\n\n        # Ignore parameters that have already been initialized\n        params = [param for param in params if \"fragments\" not in self.state[param]]\n        if not params:\n            return\n\n        # Get indices corresponding to parameters\n        id_map = dict()\n        for param_group_id, group in enumerate(self.param_groups):\n            for param_id, param in enumerate(group[\"params\"]):\n                id_map[param] = (param_group_id, param_id)\n\n        # Initialize parameters\n        for param in params:\n            if param in id_map:\n                param_group_id, param_id = id_map[param]\n                self._init_param_state(\n                    param,\n                    param_group_id,\n                    param_id,\n                    dtype=dtype,\n                    grad_sync_dtype=grad_sync_dtype,\n                    param_sync_dtype=param_sync_dtype,\n                )\n\n    def init_params_bucket(\n        self,\n        params: Iterable[torch.nn.Parameter],\n        dtype: Optional[torch.dtype] = None,\n        grad_sync_dtype: Optional[torch.dtype] = None,\n        param_sync_dtype: Optional[torch.dtype] = None,\n    ) -> None:\n        \"\"\"Initialize optimizer state for parameters in one effective bucket\n\n        The buckets corresponding to the provided parameters are\n        configured so they all perform communication together. Ignores\n        parameters that have already been initialized.\n\n        Arguments:\n            params (iterable): parameters to initialize\n\n        \"\"\"\n\n        # Ignore parameters that have already been initialized\n        if isinstance(params, torch.Tensor):\n            params = [params]\n        params = [param for param in params if \"fragments\" not in self.state[param]]\n        if not params:\n            return\n\n        # Get indices corresponding to parameters\n        id_map = dict()\n        for param_group_id, group in enumerate(self.param_groups):\n            for param_id, param in enumerate(group[\"params\"]):\n                id_map[param] = [param_group_id, param_id]\n        param_ids = [tuple([param] + id_map[param]) for param in params]\n\n        # Mark existings bucket as fully filled\n        for bucket in self.state[\"buckets\"]:\n            bucket.able_to_fill = False\n\n        # Initialize optimizer state for parameters\n        start_bucket_id = len(self.state[\"buckets\"])\n        self.init_params(\n            params,\n            dtype=dtype,\n            grad_sync_dtype=grad_sync_dtype,\n            param_sync_dtype=param_sync_dtype,\n        )\n        end_bucket_id = len(self.state[\"buckets\"])\n\n        # Make sure all added buckets depend on provided params\n        for bucket_id in range(start_bucket_id, end_bucket_id):\n            bucket = self.state[\"buckets\"][bucket_id]\n            bucket_size = bucket.bucket_size\n            bucket.able_to_fill = False\n            ids_in_bucket = set(\n                (fragment.param_group_id, fragment.param_id) for fragment in bucket.fragments\n            )\n            for param, param_group_id, param_id in param_ids:\n                if (param_group_id, param_id) not in ids_in_bucket:\n                    param_size = param.numel()\n                    fragment = self.ParameterFragment(\n                        param_group_id=param_group_id,\n                        param_id=param_id,\n                        bucket_id=bucket_id,\n                        param_range=(param_size, param_size),\n                        bucket_range=(bucket_size, bucket_size),\n                        in_local_shard=False,\n                        shard_range=None,\n                        shard_bucket_range=None,\n                        shard_param_range=None,\n                    )\n                    self.state[param][\"fragments\"].append(fragment)\n                    bucket.fragments.append(fragment)\n\n    @torch.no_grad()\n    def _init_param_state(\n        self,\n        param: torch.nn.Parameter,\n        param_group_id: int,\n        param_id: int,\n        dtype: Optional[torch.dtype] = None,\n        grad_sync_dtype: Optional[torch.dtype] = None,\n        param_sync_dtype: Optional[torch.dtype] = None,\n    ) -> None:\n        \"\"\"Initialize optimizer state for a parameter\"\"\"\n\n        # Return immediately if already initialized\n        if \"fragments\" in self.state[param]:\n            return\n        self.state[param][\"fragments\"] = []\n\n        # Data type configuration\n        if dtype is None:\n            dtype = self.dtype\n        if grad_sync_dtype is None:\n            grad_sync_dtype = self.grad_sync_dtype\n        if param_sync_dtype is None:\n            param_sync_dtype = self.param_sync_dtype\n        if dtype != self.dtype:\n            raise ValueError(\"Optimizer states with non-default dtypes are not supported\")\n        supported_dtypes = (torch.float32, torch.float16, torch.bfloat16)\n        if dtype not in supported_dtypes or grad_sync_dtype not in supported_dtypes:\n            raise ValueError(\n                \"Unsupported dtypes for DistributedFusedAdam \"\n                f\"(dtype={dtype}, \"\n                f\"grad_sync_dtype={grad_sync_dtype}, \"\n                f\"param_sync_dtype={param_sync_dtype}))\"\n            )\n\n        # Store params or param remainders\n        store_params = (\n            self.store_params or dtype != self.dtype or param_sync_dtype != self.param_sync_dtype\n        )\n        store_param_remainders = (\n            self.store_param_remainders\n            and dtype == self.dtype\n            and param_sync_dtype == self.param_sync_dtype\n        )\n\n        def last_bucket_id() -> int:\n            \"\"\"Index of last optimizer state bucket with desired dtypes\n\n            -1 if there are no such buckets.\n\n            \"\"\"\n            dtypes = (dtype, grad_sync_dtype, param_sync_dtype)\n            bucket_id = len(self.state[\"buckets\"]) - 1\n            while bucket_id > 0:\n                bucket = self.state[\"buckets\"][bucket_id]\n                if bucket.dtypes() == dtypes:\n                    break\n                bucket_id -= 1\n            return bucket_id\n\n        def make_bucket(\n            bucket_size: int,\n            shard_size: int,\n            buffer_offset: int,\n        ) -> None:\n            \"\"\"Construct new optimizer state bucket\"\"\"\n            self.state[\"buckets\"].append(\n                self.StateBucket(\n                    bucket_size,\n                    shard_size,\n                    dtype,\n                    self.device,\n                    grad_sync_dtype,\n                    param_sync_dtype,\n                    contiguous_buffer_offset=buffer_offset,\n                    store_params=store_params,\n                    store_param_remainders=store_param_remainders,\n                )\n            )\n\n        # Make sure there is at least one bucket with expected dtypes\n        if last_bucket_id() < 0:\n            shard_size = self.default_shard_size\n            bucket_size = shard_size * self.distributed_size\n            buffer_offset = 0\n            make_bucket(bucket_size, shard_size, buffer_offset)\n\n        # Split parameter values into fragments\n        # Note: Each fragment resides within a bucket\n        param_start = 0\n        param_size = param.numel()\n        while param_start < param_size:\n            # Get current bucket\n            bucket_id = last_bucket_id()\n            bucket = self.state[\"buckets\"][bucket_id]\n            fragment_id = len(bucket.fragments)\n            bucket_size = bucket.bucket_size\n            shard_size = bucket.shard_size\n\n            # Determine fragment position within bucket\n            bucket_start = _round_to_multiple(\n                bucket.filled_size,\n                self.alignment,\n                round_up=True,\n            )\n            fragment_size = min(param_size - param_start, bucket_size - bucket_start)\n            param_end = param_start + fragment_size\n            bucket_end = bucket_start + fragment_size\n\n            # Create new bucket if current one is full\n            if fragment_size <= 0 or not bucket.able_to_fill:\n                shard_size = self.default_shard_size\n                bucket_size = shard_size * self.distributed_size\n                buffer_offset = bucket.contiguous_buffer_offset + bucket.bucket_size\n                make_bucket(bucket_size, shard_size, buffer_offset)\n                continue\n\n            # Fragment position within local shard\n            shard_id = self.distributed_rank\n            shard_start = bucket_start - shard_size * shard_id\n            shard_end = bucket_end - shard_size * shard_id\n            shard_start = min(max(shard_start, 0), shard_size)\n            shard_end = min(max(shard_end, 0), shard_size)\n            in_local_shard = shard_start < shard_end\n            shard_range = None\n            shard_bucket_range = None\n            shard_param_range = None\n            if in_local_shard:\n                shard_range = (shard_start, shard_end)\n                shard_bucket_start = shard_start + shard_size * shard_id\n                shard_bucket_end = shard_bucket_start + shard_end - shard_start\n                shard_bucket_range = (shard_bucket_start, shard_bucket_end)\n                shard_param_start = shard_bucket_start - bucket_start + param_start\n                shard_param_end = shard_param_start + shard_end - shard_start\n                shard_param_range = (shard_param_start, shard_param_end)\n\n            # Record fragment info\n            fragment = self.ParameterFragment(\n                param_group_id=param_group_id,\n                param_id=param_id,\n                bucket_id=bucket_id,\n                param_range=(param_start, param_end),\n                bucket_range=(bucket_start, bucket_end),\n                in_local_shard=in_local_shard,\n                shard_range=shard_range,\n                shard_bucket_range=shard_bucket_range,\n                shard_param_range=shard_param_range,\n            )\n            self.state[param][\"fragments\"].append(fragment)\n            bucket.fragments.append(fragment)\n            bucket.filled_size = bucket_end\n            param_start = param_end\n\n        # Initialize optimizer state scaling factors if needed\n        if self.with_scaled_states:\n            for fragment in self.state[param][\"fragments\"]:\n                if not fragment.in_local_shard:\n                    continue\n                bucket_id = fragment.bucket_id\n                self._state_scales[(param_group_id, param_id, bucket_id)] = dict(\n                    param=torch.zeros([1], dtype=torch.float32, device=self.device),\n                    exp_avg=torch.zeros([1], dtype=torch.float32, device=self.device),\n                    exp_avg_sq=torch.zeros([1], dtype=torch.float32, device=self.device),\n                )\n\n        # Initialize main param buffer\n        if store_params:\n            for fragment in self.state[param][\"fragments\"]:\n                if not fragment.in_local_shard:\n                    continue\n                bucket_id = fragment.bucket_id\n                bucket = self.state[\"buckets\"][bucket_id]\n                # If param is channels last, i.e. tensor with shape (N, C, H, W)\n                # and stride (HWC, 1, WC, C), then we will turn it into a tensor\n                # with shape (N, H, W, C) and stride (HWC, WC, C, 1). The purppose\n                # is to avoid failures when flattening the tensor (`.view(-1)`)\n                # and stepping the optimizer.\n                if param.is_contiguous(memory_format=torch.channels_last):\n                    param = param.permute(0, 2, 3, 1)\n                param_range = slice(*fragment.shard_param_range)\n                shard_range = slice(*fragment.shard_range)\n                model_param_fragment = param.detach().view(-1)[param_range]\n                if self.with_scaled_states:\n                    model_param_fragment = torch.empty_like(\n                        model_param_fragment,\n                        dtype=torch.float32,\n                    ).copy_(model_param_fragment)\n                    self._apply_state_scale(\n                        model_param_fragment,\n                        self._state_scales[(param_group_id, param_id, bucket_id)][\"param\"],\n                    )\n                main_param_fragment = bucket.params_shard[shard_range]\n                main_param_fragment.copy_(model_param_fragment)\n\n        # Check if buckets are underutilized\n        if all(\"fragments\" in self.state[param] for param in self.parameters()):\n            bucket_size = sum(bucket.bucket_size for bucket in self.state[\"buckets\"])\n            filled_size = sum(bucket.filled_size for bucket in self.state[\"buckets\"])\n            buckets_utilization = filled_size / bucket_size\n            if buckets_utilization < 0.7:\n                warnings.warn(\n                    f\"Only {buckets_utilization:.1%} of buckets are used. \"\n                    \"Consider decreasing the bucket_cap_mb argument.\"\n                )\n\n    def zero_grad(self, set_to_none: bool = False) -> None:\n        \"\"\"Clear parameter gradients\"\"\"\n\n        # Reset bucket buffers\n        self._grads_buckets.clear()\n\n        # Construct views into contiguous grad buffer, if needed\n        if self.contiguous_grad_buffer:\n            if not self._grad_buffers:\n                self._init_grad_buffer()\n            for grad_buffer in self._grad_buffers.values():\n                grad_buffer.zero_()\n            for bucket_id, bucket in enumerate(self.state[\"buckets\"]):\n                bucket_size = bucket.bucket_size\n                buffer_start = bucket.contiguous_buffer_offset\n                buffer_end = buffer_start + bucket_size\n                grad_buffer = self._grad_buffers[bucket.dtypes()]\n                self._grads_buckets[bucket_id].grads_bucket = grad_buffer[buffer_start:buffer_end]\n                if self.nccl_ub:\n                    shard_size = bucket.shard_size\n                    shard_buffer_start = bucket.contiguous_buffer_offset // self.distributed_size\n                    shard_buffer_end = shard_buffer_start + shard_size\n                    shard_grad_buffer = self._shard_grad_buffers[bucket.dtypes()]\n                    self._grads_buckets[bucket_id].sync_grads_shard = shard_grad_buffer[\n                        shard_buffer_start:shard_buffer_end\n                    ]\n\n        # Reset param grads\n        for param in self.parameters():\n            with _disable_pre_forward_hook(param):\n                need_to_zero = True\n                if set_to_none:\n                    param.grad = None\n                elif self.contiguous_grad_buffer:\n                    bucket_id = self.state[param][\"fragments\"][0].bucket_id\n                    bucket = self.state[\"buckets\"][bucket_id]\n                    if param.dtype == bucket.grad_sync_dtype and _devices_match(\n                        param.device, self.device\n                    ):\n                        param.grad = self.grad_buffer_view(param)\n                        need_to_zero = False\n                if need_to_zero and param.grad is not None:\n                    param.grad.zero_()\n\n        # Reset other state\n        self._grad_scale.fill_(1.0)\n        self._grad_norm = None\n        self._dummy_overflow_buf.zero_()\n\n    def _grad_copy(self, param: torch.nn.Parameter) -> None:\n        \"\"\"Copy parameter gradients to gradient buckets\n\n        Initializes gradient buckets if needed. The original parameter\n        gradient is set to None.\n\n        \"\"\"\n\n        # Initialize parameter if needed\n        if \"fragments\" not in self.state[param]:\n            for param_group_id, group in enumerate(self.param_groups):\n                for param_id, param_ in enumerate(group[\"params\"]):\n                    if param is param_:\n                        self._init_param_state(param, param_group_id, param_id)\n            if \"fragments\" not in self.state[param]:\n                raise RuntimeError(\"Could not initialize DistributedFusedAdam with parameter\")\n\n        # Copy param grad to buckets\n        for fragment in self.state[param][\"fragments\"]:\n            # Get fragment position\n            bucket_id = fragment.bucket_id\n            bucket = self._grads_buckets[bucket_id]\n            bucket_size = self.state[\"buckets\"][bucket_id].bucket_size\n            grad_sync_dtype = self.state[\"buckets\"][bucket_id].grad_sync_dtype\n            grad_start, grad_end = fragment.param_range\n            bucket_start, bucket_end = fragment.bucket_range\n\n            # Set reduction status\n            if bucket.status == self.GradientStatus.SYNCING:\n                self._finish_bucket_grad_sync()\n            bucket.status = self.GradientStatus.PARTIALLY_FILLED\n\n            # Allocate gradient buffer if needed\n            if bucket.grads_bucket is None and self.contiguous_grad_buffer:\n                if not self._grad_buffers:\n                    self._init_grad_buffer()\n                state_bucket = self.state[\"buckets\"][bucket_id]\n                buffer_start = state_bucket.contiguous_buffer_offset\n                buffer_end = buffer_start + bucket_size\n                grad_buffer = self._grad_buffers[state_bucket.dtypes()]\n                grad_buffer = grad_buffer[buffer_start:buffer_end]\n                if (\n                    bucket.grads_shard is None\n                    or bucket.grads_shard.storage().data_ptr() != grad_buffer.storage().data_ptr()\n                ):\n                    bucket.grads_bucket = grad_buffer\n                    bucket.grads_bucket.zero_()\n            if bucket.grads_bucket is None:\n                bucket.grads_bucket = torch.zeros(\n                    [bucket_size],\n                    dtype=grad_sync_dtype,\n                    device=self.device,\n                )\n\n            # Copy param grad to bucket\n            if param.grad is not None:\n                if param.grad.is_contiguous(memory_format=torch.channels_last):\n                    grad_in = param.grad.permute(0, 2, 3, 1)\n                else:\n                    grad_in = param.grad\n                grad_in = grad_in.detach().view(-1)[grad_start:grad_end]\n                grad_out = bucket.grads_bucket[bucket_start:bucket_end]\n                if grad_in.data_ptr() != grad_out.data_ptr():\n                    grad_out.add_(grad_in)\n\n        # Free param grad buffer\n        param.grad = None\n\n    def _param_copy(\n        self,\n        params: Union[torch.nn.Parameter, Iterable[torch.nn.Parameter]],\n    ) -> None:\n        \"\"\"Update parameters with values from parameter buckets\n\n        Synchronizes and deletes parameter buckets as needed.\n\n        \"\"\"\n\n        # Get parameter fragments to be synchronized\n        if isinstance(params, torch.Tensor):\n            params = [params]\n        fragments = []\n        for param in params:\n            if \"fragments\" in self.state[param]:\n                fragments.extend(\n                    fragment\n                    for fragment in self.state[param][\"fragments\"]\n                    if fragment.bucket_id in self._params_buckets\n                )\n\n        # Return immediately if no fragments need to be synchronized\n        if not fragments:\n            return\n\n        # Make sure all needed buckets have been synchronized\n        buckets = collections.OrderedDict()\n        for fragment in fragments:\n            bucket_id = fragment.bucket_id\n            bucket = self._params_buckets[bucket_id]\n            buckets[bucket] = bucket.status\n        if any(status != self.ParameterStatus.READY for bucket, status in buckets.items()):\n            self._start_bucket_param_sync(buckets.keys())\n            self._finish_bucket_param_sync()\n\n        # Copy values from bucket buffers to params\n        self._param_copy_fragments(fragments)\n\n        # Delete buckets if possible\n        for fragment in fragments:\n            bucket_id = fragment.bucket_id\n            bucket = self._params_buckets[bucket_id]\n            bucket.params_updated.add(self.parameter(fragment))\n            bucket_fragments = self.state[\"buckets\"][bucket_id].fragments\n            if len(bucket.params_updated) == len(bucket_fragments):\n                del self._params_buckets[bucket_id]\n\n    def _param_copy_fragments(\n        self,\n        fragments: Iterable[ParameterFragment],\n    ) -> None:\n        \"\"\"Update parameter fragments with values from parameter buckets\"\"\"\n\n        # Figure out corresponding positions in param buckets and params\n        buffers_in = []\n        buffers_out = []\n        for fragment in fragments:\n            # Check if fragment needs to be updated\n            bucket_id = fragment.bucket_id\n            bucket_start, bucket_end = fragment.bucket_range\n            param_start, param_end = fragment.param_range\n            if param_end <= param_start or bucket_id not in self._params_buckets:\n                continue\n\n            # Corresponding positions in param bucket and param\n            bucket = self._params_buckets[bucket_id]\n            param = self.parameter(fragment)\n\n            # Conv with NHWC layout, i.e. shape (N, C, H, W) and stride\n            # (HWC, 1, WC, C), can't `.view(-1)`. Here to turn it to\n            # tensor with shape (N, H, W, C) and stride (HWC, WC, C, 1).\n            if param.is_contiguous(memory_format=torch.channels_last):\n                param = param.permute(0, 2, 3, 1)\n\n            buffer_in = bucket.params_bucket[bucket_start:bucket_end]\n            buffer_out = param.detach().view(-1)[param_start:param_end]\n\n            if torch.is_floating_point(buffer_in) and torch.is_floating_point(buffer_out):\n                # Cast between floating-point dtypes\n                buffers_in.append(buffer_in)\n                buffers_out.append(buffer_out)\n            else:\n                # Copy most significant bytes for non-floating-point\n                # dtypes\n                # Note: Assume dtypes are little-endian\n                in_bytes = buffer_in.unsqueeze(-1).view(torch.uint8)\n                out_bytes = buffer_out.unsqueeze(-1).view(torch.uint8)\n                copy_size = min(in_bytes.size(-1), out_bytes.size(-1))\n                buffers_in.append(in_bytes[..., -copy_size:])\n                buffers_out.append(out_bytes[..., -copy_size:])\n                if copy_size < out_bytes.size(-1):\n                    out_bytes[..., :-copy_size].zero_()\n\n        # Copy data from parameter buckets to parameters\n        _multi_tensor_copy(\n            buffers_in,\n            buffers_out,\n            dummy_overflow_buf=self._dummy_overflow_buf,\n        )\n\n    def grad_buffer_view(self, param: torch.nn.Parameter) -> torch.Tensor:\n        \"\"\"Construct view into grad buffer corresponding to param\n\n        Assumes optimizer is using a contiguous grad buffer.\n\n        \"\"\"\n\n        # Initialize contiguous grad buffers if needed\n        assert self.contiguous_grad_buffer\n        if not self._grad_buffers:\n            self._init_grad_buffer()\n\n        # Figure out corresponding position in grad buffer\n        fragment = self.state[param][\"fragments\"][0]\n        bucket_id = fragment.bucket_id\n        bucket = self.state[\"buckets\"][bucket_id]\n        bucket_start, _ = fragment.bucket_range\n        buffer_offset = bucket.contiguous_buffer_offset\n        buffer_start = buffer_offset + bucket_start\n        buffer_end = buffer_start + param.numel()\n\n        # Construct view into grad buffer\n        # Preserve memory format for gradient here\n        flat_buffer = self._grad_buffers[bucket.dtypes()]\n        flat_buffer = flat_buffer[buffer_start:buffer_end]\n        return flat_buffer.detach().as_strided(param.size(), param.stride())\n\n    def _force_bucket_grad_sync(self) -> None:\n        \"\"\"Ensure that all gradient buckets are synchronized\"\"\"\n\n        # Synchronize all unsynchronized buckets\n        Status = self.GradientStatus\n        buckets = []\n        for bucket_id, grads_bucket in sorted(self._grads_buckets.items()):\n            if grads_bucket.status not in (Status.READY, Status.SYNCING):\n                buckets.append(grads_bucket)\n                if grads_bucket.grads_bucket is None:\n                    state_bucket = self.state[\"buckets\"][bucket_id]\n                    grads_bucket.grads_bucket = torch.zeros(\n                        [state_bucket.bucket_size],\n                        dtype=state_bucket.grad_sync_dtype,\n                        device=self.device,\n                    )\n        if buckets:\n            self._start_bucket_grad_sync(buckets)\n        self._finish_bucket_grad_sync()\n\n        # Fill any unsynchronized gradients with zeros\n        for bucket_id in range(len(self.state[\"buckets\"])):\n            grads_bucket = self._grads_buckets[bucket_id]\n            if grads_bucket.grads_shard is None:\n                state_bucket = self.state[\"buckets\"][bucket_id]\n                grads_bucket.grads_shard = torch.zeros(\n                    [state_bucket.shard_size],\n                    dtype=state_bucket.grad_sync_dtype,\n                    device=self.device,\n                )\n\n    def _try_start_bucket_grad_sync(\n        self,\n        params: Optional[Iterable[torch.nn.Parameter]] = None,\n        ignore_last_bucket: bool = False,\n    ) -> None:\n        \"\"\"Attempt to launch gradient synchronization\n\n        Launches gradient synchronization if any bucket has receieved\n        all its expected gradients. Gradient synchronization is\n        asynchronous.\n\n        Arguments:\n            params (iterable): parameters that have had their\n                gradients copied to buckets\n            ignore_last_bucket (bool): avoid synchronizing last bucket\n                until all gradients have been generated. This avoids\n                excessive synchronization when initializing buckets in\n                the first backward pass.\n\n        \"\"\"\n\n        # Register params that have generated grads\n        if params is None:\n            params = []\n        for param in params:\n            for fragment in self.state[param][\"fragments\"]:\n                bucket_id = fragment.bucket_id\n                grads_bucket = self._grads_buckets[bucket_id]\n                state_bucket = self.state[\"buckets\"][bucket_id]\n                bucket_fragments = state_bucket.fragments\n                grads_bucket.grads_generated.add(param)\n                if len(grads_bucket.grads_generated) == len(bucket_fragments):\n                    grads_bucket.status = self.GradientStatus.FULLY_FILLED\n                    if grads_bucket.grads_bucket is None:\n                        grads_bucket.grads_bucket = torch.zeros(\n                            [state_bucket.bucket_size],\n                            dtype=state_bucket.grad_sync_dtype,\n                            device=self.device,\n                        )\n\n        # Launch reductions if enough buckets are ready\n        filled_buckets = []\n        for bucket_id, bucket in sorted(self._grads_buckets.items()):\n            if ignore_last_bucket and bucket_id == len(self.state[\"buckets\"]) - 1:\n                continue\n            if bucket.status == self.GradientStatus.FULLY_FILLED:\n                filled_buckets.append(bucket)\n        if filled_buckets:\n            self._start_bucket_grad_sync(filled_buckets)\n\n    def _start_bucket_grad_sync(self, buckets: List[GradientBucket]) -> None:\n        \"\"\"Synchronize gradient buckets\n\n        Gradient synchronization is asynchronous. Involves\n        reduce-scatter over distributed process group and allreduce\n        over redundant process group. Assumes grad bucket buffers are\n        already initialized.\n\n        \"\"\"\n\n        # Complete any outstanding grad syncs\n        # Note: Not needed with contiguous grad buffer since there is\n        # no memory benefit from eagerly freeing grad buffers.\n        if not self.contiguous_grad_buffer:\n            self._finish_bucket_grad_sync()\n\n        # Reduction operation\n        if self.average_grad_sync and not self.nccl_ub:\n            reduce_op = torch.distributed.ReduceOp.AVG\n        else:\n            reduce_op = torch.distributed.ReduceOp.SUM\n\n        # Initialize grad state and buffers\n        for bucket in buckets:\n            if bucket.status == self.GradientStatus.SYNCING:\n                self._finish_bucket_grad_sync()\n            bucket.status = self.GradientStatus.SYNCING\n            bucket.grads_generated.clear()\n            if self.distributed_size == 1:\n                bucket.sync_grads_shard = bucket.grads_bucket\n            elif bucket.sync_grads_shard is None:\n                bucket_size = bucket.grads_bucket.numel()\n                shard_size = bucket_size // self.distributed_size\n                bucket.sync_grads_shard = torch.empty(\n                    [shard_size],\n                    dtype=bucket.grads_bucket.dtype,\n                    device=bucket.grads_bucket.device,\n                )\n\n            # Handle case with multiple grad accumulation steps\n            if bucket.grads_shard is not None:\n                if bucket.sync_grads_shard.data_ptr() == bucket.grads_shard.data_ptr():\n                    bucket.grads_shard = bucket.grads_shard.clone()\n\n        # Side stream for communication\n        # If new bucket is ready before last bucket communication finishes, use multiple\n        # communication streams could help pipeline reduce-scatter and all-reduce.\n        main_stream = torch.cuda.current_stream()\n        self._last_comm_stream_id = (self._last_comm_stream_id + 1) % len(self._comm_streams)\n        comm_stream = self._comm_streams[self._last_comm_stream_id]\n        comm_stream.wait_stream(main_stream)\n\n        # Reduce-scatter over distributed process group\n        if buckets and self.distributed_size > 1:\n            with torch.cuda.stream(comm_stream):\n                group = self.distributed_process_group\n                with _coalescing_manager(group, self.device, async_ops=True) as cm:\n                    for bucket in buckets:\n                        if self.average_grad_sync and self.nccl_ub:\n                            bucket.grads_bucket /= self.distributed_size\n                        _coalescing_manager_append_work(\n                            cm,\n                            reduce_scatter_tensor(\n                                bucket.sync_grads_shard,\n                                bucket.grads_bucket,\n                                op=reduce_op,\n                                group=group,\n                                async_op=True,\n                            ),\n                        )\n                cm.wait()\n\n        # All-reduce over redundant process group\n        if buckets and self.redundant_size > 1:\n            with torch.cuda.stream(comm_stream):\n                group = self.redundant_process_group\n                with _coalescing_manager(group, self.device, async_ops=True) as cm:\n                    for bucket in buckets:\n                        _coalescing_manager_append_work(\n                            cm,\n                            torch.distributed.all_reduce(\n                                bucket.sync_grads_shard,\n                                op=reduce_op,\n                                group=group,\n                                async_op=True,\n                            ),\n                        )\n                cm.wait()\n\n    def _finish_bucket_grad_sync(self) -> None:\n        \"\"\"Wait for any gradient synchronizations that are in progress\"\"\"\n        main_stream = torch.cuda.current_stream()\n        for comm_stream in self._comm_streams:\n            main_stream.wait_stream(comm_stream)\n        for bucket_id, bucket in sorted(self._grads_buckets.items()):\n            if bucket.status == self.GradientStatus.SYNCING:\n                # Accumulate gradient in local shard\n                if bucket.grads_shard is None:\n                    bucket.grads_shard = bucket.sync_grads_shard\n                else:\n                    bucket.grads_shard.add_(bucket.sync_grads_shard)\n                bucket.grads_bucket = None\n\n                # Reset status\n                bucket.status = self.GradientStatus.READY\n\n                # Cached gradient norm has been invalidated\n                self._grad_norm = None\n\n    def _try_start_bucket_param_sync(\n        self,\n        params: Iterable[torch.nn.Parameter] = None,\n    ) -> None:\n        \"\"\"Attempt to launch parameter synchronization\n\n        Launches parameter synchronization for buckets corresponding\n        to provided parameters, if needed. If parameters are not\n        provided and no other synchronizations are in progress,\n        attempts to find a parameter that still requires\n        synchronization. Parameter synchronization is asynchronous.\n\n        Arguments:\n            params (iterable, optional): parameters to synchronize\n\n        \"\"\"\n\n        # Default behavior: only launch param sync if no other syncs\n        # are in progress\n        if params is None:\n            params = []\n            if any(\n                bucket.status == self.ParameterStatus.SYNCING\n                for bucket in self._params_buckets.values()\n            ):\n                return\n            for bucket_id, bucket in self._params_buckets.items():\n                if bucket.status == self.ParameterStatus.SHARDED:\n                    params.append(self.parameter(self.state[\"buckets\"][bucket_id].fragments[-1]))\n                    break\n\n        # Find buckets corresponding to params\n        bucket_ids = set()\n        for param in params:\n            bucket_ids.update(fragment.bucket_id for fragment in self.state[param][\"fragments\"])\n        buckets = [\n            self._params_buckets[bucket_id]\n            for bucket_id in sorted(bucket_ids)\n            if bucket_id in self._params_buckets\n        ]\n        buckets = [bucket for bucket in buckets if bucket.status == self.ParameterStatus.SHARDED]\n\n        # Launch param sync if needed\n        if buckets:\n            self._start_bucket_param_sync(buckets)\n\n    def _start_bucket_param_sync(self, buckets: List[ParameterBucket]) -> None:\n        \"\"\"Synchronize parameter buckets\n\n        Parameter synchronization is asynchronous. Involves all-gather\n        over distributed process group. Assumes param shard buffers\n        are already initialized.\n\n        \"\"\"\n\n        # Complete any outstanding param syncs\n        self._finish_bucket_param_sync()\n\n        # Initialize param state and buffers\n        buckets = [bucket for bucket in buckets if bucket.status == self.ParameterStatus.SHARDED]\n        for bucket in buckets:\n            bucket.status = self.ParameterStatus.SYNCING\n            if bucket.params_bucket is not None:\n                pass\n            elif self.distributed_size == 1:\n                bucket.params_bucket = bucket.params_shard\n            else:\n                shard_size = bucket.params_shard.numel()\n                bucket_size = shard_size * self.distributed_size\n                bucket.params_bucket = torch.empty(\n                    [bucket_size],\n                    dtype=bucket.params_shard.dtype,\n                    device=bucket.params_shard.device,\n                )\n\n        # Side stream for communication\n        main_stream = torch.cuda.current_stream()\n        self._last_comm_stream_id = (self._last_comm_stream_id + 1) % len(self._comm_streams)\n        comm_stream = self._comm_streams[self._last_comm_stream_id]\n        comm_stream.wait_stream(main_stream)\n\n        # All-gather over distributed process group\n        if buckets and self.distributed_size > 1:\n            with torch.cuda.stream(comm_stream):\n                group = self.distributed_process_group\n                with _coalescing_manager(group, self.device, async_ops=True) as cm:\n                    for bucket in buckets:\n                        _coalescing_manager_append_work(\n                            cm,\n                            all_gather_into_tensor(\n                                bucket.params_bucket,\n                                bucket.params_shard,\n                                group=group,\n                                async_op=True,\n                            ),\n                        )\n                cm.wait()\n\n    def _finish_bucket_param_sync(self) -> None:\n        \"\"\"Wait for any param synchronizations that are in progress\"\"\"\n        main_stream = torch.cuda.current_stream()\n        for comm_stream in self._comm_streams:\n            main_stream.wait_stream(comm_stream)\n        for bucket_id, bucket in self._params_buckets.items():\n            if bucket.status == self.ParameterStatus.SYNCING:\n                bucket.params_shard = None\n                bucket.status = self.ParameterStatus.READY\n\n    @contextlib.contextmanager\n    def no_sync(\n        self,\n        greedy_grad_copy: None = False,\n    ) -> contextlib.AbstractContextManager:\n        \"\"\"Disable overlapped gradient synchronization\n\n        Context manager that is similar to\n        torch.nn.parallel.DistributedDataParallel.no_sync. The\n        gradients can be synchronized by calling grad_sync or step. If\n        overlapped gradient synchronization is enabled, gradients can\n        also be synchronized by leaving the context and performing a\n        backward pass.\n\n        Arguments:\n            greedy_grad_copy (bool, optional): copy parameter\n                gradients to buckets as soon as they are generated\n                (default: False)\n\n        \"\"\"\n        old_greedy_grad_copy = self.greedy_grad_copy\n        old_overlap_grad_sync = self.overlap_grad_sync\n        self.greedy_grad_copy = greedy_grad_copy\n        self.overlap_grad_sync = False\n        try:\n            yield\n        finally:\n            self.greedy_grad_copy = old_greedy_grad_copy\n            self.overlap_grad_sync = old_overlap_grad_sync\n\n    def grad_sync(self) -> None:\n        \"\"\"Ensure that all gradients are synchronized\"\"\"\n        for bucket in self.state[\"buckets\"]:\n            for fragment in bucket.fragments:\n                param = self.parameter(fragment)\n                if param.grad is not None:\n                    self._grad_copy(param)\n                    if not self.contiguous_grad_buffer:\n                        self._try_start_bucket_grad_sync(\n                            params=[param],\n                            ignore_last_bucket=False,\n                        )\n        self._force_bucket_grad_sync()\n\n    def param_sync(self) -> None:\n        \"\"\"Ensure that all parameters are synchronized\"\"\"\n        if self.contiguous_param_buffer:\n            self._param_copy(self.parameters())\n        else:\n            while self._params_buckets:\n                bucket_id, bucket = next(iter((self._params_buckets.items())))\n                for fragment in reversed(self.state[\"buckets\"][bucket_id].fragments):\n                    self._param_copy(self.parameter(fragment))\n        self._params_buckets.clear()\n\n    @torch.no_grad()\n    def _local_grad_norm(\n        self,\n        parameters: Optional[Iterable[torch.nn.Parameter]] = None,\n        norm_type: float = 2.0,\n    ) -> torch.Tensor:\n        \"\"\"Local contribution to parameter gradient norm\n\n        Returns square of 2-norm. Other norms are not yet supported.\n\n        If no parameters are provided, the norm is computed for all\n        parameters in optimizer. Provided parameters are assumed to be\n        in optimizer and to require gradients.\n\n        \"\"\"\n        norm_type = float(norm_type)\n        assert norm_type == 2.0\n\n        # Make sure that gradients have been reduced\n        self.grad_sync()\n\n        # Check if provided parameters are subset of all parameters\n        if parameters is not None:\n            parameters = list(parameters)\n            params_set = set(parameters)\n            all_params_set = set()\n            for bucket in self.state[\"buckets\"]:\n                for fragment in bucket.fragments:\n                    all_params_set.add(self.parameter(fragment))\n            if not params_set.issubset(all_params_set):\n                raise RuntimeError(\n                    \"Attempted to compute gradient norm for a parameter \"\n                    \"that is not managed by DistributedFusedAdam\"\n                )\n            if params_set == all_params_set:\n                parameters = None\n\n        # Group grads by dtype\n        grad_groups = collections.defaultdict(list)\n        if parameters is None:\n            # Compute norm of all local gradients\n            for bucket_id, grads_bucket in self._grads_buckets.items():\n                state_bucket = self.state[\"buckets\"][bucket_id]\n                dtype = state_bucket.grad_sync_dtype\n                grad_groups[dtype].append(grads_bucket.grads_shard)\n        else:\n            # Compute norm of selected local gradients\n            for param in parameters:\n                if \"fragments\" not in self.state[param]:\n                    continue\n                for fragment in self.state[param][\"fragments\"]:\n                    if not fragment.in_local_shard:\n                        continue\n                    shard_start, shard_end = fragment.shard_range\n                    if shard_end <= shard_start:\n                        continue\n                    bucket_id = fragment.bucket_id\n                    grads_bucket = self._grads_buckets[bucket_id]\n                    state_bucket = self.state[\"buckets\"][bucket_id]\n                    grad_groups[state_bucket.grad_sync_dtype].append(\n                        grads_bucket.grads_shard[shard_start:shard_end]\n                    )\n\n        # Compute norm of each group of grads\n        grad_norm_sq = None\n        for grad_group in grad_groups.values():\n            grad_group_norm_sq = (\n                multi_tensor_applier(\n                    amp_C.multi_tensor_l2norm,\n                    self._dummy_overflow_buf,\n                    [grad_group],\n                    False,\n                )[0]\n                ** 2\n            )\n            if grad_norm_sq is None:\n                grad_norm_sq = grad_group_norm_sq\n            else:\n                grad_norm_sq += grad_group_norm_sq\n        if grad_norm_sq is None:\n            grad_norm_sq = torch.zeros([], dtype=torch.float32, device=self.device)\n\n        # Interpret norm as scalar\n        grad_norm_sq = grad_norm_sq.to(dtype=torch.float32, device=self.device)\n        grad_norm_sq = grad_norm_sq.view([])\n        return grad_norm_sq\n\n    def grad_norm(\n        self,\n        parameters: Optional[Iterable[torch.nn.Parameter]] = None,\n        norm_type: float = 2.0,\n        force: bool = False,\n    ) -> torch.Tensor:\n        \"\"\"Gradient norm of parameters in optimizer\n\n        The norm is computed over all gradients together, as if they\n        were concatenated into a single vector. All provided\n        parameters must be managed by optimizer.\n\n        The computed value is cached to avoid redundant communication.\n\n        Arguments:\n            parameters (iterable, optional): an iterable of parameters\n                in optimizer (default: all parameters in optimizer).\n            norm_type (float, optional): type of the used p-norm\n                (default: 2). Only 2-norm is currently supported.\n            force (bool, optional): ignore cached value and force norm\n                computation (default: False).\n\n        \"\"\"\n        if force or self._grad_norm is None:\n            norm_type = float(norm_type)\n            assert norm_type == 2.0\n            grad_norm_sq = self._local_grad_norm(\n                parameters=parameters,\n                norm_type=norm_type,\n            )\n            torch.distributed.all_reduce(\n                grad_norm_sq,\n                op=torch.distributed.ReduceOp.SUM,\n                group=self.distributed_process_group,\n            )\n            self._grad_norm = grad_norm_sq.sqrt()\n        grad_norm = self._grad_norm * self._grad_scale\n        return grad_norm.detach()\n\n    def clip_grad_norm(\n        self,\n        max_norm: float,\n        parameters: Optional[Iterable[torch.nn.Parameter]] = None,\n        norm_type: float = 2.0,\n    ) -> torch.Tensor:\n        \"\"\"Clips gradient norm of parameters in optimizer\n\n        The norm is computed over all gradients together, as if they\n        were concatenated into a single vector. The scaling is\n        deferred until the optimizer step, which should be called\n        immediately after this function.\n\n        The computed grad norm is cached to avoid redundant\n        communication.\n\n        Arguments:\n            max_norm (float): max norm of the gradients\n            parameters (iterable, optional): an iterable of parameters\n                in optimizer (default: all parameters in optimizer).\n            norm_type (float, optional): type of the used\n                p-norm (default: 2)\n\n        \"\"\"\n        assert max_norm > 0\n        total_norm = self.grad_norm(parameters=parameters, norm_type=norm_type)\n        clip_coef = max_norm / (total_norm + 1e-6)\n        clip_coef_clamped = torch.clamp(clip_coef, max=1.0)\n        self._grad_scale *= clip_coef_clamped\n        return total_norm\n\n    @torch.no_grad\n    def unscale_grads(\n        self,\n        *args: Union[Optional[torch.Tensor], Any],\n        inv_scale: Optional[torch.Tensor] = None,\n        grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,\n    ) -> None:\n        \"\"\"Custom unscale function for use by AMP gradient scaler\n\n        Either inv_scale or grad_scaler must be provided, but not\n        both. If grad_scaler is provided, this is equivalent to\n        calling its unscale_ function.\n\n        Arguments:\n            inv_scale (torch.Tensor, optional): factor to multiply\n                gradients. May be provided either as a kwarg or as the\n                first positional arg.\n            grad_scaler (torch.cuda.amp.GradScaler): gradient scaler\n                (default: None)\n\n        \"\"\"\n\n        # inv_scale is either kwarg or first positional arg\n        if inv_scale is None and len(args) >= 1:\n            inv_scale = args[0]\n\n        # Check for non-finite values\n        # Note: We compute gradient norm to check for non-finite\n        # values. This is more conservative and compute intensive than\n        # directly checking, but it avoids extra communication if we\n        # have already computed gradient norm e.g. for gradient\n        # clipping.\n        found_inf = torch.logical_not(torch.isfinite(self.grad_norm()))\n        found_inf_per_device = {found_inf.device: found_inf.float()}\n\n        # Get inv_scale from GradScaler if provided\n        if grad_scaler is not None and grad_scaler._enabled:\n            grad_scaler_state = grad_scaler._per_optimizer_states[id(self)]\n            GradScalerOptState = torch.cuda.amp.grad_scaler.OptState\n            if grad_scaler_state[\"stage\"] is GradScalerOptState.UNSCALED:\n                raise RuntimeError(\n                    \"unscale_grads has already been called since the last GradScaler update\"\n                )\n            if grad_scaler_state[\"stage\"] is GradScalerOptState.STEPPED:\n                raise RuntimeError(\"unscale_grads is being called after optimizer step\")\n            if grad_scaler._scale is None:\n                raise RuntimeError(\"Attempted unscale_grads with GradScaler that is missing _scale\")\n            if inv_scale is not None:\n                raise ValueError(\n                    \"unscale_grads is being called with both scale_inv and grad_scaler\"\n                )\n            inv_scale = grad_scaler._scale.double().reciprocal()\n            inv_scale = inv_scale.to(dtype=torch.float32, device=self.device)\n            grad_scaler_state[\"found_inf_per_device\"] = found_inf_per_device\n            grad_scaler_state[\"stage\"] = GradScalerOptState.UNSCALED\n\n        # Apply inv_scale to grad_scale\n        if inv_scale is None:\n            raise ValueError(\"unscale_grads is being called with neither scale_inv and grad_scaler\")\n        self._grad_scale *= inv_scale.view([])\n        return found_inf_per_device\n\n    def step(\n        self,\n        closure: Optional[Callable] = None,\n        *,\n        grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,\n    ):\n        \"\"\"Apply Adam optimizer step\n\n        Arguments:\n            closure (callable, optional): closure to recompute loss\n                (default: None)\n            grad_scaler (torch.cuda.amp.GradScaler, optional):\n                gradient scaler (default: None)\n\n        \"\"\"\n\n        # Apply closure\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        # Make sure params are initialized\n        self.init_params()\n\n        # Make sure that parameters and gradients are synchronized\n        self.param_sync()\n        self.grad_sync()\n\n        # Apply gradient scaler if provided\n        if grad_scaler is not None and grad_scaler._enabled:\n            grad_scaler_state = grad_scaler._per_optimizer_states[id(self)]\n            GradScalerOptState = torch.cuda.amp.grad_scaler.OptState\n            if grad_scaler_state[\"stage\"] is GradScalerOptState.READY:\n                self.unscale_grads(grad_scaler=grad_scaler)\n            found_inf = grad_scaler_state[\"found_inf_per_device\"][self.device]\n            if self.capturable:\n                self._dummy_overflow_buf.copy_(found_inf)\n            elif found_inf.item():\n                return\n        self._grad_scale = self._grad_scale.to(dtype=torch.float32, device=self.device)\n\n        # Initialize buffers for param syncs\n        num_buckets = len(self.state[\"buckets\"])\n        for bucket_id in reversed(range(num_buckets)):\n            self._params_buckets[bucket_id] = self.ParameterBucket()\n            params_bucket = self._params_buckets[bucket_id]\n            state_bucket = self.state[\"buckets\"][bucket_id]\n            shard_size = state_bucket.shard_size\n            dtype = state_bucket.dtype\n            param_sync_dtype = state_bucket.param_sync_dtype\n\n            if self.contiguous_param_buffer:\n                # Construct views into contiguous param buffer\n                if not self._param_buffers:\n                    self.init_param_buffer()\n                bucket_size = state_bucket.bucket_size\n                buffer_start = state_bucket.contiguous_buffer_offset\n                buffer_end = buffer_start + bucket_size\n                param_buffer = self._param_buffers[state_bucket.dtypes()]\n                params_bucket.params_bucket = param_buffer[buffer_start:buffer_end]\n                bucket_start = self.distributed_rank * shard_size\n                bucket_end = bucket_start + shard_size\n                params_bucket.params_shard = params_bucket.params_bucket[bucket_start:bucket_end]\n\n            # Initialize param shard buffer\n            if self.with_scaled_states:\n                # Use FP32 workspace buffer with scaled optimizer state\n                params_bucket.params_shard = None\n            elif not param_sync_dtype.is_floating_point:\n                # Make sure param shard buffer is floating-point\n                if state_bucket.params_shard is not None and dtype.is_floating_point:\n                    params_bucket.params_shard = state_bucket.params_shard\n                else:\n                    params_bucket.params_shard = torch.empty(\n                        [shard_size],\n                        dtype=self.dtype,\n                        device=self.device,\n                    )\n            else:\n                # Allocate param shard buffer if needed\n                if params_bucket.params_shard is not None:\n                    pass\n                elif state_bucket.params_shard is not None and dtype == param_sync_dtype:\n                    params_bucket.params_shard = state_bucket.params_shard\n                else:\n                    params_bucket.params_shard = torch.empty(\n                        [shard_size],\n                        dtype=param_sync_dtype,\n                        device=self.device,\n                    )\n\n        # Apply optimizer step\n        self.state[\"step\"] += (\n            1 if not self.capturable else (self._dummy_overflow_buf != 1).to(torch.int)\n        )\n        overlap_first_bucket = (\n            self.distributed_size > 1 and self.overlap_param_sync and self.state[\"buckets\"]\n        )\n        if overlap_first_bucket:\n            # Local step and non-blocking param sync\n            # Note: Overlap param sync of first buckets with optimizer\n            # step of remaining buckets.\n\n            # Get buckets containing \"first\" parameter\n            first_param = self.parameter(self.state[\"buckets\"][-1].fragments[-1])\n            first_bucket_ids = sorted(\n                fragment.bucket_id for fragment in self.state[first_param][\"fragments\"]\n            )\n\n            # Local step and launch param sync for first buckets\n            self._local_step(first_bucket_ids)\n            self._start_bucket_param_sync(\n                self._params_buckets[bucket_id] for bucket_id in first_bucket_ids\n            )\n\n            # Local step for remaining buckets\n            first_bucket_ids = set(first_bucket_ids)\n            self._local_step(\n                [bucket_id for bucket_id in range(num_buckets) if bucket_id not in first_bucket_ids]\n            )\n\n        else:\n            # Local step\n            self._local_step(list(range(num_buckets)))\n\n        # Synchronize params\n        if self.distributed_size > 1 and self.overlap_param_sync:\n            # Asynchronous param sync\n            self._try_start_bucket_param_sync()\n            for param in self.parameters():\n                param._pre_forward_hook_is_enabled = True\n        else:\n            # Blocking param sync\n            self.param_sync()\n\n        return loss\n\n    def _local_step(self, bucket_ids: List[int]) -> None:\n        \"\"\"Apply optimizer step to local shard of parameter buckets\n\n        Arguments:\n            bucket_ids (list): bucket indices\n\n        \"\"\"\n\n        # Implementation with scaled optimizer state\n        if self.with_scaled_states:\n            self._local_step_with_scaled_states(bucket_ids)\n            return\n\n        # Optimized implementation with BF16 params and 16-bit param\n        # remainders\n        if self.store_param_remainders:\n            bf16_rem_buckets = set()\n            for bucket_id in bucket_ids:\n                state_bucket = self.state[\"buckets\"][bucket_id]\n                if state_bucket.param_remainders_shard is not None:\n                    bf16_rem_buckets.add(bucket_id)\n            if bf16_rem_buckets:\n                self._local_step_with_param_remainders(sorted(bf16_rem_buckets))\n            bucket_ids = [\n                bucket_id for bucket_id in bucket_ids if bucket_id not in bf16_rem_buckets\n            ]\n            if not bucket_ids:\n                return\n\n        # Find param fragments for each bucket\n        buffers = collections.defaultdict(list)  # p_in, m, v, g, p_out\n        for bucket_id in bucket_ids:\n            state_bucket = self.state[\"buckets\"][bucket_id]\n            grads_bucket = self._grads_buckets[bucket_id]\n            params_bucket = self._params_buckets[bucket_id]\n\n            # Optimizer state buffers for local shard\n            fragments = state_bucket.fragments\n            exp_avg = state_bucket.exp_avg_shard\n            exp_avg_sq = state_bucket.exp_avg_sq_shard\n            grads = grads_bucket.grads_shard\n            params_out = params_bucket.params_shard\n\n            # Find param fragments in local shard\n            for fragment in fragments:\n                if not fragment.in_local_shard:\n                    continue\n                shard_start, shard_end = fragment.shard_range\n                if shard_end <= shard_start:\n                    continue\n                shard_range = slice(shard_start, shard_end)\n                if state_bucket.params_shard is None:\n                    param = self.parameter(fragment)\n                    if param.is_contiguous(memory_format=torch.channels_last):\n                        param = param.permute(0, 2, 3, 1)\n                    param_range = slice(*fragment.shard_param_range)\n                    param_fragment = param.detach().view(-1)[param_range]\n                    param_fragment = param_fragment.to(dtype=state_bucket.dtype, device=self.device)\n                else:\n                    params_shard = state_bucket.params_shard\n                    param_fragment = params_shard[shard_range]\n                buffers_key = (\n                    fragment.param_group_id,\n                    state_bucket.dtype,\n                    state_bucket.grad_sync_dtype,\n                    state_bucket.param_sync_dtype,\n                )\n                buffers[buffers_key].append(\n                    [\n                        param_fragment,\n                        exp_avg[shard_range],\n                        exp_avg_sq[shard_range],\n                        grads[shard_range],\n                        params_out[shard_range],\n                    ]\n                )\n\n        # Apply optimizer step to each param group\n        adam_func = (\n            distributed_adam_cuda.multi_tensor_fused_adam_capturable\n            if self.capturable\n            else distributed_adam_cuda.multi_tensor_fused_adam\n        )\n        for (group_id, _, _, _), group_buffers in buffers.items():\n            group = self.param_groups[group_id]\n            beta1, beta2 = group[\"betas\"]\n            multi_tensor_applier(\n                adam_func,\n                self._dummy_overflow_buf,\n                list(zip(*group_buffers)),\n                self._grad_scale,\n                group[\"lr\"],\n                beta1,\n                beta2,\n                group[\"eps\"],\n                self.state[\"step\"],\n                1 if self.adam_w_mode else 0,\n                1 if group[\"bias_correction\"] else 0,\n                group[\"weight_decay\"],\n            )\n\n        # Make sure param sync buffer has correct dtype\n        self._check_params_shard_dtypes(\n            {bucket_id: self._params_buckets[bucket_id] for bucket_id in bucket_ids}\n        )\n\n    def _local_step_with_param_remainders(\n        self,\n        bucket_ids: List[int],\n    ) -> None:\n        \"\"\"Apply optimizer step to local shard of parameter bucket\n\n        This is an experimental implementation that expects\n        store_params=False and store_param_remainders=True. The\n        optimizer dtype must be FP32 and the params must all be BF16\n        and GPU.\n\n        Arguments:\n            bucket_ids (list): bucket indices\n\n        \"\"\"\n\n        # Find param fragments for each bucket\n        buffers = collections.defaultdict(list)  # p_in, p_rem, m, v, g, p_out\n        for bucket_id in bucket_ids:\n            state_bucket = self.state[\"buckets\"][bucket_id]\n            grads_bucket = self._grads_buckets[bucket_id]\n            params_bucket = self._params_buckets[bucket_id]\n\n            # State buffers for local shard\n            fragments = state_bucket.fragments\n            param_remainders_shard = state_bucket.param_remainders_shard\n            exp_avg = state_bucket.exp_avg_shard\n            exp_avg_sq = state_bucket.exp_avg_sq_shard\n            grads = grads_bucket.grads_shard\n            params_out = params_bucket.params_shard\n\n            # Find param fragments in local shard\n            for fragment in fragments:\n                if not fragment.in_local_shard:\n                    continue\n                shard_start, shard_end = fragment.shard_range\n                if shard_end <= shard_start:\n                    continue\n                shard_range = slice(shard_start, shard_end)\n                buffers_key = (\n                    fragment.param_group_id,\n                    state_bucket.grad_sync_dtype,\n                )\n                param = self.parameter(fragment)\n                param_range = slice(*fragment.shard_param_range)\n                param_fragment = param.detach().view(-1)[param_range]\n                param_fragment = param_fragment.to(dtype=torch.bfloat16, device=self.device)\n                buffers[buffers_key].append(\n                    [\n                        param_fragment,\n                        param_remainders_shard[shard_range],\n                        exp_avg[shard_range],\n                        exp_avg_sq[shard_range],\n                        grads[shard_range],\n                        params_out[shard_range],\n                    ]\n                )\n\n        # Apply optimizer step to each param group\n        for (group_id, _), group_buffers in buffers.items():\n            group = self.param_groups[group_id]\n            beta1, beta2 = group[\"betas\"]\n            multi_tensor_applier(\n                distributed_adam_cuda.multi_tensor_fused_adam_with_param_remainders,\n                self._dummy_overflow_buf,\n                list(zip(*group_buffers)),\n                self._grad_scale,\n                group[\"lr\"],\n                beta1,\n                beta2,\n                group[\"eps\"],\n                self.state[\"step\"],\n                1 if self.adam_w_mode else 0,\n                1 if group[\"bias_correction\"] else 0,\n                group[\"weight_decay\"],\n            )\n\n        # Make sure param sync buffer has correct dtype\n        self._check_params_shard_dtypes(\n            {bucket_id: self._params_buckets[bucket_id] for bucket_id in bucket_ids}\n        )\n\n    @torch.no_grad()\n    def _local_step_with_scaled_states(\n        self,\n        bucket_ids: List[int],\n    ) -> None:\n        for bucket_id in bucket_ids:\n            state_bucket = self.state[\"buckets\"][bucket_id]\n            grads_bucket = self._grads_buckets[bucket_id]\n            params_bucket = self._params_buckets[bucket_id]\n            params_bucket.params_shard = torch.empty_like(\n                state_bucket.params_shard,\n                dtype=torch.float32,\n            )\n\n            # Find param fragments in local shard\n            group_buffers = collections.defaultdict(list)  # p_in, m, v, g, p_out\n            scaled_buffers = []\n            unscaled_buffers = []\n            buffer_scales = []\n            for fragment in state_bucket.fragments:\n                if not fragment.in_local_shard:\n                    continue\n                shard_start, shard_end = fragment.shard_range\n                if shard_end <= shard_start:\n                    continue\n                shard_range = slice(shard_start, shard_end)\n                param_group_id = fragment.param_group_id\n                param_id = fragment.param_id\n                scaled_param = state_bucket.params_shard[shard_range]\n                scaled_exp_avg = state_bucket.exp_avg_shard[shard_range]\n                scaled_exp_avg_sq = state_bucket.exp_avg_sq_shard[shard_range]\n                grads = grads_bucket.grads_shard[shard_range]\n                param = params_bucket.params_shard[shard_range]\n                exp_avg = torch.empty_like(scaled_exp_avg, dtype=torch.float32)\n                exp_avg_sq = torch.empty_like(scaled_exp_avg_sq, dtype=torch.float32)\n                scales = self._state_scales[(param_group_id, param_id, bucket_id)]\n                group_buffers[param_group_id].append((param, exp_avg, exp_avg_sq, grads, param))\n                scaled_buffers.extend((scaled_param, scaled_exp_avg, scaled_exp_avg_sq))\n                unscaled_buffers.extend((param, exp_avg, exp_avg_sq))\n                buffer_scales.extend((scales[\"param\"], scales[\"exp_avg\"], scales[\"exp_avg_sq\"]))\n\n            # Unscale optimizer state\n            _multi_tensor_copy(\n                scaled_buffers,\n                unscaled_buffers,\n                dummy_overflow_buf=self._dummy_overflow_buf,\n            )\n            for buf, scale in zip(unscaled_buffers, buffer_scales):\n                buf.mul_(scale)\n\n            # Apply optimizer step to each param group\n            for group_id, buffers in group_buffers.items():\n                group = self.param_groups[group_id]\n                beta1, beta2 = group[\"betas\"]\n                multi_tensor_applier(\n                    distributed_adam_cuda.multi_tensor_fused_adam,\n                    self._dummy_overflow_buf,\n                    list(zip(*buffers)),\n                    self._grad_scale,\n                    group[\"lr\"],\n                    beta1,\n                    beta2,\n                    group[\"eps\"],\n                    self.state[\"step\"],\n                    1 if self.adam_w_mode else 0,\n                    1 if group[\"bias_correction\"] else 0,\n                    group[\"weight_decay\"],\n                )\n            del group_buffers\n\n            # Make sure param sync buffer has correct dtype\n            self._check_params_shard_dtypes({bucket_id: params_bucket})\n\n            # Scale optimizer state\n            for buf, scale in zip(unscaled_buffers, buffer_scales):\n                self._apply_state_scale(buf, scale)\n            _multi_tensor_copy(\n                unscaled_buffers,\n                scaled_buffers,\n                dummy_overflow_buf=self._dummy_overflow_buf,\n            )\n            del scaled_buffers, unscaled_buffers, buffer_scales\n\n    @torch.no_grad()\n    def _check_params_shard_dtypes(\n        self,\n        params_buckets: Dict[int, ParameterBucket],\n    ) -> None:\n        \"\"\"Make sure local shards of parameters are in expected datatypes\n\n        The Adam kernel only supports floating-point datatypes. If we\n        want to perform parameter synchronization with\n        non-floating-point dtypes, we need to allocate temporary\n        buffers that can accommodate the Adam kernel. This function is\n        responsible for converting these temporary buffers to the\n        parameter synchronization datatype.\n\n        \"\"\"\n\n        # Find param shards that require dtype conversion\n        buffers_in = []\n        buffers_out = []\n        for bucket_id, param_bucket in params_buckets.items():\n            # Check if param shard is already in expected dtype\n            state_bucket = self.state[\"buckets\"][bucket_id]\n            param_sync_dtype = state_bucket.param_sync_dtype\n            if param_bucket.params_shard.dtype == param_sync_dtype:\n                continue\n\n            # Allocate buffer with required dtype\n            buffer_in = param_bucket.params_shard\n            buffer_out = torch.empty_like(\n                param_bucket.params_shard,\n                dtype=param_sync_dtype,\n            )\n            param_bucket.params_shard = buffer_out\n\n            if torch.is_floating_point(buffer_in) and torch.is_floating_point(buffer_out):\n                # Cast between floating-point dtypes\n                buffers_in.append(buffer_in)\n                buffers_out.append(buffer_out)\n            else:\n                # Copy most significant bytes for non-floating-point\n                # dtypes\n                # Note: Assume dtypes are little-endian\n                in_bytes = buffer_in.unsqueeze(-1).view(torch.uint8)\n                out_bytes = buffer_out.unsqueeze(-1).view(torch.uint8)\n                copy_size = min(in_bytes.size(-1), out_bytes.size(-1))\n                buffers_in.append(in_bytes[..., -copy_size:])\n                buffers_out.append(out_bytes[..., -copy_size:])\n                if copy_size < out_bytes.size(-1):\n                    out_bytes[..., :-copy_size].zero_()\n\n        # Perform dtype conversions\n        _multi_tensor_copy(\n            buffers_in,\n            buffers_out,\n            dummy_overflow_buf=self._dummy_overflow_buf,\n        )\n\n    @torch.no_grad()\n    def _apply_state_scale(\n        self,\n        tensor: torch.Tensor,\n        scale: torch.Tensor,\n    ) -> None:\n        \"\"\"Compute and apply scaling factor for scaled optimizer state\n\n        The scaling factor is chosen to maximize the dynamic range\n        while avoiding numerical overflows. The returned tensors are\n        the scale (used to unscale the optimizer state) and the\n        scale-reciprocal (used to generate the scaled optimizer\n        state). The input tensors are updated in-place.\n\n        \"\"\"\n        if not hasattr(self, \"_max_scaled_state\"):\n            self._max_scaled_state = torch.full(\n                [1],\n                torch.finfo(self.dtype).max / 2,\n                dtype=torch.float32,\n                device=self.device,\n            )\n        min_val, max_val = torch.aminmax(tensor)\n        absmax = torch.maximum(-min_val, max_val)\n        absmax = absmax.to(dtype=torch.float32, device=self.device)\n        torch.div(absmax, self._max_scaled_state, out=scale)\n        rscale = torch.where(scale > 0, scale.reciprocal(), 0.0)\n        tensor.mul_(rscale)\n\n    def state_dict(\n        self,\n        *,\n        state_dict_format: Optional[int] = None,\n        gather_on_root: Optional[bool] = None,\n    ) -> Optional[dict]:\n        \"\"\"Get dictionary containing optimizer state\n\n        All ranks in the process group must call this function since\n        it performs communication. The same optimizer state is\n        returned on all ranks.\n\n        Arguments:\n            state_dict_format (int, optional): Tag for custom or\n                deprecated state dict format.\n            gather_on_root (bool, optional): Option for deprecated v1\n                format.\n\n        \"\"\"\n\n        # Default state dict format\n        if state_dict_format is None:\n            state_dict_format = 2\n\n        # Construct state dict\n        state_dict = None\n        if state_dict_format == 1:\n            # Deprecated v1 format\n            kwargs = {}\n            if gather_on_root is not None:\n                kwargs[\"gather_on_root\"] = gather_on_root\n            state_dict = self._state_dict_v1(**kwargs)\n        elif state_dict_format == 2:\n            # Default v2 format\n            state_dict = self._state_dict_v2()\n        else:\n            # Unrecognized format\n            raise ValueError(f\"Unrecognized state dict format ({state_dict_format})\")\n\n        # Add format tag to state dict\n        if state_dict is not None:\n            state_dict[\"format\"] = state_dict_format\n\n        return state_dict\n\n    def _state_dict_v1(self, gather_on_root: bool = True) -> Optional[dict]:\n        \"\"\"Get dictionary containing optimizer state (deprecated v1 format)\n\n        Default behavior is to perform communication so that the\n        entire optimizer state is returned on the root rank in the\n        process group. In this case, all ranks in the process group\n        must enter this function and no value is returned on non-root\n        ranks.\n\n        Arguments:\n            gather_on_root (bool, optional): Gather state from all\n                ranks on the root rank (default: True)\n\n        \"\"\"\n        warnings.warn(\n            \"Making optimizer state dictionary in deprecated v1 format. \"\n            \"Future support is not guaranteed.\"\n        )\n        if self.with_scaled_states:\n            raise NotImplementedError(\"Deprecated v1 format does not support scaled state\")\n\n        state_dict = super().state_dict()\n        if not gather_on_root:\n            return state_dict\n\n        # Finish any asynchronous communication\n        self.grad_sync()\n        self.param_sync()\n\n        # Export local state to byte string\n        state_bytes = io.BytesIO()\n        torch.save(state_dict, state_bytes)\n        state_bytes.seek(0)\n        state_bytes_view = state_bytes.getbuffer()\n\n        # Get data sizes on all ranks\n        local_state_size = len(state_bytes_view)\n        state_sizes = [None] * self.distributed_size\n        torch.distributed.all_gather_object(\n            state_sizes,\n            local_state_size,\n            group=self.process_group,\n        )\n        max_state_size = max(state_sizes)\n\n        # Construct workspace buffers\n        chunk_size = self.default_shard_size * torch.finfo(self.grad_sync_dtype).bits // 8\n        if self.distributed_rank == 0:\n            gathered_state_bytes = [\n                torch.empty([size], dtype=torch.uint8, device=\"cpu\") for size in state_sizes\n            ]\n            gathered_state_bytes[0].copy_(torch.frombuffer(state_bytes_view, dtype=torch.uint8))\n            gathered_chunks_buffers = [\n                torch.empty(\n                    [chunk_size * self.distributed_size],\n                    dtype=torch.uint8,\n                    device=self.device,\n                )\n                for _ in range(self.pipeline_size)\n            ]\n        else:\n            chunk_buffers = [\n                torch.empty(\n                    [chunk_size],\n                    dtype=torch.uint8,\n                    device=self.device,\n                )\n                for _ in range(self.pipeline_size)\n            ]\n\n        # Split data into chunks and gather on root rank\n        # Note: Assuming we are using the NCCL backend, communication\n        # must happen on the GPU. We split the data into fixed-size\n        # chunks to limit GPU memory usage.\n        main_stream = torch.cuda.current_stream()\n        for stream in self._pipeline_streams:\n            stream.wait_stream(main_stream)\n        for stream_id, offset in enumerate(range(0, max_state_size, chunk_size)):\n            stream_id %= self.pipeline_size\n            stream = self._pipeline_streams[stream_id]\n            with torch.cuda.stream(stream):\n                # Buffers for chunk\n                if self.distributed_rank == 0:\n                    gathered_chunks = [\n                        gathered_chunks_buffers[stream_id][i * chunk_size : (i + 1) * chunk_size]\n                        for i in range(self.distributed_size)\n                    ]\n                else:\n                    chunk = chunk_buffers[stream_id]\n\n                # Copy to GPU\n                if self.distributed_rank != 0 and offset < local_state_size:\n                    local_chunk_size = min(chunk_size, local_state_size - offset)\n                    chunk[:local_chunk_size].copy_(\n                        torch.frombuffer(\n                            state_bytes_view,\n                            dtype=torch.uint8,\n                            count=local_chunk_size,\n                            offset=offset,\n                        ),\n                        non_blocking=True,\n                    )\n\n                # Gather on root\n                # Note: Call in main stream to avoid memory pool\n                # overheads from internal memory allocations in\n                # gather.\n                main_stream.wait_stream(stream)\n                with torch.cuda.stream(main_stream):\n                    if self.distributed_rank == 0:\n                        if self._gather_no_copy:\n                            no_copy_kwarg = {\"no_copy\": True}\n                        else:\n                            no_copy_kwarg = {}\n                        torch.distributed.gather(\n                            gathered_chunks[0],\n                            gathered_chunks,\n                            dst=self.process_group_root,\n                            group=self.process_group,\n                            **no_copy_kwarg,\n                        )\n                    else:\n                        torch.distributed.gather(\n                            chunk,\n                            dst=self.process_group_root,\n                            group=self.process_group,\n                        )\n                stream.wait_stream(main_stream)\n\n                # Copy back to CPU\n                if self.distributed_rank == 0:\n                    for rank in range(1, self.distributed_size):\n                        rank_chunk_start = offset\n                        rank_chunk_end = min(offset + chunk_size, state_sizes[rank])\n                        rank_chunk_size = rank_chunk_end - rank_chunk_start\n                        if rank_chunk_size > 0:\n                            src = gathered_chunks[rank][:rank_chunk_size]\n                            dst = gathered_state_bytes[rank][rank_chunk_start:rank_chunk_end]\n                            dst.copy_(src, non_blocking=True)\n\n        # Synchronize GPU\n        for stream in self._pipeline_streams:\n            main_stream.wait_stream(stream)\n        main_stream.synchronize()\n\n        # Return gathered state data on root rank\n        if self.distributed_rank == 0:\n            return {\"gathered_states\": gathered_state_bytes}\n        else:\n            return None\n\n    @torch.no_grad()\n    def _state_dict_v2(self) -> Optional[dict]:\n        \"\"\"Get dictionary containing optimizer state (default v2 format)\n\n        All ranks in the process group must call this function since\n        it performs communication. The same optimizer state is\n        returned on all ranks.\n\n        \"\"\"\n\n        # Make sure params are initialized\n        self.init_params()\n\n        # Finish any asynchronous communication\n        self.grad_sync()\n        self.param_sync()\n\n        # Output tensor format\n        dtype = torch.float32 if self.with_scaled_states else self.dtype\n        device = torch.device(\"cpu\")\n\n        # Get state dict from base class\n        state_dict = super().state_dict()\n        state_dict[\"state\"] = {\"step\": state_dict[\"state\"][\"step\"]}\n\n        # Initialize state dict with CPU buffers\n        for param in self.parameters():\n            # Get param index in state dict\n            fragment = self.state[param][\"fragments\"][0]\n            param_group_id = fragment.param_group_id\n            param_id = fragment.param_id\n            index = state_dict[\"param_groups\"][param_group_id][\"params\"][param_id]\n\n            # Construct CPU buffers with optimizer state\n            state_dict[\"state\"][index] = dict(\n                param=torch.zeros_like(param, dtype=dtype, device=device),\n                exp_avg=torch.zeros_like(param, dtype=dtype, device=device),\n                exp_avg_sq=torch.zeros_like(param, dtype=dtype, device=device),\n            )\n\n        # Workspace buffers for gathering shards on root rank\n        num_buckets = len(self.state[\"buckets\"])\n        max_bucket_size = max(bucket.bucket_size for bucket in self.state[\"buckets\"])\n        bucket_buffers = [\n            torch.empty(\n                [max_bucket_size],\n                dtype=dtype,\n                device=self.device,\n            )\n            for _ in range(self.pipeline_size)\n        ]\n        if self.store_param_remainders:\n            max_shard_size = max(bucket.shard_size for bucket in self.state[\"buckets\"])\n            shard_bf16_buffers = [\n                torch.empty([max_shard_size], dtype=torch.bfloat16, device=self.device)\n                for _ in range(self.pipeline_size)\n            ]\n\n        # Synchronize streams\n        main_stream = torch.cuda.current_stream()\n        for stream in self._pipeline_streams:\n            stream.wait_stream(main_stream)\n\n        def get_workspace_shard(bucket_id: int) -> torch.Tensor:\n            \"\"\"Workspace buffer for local shard\"\"\"\n            bucket = self.state[\"buckets\"][bucket_id]\n            shard_size = bucket.shard_size\n            stream_id = bucket_id % self.pipeline_size\n            shard_range = slice(\n                shard_size * self.distributed_rank,\n                shard_size * (self.distributed_rank + 1),\n            )\n            return bucket_buffers[stream_id][shard_range]\n\n        def unscale_shard(\n            bucket_id: int,\n            shard: torch.Tensor,\n            state_key: str,\n        ) -> torch.Tensor:\n            \"\"\"Unscale local shard if needed\n\n            If state buffers are scaled, then the shard is unscaled\n            and output to a workspace buffer. Otherwise, the shard is\n            immediately returned.\n\n            \"\"\"\n            if not self.with_scaled_states:\n                return shard\n            out = get_workspace_shard(bucket_id)\n            bucket = self.state[\"buckets\"][bucket_id]\n            stream_id = bucket_id % self.pipeline_size\n            stream = self._pipeline_streams[stream_id]\n            with torch.cuda.stream(stream):\n                for fragment in bucket.fragments:\n                    if not fragment.in_local_shard:\n                        continue\n                    param_group_id = fragment.param_group_id\n                    param_id = fragment.param_id\n                    shard_range = slice(*fragment.shard_range)\n                    scale = self._state_scales[(param_group_id, param_id, bucket_id)][state_key]\n                    out[shard_range].copy_(shard[shard_range]).mul_(scale)\n            return out\n\n        def pack_param_shard(bucket_id: int) -> torch.Tensor:\n            \"\"\"Pack local shard of param values into contiguous buffer\"\"\"\n\n            # Stream objects\n            stream_id = bucket_id % self.pipeline_size\n            stream = self._pipeline_streams[stream_id]\n\n            # Bucket objects\n            bucket = self.state[\"buckets\"][bucket_id]\n            shard_size = bucket.shard_size\n\n            # Case 1: Param state is already packed\n            if bucket.params_shard is not None:\n                return unscale_shard(bucket_id, bucket.params_shard, \"param\")\n\n            # Case 2: Pack BF16 model params with 16-bit remainders\n            if bucket.param_remainders_shard is not None:\n                with torch.cuda.stream(stream):\n                    # Pack bf16 param values\n                    shard_bf16 = shard_bf16_buffers[stream_id][:shard_size]\n                    buffers_in = []\n                    buffers_out = []\n                    for fragment in bucket.fragments:\n                        if not fragment.in_local_shard:\n                            continue\n                        param_range = slice(*fragment.shard_param_range)\n                        shard_range = slice(*fragment.shard_range)\n                        param = self.parameter(fragment)\n                        buffers_in.append(param.view(-1)[param_range])\n                        buffers_out.append(shard_bf16[shard_range])\n                    _multi_tensor_copy(\n                        buffers_in,\n                        buffers_out,\n                        dummy_overflow_buf=self._dummy_overflow_buf,\n                    )\n\n                    # Reconstruct fp32 from bf16 and remainders\n                    shard_fp32 = get_workspace_shard(bucket_id)\n                    _bf16_rem_to_fp32(\n                        shard_bf16,\n                        bucket.param_remainders_shard,\n                        shard_fp32,\n                    )\n                    return shard_fp32\n\n            # Case 3: Pack model params\n            with torch.cuda.stream(stream):\n                shard = get_workspace_shard(bucket_id)\n                buffers_in = []\n                buffers_out = []\n                for fragment in bucket.fragments:\n                    if not fragment.in_local_shard:\n                        continue\n                    param_range = slice(*fragment.shard_param_range)\n                    shard_range = slice(*fragment.shard_range)\n                    param = self.parameter(fragment)\n                    buffers_in.append(param.view(-1)[param_range])\n                    buffers_out.append(shard[shard_range])\n                _multi_tensor_copy(\n                    buffers_in,\n                    buffers_out,\n                    dummy_overflow_buf=self._dummy_overflow_buf,\n                )\n                return shard\n\n        def start_all_gather(bucket_id: int, shard: torch.Tensor) -> None:\n            \"\"\"Launch all-gather on bucket shards\n\n            Communication is done on main stream to ensure consistent\n            ordering.\n\n            \"\"\"\n\n            # Stream objects\n            stream_id = bucket_id % self.pipeline_size\n            stream = self._pipeline_streams[stream_id]\n\n            # Workspace buffer\n            bucket = self.state[\"buckets\"][bucket_id]\n            bucket_size = bucket.bucket_size\n            bucket_buffer = bucket_buffers[stream_id][:bucket_size]\n\n            # All-gather shards\n            main_stream.wait_stream(stream)\n            all_gather_into_tensor(\n                bucket_buffer,\n                shard,\n                group=self.distributed_process_group,\n            )\n            stream.wait_stream(main_stream)\n\n        def finish_all_gather(bucket_id: int, state_dict_key: str) -> None:\n            \"\"\"Finish all-gather on bucket shards\n\n            Data is copied into state dict CPU buffers.\n\n            Splitting the NCCL all-gather and the CPU memcpys into\n            separate stages helps achieve good overlap when kernel\n            launches are serialized with\n            CUDA_DEVICE_MAX_CONNECTIONS=1. In particular, the pipeline\n            calls start_all_gather(bucket_id+1) before\n            finish_all_gather(bucket_id).\n\n            \"\"\"\n\n            # Stream objects\n            stream_id = bucket_id % self.pipeline_size\n            stream = self._pipeline_streams[stream_id]\n\n            # Bucket objects\n            bucket = self.state[\"buckets\"][bucket_id]\n            bucket_size = bucket.bucket_size\n            bucket_buffer = bucket_buffers[stream_id][:bucket_size]\n\n            # Update state dict\n            with torch.cuda.stream(stream):\n                for fragment in bucket.fragments:\n                    param_range = slice(*fragment.param_range)\n                    bucket_range = slice(*fragment.bucket_range)\n                    param_group_id = fragment.param_group_id\n                    param_id = fragment.param_id\n                    index = state_dict[\"param_groups\"][param_group_id][\"params\"][param_id]\n                    state_buffer = state_dict[\"state\"][index][state_dict_key]\n                    state_fragment = state_buffer.view(-1)[param_range]\n                    bucket_fragment = bucket_buffer[bucket_range]\n                    state_fragment.copy_(bucket_fragment, non_blocking=True)\n\n        # All-gather param state\n        for bucket_id in range(num_buckets):\n            shard = pack_param_shard(bucket_id)\n            start_all_gather(bucket_id, shard)\n            if bucket_id > 0:\n                finish_all_gather(bucket_id - 1, \"param\")\n            if bucket_id == num_buckets - 1:\n                finish_all_gather(bucket_id, \"param\")\n\n        # All-gather exp_avg state\n        for bucket_id in range(num_buckets):\n            shard = unscale_shard(\n                bucket_id,\n                self.state[\"buckets\"][bucket_id].exp_avg_shard,\n                \"exp_avg\",\n            )\n            start_all_gather(bucket_id, shard)\n            if bucket_id > 0:\n                finish_all_gather(bucket_id - 1, \"exp_avg\")\n            if bucket_id == num_buckets - 1:\n                finish_all_gather(bucket_id, \"exp_avg\")\n\n        # All-gather exp_avg_sq state\n        for bucket_id in range(num_buckets):\n            shard = unscale_shard(\n                bucket_id,\n                self.state[\"buckets\"][bucket_id].exp_avg_sq_shard,\n                \"exp_avg_sq\",\n            )\n            start_all_gather(bucket_id, shard)\n            if bucket_id > 0:\n                finish_all_gather(bucket_id - 1, \"exp_avg_sq\")\n            if bucket_id == num_buckets - 1:\n                finish_all_gather(bucket_id, \"exp_avg_sq\")\n\n        # Synchronize GPU and return\n        for stream in self._pipeline_streams:\n            main_stream.wait_stream(stream)\n        main_stream.synchronize()\n        return state_dict\n\n    def load_state_dict(self, state_dict: dict) -> None:\n        \"\"\"Load optimizer state\"\"\"\n\n        # Figure out state dict format\n        state_dict_format = state_dict.pop(\"format\", None)\n        if state_dict_format is None:\n            if \"buckets\" in state_dict or \"gathered_states\" in state_dict:\n                state_dict_format = 1\n            else:\n                state_dict_format = 2\n\n        # Load state dict\n        if state_dict_format == 1:\n            # Deprecated v1 format\n            self._load_state_dict_v1(state_dict)\n        elif state_dict_format == 2:\n            # Default v2 format\n            self._load_state_dict_v2(state_dict)\n        else:\n            # Unrecognized format\n            raise ValueError(f\"Unrecognized state dict format ({state_dict_format})\")\n\n    def _load_state_dict_v1(self, state_dict: dict) -> None:\n        \"\"\"Load optimizer state (deprecated v1 format)\n\n        Parallel configuration (e.g. process group sizes) and\n        optimizer options must match between saving and loading the\n        optimizer state.\n\n        \"\"\"\n        warnings.warn(\n            \"Loading checkpoint in deprecated v1 format. Future support is not guaranteed.\"\n        )\n        if self.with_scaled_states:\n            raise NotImplementedError(\"Deprecated v1 format does not support scaled state\")\n\n        # Get state dict for current rank\n        if \"gathered_states\" in state_dict:\n            # Deallocate distributed optimizer state to reduce GPU\n            # memory usage\n            if \"buckets\" in self.state:\n                del self.state[\"buckets\"]\n\n            # Get state for current rank and parse byte string\n            state_bytes = state_dict[\"gathered_states\"][self.distributed_rank]\n            state_bytes = io.BytesIO(state_bytes.numpy())\n            state_dict = torch.load(state_bytes)\n\n        # Load state dict\n        super().load_state_dict(state_dict)\n\n        # Handle old state dicts without per-bucket dtypes\n        for bucket in self.state[\"buckets\"]:\n            if getattr(bucket, \"dtype\", None) is None:\n                bucket.dtype = self.dtype\n            if getattr(bucket, \"grad_sync_dtype\", None) is None:\n                bucket.grad_sync_dtype = self.grad_sync_dtype\n            if getattr(bucket, \"param_sync_dtype\", None) is None:\n                bucket.param_sync_dtype = self.param_sync_dtype\n\n            if bucket.params_shard is not None:\n                bucket.params_shard = bucket.params_shard.to(self.device)\n            if bucket.param_remainders_shard is not None:\n                bucket.param_remainders_shard = bucket.param_remainders_shard.to(self.device)\n            bucket.exp_avg_shard = bucket.exp_avg_shard.to(self.device)\n            bucket.exp_avg_sq_shard = bucket.exp_avg_sq_shard.to(self.device)\n\n    @torch.no_grad()\n    def _load_state_dict_v2(self, state_dict: dict) -> None:\n        \"\"\"Load optimizer state (default v2 format)\n\n        The parallel configuration and optimizer options are allowed\n        to differ between saving and loading the model.\n\n        \"\"\"\n\n        # Make sure params are initialized\n        self.init_params()\n\n        # Finish any asynchronous communication\n        self.grad_sync()\n        self.param_sync()\n\n        # Load general state\n        # Note: State includes bucketing scheme (e.g.\n        # self.state[\"buckets\"] and self.state[param][\"fragments\"]).\n        # This was needed for v1 checkpoints, but not for v2. As a\n        # kludge, we temporarily set state to dummy dict to avoid\n        # messing up the bucketing scheme.\n        state = self.state\n        self.state = {}\n        super().load_state_dict(\n            {\n                \"state\": {},\n                \"param_groups\": state_dict[\"param_groups\"],\n            }\n        )\n        self.state = state\n        self.state[\"step\"] = state_dict[\"state\"][\"step\"]\n\n        # Load state for each param\n        for param in self.parameters():\n            # Get param index in state dict\n            fragment = self.state[param][\"fragments\"][0]\n            param_id = fragment.param_id\n            param_group_id = fragment.param_group_id\n            index = state_dict[\"param_groups\"][param_group_id][\"params\"][param_id]\n\n            # Buffers in state dict\n            param_state = state_dict[\"state\"][index][\"param\"].view(-1)\n            exp_avg = state_dict[\"state\"][index][\"exp_avg\"].view(-1)\n            exp_avg_sq = state_dict[\"state\"][index][\"exp_avg_sq\"].view(-1)\n\n            # Copy to local shard of state buckets\n            for fragment in self.state[param][\"fragments\"]:\n                if not fragment.in_local_shard:\n                    continue\n                bucket_id = fragment.bucket_id\n                bucket = self.state[\"buckets\"][bucket_id]\n                param_range = slice(*fragment.shard_param_range)\n                shard_range = slice(*fragment.shard_range)\n                if self.with_scaled_states:\n                    scales = self._state_scales[(param_group_id, param_id, bucket_id)]\n                    temp = torch.empty_like(\n                        param_state[param_range],\n                        dtype=torch.float32,\n                        device=self.device,\n                    )\n                    temp.copy_(param_state[param_range], non_blocking=True)\n                    self._apply_state_scale(temp, scales[\"param\"])\n                    bucket.params_shard[shard_range].copy_(temp)\n                    temp.copy_(exp_avg[param_range], non_blocking=True)\n                    self._apply_state_scale(temp, scales[\"exp_avg\"])\n                    bucket.exp_avg_shard[shard_range].copy_(temp)\n                    temp.copy_(exp_avg_sq[param_range], non_blocking=True)\n                    self._apply_state_scale(temp, scales[\"exp_avg_sq\"])\n                    bucket.exp_avg_sq_shard[shard_range].copy_(temp)\n                else:\n                    if bucket.params_shard is not None:\n                        bucket.params_shard[shard_range].copy_(\n                            param_state[param_range],\n                            non_blocking=True,\n                        )\n                    if bucket.param_remainders_shard is not None:\n                        param_state_int16 = param_state.unsqueeze(-1).view(torch.int16)\n                        bucket.param_remainders_shard[shard_range].copy_(\n                            param_state_int16[param_range, 0],\n                            non_blocking=True,\n                        )\n                    bucket.exp_avg_shard[shard_range].copy_(\n                        exp_avg[param_range],\n                        non_blocking=True,\n                    )\n                    bucket.exp_avg_sq_shard[shard_range].copy_(\n                        exp_avg_sq[param_range],\n                        non_blocking=True,\n                    )\n\n        # Synchronize GPU\n        torch.cuda.current_stream().synchronize()\n"
  },
  {
    "path": "apex/contrib/optimizers/distributed_fused_lamb.py",
    "content": "import os\nimport inspect\nimport torch\nimport importlib\nimport amp_C\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\nimport torch.distributed.distributed_c10d as c10d\n\n# Fallback to private fields if using older PyTorch version\ntry:\n    import torch.distributed.distributed_c10d.get_process_group_ranks\nexcept ImportError:\n\n    def get_process_group_ranks(group):\n        return list(c10d._pg_group_ranks[group].keys())\n\n\n_make_nccl_premul_sum = getattr(torch.distributed, \"_make_nccl_premul_sum\", None)\n# Ref: https://github.com/pytorch/pytorch/pull/81272\nif _make_nccl_premul_sum is None:\n    if hasattr(torch.distributed, \"make_nccl_premul_sum\"):\n        _make_nccl_premul_sum = torch.distributed.make_nccl_premul_sum\n\n\nclass DistributedFusedLAMB(torch.optim.Optimizer):\n    \"\"\"Implements LAMB algorithm.\n\n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./``.\n\n    This version of fused LAMB implements 2 fusions.\n\n      * Fusion of the LAMB update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::\n\n        opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)\n        ...\n        opt.step()\n\n    :class:`apex.optimizers.FusedLAMB` may be used with or without Amp.  If you wish to use :class:`FusedLAMB` with Amp,\n    you may choose any ``opt_level``::\n\n        opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n\n    In general, ``opt_level=\"O1\"`` is recommended.\n\n    LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its norm. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            NOT SUPPORTED now! (default: False)\n        adam_w_mode (boolean, optional): Apply L2 regularization or weight decay\n            True for decoupled weight decay(also known as AdamW) (default: True)\n        grad_averaging (bool, optional): whether apply (1-beta2) to grad when\n            calculating running averages of gradient. (default: True)\n        set_grad_none (bool, optional): whether set grad to None when zero_grad()\n            method is called. (default: True)\n        max_grad_norm (float, optional): value used to clip global grad norm\n            (default: 1.0)\n        use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0\n            weight decay parameter (default: False)\n        step_supports_amp_scaling(boolean, optional): whether to use customized\n            gradient unscaling logic (default: True)\n\n    .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:\n        https://arxiv.org/abs/1904.00962\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    class AtomicCounter(object):\n        def __init__(self):\n            self.value = 0\n            self.order = []\n            import threading\n\n            self._lock = threading.Lock()\n\n        def add(self, idx):\n            with self._lock:\n                self.value += 1\n                self.order.append(idx)\n\n    def __init__(\n        self,\n        params,\n        lr=1e-3,\n        bias_correction=True,\n        grad_averaging=True,\n        betas=(0.9, 0.999),\n        eps=1e-8,\n        weight_decay=0.0,\n        max_grad_norm=0.0,\n        adam_w_mode=True,\n        use_nvlamb=False,\n        step_supports_amp_scaling=True,\n        overlap_reductions=True,\n        dwu_group_size=0,\n        dwu_num_blocks=4,\n        dwu_num_chunks=4,\n        dwu_num_rs_pg=1,\n        dwu_num_ar_pg=4,\n        dwu_num_ag_pg=0,\n        fused_norm=False,\n        e5m2_allgather=False,\n        verbose=False,\n        clip_after_ar=True,\n        full_ar=False,\n        set_param_views_to_flat_buffer=False,\n        skip_allgather=False,\n        fuse_scale=False,\n        param_order=None,\n        nccl_allgather_channels=0,\n    ):\n        defaults = dict(\n            lr=lr,\n            bias_correction=bias_correction,\n            betas=betas,\n            eps=eps,\n            weight_decay=weight_decay,\n            grad_averaging=grad_averaging,\n            max_grad_norm=max_grad_norm,\n        )\n\n        super(DistributedFusedLAMB, self).__init__(params, defaults)\n\n        global fused_adam_cuda, distributed_lamb_cuda\n        fused_adam_cuda = importlib.import_module(\"fused_adam_cuda\")\n        distributed_lamb_cuda = importlib.import_module(\"distributed_lamb_cuda\")\n\n        self._overflow_buf = torch.cuda.IntTensor([0])\n        self._has_overflow = False\n        self.multi_tensor_lamb_compute_update_term = (\n            distributed_lamb_cuda.multi_tensor_lamb_compute_update_term\n        )\n        self.multi_tensor_lamb_update_weights = (\n            distributed_lamb_cuda.multi_tensor_lamb_update_weights\n        )\n        import amp_C\n\n        self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm\n\n        self._grad_averaging = grad_averaging\n        self._adam_w_mode = 1 if adam_w_mode else 0\n        self._use_nvlamb = use_nvlamb\n        self._step_supports_amp_scaling = step_supports_amp_scaling\n        self._is_accumulation_step = False\n        self._last_step = False\n        self._overlap_reductions = overlap_reductions\n        self._global_scale = None\n        self._num_blocks = dwu_num_blocks\n        self._num_chunks = dwu_num_chunks\n        self._e5m2_allgather = e5m2_allgather\n        self._verbose = verbose\n        self._clip_after_ar = clip_after_ar\n        self._full_ar = full_ar\n        self._fuse_scale = fuse_scale\n        self._L2_grad_norm = None\n        self._set_flat_param_view = set_param_views_to_flat_buffer\n        self._skip_ag = skip_allgather\n        self._fused_norm = fused_norm if not clip_after_ar else False\n        self._current_process_group = c10d._get_default_group()\n        self._available_ranks = get_process_group_ranks(self._current_process_group)\n        self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size\n        self._world_size = torch.distributed.get_world_size()\n        self._num_groups = self._world_size // self._group_size\n        self._rank_in_group = torch.distributed.get_rank() % self._group_size\n\n        self._lr = torch.tensor(0.0, dtype=torch.float32, device=\"cuda\")\n\n        self._resume_from_checkpoint = False\n        self._step = torch.cuda.IntTensor([0])\n\n        # Master weight, moment, gradient buffers\n        self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = (\n            None,\n            None,\n            None,\n            None,\n            None,\n        )\n\n        # Check if collectives have no_copy option\n        self._reduce_scatter_no_copy = (\n            \"no_copy\" in inspect.getfullargspec(torch.distributed.reduce_scatter).args\n        )\n        self._all_gather_no_copy = (\n            \"no_copy\" in inspect.getfullargspec(torch.distributed.all_gather).args\n        )\n\n        if \"reduce_scatter_tensor\" not in dir(torch.distributed):\n            torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base\n        if \"all_gather_into_tensor\" not in dir(torch.distributed):\n            torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base\n\n        self._num_rs_pg = dwu_num_rs_pg\n        self._num_ar_pg = dwu_num_ar_pg\n        self._num_ag_pg = dwu_num_ag_pg\n\n        if self._full_ar:  # full all reduce, only need AR and AG groups\n            # l2_grad_norm may be reduced within a node to limit from memory reads\n            for group_i in range(self._num_groups):\n                ranks = [group_i * self._group_size + j for j in range(self._group_size)]\n                l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)\n                if torch.distributed.get_rank() in ranks:\n                    self._l2_grad_norm_pg = l2_grad_norm_pg\n\n            self._ar_pg = []\n            # consider all the ranks\n            ranks = list(range(0, self._world_size))\n            for i in range(self._num_ar_pg):\n                if self._verbose:\n                    print(f\"creating new AR group {i}: {ranks}\")\n                grp = torch.distributed.new_group(ranks=ranks)\n                if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:\n                    if self._verbose:\n                        print(f\"group {i}: init barrier (device: {torch.cuda.current_device()})\")\n                    torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])\n                if self._verbose:\n                    print(f\"created new AR group {i}: {ranks}\")\n\n                if torch.distributed.get_rank() in ranks:\n                    self._ar_pg.append(grp)\n            self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]\n            if nccl_allgather_channels > 0:\n                os.putenv(\"NCCL_MAX_NCHANNELS\", str(nccl_allgather_channels))\n            if self._num_ag_pg == 0:\n                self._ag_pg = self._ar_pg\n                self._ag_st = self._ar_st\n                self._num_ag_pg = self._num_ar_pg\n            else:\n                self._ag_pg = []\n                ranks = []\n                stride = torch.cuda.device_count()\n                for i in range(self._num_groups):\n                    rs = list(range(i * stride, (i + 1) * stride))\n                    ranks.append(rs)\n                for rs in ranks:\n                    for i in range(self._num_ag_pg):\n                        grp = torch.distributed.new_group(ranks=rs)\n                        if torch.distributed.get_rank() in rs:\n                            if self._verbose:\n                                print(f\"creating AG group {i}: {rs}\")\n                            self._ag_pg.append(grp)\n\n                self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]\n        else:  # reduce-scatter + all-reduce, need RS, AR, AG groups\n            if self._num_groups > 1:\n                self._ar_pg = []\n                for dev_i in range(self._group_size):\n                    ranks = [dev_i + j * self._group_size for j in range(self._num_groups)]\n                    for i in range(self._num_ar_pg):\n                        if self._verbose:\n                            print(f\"creating new AR group {i}: {ranks}\")\n                        grp = torch.distributed.new_group(ranks=ranks)\n                        if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:\n                            if self._verbose:\n                                print(\n                                    f\"group {i}: init barrier (device: {torch.cuda.current_device()})\"\n                                )\n                            torch.distributed.barrier(\n                                group=grp, device_ids=[torch.cuda.current_device()]\n                            )\n                        if self._verbose:\n                            print(f\"created new AR group {i}: {ranks}\")\n\n                        if torch.distributed.get_rank() in ranks:\n                            self._ar_pg.append(grp)\n                self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]\n            rs_ranks = []\n            for group_i in range(self._num_groups):\n                rs_ranks.append([group_i * self._group_size + j for j in range(self._group_size)])\n            self._rs_pg = []\n            for group_i in range(self._num_groups):\n                ranks = rs_ranks[group_i]\n                for i in range(self._num_rs_pg):\n                    grp = torch.distributed.new_group(ranks=ranks)\n                    if torch.distributed.get_rank() in ranks:\n                        self._rs_pg.append(grp)\n                        if self._verbose:\n                            print(f\"creating RS group : {ranks}\")\n                l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)\n                if torch.distributed.get_rank() in ranks:\n                    self._l2_grad_norm_pg = l2_grad_norm_pg\n            self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]\n            if self._num_ag_pg == 0:\n                self._ag_pg = self._rs_pg\n                self._ag_st = self._rs_st\n                self._num_ag_pg = self._num_rs_pg\n            else:\n                self._ag_pg = []\n                for group_i in range(self._num_groups):\n                    ranks = rs_ranks[group_i]\n                    for i in range(self._num_ag_pg):\n                        grp = torch.distributed.new_group(ranks=ranks)\n                        if torch.distributed.get_rank() in ranks:\n                            self._ag_pg.append(grp)\n                            if self._verbose:\n                                print(f\"creating AG group : {ranks}\")\n                self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]\n        for ag_pg in self._ag_pg:\n            torch.distributed.barrier(group=ag_pg)\n\n        self._l2_grad_norm_st = torch.cuda.Stream()\n        self._completion_st = torch.cuda.Stream()\n        self._step.record_stream(self._completion_st)\n\n        self._reductions_works = [None] * self._num_blocks\n        self._allgather_works = [None] * self._num_blocks\n\n        self._one = torch.cuda.IntTensor([1])\n\n        self._first_step = True\n        self._lazy_init_stage1_done, self._lazy_init_stage2_done = False, False\n        self._param_order = self.AtomicCounter()\n\n        p_offset = 0\n        p_i = 0\n        self._model_params = []\n        self._grad_accs = []\n        self._group_properties = []\n        for group in self.param_groups:\n            prev = None\n            beta1, beta2 = group[\"betas\"]\n            beta3 = 1.0 - beta1 if self._grad_averaging else 1.0\n            bias_correction = 1 if group[\"bias_correction\"] else 0\n            eps = group[\"eps\"]\n            weight_decay = group[\"weight_decay\"]\n            for p in group[\"params\"]:\n                if not p.requires_grad:\n                    continue\n                self._model_params.append(p)\n                self._group_properties.append(\n                    (weight_decay, bias_correction, beta1, beta2, beta3, eps)\n                )\n                p_grads_size = p.numel()\n                if self._set_flat_param_view:\n                    if param_order:\n                        # this is executed when param_order is specified by the user\n                        self._param_order.add(param_order[p])\n                    else:\n                        self._param_order.add(p_i)\n                p_offset += p_grads_size\n                # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters\n                # RNN is one example of consecutive parameters:\n                # (weight_ih, weight_hh, bias_ih, bias_hh)\n                if prev is not None and (\n                    prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()\n                ):\n                    p_offset = ((p_offset + 63) // 64) * 64\n                prev = p\n                p_i += 1\n        if param_order:\n            self._param_order.order = torch.argsort(torch.tensor(self._param_order.order)).tolist()\n        self._grads_generated = [False] * len(self._model_params)\n        self._grads_fp16, self._grads_fp32 = [], []\n        if self._overlap_reductions:\n            self._current_block = self._num_blocks\n\n        self._net_total_param_size = p_offset\n        self._total_param_size = p_offset\n        dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size\n        self._total_param_size = (\n            (self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size\n        ) * dwu_min_page_size\n        self._new_params = torch.zeros(\n            [self._total_param_size],\n            dtype=torch.uint8 if self._e5m2_allgather else torch.float16,\n            device=\"cuda\",\n        )\n\n    def _lazy_init_stage1(self):\n        if self._lazy_init_stage1_done:\n            return\n\n        p_i = 0\n        # self._model_params = []\n        # self._grad_accs = []\n        # self._group_properties = []\n        for group in self.param_groups:\n            for p in group[\"params\"]:\n                torch.distributed.broadcast(p, 0)\n                if not p.requires_grad:\n                    continue\n\n                def wrapper(param, param_i):\n                    param_tmp = param.expand_as(param)\n                    grad_acc = param_tmp.grad_fn.next_functions[0][0]\n\n                    def allreduce_hook(*unused):\n                        if not self._set_flat_param_view:\n                            if self._first_step:\n                                # first time\n                                self._param_order.add(param_i)\n                            else:\n                                idx = self._param_order.order.index(param_i)\n                                self._do_overlapped_reduction(idx, param)\n                        else:\n                            if not self._first_step:\n                                idx = self._param_order.order.index(param_i)\n                                self._do_overlapped_reduction(idx, param)\n\n                    grad_acc.register_hook(allreduce_hook)\n                    self._grad_accs.append(grad_acc)\n\n                wrapper(p, p_i)\n                p_i += 1\n\n        self._block_size = self._total_param_size // self._num_blocks\n        self._chunk_size = self._block_size // self._num_chunks\n        self._shard_size = self._chunk_size // self._group_size\n\n        self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device=\"cuda\")\n        self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size\n        # initialize master weights, moments buffers if not loaded from checkpoint\n        if self._fp32_p is None:\n            self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device=\"cuda\")\n            self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device=\"cuda\")\n            self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device=\"cuda\")\n            self._fp32_u = torch.zeros([self._mega_shard_size], dtype=torch.float32, device=\"cuda\")\n        # FIXME: Rethink fp16 label since it's either uint8 or fp16\n        self._fp16_p = torch.zeros(\n            [self._mega_shard_size],\n            dtype=torch.uint8 if self._e5m2_allgather else torch.float16,\n            device=\"cuda\",\n        )\n        self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device=\"cuda\")\n\n        def _flat_split(p):\n            def __blockify(p):\n                return [\n                    p[block_id * self._block_size : (block_id + 1) * self._block_size]\n                    for block_id in range(self._num_blocks)\n                ]\n\n            def __chunkify(p):\n                return [\n                    p[chunk_id * self._chunk_size : (chunk_id + 1) * self._chunk_size]\n                    for chunk_id in range(self._num_chunks)\n                ]\n\n            def __shardify(p):\n                return [\n                    p[shard_id * self._shard_size : (shard_id + 1) * self._shard_size]\n                    for shard_id in range(self._group_size)\n                ]\n\n            list_of_blocks = __blockify(p)\n            list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]\n            list_of_list_of_list_of_shards = [\n                [__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks\n            ]\n            return (\n                list_of_blocks,\n                list_of_list_of_chunks,\n                list_of_list_of_list_of_shards,\n            )\n\n        # note(crcrpar): the function below doesn't seem to be used at all.\n        # def _flat_split_no_shards(p):\n        #     def __blockify(p):\n        #         return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]\n        #     def __chunkify(p):\n        #         return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]\n        #     list_of_blocks = __blockify(self._flat_grads)\n        #     list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]\n        #     return list_of_blocks, list_of_list_of_chunks\n\n        def _full_packed_split(p):\n            def __shardify(p):\n                return [\n                    p[mega_shard * self._mega_shard_size : (mega_shard + 1) * self._mega_shard_size]\n                    for mega_shard in range(self._group_size)\n                ]\n\n            def __blockify(p):\n                return [\n                    p[\n                        block_id * self._num_chunks * self._shard_size : (block_id + 1)\n                        * self._num_chunks\n                        * self._shard_size\n                    ]\n                    for block_id in range(self._num_blocks)\n                ]\n\n            def __chunkify(p):\n                return [\n                    p[chunk_id * self._shard_size : (chunk_id + 1) * self._shard_size]\n                    for chunk_id in range(self._num_chunks)\n                ]\n\n            list_of_mega_shards = __shardify(p)\n            list_of_list_of_mega_blocks = [\n                __blockify(mega_shard) for mega_shard in list_of_mega_shards\n            ]\n            list_of_list_of_list_of_mega_chunks = [\n                [__chunkify(mega_block) for mega_block in mega_blocks]\n                for mega_blocks in list_of_list_of_mega_blocks\n            ]\n            return (\n                list_of_mega_shards,\n                list_of_list_of_mega_blocks,\n                list_of_list_of_list_of_mega_chunks,\n            )\n\n        def _packed_split(p):\n            def __packed_blockify(p):\n                packed_block_size = self._num_chunks * self._shard_size\n                return [\n                    p[block_id * packed_block_size : (block_id + 1) * packed_block_size]\n                    for block_id in range(self._num_blocks)\n                ]\n\n            def __packed_chunkify(p):\n                # in the packed format, each chunk contains one shard, so packed_chunk_size == self._shard_size\n                return [\n                    p[chunk_id * self._shard_size : (chunk_id + 1) * self._shard_size]\n                    for chunk_id in range(self._num_chunks)\n                ]\n\n            list_of_blocks = __packed_blockify(p)\n            list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]\n            return list_of_blocks, list_of_list_of_chunks\n\n        def _split_assign(shards):\n            packed_block_size = self._num_chunks * self._shard_size\n            list_of_list_of_chunks = []\n            for block_id in range(self._num_blocks):\n                list_of_chunks = []\n                for chunk_id in range(self._num_chunks):\n                    # self._fp16_g[block_id*packed_block_size+chunk_id*self._shard_size:block_id*packed_block_size+(chunk_id+1)*self._shard_size] = shards[block_id][chunk_id][self._rank_in_group]\n                    list_of_chunks.append(shards[block_id][chunk_id][self._rank_in_group])\n                list_of_list_of_chunks.append(list_of_chunks)\n            return list_of_list_of_chunks\n\n        (\n            self._new_params_mega_shards,\n            self._new_params_mega_blocks,\n            self._new_params_mega_chunks,\n        ) = _full_packed_split(self._new_params)\n        # this splitting scheme is needed when allgather needs to be split into multiple chunks in a contiguous way\n        self._new_params2_blocks, self._new_params2_chunks, self._new_params2_shards = _flat_split(\n            self._new_params\n        )\n\n        self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)\n        self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)\n        self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)\n        self._fp32_u_blocks, self._fp32_u_chunks = _packed_split(self._fp32_u)\n        self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)\n\n        if self._full_ar:\n            # for gradient all-reduce\n            (\n                self._flat_grads_blocks,\n                self._flat_grads_chunks,\n                self._flat_grads_shards,\n            ) = _flat_split(self._flat_grads)\n            # for weight update\n            self._fp16_g_chunks = _split_assign(self._flat_grads_shards)\n        else:\n            (\n                self._flat_grads_blocks,\n                self._flat_grads_chunks,\n                self._flat_grads_shards,\n            ) = _flat_split(self._flat_grads)\n            self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)\n\n        self._lazy_init_stage1_done = True\n\n    def _lazy_init_stage2(self):\n        if self._lazy_init_stage2_done:\n            return\n        if not self._set_flat_param_view:\n            # reversing is needed for overlapping allreduce and backprop, but currently not supported for flat param view\n            self._param_order.order.reverse()\n\n            # re-order model_params, grad_accs, group_properties lists\n        self._model_params = [self._model_params[i] for i in self._param_order.order]\n        self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]\n        self._group_properties = [self._group_properties[i] for i in self._param_order.order]\n\n        def _get_flat_view(param):\n            if param.is_contiguous(memory_format=torch.channels_last):\n                K, C, H, W = param.shape\n                pv = param.as_strided(size=(K, H, W, C), stride=(H * W * C, W * C, C, 1))\n            elif param.is_contiguous(memory_format=torch.channels_last_3d):\n                K, C, D, H, W = param.shape\n                pv = param.as_strided(\n                    size=(K, D, H, W, C), stride=(D * H * W * C, H * W * C, W * C, C, 1)\n                )\n            else:\n                pv = param\n            return pv.view(-1)\n\n        # re-collect grads info (size, offset) after ordering\n        prev = None\n        p_offset = 0\n        self._grads_info = []\n        self._individual_flat_grads = []\n        for i, p in enumerate(self._model_params):\n            p_grads_size = p.numel()\n            self._grads_info.append({\"param_grads_size\": p_grads_size, \"param_offset\": p_offset})\n            self._individual_flat_grads.append(\n                self._flat_grads[p_offset : p_offset + p_grads_size].view_as(p)\n            )\n            # for the first iteration\n            self._do_overlapped_reduction(i, p)\n            p_offset += p_grads_size\n            # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters\n            # RNN is one example of consecutive parameters:\n            # (weight_ih, weight_hh, bias_ih, bias_hh)\n            if prev is not None and (\n                prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()\n            ):\n                p_offset = ((p_offset + 63) // 64) * 64\n            prev = p\n\n        self._low_param_i = [0] * self._num_blocks\n        for block_id in range(self._num_blocks - 1, -1, -1):\n            p_i = len(self._grads_info) - 1\n            while p_i > 0 and self._grads_info[p_i][\"param_offset\"] > block_id * self._block_size:\n                p_i -= 1\n            self._low_param_i[block_id] = p_i\n        # print(\"self._low_param_i\", self._low_param_i)\n\n        # This paragraph does two things:\n        # 1) Copy model parameters into master buffer\n        # 2) Create tensor lists for unpacking new parameter tensor after all-gather\n        self._packed_flat_to_model_params_fp16 = []\n        self._packed_flat_to_model_params_fp32 = []\n        self._model_params_num = len(self._model_params)\n        self._contrib_tensor_list = []\n        self._contrib_min_param_i, self._contrib_max_param_i = -1, -1\n        self._contrib_update_frag_for_norm = []\n        self._contrib_model_param_for_norm_fp16 = []\n        self._contrib_model_param_for_norm_fp32 = []\n        self._contrib_model_param_for_norm_is_fp16 = []\n        self._model_param_is_contrib = []\n        self._contrib_group_properties = []\n        for shard_id in range(self._group_size):\n            for block_id in range(self._num_blocks):\n                for chunk_id in range(self._num_chunks):\n                    flat_shard_start = (\n                        ((block_id * self._num_chunks + chunk_id) * self._group_size) + shard_id\n                    ) * self._shard_size\n                    flat_shard_end = flat_shard_start + self._shard_size\n                    for param_i, (p, grads_info, group_props) in enumerate(\n                        zip(self._model_params, self._grads_info, self._group_properties)\n                    ):\n                        flat_grad_start = grads_info[\"param_offset\"]\n                        flat_grad_end = flat_grad_start + grads_info[\"param_grads_size\"]\n                        clipped_start = (lambda a, b: a if a > b else b)(\n                            flat_grad_start, flat_shard_start\n                        )\n                        clipped_end = (lambda a, b: a if a < b else b)(\n                            flat_grad_end, flat_shard_end\n                        )\n                        if clipped_start < clipped_end:\n                            grad_offset = clipped_start - flat_grad_start\n                            grad_length = clipped_end - clipped_start\n                            shard_offset = clipped_start - flat_shard_start\n                            pf = _get_flat_view(p)\n                            model_param_fragment = pf[grad_offset : grad_offset + grad_length]\n                            new_param_packed_fragment = self._new_params_mega_chunks[shard_id][\n                                block_id\n                            ][chunk_id][shard_offset : shard_offset + grad_length]\n                            if model_param_fragment.dtype == torch.float16:\n                                self._packed_flat_to_model_params_fp16.append(\n                                    (new_param_packed_fragment, model_param_fragment)\n                                )\n                            else:\n                                self._packed_flat_to_model_params_fp32.append(\n                                    (new_param_packed_fragment, model_param_fragment)\n                                )\n                            if shard_id == self._rank_in_group:\n                                self._model_param_is_contrib.append(param_i)\n                                # copy model parameters into master buffer\n                                master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][\n                                    shard_offset : shard_offset + grad_length\n                                ]\n                                opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][\n                                    shard_offset : shard_offset + grad_length\n                                ]\n                                opti_state_v_fragment = self._fp32_v_chunks[block_id][chunk_id][\n                                    shard_offset : shard_offset + grad_length\n                                ]\n                                opti_state_u_fragment = self._fp32_u_chunks[block_id][chunk_id][\n                                    shard_offset : shard_offset + grad_length\n                                ]\n                                opti_state_g_fragment = self._fp16_g_chunks[block_id][chunk_id][\n                                    shard_offset : shard_offset + grad_length\n                                ]\n                                opti_state_p_fragment = self._fp16_p_chunks[block_id][chunk_id][\n                                    shard_offset : shard_offset + grad_length\n                                ]\n                                # print(\"model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s\" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))\n                                if not self._resume_from_checkpoint:\n                                    master_param_fragment.copy_(model_param_fragment)\n                                self._contrib_group_properties.append(group_props)\n                                self._contrib_tensor_list.append(\n                                    (\n                                        master_param_fragment,\n                                        opti_state_m_fragment,\n                                        opti_state_v_fragment,\n                                        opti_state_u_fragment,\n                                        opti_state_g_fragment,\n                                        opti_state_p_fragment,\n                                    )\n                                )  # p, m, v, u, g, p_copy\n                                self._contrib_update_frag_for_norm.append(opti_state_u_fragment)\n                                if p.dtype == torch.float16:\n                                    self._contrib_model_param_for_norm_fp16.append(p)\n                                else:\n                                    self._contrib_model_param_for_norm_fp32.append(p)\n                                self._contrib_model_param_for_norm_is_fp16.append(\n                                    True if p.dtype == torch.float16 else False\n                                )\n                                if self._contrib_min_param_i < 0:\n                                    self._contrib_min_param_i = param_i\n                                self._contrib_max_param_i = param_i\n        self._contrib_model_param_for_norm_num = len(self._contrib_model_param_for_norm_is_fp16)\n        if len(self._contrib_model_param_for_norm_fp16) == 0:\n            self._contrib_model_param_for_norm_fp16 = None\n        if len(self._contrib_model_param_for_norm_fp32) == 0:\n            self._contrib_model_param_for_norm_fp32 = None\n        self._contrib_model_param_for_norm_is_fp32 = torch.tensor(\n            [not is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16],\n            dtype=torch.bool,\n            device=\"cuda\",\n        )\n        self._contrib_model_param_for_norm_is_fp16 = torch.tensor(\n            [is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16],\n            dtype=torch.bool,\n            device=\"cuda\",\n        )\n        self._offsets = torch.tensor(self._model_param_is_contrib, dtype=torch.int64, device=\"cuda\")\n\n        p, m, v, u, g, p_copy = list(zip(*self._contrib_tensor_list))\n        self._contrib_compute_update_term_tensor_list = [g, p, m, v, u]\n        self._contrib_update_weights_tensor_list = [u, p, p_copy]\n\n        math_type = self._fp32_u.dtype\n        decay, bias_correction, beta1, beta2, beta3, epsilon = list(\n            zip(*self._contrib_group_properties)\n        )\n        self._contrib_beta1 = torch.tensor(beta1, dtype=math_type, device=\"cuda\")\n        self._contrib_beta2 = torch.tensor(beta2, dtype=math_type, device=\"cuda\")\n        self._contrib_beta3 = torch.tensor(beta3, dtype=math_type, device=\"cuda\")\n        self._contrib_bias_correction = torch.tensor(\n            bias_correction, dtype=torch.int, device=\"cuda\"\n        )\n        self._contrib_epsilon = torch.tensor(epsilon, dtype=math_type, device=\"cuda\")\n        self._contrib_weight_decay = torch.tensor(decay, dtype=math_type, device=\"cuda\")\n\n        self._packed_flat_to_model_params_fp16 = (\n            list(zip(*self._packed_flat_to_model_params_fp16))\n            if len(self._packed_flat_to_model_params_fp16) > 0\n            else None\n        )\n        self._packed_flat_to_model_params_fp32 = (\n            list(zip(*self._packed_flat_to_model_params_fp32))\n            if len(self._packed_flat_to_model_params_fp32) > 0\n            else None\n        )\n\n        self._lazy_init_stage2_done = True\n\n        self.complete_reductions()\n        self._first_step = False\n\n    def set_is_accumulation_step(self, is_accumulation_step):\n        self._is_accumulation_step = is_accumulation_step\n\n    def set_last_step(self, last_step):\n        self._last_step = last_step\n\n    def _get_flush_block(self):\n        flush_block = []\n        if (\n            self._current_block > 0\n            and self._grads_generated[self._low_param_i[self._current_block - 1]]\n        ):\n            num_grads = len(self._grads_generated)\n            contiguous_idx = num_grads\n            while contiguous_idx > 0 and self._grads_generated[contiguous_idx - 1]:\n                contiguous_idx -= 1\n\n            if (\n                contiguous_idx < num_grads\n                and self._grads_info[contiguous_idx][\"param_offset\"]\n                <= (self._current_block - 1) * self._block_size\n            ):\n                self._current_block -= 1\n                start = self._current_block * self._block_size\n                end = (self._current_block + 1) * self._block_size\n                flush_block = [start, end]\n\n        return flush_block\n\n    def _full_all_reduce_scale(self, block_id, scale):\n        works = [None] * self._num_chunks\n        if self._clip_after_ar:\n            for chunk_id in range(self._num_chunks):\n                glob_chunk_id = block_id * self._num_chunks + chunk_id\n                ar_stream = self._ar_st[glob_chunk_id % self._num_ar_pg]\n                ar_stream.wait_stream(torch.cuda.current_stream())\n                with torch.cuda.stream(ar_stream):\n                    works[chunk_id] = torch.distributed.all_reduce(\n                        self._flat_grads_chunks[block_id][chunk_id],\n                        group=self._ar_pg[glob_chunk_id % self._num_ar_pg],\n                        async_op=True,\n                        op=_make_nccl_premul_sum(scale),\n                    )\n        else:\n            glob_chunk_id = block_id\n            ar_stream = self._ar_st[glob_chunk_id % self._num_ar_pg]\n            ar_stream.wait_stream(torch.cuda.current_stream())\n            with torch.cuda.stream(ar_stream):\n                works0 = torch.distributed.all_reduce(\n                    self._flat_grads_blocks[block_id],\n                    group=self._ar_pg[glob_chunk_id % self._num_ar_pg],\n                    async_op=True,\n                    op=_make_nccl_premul_sum(scale),\n                )\n            for i in range(self._num_chunks):\n                works[i] = works0\n        self._reductions_works[block_id] = works\n\n    def _full_all_reduce(self, block_id):\n        works = [None] * self._num_chunks\n\n        for chunk_id in range(self._num_chunks):\n            glob_chunk_id = block_id * self._num_chunks + chunk_id\n            ar_stream = self._ar_st[glob_chunk_id % self._num_ar_pg]\n            ar_stream.wait_stream(torch.cuda.current_stream())\n            with torch.cuda.stream(ar_stream):\n                works[chunk_id] = torch.distributed.all_reduce(\n                    self._flat_grads_chunks[block_id][chunk_id],\n                    group=self._ar_pg[glob_chunk_id % self._num_ar_pg],\n                    async_op=True,\n                )\n        self._reductions_works[block_id] = works\n\n    def _reduce_scatter_and_all_reduce_scale(self, block_id, scale):\n        # Reduction within each node\n        # Changes gradient format from [block * chunk * shard] to [shard * block * chunk]\n        # The output format is the same as the fp32 master parameters\n        works = [None] * self._num_chunks\n        for chunk_id in range(self._num_chunks):\n            glob_chunk_id = block_id * self._num_chunks + chunk_id\n            rs_stream = self._rs_st[glob_chunk_id % self._num_rs_pg]\n            rs_stream.wait_stream(torch.cuda.current_stream())\n            rs_stream.wait_stream(self._l2_grad_norm_st)\n            with torch.cuda.stream(rs_stream):\n                if self._reduce_scatter_no_copy:\n                    works[chunk_id] = torch.distributed.reduce_scatter(\n                        output=self._fp16_g_chunks[block_id][chunk_id],\n                        input_list=self._flat_grads_shards[block_id][chunk_id],\n                        group=self._rs_pg[glob_chunk_id % self._num_rs_pg],\n                        async_op=True,\n                        no_copy=True,\n                        op=_make_nccl_premul_sum(scale),\n                    )\n                else:\n                    works[chunk_id] = torch.distributed.reduce_scatter_tensor(\n                        output=self._fp16_g_chunks[block_id][chunk_id],\n                        input=self._flat_grads_chunks[block_id][chunk_id],\n                        group=self._rs_pg[glob_chunk_id % self._num_rs_pg],\n                        async_op=True,\n                        op=_make_nccl_premul_sum(scale),\n                    )\n\n        # Reduction across nodes for each rank\n        if self._num_groups > 1:\n            for chunk_id in range(self._num_chunks):\n                glob_chunk_id = block_id * self._num_chunks + chunk_id\n                ar_stream = self._ar_st[glob_chunk_id % self._num_ar_pg]\n                with torch.cuda.stream(ar_stream):\n                    works[chunk_id].wait()\n                    works[chunk_id] = torch.distributed.all_reduce(\n                        self._fp16_g_chunks[block_id][chunk_id],\n                        group=self._ar_pg[glob_chunk_id % self._num_ar_pg],\n                        async_op=True,\n                    )\n        self._reductions_works[block_id] = works\n\n    def _reduce_scatter_and_all_reduce(self, block_id):\n        # Reduction within each node\n        # Changes gradient format from [block * chunk * shard] to [shard * block * chunk]\n        # The output format is the same as the fp32 master parameters\n        works = [None] * self._num_chunks\n        for chunk_id in range(self._num_chunks):\n            glob_chunk_id = block_id * self._num_chunks + chunk_id\n            rs_stream = self._rs_st[glob_chunk_id % self._num_rs_pg]\n            rs_stream.wait_stream(torch.cuda.current_stream())\n            with torch.cuda.stream(rs_stream):\n                if self._reduce_scatter_no_copy:\n                    works[chunk_id] = torch.distributed.reduce_scatter(\n                        output=self._fp16_g_chunks[block_id][chunk_id],\n                        input_list=self._flat_grads_shards[block_id][chunk_id],\n                        group=self._rs_pg[glob_chunk_id % self._num_rs_pg],\n                        async_op=True,\n                        no_copy=True,\n                    )\n                else:\n                    works[chunk_id] = torch.distributed.reduce_scatter_tensor(\n                        output=self._fp16_g_chunks[block_id][chunk_id],\n                        input=self._flat_grads_chunks[block_id][chunk_id],\n                        group=self._rs_pg[glob_chunk_id % self._num_rs_pg],\n                        async_op=True,\n                    )\n\n        # Reduction across nodes for each rank\n        if self._num_groups > 1:\n            for chunk_id in range(self._num_chunks):\n                glob_chunk_id = block_id * self._num_chunks + chunk_id\n                ar_stream = self._ar_st[glob_chunk_id % self._num_ar_pg]\n                with torch.cuda.stream(ar_stream):\n                    works[chunk_id].wait()\n                    works[chunk_id] = torch.distributed.all_reduce(\n                        self._fp16_g_chunks[block_id][chunk_id],\n                        group=self._ar_pg[glob_chunk_id % self._num_ar_pg],\n                        async_op=True,\n                    )\n        self._reductions_works[block_id] = works\n\n    def _pipeline_block_reductions(self, block_id):\n        if self._clip_after_ar:\n            self._flatten_grad_mt(1.0 / self._world_size)\n\n            if self._full_ar:\n                self._full_all_reduce(block_id)\n            else:\n                self._reduce_scatter_and_all_reduce(block_id)\n\n            # Compute L2 grad norm\n            if block_id == 0:\n                with torch.cuda.stream(self._l2_grad_norm_st):\n                    for block_id in range(self._num_blocks):\n                        for chunk_id in range(self._num_chunks):\n                            self._reductions_works[block_id][chunk_id].wait()\n                    # Since the packed format is contiguous after reductions, only one norm is needed\n                    l2_grad_norm_sq = torch.empty([1], device=\"cuda\")\n                    if self._full_ar:\n                        # this flattening of lists is to keep multi_tensor_apply function happy, it wants depth=1 for l2 norm computation\n                        flat_list = [item for sublist in self._fp16_g_chunks for item in sublist]\n                        l2_grad_norm_sq = (\n                            multi_tensor_applier(\n                                self.multi_tensor_l2norm,\n                                self._overflow_buf,\n                                [flat_list],\n                                False,\n                            )[0]\n                            ** 2\n                        )\n                    else:\n                        l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2) ** 2\n                    torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)\n                    self._L2_grad_norm = l2_grad_norm_sq.sqrt()\n        else:\n            # Copy model grads to flat grads buffer\n            self._flatten_grad_mt(1.0)\n\n            # Compute L2 grad norm\n            self._l2_grad_norm_st.wait_stream(torch.cuda.current_stream())\n            with torch.cuda.stream(self._l2_grad_norm_st):\n                if not self._fused_norm:\n                    self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float16, p=2).float()\n            torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)\n\n            # Apply clipping & pre-reduction scaling on grads\n            loss_scale = self.global_scale\n            max_grad_norm = loss_scale * self.defaults[\"max_grad_norm\"]\n            coeff = max_grad_norm / (1e-6 + self.L2_grad_norm)\n            coeff = (coeff > 1) * self._one + (coeff <= 1) * coeff\n            tmp = torch.cat(((self._one), (coeff)))\n            index = (coeff + 1 > coeff).int()\n            scale = tmp.index_select(0, index).half() / self._world_size\n            if not self._fuse_scale:\n                self._flat_grads.mul_(scale)\n\n            if self._full_ar:\n                if self._fuse_scale:\n                    self._full_all_reduce_scale(block_id, scale)\n                else:\n                    self._full_all_reduce(block_id)\n            else:\n                if self._fuse_scale:\n                    self._reduce_scatter_and_all_reduce_scale(block_id, scale)\n                else:\n                    self._reduce_scatter_and_all_reduce(block_id)\n\n            if block_id == 0:\n                for block_id in range(self._num_blocks):\n                    for chunk_id in range(self._num_chunks):\n                        self._reductions_works[block_id][chunk_id].wait()\n\n    def __compute_contrib_param_norm(self):\n        if (\n            self._contrib_model_param_for_norm_fp16 is not None\n            and self._contrib_model_param_for_norm_fp32 is not None\n        ):\n            gnorm_fp16 = multi_tensor_applier(\n                self.multi_tensor_l2norm,\n                self._overflow_buf,\n                [self._contrib_model_param_for_norm_fp16],\n                True,\n            )[1]\n            gnorm_fp32 = multi_tensor_applier(\n                self.multi_tensor_l2norm,\n                self._overflow_buf,\n                [self._contrib_model_param_for_norm_fp32],\n                True,\n            )[1]\n            gnorm = torch.empty(\n                size=[self._contrib_model_param_for_norm_num],\n                dtype=torch.bool,\n                device=\"cuda\",\n            )\n            gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp16, gnorm_fp16)\n            gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp32, gnorm_fp32)\n        elif self._contrib_model_param_for_norm_fp16 is not None:\n            gnorm = multi_tensor_applier(\n                self.multi_tensor_l2norm,\n                self._overflow_buf,\n                [self._contrib_model_param_for_norm_fp16],\n                True,\n            )[1]\n        elif self._contrib_model_param_for_norm_fp32 is not None:\n            gnorm = multi_tensor_applier(\n                self.multi_tensor_l2norm,\n                self._overflow_buf,\n                [self._contrib_model_param_for_norm_fp32],\n                True,\n            )[1]\n        return gnorm\n\n    def __compute_contrib_update_norm(self):\n        l2_norm = torch.zeros(size=[self._model_params_num], dtype=torch.float32, device=\"cuda\")\n        local_contrib_l2_norm = (\n            multi_tensor_applier(\n                self.multi_tensor_l2norm,\n                self._overflow_buf,\n                [self._contrib_update_frag_for_norm],\n                True,\n            )[1]\n            ** 2\n        )\n        l2_norm.scatter_(dim=0, index=self._offsets, src=local_contrib_l2_norm)\n        torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0])\n        l2_norm = torch.sqrt(l2_norm)\n        return l2_norm\n\n    def _pipeline_step(self):\n        global_scale = self.global_scale\n        # if clip before ar, set max_grad_norm to 0\n        max_grad_norm = self.defaults[\"max_grad_norm\"] * self._clip_after_ar\n        self._completion_st.wait_stream(self._l2_grad_norm_st)\n        global_grad_norm = self.L2_grad_norm\n\n        # check global_grad_norm and fill overflow_buf\n        is_finite = (global_grad_norm + 1 > global_grad_norm).int()\n        self._overflow_buf = self._one * (is_finite ^ self._one)  # toggle between 0 and 1\n\n        if not self._clip_after_ar:\n            torch.distributed.all_reduce(\n                is_finite,\n                op=torch.distributed.ReduceOp.MIN,\n                group=self._current_process_group,\n            )\n            torch.distributed.all_reduce(\n                self._overflow_buf,\n                op=torch.distributed.ReduceOp.MAX,\n                group=self._current_process_group,\n            )\n\n        # increment step counter if no overflow\n        self._step += is_finite\n        self._completion_st.wait_stream(torch.cuda.current_stream())\n        self._completion_st.wait_stream(self._l2_grad_norm_st)\n\n        # Call step kernel once per step\n        # Call all-gather once per step\n        with torch.cuda.stream(self._completion_st):\n            for block_id in range(self._num_blocks):\n                for chunk_id in range(self._num_chunks):\n                    self._reductions_works[block_id][chunk_id].wait()\n            param_norm = self.__compute_contrib_param_norm()\n            multi_tensor_applier(\n                self.multi_tensor_lamb_compute_update_term,\n                self._overflow_buf,\n                self._contrib_compute_update_term_tensor_list,  # g, p, m, v, u\n                self._contrib_beta1,\n                self._contrib_beta2,\n                self._contrib_beta3,\n                self._contrib_bias_correction,\n                self._step,\n                self._contrib_epsilon,\n                self._adam_w_mode,\n                self._contrib_weight_decay,\n                global_scale,\n                global_grad_norm,\n                max_grad_norm,\n            )\n            upd_norm = self.__compute_contrib_update_norm()\n            multi_tensor_applier(\n                self.multi_tensor_lamb_update_weights,\n                self._overflow_buf,\n                self._contrib_update_weights_tensor_list,  # u, p, p_copy\n                param_norm,\n                upd_norm,\n                self._offsets,\n                self._lr,\n                self._contrib_weight_decay,\n                global_grad_norm,\n                self._use_nvlamb,\n            )\n            if not self._skip_ag:\n                # allgather chunking is currently not supported for clip after allreduce\n                if not self._clip_after_ar:\n                    for block in range(self._num_blocks):\n                        for chunk in range(self._num_chunks):\n                            if self._all_gather_no_copy:\n                                torch.distributed.all_gather(\n                                    tensor_list=self._new_params2_shards[block][chunk],\n                                    tensor=self._fp16_p_chunks[block][chunk],\n                                    group=self._ag_pg[0],\n                                    no_copy=True,\n                                )\n                            else:\n                                torch.distributed.all_gather_into_tensor(\n                                    output_tensor=self._new_params2_blocks[block],\n                                    input_tensor=self._fp16_p_chunks[block][chunk],\n                                    group=self._ag_pg[0],\n                                )\n                else:\n                    if self._all_gather_no_copy:\n                        torch.distributed.all_gather(\n                            tensor_list=self._new_params_mega_shards,\n                            tensor=self._fp16_p,\n                            group=self._ag_pg[0],\n                            no_copy=True,\n                        )\n                    else:\n                        torch.distributed.all_gather_into_tensor(\n                            output_tensor=self._new_params,\n                            input_tensor=self._fp16_p,\n                            group=self._ag_pg[0],\n                        )\n\n    def _flatten_grad_mt(self, scale):\n        if len(self._grads_fp16) > 0:\n            self._overflow_buf.zero_()\n            if not self._fused_norm:\n                multi_tensor_applier(\n                    amp_C.multi_tensor_scale,\n                    self._overflow_buf,\n                    list(zip(*self._grads_fp16)),\n                    scale,\n                )\n            else:\n                self._L2_grad_norm = multi_tensor_applier(\n                    amp_C.multi_tensor_l2norm_scale,\n                    self._overflow_buf,\n                    list(zip(*self._grads_fp16)),\n                    scale,\n                    False,\n                )[0].float()\n\n            self._grads_fp16 = []\n        if len(self._grads_fp32) > 0:\n            self._overflow_buf.zero_()\n            if not self._fused_norm:\n                multi_tensor_applier(\n                    amp_C.multi_tensor_scale,\n                    self._overflow_buf,\n                    list(zip(*self._grads_fp32)),\n                    scale,\n                )\n            else:\n                self._L2_grad_norm = multi_tensor_applier(\n                    amp_C.multi_tensor_l2norm_scale,\n                    self._overflow_buf,\n                    list(zip(*self._grads_fp32)),\n                    scale,\n                    False,\n                )[0].float()\n            self._grads_fp32 = []\n\n    def _do_overlapped_reduction(self, param_i, param):\n        if not self._is_accumulation_step:\n            # handle overlapped reductions\n            if param.dtype == torch.float16:\n                self._grads_fp16.append((param.grad, self._individual_flat_grads[param_i]))\n            else:\n                self._grads_fp32.append((param.grad, self._individual_flat_grads[param_i]))\n            self._grads_generated[param_i] = True\n            if not self._first_step and not self._last_step:\n                if self._overlap_reductions:\n                    flush_block = self._get_flush_block()\n                    while flush_block:\n                        block_id = flush_block[0] // self._block_size\n                        self._pipeline_block_reductions(block_id)\n                        flush_block = self._get_flush_block()\n\n    def set_global_scale(self, global_scale):\n        \"\"\"Set global scale.\"\"\"\n        self._global_scale = global_scale\n\n    @property\n    def global_scale(self):\n        return self._global_scale\n\n    @property\n    def L2_grad_norm(self):\n        torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)\n        return self._L2_grad_norm\n\n    def complete_reductions(self):\n        \"\"\"Complete reductions if full pipeline is not selected or overlap is not allowed.\"\"\"\n        if self._last_step:\n            # zero out gradients that have not been completed yet\n            for param_i, grad_generated in enumerate(self._grads_generated):\n                if not grad_generated:\n                    grad_info = self._grads_info[param_i]\n                    param_offset = grad_info[\"param_offset\"]\n                    param_size = grad_info[\"param_grads_size\"]\n                    self._flat_grads[param_offset : param_offset + param_size].zero_()\n                    self._grads_generated[param_i] = True\n\n        if self._first_step or self._last_step or not self._overlap_reductions:\n            # nothing done so far, run full pipeline after reductions\n            for block_id in range(self._num_blocks - 1, -1, -1):\n                self._pipeline_block_reductions(block_id)\n\n        torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)\n\n        self._current_block = self._num_blocks\n        self._grads_generated = [False] * len(self._grads_info)\n\n    def step(self, closure=None, grad_scaler=None):\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        self._pipeline_step()\n\n        if grad_scaler is not None:\n            found_inf = self._overflow_buf.float()\n            optimizer_state = grad_scaler._per_optimizer_states[id(self)]\n            current_device = torch.device(\"cuda\", torch.cuda.current_device())\n            optimizer_state[\"found_inf_per_device\"][current_device] = found_inf\n\n        self._completion_st.wait_stream(torch.cuda.current_stream())\n        if not self._set_flat_param_view:\n            with torch.cuda.stream(self._completion_st):\n                # Copy self._new_params to model params\n                with torch.no_grad():\n                    if self._packed_flat_to_model_params_fp16 is not None:\n                        multi_tensor_applier(\n                            fused_adam_cuda.maybe_cast_mt,\n                            self._overflow_buf,\n                            self._packed_flat_to_model_params_fp16,\n                        )\n                    if self._packed_flat_to_model_params_fp32 is not None:\n                        multi_tensor_applier(\n                            fused_adam_cuda.maybe_cast_mt,\n                            self._overflow_buf,\n                            self._packed_flat_to_model_params_fp32,\n                        )\n\n        torch.cuda.current_stream().wait_stream(self._completion_st)\n\n        self._reductions_works = [None] * self._num_blocks\n        self._allgather_works = [None] * self._num_blocks\n\n        return loss\n\n    def state_dict(self):\n        \"\"\"\n        Returns a dict containing the current state of this :class:`DistributedFusedAdam` instance.\n        Example::\n            checkpoint = {}\n            checkpoint['model'] = model.state_dict()\n            checkpoint['optimizer'] = optimizer.state_dict()\n            torch.save(checkpoint, \"saved.pth\")\n        \"\"\"\n        # save step, master weights and first/second moments\n        state_dict = {}\n        state_dict[\"step\"] = self._step\n        state_dict[\"fp32_p\"] = self._fp32_p\n        state_dict[\"fp32_m\"] = self._fp32_m\n        state_dict[\"fp32_v\"] = self._fp32_v\n        return state_dict\n\n    def load_state_dict(self, state_dict):\n        \"\"\"\n        Loads a state_dict created by an earlier call to state_dict().\n        If an DistributedFusedAdam instance was constructed from some ``init_optimizer``,\n        whose parameters in turn came from ``model``, it is expected that the user\n        will call ``model.load_state_dict()`` before\n        ``optimizer.load_state_dict()`` is called.\n        Example::\n            model = torch.nn.Linear(D_in, D_out).cuda().half()\n            optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n            optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)\n            ...\n            checkpoint = torch.load(\"saved.pth\")\n            model.load_state_dict(checkpoint['model'])\n            optimizer.load_state_dict(checkpoint['optimizer'])\n        \"\"\"\n        # restore step, master weights and first/second moments\n        self._step = state_dict[\"step\"]\n        self._fp32_p = state_dict[\"fp32_p\"].to(device=\"cuda\")\n        self._fp32_m = state_dict[\"fp32_m\"].to(device=\"cuda\")\n        self._fp32_v = state_dict[\"fp32_v\"].to(device=\"cuda\")\n        self._resume_from_checkpoint = True\n"
  },
  {
    "path": "apex/contrib/optimizers/fp16_optimizer.py",
    "content": "import torch\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\n\nclass FP16_Optimizer(object):\n    \"\"\"\n    :class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.\n    Designed only to wrap apex.contrib.optimizers.FusedAdam, FusedSGD.\n    Refer to apex.fp16_utils documents for more information.\n    Example::\n        model = torch.nn.Linear(D_in, D_out).cuda().half()\n        optimizer = apex.contrib.optimizers.FusedSGD(model.parameters())\n        optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)\n        ...\n        # loss.backward() becomes:\n        optimizer.backward(loss)\n        ...\n    Example with dynamic loss scaling::\n        ...\n        optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)\n                                   # optional arg to control dynamic loss scaling behavior\n                                   # dynamic_loss_args={'scale_window' : 500})\n                                   # Usually, dynamic_loss_args is not necessary.\n    \"\"\"\n\n    def __init__(\n        self,\n        init_optimizer,\n        static_loss_scale=1.0,\n        dynamic_loss_scale=False,\n        dynamic_loss_args=None,\n        verbose=True,\n    ):\n        print(\"\\nThis fp16_optimizer is designed to only work with apex.contrib.optimizers.*\")\n        print(\"To update, use updated optimizers with AMP.\")\n        # The fused optimizer does all the work. We need this layer for two reason:\n        # 1. maintain same user API from apex.fp16_utils\n        # 2. keep common stuff here in case we need to add new fused optimizer later\n\n        if not torch.cuda.is_available:\n            raise SystemError(\"Cannot use fp16 without CUDA.\")\n        self.optimizer = init_optimizer\n\n        self.fp16_groups = []  # model params\n        self.fp32_groups = []  # master weights\n\n        # iterate over param_groups\n        for param_group in self.optimizer.param_groups:\n            fp16_group = []\n            fp32_group = []\n            for p in param_group[\"params\"]:\n                fp16_group.append(p)\n                fp32_group.append(p.clone().float().detach())\n            self.fp16_groups.append(fp16_group)\n            self.fp32_groups.append(fp32_group)\n            param_group[\"params\"] = fp32_group\n\n        if multi_tensor_applier.available:\n            import amp_C\n\n            self.overflow_buf = torch.cuda.IntTensor([0])\n            self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm\n        else:\n            raise RuntimeError(\"FP16_Optimizer requires cuda extensions\")\n\n        # we may have a way of fusing dynamic scale. Do not support for now\n        if dynamic_loss_scale:\n            if dynamic_loss_args is not None:\n                raise SystemError(\"Do not support dynamic loss scale args for now.\")\n            self.dynamic_loss_scale = True\n            self.cur_scale = 2**16\n            self.cur_iter = 0\n            self.last_overflow_iter = -1\n            self.scale_factor = 2\n            self.scale_window = 1000\n        else:\n            self.dynamic_loss_scale = False\n            self.cur_iter = 0\n            self.cur_scale = static_loss_scale\n        self.verbose = verbose\n\n    def zero_grad(self, set_grads_to_None=True):\n        \"\"\"\n        Zero FP16 parameter grads.\n        \"\"\"\n        # FP32 grad should never exist.\n        # For speed, set model fp16 grad to None by default\n        for group in self.fp16_groups:\n            for p in group:\n                if set_grads_to_None:\n                    p.grad = None\n                else:\n                    if p.grad is not None:\n                        p.grad.detach_()\n                        p.grad.zero_()\n\n    def step(self, closure=None):\n        \"\"\"\n        Not supporting closure.\n        \"\"\"\n        fp16_grads = []\n        norm_groups = []\n        skip = False\n\n        for group in self.fp16_groups:\n            fp16_grad = []\n            for i, p in enumerate(group):\n                fp16_grad.append(p.grad)\n            fp16_grads.append(fp16_grad)\n\n        # nan check\n        self.overflow_buf.zero_()\n        for fp16_grad in fp16_grads:\n            if len(fp16_grad) > 0:\n                norm, norm_per_tensor = multi_tensor_applier(\n                    self.multi_tensor_l2norm, self.overflow_buf, [fp16_grad], True\n                )\n                norm_groups.append(norm)\n                if self.overflow_buf.item() != 0:\n                    skip = True\n\n        if skip:\n            self._update_scale(skip)\n            return\n\n        # norm is in fact norm*cur_scale\n        self.optimizer.step(\n            grads=fp16_grads,\n            output_params=self.fp16_groups,\n            scale=self.cur_scale,\n            grad_norms=norm_groups,\n        )\n\n        self._update_scale(False)\n        return\n\n    def backward(self, loss):\n        \"\"\"\n        :attr:`backward` performs the following steps:\n        1. fp32_loss = loss.float()\n        2. scaled_loss = fp32_loss*loss_scale\n        3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves\n        \"\"\"\n        scaled_loss = (loss.float()) * self.cur_scale\n        scaled_loss.backward()\n\n    def _update_scale(self, skip):\n        if self.dynamic_loss_scale:\n            if skip:\n                if self.verbose:\n                    print(\"\\nGrad overflow on iteration\", self.cur_iter)\n                    print(\"Using dynamic loss scale of\", self.cur_scale)\n                self.cur_scale = max(self.cur_scale / self.scale_factor, 1)\n                self.last_overflow_iter = self.cur_iter\n            else:\n                if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:\n                    self.cur_scale *= self.scale_factor\n        else:\n            if skip:\n                print(\"\\nGrad overflow on iteration\", self.cur_iter)\n                print(\"Using static loss scale of\", self.cur_scale)\n        self.cur_iter += 1\n        return\n\n    # Promote state so it can be retrieved or set via \"fp16_optimizer_instance.state\"\n    def _get_state(self):\n        return self.optimizer.state\n\n    def _set_state(self, value):\n        self.optimizer.state = value\n\n    state = property(_get_state, _set_state)\n\n    # Promote param_groups so it can be retrieved or set via \"fp16_optimizer_instance.param_groups\"\n    # (for example, to adjust the learning rate)\n    def _get_param_groups(self):\n        return self.optimizer.param_groups\n\n    def _set_param_groups(self, value):\n        self.optimizer.param_groups = value\n\n    param_groups = property(_get_param_groups, _set_param_groups)\n\n    def state_dict(self):\n        \"\"\"\n        Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.\n        This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict\n        of the contained Pytorch optimizer.\n        Example::\n            checkpoint = {}\n            checkpoint['model'] = model.state_dict()\n            checkpoint['optimizer'] = optimizer.state_dict()\n            torch.save(checkpoint, \"saved.pth\")\n        \"\"\"\n        state_dict = {}\n        state_dict[\"dynamic_loss_scale\"] = self.dynamic_loss_scale\n        state_dict[\"cur_scale\"] = self.cur_scale\n        state_dict[\"cur_iter\"] = self.cur_iter\n        if state_dict[\"dynamic_loss_scale\"]:\n            state_dict[\"last_overflow_iter\"] = self.last_overflow_iter\n            state_dict[\"scale_factor\"] = self.scale_factor\n            state_dict[\"scale_window\"] = self.scale_window\n        state_dict[\"optimizer_state_dict\"] = self.optimizer.state_dict()\n        state_dict[\"fp32_groups\"] = self.fp32_groups\n        return state_dict\n\n    def load_state_dict(self, state_dict):\n        \"\"\"\n        Loads a state_dict created by an earlier call to state_dict().\n        If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,\n        whose parameters in turn came from ``model``, it is expected that the user\n        will call ``model.load_state_dict()`` before\n        ``fp16_optimizer_instance.load_state_dict()`` is called.\n        Example::\n            model = torch.nn.Linear(D_in, D_out).cuda().half()\n            optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n            optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)\n            ...\n            checkpoint = torch.load(\"saved.pth\")\n            model.load_state_dict(checkpoint['model'])\n            optimizer.load_state_dict(checkpoint['optimizer'])\n        \"\"\"\n        # I think it should actually be ok to reload the optimizer before the model.\n        self.dynamic_loss_scale = state_dict[\"dynamic_loss_scale\"]\n        self.cur_scale = state_dict[\"cur_scale\"]\n        self.cur_iter = state_dict[\"cur_iter\"]\n        if state_dict[\"dynamic_loss_scale\"]:\n            self.last_overflow_iter = state_dict[\"last_overflow_iter\"]\n            self.scale_factor = state_dict[\"scale_factor\"]\n            self.scale_window = state_dict[\"scale_window\"]\n        self.optimizer.load_state_dict(state_dict[\"optimizer_state_dict\"])\n        # At this point, the optimizer's references to the model's fp32 parameters are up to date.\n        # The optimizer's hyperparameters and internal buffers are also up to date.\n        # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still\n        # out of date.  There are two options.\n        # 1:  Refresh the master params from the model's fp16 params.\n        # This requires less storage but incurs precision loss.\n        # 2:  Save and restore the fp32 master copies separately.\n        # We choose option 2.\n        #\n        # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device\n        # of their associated parameters, because it's possible those buffers might not exist yet in\n        # the current optimizer instance.  In our case, as long as the current FP16_Optimizer has been\n        # constructed in the same way as the one whose state_dict we are loading, the same master params\n        # are guaranteed to exist, so we can just copy_() from the saved master params.\n        for current, saved in zip(self.fp32_groups, state_dict[\"fp32_groups\"]):\n            for _current, _saved in zip(current, saved):\n                _current.data.copy_(_saved.data)\n"
  },
  {
    "path": "apex/contrib/optimizers/fused_adam.py",
    "content": "import types\nimport torch\nimport importlib\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\n\nclass FusedAdam(torch.optim.Optimizer):\n    \"\"\"Implements Adam algorithm. Currently GPU-only.  Requires Apex to be installed via\n    ``python setup.py install --cuda_ext --cpp_ext``.\n\n    It has been proposed in `Adam: A Method for Stochastic Optimization`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            (default: False) NOT SUPPORTED in FusedAdam!\n        eps_inside_sqrt (boolean, optional): in the 'update parameters' step,\n            adds eps to the bias-corrected second moment estimate before\n            evaluating square root instead of adding it to the square root of\n            second moment estimate as in the original paper. (default: False)\n        use_mt (boolean, optional): use multi tensor apply for lower launch\n            latency. (default: False)\n\n    .. _Adam - A Method for Stochastic Optimization:\n        https://arxiv.org/abs/1412.6980\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(\n        self,\n        params,\n        lr=1e-3,\n        bias_correction=True,\n        betas=(0.9, 0.999),\n        eps=1e-8,\n        eps_inside_sqrt=False,\n        weight_decay=0.0,\n        max_grad_norm=0.0,\n        amsgrad=False,\n        use_mt=False,\n        amp_scale_adjustment=1.0,\n    ):\n        global fused_adam_cuda\n        fused_adam_cuda = importlib.import_module(\"fused_adam_cuda\")\n\n        self._use_multi_tensor = False\n        if use_mt:\n            if not multi_tensor_applier.available:\n                print(\"Warning:  multi_tensor_applier is unavailable\")\n            else:\n                self._use_multi_tensor = True\n                self._overflow_buf = torch.cuda.IntTensor([0])\n\n        self._amp_scale_adjustment = amp_scale_adjustment\n\n        if amsgrad:\n            raise RuntimeError(\"FusedAdam does not support the AMSGrad variant.\")\n        defaults = dict(\n            lr=lr,\n            bias_correction=bias_correction,\n            betas=betas,\n            eps=eps,\n            weight_decay=weight_decay,\n            max_grad_norm=max_grad_norm,\n        )\n        super(FusedAdam, self).__init__(params, defaults)\n        self.eps_mode = 0 if eps_inside_sqrt else 1\n\n    def step(self, closure=None, grads=None, output_params=None, scale=1.0, grad_norms=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n            grads (list of tensors, optional): weight gradient to use for the\n                optimizer update. If gradients have type torch.half, parameters\n                are expected to be in type torch.float. (default: None)\n            output params (list of tensors, optional): A reduced precision copy\n                of the updated weights written out in addition to the regular\n                updated weights. Have to be of same type as gradients. (default: None)\n            scale (float, optional): factor to divide gradient tensor values\n                by before applying to weights. (default: 1)\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        if hasattr(self, \"_amp_stash\"):\n            grads = self._amp_stash.grads\n            output_params = self._amp_stash.output_params\n            scale = self._amp_stash.scale * self._amp_scale_adjustment\n            grad_norms = self._amp_stash.grad_norms\n\n        if grads is None:\n            grads_group = [None] * len(self.param_groups)\n        # backward compatibility\n        # assuming a list/generator of parameter means single group\n        elif isinstance(grads, types.GeneratorType):\n            grads_group = [grads]\n        elif type(grads[0]) != list:\n            grads_group = [grads]\n        else:\n            grads_group = grads\n\n        if output_params is None:\n            output_params_group = [None] * len(self.param_groups)\n        elif isinstance(output_params, types.GeneratorType):\n            output_params_group = [output_params]\n        elif type(output_params[0]) != list:\n            output_params_group = [output_params]\n        else:\n            output_params_group = output_params\n\n        if grad_norms is None:\n            grad_norms = [None] * len(self.param_groups)\n\n        for group, grads_this_group, output_params_this_group, grad_norm in zip(\n            self.param_groups, grads_group, output_params_group, grad_norms\n        ):\n            if grads_this_group is None:\n                grads_this_group = [None] * len(group[\"params\"])\n            if output_params_this_group is None:\n                output_params_this_group = [None] * len(group[\"params\"])\n\n            # compute combined scale factor for this group\n            combined_scale = scale\n            if group[\"max_grad_norm\"] > 0:\n                # norm is in fact norm*scale\n                clip = ((grad_norm / scale) + 1e-6) / group[\"max_grad_norm\"]\n                if clip > 1:\n                    combined_scale = clip * scale\n\n            bias_correction = 1 if group[\"bias_correction\"] else 0\n\n            if self._use_multi_tensor:\n                if output_params:\n                    tensorlists = [[], [], [], [], []]\n                else:\n                    tensorlists = [[], [], [], []]\n                tensordevice = None\n\n            for p, grad, output_param in zip(\n                group[\"params\"], grads_this_group, output_params_this_group\n            ):\n                # note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients\n                if p.grad is None and grad is None:\n                    continue\n                if grad is None:\n                    grad = p.grad.data\n                if grad.is_sparse:\n                    raise RuntimeError(\n                        \"FusedAdam does not support sparse gradients, please consider SparseAdam instead\"\n                    )\n\n                state = self.state[p]\n\n                # State initialization\n                if len(state) == 0:\n                    state[\"step\"] = 0\n                    # Exponential moving average of gradient values\n                    state[\"exp_avg\"] = torch.zeros_like(p.data)\n                    # Exponential moving average of squared gradient values\n                    state[\"exp_avg_sq\"] = torch.zeros_like(p.data)\n\n                exp_avg, exp_avg_sq = state[\"exp_avg\"], state[\"exp_avg_sq\"]\n                beta1, beta2 = group[\"betas\"]\n\n                state[\"step\"] += 1\n\n                out_p = (\n                    torch.tensor([], dtype=torch.float) if output_param is None else output_param\n                )\n                if self._use_multi_tensor:\n                    pl = [p.data, exp_avg, exp_avg_sq, grad]\n                    if output_param is not None:\n                        pl.append(out_p)\n\n                    for tl, t in zip(tensorlists, pl):\n                        tl.append(t)\n\n                    if tensordevice is None:\n                        tensordevice = p.device\n                    elif tensordevice != p.device:\n                        raise RuntimeError(\n                            \"FusedAdam does not support use_mt with tensors on multiple device\"\n                        )\n\n                else:\n                    with torch.cuda.device(p.device):\n                        fused_adam_cuda.adam(\n                            p.data,\n                            out_p,\n                            exp_avg,\n                            exp_avg_sq,\n                            grad,\n                            group[\"lr\"],\n                            beta1,\n                            beta2,\n                            group[\"eps\"],\n                            combined_scale,\n                            state[\"step\"],\n                            self.eps_mode,\n                            bias_correction,\n                            group[\"weight_decay\"],\n                        )\n\n            if self._use_multi_tensor:\n                with torch.cuda.device(tensordevice):\n                    multi_tensor_applier(\n                        fused_adam_cuda.adam_mt,\n                        self._overflow_buf,\n                        tensorlists,\n                        group[\"lr\"],\n                        beta1,\n                        beta2,\n                        group[\"eps\"],\n                        combined_scale,\n                        state[\"step\"],\n                        self.eps_mode,\n                        bias_correction,\n                        group[\"weight_decay\"],\n                    )\n\n        return loss\n"
  },
  {
    "path": "apex/contrib/optimizers/fused_lamb.py",
    "content": "import torch\nimport importlib\nimport math\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\n\nclass FusedLAMB(torch.optim.Optimizer):\n    \"\"\"Implements LAMB algorithm.\n\n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" --global-option=\"--deprecated_fused_lamb\" ./``.\n\n    This version of fused LAMB implements 2 fusions.\n\n      * Fusion of the LAMB update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.contrib.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::\n\n        opt = apex.contrib.optimizers.FusedLAMB(model.parameters(), lr = ....)\n        ...\n        opt.step()\n\n    :class:`apex.optimizers.FusedLAMB` may be used with or without Amp.  If you wish to use :class:`FusedLAMB` with Amp,\n    you may choose any ``opt_level``::\n\n        opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n\n    In general, ``opt_level=\"O1\"`` is recommended.\n\n    LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its norm. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            NOT SUPPORTED now! (default: False)\n        adam_w_mode (boolean, optional): Apply L2 regularization or weight decay\n            True for decoupled weight decay(also known as AdamW) (default: True)\n        grad_averaging (bool, optional): whether apply (1-beta2) to grad when\n            calculating running averages of gradient. (default: True)\n        set_grad_none (bool, optional): whether set grad to None when zero_grad()\n            method is called. (default: True)\n        max_grad_norm (float, optional): value used to clip global grad norm\n            (default: 1.0)\n\n    .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:\n        https://arxiv.org/abs/1904.00962\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(\n        self,\n        params,\n        lr=1e-3,\n        bias_correction=True,\n        betas=(0.9, 0.999),\n        eps=1e-6,\n        weight_decay=0.01,\n        amsgrad=False,\n        adam_w_mode=True,\n        grad_averaging=True,\n        set_grad_none=True,\n        max_grad_norm=1.0,\n    ):\n        if amsgrad:\n            raise RuntimeError(\"FusedLAMB does not support the AMSGrad variant.\")\n        defaults = dict(\n            lr=lr,\n            bias_correction=bias_correction,\n            betas=betas,\n            eps=eps,\n            weight_decay=weight_decay,\n            grad_averaging=grad_averaging,\n            max_grad_norm=max_grad_norm,\n        )\n        super(FusedLAMB, self).__init__(params, defaults)\n        if multi_tensor_applier.available:\n            import amp_C\n\n            self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm\n            self._dummy_overflow_buf = torch.cuda.IntTensor([0])\n            fused_lamb_cuda = importlib.import_module(\"fused_lamb_cuda\")\n            self.multi_tensor_lamb = fused_lamb_cuda.lamb\n        else:\n            raise RuntimeError(\"apex.contrib.optimizers.FusedLAMB requires cuda extensions\")\n\n        self.adam_w_mode = 1 if adam_w_mode else 0\n        self.set_grad_none = set_grad_none\n\n    def zero_grad(self):\n        if self.set_grad_none:\n            for group in self.param_groups:\n                for p in group[\"params\"]:\n                    p.grad = None\n        else:\n            super(FusedLAMB, self).zero_grad()\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        # create separate grad lists for fp32 and fp16 params\n        g_all_32, g_all_16 = [], []\n        for group in self.param_groups:\n            for p in group[\"params\"]:\n                if p.grad is None:\n                    continue\n                if p.dtype == torch.float32:\n                    g_all_32.append(p.grad.data)\n                elif p.dtype == torch.float16:\n                    g_all_16.append(p.grad.data)\n                else:\n                    raise RuntimeError(\"FusedLAMB only support fp16 and fp32.\")\n\n        g_norm_32, g_norm_16 = 0.0, 0.0\n        # compute grad norm for two lists\n        if len(g_all_32) > 0:\n            g_norm_32 = multi_tensor_applier(\n                self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_32], False\n            )[0].item()\n        if len(g_all_16) > 0:\n            g_norm_16 = multi_tensor_applier(\n                self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_16], False\n            )[0].item()\n\n        # blend two grad norms to get global grad norm\n        global_grad_norm = math.sqrt(g_norm_32 * g_norm_32 + g_norm_16 * g_norm_16)\n        max_grad_norm = self.defaults[\"max_grad_norm\"]\n\n        for group in self.param_groups:\n            bias_correction = 1 if group[\"bias_correction\"] else 0\n            beta1, beta2 = group[\"betas\"]\n            grad_averaging = 1 if group[\"grad_averaging\"] else 0\n\n            # assume same step across group now to simplify things\n            # per parameter step can be easily support by making it tensor, or pass list into kernel\n            if \"step\" in group:\n                group[\"step\"] += 1\n            else:\n                group[\"step\"] = 1\n\n            # create lists for multi-tensor apply\n            g_16, p_16, m_16, v_16 = [], [], [], []\n            g_32, p_32, m_32, v_32 = [], [], [], []\n\n            for p in group[\"params\"]:\n                if p.grad is None:\n                    continue\n                if p.grad.data.is_sparse:\n                    raise RuntimeError(\n                        \"FusedLAMB does not support sparse gradients, please consider SparseAdam instead\"\n                    )\n\n                state = self.state[p]\n                # State initialization\n                if len(state) == 0:\n                    # Exponential moving average of gradient values\n                    state[\"exp_avg\"] = torch.zeros_like(p.data)\n                    # Exponential moving average of gradient values\n                    state[\"exp_avg_sq\"] = torch.zeros_like(p.data)\n\n                if p.dtype == torch.float16:\n                    g_16.append(p.grad.data)\n                    p_16.append(p.data)\n                    m_16.append(state[\"exp_avg\"])\n                    v_16.append(state[\"exp_avg_sq\"])\n                elif p.dtype == torch.float32:\n                    g_32.append(p.grad.data)\n                    p_32.append(p.data)\n                    m_32.append(state[\"exp_avg\"])\n                    v_32.append(state[\"exp_avg_sq\"])\n                else:\n                    raise RuntimeError(\"FusedLAMB only support fp16 and fp32.\")\n\n            if len(g_16) > 0:\n                multi_tensor_applier(\n                    self.multi_tensor_lamb,\n                    self._dummy_overflow_buf,\n                    [g_16, p_16, m_16, v_16],\n                    group[\"lr\"],\n                    beta1,\n                    beta2,\n                    group[\"eps\"],\n                    group[\"step\"],\n                    bias_correction,\n                    group[\"weight_decay\"],\n                    grad_averaging,\n                    self.adam_w_mode,\n                    global_grad_norm,\n                    max_grad_norm,\n                )\n            if len(g_32) > 0:\n                multi_tensor_applier(\n                    self.multi_tensor_lamb,\n                    self._dummy_overflow_buf,\n                    [g_32, p_32, m_32, v_32],\n                    group[\"lr\"],\n                    beta1,\n                    beta2,\n                    group[\"eps\"],\n                    group[\"step\"],\n                    bias_correction,\n                    group[\"weight_decay\"],\n                    grad_averaging,\n                    self.adam_w_mode,\n                    global_grad_norm,\n                    max_grad_norm,\n                )\n\n        return loss\n"
  },
  {
    "path": "apex/contrib/optimizers/fused_sgd.py",
    "content": "import types\nimport torch\nfrom torch.optim.optimizer import Optimizer, required\n\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\n\nclass FusedSGD(Optimizer):\n    r\"\"\"Implements stochastic gradient descent (optionally with momentum).\n\n    This version of fused SGD implements 2 fusions.\n      * Fusion of the SGD update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.contrib.optimizers.FusedSGD` should be used without AMP.\n   \n    :class:`apex.contrib.optimizers.FusedSGD` only works in the case where all parameters require grad. \n\n    Nesterov momentum is based on the formula from\n    `On the importance of initialization and momentum in deep learning`__.\n\n    Args:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float): learning rate\n        momentum (float, optional): momentum factor (default: 0)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        dampening (float, optional): dampening for momentum (default: 0)\n        nesterov (bool, optional): enables Nesterov momentum (default: False)\n\n    Example:\n        model = ...\n        model.half()\n        optimizer = apex.contrib.optimizers.FusedSGD(model.parameters())\n        # wrap with FP16_Optimizer\n        optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)\n        optimizer.zero_grad()\n\t...\n        optimizer.backward(loss)\n        optmizer.step()\n\n    __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf\n\n    .. note::\n        The implementation of SGD with Momentum/Nesterov subtly differs from\n        Sutskever et. al. and implementations in some other frameworks.\n\n        Considering the specific case of Momentum, the update can be written as\n\n        .. math::\n                  v = \\rho * v + g \\\\\n                  p = p - lr * v\n\n        where p, g, v and :math:`\\rho` denote the parameters, gradient,\n        velocity, and momentum respectively.\n\n        This is in contrast to Sutskever et. al. and\n        other frameworks which employ an update of the form\n\n        .. math::\n             v = \\rho * v + lr * g \\\\\n             p = p - v\n\n        The Nesterov version is analogously modified.\n    \"\"\"\n\n    def __init__(\n        self,\n        params,\n        lr=required,\n        momentum=0,\n        dampening=0,\n        weight_decay=0,\n        nesterov=False,\n        wd_after_momentum=False,\n        materialize_master_grads=True,\n    ):\n        if lr is not required and lr < 0.0:\n            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n        if momentum < 0.0:\n            raise ValueError(\"Invalid momentum value: {}\".format(momentum))\n        if weight_decay < 0.0:\n            raise ValueError(\"Invalid weight_decay value: {}\".format(weight_decay))\n\n        defaults = dict(\n            lr=lr,\n            momentum=momentum,\n            dampening=dampening,\n            weight_decay=weight_decay,\n            nesterov=nesterov,\n        )\n        if nesterov and (momentum <= 0 or dampening != 0):\n            raise ValueError(\"Nesterov momentum requires a momentum and zero dampening\")\n        super(FusedSGD, self).__init__(params, defaults)\n\n        self.wd_after_momentum = wd_after_momentum\n\n        if multi_tensor_applier.available:\n            import amp_C\n\n            # Skip buffer\n            self._dummy_overflow_buf = torch.cuda.IntTensor([0])\n            self.multi_tensor_sgd = amp_C.multi_tensor_sgd\n        else:\n            raise RuntimeError(\"apex.contrib.optimizers.FusedSGD requires cuda extensions\")\n\n    def __setstate__(self, state):\n        super(FusedSGD, self).__setstate__(state)\n        for group in self.param_groups:\n            group.setdefault(\"nesterov\", False)\n\n    def get_momentums(self, params):\n        momentums = []\n        first_run = True\n        for p in params:\n            param_state = self.state[p]\n            # torch.optim.SGD initializes momentum in the main loop, we have\n            # to do it here, and track whether or not we've done so, so that\n            # momentum application can be skipped in the main kernel.\n            if \"momentum_buffer\" not in param_state:\n                first_run = True\n                buf = param_state[\"momentum_buffer\"] = torch.zeros_like(p.data)\n                momentums.append(buf)\n            else:\n                first_run = False\n                momentums.append(param_state[\"momentum_buffer\"])\n        return momentums, first_run\n\n    def step(self, closure=None, grads=None, output_params=None, scale=1.0, grad_norms=None):\n        \"\"\"Performs a single optimization step.\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n            grads (list of tensors, optional): weight gradient to use for the\n                optimizer update. If gradients have type torch.half, parameters\n                are expected to be in type torch.float. (default: None)\n            output_params (list of tensors, optional): A reduced precision copy\n                of the updated weights written out in addition to the regular\n                updated weights. Have to be of same type as gradients. (default: None)\n            scale (float, optional): factor to divide gradient tensor values\n                by before applying to weights. (default: 1)\n        \"\"\"\n        if hasattr(self, \"_amp_stash\"):\n            raise RuntimeError(\"apex.contrib.optimizers.FusedSGD should not be used with AMP.\")\n\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        if grads is None:\n            raise RuntimeError(\n                \"apex.contrib.optimizers.FusedSGD must be wrapped \\\n\t                       with apex.contrib.optimizers.FP16_Optimizer \\\n\t\t\t       which provides grads.\"\n            )\n        # backward compatibility\n        # assuming a list/generator of parameter means single group\n        elif isinstance(grads, types.GeneratorType):\n            grads_group = [grads]\n        elif type(grads[0]) != list:\n            grads_group = [grads]\n        else:\n            grads_group = grads\n\n        if output_params is None:\n            raise RuntimeError(\n                \"apex.contrib.optimizers.FusedSGD must be wrapped \\\n                               with apex.contrib.optimizers.FP16_Optimizer \\\n                               which provides output_params.\"\n            )\n        elif isinstance(output_params, types.GeneratorType):\n            output_params_group = [output_params]\n        elif type(output_params[0]) != list:\n            output_params_group = [output_params]\n        else:\n            output_params_group = output_params\n\n        for group, grads_this_group, output_params_this_group in zip(\n            self.param_groups, grads_group, output_params_group\n        ):\n            if grads_this_group is None or output_params_this_group is None:\n                raise RuntimeError(\n                    \"apex.contrib.optimizers.FusedSGD only works \\\n                                    when all parameters require grad.\"\n                )\n\n            weight_decay = group[\"weight_decay\"]\n            momentum = group[\"momentum\"]\n            dampening = group[\"dampening\"]\n            nesterov = group[\"nesterov\"]\n            lr = group[\"lr\"]\n\n            first_runs = [True, True]\n\n            # output_params_this_group: original weights (either fp16 or fp32)\n            # group['params']: master weights (fp32)\n\n            # grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy\n            # fp32, fp32, fp32, No\n            fp32_grads = [\n                g\n                for (p, g) in zip(output_params_this_group, grads_this_group)\n                if p.dtype == torch.float32\n            ]\n            fp32_params = [\n                p2\n                for (p1, p2) in zip(output_params_this_group, group[\"params\"])\n                if p1.dtype == torch.float32\n            ]\n            fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)\n            fp32_set = [fp32_grads, fp32_params, fp32_momentums]\n\n            # fp16, fp32, fp32, Yes\n            fp16_grads = [\n                g\n                for (p, g) in zip(output_params_this_group, grads_this_group)\n                if p.dtype == torch.float16\n            ]\n            fp32_from_fp16_params = [\n                p2\n                for (p1, p2) in zip(output_params_this_group, group[\"params\"])\n                if p1.dtype == torch.float16\n            ]\n            fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)\n            fp16_params = [\n                p1\n                for (p1, p2) in zip(output_params_this_group, group[\"params\"])\n                if p1.dtype == torch.float16\n            ]\n            fp16_set = [\n                fp16_grads,\n                fp32_from_fp16_params,\n                fp32_from_fp16_momentums,\n                fp16_params,\n            ]\n\n            launch_sets = [fp16_set, fp32_set]\n\n            for launch_set, first_run in zip(launch_sets, first_runs):\n                assert len(launch_set[0]) == len(launch_set[1])\n                assert len(launch_set[0]) == len(launch_set[2])\n                if len(launch_set[0]) > 0:\n                    multi_tensor_applier(\n                        self.multi_tensor_sgd,\n                        self._dummy_overflow_buf,\n                        launch_set,\n                        weight_decay,\n                        momentum,\n                        dampening,\n                        lr,\n                        nesterov,\n                        first_run,\n                        self.wd_after_momentum,\n                        1.0 / scale,\n                    )\n\n        return loss\n"
  },
  {
    "path": "apex/contrib/peer_memory/__init__.py",
    "content": "from .peer_memory import PeerMemoryPool\nfrom .peer_halo_exchanger_1d import PeerHaloExchanger1d\n"
  },
  {
    "path": "apex/contrib/peer_memory/peer_halo_exchanger_1d.py",
    "content": "import torch\nimport peer_memory_cuda as pm\n\n\nclass PeerHaloExchanger1d:\n    def __init__(self, ranks, rank_in_group, peer_pool, half_halo):\n        self.peer_group_size = len(ranks)\n        self.ranks = ranks\n        self.peer_rank = rank_in_group\n        self.low_neighbor = (self.peer_rank + self.peer_group_size - 1) % self.peer_group_size\n        self.high_neighbor = (self.peer_rank + 1) % self.peer_group_size\n        self.low_zero = True if self.peer_rank == 0 else False\n        self.high_zero = True if self.peer_rank == self.peer_group_size - 1 else False\n\n        self.peer_pool = peer_pool\n        self.half_halo = half_halo\n\n    def _allocate_peer_tensor(self, halo):\n        # Compute size in bytes\n        # Note: Pad buffer so each CUDA block gets required buffer size\n        size = 4 * halo.numel() * halo.element_size()\n        size_per_block = 128 * 2 * 16  # 128 threads each require two 128b buffers\n        size = (size + size_per_block - 1) // size_per_block * size_per_block\n\n        # Construct dtype peer buffer with desired size\n        shape = [1, 1, 1, size // halo.element_size()]\n        return self.peer_pool.allocate_peer_tensors(shape, halo.dtype, False, True)\n\n    def __call__(self, y, H_split=True, explicit_nhwc=False, numSM=0, diagnostics=False):\n        channels_last = y.is_contiguous(memory_format=torch.channels_last) and not explicit_nhwc\n        if H_split:\n            if explicit_nhwc:\n                _, Hs, _, _ = list(y.shape)\n                H = Hs - 2 * self.half_halo\n                low_out_halo = y[:, self.half_halo : 2 * self.half_halo, :, :]\n                low_tx = self._allocate_peer_tensor(low_out_halo)\n                low_inp_halo = y[:, : self.half_halo, :, :]\n                high_out_halo = y[:, H : H + self.half_halo, :, :]\n                high_tx = self._allocate_peer_tensor(high_out_halo)\n                high_inp_halo = y[:, H + self.half_halo : H + 2 * self.half_halo, :, :]\n            else:\n                _, _, Hs, _ = list(y.shape)\n                H = Hs - 2 * self.half_halo\n                low_out_halo = y[:, :, self.half_halo : 2 * self.half_halo, :]\n                low_tx = self._allocate_peer_tensor(low_out_halo)\n                low_inp_halo = y[:, :, : self.half_halo, :]\n                high_out_halo = y[:, :, H : H + self.half_halo, :]\n                high_tx = self._allocate_peer_tensor(high_out_halo)\n                high_inp_halo = y[:, :, H + self.half_halo : H + 2 * self.half_halo, :]\n        else:\n            if explicit_nhwc:\n                _, _, Ws, _ = list(y.shape)\n                W = Ws - 2 * self.half_halo\n                low_out_halo = y[:, :, self.half_halo : 2 * self.half_halo, :]\n                low_tx = self._allocate_peer_tensor(low_out_halo)\n                low_inp_halo = y[:, :, : self.half_halo, :]\n                high_out_halo = y[:, :, W : W + self.half_halo, :]\n                high_tx = self._allocate_peer_tensor(high_out_halo)\n                high_inp_halo = y[:, :, W + self.half_halo : W + 2 * self.half_halo, :]\n            else:\n                _, _, _, Ws = list(y.shape)\n                W = Ws - 2 * self.half_halo\n                low_out_halo = y[:, :, :, self.half_halo : 2 * self.half_halo]\n                low_tx = self._allocate_peer_tensor(low_out_halo)\n                low_inp_halo = y[:, :, :, : self.half_halo]\n                high_out_halo = y[:, :, :, W : W + self.half_halo]\n                high_tx = self._allocate_peer_tensor(high_out_halo)\n                high_inp_halo = y[:, :, :, W + self.half_halo : W + 2 * self.half_halo]\n        pm.push_pull_halos_1d(\n            diagnostics,\n            explicit_nhwc,\n            numSM,\n            self.peer_rank,\n            self.low_zero,\n            low_out_halo,\n            low_tx[self.peer_rank],\n            high_tx[self.low_neighbor],\n            low_inp_halo,\n            self.high_zero,\n            high_out_halo,\n            high_tx[self.peer_rank],\n            low_tx[self.high_neighbor],\n            high_inp_halo,\n        )\n"
  },
  {
    "path": "apex/contrib/peer_memory/peer_memory.py",
    "content": "import torch\nimport numpy as np\nimport peer_memory_cuda as pm\n\n\nclass PeerMemoryPool(object):\n    def __init__(self, static_size, dynamic_size, peer_ranks=None):\n        rank = torch.distributed.get_rank()\n        world_size = torch.distributed.get_world_size()\n        ngpus = min(torch.cuda.device_count(), world_size)\n        peer_group_size = ngpus\n        peer_group = rank // ngpus\n        peer_rank_base = peer_group * ngpus\n        peer_rank = rank - peer_rank_base\n        if peer_ranks is None:\n            peer_ranks = [i + peer_rank_base for i in range(peer_group_size)]\n        peer_rank_start = peer_rank_base\n        peer_rank_end = peer_rank_start + peer_group_size - 1\n        for pr in peer_ranks:\n            assert pr >= peer_rank_start and pr <= peer_rank_end, (\n                \"%d :: peer_rank %d not on same node (ranks=[%d,%d])\"\n                % (rank, pr, peer_rank_start, peer_rank_end)\n            )\n\n        self.alignment = 256\n        self.static_size = ((static_size + self.alignment - 1) // self.alignment) * self.alignment\n        self.dynamic_size = ((dynamic_size + self.alignment - 1) // self.alignment) * self.alignment\n\n        # allocate giant pool of device memory\n        self.raw = pm.allocate_raw(self.static_size + self.dynamic_size)\n\n        # exchange peer pointers with nccl\n        raw_ipc = pm.get_raw_ipc_address(self.raw).cuda()\n        peer_raw_ipcs = [torch.empty_like(raw_ipc) for _ in range(world_size)]\n        torch.distributed.all_gather(peer_raw_ipcs, raw_ipc)\n        peer_raw_ipcs = torch.stack(peer_raw_ipcs).cpu()\n\n        # extract IPC pointers for ranks on same node\n        peer_raw = pm.get_raw_peers(\n            peer_raw_ipcs[peer_rank_base : peer_rank_base + ngpus], peer_rank, self.raw\n        )\n        self.peer_raw = [peer_raw[peer_rank - peer_rank_base] for peer_rank in peer_ranks]\n        self.static_offset = 0\n        self.dynamic_offset = 0\n        self.peer_ranks = peer_ranks\n\n    def __del__(self):\n        pm.free_raw(self.raw)\n\n    def reset(self):\n        self.dynamic_offset = 0\n\n    def allocate_peer_tensors(self, shape, dtype, channels_last, dynamic):\n        nels = np.prod(shape)\n        if dtype == torch.float16:\n            elem_size = 2\n            if dynamic:\n                start = (\n                    (self.dynamic_offset + self.alignment - 1) // self.alignment\n                ) * self.alignment\n                self.dynamic_offset = start + nels * elem_size\n                assert self.dynamic_offset < self.dynamic_size, \"Dynamic peer memory pool exhausted\"\n                return [\n                    pm.blob_view_half(pr + self.static_size + start, shape, channels_last)\n                    for pr in self.peer_raw\n                ]\n            else:\n                start = (\n                    (self.static_offset + self.alignment - 1) // self.alignment\n                ) * self.alignment\n                self.static_offset = start + nels * elem_size\n                assert self.static_offset < self.static_size, \"Static peer memory pool exhausted\"\n                return [pm.blob_view_half(pr + start, shape, channels_last) for pr in self.peer_raw]\n        if dtype == torch.float32:\n            elem_size = 4\n            if dynamic:\n                start = (\n                    (self.dynamic_offset + self.alignment - 1) // self.alignment\n                ) * self.alignment\n                self.dynamic_offset = start + nels * elem_size\n                assert self.dynamic_offset < self.dynamic_size, \"Dynamic peer memory pool exhausted\"\n                return [\n                    pm.blob_view_float(pr + self.static_size + start, shape, channels_last)\n                    for pr in self.peer_raw\n                ]\n            else:\n                start = (\n                    (self.static_offset + self.alignment - 1) // self.alignment\n                ) * self.alignment\n                self.static_offset = start + nels * elem_size\n                assert self.static_offset < self.static_size, \"Static peer memory pool exhausted\"\n                return [\n                    pm.blob_view_float(pr + start, shape, channels_last) for pr in self.peer_raw\n                ]\n        if dtype == torch.int32:\n            elem_size = 4\n            if dynamic:\n                start = (\n                    (self.dynamic_offset + self.alignment - 1) // self.alignment\n                ) * self.alignment\n                self.dynamic_offset = start + nels * elem_size\n                assert self.dynamic_offset < self.dynamic_size, \"Dynamic peer memory pool exhausted\"\n                return [\n                    pm.blob_view_int(pr + self.static_size + start, shape, channels_last)\n                    for pr in self.peer_raw\n                ]\n            else:\n                start = (\n                    (self.static_offset + self.alignment - 1) // self.alignment\n                ) * self.alignment\n                self.static_offset = start + nels * elem_size\n                assert self.static_offset < self.static_size, \"Static peer memory pool exhausted\"\n                return [pm.blob_view_int(pr + start, shape, channels_last) for pr in self.peer_raw]\n        else:\n            assert False, \"dtype %s not supported\" % (str(dtype))\n"
  },
  {
    "path": "apex/contrib/sparsity/COPYRIGHT",
    "content": "Copyright (c) 2011-2022, NVIDIA CORPORATION.  All rights reserved.\n"
  },
  {
    "path": "apex/contrib/sparsity/README.md",
    "content": "# Introduction to ASP\r\n\r\nThis serves as a quick-start for ASP (Automatic SParsity), a tool that enables sparse training and inference for PyTorch models by adding 2 lines of Python.\r\n\r\nFor details on \"[Channel Permutations for N:M Sparsity](https://proceedings.neurips.cc/paper/2021/hash/6e8404c3b93a9527c8db241a1846599a-Abstract.html),\" please see the [permutation_tests](permutation_tests/README.md) directory.\r\n\r\n## Importing ASP\r\n\r\n```\r\nfrom apex.contrib.sparsity import ASP\r\n```\r\n\r\n## Initializing ASP\r\n\r\nApart from the import statement, it is sufficient to add just the following line of code before the training phase to augment the model and the optimizer for sparse training/inference:\r\n\r\n```\r\nASP.prune_trained_model(model, optimizer)\r\n```\r\n\r\nIn the context of a typical PyTorch training loop, it might look like this:\r\n\r\n```\r\nASP.prune_trained_model(model, optimizer)\r\n\r\nx, y = DataLoader(args)\r\nfor epoch in range(epochs):\r\n    y_pred = model(x)\r\n    loss = loss_function(y_pred, y)\r\n    loss.backward()\r\n    optimizer.step()\r\n\r\ntorch.save(...)\r\n```\r\n\r\nThe `prune_trained_model` step calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step. \r\n\r\n## Generate a Sparse Network\r\n\r\nThe following approach serves as a guiding example on how to generate a pruned model that can use Sparse Tensor Cores in the NVIDIA Ampere Architecture. This approach generates a model for deployment, i.e. inference mode.\r\n\r\n```\r\n(1) Given a fully trained (dense) network, prune parameter values in a 2:4 sparse pattern.\r\n(2) Fine-tune  the  pruned  model  with  optimization  method  and  hyper-parameters (learning-rate, schedule, number of epochs, etc.) exactly as those used to obtain the trained model.\r\n(3) (If required) Quantize the model.\r\n```\r\n\r\nIn code, below is a sketch on how to use ASP for this approach (steps 1 and 2 above).\r\n\r\n```\r\nmodel = define_model(..., pretrained=True) # define model architecture and load parameter tensors with trained values (by reading a trained checkpoint)\r\ncriterion = ... # compare ground truth with model predition; use the same criterion as used to generate the dense trained model\r\noptimizer = ... # optimize model parameters; use the same optimizer as used to generate the dense trained model\r\nlr_scheduler = ... # learning rate scheduler; use the same schedule as used to generate the dense trained model\r\n\r\nfrom apex.contrib.sparsity import ASP     \r\nASP.prune_trained_model(model, optimizer) #pruned a trained model\r\n\r\nx, y = DataLoader(args)\r\nfor epoch in range(epochs): # train the pruned model for the same number of epochs as used to generate the dense trained model\r\n    y_pred = model(x)\r\n    loss = criterion(y_pred, y)\r\n    lr_scheduler.step()\r\n    loss.backward()\r\n    optimizer.step()\r\n\r\ntorch.save(...) # saves the pruned checkpoint with sparsity masks \r\n```\r\n\r\n## Non-Standard Usage\r\n\r\nIf your goal is to easily perpare a network for accelerated inference, please follow the recipe above.  However, ASP can also be used to perform experiments in advanced techniques like training with sparsity from initialization. For example, in order to recompute the sparse mask in between training steps, use the following method:\r\n\r\n```\r\nASP.compute_sparse_masks()\r\n```\r\n\r\nA more thorough example can be found in `./test/toy_problem.py`. \r\n\r\n## Advanced Usage: Channel Permutation\r\n\r\nWe introduce channel permutations as an advanced method to maximize the accuracy of structured sparse networks. By permuting weight matrices along their channel dimension and adjusting the surrounding layers appropriately, we demonstrate accuracy recovery for even small, parameter-efficient networks, without affecting inference run-time.\r\n\r\nThe final accuracy has a strong relationship with the quality of permutations. We provide the default algorithms to search for high-quality permutations. The permutation search process can be accelerated by the Apex CUDA extension: `apex.contrib.sparsity.permutation_search_kernels`\r\n\r\nIf you want to use the GPU to accelerate the permutation search process, we recommend installing Apex with permutation search CUDA extension via\r\n\r\n```\r\npip install -v --disable-pip-version-check --no-cache-dir --global-option=\"--permutation_search\" ./\r\n```\r\n\r\nIf you want to disable the permutation search process, please pass the `allow_permutation=False` to `init_model_for_pruning` function. For example:\r\n\r\n```\r\nASP.init_model_for_pruning(model, mask_calculator=\"m4n2_1d\", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d], allow_recompute_mask=False, allow_permutation=False)\r\n```\r\n\r\nPlease notice, when using multi-GPUs we should set the identical random seed for all GPUs to make sure the same results generated in permutation search. The library has implemented the `set_identical_seed` function in `permutation_lib.py`, and be called in ASP library. We still suggest the users to set the identical random seed when using multi-GPUs in their code, the example code is as follows:\r\n\r\n```\r\nimport torch\r\nimport numpy\r\nimport random\r\n\r\ntorch.manual_seed(identical_seed)\r\ntorch.cuda.manual_seed_all(identical_seed)\r\nnumpy.random.seed(identical_seed)\r\nrandom.seed(identical_seed)\r\ntorch.backends.cudnn.deterministic = True\r\ntorch.backends.cudnn.benchmark = False\r\n```\r\n\r\n## Reference Papers\r\n\r\nMore details about sparsity support on the NVIDIA Ampere GPU with Sparse Tensor Cores can refer to our [white paper](https://arxiv.org/abs/2104.08378).\r\n\r\n```\r\n@article{mishra2021accelerating,\r\n  title={Accelerating sparse deep neural networks},\r\n  author={Mishra, Asit and Latorre, Jorge Albericio and Pool, Jeff and Stosic, Darko and Stosic, Dusan and Venkatesh, Ganesh and Yu, Chong and Micikevicius, Paulius},\r\n  journal={arXiv preprint arXiv:2104.08378},\r\n  year={2021}\r\n}\r\n```\r\n\r\nThe details about sparsity with permutation can refer to our [paper](https://proceedings.neurips.cc/paper/2021/hash/6e8404c3b93a9527c8db241a1846599a-Abstract.html) published in *Thirty-fourth Conference on Neural Information Processing Systems* (**NeurIPS 2021**):\r\n\r\n```\r\n@inproceedings{pool2021channel,\r\n  author    = {Pool, Jeff and Yu, Chong},\r\n  booktitle = {Advances in Neural Information Processing Systems ({NeurIPS})},\r\n  title     = {Channel Permutations for {N:M} Sparsity},\r\n  url       = {https://proceedings.neurips.cc/paper/2021/file/6e8404c3b93a9527c8db241a1846599a-Paper.pdf},\r\n  volume    = {34},\r\n  year      = {2021}\r\n}\r\n\r\n```\r\n"
  },
  {
    "path": "apex/contrib/sparsity/__init__.py",
    "content": "from .sparse_masklib import create_mask\nfrom .asp import ASP\n"
  },
  {
    "path": "apex/contrib/sparsity/asp.py",
    "content": "import types\nimport torch\nfrom .sparse_masklib import create_mask\nfrom .permutation_lib import Permutation\n\ntorchvision_imported = True\ntry:\n    import torchvision\nexcept ImportError:\n    print(\"[ASP][Warning] torchvision cannot be imported.\")\n    torchvision_imported = False\n\nimport os\nimport time\n\n\ndef eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallowed_layer_names):\n    eligible_modules_list = []\n    for name, mod in model.named_modules():\n        if isinstance(mod, whitelist_layer_types) and name not in disallowed_layer_names:\n            if allowed_layer_names is not None and name not in allowed_layer_names:\n                continue\n            eligible_modules_list.append((name, mod))\n    return eligible_modules_list\n\n\nclass ASP:\n    __model = None\n    __verbosity = 0\n    __optimizer = None\n    __sparse_parameters = []\n    __calculate_mask = None\n    __allow_permutation = True\n    __all_parameters = []\n    __save_permutation_graph = False\n    __permutation_output_dir = \"\"\n\n    @classmethod\n    def init_model_for_pruning(\n        cls,\n        model,\n        mask_calculator=\"m4n2_1d\",\n        verbosity=3,\n        whitelist=[\n            torch.nn.Linear,\n            torch.nn.Conv1d,\n            torch.nn.Conv2d,\n            torch.nn.Conv3d,\n            torch.nn.MultiheadAttention,\n        ],\n        allowed_layer_names=None,\n        disallowed_layer_names=[],\n        allow_recompute_mask=False,\n        custom_layer_dict={},\n        allow_permutation=True,\n    ):\n        \"\"\"Call this method to modify your model to take advantage of sparse matrix multiplication.\n        Note that this call alone only augments the model with additional buffers needed for sparse MMA,\n        it does not enable use of sparse MMA.\n\n        If you are starting with a fresh model:\n\n        model = ...\n        ASP.init_model_for_pruning(model, mask_calculator, ...)\n        if (training) ASP.init_optimizer_for_pruning(optimizer)\n        ASP.compute_sparse_masks() // sparsity is off by default, call when youy want to enable it.\n\n        If you are starting from a checkpoint:\n\n        model = ...\n        ASP.init_model_for_pruning(model, mask_calculator, ...)\n        torch.load(...)\n        if (training) ASP.init_optimizer_for_pruning(optimizer)\n\n        Arguments:\n          model                    The model\n          mask_calculator          Either callable that computes mask given a tensor OR pattern string for sparse mask lib.\n          verbosity                Integer controling verbosity level.\n                                   0 -> Only errors.\n                                   1 -> Errors and warnings.\n                                   2 -> Errors, warnings and info.\n                                   3 -> Errors, warnings, info and debug.\n          whitelist                Module types approved for sparsity.\n          allowed_layer_names      If not None, only layer names that appear in this list are considered for sparsity.\n          disallowed_layer_names   If not [], only layer names that do not appear in this list are considered for sparsity.\n          allow_recompute_mask     If True, stores pruned values so that dense weights can be restored.\n                                   Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage.\n          custom_layer_dict        Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']}\n          allow_permutation        If True, allow the input channel permutation to ease the influence of weight pruning.\n\n          [Future] Support for allow_recompute_mask can be removed, it is not part of sparse inference recipe.\n        \"\"\"\n        assert cls.__model is None, \"ASP has been initialized already.\"\n        cls.__model = model\n        cls.__verbosity = verbosity\n        cls.__allow_permutation = allow_permutation\n\n        if isinstance(mask_calculator, str):\n\n            def create_mask_from_pattern(param):\n                return create_mask(param, mask_calculator).bool()\n\n            cls.__calculate_mask = create_mask_from_pattern\n        else:\n            cls.__calculate_mask = mask_calculator  # user defined function\n\n        # function to extract variables that will be sparsified.\n        # idea is that you will add one of these functions for each module type that can be sparsified.\n        if torchvision_imported:\n            print(\n                \"[ASP] torchvision is imported, can work with the MaskRCNN/KeypointRCNN from torchvision.\"\n            )\n            torchvision_version = str(torchvision.__version__)\n            torchvision_version_major = int(torchvision_version.split(\".\")[0])\n            torchvision_version_minor = int(torchvision_version.split(\".\")[1])\n            if torchvision_version_major == 0 and torchvision_version_minor < 12:\n                sparse_parameter_list = {\n                    torch.nn.Linear: [\"weight\"],\n                    torch.nn.Conv1d: [\"weight\"],\n                    torch.nn.Conv2d: [\"weight\"],\n                    torch.nn.Conv3d: [\"weight\"],\n                    torch.nn.modules.linear.NonDynamicallyQuantizableLinear: [\"weight\"],\n                    torch.nn.MultiheadAttention: [\n                        \"q_proj_weight\",\n                        \"k_proj_weight\",\n                        \"v_proj_weight\",\n                        \"in_proj_weight\",\n                    ],\n                    torchvision.ops.misc.Conv2d: [\"weight\"],\n                }\n            else:  # Torchvision remove APIs that were deprecated before 0.8 (#5386) in 0.12.0, torchvision.ops.misc.Conv2d is removed\n                sparse_parameter_list = {\n                    torch.nn.Linear: [\"weight\"],\n                    torch.nn.Conv1d: [\"weight\"],\n                    torch.nn.Conv2d: [\"weight\"],\n                    torch.nn.Conv3d: [\"weight\"],\n                    torch.nn.modules.linear.NonDynamicallyQuantizableLinear: [\"weight\"],\n                    torch.nn.MultiheadAttention: [\n                        \"q_proj_weight\",\n                        \"k_proj_weight\",\n                        \"v_proj_weight\",\n                        \"in_proj_weight\",\n                    ],\n                }\n        else:\n            sparse_parameter_list = {\n                torch.nn.Linear: [\"weight\"],\n                torch.nn.Conv1d: [\"weight\"],\n                torch.nn.Conv2d: [\"weight\"],\n                torch.nn.Conv3d: [\"weight\"],\n                torch.nn.modules.linear.NonDynamicallyQuantizableLinear: [\"weight\"],\n                torch.nn.MultiheadAttention: [\n                    \"q_proj_weight\",\n                    \"k_proj_weight\",\n                    \"v_proj_weight\",\n                    \"in_proj_weight\",\n                ],\n            }\n        if custom_layer_dict:  # Update default list to include user supplied custom (layer type : parameter tensor), make sure this tensor type is something ASP knows how to prune\n            sparse_parameter_list.update(custom_layer_dict)\n            whitelist += list(custom_layer_dict.keys())\n\n        for module_type in whitelist:\n            assert module_type in sparse_parameter_list, (\n                \"Module %s :: Don't know how to sparsify module.\" % module.dtype()\n            )\n\n        # find all sparse modules, extract sparse parameters and decorate\n        def add_sparse_attributes(module_name, module):\n            sparse_parameters = sparse_parameter_list[type(module)]\n            for p_name, p in module.named_parameters():\n                if p_name in sparse_parameters and p.requires_grad:\n                    # check for NVIDIA's TC compatibility: we check along the horizontal direction\n                    if p.dtype == torch.float32 and (\n                        (p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0\n                    ):  # User defines FP32 and APEX internally uses FP16 math\n                        print(\n                            \"[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity\"\n                            % (module_name, p_name, str(p.size()), str(p.dtype))\n                        )\n                        continue\n                    if p.dtype == torch.float16 and (\n                        (p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0\n                    ):  # For Conv2d dim= K x CRS; we prune along C\n                        print(\n                            \"[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity\"\n                            % (module_name, p_name, str(p.size()), str(p.dtype))\n                        )\n                        continue\n\n                    if cls.__verbosity >= 3:\n                        print(\n                            \"[ASP] Sparsifying %s::%s of size=%s and type=%s for sparsity\"\n                            % (module_name, p_name, str(p.size()), str(p.dtype))\n                        )\n\n                    mask = torch.ones_like(p).bool()\n                    buffname = p_name.split(\".\")[-1]  # buffer names cannot contain \".\"\n                    module.register_buffer(\"__%s_mma_mask\" % buffname, mask)\n                    if allow_recompute_mask:\n                        pruned = torch.zeros_like(p).cpu()\n                        module.register_buffer(\"__%s_mma_pruned_p\" % buffname, pruned)\n                    else:\n                        pruned = None\n                    cls.__sparse_parameters.append((module_name, module, p_name, p, mask, pruned))\n                else:\n                    if cls.__verbosity >= 3:\n                        print(\n                            \"[ASP] Not sparsifying %s::%s of size=%s and type=%s\"\n                            % (module_name, p_name, str(p.size()), str(p.dtype))\n                        )\n\n        for name, sparse_module in eligible_modules(\n            model, tuple(whitelist), allowed_layer_names, disallowed_layer_names\n        ):\n            add_sparse_attributes(name, sparse_module)\n\n        if allow_permutation:  # find all named modules, extract parameters and decorate, used for offline permutation in K dim\n            for module_name, module in model.named_modules():\n                module_type_str = str(type(module)).split(\"'\")[1]\n                if (\n                    module_type_str == \"torch.nn.modules.container.Sequential\"\n                    or module_type_str.startswith(\"torchvision.models\")\n                ):\n                    # filter out the 'torch.nn.modules.container.Sequential' type and the whole model, like 'torchvision.models.vgg.VGG'\n                    continue\n                for p_name, p in module.named_parameters():\n                    cls.__all_parameters.append((module_name, module, p_name, p))\n                if module_type_str == \"torch.nn.modules.batchnorm.BatchNorm2d\":\n                    # need to get the running_mean and running_var from model.state_dict(), as they are not the learnable parameters\n                    module_mean_name = module_name + \".running_mean\"\n                    module_var_name = module_name + \".running_var\"\n                    for param_key in model.state_dict():\n                        if module_mean_name == param_key or module_var_name == param_key:\n                            cls.__all_parameters.append(\n                                (\n                                    module_name,\n                                    module,\n                                    param_key.split(\".\")[-1],\n                                    model.state_dict()[param_key],\n                                )\n                            )\n            # add the __permutation_output_dir field to save the intermediate results for permutation\n            cls.__permutation_output_dir = \".\"\n            # Set the corresponding params from ASP class to the Permutation class\n            permutation_verbosity = 5\n            Permutation.set_permutation_params_from_asp(\n                cls.__model,\n                cls.__sparse_parameters,\n                cls.__all_parameters,\n                permutation_verbosity,\n            )\n            # Set the identical random seed for all GPUs to make sure the same results generated in permutation search\n            Permutation.set_identical_seed()\n\n    @classmethod\n    def already_init_asp_model(cls):\n        \"\"\"Call this method to check whether ASP has been initialized already.\"\"\"\n        if cls.__model is None:\n            if cls.__verbosity >= 3:\n                print(\"[ASP] ASP has not been initialized.\")\n                return False\n        else:\n            if cls.__verbosity >= 3:\n                print(\"[ASP] ASP has been initialized already.\")\n                return True\n\n    @classmethod\n    def init_optimizer_for_pruning(cls, optimizer):\n        \"\"\"Call this method to monkey patch optimizer step function so that masks can be applied to\n        gradients and weights during training.\n        You must call init_model_for_pruning(...) before calling init_optimizer_for_pruning(...)\n        \"\"\"\n        assert cls.__optimizer is None, \"ASP has initialized optimizer already.\"\n        assert cls.__calculate_mask is not None, (\n            \"Called ASP.init_optimizer_for_pruning before ASP.init_model_for_pruning.\"\n        )\n\n        # store pointer to original optimizer step method\n        cls.__optimizer = optimizer\n        cls.__optimizer.__step = optimizer.step\n\n        def __step(opt_self, *args, **kwargs):\n            # prune gradients before step method\n            with torch.no_grad():\n                for (\n                    module_name,\n                    module,\n                    p_name,\n                    p,\n                    mask,\n                    pruned,\n                ) in cls.__sparse_parameters:\n                    if p.grad is not None:  # thx pjudd\n                        p.grad.mul_(mask)\n            # call original optimizer step method\n            rval = opt_self.__step(*args, **kwargs)\n            # prune parameters after step method\n            with torch.no_grad():\n                for (\n                    module_name,\n                    module,\n                    p_name,\n                    p,\n                    mask,\n                    pruned,\n                ) in cls.__sparse_parameters:\n                    p.mul_(mask)\n            return rval\n\n        cls.__optimizer.step = types.MethodType(__step, cls.__optimizer)\n\n    @classmethod\n    def compute_sparse_masks(cls):\n        \"\"\"Call this method to enable sparsity.\n        If init(...) was called with allow_recompute_mask=False AND sparsity is disabled, pruned field can be None.\n        \"\"\"\n        with torch.no_grad():\n            if cls.__allow_permutation:\n                # Step 1: use the Torch.FX library to build the graph\n                # Step 2: permutation search with the customized kernel\n                # The simplest without user intervention:\n                # A. try to import with the distributed mode of the original model\n                # B. if meet the error, import with the none-distributed mode of the original model\n                start_time_permute = time.perf_counter()\n                successful_permutation = False\n                try:\n                    successful_permutation = Permutation.permute_model(\n                        cls.__model.module,\n                        dump_fx_graph=cls.__save_permutation_graph,\n                        save_dumped_fx_graph=os.path.join(\n                            cls.__permutation_output_dir,\n                            \"model_offline_permutation_graph.json\",\n                        ),\n                    )\n                    if successful_permutation:\n                        print(\"\\n[compute_sparse_masks] permuted the (distributed) model.\")\n                except AttributeError:\n                    successful_permutation = Permutation.permute_model(\n                        cls.__model,\n                        dump_fx_graph=cls.__save_permutation_graph,\n                        save_dumped_fx_graph=os.path.join(\n                            cls.__permutation_output_dir,\n                            \"model_offline_permutation_graph.json\",\n                        ),\n                    )\n                    if successful_permutation:\n                        print(\"\\n[compute_sparse_masks] permuted the model.\")\n\n                if successful_permutation:\n                    duration_build_offline_permutation_graph = (\n                        time.perf_counter() - start_time_permute\n                    )\n                    print(\n                        \"[compute_sparse_masks] Take {:.4f} seconds to find and apply permutations.\".format(\n                            duration_build_offline_permutation_graph\n                        )\n                    )\n\n            for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:\n                if mask.sum() < mask.numel():  # when recalculating masks\n                    # restore dense parameter if allow_recompute_mask is enabled\n                    assert pruned is not None, (\n                        \"Unable to restore dense parameter because allow_recompute_mask == False\"\n                    )\n                    p.add_(pruned.cuda())\n\n                mask.set_(cls.__calculate_mask(p))\n\n                if pruned is not None:  # stow away pruned weights to cpu\n                    pruned.set_((p * (~mask)).cpu())\n\n                p.mul_(\n                    mask\n                )  # in-place multiplication, so pruned weights are 0-values, hence checkpoint will have 0s for pruned weights\n                if cls.__verbosity >= 2:\n                    print(\n                        \"[ASP] Enabled %.2f%% sparsity for %s::%s of size=%s and type=%s with magnitude %s\"\n                        % (\n                            100.0 - 100.0 * mask.sum() / mask.numel(),\n                            module_name,\n                            p_name,\n                            str(p.size()),\n                            str(p.dtype),\n                            torch.sum(torch.abs(p)),\n                        )\n                    )\n\n    @classmethod\n    def restore_pruned_weights(cls):\n        \"\"\"Call this method to disable sparsity and restore all weights.\n        This will only work if init(...) was called with allow_recompute=True.\n        \"\"\"\n        with torch.no_grad():\n            for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:\n                if mask.sum() < mask.numel():\n                    assert pruned is not None, (\n                        \"Unable to restore dense parameter because allow_recompute_mask == False\"\n                    )\n                    p.add_(pruned.cuda())\n                    mask.fill_(1)\n                    pruned.zero_()\n                    if cls.__verbosity >= 2:\n                        print(\n                            \"[ASP] Disabled sparsity for %s::%s (dense weights restored)\"\n                            % (module_name, p_name)\n                        )\n\n    @classmethod\n    def is_sparsity_enabled(cls):\n        \"\"\"Call this method to determine if sparsity is enabled in the model.\n        The typical use case is right after checkpoint has been loaded.\n        \"\"\"\n        total, sp100, sp50 = 0, 0, 0\n        for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:\n            total += 1\n            mask_sum = mask.sum()\n            mask_numel = mask.numel()\n            if mask_sum == mask_numel:\n                sp100 += 1\n            elif mask_sum * 2 == mask_numel:\n                sp50 += 1\n\n        assert total == sp100 or total == sp50, \"Inconsistent model sparsity\"\n        if total == sp100:\n            return False\n        elif total == sp50:\n            return True\n\n    @classmethod\n    def prune_trained_model(cls, model, optimizer):\n        # add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks)\n        cls.init_model_for_pruning(\n            model,\n            mask_calculator=\"m4n2_1d\",\n            verbosity=2,\n            whitelist=[torch.nn.Linear, torch.nn.Conv2d, torch.nn.MultiheadAttention],\n            allow_recompute_mask=False,\n        )\n        cls.init_optimizer_for_pruning(optimizer)\n        cls.compute_sparse_masks()\n\n    @classmethod\n    def set_permutation_saving_params(\n        cls,\n        allow_permutation=True,\n        save_permutation_graph=False,\n        permutation_output_dir=\".\",\n    ):\n        \"\"\"This function is used to set the permutation saving related parameters in ASP class and inside of the Permutation class.\"\"\"\n        print(\"\\n[ASP][set_permutation_saving_param] Set permutation saving related parameters\")\n        print(\"\\n[set_permutation_saving_param] Set permutation saving related parameters\")\n        cls.__allow_permutation = allow_permutation\n        print(\n            \"[set_permutation_saving_param]\\t Allow permutation: {}\".format(cls.__allow_permutation)\n        )\n        cls.__save_permutation_graph = save_permutation_graph\n        print(\n            \"[set_permutation_saving_param]\\t Save permutation graphs: {}\".format(\n                cls.__save_permutation_graph\n            )\n        )\n        cls.__permutation_output_dir = permutation_output_dir\n        print(\n            \"[set_permutation_saving_param]\\t Permutation graphs saving dir: {}\".format(\n                cls.__permutation_output_dir\n            )\n        )\n\n        Permutation.set_permutation_saving_params(\n            allow_permutation, save_permutation_graph, permutation_output_dir\n        )\n"
  },
  {
    "path": "apex/contrib/sparsity/permutation_lib.py",
    "content": "import os\nimport torch\nimport json\nimport string\nimport time\nimport numpy as np\nimport builtins as __builtin__\nimport io\n\ntry:\n    from .permutation_search_kernels import (\n        accelerated_search_for_good_permutation,\n        sum_after_2_to_4,\n    )\n\n    print(\"[ASP][Info] permutation_search_kernels can be imported.\")\nexcept ImportError:\n    print(\"[ASP][Warning] permutation_search_kernels cannot be imported.\")\n    print(\n        \"[ASP][Warning] If you want to accelerate the permutation search process by GPU, please build APEX by following the instructions at https://github.com/NVIDIA/apex/blob/master/apex/contrib/sparsity/README.md\"\n    )\n\n\ndef convert_fx_node_name(fx_node_name):\n    \"\"\"Standardize punctuation of a node's name: replace all '_' with '.'\"\"\"\n    return fx_node_name.replace(\"_\", \".\")\n\n\ndef get_node_parent_children(fx_node):\n    \"\"\"Populate lists of all direct parents and children of a node\"\"\"\n    # get node parent list, and convert node name to module name\n    node_parent_name_converted = []\n    if len(fx_node.all_input_nodes) > 0:\n        node_parent = fx_node.all_input_nodes\n        for item in node_parent:\n            converted_item = convert_fx_node_name(item.name)\n            node_parent_name_converted.append(converted_item)\n    else:\n        node_parent = []\n\n    # get node children list, and convert node name to module name\n    node_children_name_converted = []\n    if len(list(fx_node.users.keys())) > 0:\n        node_children = list(fx_node.users.keys())\n        for item in node_children:\n            converted_item = convert_fx_node_name(item.name)\n            node_children_name_converted.append(converted_item)\n    else:\n        node_children = []\n\n    return node_parent_name_converted, node_children_name_converted\n\n\ndef node_name_matches(node_name, module_name):\n    \"\"\"Check for a match between graph node name and stored module name, accounting for formatting and DDP training differences\"\"\"\n\n    # process: remove all punctuation, everything to lower case\n    def process(name):\n        return \"\".join(c for c in name if c not in string.punctuation).lower()\n\n    processed_node_name = process(node_name)\n    processed_module_name = process(module_name)\n\n    # module names start with 'module.' in distributed data-parallel training, but fx graph node names don't; check for both\n    distributed_node_name = \"module.\" + node_name\n    distributed_processed_node_name = \"module\" + processed_node_name\n\n    return (\n        (node_name == module_name)\n        or (distributed_node_name == module_name)\n        or (processed_node_name == processed_module_name)\n        or (distributed_processed_node_name == processed_module_name)\n    )\n\n\ndef replicate_sequence(sequence, replications):\n    \"\"\"Replicate a permutation to apply it to an even multiple of channel counts\"\"\"\n    replicated_sequence = []\n\n    for rep in range(replications):\n        offset = len(sequence) * rep\n        for c in sequence:\n            replicated_sequence.append(c + offset)\n\n    return replicated_sequence\n\n\nclass Permutation:\n    __model = None\n    __sparse_parameters = []\n    __allow_permutation = False\n    __all_parameters = []\n    __verbosity = 0  ## 0: errors only, 1: also high-level details, warnings, 2: also intermediate steps, 3: everything\n    __params_permuted_in_C = []\n    __params_permuted_in_K = []\n    __unpermuted_dims = []\n\n    __save_permutation_graph = False\n    __permutation_output_dir = \"\"\n    __manual_seed = None\n    __tcpstore_port = 2341\n\n    # these module types may be the target of permutations (have potentially sparse weights or are attributes with no parents)\n    __permutation_target_module_types = [\n        \"torch.nn.modules.conv.Conv1d\",\n        \"torch.nn.modules.conv.Conv2d\",\n        \"torch.nn.modules.linear.Linear\",\n        \"torch.nn.modules.linear.LazyLinear\",\n        \"torch.nn.modules.linear.NonDynamicallyQuantizableLinear\",\n        \"torch.nn.modules.activation.MultiheadAttention\",\n        \"get_attr\",\n    ]\n\n    # these module types are not permuted, but must pass any permutation seen by a child's C or passed-thru K to the parents' K\n    __simple_passthru_module_types = [\n        \"torch.nn.modules.activation.ReLU6\",\n        \"torch.nn.modules.activation.ReLU\",\n        \"torch.nn.modules.dropout.Dropout\",\n        \"torch.nn.modules.dropout.Dropout1d\",\n        \"torch.nn.modules.dropout.Dropout2d\",\n        \"torch.nn.modules.dropout.Dropout3d\",\n        \"torch.nn.modules.dropout.AlphaDropout\",\n        \"torch.nn.modules.dropout.FeatureAlphaDropout\",\n        \"torch.nn.modules.pooling.MaxPool2d\",\n        \"torch.nn.modules.pooling.AdaptiveAvgPool2d\",\n        \"torch.nn.modules.pooling.AvgPool2d\",\n        \"torch.nn.modules.activation.Hardsigmoid\",\n        \"torch.nn.modules.activation.Hardswish\",\n        \"torch.nn.modules.activation.GELU\",\n        \"torch.nn.modules.normalization.LocalResponseNorm\",\n        \"torch.nn.modules.activation.Softmin\",\n        \"torch.nn.modules.activation.Softmax\",\n        \"torch.nn.modules.activation.Softmax2d\",\n        \"torch.nn.modules.activation.LogSoftmax\",\n        \"torch.nn.modules.activation.AdaptiveLogSoftmaxWithLoss\",\n        \"torch.nn.modules.activation.SiLU\",\n        \"torch.nn.modules.activation.Sigmoid\",\n        \"concat\",\n        \"torch.nn.modules.flatten.Flatten\",  # if it's a problem, it'll be handled via dimension mismatch check\n    ]\n\n    # these module types have parameters that must be permuted along K as well as need to pass the permutation thru to parents' K\n    __permute_K_and_passthru_module_types = [\n        \"torch.nn.modules.batchnorm.BatchNorm2d\",\n        \"torch.nn.modules.normalization.LayerNorm\",\n        \"torch.nn.modules.instancenorm.InstanceNorm2d\",\n        \"torch.nn.modules.batchnorm.SyncBatchNorm\",\n    ]\n\n    # these module types cannot be permuted safely (today), and cause neighboring layers to have permutations disabled\n    __disallow_permutations_module_types = [\n        \"torch.nn.modules.normalization.GroupNorm\",  # to handle: influence GCD of real children's sibling group\n        \"torch.nn.modules.linear.Bilinear\",  # need to permute one input along in1_features and the other along in2_features\n        \"torch.nn.modules.activation.GLU\",  # may work OOTB, but might need to explicitly handle dimsionality change\n    ]\n\n    @classmethod\n    def set_identical_seed(cls, identical_seed=1):\n        \"\"\"Make all GPUs in DDP use the same seed to find identical permutations and not require syncing parameters later\"\"\"\n\n        if cls.__verbosity > 0:\n            print(\n                \"[set_identical_seed] Set the identical seed: {:} for all GPUs to make sure the same results generated in permutation search\".format(\n                    identical_seed\n                )\n            )\n\n        cls.__manual_seed = identical_seed\n        cls.reset_seed()\n\n    @classmethod\n    def reset_seed(cls):\n        \"\"\"To find the same permutations no matter how many GPUs are used, we reset the seed before every search\"\"\"\n\n        identical_seed = cls.__manual_seed\n        assert identical_seed is not None, \"Must call set_identical_seed() before it can be reset\"\n\n        torch.manual_seed(identical_seed)\n        torch.cuda.manual_seed(identical_seed)\n        import random\n\n        np.random.seed(identical_seed)\n        random.seed(identical_seed)\n        torch.backends.cudnn.deterministic = True\n        torch.backends.cudnn.benchmark = False\n\n    @classmethod\n    def set_tcpstore_port(cls, tcpstore_port):\n        \"\"\"Override the default port if it is in use in a distributed training session\"\"\"\n\n        cls.__tcpstore_port = tcpstore_port\n        if cls.__verbosity > 0:\n            print(f\"[set_tcpstore_port] TCPStore port set to {cls.__tcpstore_port} .\")\n\n    @classmethod\n    def set_permutation_saving_params(\n        cls,\n        allow_permutation=False,\n        save_permutation_graph=False,\n        permutation_output_dir=\".\",\n    ):\n        \"\"\"This function is used to set the permutation saving related parameters.\"\"\"\n\n        cls.__allow_permutation = allow_permutation\n        cls.__save_permutation_graph = save_permutation_graph\n        cls.__permutation_output_dir = permutation_output_dir\n\n        if cls.__verbosity > 0:\n            print(\n                f\"[permutation_lib][set_permutation_saving_param] Set permutation saving related parameters\\n\\tAllow permutation: {cls.__alow_permutation}\\n\\tSave permutation graphs: {cls.__save_permutation_graph}\\n\\tPermutation graphs saving dir: {cls.__permutation_output_dir}\"\n            )\n\n    @classmethod\n    def set_permutation_params_from_asp(cls, model, sparse_parameters, all_parameters, verbosity):\n        \"\"\"This function is used to set the permutation needed parameters from ASP class.\"\"\"\n        cls.__verbosity = verbosity\n\n        if cls.__verbosity > 0:\n            print(\"[set_permutation_params_from_asp] Set permutation needed parameters\")\n        cls.__model = model\n        cls.__sparse_parameters = sparse_parameters\n        cls.__all_parameters = all_parameters\n\n        if cls.__verbosity > 1:\n            sparse_param_names = [\n                module_name + \":\" + p_name\n                for (\n                    module_name,\n                    module,\n                    p_name,\n                    p,\n                    mask,\n                    pruned,\n                ) in cls.__sparse_parameters\n            ]\n            all_param_names = [\n                module_name + \":\" + p_name\n                for (module_name, module, p_name, p) in cls.__all_parameters\n            ]\n            print(\n                f\"\\tSparse parameter names: {sparse_param_names}\\n\\tAll parameter names: {all_param_names}\"\n            )\n\n        cls.__params_permuted_in_C = []\n        cls.__params_permuted_in_K = []\n        cls.__unpermuted_dims = []\n\n    @classmethod\n    def permute_model(\n        cls,\n        model,\n        dump_fx_graph=False,\n        save_dumped_fx_graph=\"./model_permutation_graph.json\",\n    ):\n        \"\"\"Permute a model's weights in order to maintain more magnitude after enforcing the sparsity constraint.\"\"\"\n\n        if cls.__verbosity > 0:\n            print(\"\\n[permute_model] Permuting the model\")\n\n        # extract the output_dir, so all the intermediate fx_graph can be saved under that path\n        extract_output_dir = os.path.split(save_dumped_fx_graph)[0]\n        cls.__permutation_output_dir = extract_output_dir\n        fx_graph, success_in_build_fx_graph = cls.build_fx_graph(\n            model,\n            dump_fx_graph=dump_fx_graph,\n            save_dumped_fx_graph=save_dumped_fx_graph,\n        )\n\n        if success_in_build_fx_graph:\n            fx_graph_after_init_flags = cls.init_permutation_flags(fx_graph)\n            fx_graph_after_find_real_parents = cls.find_real_parents(fx_graph_after_init_flags)\n            fx_graph_after_find_real_children = cls.find_real_children(\n                fx_graph_after_find_real_parents\n            )\n            fx_graph_after_making_groups = cls.make_sibling_coparent_groups(\n                fx_graph_after_find_real_children\n            )\n            fx_graph_after_fixup_concats = cls.fixup_concats(fx_graph_after_making_groups)\n            fx_graph_after_enforce_dimension_agreement = cls.enforce_dimension_agreement(\n                fx_graph_after_fixup_concats\n            )\n            fx_graph_after_propagate_flags = cls.propagate_permutation_flags(\n                fx_graph_after_enforce_dimension_agreement\n            )\n\n            start_time_search_for_good_permutation = time.perf_counter()\n            fx_graph_after_find_permutations = cls.find_permutations(fx_graph_after_propagate_flags)\n\n            if torch.distributed.is_initialized():\n                if cls.__verbosity > 0:\n                    duration_search_for_good_permutation = (\n                        time.perf_counter() - start_time_search_for_good_permutation\n                    )\n                    print(\n                        f\"[permute_model] Rank {torch.distributed.get_rank()} completed search in {duration_search_for_good_permutation:.2f}s, waiting for others.\",\n                        force=True,\n                    )\n                torch.distributed.barrier()\n\n            duration_search_for_good_permutation = (\n                time.perf_counter() - start_time_search_for_good_permutation\n            )\n            if cls.__verbosity > 0:\n                print(\n                    \"\\n[permute_model] Take {:.4f} seconds to finish search_for_good_permutation function.\".format(\n                        duration_search_for_good_permutation\n                    )\n                )\n\n            fx_graph_after_sync_permutations = cls.sync_permutations(\n                fx_graph_after_find_permutations\n            )\n            fx_graph_after_apply_permutations = cls.apply_permutations(\n                fx_graph_after_sync_permutations\n            )\n            cls.check_graph_for_unpermuted_nodes(fx_graph_after_apply_permutations)\n\n            fx_graph = fx_graph_after_apply_permutations\n\n        if cls.__save_permutation_graph:\n            cls.save_graph_to_json(\n                fx_graph,\n                save_dumped_graph_path_with_name=os.path.join(\n                    cls.__permutation_output_dir, \"./model_graph_permutation_graph.json\"\n                ),\n            )  # save the intermediate graph as JSON file for debugging\n\n        return success_in_build_fx_graph\n\n    @classmethod\n    def get_permutation_stats(cls):\n        \"\"\"Return statistics for how many permutations were applied in various dimensions, used for testing\"\"\"\n\n        return (\n            cls.__params_permuted_in_C,\n            cls.__params_permuted_in_K,\n            cls.__unpermuted_dims,\n        )\n\n    @classmethod\n    def apply_permutation_in_C_dim(cls, node_name, permutation_sequence, dryrun):\n        \"\"\"This function is used to permutation for a node in C dim. (Only need to handle the weight of the node)\"\"\"\n\n        if cls.__verbosity > 1 and dryrun:\n            print(\n                \"[apply_permutation_in_C_dim] Permutation for node: '{:}' in C dim\".format(\n                    node_name\n                )\n            )\n\n        if len(permutation_sequence) == 0:\n            if cls.__verbosity >= 0:\n                print(\n                    f\"ERROR: [apply_permutation_in_C_dim] the permutation sequence for node {node_name} is empty, fail to apply permutation in C dim.\"\n                )\n            return False\n\n        is_node_in_sparse_parameters = False\n        success_permutation = False\n        for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:\n            if node_name_matches(node_name, module_name):\n                if cls.__verbosity > 2 and dryrun:\n                    print(\n                        \"[apply_permutation_in_C_dim] find the node: '{:}' '{:}' in cls.__sparse_parameters, succeed to apply permutation in C dim.\".format(\n                            node_name, p_name\n                        )\n                    )\n                is_node_in_sparse_parameters = True\n                permutation_to_apply = permutation_sequence\n                if p.shape[1] != len(\n                    permutation_sequence\n                ):  # assumed to be grouped convolutions or concatenated weights\n                    if p.shape[1] % len(permutation_sequence) != 0:\n                        return False\n\n                    permutation_to_apply = replicate_sequence(\n                        permutation_sequence, p.shape[1] // len(permutation_sequence)\n                    )\n\n                if not dryrun:\n                    p.data.copy_(p[:, permutation_to_apply, ...])\n                    cls.__params_permuted_in_C.append(node_name + \".\" + p_name)\n\n                success_permutation = True\n        if not is_node_in_sparse_parameters:\n            # A special case: if the node itself not in sparse_module_names but one of its real_siblings in sparse_module_names, then the node will not do the permutation search, but it may need to apply the offline permutation in C dim according to the searched permutation sequence from its real_siblings in sparse_module_names\n            try:\n                for (\n                    module_name_from_all_parameters,\n                    module_from_all_parameters,\n                    p_name_from_all_parameters,\n                    p_from_all_parameters,\n                ) in cls.__all_parameters:\n                    if (\n                        node_name_matches(node_name, module_name_from_all_parameters)\n                        and p_name_from_all_parameters == \"weight\"\n                    ):\n                        if cls.__verbosity > 3 and dryrun:\n                            print(\n                                \"[apply_permutation_in_C_dim] cannot find the node: '{:}' '{:}' in cls.__sparse_parameters, but can find in cls.__all_parameters.\".format(\n                                    node_name, p_name_from_all_parameters\n                                )\n                            )\n                        permutation_to_apply = permutation_sequence\n                        if p_from_all_parameters.shape[1] != len(\n                            permutation_sequence\n                        ):  # assumed to be grouped convolutions\n                            if p_from_all_parameters.shpae[1] % len(permutation_sequence) != 0:\n                                return False\n\n                            permutation_to_apply = replicate_sequence(\n                                permutation_sequence,\n                                p_from_all_parameters.shape[1] // len(permutation_sequence),\n                            )\n\n                        if not dryrun:\n                            p_from_all_parameters.data.copy_(\n                                p_from_all_parameters[:, permutation_to_apply, ...]\n                            )\n                            cls.__params_permuted_in_C.append(\n                                node_name + \".\" + p_name_from_all_parameters\n                            )\n\n                        success_permutation = True\n                        if cls.__verbosity > 2 and dryrun:\n                            print(\n                                \"[apply_permutation_in_C_dim] cannot find the node: '{:}' in cls.__sparse_parameters, after trying with cls.__all_parameters, succeed to apply permutation in C dim.\".format(\n                                    node_name\n                                )\n                            )\n            except:\n                success_permutation = False\n                if cls.__verbosity >= 0:\n                    print(\n                        \"ERROR: [apply_permutation_in_C_dim] cannot find the node: '{:}' in cls.__sparse_parameters, after trying with cls.__all_parameters, still fail to apply permutation in C dim.\".format(\n                            node_name\n                        )\n                    )\n        return success_permutation\n\n    @classmethod\n    def permute_attr(cls, node_name, permutation_sequence, fx_graph, dryrun):\n        \"\"\"Permute a node's attributes. Somewhat hacky, assumes that we'll find exactly one dimension with a length matching the permutation's\"\"\"\n\n        assert \"attr\" in fx_graph[node_name].keys()\n        attr = fx_graph[node_name][\"attr\"]\n        if cls.__verbosity > 1:\n            print(f\"Found attribute {node_name} of shape {attr.shape}\")\n        found_perm = False\n        for dim in range(len(attr.shape)):\n            if attr.shape[dim] == len(permutation_sequence):\n                if found_perm:\n                    if cls.__verbosity > 0:\n                        print(\n                            f\"\\tWARNING: {node_name} has already been permuted, but it's trying to happen again along another dimension {dim}.\"\n                        )\n\n                    return False\n\n                found_perm = True\n                if cls.__verbosity > 1 and dryrun:\n                    print(f\"\\tpermuting along dimension {dim}\")\n\n                if not dryrun:\n                    # permute the dimension of interest to the front, permute within that dimension, then reset it\n                    order = [c for c in range(len(attr.shape))]\n                    order[0] = dim\n                    order[dim] = 0\n                    prmt = tuple(order)\n\n                    temp_weight = torch.clone(attr)\n                    temp_weight = torch.permute(temp_weight, prmt)\n                    temp_weight.copy_(temp_weight[permutation_sequence, ...])\n                    temp_weight = torch.permute(temp_weight, prmt)\n                    attr.data.copy_(temp_weight)\n\n                    cls.__params_permuted_in_K.append(node_name + \"_\" + str(dim))\n\n        return found_perm\n\n    @classmethod\n    def apply_permutation_in_K_dim(cls, node_name, permutation_sequence, fx_graph, dryrun):\n        \"\"\"This function is used to permutation for a node in K dim. (Need to handle the weight/bias/running_mean/running_var of the node)\"\"\"\n\n        if cls.__verbosity > 1:\n            print(\n                \"[apply_permutation_in_K_dim] Permutation for node: '{:}' in K dim\".format(\n                    node_name\n                )\n            )\n\n        if len(permutation_sequence) == 0:\n            if cls.__verbosity >= 0:\n                print(\n                    \"ERROR: [apply_permutation_in_K_dim] the permutation sequence is empty, fail to apply permutation in K dim.\"\n                )\n            return False\n\n        # permute attribute nodes\n        if \"attr\" in fx_graph[node_name].keys():\n            return cls.permute_attr(node_name, permutation_sequence, fx_graph, dryrun)\n\n        # if we didn't store the attribute already, look in the modules' parameters\n        is_node_in_all_parameters = False\n        success_permutation = False\n\n        for module_name, module, p_name, p in cls.__all_parameters:\n            if node_name_matches(node_name, module_name):\n                if cls.__verbosity > 1 and dryrun:\n                    print(\n                        \"[apply_permutation_in_K_dim] find the node: '{:}' with '{:}' in cls.__all_parameters, may succeed to apply permutation in K dim.\".format(\n                            node_name, p_name\n                        )\n                    )\n                is_node_in_all_parameters = True\n                permutation_to_apply = permutation_sequence\n\n                if p.shape[0] != len(permutation_sequence):  # assumed to be grouped convolutions\n                    if cls.__verbosity > 2 and dryrun:\n                        print(\n                            f\"Mismatch in K dimension between found module {module_name} {p_name} for node {node_name}: permutation length {len(permutation_sequence)} but parameter shape in K {p.shape[0]}\"\n                        )\n\n                    if p.shape[0] % len(permutation_sequence) != 0:\n                        return False\n\n                    permutation_to_apply = replicate_sequence(\n                        permutation_sequence, p.shape[0] // len(permutation_sequence)\n                    )\n\n                    if cls.__verbosity > 1 and dryrun:\n                        print(\n                            \"[apply_permutation_in_K_dim] the node: '{:}' with shape: '{:}' required replicating the permutation sequence with len '{:}' {:} times to succeed in applying the permutation in the K dimension.\".format(\n                                node_name,\n                                p.shape,\n                                len(permutation_sequence),\n                                p.shape[0] // len(permutation_sequence),\n                            )\n                        )\n                else:\n                    if cls.__verbosity > 1 and dryrun:\n                        print(\n                            \"[apply_permutation_in_K_dim] the node: '{:}' with shape: '{:}', can match the size of permutation sequence with len: '{:}', succeed to apply permutation in K dim.\".format(\n                                node_name, p.shape, len(permutation_sequence)\n                            )\n                        )\n\n                if not dryrun:\n                    p.data.copy_(p[permutation_to_apply, ...])\n                    cls.__params_permuted_in_K.append(node_name + \".\" + p_name)\n\n                success_permutation = True\n\n        if not is_node_in_all_parameters:\n            if cls.__verbosity >= 0:\n                print(\n                    \"ERROR: [apply_permutation_in _K_dim] cannot find the node: '{:}' in cls.__all_parameters, fail to apply permutation in K dim.\".format(\n                        node_name\n                    )\n                )\n            success_permutation = False\n\n        return success_permutation\n\n    @classmethod\n    def check_graph_for_unpermuted_nodes(cls, fx_graph):\n        \"\"\"Make sure that all permutable nodes/parameters were actually permuted and all GPUs agree\"\"\"\n\n        for node_name in fx_graph.keys():\n            node = fx_graph[node_name]\n\n            if \"C_permutable\" in node.keys() and node[\"C_permutable\"] and not node[\"C_permuted\"]:\n                sibling_group_id = node[\"sibling_group_id\"]\n                if (\n                    node[\"is_real\"]\n                    and cls.__group_data[\"skipped_sibling_groups\"][sibling_group_id] is None\n                ):\n                    if cls.__verbosity >= 0:\n                        print(\n                            f\"{node_name} was C_permutable in a not skipped sibling group but was not permuted along C! {node}\"\n                        )\n                    cls.__unpermuted_dims.append(node_name + \"_C\")\n\n            if \"K_permutable\" in node.keys() and node[\"K_permutable\"] and not node[\"K_permuted\"]:\n                coparent_group_id = node[\"coparent_group_id\"]\n                if (\n                    node[\"is_real\"]\n                    and cls.__group_data[\"skipped_coparent_groups\"][coparent_group_id] is None\n                ):\n                    if cls.__verbosity >= 0:\n                        print(\n                            f\"{node_name} was K_permutable in a not skipped coparent group but was not permuted along K! {node}\"\n                        )\n                    cls.__unpermuted_dims.append(node_name + \"_K\")\n\n        if cls.__verbosity > 0:\n            print(\n                f\"[check_graph_for_unpermuted_nodes] found nodes that missed permutations along {len(cls.__unpermuted_dims)} dimensions.\"\n            )\n\n        # make sure all GPUs agree\n        if torch.distributed.is_initialized():\n            cls.__unpermuted_dims = sorted(cls.__unpermuted_dims)\n            rank = torch.distributed.get_rank()\n            world_size = torch.distributed.get_world_size()\n            dist_store = torch.distributed.TCPStore(\n                \"127.0.0.1\", cls.__tcpstore_port, world_size, rank == 0\n            )\n            torch.distributed.barrier()\n\n            dist_store.set(str(rank), \",\".join(cls.__unpermuted_dims))\n            torch.distributed.barrier()\n\n            if rank == 0:\n                my_list = dist_store.get(\"0\").decode()\n\n                for peer in range(1, world_size):\n                    peer_list = dist_store.get(str(peer)).decode()\n                    assert my_list == peer_list, (\n                        f\"peer {peer} disagreed with rank 0's list of unpermuted nodes: \\n{my_list}\\n{peer_list}\"\n                    )\n\n    @classmethod\n    def find_sparse_parameters_for_node(cls, node_name):\n        \"\"\"If the node has parameters that are in the trackd sparse parameter list, find them and reshape to a 2D tensor with channels last\"\"\"\n        node_weight = None\n\n        # check the sparse parameters\n        for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:\n            if node_name_matches(node_name, module_name):\n                node_weight = torch.zeros_like(p)\n                node_weight.copy_(p)\n\n        # if we found something, reshape to concatenate along the same dimension\n        if node_weight is not None:\n            # Need to handle the concat for layers with different R & S\n            shape = node_weight.shape\n            # 1d-tensor\n            if len(shape) == 1:\n                node_weight = node_weight.view(1, shape[0])\n            # 2d-tensor (K, C)\n            elif len(shape) == 2:\n                node_weight = node_weight.view(shape[0], shape[1])\n            # 3d-tensor (K, C, R)\n            elif len(shape) == 3:\n                node_weight = (\n                    node_weight.permute(0, 2, 1).contiguous().view(shape[0] * shape[2], shape[1])\n                )\n            # 4d-tensor (K, C, R, S)\n            elif len(shape) == 4:\n                # convs\n                node_weight = (\n                    node_weight.permute(2, 3, 0, 1)\n                    .contiguous()\n                    .view(shape[2] * shape[3] * shape[0], shape[1])\n                )\n\n        return node_weight\n\n    @classmethod\n    def find_permutation_for_matrix_group(cls, matrix_group):\n        \"\"\"Find a good permutation for some matrix (which may be concatenated matrices that require the same permutation)\"\"\"\n\n        if cls.__verbosity > 1:\n            print(\n                f\"Searching for a good permutation for this sibling group of shape {matrix_group.shape}\"\n            )\n\n        permutation_found = False\n        num_channels = matrix_group.shape[1]\n        group_permutation = [c for c in range(num_channels)]\n\n        # automatic check for skipping the permutation search process\n        original_magnitude = (torch.abs(matrix_group)).sum(dtype=torch.float64)\n        pruned_magnitude = sum_after_2_to_4(matrix_group.cpu().detach().numpy())\n        diff_ratio = abs(original_magnitude - pruned_magnitude) / original_magnitude\n        epsilon = 1e-3\n\n        if cls.__verbosity > 1:\n            print(\n                \"\\n[search_for_good_permutation] Original element abs sum: {:}, Pruned element abs sum: {:}, Diff ratio: {:}\".format(\n                    original_magnitude, pruned_magnitude, diff_ratio\n                )\n            )\n\n        start_time_accelerated_search_for_good_permutation = time.perf_counter()\n        if diff_ratio < epsilon:\n            if cls.__verbosity > 2:\n                print(\n                    \"[search_for_good_permutation] Original element abs sum is almost same as the pruned element abs sum, further permutation search will not help, skipping!\"\n                )\n\n        else:\n            if cls.__verbosity > 2:\n                print(\n                    \"[search_for_good_permutation] Original element abs sum is different from the pruned element abs sum, further permutation search will help, continue with the permutation search!\"\n                )\n\n            # call the permutation search CUDA kernels as ASP extension.\n            # users can provide prefer search strategy by providing a valid 'search_options' as a dictionary,\n            # or users can implement their customized 'accelerated_search_for_good_permutation' function.\n            search_options = {}\n            # No.1 Strategy: Exhaustive Search\n            search_options[\"strategy\"] = \"exhaustive\"\n            search_options[\"stripe_group_size\"] = 8\n            search_options[\"escape_attempts\"] = 100\n            # No.2 Strategy: Progressive Channel Swap Search\n            # search_options['strategy'] = 'progressive channel swap'\n            # search_options['progressive_search_time_limit'] = 10\n            # search_options['improvement_threshold'] = 1e-9\n\n            # permutation search time is too long for matrix_group with large channel num\n            # change from Exhaustive Search to Progressive Channel Swap Search based on input matrix_group size\n            if num_channels > 2048:\n                search_options = {}\n                search_options[\"strategy\"] = \"progressive channel swap\"\n                search_options[\"progressive_search_time_limit\"] = 120\n                search_options[\"improvement_threshold\"] = 1e-9\n\n            if cls.__verbosity > 1:\n                print(f\"[search_for_good_permutation] search options: {search_options}\")\n\n            group_permutation = accelerated_search_for_good_permutation(\n                matrix_group, options=search_options, verbosity=cls.__verbosity\n            )\n            permutation_found = True\n\n        if cls.__verbosity > 1:\n            duration_accelerated_search_for_good_permutation = (\n                time.perf_counter() - start_time_accelerated_search_for_good_permutation\n            )\n            permuted_magnitude = sum_after_2_to_4(\n                matrix_group.cpu().detach().numpy()[:, group_permutation]\n            )\n            print(\n                \"[search_for_good_permutation] Take {:.4f} seconds to finish accelerated_search_for_good_permutation function and with final magnitude {:}.\".format(\n                    duration_accelerated_search_for_good_permutation, permuted_magnitude\n                )\n            )\n\n        return group_permutation, permutation_found\n\n    @classmethod\n    def skip_sibling_group(cls, fx_graph, sibling_group_id, reason):\n        \"\"\"Keep track of sibling groups that do not have permutations applied\"\"\"\n\n        # grab a parent to get the coparent group id\n        sibling_group = cls.__group_data[\"sibling_groups\"][sibling_group_id]\n        a_sibling = list(sibling_group)[0]\n        a_parent = fx_graph[a_sibling][\"real_parents\"][0]\n        coparent_group_id = fx_graph[a_parent][\"coparent_group_id\"]\n\n        if cls.__verbosity > 1:\n            print(\n                f\"Skipping permutations for Sibling Group {sibling_group_id} and Coparent Group {coparent_group_id}: {reason}\"\n            )\n\n        cls.__group_data[\"skipped_sibling_groups\"][sibling_group_id] = reason\n        cls.__group_data[\"skipped_coparent_groups\"][coparent_group_id] = reason\n\n    @classmethod\n    def collect_sparse_weights(cls, fx_graph, sibling_group, sibling_group_C_param):\n        \"\"\"Gather all sparse weights for a sibling group (to serve as input to the permutation search)\"\"\"\n\n        matrix_group = None\n\n        for sibling in sibling_group:\n            node_weight = cls.find_sparse_parameters_for_node(sibling)\n\n            if node_weight is not None:\n                # reshape due to siblings with grouped convolutions of different sizes\n                assert node_weight.shape[1] % sibling_group_C_param == 0, (\n                    f\"sibling {sibling}'s weights' C={node_weight.shape[1]} must be even multiple of the sibling group's C parameter {sibling_group_C_param}\"\n                )\n                node_weight = torch.reshape(node_weight, (-1, sibling_group_C_param))\n\n                if matrix_group is None:\n                    matrix_group = node_weight\n                else:\n                    try:\n                        matrix_group = torch.cat(\n                            (matrix_group, node_weight), dim=0\n                        )  # concat the weights in the K dimension, keep the same C dimension\n\n                    except:\n                        if cls.__verbosity >= 0:\n                            print(\n                                \"ERROR: [search_for_good_permutation][warning] cannot merge the weight for node: '{:}', with its weight shape: '{:}', the matrix_group shape: '{:}'.\".format(\n                                    sibling, node_weight.size(), matrix_group.size()\n                                )\n                            )\n                        continue\n                if cls.__verbosity > 2:\n                    print(\n                        \"[search_for_good_permutation] have merged the weight for node: '{:}', with its weight shape: '{:}', the matrix_group shape: '{:}'.\".format(\n                            sibling, node_weight.size(), matrix_group.size()\n                        )\n                    )\n            else:\n                if cls.__verbosity > 2:\n                    print(\n                        f\"[search_for_good_permutation] not adding dense weights for node {sibling} to the group\"\n                    )\n\n        return matrix_group\n\n    @classmethod\n    def find_sibling_group_permutation(cls, fx_graph, sibling_group_id):\n        \"\"\" \"Find a good permutation for some sibling group\"\"\"\n\n        if cls.__verbosity > 1:\n            print(f\"Finding permutation for sibling group {sibling_group_id}\")\n\n        cls.reset_seed()\n\n        sibling_group = cls.__group_data[\"sibling_groups\"][sibling_group_id]\n        sibling_group_C_param = int(cls.__group_data[\"sibling_group_C_params\"][sibling_group_id])\n\n        if sibling_group_C_param % 4 != 0 or sibling_group_C_param < 8:\n            cls.skip_sibling_group(\n                fx_graph, sibling_group_id, f\"Useless C: {sibling_group_C_param}\"\n            )\n            return\n\n        # collect *sparse* weights from all siblings, get the coparent group\n        matrix_group = cls.collect_sparse_weights(fx_graph, sibling_group, sibling_group_C_param)\n\n        # early-out if no siblings are sparse\n        if matrix_group is None:\n            cls.skip_sibling_group(fx_graph, sibling_group_id, \"Dense\")\n            return\n\n        # find a good permutation\n        group_permutation, found = cls.find_permutation_for_matrix_group(matrix_group)\n\n        # if no permutation was found, we didn't need it (input already sparse)\n        if not found:\n            cls.skip_sibling_group(fx_graph, sibling_group_id, \"Not needed\")\n            return\n\n        if cls.__verbosity > 2:\n            print(f\"Permutation for sibling group {sibling_group_id}: {group_permutation}\")\n\n        cls.__group_data[\"sibling_group_permutations\"][sibling_group_id] = group_permutation\n\n    @classmethod\n    def permute_sibling_group(cls, fx_graph, sibling_group_id, group_permutation):\n        \"\"\"Apply a permutation to some sibling group\"\"\"\n\n        if cls.__verbosity > 1:\n            print(f\"Attempting to permute sibling group {sibling_group_id}\")\n\n        sibling_group = cls.__group_data[\"sibling_groups\"][sibling_group_id]\n\n        # apply the permutation in two steps: first, a dry run to find any issues.\n        # if there were no issues, actually apply the permutation in the second step.\n        success = True\n        coparent_group_id = None\n        for dryrun in [True, False]:\n            # apply that permutation to the siblings' C dimension\n            for sibling in sibling_group:\n                assert fx_graph[sibling][\"C_permutable\"] and not fx_graph[sibling][\"C_permuted\"]\n                sibling_permuted = cls.apply_permutation_in_C_dim(\n                    sibling, group_permutation, dryrun\n                )\n                if dryrun:\n                    success = success and sibling_permuted\n                else:\n                    assert sibling_permuted, \"shouldn't fail permuting siblings after the dry run\"\n                    fx_graph[sibling][\"C_permuted\"] = sibling_permuted\n\n                a_parent = fx_graph[sibling][\"real_parents\"][0]\n                if coparent_group_id is None:\n                    coparent_group_id = fx_graph[a_parent][\"coparent_group_id\"]\n                else:\n                    assert coparent_group_id == fx_graph[a_parent][\"coparent_group_id\"], (\n                        f\"parent {a_parent} must belong to the same coparent group {coparent_group_id}, not {fx_graph[a_parent]['coparent_group_id']}\"\n                    )\n\n            # grab the parents (and co-parents) and apply to their K dimension\n            coparents = cls.__group_data[\"coparent_groups\"][coparent_group_id]\n            for coparent in coparents:\n                assert fx_graph[coparent][\"K_permutable\"] and not fx_graph[coparent][\"K_permuted\"]\n                coparent_permuted = cls.apply_permutation_in_K_dim(\n                    coparent, group_permutation, fx_graph, dryrun\n                )\n                if dryrun:\n                    success = success and coparent_permuted\n                else:\n                    assert coparent_permuted, \"shouldn't fail permuting coparents after the dry run\"\n                    fx_graph[coparent][\"K_permuted\"] = coparent_permuted\n\n                children_permuted = cls.apply_permutation_in_K_dim_to_children(\n                    fx_graph, coparent, group_permutation, dryrun\n                )\n                if dryrun:\n                    success = success and children_permuted\n                else:\n                    assert children_permuted, (\n                        \"shouldn't fail permuting coparents' children after the dry run\"\n                    )\n\n            if not success:\n                cls.skip_sibling_group(fx_graph, sibling_group_id, \"dryrun_failure\")\n\n                if cls.__verbosity > 0:\n                    print(\n                        f\"There was an issue permuting sibling group {sibling_group_id}, skipping it to preserve network quality.\"\n                    )\n\n                break\n\n    @classmethod\n    def apply_permutation_in_K_dim_to_children(cls, fx_graph, node_name, permutation, dryrun):\n        \"\"\"Apply a permutation along K to the children of some node\"\"\"\n\n        success = True\n        children = fx_graph[node_name][\"children\"]\n        if cls.__verbosity > 2 and dryrun:\n            print(f\"Applying a permutation in K to children of {node_name} : {children}\")\n\n        # apply the permutation along K to children as necessary\n        for child in children:\n            if \"is_real\" in fx_graph[child].keys() and fx_graph[child][\"is_real\"]:\n                if cls.__verbosity > 3 and dryrun:\n                    print(f\"\\tFound a real child {child}, not permuting it or its children along K\")\n            else:\n                if (\n                    \"module_type\" not in fx_graph[child].keys()\n                    or fx_graph[child][\"module_type\"] == \"None\"\n                ):\n                    if cls.__verbosity > 3 and dryrun:\n                        print(f\"\\tPermuting children of non-module {child} along K\")\n                    success = success and cls.apply_permutation_in_K_dim_to_children(\n                        fx_graph, child, permutation, dryrun\n                    )\n                elif not fx_graph[child][\"C_permutable\"]:\n                    if fx_graph[child][\"K_permutable\"] and not fx_graph[child][\"K_permuted\"]:\n                        if cls.__verbosity > 2 and dryrun:\n                            print(f\"\\tPermuting {child} along K\")\n                        child_permuted = cls.apply_permutation_in_K_dim(\n                            child, permutation, fx_graph, dryrun\n                        )\n                        success = success and child_permuted\n                        if not dryrun:\n                            fx_graph[child][\"K_permuted\"] = child_permuted\n                        assert fx_graph[child][\"K_passthru\"]\n\n                    if fx_graph[child][\"K_passthru\"]:\n                        success = success and cls.apply_permutation_in_K_dim_to_children(\n                            fx_graph, child, permutation, dryrun\n                        )\n                    else:\n                        if cls.__verbosity >= 0:\n                            print(\n                                f\"\\t!! ERROR {child} was a not real module that was not K_passthru\"\n                            )\n\n        return success\n\n    @classmethod\n    def defer_prints(cls):\n        \"\"\"Collect prints from this rank in distributed mode to avoid interleaved output\"\"\"\n\n        if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1:\n            cls.__new_stdout = io.StringIO(str(torch.distributed.get_rank()))\n            cls.__builtin_print = __builtin__.print\n\n            def deferred_print(*args, **kwargs):\n                try:  # see if torchvision examples has suppressed other ranks with the force argument\n                    cls.__builtin_print(*args, file=cls.__new_stdout, force=True, **kwargs)\n                except:\n                    cls.__builtin_print(*args, file=cls.__new_stdout, **kwargs)\n\n            __builtin__.print = deferred_print\n\n    @classmethod\n    def resume_prints(cls):\n        \"\"\"Emit the collected outputs from this rank, resume immediate printing\"\"\"\n\n        if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1:\n            output = cls.__new_stdout.getvalue()\n            __builtin__.print = cls.__builtin_print\n\n            try:\n                print(output, force=True)\n            except:\n                print(output)\n\n    @classmethod\n    def find_permutations(cls, fx_graph):\n        \"\"\"Search for permutations for all sibling groups\"\"\"\n\n        for sibling_group_id in cls.__group_data[\"sibling_groups\"].keys():\n            search_this_group = True\n            if torch.distributed.is_initialized():\n                rank = torch.distributed.get_rank()\n                world_size = torch.distributed.get_world_size()\n\n                if sibling_group_id % world_size != rank:\n                    search_this_group = False\n\n            cls.__group_data[\"sibling_group_permutations\"][sibling_group_id] = None\n            if search_this_group:\n                cls.defer_prints()\n\n                sibling_group = cls.__group_data[\"sibling_groups\"][sibling_group_id]\n                test_node_name = list(sibling_group)[0]\n                if not fx_graph[test_node_name][\"C_permutable\"]:\n                    if cls.__verbosity > 1:\n                        print(\n                            f\"Skipping permutation for sibling group {sibling_group_id} since it does not allow permutations along C\"\n                        )\n\n                else:\n                    if cls.__verbosity > 1:\n                        print(f\"Sibling group {sibling_group_id} can permute along C, permuting it\")\n\n                    cls.find_sibling_group_permutation(fx_graph, sibling_group_id)\n\n                cls.resume_prints()\n\n        return fx_graph\n\n    @classmethod\n    def sync_permutations(cls, fx_graph):\n        \"\"\"If multiple GPUs were involved in finding permutations, make sure everyone's in sync\"\"\"\n\n        if not torch.distributed.is_initialized():\n            return fx_graph\n\n        rank = torch.distributed.get_rank()\n        world_size = torch.distributed.get_world_size()\n        dist_store = torch.distributed.TCPStore(\n            \"127.0.0.1\", cls.__tcpstore_port, world_size, rank == 0\n        )\n\n        if cls.__verbosity > 0:\n            print(f\"Syncing permutations found among world size {world_size}\")\n\n        torch.distributed.barrier()\n        for sibling_group_id in sorted(cls.__group_data[\"sibling_groups\"].keys()):\n            src_rank = sibling_group_id % world_size\n\n            if src_rank == rank:\n                to_send = cls.__group_data[\"sibling_group_permutations\"].get(sibling_group_id, None)\n                skip_reason = None\n                if to_send is None:\n                    skip_reason = cls.__group_data[\"skipped_sibling_groups\"].get(\n                        sibling_group_id, None\n                    )\n                    if skip_reason is None:\n                        to_send = \"\"\n                    else:\n                        to_send = \"skip\"\n                else:\n                    to_send = \",\".join(str(c) for c in to_send)\n\n                dist_store.set(str(sibling_group_id), to_send)\n                if skip_reason is not None:\n                    dist_store.set(f\"skip {sibling_group_id}\", skip_reason)\n\n                if cls.__verbosity > 1:\n                    print(\n                        f\"{rank}: stored permutation for sibling group {sibling_group_id}\",\n                        force=True,\n                    )\n\n        torch.distributed.barrier()\n        for sibling_group_id in sorted(cls.__group_data[\"sibling_groups\"].keys()):\n            permutation = dist_store.get(str(sibling_group_id)).decode()\n\n            if permutation == \"skip\":\n                permutation = None\n                skip_reason = dist_store.get(f\"skip {sibling_group_id}\").decode()\n                cls.skip_sibling_group(fx_graph, sibling_group_id, skip_reason)\n            else:\n                if len(permutation) == 0:\n                    permutation = None\n                else:\n                    permutation = [int(c) for c in permutation.split(\",\")]\n\n            cls.__group_data[\"sibling_group_permutations\"][sibling_group_id] = permutation\n\n            if cls.__verbosity > 1:\n                print(f\"Got permutation for sibling group {sibling_group_id}\")\n\n        torch.distributed.barrier()\n        return fx_graph\n\n    @classmethod\n    def apply_permutations(cls, fx_graph):\n        \"\"\"Apply all the permutations that were found to the network appropriately\"\"\"\n\n        for sibling_group_id in cls.__group_data[\"sibling_group_permutations\"].keys():\n            permutation = cls.__group_data[\"sibling_group_permutations\"][sibling_group_id]\n\n            if permutation is not None:\n                cls.permute_sibling_group(fx_graph, sibling_group_id, permutation)\n\n        return fx_graph\n\n    @staticmethod\n    def insert_MHA_out_proj(fx_graph, MHA_node, verbosity):\n        \"\"\"MHA nodes have a hidden out_proj node, so insert it and fix up neighboring nodes\"\"\"\n\n        if verbosity > 1:\n            print(f\"Inserting MHA out_proj for node {MHA_node}\")\n        out_proj_node_name = MHA_node + \".out_proj\"\n        # insert the new node\n        fx_graph[out_proj_node_name] = {}\n        fx_graph[out_proj_node_name][\"parents\"] = [MHA_node]\n        fx_graph[out_proj_node_name][\"children\"] = fx_graph[MHA_node][\"children\"]\n        fx_graph[MHA_node][\"children\"] = [out_proj_node_name]\n\n        # set the new node's properties\n        fx_graph[out_proj_node_name][\"fx_op\"] = \"call_module\"\n        fx_graph[out_proj_node_name][\"module_type\"] = \"torch.nn.modules.linear.Linear\"\n        fx_graph[out_proj_node_name][\"groups_param\"] = \"None\"\n        fx_graph[out_proj_node_name][\"C_param\"] = fx_graph[MHA_node][\"C_param\"]\n        fx_graph[out_proj_node_name][\"K_param\"] = fx_graph[MHA_node][\"K_param\"]\n        fx_graph[out_proj_node_name][\"sibling_group_id\"] = None\n        fx_graph[out_proj_node_name][\"coparent_group_id\"] = None\n\n        # set permutation flags\n        fx_graph[out_proj_node_name][\"C_permutable\"] = False\n        fx_graph[MHA_node][\"K_permutable\"] = False\n        fx_graph[MHA_node][\"C_permutable\"] = True\n        fx_graph[out_proj_node_name][\"K_permutable\"] = True\n        fx_graph[out_proj_node_name][\"K_passthru\"] = False\n        fx_graph[out_proj_node_name][\"C_permuted\"] = False\n        fx_graph[out_proj_node_name][\"K_permuted\"] = False\n        fx_graph[out_proj_node_name][\"is_real\"] = True\n\n        if verbosity > 2:\n            print(f\"\\tUpdated: {MHA_node}: {fx_graph[MHA_node]}\")\n            print(f\"\\tAdded: {out_proj_node_name}: {fx_graph[out_proj_node_name]}\")\n\n        # update any nodes that thought their parent was the MHA node\n        for node in fx_graph.keys():\n            parents = fx_graph[node][\"parents\"]\n            if node != out_proj_node_name and MHA_node in parents:\n                parents.remove(MHA_node)\n                parents.append(out_proj_node_name)\n                fx_graph[node][\"parents\"] = parents\n                if verbosity > 2:\n                    print(f\"\\tUpdated parents of {node}: {fx_graph[node]}\")\n\n        return fx_graph\n\n    @staticmethod\n    def init_grouped_conv_permutation_flags(fx_graph, node_name, node_groups, verbosity):\n        \"\"\"Handle grouped convolutions to make dimensions match\"\"\"\n\n        node_C = int(fx_graph.get(node_name).get(\"C_param\"))\n        node_K = int(fx_graph.get(node_name).get(\"K_param\"))\n        node_groups = int(node_groups)\n\n        if verbosity > 2:\n            print(f\"\\t{node_name} pre-divide C: {node_C}, K: {node_K}, G: {node_groups}\")\n        assert node_C % node_groups == 0\n        node_C = int(node_C / node_groups)\n        fx_graph[node_name][\"C_param\"] = str(node_C)\n        if verbosity > 2:\n            print(f\"\\t{node_name} post-divide C: {node_C}, K: {node_K}, G: {node_groups}\")\n\n        if node_C == 1:  # G == C (C is pre-divided by G)\n            if node_groups == node_K:  # true depthwise, G == C == K (C will be pre-divided by G)\n                fx_graph[node_name][\"K_permutable\"] = True\n                fx_graph[node_name][\"K_permuted\"] = False\n                fx_graph[node_name][\"K_passthru\"] = True\n                fx_graph[node_name][\"is_real\"] = False\n            # else:                                          # G != K, handling a permutation along K would be very tricky and not likely useful\n\n        else:  # G != C\n            if (\n                node_C > 4 and node_C % 4 == 0\n            ):  # permutations only help if there's more than one 2:4 pruning group\n                fx_graph[node_name][\"C_permutable\"] = True\n                fx_graph[node_name][\"C_permuted\"] = False\n\n    @classmethod\n    def init_permutation_flags(cls, fx_graph):\n        \"\"\"Set the permutation flags for each node based only on that node's module type and parameters\"\"\"\n\n        if cls.__verbosity > 0:\n            print(\n                \"\\n[init_permutation_flags] Initialize the permutation flags for each node according to module type and parameters\"\n            )\n\n        # initialize some graph-wide trackers\n        cls.__group_data = {}\n        cls.__group_data[\"next_sibling_group_id\"] = 0\n        cls.__group_data[\"next_coparent_group_id\"] = 0\n        cls.__group_data[\"sibling_groups\"] = {}\n        cls.__group_data[\"sibling_group_permutations\"] = {}\n        cls.__group_data[\"sibling_group_C_params\"] = {}\n        cls.__group_data[\"skipped_sibling_groups\"] = {}\n        cls.__group_data[\"coparent_groups\"] = {}\n        cls.__group_data[\"skipped_coparent_groups\"] = {}\n\n        # track MHA nodes\n        MHA_nodes = []\n\n        # initialize each node's details\n        for node_name in fx_graph.keys():\n            fx_node = fx_graph.get(node_name)\n            node_module_type = fx_node.get(\"module_type\")\n            if cls.__verbosity > 1:\n                if node_module_type == \"get_attr\":\n                    print(f\"Initializing node {node_name} of type {node_module_type}\")\n                else:\n                    print(f\"Initializing node {node_name} of type {node_module_type}: {fx_node}\")\n\n            # default for all nodes: don't allow anything\n            if node_module_type is not None:\n                fx_graph[node_name][\"C_permutable\"] = (\n                    False  # does this node have parameters that can be permuted in C\n                )\n                fx_graph[node_name][\"K_permutable\"] = (\n                    False  # does this node have parameters that can be permuted in K\n                )\n                fx_graph[node_name][\"K_passthru\"] = (\n                    False  # does this node need to pass a K permutation to its parents\n                )\n                fx_graph[node_name][\"is_real\"] = False\n                fx_graph[node_name][\"C_permuted\"] = False\n                fx_graph[node_name][\"K_permuted\"] = False\n\n                # initialize sibling and coparent groups\n                fx_graph[node_name][\"sibling_group_id\"] = None\n                fx_graph[node_name][\"coparent_group_id\"] = None\n\n                # update each node to be more permissive if supported\n                if node_module_type in cls.__permutation_target_module_types:\n                    fx_graph[node_name][\"is_real\"] = True\n                    node_groups = fx_graph.get(node_name).get(\"groups_param\")\n\n                    if node_groups in [\"None\", \"1\"]:  # no groups, no constraints\n                        fx_graph[node_name][\"C_permutable\"] = True\n                        fx_graph[node_name][\"K_permutable\"] = True\n\n                    else:  # handle groups\n                        Permutation.init_grouped_conv_permutation_flags(\n                            fx_graph, node_name, node_groups, cls.__verbosity\n                        )\n\n                elif node_module_type in cls.__permute_K_and_passthru_module_types:\n                    fx_graph[node_name][\"K_permutable\"] = True\n                    fx_graph[node_name][\"K_passthru\"] = True\n                    fx_graph[node_name][\"is_real\"] = False\n\n                elif node_module_type in cls.__simple_passthru_module_types:\n                    fx_graph[node_name][\"K_passthru\"] = True\n                    fx_graph[node_name][\"is_real\"] = False\n\n                elif node_module_type in cls.__disallow_permutations_module_types:\n                    fx_graph[node_name][\"is_real\"] = True\n                    fx_graph[node_name][\"C_param\"] = 1\n                    fx_graph[node_name][\"K_param\"] = 1\n                    fx_graph[node_name][\"groups_param\"] = 1\n\n                elif \"activation\" in node_module_type:\n                    if cls.__verbosity > 0:\n                        print(\n                            f\"WARNING: how should permutation flags be initialized for node {node_name} of module type {node_module_type}?  Found 'activation', assuming simple passthru behavior.\"\n                        )\n                    fx_graph[node_name][\"K_passthru\"] = True\n                    fx_graph[node_name][\"is_real\"] = False\n\n                else:\n                    if cls.__verbosity > 0:\n                        print(\n                            f\"WARNING: how should permutation flags be initialized for node {node_name} of module type {node_module_type}?  Defaulting to strict, disallowing permutations around it.\"\n                        )\n                    # is_real coupled with disallowed C and K permutations will poison real parents and real children\n                    fx_graph[node_name][\"is_real\"] = True\n                    # dummy entries:\n                    fx_graph[node_name][\"C_param\"] = 1\n                    fx_graph[node_name][\"K_param\"] = 1\n                    fx_graph[node_name][\"groups_param\"] = 1\n\n                # MHA nodes only handle the in_proj, need to add out_proj nodes explicitly\n                # keep track here so we can iterate directly and change fx_graph keys\n                if node_module_type == \"torch.nn.modules.activation.MultiheadAttention\":\n                    MHA_nodes.append(node_name)\n\n            if cls.__verbosity > 1:\n                if node_module_type == \"get_attr\":\n                    print(f\"\\tInitialized node {node_name} of type {node_module_type}\")\n                else:\n                    print(\n                        f\"\\tInitialized node {node_name} of type {node_module_type}: {fx_graph[node_name]}\"\n                    )\n\n        for MHA_node in MHA_nodes:\n            fx_graph = Permutation.insert_MHA_out_proj(fx_graph, MHA_node, cls.__verbosity)\n\n        return fx_graph\n\n    @staticmethod\n    def collect_siblings(fx_graph, node_name, all_siblings):\n        \"\"\"Recursively build a set of some node's siblings in the graph\"\"\"\n\n        # find all siblings of the requested node\n        siblings = set()\n        parents = fx_graph.get(node_name).get(\"real_parents\")\n        for parent in parents:\n            children = fx_graph.get(parent).get(\"real_children\")\n            for child in children:\n                siblings.add(child)\n\n        # separate the new siblings, since we'll need to process them recursively\n        new_siblings = siblings.difference(all_siblings)\n        # update the final list with just the new elements\n        all_siblings.update(new_siblings)\n\n        for new_sibling in new_siblings:\n            all_siblings = Permutation.collect_siblings(fx_graph, new_sibling, all_siblings)\n\n        return all_siblings\n\n    @staticmethod\n    def propagate_sibling_group(fx_graph, all_siblings, verbosity):\n        \"\"\"Check a sibling group for ability to be permuted, disallow all siblings and coparents if there's an issue\"\"\"\n\n        made_change = False\n        allow_C = True\n        for sibling in all_siblings:\n            pre_check = allow_C\n            allow_C = allow_C and fx_graph[sibling][\"C_permutable\"]\n            if allow_C != pre_check:\n                if verbosity > 2:\n                    if fx_graph[sibling][\"module_type\"] == \"get_attr\":\n                        print(f\"\\tnode {sibling} has poisoned the sibling group of {all_siblings}\")\n                    else:\n                        print(\n                            f\"\\tnode {sibling} has poisoned the sibling group of {all_siblings}: {fx_graph[sibling]}\"\n                        )\n                break\n\n        if not allow_C:\n            for sibling in all_siblings:\n                made_change = made_change or fx_graph[sibling][\"C_permutable\"]\n                fx_graph[sibling][\"C_permutable\"] = False\n\n                # only disable permutation along K for parents if this node cannot passthru, either\n                if not fx_graph[sibling][\"K_passthru\"]:\n                    sibling_parents = fx_graph[sibling][\"real_parents\"]\n                    for sibling_parent in sibling_parents:\n                        made_change = (\n                            made_change\n                            or fx_graph[sibling_parent][\"K_permutable\"]\n                            or fx_graph[sibling_parent][\"K_passthru\"]\n                        )\n                        fx_graph[sibling_parent][\"K_permutable\"] = False\n                        fx_graph[sibling_parent][\"K_passthru\"] = False\n\n        return made_change\n\n    @staticmethod\n    def collect_coparents(fx_graph, node_name, all_coparents):\n        \"\"\"Recursively build a set of all coparents of a particular node in the graph\"\"\"\n\n        # find all coparents of the requested node\n        coparents = set()\n        children = fx_graph.get(node_name).get(\"real_children\")\n        for child in children:\n            parents = fx_graph.get(child).get(\"real_parents\")\n            for parent in parents:\n                coparents.add(parent)\n\n                # coparents are used to restrict what nodes can be permuted along C, so we need to track if the current parents also pass their K permutations up\n                if fx_graph[parent][\"K_passthru\"]:\n                    grandparents = fx_graph[parent][\"real_parents\"]\n                    for grandparent in grandparents:\n                        coparents = coparents.union(\n                            Permutation.collect_coparents(fx_graph, grandparent, coparents)\n                        )\n\n        # separate the new coparents, since we'll need to process them recursively\n        new_coparents = coparents.difference(all_coparents)\n        # update the final list with just the new elements\n        all_coparents.update(new_coparents)\n\n        for new_coparent in new_coparents:\n            all_coparents = Permutation.collect_coparents(fx_graph, new_coparent, all_coparents)\n\n        return all_coparents\n\n    @staticmethod\n    def propagate_coparent_group(fx_graph, all_coparents, verbosity):\n        \"\"\"Check a coparent group for ability to be permuted, disallow all fellow coparents and children if there's an issue\"\"\"\n\n        # see if all coparents agree that K can be permuted\n        allow_K = True\n        made_change = False\n        for coparent in all_coparents:\n            pre_check = allow_K\n            allow_K = allow_K and (\n                fx_graph[coparent][\"K_permutable\"] or fx_graph[coparent][\"K_passthru\"]\n            )\n            if allow_K != pre_check:\n                if verbosity > 2:\n                    if fx_graph[coparent][\"module_type\"] == \"get_attr\":\n                        print(\n                            f\"\\tnode {coparent} has poisoned the coparent group of {all_coparents}\"\n                        )\n                    else:\n                        print(\n                            f\"\\tnode {coparent} has poisoned the coparent group of {all_coparents}: {fx_graph[coparent]}\"\n                        )\n                break\n\n        # if anyone says no, force everyone to 'no', keep track of updated state\n        if not allow_K:\n            for coparent in all_coparents:\n                # all coparents can no longer be permuted along K\n                if fx_graph[coparent][\"K_permutable\"] or fx_graph[coparent][\"K_passthru\"]:\n                    made_change = True\n\n                    fx_graph[coparent][\"K_permutable\"] = False\n                    fx_graph[coparent][\"K_passthru\"] = False\n\n                # children of coparents can't be permuted along C\n                coparent_children = fx_graph[coparent][\"real_children\"]\n                for coparent_child in coparent_children:\n                    if fx_graph[coparent_child][\"C_permutable\"]:\n                        fx_graph[coparent_child][\"C_permutable\"] = False\n                        made_change = True\n\n        return made_change\n\n    @classmethod\n    def fixup_concats(cls, fx_graph):\n        \"\"\"concat operations/modules may concatenate along the channel dimension, which requires special handling (like grouped convs)\"\"\"\n\n        if cls.__verbosity > 0:\n            print(\"[fixup_concats]\")\n\n        for node_name in fx_graph.keys():\n            fx_node = fx_graph[node_name]\n            if fx_node.get(\"module_type\") == \"concat\":\n                # get real parents, find GCD of their Ks\n                node_real_parents = fx_node[\"real_parents\"]\n\n                # some concats are at the front of networks (googlenet)\n                if len(node_real_parents) == 0:\n                    continue\n\n                parents_K_params = []\n                for parent in node_real_parents:\n                    parent_K_param = int(fx_graph[parent][\"K_param\"])\n                    parents_K_params.append(parent_K_param)\n                    fx_graph[parent][\"allow_K_mismatch\"] = \"concat op\"\n\n                # if grouped convolutions make the input channels different among siblings different sizes,\n                # restrict the permutation atom to the greatest common divisor so it can be tiled as needed for each sibling (and parent)\n                if cls.__verbosity > 2:\n                    print(\n                        f\"\\tfixing up concat node {node_name}, found parents' {node_real_parents} Ks: {parents_K_params}\"\n                    )\n\n                children_GCD_param = str(np.gcd.reduce(parents_K_params))\n\n                # set this to GCD of children's sibling group\n                sibling_group_id = -1\n                node_real_children = fx_node[\"real_children\"]\n                for child in node_real_children:\n                    sibling_group_id = fx_graph[child][\"sibling_group_id\"]\n                    fx_graph[child][\"C_param\"] = children_GCD_param\n\n                old_children_GCD = cls.__group_data[\"sibling_group_C_params\"][sibling_group_id]\n                cls.__group_data[\"sibling_group_C_params\"][sibling_group_id] = children_GCD_param\n\n                # fixup this node's dimensions\n                # use the functionality of grouped convolutions\n                fx_node[\"C_param\"] = children_GCD_param\n                fx_node[\"K_param\"] = old_children_GCD\n                fx_node[\"groups_param\"] = str(int(old_children_GCD) // int(children_GCD_param))\n\n                if cls.__verbosity > 2:\n                    print(\n                        f\"\\tfixed up concat node {node_name}, found GCD of parents' {node_real_parents} K to be {children_GCD_param}, updated children's {node_real_children} C_params and sibling group {sibling_group_id} GCD\"\n                    )\n                    print(f\"\\tthis node now: {fx_node}\")\n\n        return fx_graph\n\n    @classmethod\n    def enforce_dimension_agreement(cls, fx_graph):\n        \"\"\"Check all nodes' channel dimensions against parents and children to make sure they agree; e.g. flatten ops may change these dimensions\"\"\"\n\n        if cls.__verbosity > 0:\n            print(\"[enforce_dimension_agreement]\")\n\n        for node_name in fx_graph.keys():\n            fx_node = fx_graph[node_name]\n            if \"is_real\" in fx_node.keys() and fx_node[\"is_real\"]:\n                # enforce this node's input dimension matches its parents' output dimensions\n                node_C = int(fx_node[\"C_param\"])\n                node_K = int(fx_node[\"K_param\"])\n\n                if fx_graph[node_name][\"groups_param\"] not in [\"1\", \"None\"]:\n                    node_C = node_C * int(fx_node[\"groups_param\"])\n\n                node_real_parents = fx_node[\"real_parents\"]\n                if len(node_real_parents) == 0:\n                    if cls.__verbosity > 1:\n                        print(f\"\\t{node_name} has no real parents, disabling permutations along C\")\n                    fx_graph[node_name][\"C_permutable\"] = False\n                else:\n                    for real_parent in node_real_parents:\n                        parent_K = int(fx_graph[real_parent][\"K_param\"])\n                        ignore_mismatch = fx_graph[real_parent].get(\"allow_K_mismatch\")\n\n                        if ignore_mismatch is not None:\n                            if cls.__verbosity > 1:\n                                print(\n                                    f\"\\tIgnoring dimension mismatch between {node_name} (C={node_C}) and its parent {real_parent} (K={parent_K}) as requested: {ignore_mismatch}\"\n                                )\n\n                        elif parent_K >= 0 and node_C != parent_K:\n                            if cls.__verbosity > 1:\n                                print(\n                                    f\"\\tDimensions mismatch between {node_name} (C={node_C}) and its parent {real_parent} (K={parent_K}), disallowing the relevant permutations\"\n                                )\n\n                            fx_graph[node_name][\"C_permutable\"] = False\n                            fx_graph[real_parent][\"K_permutable\"] = False\n\n                            if cls.__verbosity > 2:\n                                print(f\"\\t{fx_graph[node_name]}\\n\\t{fx_graph[real_parent]}\")\n\n                if len(fx_graph[node_name][\"real_children\"]) == 0:\n                    if cls.__verbosity > 1:\n                        print(f\"\\t{node_name} has no real children, disabling permutations along K\")\n                    fx_graph[node_name][\"K_permutable\"] = False\n\n        return fx_graph\n\n    @classmethod\n    def make_sibling_coparent_groups(cls, fx_graph):\n        \"\"\"Traverse all real nodes in the graph and collect their siblings and coparents\"\"\"\n\n        if cls.__verbosity > 0:\n            print(\"[make_sibling_coparent_groups]\")\n\n        for node_name in fx_graph.keys():\n            fx_node = fx_graph[node_name]\n\n            if \"is_real\" in fx_node.keys() and fx_node[\"is_real\"]:\n                sibling_group_id = fx_node[\"sibling_group_id\"]\n                if sibling_group_id is None:  # need to make a new sibling group for this node\n                    all_siblings = cls.collect_siblings(fx_graph, node_name, set([node_name]))\n                    all_siblings = sorted(all_siblings)  # deterministic order for DDP setups\n                    sibling_group_id = cls.__group_data[\"next_sibling_group_id\"]\n                    cls.__group_data[\"sibling_groups\"][sibling_group_id] = all_siblings\n                    cls.__group_data[\"next_sibling_group_id\"] = sibling_group_id + 1\n\n                    sibling_group_C_params = []\n                    for sibling in all_siblings:\n                        fx_graph[sibling][\"sibling_group_id\"] = sibling_group_id\n                        sibling_C_param = int(fx_graph[sibling][\"C_param\"])\n                        sibling_group_C_params.append(sibling_C_param)\n\n                    # if grouped convolutions make the input channels different among siblings different sizes,\n                    # restrict the permutation atom to the greatest common divisor so it can be tiled as needed for each sibling (and parent)\n                    sibling_group_C_param = str(np.gcd.reduce(sibling_group_C_params))\n                    cls.__group_data[\"sibling_group_C_params\"][sibling_group_id] = (\n                        sibling_group_C_param\n                    )\n                    cls.__group_data[\"skipped_sibling_groups\"][sibling_group_id] = None\n\n                    if cls.__verbosity > 1:\n                        print(\n                            f\"New sibling group {sibling_group_id} with GCD(C) of {sibling_group_C_param}: {all_siblings}\"\n                        )\n\n                coparent_group_id = fx_node[\"coparent_group_id\"]\n                if coparent_group_id is None:\n                    all_coparents = cls.collect_coparents(fx_graph, node_name, set([node_name]))\n                    coparent_group_id = cls.__group_data[\"next_coparent_group_id\"]\n                    cls.__group_data[\"coparent_groups\"][coparent_group_id] = all_coparents\n                    cls.__group_data[\"next_coparent_group_id\"] = coparent_group_id + 1\n                    cls.__group_data[\"skipped_coparent_groups\"][coparent_group_id] = None\n\n                    for coparent in all_coparents:\n                        fx_graph[coparent][\"coparent_group_id\"] = coparent_group_id\n\n                    if cls.__verbosity > 1:\n                        print(f\"New coparent group {coparent_group_id}: {all_coparents}\")\n        return fx_graph\n\n    @classmethod\n    def propagate_permutation_flags(cls, fx_graph):\n        \"\"\"Disallow sibling groups from having different C_permutable flags and coparent groups from having different K_permutable flags within the groups\"\"\"\n\n        made_change = True  # will we need to repeat this propagation?\n        # TODO: just propagate to sibling groups and coparent groups directly, instead of iteratively to direct real_parents and siblings\n        while made_change:\n            made_change = False\n\n            if cls.__verbosity > 0:\n                print(\"Making a pass at propagating permutation flags\")\n\n            for node_name in fx_graph.keys():\n                fx_node = fx_graph.get(node_name)\n\n                node_parents = fx_graph.get(node_name).get(\"parents\")\n                node_real_parents = fx_graph.get(node_name).get(\"real_parents\")\n                node_children = fx_graph.get(node_name).get(\"children\")\n                node_real_children = fx_graph.get(node_name).get(\"real_children\")\n\n                # input layers can't be permuted along C without a runtime fixup, skip them\n                if node_parents is None or (\n                    \"x\" in node_parents\n                    and \"C_permutable\" in fx_graph[node_name].keys()\n                    and fx_graph[node_name][\"C_permutable\"]\n                ):\n                    if cls.__verbosity > 1:\n                        print(\n                            f\"{node_name} has no parents, or only an input, disabling permutations in C\"\n                        )\n                    made_change = True\n                    fx_graph[node_name][\"C_permutable\"] = False\n\n                # output layers can't be permuted along K without a runtime fixup, skip them\n                if node_children is None or (\n                    \"output\" in node_children\n                    and \"K_permutable\" in fx_graph[node_name].keys()\n                    and fx_graph[node_name][\"K_permutable\"]\n                ):\n                    if cls.__verbosity > 1:\n                        print(\n                            f\"{node_name} has no children, or only an output, disabling permutations in K\"\n                        )\n                    made_change = True\n                    fx_graph[node_name][\"K_permutable\"] = False\n                    fx_graph[node_name][\"K_passthru\"] = False\n\n                if \"is_real\" in fx_node.keys() and fx_node[\"is_real\"]:\n                    # siblings must share C-flags; if one cannot be permuted along C, none can\n                    sibling_group_id = fx_graph[node_name][\"sibling_group_id\"]\n                    all_siblings = cls.__group_data[\"sibling_groups\"][sibling_group_id]\n                    made_change = (\n                        cls.propagate_sibling_group(fx_graph, all_siblings, cls.__verbosity)\n                        or made_change\n                    )\n\n                    # coparents must share K-flags; if one cannot be permuted along K, none can\n                    coparent_group_id = fx_graph[node_name][\"coparent_group_id\"]\n                    all_coparents = cls.__group_data[\"coparent_groups\"][coparent_group_id]\n                    made_change = (\n                        cls.propagate_coparent_group(fx_graph, all_coparents, cls.__verbosity)\n                        or made_change\n                    )\n\n        return fx_graph\n\n    @classmethod\n    def find_node_real_children(cls, fx_graph, node_name, found_children):\n        \"\"\"Collect the real children of some node\"\"\"\n\n        if \"real_children\" in fx_graph[node_name].keys():\n            return found_children.union(fx_graph[node_name][\"real_children\"])\n\n        children = fx_graph[node_name][\"children\"]\n        for child in children:\n            if child in fx_graph.keys():  # not the output node\n                if cls.__verbosity > 3:\n                    print(f\"\\tchecking child {child} of node {node_name}\")\n\n                # if it's a real node, just add it\n                if \"is_real\" in fx_graph[child].keys() and fx_graph[child][\"is_real\"]:\n                    found_children.add(child)\n                else:  # otherwise, search its children\n                    found_children = cls.find_node_real_children(fx_graph, child, found_children)\n\n        return found_children\n\n    @classmethod\n    def find_real_children(cls, fx_graph):\n        \"\"\"Collect the real children of all nodes in the graph\"\"\"\n\n        if cls.__verbosity > 0:\n            print(\n                \"\\n[find_real_children] Find the real children for each node according to the whole network graph built with Torch.FX\"\n            )\n\n        reversible_fx_graph_keys = list(fx_graph.keys())\n        for node_name in reversed(\n            reversible_fx_graph_keys\n        ):  # as the optimization, we need to find the real children from back to front, to use the already saved 'real_children'\n            node_children = fx_graph.get(node_name).get(\"children\")\n\n            if cls.__verbosity > 2:\n                print(\n                    \"[find_real_children] node_name: '{:}', children: {:}\".format(\n                        node_name, node_children\n                    )\n                )\n\n            real_children = cls.find_node_real_children(fx_graph, node_name, set())\n\n            if cls.__verbosity > 1:\n                print(\n                    f\"[find_real_children] {node_name} has {len(real_children)} real children: {real_children}\"\n                )\n\n            fx_graph[node_name][\"real_children\"] = sorted(real_children)\n\n        if cls.__save_permutation_graph:\n            cls.save_graph_to_json(\n                fx_graph,\n                save_dumped_graph_path_with_name=os.path.join(\n                    cls.__permutation_output_dir,\n                    \"./model_graph_find_real_children.json\",\n                ),\n            )  # save the intermediate graph as JSON file for debugging\n        return fx_graph\n\n    @classmethod\n    def find_node_real_parents(cls, fx_graph, node_name, found_parents):\n        \"\"\"Collect the real parents of some node\"\"\"\n\n        if \"real_parents\" in fx_graph[node_name].keys():\n            return found_parents.union(fx_graph[node_name][\"real_parents\"])\n\n        parents = fx_graph[node_name][\"parents\"]\n        for parent in parents:\n            if parent in fx_graph.keys():  # not the input node\n                if cls.__verbosity > 3:\n                    print(f\"\\tchecking parent {parent} of node {node_name}\")\n\n                # if it's a real node, just add it\n                if \"is_real\" in fx_graph[parent].keys() and fx_graph[parent][\"is_real\"]:\n                    found_parents.add(parent)\n                else:  # otherwise, search its parents\n                    found_parents = cls.find_node_real_parents(fx_graph, parent, found_parents)\n\n        return found_parents\n\n    @classmethod\n    def find_real_parents(cls, fx_graph):\n        \"\"\"Collect the real parents of all nodes in the graph\"\"\"\n\n        if cls.__verbosity > 0:\n            print(\n                \"\\n[find_real_parents] Find the real parents for each node according to the whole network graph built with Torch.FX\"\n            )\n\n        for node_name in fx_graph.keys():\n            node_real_parents_name = []\n            node_real_parents_module_type = []\n\n            real_parents = cls.find_node_real_parents(fx_graph, node_name, set())\n\n            if cls.__verbosity > 1:\n                print(\n                    f\"[find_real_parents] {node_name} has {len(real_parents)} real parents: {real_parents}\"\n                )\n\n            fx_graph[node_name][\"real_parents\"] = sorted(real_parents)\n\n        if cls.__save_permutation_graph:\n            cls.save_graph_to_json(\n                fx_graph,\n                save_dumped_graph_path_with_name=os.path.join(\n                    cls.__permutation_output_dir, \"./model_graph_find_real_parent.json\"\n                ),\n            )  # save the intermediate graph as JSON file for debugging\n        return fx_graph\n\n    @classmethod\n    def build_fx_graph(\n        cls, model, dump_fx_graph=False, save_dumped_fx_graph=\"./model_fx_graph.json\"\n    ):\n        \"\"\"Build the whole network graph with Torch.FX.\"\"\"\n\n        network_fx_graph = {}\n        success = True\n        torch_version = str(torch.__version__)\n        torch_version_major = int(torch_version.split(\".\")[0])\n        torch_version_minor = int(torch_version.split(\".\")[1])\n        try:\n            torch_version_minimum = int(torch_version.split(\".\")[2])\n        except ValueError:  # support the none standard version\n            torch_version_minimum = torch_version.split(\".\")[2]\n        if cls.__verbosity > 2:\n            print(\n                \"[build_fx_graph] The torch version is: {}, version major is: {}, version minor is: {}, version minimum is: {}\".format(\n                    torch_version,\n                    torch_version_major,\n                    torch_version_minor,\n                    torch_version_minimum,\n                )\n            )\n\n        if torch_version_major >= 2 or (torch_version_major >= 1 and torch_version_minor >= 8):\n            if cls.__verbosity > 1:\n                print(\"[build_fx_graph] The Torch.FX is supported.\")\n        else:  # Torch.FX is introduced in torch 1.8.0\n            if cls.__verbosity >= 0:\n                print(\n                    \"[build_fx_graph] The Torch.FX is not supported. So cannot build the Torch.FX graph.\"\n                )\n            success = False\n            return network_fx_graph, success\n\n        if cls.__verbosity > 2:\n            print(\"\\n[build_fx_graph] Print the model structure with pure PyTorch function\")\n            print(model)\n\n        graph_module = cls.trace_and_print_raw_fx_graph(\n            model, print_tabular=cls.__verbosity > 1\n        )  # needs \"tabulate\" library\n        if graph_module is None:\n            success = False\n            return network_fx_graph, success\n\n        if cls.__verbosity > 0:\n            print(\"\\n[build_fx_graph] Build the module name and type dictionary\")\n\n        module_name_type_dict = {}\n        module_name_group_conv_dict = {}\n        module_name_C_dict = {}\n        module_name_K_dict = {}\n        for name, mod in model.named_modules():\n            if cls.__verbosity > 1:\n                print(\"[build_fx_graph] module_name: {}, module type: {}\".format(name, type(mod)))\n            module_name_type_dict[name] = str(type(mod)).split(\"'\")[1]\n            try:\n                module_name_C_dict[name] = str(mod.in_channels)\n            except:\n                try:\n                    module_name_C_dict[name] = str(mod.in_features)\n                except:\n                    try:\n                        module_name_C_dict[name] = str(mod.embed_dim)\n                    except:\n                        module_name_C_dict[name] = \"None\"\n\n            try:\n                module_name_K_dict[name] = str(mod.out_channels)\n            except:\n                try:\n                    module_name_K_dict[name] = str(mod.out_features)\n                except:\n                    try:\n                        module_name_K_dict[name] = str(mod.embed_dim)\n                    except:\n                        module_name_K_dict[name] = \"None\"\n\n            try:\n                module_name_group_conv_dict[name] = str(mod.groups)\n                if cls.__verbosity > 1:\n                    print(\n                        \"[build_fx_graph] this module has 'group' param with value: {}\".format(\n                            mod.groups\n                        )\n                    )\n            except:\n                module_name_group_conv_dict[name] = \"None\"\n                continue\n\n        # keep track of children and parents for each layer (could be call_module or call_function)\n        if cls.__verbosity > 0:\n            print(\"\\n[build_fx_graph] Print the children and parents relationship for each layer\")\n        network_fx_graph = {}\n        for node in graph_module.graph.nodes:\n            if node.op == \"placeholder\":\n                if cls.__verbosity > 2:\n                    print(\"[build_fx_graph] This is the 'input' node: {:}\".format(node.target))\n                continue\n            elif node.op == \"get_attr\":\n                if cls.__verbosity > 2:\n                    print(\"[build_fx_graph] This is the 'get_attr' node: {:}\".format(node.target))\n                node_parent, node_children = get_node_parent_children(node)\n                converted_node_name = convert_fx_node_name(node.target)\n\n                network_fx_graph[converted_node_name] = {}\n                network_fx_graph[converted_node_name][\"parents\"] = node_parent\n                network_fx_graph[converted_node_name][\"children\"] = node_children\n                network_fx_graph[converted_node_name][\"module_type\"] = \"get_attr\"\n                network_fx_graph[converted_node_name][\"groups_param\"] = \"None\"\n\n                # inspired by https://pytorch.org/docs/stable/fx.html\n                def fetch_attr(target: str, mod):\n                    target_atoms = target.split(\".\")\n                    attr_itr = mod\n                    for i, atom in enumerate(target_atoms):\n                        if not hasattr(attr_itr, atom):\n                            raise RuntimeError(\n                                f\"Node referenced nonexistant target {'.'.join(target_atoms[:i])}\"\n                            )\n                        attr_itr = getattr(attr_itr, atom)\n                    return attr_itr\n\n                attr = fetch_attr(node.target, graph_module)\n                network_fx_graph[converted_node_name][\"C_param\"] = 1\n                network_fx_graph[converted_node_name][\"K_param\"] = -1\n                network_fx_graph[converted_node_name][\"attr\"] = attr\n\n            elif (\n                node.op == \"call_function\"\n            ):  # e.g. 'adaptive.avg.pool2d', 'add', 'cat', 'flatten', 'floordiv', 'getattr', 'getitem', 'hardsigmoid', 'mean', 'mul', 'relu', 'transpose'\n                node_parent, node_children = get_node_parent_children(node)\n                converted_node_name = convert_fx_node_name(node.name)\n                if cls.__verbosity > 2:\n                    print(\n                        \"[build_fx_graph] This is the 'call_function' node: {:}, its parent list: {:}, its children list: {:}\".format(\n                            converted_node_name, node_parent, node_children\n                        )\n                    )\n                network_fx_graph[converted_node_name] = {}\n                network_fx_graph[converted_node_name][\"parents\"] = node_parent\n                network_fx_graph[converted_node_name][\"children\"] = node_children\n                network_fx_graph[converted_node_name][\"fx_op\"] = \"call_function\"\n\n                ### \"convert\" some ops to modules\n\n                # concatenating along K can be handled by reducing the size of the childrens' C appropriately\n                # see fixup_concats, if no dim arg, default is 0 (handled automatically)\n                if node.target == torch.cat and len(node.args) > 1 and node.args[1] == 1:\n                    network_fx_graph[converted_node_name][\"fx_op\"] = \"call_module\"\n                    network_fx_graph[converted_node_name][\"module_type\"] = \"concat\"\n                    network_fx_graph[converted_node_name][\"groups_param\"] = (\n                        \"N/A\"  # just need placeholders\n                    )\n                    network_fx_graph[converted_node_name][\"C_param\"] = \"N/A\"\n                    network_fx_graph[converted_node_name][\"K_param\"] = \"N/A\"\n\n            elif (\n                node.op == \"call_method\"\n            ):  # e.g. 'chunk', 'contiguous', 'mean', 'size', 'unsqueeze', 'view'\n                node_parent, node_children = get_node_parent_children(node)\n                converted_node_name = convert_fx_node_name(node.name)\n                if cls.__verbosity > 2:\n                    print(\n                        \"[build_fx_graph] This is the 'call_method' node: {:}, its parent list: {:}, its children list: {:}\".format(\n                            converted_node_name, node_parent, node_children\n                        )\n                    )\n                network_fx_graph[converted_node_name] = {}\n                network_fx_graph[converted_node_name][\"parents\"] = node_parent\n                network_fx_graph[converted_node_name][\"children\"] = node_children\n                network_fx_graph[converted_node_name][\"fx_op\"] = \"call_method\"\n                continue\n\n            elif node.op == \"call_module\":\n                node_parent, node_children = get_node_parent_children(node)\n                converted_node_name = convert_fx_node_name(node.name)\n                # check whether the converted_node_name is same as node.target, especially for ReLU case\n                if converted_node_name != node.target:\n                    if cls.__verbosity > 2:\n                        print(\n                            \"[build_fx_graph][warning] The target name from Torch.FX is '{:}', the manually converted node name is '{:}', not the same one, choose the converted node name\".format(\n                                node.target, converted_node_name\n                            )\n                        )\n\n                # assume the modules share the same target name have the same type, because converted_node_name may not be obtained by model.named_modules(), like some ReLU (defined in forward function)\n                node_type = module_name_type_dict[node.target]\n                if cls.__verbosity > 2:\n                    print(\n                        \"[build_fx_graph] This is the 'call_module' node: {:}, its parent list: {:}, its children list: {:}, its type: {:}\".format(\n                            converted_node_name, node_parent, node_children, node_type\n                        )\n                    )\n                network_fx_graph[converted_node_name] = {}\n                network_fx_graph[converted_node_name][\"parents\"] = node_parent\n                network_fx_graph[converted_node_name][\"children\"] = node_children\n                network_fx_graph[converted_node_name][\"fx_op\"] = \"call_module\"\n                network_fx_graph[converted_node_name][\"module_type\"] = node_type\n                network_fx_graph[converted_node_name][\"groups_param\"] = module_name_group_conv_dict[\n                    node.target\n                ]\n                network_fx_graph[converted_node_name][\"C_param\"] = module_name_C_dict[node.target]\n                network_fx_graph[converted_node_name][\"K_param\"] = module_name_K_dict[node.target]\n\n            elif node.op == \"output\":\n                if cls.__verbosity > 2:\n                    print(\"[build_fx_graph] This is the 'output' node: {:}\".format(node.target))\n                continue\n\n        if dump_fx_graph:\n            if cls.__verbosity > 0:\n                print(\n                    \"\\n[build_fx_graph] Dump the overall dict for children and parents relationship into JSON file\"\n                )\n            cls.save_graph_to_json(\n                network_fx_graph, save_dumped_graph_path_with_name=save_dumped_fx_graph\n            )\n\n        return network_fx_graph, success\n\n    @classmethod\n    def trace_and_print_raw_fx_graph(cls, model, print_tabular=False, generate_python_code=False):\n        \"\"\"This function is used to find and print the intermediate representation (IR) - Graph representation with Torch.FX features.\"\"\"\n\n        from torch.fx import symbolic_trace\n        import traceback\n\n        # Symbolic tracing frontend - captures the semantics of the module\n        try:\n            symbolic_traced: torch.fx.GraphModule = symbolic_trace(model)\n        except Exception as ex:\n            if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:\n                if cls.__verbosity > 0:\n                    print(ex)\n                    print(\n                        \"\".join(\n                            traceback.format_exception(\n                                etype=type(ex), value=ex, tb=ex.__traceback__\n                            )\n                        )\n                    )\n                    print(\n                        \"\\n[print_raw_fx_graph] Meet the fatal fault when trying to symbolic trace the model with Torch.FX\"\n                    )\n            return None\n\n        # High-level intermediate representation (IR) - Graph representation\n        if cls.__verbosity > 1:\n            print(\"\\n[print_raw_fx_graph] Print the intermediate representation (IR) with Torch.FX\")\n            print(symbolic_traced.graph)\n\n        if print_tabular:\n            print(\n                \"\\n[print_raw_fx_graph] Print the intermediate representation (IR) with Torch.FX in a table format\"\n            )\n            try:\n                from tabulate import tabulate\n\n                symbolic_traced.graph.print_tabular()\n            except ImportError:\n                if cls.__verbosity > 1:\n                    print(\n                        \"[print_raw_fx_graph][Warning] 'print_tabular' relies on the library `tabulate`; run `pip install tabulate` to install it.\"\n                    )\n            except (\n                AttributeError\n            ):  # to avoid the AttributeError: 'Graph' object has no attribute 'print_tabular'\n                if cls.__verbosity > 1:\n                    print(\n                        \"[print_raw_fx_graph][Warning] 'print_tabular' function is not supported in current Torch version. Skip!\"\n                    )\n\n        # Code generation - valid Python code\n        if generate_python_code:\n            print(\n                \"\\n[print_raw_fx_graph] Create valid Python code matching the IR/Graph's semantics with Torch.FX\"\n            )\n            print(symbolic_traced.code)\n\n        return symbolic_traced\n\n    @classmethod\n    def save_graph_to_json(cls, graph, save_dumped_graph_path_with_name=\"./model_fx_graph.json\"):\n        \"\"\"This function is used to save the graph into JSON file for inspection.\"\"\"\n\n        # use dumps to transfer the dict to JSON string\n        json_graph_str = json.dumps(graph)\n        with open(save_dumped_graph_path_with_name, \"w\", encoding=\"utf-8\") as dumped_graph_file:\n            dumped_graph_file.write(\n                json_graph_str\n            )  # write the transferred JSON string into JSON file\n"
  },
  {
    "path": "apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu",
    "content": "#include <pybind11/numpy.h>\n#include <pybind11/pybind11.h>\n#include <stdio.h>\nnamespace py = pybind11;\n\n#define gpuErrchk(ans)                    \\\n  {                                       \\\n    gpuAssert((ans), __FILE__, __LINE__); \\\n  }\ninline void gpuAssert(cudaError_t code, const char* file, int line, bool abort = true) {\n  if (code != cudaSuccess) {\n    fprintf(stderr, \"GPUassert %d: %s %s %d\\n\", (int)code, cudaGetErrorString(code), file, line);\n    if (abort) exit(code);\n  }\n}\n\n// find the magnitude after enforcing the 2:4 sparsity constraint on a group of 4 values\n__device__ float group_2_to_4(float4 vals) {\n  vals.x = fabs(vals.x);\n  vals.y = fabs(vals.y);\n  vals.z = fabs(vals.z);\n  vals.w = fabs(vals.w);\n\n  float sum0 = vals.x + vals.y;\n  float sum1 = vals.x + vals.z;\n  float sum2 = vals.x + vals.w;\n  float sum3 = vals.y + vals.z;\n  float sum4 = vals.y + vals.w;\n  float sum5 = vals.z + vals.w;\n\n  float best_sum0 = fmax(sum0, sum1);\n  float best_sum1 = fmax(sum2, sum3);\n  float best_sum2 = fmax(sum4, sum5);\n  float best_sum = fmax(fmax(best_sum0, best_sum1), best_sum2);\n\n  return best_sum;\n}\n\ninline float* float_ptr_from_numpy(py::array_t<float>& py_float) { return (float*)py_float.data(); }\n\ninline unsigned int* uint_ptr_from_numpy(py::array_t<unsigned int>& py_uint) { return (unsigned int*)py_uint.data(); }\n\n/**********************************************************\n *  Check for the best permutation for an entire matrix\n **********************************************************/\n__global__ void permute_and_sum_after_2_to_4(float* matrix, unsigned int rows, unsigned int cols, unsigned int* stripes,\n                                             unsigned int total_stripes, unsigned int* permutations, float* output) {\n  // vectorize\n  float4* mat4 = (float4*)matrix;\n  cols /= 4;\n\n  // each thread in a block takes some number of rows\n  size_t num_rows = max((int)ceilf((float)rows / (float)blockDim.x), 1);\n  size_t row_offset = num_rows * threadIdx.x;\n  size_t num_stripes = total_stripes;  // total_stripes / gridDim.x;\n  size_t stripe_offset = 0;            // num_stripes * blockIdx.x;\n  unsigned int localStart = stripe_offset;\n  unsigned int localEnd = localStart + num_stripes;\n\n  // each block takes care of one permutation\n  unsigned int p = blockIdx.x;\n  unsigned int* permutation = &permutations[p * total_stripes * 4];\n\n  float sum = 0.0f;\n  extern __shared__ float s[32][32];\n  float4* local_stripes = (float4*)&s[threadIdx.x];\n  float* local_columns = (float*)&s[threadIdx.x];\n  float4* permuted_local_stripes = (float4*)&local_stripes[num_stripes];\n  float* permuted_local_columns = (float*)&local_columns[num_stripes * 4];\n\n  for (unsigned int r = row_offset; r < row_offset + num_rows; ++r) {\n    if (r >= rows) break;\n\n    // load into smem\n    for (unsigned int s = localStart; s < localEnd; ++s) {\n      unsigned int stripe = stripes[s];\n      local_stripes[s] = mat4[r * cols + stripe];\n    }\n\n// now permute\n#pragma unroll 4\n    for (unsigned int c = 0; c < num_stripes * 4; ++c) {\n      permuted_local_columns[c] = local_columns[permutation[c]];\n    }\n\n    // now sum 2:4\n    for (unsigned int s = 0; s < num_stripes; ++s) {\n      sum += group_2_to_4(permuted_local_stripes[s]);\n    }\n  }\n\n  atomicAdd(&output[p], sum);\n}\n\nvoid free_permutation_memory(float** dmatrix, unsigned int** dstripe_groups, unsigned int** dpermutations,\n                             float** dresults, float** hresults) {\n  cudaFree(*dmatrix);\n  cudaFree(*dresults);\n  cudaFree(*dpermutations);\n  cudaFree(*dstripe_groups);\n  free(*hresults);\n}\n\nint set_up_check_permutation_memory(float** dmatrix, unsigned int rows, unsigned int cols,\n                                    unsigned int** dstripe_groups, unsigned int group_width, unsigned int num_groups,\n                                    unsigned int** dpermutations, unsigned int num_permutations, float** dresults,\n                                    float** hresults) {\n  static unsigned int setupRows = 0;\n  static unsigned int setupCols = 0;\n  static unsigned int setupGroupWidth = 0;\n  static unsigned int setupNumGroups = 0;\n  static unsigned int setupNumPermutations = 0;\n  static bool allocated = false;\n  int fresh_alloc = 0;\n  if (!allocated || setupRows != rows || setupCols != cols || setupGroupWidth != group_width ||\n      setupNumGroups != num_groups || setupNumPermutations != num_permutations) {\n    if (allocated) {\n      free_permutation_memory(dmatrix, dstripe_groups, dpermutations, dresults, hresults);\n    }\n\n    gpuErrchk(cudaMalloc((void**)dmatrix, rows * cols * sizeof(float)));\n    gpuErrchk(cudaMalloc((void**)dstripe_groups, group_width * num_groups * sizeof(unsigned int)));\n    gpuErrchk(cudaMalloc((void**)dpermutations, num_permutations * group_width * 4 * sizeof(unsigned int)));\n    gpuErrchk(cudaMalloc((void**)dresults, num_permutations * sizeof(float)));\n    *hresults = (float*)malloc(num_permutations * sizeof(float));\n\n    allocated = true;\n    setupRows = rows;\n    setupCols = cols;\n    setupGroupWidth = group_width;\n    setupNumGroups = num_groups;\n    setupNumPermutations = num_permutations;\n    fresh_alloc = 1;\n  }\n\n  return fresh_alloc;\n}\n\nint run_check_permutations(\n    py::array_t<float>& py_matrix, unsigned int rows, unsigned int cols,\n    py::array_t<unsigned int>&\n        py_stripe_groups,  // groups of stripes, group_width = stripes per group, num_groups = groups in the array\n    unsigned int group_width, unsigned int num_groups,\n    py::array_t<unsigned int>& py_permutations,  // array of permutations to try, group_width*4 values per each of\n                                                 // num_permutations permutations\n    unsigned int num_permutations,\n    py::array_t<float>& py_improvement,        // improvment offered by the best permutation\n    py::array_t<unsigned int>& py_permutation  // the best permutation\n) {\n  const unsigned int threads = 32;\n  static float* d_matrix;\n  static unsigned int* d_permutations;\n  static unsigned int* d_stripes;\n  static float* d_results;\n  static float* results;\n\n  float* matrix = float_ptr_from_numpy(py_matrix);\n  unsigned int* stripe_groups = uint_ptr_from_numpy(py_stripe_groups);\n  unsigned int* permutations = uint_ptr_from_numpy(py_permutations);\n  float* improvement = float_ptr_from_numpy(py_improvement);\n  unsigned int* permutation = uint_ptr_from_numpy(py_permutation);\n\n  int fresh_alloc = set_up_check_permutation_memory(&d_matrix, rows, cols, &d_stripes, group_width, num_groups,\n                                                    &d_permutations, num_permutations, &d_results, &results);\n  if (fresh_alloc == 1) {\n    gpuErrchk(cudaMemcpy(d_permutations, permutations, num_permutations * group_width * 4 * sizeof(unsigned int),\n                         cudaMemcpyHostToDevice));\n    gpuErrchk(\n        cudaMemcpy(d_stripes, stripe_groups, group_width * num_groups * sizeof(unsigned int), cudaMemcpyHostToDevice));\n  }\n\n  // initialize results, new matrix\n  gpuErrchk(cudaMemset(d_results, 0, num_permutations * sizeof(float)));\n  gpuErrchk(cudaMemcpy(d_matrix, matrix, rows * cols * sizeof(float), cudaMemcpyHostToDevice));\n\n  // get results for all permutations\n  permute_and_sum_after_2_to_4<<<num_permutations, threads, threads * group_width * 4 * 2 * sizeof(float)>>>(\n      d_matrix, rows, cols, d_stripes, group_width, d_permutations, d_results);\n  gpuErrchk(cudaDeviceSynchronize());\n\n  gpuErrchk(cudaMemcpy(results, d_results, num_permutations * sizeof(float), cudaMemcpyDeviceToHost));\n\n  // find the best permutation - could reduce on GPU\n  unsigned int best_permutation = 0;\n  float best_improvement = 0.0f;\n  for (unsigned int p = 1; p < num_permutations; ++p) {\n    float cur_improvement = results[p] - results[0];\n    if (best_improvement < cur_improvement) {\n      best_permutation = p;\n      best_improvement = cur_improvement;\n    }\n  }\n\n  *improvement = best_improvement;\n  *permutation = best_permutation;\n\n  return 0;\n}\n\n///////////////////////////////////////////////////////////\n\n/**********************************************************\n * Get the magnitude of a matrix after applying 2:4\n **********************************************************/\n// find the magnitude after enforcing the 2:4 sparsity constraint on a subset of the columns of an input matrix\n__global__ void subset_sum_after_2_to_4(float* matrix, unsigned int rows, unsigned int cols, unsigned int start_col,\n                                        unsigned int end_col, float* output) {\n  // vectorize\n  float4* mat4 = (float4*)matrix;\n  cols /= 4;\n  start_col /= 4;\n  end_col /= 4;\n\n  // each thread in a block takes some number of rows\n  size_t num_rows = max((int)ceilf((float)rows / (float)blockDim.x), 1);\n  size_t row_offset = num_rows * threadIdx.x;\n  // each block takes some number of columns\n  size_t num_cols = (end_col - start_col) / gridDim.x;\n  size_t col_offset = num_cols * blockIdx.x;\n  start_col += col_offset;\n  end_col = start_col + num_cols;\n\n  float sum = 0.0f;\n  for (unsigned int r = row_offset; r < row_offset + num_rows; ++r) {\n    if (r < rows) {\n      for (unsigned int c = start_col; c < end_col; c++) {\n        sum += group_2_to_4(mat4[r * cols + c]);\n      }\n    }\n  }\n\n  atomicAdd(output, sum);\n}\n\n// build the entire permute map at once\n// each block handles one group of stripes\n// each threads in the block handle all handle the same permutation at the same time on different rows before moving to\n// the next permutation\n__global__ void build_permute_map(float* matrix, unsigned int rows, unsigned int cols, unsigned int* stripes,\n                                  unsigned int group_width, unsigned int* permutations, unsigned int num_permutations,\n                                  unsigned int perm_length, float* output, unsigned int* best_indices) {\n  // vectorize\n  float4* mat4 = (float4*)matrix;\n  cols /= 4;\n\n  // each block handles a group of stripes\n  unsigned int* stripe_group = (unsigned int*)&stripes[blockIdx.x * group_width];\n\n  // shared memory: 32 threads each need 16*2\n  extern __shared__ float pm_shared[32][32];\n  float4* local_stripes = (float4*)&pm_shared[threadIdx.x];\n  float* local_columns = (float*)&pm_shared[threadIdx.x];\n  float4* permuted_stripes = (float4*)&local_stripes[4];\n  float* permuted_columns = (float*)&local_columns[16];\n\n  // each thread handles all permutations in the row before moving on to the next row\n  size_t num_rows = max((int)ceilf((float)rows / (float)blockDim.x), 1);\n  size_t row_offset = num_rows * threadIdx.x;\n\n  for (unsigned int r = row_offset; r < row_offset + num_rows; ++r) {\n    if (r >= rows) break;\n\n    // load a row into smem\n    for (unsigned int s = 0; s < group_width; ++s) {\n      unsigned int const stripe = stripe_group[s];\n      local_stripes[s] = mat4[r * cols + stripe];\n    }\n\n    for (unsigned int p = 0; p < num_permutations; ++p) {\n      unsigned int* permutation = &permutations[p * perm_length];\n      float sum = 0.0f;\n\n// permute\n#pragma unroll 4\n      for (unsigned int c = 0; c < group_width * 4; ++c) {\n        permuted_columns[c] = local_columns[permutation[c]];\n      }\n\n      // sum 2:4\n      for (unsigned int s = 0; s < group_width; ++s) {\n        sum += group_2_to_4(permuted_stripes[s]);\n      }\n\n      // update the running sum for this stripe group's permutation\n      atomicAdd(&output[blockIdx.x * num_permutations + p], sum);\n    }\n  }\n\n  // at this point, each permutation's sum in this stripe group has been calculated\n  // now, find the best option\n  __syncthreads();\n\n  if (threadIdx.x == 0) {\n    unsigned int best_permutation = 0;\n    float best_magnitude = output[blockIdx.x * num_permutations];\n    float base_magnitude = best_magnitude;\n\n    // #pragma unroll 32\n    for (unsigned int p = 1; p < num_permutations; ++p) {\n      float magnitude = output[blockIdx.x * num_permutations + p];\n      if (magnitude > best_magnitude) {\n        best_permutation = p;\n        best_magnitude = magnitude;\n      }\n    }\n\n    output[blockIdx.x * num_permutations] = best_magnitude - base_magnitude;\n    best_indices[blockIdx.x] = best_permutation;\n  }\n}\n\nvoid free_sum_after_2_to_4_memory(float** dmatrix, float** dresult) {\n  cudaFree(*dmatrix);\n  cudaFree(*dresult);\n}\n\nint set_up_sum_after_2_to_4_memory(float** dmatrix, unsigned int rows, unsigned int cols, float** dresult) {\n  static unsigned int setupRows = 0;\n  static unsigned int setupCols = 0;\n  static bool allocated = false;\n\n  int fresh_allocation = 0;\n  if (!allocated || setupRows != rows || setupCols != cols) {\n    if (allocated) free_sum_after_2_to_4_memory(dmatrix, dresult);\n\n    gpuErrchk(cudaMalloc((void**)dmatrix, rows * cols * sizeof(float)));\n    gpuErrchk(cudaMalloc((void**)dresult, sizeof(float)));\n\n    setupRows = rows;\n    setupCols = cols;\n\n    fresh_allocation = 1;\n  }\n\n  allocated = true;\n\n  return fresh_allocation;\n}\n\nint run_subset_sum_after_2_to_4(py::array_t<float>& py_matrix, unsigned int rows, unsigned int cols,\n                                unsigned int start_col, unsigned int end_col, unsigned int blocks, unsigned int threads,\n                                py::array_t<float>& py_output) {\n  static float* d_matrix;\n  static float* d_result;\n\n  int fresh_allocation = set_up_sum_after_2_to_4_memory(&d_matrix, rows, cols, &d_result);\n\n  float* matrix = float_ptr_from_numpy(py_matrix);\n  float* output = float_ptr_from_numpy(py_output);\n\n  gpuErrchk(cudaMemcpy(d_matrix, matrix, rows * cols * sizeof(float), cudaMemcpyHostToDevice));\n  gpuErrchk(cudaMemset(d_result, 0, sizeof(float)));\n\n  subset_sum_after_2_to_4<<<blocks, threads>>>(d_matrix, rows, cols, start_col, end_col, d_result);\n  gpuErrchk(cudaDeviceSynchronize());\n\n  gpuErrchk(cudaMemcpy(output, d_result, sizeof(float), cudaMemcpyDeviceToHost));\n\n  return 0;\n}\n\nvoid set_up_permute_map_memory(float** dmatrix, unsigned int rows, unsigned int cols, unsigned int** dstripes,\n                               unsigned int num_groups, unsigned int group_width, unsigned int** dpermutations,\n                               unsigned int num_permutations, unsigned int perm_length, float** doutput,\n                               unsigned int** dindices, float** hresult, unsigned int** hindices) {\n  static unsigned int setUpRows = 0;\n  static unsigned int setUpCols = 0;\n  static unsigned int setUpGroupWidth = 0;\n  static unsigned int setUpNumGroups = 0;\n  static unsigned int setUpNumPerms = 0;\n  static unsigned int setUpPermLength = 0;\n\n  if (setUpRows != rows || setUpCols != cols) {\n    if (*dmatrix != NULL) {\n      gpuErrchk(cudaFree(*dmatrix));\n      *dmatrix = NULL;\n    }\n    gpuErrchk(cudaMalloc((void**)dmatrix, rows * cols * sizeof(float)));\n  }\n\n  if (setUpGroupWidth < group_width || setUpNumGroups < num_groups) {\n    if (*dstripes != NULL) {\n      gpuErrchk(cudaFree(*dstripes));\n      *dstripes = NULL;\n    }\n    gpuErrchk(cudaMalloc((void**)dstripes, num_groups * group_width * sizeof(unsigned int)));\n\n    if (setUpNumGroups < num_groups) {\n      if (*dindices != NULL) {\n        gpuErrchk(cudaFree(*dindices));\n        *dindices = NULL;\n      }\n      gpuErrchk(cudaMalloc((void**)dindices, num_groups * sizeof(unsigned int)));\n      if (*hindices != NULL) {\n        free(*hindices);\n        *hindices = NULL;\n      }\n      *hindices = (unsigned int*)malloc(num_groups * sizeof(unsigned int));\n    }\n  }\n\n  if (setUpNumPerms < num_permutations || setUpPermLength < perm_length) {\n    if (*dpermutations != NULL) {\n      gpuErrchk(cudaFree(*dpermutations));\n      *dpermutations = NULL;\n    }\n    gpuErrchk(cudaMalloc((void**)dpermutations, perm_length * num_permutations * sizeof(unsigned int)));\n  }\n\n  if (setUpNumPerms < num_permutations || setUpNumGroups < num_groups) {\n    if (*doutput != NULL) {\n      gpuErrchk(cudaFree(*doutput));\n      *doutput = NULL;\n    }\n    gpuErrchk(cudaMalloc((void**)doutput, num_permutations * num_groups * sizeof(float)));\n    if (*hresult != NULL) {\n      free(*hresult);\n      *hresult = NULL;\n    }\n    *hresult = (float*)malloc(num_permutations * num_groups * sizeof(float));\n  }\n\n  setUpRows = rows;\n  setUpCols = cols;\n  setUpGroupWidth = group_width;\n  setUpNumGroups = num_groups;\n  setUpNumPerms = num_permutations;\n  setUpPermLength = perm_length;\n}\n\nint run_build_permute_map(py::array_t<float>& py_matrix, unsigned int rows, unsigned int cols,\n                          py::array_t<unsigned int>& py_stripes, unsigned int num_groups, unsigned int group_width,\n                          py::array_t<unsigned int>& py_permutations, unsigned int perm_length,\n                          py::array_t<float>& py_improvements, py::array_t<unsigned int>& py_best_indices) {\n  static float* d_matrix = NULL;\n  static unsigned int* d_stripes = NULL;\n  static unsigned int* d_permutations = NULL;\n  static float* d_output = NULL;\n  static unsigned int* d_indices = NULL;\n  static float* hresult = NULL;\n  static unsigned int* hindices = NULL;\n\n  const unsigned int num_permutations = py_permutations.size() / perm_length;\n\n  const unsigned int MAX_GROUPS_PER_LAUNCH = num_permutations <= 5775 ? 1820 : 40;\n  const unsigned int full_launches = num_groups / MAX_GROUPS_PER_LAUNCH;\n  const unsigned int final_launch = num_groups % MAX_GROUPS_PER_LAUNCH;\n  const unsigned int launches = full_launches + (final_launch != 0 ? 1 : 0);\n\n  set_up_permute_map_memory(&d_matrix, rows, cols, &d_stripes, min(num_groups, MAX_GROUPS_PER_LAUNCH), group_width,\n                            &d_permutations, num_permutations, perm_length, &d_output, &d_indices, &hresult, &hindices);\n\n  float* matrix = float_ptr_from_numpy(py_matrix);\n  unsigned int* stripes = uint_ptr_from_numpy(py_stripes);\n  unsigned int* permutations = uint_ptr_from_numpy(py_permutations);\n  float* improvements = float_ptr_from_numpy(py_improvements);\n  unsigned int* best_indices = uint_ptr_from_numpy(py_best_indices);\n\n  gpuErrchk(cudaMemcpy(d_matrix, matrix, rows * cols * sizeof(float), cudaMemcpyHostToDevice));\n  gpuErrchk(cudaMemcpy(d_permutations, permutations, num_permutations * perm_length * sizeof(unsigned int),\n                       cudaMemcpyHostToDevice));\n\n  unsigned int group_offset = 0;\n  for (unsigned int l = 0; l < launches; ++l) {\n    unsigned int groups_this_launch = (l < full_launches) ? MAX_GROUPS_PER_LAUNCH : final_launch;\n\n    gpuErrchk(cudaMemcpy(d_stripes, &stripes[group_offset * group_width],\n                         groups_this_launch * group_width * sizeof(unsigned int), cudaMemcpyHostToDevice));\n    gpuErrchk(cudaMemset(d_output, 0, groups_this_launch * num_permutations * sizeof(float)));\n    gpuErrchk(cudaMemset(d_indices, 0, groups_this_launch * sizeof(unsigned int)));\n\n    unsigned int shmem = 32 * (32) * sizeof(float);\n    build_permute_map<<<groups_this_launch, 32, shmem>>>(d_matrix, rows, cols, d_stripes, group_width, d_permutations,\n                                                         num_permutations, perm_length, d_output, d_indices);\n    gpuErrchk(cudaDeviceSynchronize());\n\n    gpuErrchk(\n        cudaMemcpy(hresult, d_output, num_permutations * groups_this_launch * sizeof(float), cudaMemcpyDeviceToHost));\n    gpuErrchk(cudaMemcpy(hindices, d_indices, groups_this_launch * sizeof(unsigned int), cudaMemcpyDeviceToHost));\n\n    // thread0 stuck the minimum in the first slot of each group\n    for (unsigned int g = 0; g < groups_this_launch; ++g) {\n      improvements[group_offset + g] = hresult[g * num_permutations];\n      best_indices[group_offset + g] = hindices[g];\n    }\n\n    group_offset += groups_this_launch;\n  }\n\n  return 0;\n}\n\n/**********************************************************\n * Build the swap map for channel_swaps\n **********************************************************/\n// find the magnitude improvement if some columns were swapped (check all pairs of columns in all the stripe_pairs)\n__global__ void swap_columns_sum_after_2_to_4(float* matrix, unsigned int rows, unsigned int cols,\n                                              unsigned int* stripe_pairs, float* output) {\n  // vectorize\n  float4* mat4 = (float4*)matrix;\n  cols /= 4;\n\n  // each thread takes some number of rows\n  size_t const num_rows = max((int)ceilf((float)rows / (float)blockDim.x), 1);\n  size_t const row_offset = num_rows * threadIdx.x;\n\n  // each block is repsonsible for a pair of stripes\n  unsigned int const stripe0 = stripe_pairs[2 * blockIdx.x];\n  unsigned int const stripe1 = stripe_pairs[2 * blockIdx.x + 1];\n  // space for 32 threads, 8 values (2 stripes) in use at a time, plus 16 partial sums and one base sum\n  extern __shared__ float cs[32][32];\n  float4* local_stripe0 = (float4*)&cs[threadIdx.x][0];\n  float* local_cols0 = (float*)&cs[threadIdx.x][0];\n  float4* local_stripe1 = (float4*)&cs[threadIdx.x][4];\n  float* local_cols1 = (float*)&cs[threadIdx.x][4];\n  float* local_psum = (float*)&cs[threadIdx.x][8];\n  float* base_psum = (float*)&cs[threadIdx.x][24];\n\n  *base_psum = 0.0f;\n  for (unsigned int s = 0; s < 16; ++s) {\n    local_psum[s] = 0.0f;\n  }\n\n  for (unsigned int r = row_offset; r < row_offset + num_rows; ++r) {\n    if (r >= rows) break;\n    *local_stripe0 = mat4[r * cols + stripe0];\n    *local_stripe1 = mat4[r * cols + stripe1];\n    *base_psum += group_2_to_4(*local_stripe0) + group_2_to_4(*local_stripe1);\n    unsigned int swap_idx = 0;\n    for (unsigned int c0 = 0; c0 < 4; ++c0) {\n      for (unsigned int c1 = 0; c1 < 4; ++c1) {\n        // swap c0 and c1\n        float tmp = local_cols0[c0];\n        local_cols0[c0] = local_cols1[c1];\n        local_cols1[c1] = tmp;\n\n        // grab the sum\n        local_psum[swap_idx] += group_2_to_4(*local_stripe0) + group_2_to_4(*local_stripe1);\n\n        // swap back\n        local_cols1[c1] = local_cols0[c0];\n        local_cols0[c0] = tmp;\n\n        swap_idx++;\n      }\n    }\n  }\n\n  // reduce partial sums, store local diffs in the output\n  __syncthreads();\n  if (threadIdx.x == 0) {\n    for (unsigned int t = 1; t < blockDim.x; ++t) {\n      for (unsigned int swap = 0; swap < 16; ++swap) {\n        local_psum[swap] += cs[t][8 + swap];\n      }\n      *base_psum += cs[t][24];\n    }\n\n    for (unsigned int swap = 0; swap < 16; ++swap) {\n      atomicAdd(&output[blockIdx.x * 16 + swap], local_psum[swap] - (*base_psum));\n    }\n  }\n}\n\nvoid set_up_swap_map_memory(float** dmatrix, unsigned int rows, unsigned int cols, unsigned int** dstripe_pairs,\n                            unsigned int num_pairs, float** dresult) {\n  static unsigned int setupRows = 0;\n  static unsigned int setupCols = 0;\n  static unsigned int setupPairs = 0;\n\n  if (*dmatrix == NULL || setupRows != rows || setupCols != cols) {\n    if (*dmatrix != NULL) {\n      gpuErrchk(cudaFree(*dmatrix));\n      *dmatrix = NULL;\n    }\n    gpuErrchk(cudaMalloc((void**)dmatrix, rows * cols * sizeof(float)));\n    setupRows = rows;\n    setupCols = cols;\n  }\n\n  if (*dstripe_pairs == NULL || *dresult == NULL || setupPairs < num_pairs) {\n    if (*dstripe_pairs != NULL) {\n      gpuErrchk(cudaFree(*dstripe_pairs));\n      *dstripe_pairs = NULL;\n    }\n    if (*dresult != NULL) {\n      gpuErrchk(cudaFree(*dresult));\n      *dresult = NULL;\n    }\n    gpuErrchk(cudaMalloc((void**)dstripe_pairs, num_pairs * 2 * sizeof(unsigned int)));\n    gpuErrchk(cudaMalloc((void**)dresult, num_pairs * 16 * sizeof(float)));\n\n    setupPairs = num_pairs;\n  }\n}\n\nint run_build_swap_map(py::array_t<float>& py_matrix, unsigned int rows, unsigned int cols,\n                       py::array_t<uint32_t>& py_stripe_pairs, py::array_t<float>& py_output) {\n  static float* d_matrix = NULL;\n  static float* d_result = NULL;\n  static unsigned int* d_stripe_pairs = NULL;\n\n  float* matrix = float_ptr_from_numpy(py_matrix);                    //(float*)py_matrix.data();\n  unsigned int* stripe_pairs = uint_ptr_from_numpy(py_stripe_pairs);  //(unsigned int*)py_stripe_pairs.data();\n  float* output = float_ptr_from_numpy(py_output);                    //(float*)py_output.data();\n\n  unsigned int num_pairs = py_stripe_pairs.size() / 2;\n\n  set_up_swap_map_memory(&d_matrix, rows, cols, &d_stripe_pairs, num_pairs, &d_result);\n  gpuErrchk(cudaMemcpy(d_matrix, matrix, rows * cols * sizeof(float), cudaMemcpyHostToDevice));\n  gpuErrchk(cudaMemcpy(d_stripe_pairs, stripe_pairs, 2 * num_pairs * sizeof(unsigned int), cudaMemcpyHostToDevice));\n  gpuErrchk(cudaMemset(d_result, 0, num_pairs * 16 * sizeof(float)));\n\n  unsigned int shmem = 32 * (32) * sizeof(float);\n  swap_columns_sum_after_2_to_4<<<num_pairs, 32, shmem>>>(d_matrix, rows, cols, d_stripe_pairs, d_result);\n  gpuErrchk(cudaDeviceSynchronize());\n\n  gpuErrchk(cudaMemcpy(output, d_result, num_pairs * 16 * sizeof(float), cudaMemcpyDeviceToHost));\n\n  return 0;\n}\n///////////////////////////////////////////////////////////\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"sum_after_2_to_4\", &run_subset_sum_after_2_to_4, \"matrix sum after applying 2:4 (CUDA)\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"build_permute_map\", &run_build_permute_map, \"optimize stripe groups (CUDA)\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"check_permutations\", &run_check_permutations, \"exhaustively check all permutations (CUDA)\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"build_swap_map\", &run_build_swap_map, \"channel swaps (CUDA)\", py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "apex/contrib/sparsity/permutation_search_kernels/__init__.py",
    "content": "from .call_permutation_search_kernels import accelerated_search_for_good_permutation\nfrom .permutation_utilities import sum_after_2_to_4\n"
  },
  {
    "path": "apex/contrib/sparsity/permutation_search_kernels/call_permutation_search_kernels.py",
    "content": "import numpy as np\nfrom .permutation_utilities import *\nfrom .exhaustive_search import Exhaustive_Search\n\n\ndef accelerated_search_for_good_permutation(matrix_group, options=None, verbosity=0):\n    \"\"\"This function is used to call the permutation search CUDA kernels.\n    users can provide prefer search strategy by providing a valid 'options' as a dictionary,\n    or users can implement their customized 'accelerated_search_for_good_permutation' function.\n    \"\"\"\n    input_matrix = matrix_group.cpu().detach().numpy()\n    if verbosity > 1:\n        print(\n            \"\\n[accelerated_search_for_good_permutation] input matrix shape: '{:}'.\".format(\n                input_matrix.shape\n            )\n        )\n\n    result = np.copy(input_matrix)\n    # init a sequential permutation search sequence\n    input_channel_num = matrix_group.size(1)\n    permutation_sequence = [n for n in range(input_channel_num)]\n    duration = 0.0\n\n    if options == None:\n        options = {}\n    if (\n        \"strategy\" not in options\n    ):  # right now, the default permutation search strategy is: 'exhaustive' search\n        options[\"strategy\"] = \"exhaustive\"\n\n    if verbosity > 1:\n        print(\n            \"[accelerated_search_for_good_permutation] the permutation strategy is: '{:} search'.\".format(\n                options[\"strategy\"]\n            )\n        )\n\n    # define sub options for each search strategy\n    if options[\"strategy\"] == \"exhaustive\":\n        # right now, the default options for 'exhaustive' search is: 'exhaustive,8,100'\n        if \"stripe_group_size\" not in options:\n            options[\"stripe_group_size\"] = 8\n        if \"escape_attempts\" not in options:\n            options[\"escape_attempts\"] = 100\n    elif options[\"strategy\"] == \"progressive channel swap\":\n        # just swaps meaningful channels, keeping the good swaps, until the search time limit expires.\n        if \"progressive_search_time_limit\" not in options:\n            options[\"progressive_search_time_limit\"] = 60\n        if \"improvement_threshold\" not in options:\n            options[\"improvement_threshold\"] = 1e-9\n\n    # execute the requested strategy\n    if options[\"strategy\"] == \"exhaustive\":\n        result, duration, permutation_sequence = Exhaustive_Search(\n            result,\n            stripe_group_size=options[\"stripe_group_size\"],\n            escape_attempts=options[\"escape_attempts\"],\n        )\n    elif options[\"strategy\"] == \"progressive channel swap\":\n        real_swap_num = 0\n        start_time = time.perf_counter()\n        while time.perf_counter() - start_time < options[\"progressive_search_time_limit\"]:\n            src = np.random.randint(result.shape[1])\n            dst = np.random.randint(result.shape[1])\n            src_group = int(src / 4)\n            dst_group = int(dst / 4)\n            if src_group == dst_group:  # channel swapping within a stripe does nothing\n                continue\n            new_sum, improvement = try_swap(result, dst, src)\n            if improvement > options[\"improvement_threshold\"]:\n                result[..., [src, dst]] = result[..., [dst, src]]\n                permutation_sequence[src], permutation_sequence[dst] = (\n                    permutation_sequence[dst],\n                    permutation_sequence[src],\n                )\n                real_swap_num += 1\n        duration = time.perf_counter() - start_time\n        if verbosity > 1:\n            print(\n                \"\\tFinally swap {} channel pairs until the search time limit expires.\".format(\n                    real_swap_num\n                )\n            )\n    elif (\n        options[\"strategy\"] == \"user defined\"\n    ):  # need to get the permutated matrix (result) by applying customized permutation search function\n        if verbosity > 1:\n            print(\n                \"[accelerated_search_for_good_permutation] Use the user customized permutation search function!\"\n            )\n    else:\n        if verbosity >= 0:\n            print(\n                \"[accelerated_search_for_good_permutation] Cannot find the implementation of the required strategy!\"\n            )\n\n    if verbosity > 1:\n        print(\n            \"[accelerated_search_for_good_permutation] Take {:.4f} seconds to search the permutation sequence.\".format(\n                duration\n            )\n        )\n\n    return permutation_sequence\n"
  },
  {
    "path": "apex/contrib/sparsity/permutation_search_kernels/channel_swap.py",
    "content": "from .permutation_utilities import *\n\n################################################################################################################\n# Greedy Channel Swaps - iterative, deterministic, can be parallelized\n#   1. Build a map of the magnitude improvement of involved stripes for all pairs of channel swaps\n#   2. Sort the map, march through by decreasing improvement, skipping entries whose stripes have been modified\n#   3. Repeat until there's no entry with positive improvement (convergence)\n################################################################################################################\n\n\n## try swapping columns and tracking magnitude after pruning\ndef try_swap(matrix, dst, src):\n    src_base = sum_after_2_to_4(matrix[..., int(src / 4) * 4 : int(src / 4) * 4 + 4])\n    dst_base = sum_after_2_to_4(matrix[..., int(dst / 4) * 4 : int(dst / 4) * 4 + 4])\n\n    # swap\n    matrix[..., [src, dst]] = matrix[..., [dst, src]]\n\n    # check the Nx4 slices of the swapped columns\n    src_sum = sum_after_2_to_4(matrix[..., int(src / 4) * 4 : int(src / 4) * 4 + 4])\n    dst_sum = sum_after_2_to_4(matrix[..., int(dst / 4) * 4 : int(dst / 4) * 4 + 4])\n\n    # swap back\n    matrix[..., [src, dst]] = matrix[..., [dst, src]]\n\n    return src_sum + dst_sum, (src_sum + dst_sum) - (src_base + dst_base)\n\n\n## convert stripe and a swap indices to columns\ndef stripes_and_swap_idx_to_columns(stripe0, stripe1, idx):\n    i = 0\n    for c0 in range(4):\n        for c1 in range(4):\n            if i == idx:\n                return stripe0 * 4 + c0, stripe1 * 4 + c1\n            i += 1\n    return None\n\n\n## convert columns to stripe and swap indices\ndef columns_to_stripes_and_swap_idx(col0, col1):\n    stripe0 = int(col0 / 4)\n    col0 %= 4\n    stripe1 = int(col1 / 4)\n    col1 %= 4\n\n    idx = 0\n    for c0 in range(4):\n        for c1 in range(4):\n            if c0 == col0 and c1 == col1:\n                return stripe0, stripe1, idx\n            idx += 1\n    return None\n\n\n## build a list of stripe pairs that need their benefits recomputed because one stripe was modified\ndef build_stripe_pairs(matrix, used_stripes):\n    stripe_pairs = []\n    total_stripes = int(matrix.shape[1] / 4)\n\n    used_stripes = np.sort(used_stripes)\n    for stripe0 in range(total_stripes - 1):\n        for stripe1 in range(stripe0, total_stripes):\n            if stripe0 in used_stripes or stripe1 in used_stripes:\n                stripe_pairs.append([stripe0, stripe1])\n\n    return np.asarray(stripe_pairs)\n\n\n## compute the benefit of swapping each pair of columns in the matrix using the GPU\n## only update stripes' columns that appear in used_stripes to avoid unnecessary computations\ndef compute_swap_map(matrix, used_stripes):\n    do_gpu = use_gpu()\n    assert do_gpu\n\n    stripe_pairs = build_stripe_pairs(matrix, used_stripes).astype(np.uint32)\n    matrix_view = matrix.astype(np.float32).flatten()\n    stripe_pairs_view = stripe_pairs.flatten()\n    output = np.zeros((len(stripe_pairs) * 16), dtype=np.float32).flatten()\n    result = permutation_search_cuda_kernels.build_swap_map(\n        matrix_view, matrix.shape[0], matrix.shape[1], stripe_pairs_view, output\n    )\n\n    # translate the flat array from the GPU to a map\n    pair_improvement_map = {}\n    for i, pair in enumerate(stripe_pairs):\n        for swap_idx in range(16):\n            col0, col1 = stripes_and_swap_idx_to_columns(pair[0], pair[1], swap_idx)\n            pair_improvement_map[(col0, col1)] = output[i * 16 + swap_idx]\n    return pair_improvement_map\n\n\n## build the full swap map\ndef build_swap_map(matrix, swap_map, swap_ids, used_stripes, verbosity):\n    improvements = None\n\n    # if we have a GPU and built kernels, pre-compute the needed values\n    do_gpu = use_gpu()\n    if do_gpu:\n        if len(swap_map) == 0:\n            used_stripes = [s for s in range(int(matrix.shape[1] / 4))]\n        improvements = compute_swap_map(matrix, used_stripes)\n\n    idx = 0\n    updates = 0\n    for src in range(matrix.shape[1] - 1):  # parallelize these loops\n        for dst in range(src + 1, matrix.shape[1]):\n            # swapping within a stripe does nothing\n            if int(src / 4) == int(dst / 4):\n                continue\n\n            # if we touched this stripe last time, update it\n            if (\n                (int(src / 4) in used_stripes)\n                or (int(dst / 4) in used_stripes)\n                or len(swap_map) <= idx\n            ):\n                tmp_improvement = 0.0\n\n                # use the pre-computed values from the GPU if possible, otherwise compute on the CPU\n                if do_gpu:\n                    tmp_improvement = improvements[(src, dst)]\n                else:\n                    tmp_mag, tmp_improvement = try_swap(matrix, src, dst)\n                updates += 1\n\n                if len(swap_map) <= idx:\n                    swap_map.append(tmp_improvement)\n                    swap_ids.append((src, dst))\n                else:\n                    swap_map[idx] = tmp_improvement\n                    swap_ids[idx] = (src, dst)\n\n            idx += 1\n\n    if verbosity > 15:\n        print(f\"\\tupdated {updates} map entries\")\n    return swap_map, swap_ids\n\n\ndef use_swap_map(\n    matrix,\n    swap_map,\n    swap_ids,\n    threshold,\n    used_escape_attempts,\n    escape_attempts,\n    permutation,\n    verbosity,\n):\n    used_stripes = []\n    swaps = 0\n    improvement = 0.0\n\n    # set the traversal order and threshold\n    ix = np.flip(np.argsort(swap_map))  # small to large -> large to small\n    threshold = min(max(swap_map[ix[0]] * threshold, 0.0001), 1.0)\n\n    # iterate through the potential swaps in benefit order\n    for swap in range(len(ix)):\n        swap_id = ix[swap]\n        src = swap_ids[swap_id][0]\n        dst = swap_ids[swap_id][1]\n\n        # early-out of swaps that are below the threshold (don't be so greedy)\n        if swap_map[ix[swap]] < threshold:\n            # see if an arbitrary swap helps things if we've converged\n            if len(used_stripes) == 0 and used_escape_attempts < escape_attempts:\n                swap_id = np.random.randint(len(swap_ids))\n                if verbosity > 15:\n                    print(\n                        f\"converged, attempt #{used_escape_attempts + 1} to jiggle out, using index {swap_id} into the sorted list={ix[swap_id]}\"\n                    )\n                swap_id = ix[swap_id]\n                src = swap_ids[swap_id][0]\n                dst = swap_ids[swap_id][1]\n                used_escape_attempts += 1\n            else:\n                break\n\n        # skip swaps that include a stripe we've already modified\n        if int(src / 4) in used_stripes or int(dst / 4) in used_stripes:\n            continue\n\n        # we'll need to update these stripes later\n        used_stripes.append(int(src / 4))\n        used_stripes.append(int(dst / 4))\n\n        # make the swap\n        if verbosity > 20:\n            print(f\"\\t{swap}\\t{src},{dst}  {swap_map[swap_id]:.4f}\")\n        matrix[..., [src, dst]] = matrix[..., [dst, src]]\n        permutation[src], permutation[dst] = permutation[dst], permutation[src]\n        improvement += swap_map[swap_id]\n        swaps += 1\n\n    return (\n        matrix,\n        swaps,\n        swap_map,\n        swap_ids,\n        used_stripes,\n        improvement,\n        used_escape_attempts,\n        permutation,\n    )\n\n\ndef Channel_Swap(matrix, escape_attempts=0, verbosity=0, permutation=None):\n    threshold = 0.00001\n    used_escape_attempts = 0\n\n    # initialize\n    if permutation is None:\n        permutation = [c for c in range(matrix.shape[1])]\n    swap_map = []\n    swap_ids = []\n    used_stripes = []\n    swap_count = 0\n    iterations = 0\n    agg_improvement = 0.0\n    cur_total_sum = sum_after_2_to_4(matrix)\n    start_time = time.perf_counter()\n\n    # do the work\n    swapped = 1  # just start with nonzero value to fall into the loop\n    while swapped > 0:\n        swap_map, swap_ids = build_swap_map(matrix, swap_map, swap_ids, used_stripes, verbosity)\n        (\n            matrix,\n            swapped,\n            swap_map,\n            swap_ids,\n            used_stripes,\n            improvement,\n            used_escape_attempts,\n            permutation,\n        ) = use_swap_map(\n            matrix,\n            swap_map,\n            swap_ids,\n            threshold,\n            used_escape_attempts,\n            escape_attempts,\n            permutation,\n            verbosity,\n        )\n        agg_improvement += improvement\n\n        # keep track of statistics, print occasionally\n        swap_count += swapped\n        if verbosity > 10:\n            iterations += 1\n            cur_total_sum += agg_improvement\n            duration = time.perf_counter() - start_time\n            print(\n                f\"\\t{iterations:8} {cur_total_sum:7.2f} {agg_improvement:7.2f} {swap_count:4} {agg_improvement / max(swap_count, 1):5.2f} {duration:7.2f}\"\n            )\n            agg_improvement = 0.0\n            swap_count = 0\n\n    # final status\n    seconds = time.perf_counter() - start_time\n\n    return matrix, seconds, permutation\n"
  },
  {
    "path": "apex/contrib/sparsity/permutation_search_kernels/exhaustive_search.py",
    "content": "from .permutation_utilities import *\n\nASP_CACHE_DIR_ENV_VAR = \"APEX_ASP_CACHE_DIR\"\nASP_CACHE_DIR_DEFAULT = \".cache\"\n\n################################################################################################################\n# Exhaustive\n#   Try them all\n#   - order of columns within a group doesn't matter\n#   - order of groups doesn't matter\n#   - we can eliminate effective duplicates by defining aunique combination to be a sorted list of sorted groups\n################################################################################################################\n\n####################################################################\n# generate unique permutations\n####################################################################\n\n\n# check if adding a column index to a current permutation would keep it in canonical form\n# assumes that perm is in canonical form already!\ndef is_canonical(perm, col):\n    # if it's a new group\n    if len(perm) % 4 == 0:\n        # every column ID < col needs to be in the permutation already\n        for val in range(col):\n            if val not in perm:\n                return False\n        # this new group needs to be sorted w.r.t. the previous group\n        return col > perm[-4]\n\n    # not a new group, just check to see if it will still be sorted\n    return col > perm[-1]\n\n\n# recursive: build a unique permutation one column index at a time\ndef generate_unique_combinations(\n    built_permutation, remaining_columns, full_permutation_list, group_width\n):\n    # base case: nothing else to add\n    if len(remaining_columns) == 0:\n        full_permutation_list.append(np.copy(built_permutation))\n        if len(full_permutation_list) % 1000000 == 0:\n            print(f\"{len(full_permutation_list)} unique permutations found so far\")\n\n    # still more choices to make, so add each remaining column in turn column if it keeps everything sorted\n    else:\n        for c in range(len(remaining_columns)):\n            # to satisfy our immutables (values within groups are sorted, groups are globally sorted),\n            # only add this column if either:\n            #   it's starting a new group and is larger than the previous group's first entry\n            #   OR\n            #   it's larger than the last value in the built_permutation\n            col_to_add = remaining_columns[c]\n\n            if is_canonical(built_permutation, col_to_add):\n                # add the column to the running permutation, remove it from remaining columns\n                built_permutation.append(col_to_add)\n                remaining_columns.pop(c)\n                # recurse\n                generate_unique_combinations(\n                    built_permutation,\n                    remaining_columns,\n                    full_permutation_list,\n                    group_width,\n                )\n                # remove the most recent column and put it back on the remaining column list where we found it (sorted)\n                remaining_columns.insert(c, built_permutation.pop(-1))\n\n\nimport os\nfrom os import path\n\nunique_permutation_list = {}\n\n\ndef generate_all_unique_combinations(C, M, must_use_all_groups=False):\n    cache_dir_path = os.getenv(ASP_CACHE_DIR_ENV_VAR, ASP_CACHE_DIR_DEFAULT)\n    cache_file_path = path.join(cache_dir_path, f\"permutations_{C}_{M}.npy\")\n\n    global unique_permutation_list\n    if (C, M) not in unique_permutation_list:\n        if path.exists(cache_file_path):\n            unique_permutation_list[(C, M)] = np.load(cache_file_path, allow_pickle=False)\n\n        else:\n            full_permutation_list = []\n            generate_unique_combinations([0], [c for c in range(1, C)], full_permutation_list, M)\n            unique_permutation_list[(C, M)] = full_permutation_list\n            if not path.exists(cache_dir_path):\n                os.makedirs(cache_dir_path)\n            np.save(cache_file_path, full_permutation_list, allow_pickle=False)\n\n    unique_permutations = unique_permutation_list[(C, M)]\n\n    return unique_permutations\n\n\n# analytical solution\nimport math\n\n\ndef predict_unique_combinations(C, M):\n    assert C % M == 0\n    G = int(C / M)\n    return int(int(math.factorial(C)) / (int(math.pow(math.factorial(M), G)) * math.factorial(G)))\n\n\n#################################################################\n# exhaustively try all unique permutations\n#################################################################\n\n\n# exhaustively search the entire matrix\ndef search_matrix(matrix, group_width):\n    # give up quickly if we'd go on forever\n    prediction = predict_unique_combinations(matrix.shape[1], group_width)\n    best_permutation = [c for c in range(matrix.shape[1])]\n    if prediction > 1e10:\n        print(\n            f\"There are {prediction} unique combinations with {matrix.shape[1]} columns and a group width of {group_width}, not searching.\"\n        )\n        return matrix, prediction, best_permutation\n\n    start_time = time.perf_counter()\n    full_permutation_list = generate_all_unique_combinations(matrix.shape[1], group_width)\n\n    # found them, now try them\n    best_improvement = 0.0\n    use_cuda = use_gpu()\n    if (\n        use_cuda and matrix.shape[1] >= 8 and group_width == 4\n    ):  # CUDA path only works for a group width of 4\n        best_improvement, best_permutation = try_permutations_on_matrix(\n            matrix, full_permutation_list\n        )\n    else:\n        base_sum = sum_after_2_to_4(matrix)\n        for i in range(1, len(full_permutation_list)):\n            permutation = full_permutation_list[i]\n            permuted = matrix[:, permutation]\n            cur_improvement = sum_after_2_to_4(permuted) - base_sum\n\n            if cur_improvement > best_improvement:\n                best_improvement = cur_improvement\n                best_permutation = permutation\n    seconds = time.perf_counter() - start_time\n    return matrix[:, best_permutation], seconds, best_permutation, best_improvement\n\n\n#############\n# Stripe group handling\n#############\n\n\n# gather stripes from a larger matrix into a single matrix\ndef collect_stripes(matrix, stripes, group_width):\n    subset = np.zeros((matrix.shape[0], len(stripes) * group_width))\n    for s, stripe in enumerate(stripes):\n        subset[..., s * group_width : s * group_width + group_width] = matrix[\n            ..., stripe * group_width : stripe * group_width + group_width\n        ]\n    return subset\n\n\n# apply the stripe group permutation to the entire permutation\ndef apply_stripe_group_permutation(sgp, stripes, group_width, permutation):\n    new_permutation = permutation.copy()\n    for subset_idx in range(len(sgp)):\n        dst_stripe_idx = stripes[int(subset_idx / group_width)]\n        dst_col_idx = subset_idx % group_width\n\n        subset_val = sgp[subset_idx]\n        src_stripe_idx = stripes[int(subset_val / group_width)]\n        src_col_idx = subset_val % group_width\n\n        new_permutation[dst_stripe_idx * group_width + dst_col_idx] = permutation[\n            src_stripe_idx * group_width + src_col_idx\n        ]\n\n    return new_permutation\n\n\n# generate all possible stripe groups\ndef generate_stripe_groups(num_stripes, window_size):\n    stripe_array = [[c] for c in range(num_stripes)]\n\n    next_stripe_array = []\n    for w in range(1, window_size):\n        for g in range(len(stripe_array)):\n            start_c = stripe_array[g][w - 1] + 1\n            group = stripe_array[g]\n            for c in range(start_c, num_stripes):\n                new_group = group.copy()\n                new_group.append(c)\n                next_stripe_array.append(new_group)\n        stripe_array = next_stripe_array\n        next_stripe_array = []\n\n    return set(tuple(stripe_array[g]) for g in range(len(stripe_array)))\n\n\n# It is not safe to just reset the stripe_set as None here.\n# When calling the Exhaustive_Search in E2E search, the stripe_set will not be reset as None.\nstripe_set = None\nstripe_set_config = None\n\n\n# build the stripe map\ndef build_stripe_map(\n    matrix, group_width, window_size, stripe_map, stripe_ids, perm_map, used_stripes\n):\n    global stripe_set, stripe_set_config\n\n    window_size = int(window_size / group_width)\n\n    if (\n        stripe_set is None\n        or stripe_set_config is None\n        or stripe_set_config != (group_width, window_size)\n    ):\n        num_stripes = int(matrix.shape[1] / group_width)\n        assert group_width * num_stripes == matrix.shape[1]\n        stripe_set = generate_stripe_groups(num_stripes, window_size)\n        stripe_set_config = (group_width, window_size)\n\n    # step through each, update the stripe_map/stripe_ids if necessary\n    updates = 0\n    use_cuda = use_gpu()\n    gpu_list = []\n    gpu_groups = []\n    for i, s in enumerate(stripe_set):\n        sg = []  # build the group of stripes, check if any members changed\n        need_update = i >= len(stripe_map)\n        for stripe in s:\n            sg.append(stripe)\n            if stripe in used_stripes:\n                need_update = True\n\n        # pre-populate if we're building fresh\n        if i >= len(stripe_map):\n            stripe_ids.append(sg)\n            stripe_map.append(0.0)\n            perm_map.append([c for c in range(group_width * window_size)])\n\n        # update entries if needed (only stripe_map and perm_map)\n        if need_update:\n            updates += 1\n\n            if not use_cuda:  # do the work here if using the CPU\n                subset = collect_stripes(matrix, sg, group_width)\n                sub_result, sub_duration, permutation, improvement = search_matrix(\n                    subset, group_width\n                )\n                stripe_map[i] = improvement\n                perm_map[i] = permutation\n            else:  # otherwise, just track the work needed to farm off to the GPU\n                gpu_groups.append(sg)\n                gpu_list.append(i)\n\n    if use_cuda:  # if using the GPU, perform the work\n        matrix_view = np.copy(matrix).astype(np.float32).flatten()\n        all_permutations = generate_all_unique_combinations(window_size * group_width, group_width)\n        num_permutations = len(all_permutations)\n        permutation_view = np.copy(np.asarray(all_permutations)).astype(np.uint32).flatten()\n        stripe_groups_view = np.asarray(gpu_groups).astype(np.uint32).flatten()\n        num_gpu_groups = len(gpu_list)\n        gpu_improvement = np.zeros((num_gpu_groups), dtype=np.float32).flatten()\n        gpu_permutation = np.zeros((num_gpu_groups), dtype=np.uint32).flatten()\n\n        result = permutation_search_cuda_kernels.build_permute_map(\n            matrix_view,\n            matrix.shape[0],\n            matrix.shape[1],\n            stripe_groups_view,\n            num_gpu_groups,\n            window_size,\n            permutation_view,\n            window_size * group_width,\n            gpu_improvement,\n            gpu_permutation,\n        )\n\n        # put the data where python expects it\n        for i in range(len(gpu_list)):\n            stripe_map[gpu_list[i]] = gpu_improvement[i]\n            perm_map[gpu_list[i]] = all_permutations[gpu_permutation[i]]\n\n    return stripe_map, stripe_ids, perm_map\n\n\n# start performing stripe checks\nsm_perturbations = 0\nsm_perturbation_limit = 0\n\n\ndef use_stripe_map(matrix, group_width, stripe_map, stripe_ids, perm_map, permutation):\n    global sm_perturbations, sm_perturbation_limit\n    used_stripes = []\n    stripe_groups_optimized = 0\n    improvement = 0.0\n\n    # set the traversal order\n    ix = np.flip(np.argsort(stripe_map))  # small to large --> large to small\n\n    for i in range(len(ix)):\n        stripe_group_id = ix[i]\n        perm = perm_map[stripe_group_id].copy()\n\n        if stripe_map[stripe_group_id] <= np.finfo(np.float16).tiny * 5.0:\n            # perturbations\n            if len(used_stripes) == 0 and sm_perturbations < sm_perturbation_limit:\n                sm_perturbations += 1\n                # use this permutation, but swap two channels from left/right halves to include two stripes, no matter the group size\n                stripe_group_id = ix[np.random.randint(len(ix))]\n                perm = perm_map[stripe_group_id].copy()\n                # a little easier to escape from\n                src = np.random.randint(int(len(perm) / 2))\n                dst = int(len(perm) / 2) + np.random.randint(int(len(perm) / 2))\n                perm[src], perm[dst] = perm[dst], perm[src]\n            else:\n                break\n\n        stripe_group = stripe_ids[stripe_group_id]\n\n        # don't work on stripes we've already touched\n        touched_stripe = False\n        for stripe in stripe_group:\n            if stripe in used_stripes:\n                touched_stripe = True\n        if touched_stripe:\n            continue\n\n        # apply the permutation we've already found to this stripe group\n        subset = collect_stripes(matrix, stripe_group, group_width)\n        sub_result = subset[..., perm]\n        permutation = apply_stripe_group_permutation(perm, stripe_group, group_width, permutation)\n\n        # scatter the results, track what changed\n        for s, stripe in enumerate(stripe_group):\n            # see if this group is in canonical form (entry 0 a multiple of 4, contiguous values))\n            group = perm[\n                s * group_width : s * group_width + group_width\n            ]  # columns in this group of the used permutation\n            changed = False\n            if group[0] % 4 != 0:\n                changed = True\n            for c in range(1, group_width):\n                if group[c] != group[c - 1] + 1:\n                    changed = True\n                    break\n            # if it's not, then it changed\n            if changed:\n                used_stripes.append(stripe_group[s])\n\n            matrix[..., stripe * group_width : stripe * group_width + group_width] = sub_result[\n                ..., s * group_width : s * group_width + group_width\n            ]\n\n        improvement += stripe_map[stripe_group_id]\n        stripe_groups_optimized += 1\n\n    return (\n        matrix,\n        stripe_groups_optimized,\n        stripe_map,\n        stripe_ids,\n        used_stripes,\n        improvement,\n        permutation,\n    )\n\n\n# entry point for exhaustive searches - both the entire matrix, as well as stripe groups\ndef Exhaustive_Search(matrix, stripe_group_size=-1, escape_attempts=0, permutation=None):\n    global sm_perturbation_limit, sm_perturbations\n    sm_perturbations = 0\n    sm_perturbation_limit = escape_attempts\n    if permutation is None:\n        permutation = [c for c in range(matrix.shape[1])]\n\n    # It is much safer to reset the stripe_set as None in the entry point of Exhaustive_Search\n    global stripe_set, stripe_set_config\n    stripe_set = None\n    stripe_set_config = None\n\n    # only support N:4 for now\n    group_width = 4\n\n    result = np.copy(matrix)\n\n    # if the matrix is too large for a window size of 12, subdivide, then fix up with a global optimization with a window size of 8\n    if group_width == 4 and stripe_group_size == 12 and matrix.shape[1] > 512:\n        stripe_split = int(matrix.shape[1] / 2 / group_width)\n        col_split = stripe_split * group_width\n        result[:, :col_split], durationL, permutation[:col_split] = Exhaustive_Search(\n            result[:, :col_split],\n            stripe_group_size=stripe_group_size,\n            escape_attempts=escape_attempts,\n            permutation=permutation[:col_split],\n        )\n        result[:, col_split:], durationR, permutation[col_split:] = Exhaustive_Search(\n            result[:, col_split:],\n            stripe_group_size=stripe_group_size,\n            escape_attempts=escape_attempts,\n            permutation=permutation[col_split:],\n        )\n        escape_attempts = max(escape_attempts, 100) * 10\n        result, duration, permutation = Exhaustive_Search(\n            result,\n            stripe_group_size=8,\n            escape_attempts=escape_attempts,\n            permutation=permutation,\n        )\n        return result, durationL + durationR + duration, permutation\n\n    # small enough to optimize the entire matrix at once\n    if stripe_group_size != -1 and stripe_group_size < matrix.shape[1]:\n        stripe_map = []\n        stripe_ids = []\n        perm_map = []\n        used_stripes = []\n\n        # in practice, this work will be cached ahead of time; doing it now.\n        # (Reading the cached list from disk can take several seconds, which shouldn't be counted against the search, but amortized over every layer in a network)\n        generate_all_unique_combinations(stripe_group_size, group_width)\n\n        start_time = time.perf_counter()\n\n        while True:\n            # print(\"[Debug][Exhaustive_Search] Before entering the build_stripe_map function.\")\n            # print(\"[Debug][Exhaustive_Search] Now the stripe_set value is: {}\".format(stripe_set))\n            stripe_map, stripe_ids, perm_map = build_stripe_map(\n                result,\n                group_width,\n                stripe_group_size,\n                stripe_map,\n                stripe_ids,\n                perm_map,\n                used_stripes,\n            )\n            (\n                result,\n                stripe_groups_optimized,\n                stripe_map,\n                stripe_ids,\n                used_stripes,\n                improvement,\n                permutation,\n            ) = use_stripe_map(result, group_width, stripe_map, stripe_ids, perm_map, permutation)\n\n            # converged?\n            if len(used_stripes) == 0:\n                break\n\n        duration = time.perf_counter() - start_time\n\n    else:  # no sliding window, single iteration\n        print(\n            f\"Matrix has {matrix.shape[1]} columns and the search window is only {stripe_group_size}: searching exhaustively\"\n        )\n        result, duration, permutation, improvement = search_matrix(matrix, group_width)\n\n    return result, duration, permutation\n"
  },
  {
    "path": "apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py",
    "content": "import numpy as np\nimport subprocess\nimport math\n\ngpus_tested = False\ngpus_found = 0\nkernels_found = True\ntry:\n    import permutation_search_cuda as permutation_search_cuda_kernels\n\n    print(\"Found permutation search CUDA kernels\")\nexcept ImportError:\n    try:\n        from . import permutation_search_cuda as permutation_search_cuda_kernels\n\n        print(\"Found permutation search CUDA kernels for standalone testing\")\n\n    except ImportError:\n        print(\"Could not find permutation search CUDA kernels, falling back to CPU path\")\n        kernels_found = False\n\n\ndef use_gpu(initial_override=True):\n    global gpus_tested, gpus_found, kernels_found\n    if not gpus_tested:\n        if not initial_override:\n            gpus_tested = True\n            return False\n\n        try:\n            gpus_found = str(subprocess.check_output([\"nvidia-smi\", \"-L\"])).count(\"UUID\")\n            print(f\"Found {gpus_found} gpus\")\n        except:\n            gpus_found = 0\n            print(\"Could not find nvidia-smi, please check your cuda installation\")\n\n        gpus_tested = True\n\n    return gpus_found > 0 and kernels_found\n\n\n##############################################################################################\n# pruning utilities\n##############################################################################################\n## apply 2:4 to some matrix\ndef apply_2_to_4(matrix):\n    for row in range(matrix.shape[0]):\n        for col in range(0, matrix.shape[1], 4):\n            ix = np.argsort(np.abs(matrix[row, col : col + 4]))\n            matrix[row, col + ix[0]] = 0.0\n            matrix[row, col + ix[1]] = 0.0\n    return matrix\n\n\n## find the sum of magnitudes if 2:4 were applied to a matrix\ndef sum_after_2_to_4(matrix):\n    cur_sum = 0.0\n    use_cuda = use_gpu()\n    if not use_cuda:\n        for row in range(matrix.shape[0]):\n            for col in range(0, matrix.shape[1], 4):\n                ix = np.argsort(np.abs(matrix[row, col : col + 4]))\n                cur_sum += abs(matrix[row, col + ix[2]])\n                cur_sum += abs(matrix[row, col + ix[3]])\n    else:\n        matrix = matrix.astype(np.float32)\n        cuda_sum = np.zeros((1), dtype=np.float32)\n        matrix_view = np.copy(matrix).flatten()\n        sum_view = cuda_sum.flatten()\n        blocks = max(int(matrix.shape[1] / 4 / 2), 1)\n        threads = min(max(math.ceil(matrix.shape[0] / 4), 1), 1024)\n        result = permutation_search_cuda_kernels.sum_after_2_to_4(\n            matrix_view,\n            matrix.shape[0],\n            matrix.shape[1],\n            0,\n            matrix.shape[1],\n            blocks,\n            threads,\n            sum_view,\n        )\n        cur_sum = sum_view[0]\n    return cur_sum\n\n\n# perform unstructured pruning on some matrix\ndef unstructured_prune(matrix, sparsity):\n    shp = matrix.shape\n    matrix = matrix.flatten()\n    ix = np.argsort(matrix)\n    ix = ix[: int(len(ix) * sparsity)]\n    matrix[ix] = 0.0\n    matrix = np.reshape(matrix, shp)\n    return matrix\n\n\n## try swapping columns and tracking magnitude after pruning\ndef try_swap(matrix, dst, src):\n    src_base = sum_after_2_to_4(matrix[..., int(src / 4) * 4 : int(src / 4) * 4 + 4])\n    dst_base = sum_after_2_to_4(matrix[..., int(dst / 4) * 4 : int(dst / 4) * 4 + 4])\n\n    # swap\n    matrix[..., [src, dst]] = matrix[..., [dst, src]]\n\n    # check the Nx4 slices of the swapped columns\n    src_sum = sum_after_2_to_4(matrix[..., int(src / 4) * 4 : int(src / 4) * 4 + 4])\n    dst_sum = sum_after_2_to_4(matrix[..., int(dst / 4) * 4 : int(dst / 4) * 4 + 4])\n\n    # swap back\n    matrix[..., [src, dst]] = matrix[..., [dst, src]]\n\n    return src_sum + dst_sum, (src_sum + dst_sum) - (src_base + dst_base)\n\n\n## magnitude improvement from the naive 2:4 matrix / how much was lost by naive 2:4 compared to the optimal\ndef efficacy(optimal_lost_magnitude, base_lost_magnitude, cur_lost_magnitude):\n    if base_lost_magnitude == optimal_lost_magnitude:\n        eff = 1.0\n    else:\n        eff = (base_lost_magnitude - cur_lost_magnitude) / (\n            base_lost_magnitude - optimal_lost_magnitude\n        )\n    return eff\n\n\n## find the magnitude if the rows of a matrix were pruned independently, without structure\ndef magnitude_after_pruning_rows(matrix, rate=0.5):\n    magnitude = 0.0\n    cols = matrix.shape[1]\n    for r in range(matrix.shape[0]):\n        rowVals = matrix[r]\n        rowVals = np.sort(np.abs(rowVals))\n        magnitude += np.sum(rowVals[int(cols * rate) :])\n\n    return magnitude\n\n\n##############################################################################################\n# permutation utilities\n##############################################################################################\n\n\n## exhaustively search an entire matrix on the GPU\ndef try_permutations_on_matrix(matrix, permutations):\n    use_cuda = use_gpu()\n    assert use_cuda  # caller should have checked\n    matrix = np.copy(matrix)\n    matrix = matrix.astype(np.float32)\n    matrix_view = np.copy(matrix).flatten()\n    permutations_view = np.copy(np.asarray(permutations)).astype(np.uint32).flatten()\n\n    stripe_groups = np.asarray([[s for s in range(int(matrix.shape[1] / 4))]]).astype(np.uint32)\n    stripe_groups_view = stripe_groups.flatten()\n\n    improvement = np.zeros((1), dtype=np.float32).flatten()\n    permutation = np.zeros((1), dtype=np.uint32).flatten()\n\n    result = permutation_search_cuda_kernels.check_permutations(\n        matrix_view,\n        matrix.shape[0],\n        matrix.shape[1],\n        stripe_groups_view,\n        len(stripe_groups[0]),\n        len(stripe_groups),\n        permutations_view,\n        len(permutations),\n        improvement,\n        permutation,\n    )\n    return improvement[0], permutations[permutation[0]]\n\n\n## find the permutation needed to make matrix A look like matrix B\ndef find_permutation(A, B):\n    permutation = []\n    for col in range(A.shape[1]):\n        Avals = A[..., col]\n        for bcol in range(B.shape[1]):\n            if np.all(Avals - B[..., bcol] == np.zeros(Avals.shape)):\n                permutation.append(bcol)\n                break\n    return permutation\n\n\n########################################\n# reasonable method to find distance between permutations\n# this is used to generate permutations \"between\" two other permutations to divide efficacy space\n#######################################\n\n\n## separate a flat permutation array into its groups, sort each group and the overall order to\n## put the output into a canonical order: if two permutations have the same groups, they should appear identical\ndef make_grouped(A):\n    groups = []\n    for x in range(0, len(A), 4):\n        group = []\n        for c in range(4):\n            group.append(A[x + c])\n        group = np.sort(group)\n\n        groups.append(group)\n    return groups\n\n\n## given two permutations, find the groups they have in common\ndef common_groups(A, B):\n    Ag = make_grouped(A)\n    Bg = make_grouped(B)\n\n    # convert to sets to take the intersection\n    As = set(tuple(Ag[g]) for g in range(len(Ag)))\n    Bs = set(tuple(Bg[g]) for g in range(len(Bg)))\n    common = As.intersection(Bs)\n\n    # flatten\n    C = []\n    for s in common:\n        for v in s:\n            C.append(v)\n\n    # group\n    return make_grouped(C)\n\n\n## given two permutations, remove the groups that are common between them\ndef remove_common_groups(A, B):\n    Ag = make_grouped(A)\n    Bg = make_grouped(B)\n\n    # convert to sets to take set difference\n    As = set(tuple(Ag[g]) for g in range(len(Ag)))\n    Bs = set(tuple(Bg[g]) for g in range(len(Bg)))\n    Ad = As - Bs\n    Bd = Bs - As\n\n    # turn the differences back into flat arrays\n    A = []\n    for s in Ad:\n        for v in s:\n            A.append(v)\n    B = []\n    for s in Bd:\n        for v in s:\n            B.append(v)\n\n    # group to put into canonical order, re-flatten\n    A = make_grouped(A)\n    B = make_grouped(B)\n    A = [item for sublist in A for item in sublist]\n    B = [item for sublist in B for item in sublist]\n\n    return A, B\n\n\n## given two permutations, find which elements in B need to go where to look like A\ndef group_differences(A, B):\n    Ag = make_grouped(A)\n    Bg = make_grouped(B)\n\n    wrong_entries = []\n    # for g,group in enumerate(Bg):\n    for g in range(len(Bg)):\n        group = Bg[g]\n        for i in range(len(group)):\n            val = group[i]\n            if val not in Ag[g]:\n                group_in_a = int(np.where(A == val)[0][0] / 4)\n                wrong_entries.append((val, g, group_in_a))\n\n    return wrong_entries\n\n\n## (val, cur_group, desired_group) ==> dict[(cur_group, desired_group)] = [vals]\ndef dictify(wrong_entries):\n    result = {}\n    for entry in wrong_entries:\n        key = (entry[1], entry[2])\n        if key in result:\n            result[key].append(entry[0])\n        else:\n            result[key] = [entry[0]]\n    return result\n\n\n## move groups of B to where they best match A's groups\ndef move_groups_to_match(B, A, debug=False):\n    Ag = make_grouped(A)\n    Bg = make_grouped(B)\n\n    new_Bg = [[] for g in range(len(Ag))]\n    wrong_entry_dict = dictify(group_differences(A, B))\n\n    if debug:\n        print(f\"MGTM:\\n\\tAg: {Ag}\\n\\tBg: {Bg}\\n\\tWED: {wrong_entry_dict}\")\n\n    moved_groups = []\n\n    keys_to_del = []\n    # move triples to the right spot\n    for k in wrong_entry_dict.keys():\n        if k[0] in moved_groups:\n            keys_to_del.append(k)\n            continue\n\n        if len(wrong_entry_dict[k]) == 3:\n            new_Bg[k[1]] = Bg[k[0]]\n            moved_groups.append(k[0])\n            keys_to_del.append(k)\n            if debug:\n                print(f\"MGTM: moved triple {wrong_entry_dict[k]} from group {k[0]} to group {k[1]}\")\n\n    for k in keys_to_del:\n        del wrong_entry_dict[k]\n    keys_to_del = []\n\n    # move doubles\n    for k in wrong_entry_dict.keys():\n        # if we've already moved the group to which this key belongs, remove it\n        if k[0] in moved_groups:\n            keys_to_del.append(k)\n            continue\n\n        if len(wrong_entry_dict[k]) == 2:\n            if len(new_Bg[k[1]]) == 0:  # move it to its requested destination if possible\n                new_Bg[k[1]] = Bg[k[0]]\n                keys_to_del.append(k)\n                assert k[0] not in moved_groups\n                moved_groups.append(k[0])\n                if debug:\n                    print(\n                        f\"MGTM: moved double {wrong_entry_dict[k]} from group {k[0]} to its preferred group {k[1]}\"\n                    )\n            elif len(new_Bg[k[0]]) == 0:  # otherwise leave it where it is (if possible)\n                new_Bg[k[0]] = Bg[k[0]]\n                keys_to_del.append(k)\n                assert k[0] not in moved_groups\n                moved_groups.append(k[0])\n                if debug:\n                    print(f\"MGTM: left double {wrong_entry_dict[k]} where it was in group {k[0]}\")\n    for k in keys_to_del:\n        del wrong_entry_dict[k]\n    keys_to_del = []\n\n    # move singles\n    # try to leave things where they are to prevent oscillating\n    for k in wrong_entry_dict.keys():\n        if k[0] in moved_groups:\n            keys_to_del.append(k)\n            continue\n\n        if len(new_Bg[k[1]]) == 0:  # requested destination\n            new_Bg[k[1]] = Bg[k[0]]\n            keys_to_del.append(k)\n            assert k[0] not in moved_groups\n            moved_groups.append(k[0])\n            if debug:\n                print(\n                    f\"MGTM: moved single {wrong_entry_dict[k]} from group {k[0]} to its preferred group {k[1]}\"\n                )\n\n        elif len(new_Bg[k[0]]) == 0:\n            new_Bg[k[0]] = Bg[k[0]]\n            keys_to_del.append(k)\n            assert k[0] not in moved_groups\n            moved_groups.append(k[0])\n            if debug:\n                print(f\"MGTM: left group {wrong_entry_dict[k]} where it was in group {k[0]}\")\n\n    for k in keys_to_del:\n        del wrong_entry_dict[k]\n    keys_to_del = []\n\n    # put what's left where it'll fit\n    for k in wrong_entry_dict.keys():\n        if k[0] in moved_groups:\n            keys_to_del.append(k)\n            continue\n\n        for dst in range(len(new_Bg)):\n            if len(new_Bg[dst]) == 0:\n                new_Bg[dst] = Bg[k[0]]\n                keys_to_del.append(k)\n                assert k[0] not in moved_groups\n                moved_groups.append(k[0])\n                if debug:\n                    print(\n                        f\"MGTM: put group {wrong_entry_dict[k]} where it found a spot in group {dst}\"\n                    )\n                break\n\n    for k in keys_to_del:\n        del wrong_entry_dict[k]\n    keys_to_del = []\n\n    assert len(wrong_entry_dict) == 0\n    Agsize = sum([len(group) for group in Ag])\n    Bgsize = sum([len(group) for group in new_Bg])\n    assert Agsize == Bgsize\n    new_B = [item for sublist in new_Bg for item in sublist]\n    return new_B\n\n\n## swap two permutation entries and put the permutation into unique order\ndef swap_and_correct(permutation, src, tgt):\n    permutation[src], permutation[tgt] = permutation[tgt], permutation[src]\n    grouped = make_grouped(permutation)\n    grouped = [item for sublist in grouped for item in sublist]\n    return grouped\n\n\n## make a swap that will move B in the direction of A\nnum_diffs = 0\n\n\ndef move_permutation_towards(B, A, debug=False):\n    global num_diffs\n    B = move_groups_to_match(B, A, debug)\n    wrong_entries = group_differences(A, B)\n    num_diffs = len(wrong_entries)\n\n    # nothing to do, early out\n    if len(wrong_entries) == 0:\n        if debug:\n            print(\"MPT: early out\")\n        return B\n\n    if debug:\n        print(f\"MPT: checking {len(wrong_entries)} diffs: {wrong_entries}\")\n\n    # look for a group of three wrong entries that want to do the same thing\n    entry_dict = dictify(wrong_entries)\n    for k in entry_dict.keys():\n        entry = entry_dict[k]\n        if len(entry) == 3:\n            if debug:\n                print(f\"MPT: found a triple swap at {k}: {entry_dict[k]}\")\n            (src, dst) = k\n            # find the index of the one needed to complete the group\n            # the value is the value in A[dst] that's not in B[src]\n            # it's already in the destination group and may or may not need to move\n            group_id = dst\n            Ag = make_grouped(np.copy(A))\n            Bg = make_grouped(np.copy(B))\n            value = -1\n            for c in range(4):\n                if Ag[dst][c] not in Bg[src]:\n                    value = Ag[dst][c]\n                    if debug:\n                        print(f\"\\tMPT: found the missing value {value} in A group {dst} offset {c}\")\n                    break\n            assert value != -1\n\n            # now find that value in B\n            idx0 = np.where(B == value)[0][0]\n            # find the index of the one this group doesn't need\n            # it's a member of the group but not in the dict entry\n            group_id = src\n            for c in range(4):\n                if B[group_id * 4 + c] not in entry_dict[k]:\n                    if debug:\n                        print(f\"\\tMPT: swapping {idx0} and {group_id * 4 + c}\")\n                    return swap_and_correct(B, idx0, group_id * 4 + c)\n\n    # look for a group of two entries that are heading to the same place as another wrong entry\n    victim_loner_pair = None\n    for k in entry_dict.keys():\n        entry = entry_dict[k]\n        if len(entry) == 2:\n            if debug:\n                print(f\"MPT: found a double swap at {k}: {entry_dict[k]}\")\n            (src, dst) = k\n            # find a wrong entry whose dst is the same\n            for k2 in entry_dict.keys():\n                if k2 == k:\n                    continue\n\n                # k2 is a key whose value also belongs in stripe k2[1] (dst2)\n                if dst == k2[1]:\n                    if debug:\n                        print(\n                            f\"\\tMPT: found a loner going in the same direction at {k2}: {entry_dict[k2][0]}\"\n                        )\n                    # instead of moving these three to where they're headed, start merging them by moving the loner into the double\n\n                    # look for a complement: something moving from src to src2\n                    (src2, dst2) = k2\n                    complement_key = (src, src2)\n                    if complement_key in entry_dict:\n                        complement = entry_dict[complement_key][0]\n                        if debug:\n                            print(f\"\\t\\tMPT: found a complement to the loner:{complement}\")\n                        return swap_and_correct(\n                            B,\n                            np.where(B == entry_dict[k2][0])[0][0],\n                            np.where(B == complement)[0][0],\n                        )\n                    # didn't find a complement, choose one of the two in the src group that don't belong\n                    elif victim_loner_pair is None:\n                        for k3 in entry_dict.keys():\n                            if k3 == k:\n                                continue\n\n                            if k3[0] == src:  # found the victim\n                                victim = entry_dict[k3][0]\n                                if debug:\n                                    print(\n                                        f\"\\t\\tMPT: found a victim for the double swap:{k3} -> {victim}\"\n                                    )\n                                victim_loner_pair = (victim, entry_dict[k2][0])\n                                # return swap_and_correct(B, np.where(B == entry_dict[k2][0])[0][0], np.where(B == victim)[0][0])\n\n    if victim_loner_pair is not None:\n        if debug:\n            print(\n                f\"\\t\\tMPT: couldn't find any complements for double swaps, so going with a loner to make a triple: {victim_loner_pair}\"\n            )\n        return swap_and_correct(\n            B,\n            np.where(B == victim_loner_pair[0])[0][0],\n            np.where(B == victim_loner_pair[1])[0][0],\n        )\n\n    # look for one swap that will correct two entries\n    candidate_second = None\n    for we in range(len(wrong_entries)):\n        cur_entry = wrong_entries[we]\n        # if debug:\n        #    print(f\"\\tMPT: checking {cur_entry} for complement\")\n        for we2 in range(0, len(wrong_entries)):\n            pos_swap = wrong_entries[we2]\n            # if debug:\n            #    print(f\"\\t\\tMPT: is {pos_swap}?\")\n            if cur_entry[1] == pos_swap[2] and cur_entry[2] == pos_swap[1]:\n                if debug:\n                    print(f\"\\t\\tfound complements: swapping {cur_entry} and {pos_swap}\")\n                return swap_and_correct(\n                    B,\n                    np.where(B == cur_entry[0])[0][0],\n                    np.where(B == pos_swap[0])[0][0],\n                )\n            elif (\n                wrong_entries[0][2] == pos_swap[1]\n            ):  # if pos_swap is currently where we[0] wants to go, keep it in mind\n                candidate_second = pos_swap\n\n    # fall back on picking the first one we come across\n    assert candidate_second is not None\n    if debug:\n        print(f\"No complement, swapping two entries: {wrong_entries[0]} {candidate_second}\")\n    return swap_and_correct(\n        B,\n        np.where(B == wrong_entries[0][0])[0][0],\n        np.where(B == candidate_second[0])[0][0],\n    )\n\n\n## find a shortest path from permutation A to B\ndef permutation_distance(A, B, matrix=None, magnitude_targets=None, debug=False, verbosity=0):\n    global num_diffs\n    swaps = 0\n    debug = False\n\n    swap_limit = int(math.pow(2, int(len(A) / 4) - 1))\n    num_diffs = swap_limit\n    common = []\n    target_results = None\n    if magnitude_targets is not None:\n        assert matrix is not None\n        cur_mag = sum_after_2_to_4(matrix[:, A])\n        target_results = [(cur_mag, A) for i in range(len(magnitude_targets))]\n\n    if verbosity > 0 and matrix is not None:\n        print(f\"swap {'0':>4} {sum_after_2_to_4(matrix[:, B]):>15.3f}\")\n        if verbosity > 5:\n            print(f\"swap {0:>4}, {make_grouped(A)} {make_grouped(B)}\")\n\n    while not np.all(np.array(A) - np.array(B) == np.zeros(np.array(A).shape)):\n        cGroups = common_groups(A, B)\n        for g in cGroups:\n            common.append(g)\n        A, B = remove_common_groups(A, B)\n        if len(A) == 0:\n            break\n\n        B = move_permutation_towards(np.array(B), np.array(A), debug=debug)\n        swaps += 1\n\n        if matrix is not None:\n            total_cur_permute = [c for c in B]\n\n            for c in [item for sublist in common for item in sublist]:\n                total_cur_permute.append(c)\n\n            if verbosity > 0 or magnitude_targets is not None:\n                cur_mag = sum_after_2_to_4(matrix[:, total_cur_permute])\n                for i in range(len(target_results)):\n                    result = target_results[i]\n                    if abs(magnitude_targets[i] - result[0]) > abs(magnitude_targets[i] - cur_mag):\n                        target_results[i] = (cur_mag, total_cur_permute)\n                if verbosity > 0:\n                    print(f\"swap {swaps:>4} {cur_mag:>15.3f}\")\n\n        if verbosity > 5 or swaps > swap_limit:\n            print(f\"swap {swaps:>4}, {A} {B}, {num_diffs} diffs remain\")\n\n        # safety net\n        if swaps > swap_limit + 3:\n            return swaps, target_results\n\n    return swaps, target_results\n"
  },
  {
    "path": "apex/contrib/sparsity/permutation_tests/README.md",
    "content": "# ChannelPermutations\n\nStandalone code to reproduce results in \"[Channel Permutations for N:M Sparsity](https://proceedings.neurips.cc/paper/2021/hash/6e8404c3b93a9527c8db241a1846599a-Abstract.html),\" Jeff Pool and Chong Yu, NeurIPS 2021.\n\nThree search strategies are supported: randomly generating permutations and checking quality, greedily swapping columns until convergence (i.e. TETRIS adapted for 2:4 sparsity), and the technique presented in the above paper, optimizing stripe groups.  This tool will apply these strategies, as configured below, to either a randomly-generated matrix or an .npy file (typically from a real network) and report the efficacy and runtime of the strategy.\n\n## Quick Start\n\n### Installation\n\n#### GPU path\n\nRequirements:\n- CUDA\n- pybind11\n\nA container such as `nvcr.io/nvidia/pytorch:21.12-py3` satisfies these requirements.\n\nInstallation (from this directory):\n```\npushd ../permutation_search_kernels/CUDA_kernels\nnvcc -O3 -shared -Xcompiler -fPIC -Xcompiler -DTORCH_EXTENSION_NAME=permutation_search_cuda -std=c++11 $(python3 -m pybind11 --includes) permutation_search_kernels.cu -o ../permutation_search_cuda$(python3-config --extension-suffix)\npopd\n```\n\n#### CPU path\n\nOnly NumPy is required for CPU-only execution.\n\n### Important arguments\n\n`python3 permutation_test.py` will tell you all the available arguments and alert you about required arguments:\n```\n    usage: permutation_test.py [-h] [--infile INFILE] [--channels CHANNELS] [--filters FILTERS] \n                               [--verbosity VERBOSITY] [--seed SEED] [--pretty_print PRETTY_PRINT] \n                               [--unstructured UNSTRUCTURED] [--gpu GPU] [--check_permutation CHECK_PERMUTATION] \n                               [--intermediate_steps INTERMEDIATE_STEPS] [--print_permutation PRINT_PERMUTATION]\n                               strategy [strategy ...]\n    permutation_test.py: error: the following arguments are required: strategy\n```\n\nDetailed information about each argument:\n\n- `--infile` (string) accepts .npy files with weights dumped from some model checkpoint.  By default, the input file is `'random'`, which will generate a random 2D matrix with `CHANNELS` columns and `FILTERS` rows.\n- `--channels` and `--filters` (unsigned integers) specify the size of the randomly-generated matrix if there is no input file specified.\n- `--verbosity` (unsigned integer) controls the amount of debug and status information printed.  `0` is just the important data, `11` can give periodic status details, and higher integers provide increasingly more detail.\n- `--seed` (unsigned integer) allows for changing the random seed, which will affect the random matrix generation, random permutations generated, and columns swapped for bounded regressions.\n- `--pretty_print` (bool) prints a pretty graph by default (below), but disabling will generate output friendly for redirecting to a .csv file.\n- `--unstructured` (float) will apply unstructured pruning to the matrix before searching for permutations.  A negative value will find the minimum unstructured sparsity for which a search strategy can find a perfect permutation and not create any extra zeros.\n- `--gpu` (bool) uses CUDA kernels by default (if they are built and there is a GPU available), but you can override this to run on the CPU.\n- `--check_permutation` (bool) makes sure the permutation tracked during the search process matches the one that's recovered directly from the permuted matrix.\n- `--intermediate_steps` (unsigned integer) will emit permutations with efficacies equally dividing the distance between the default order and the best permutation found.\n- `--print_permutation` (bool) prints the permutation found for each strategy.\n\nFinally, after these optional arguments, provide the search strategies desired.  There are three strategies offered:\n- `random,<num_seeds=10>`\n- `channel_swaps,<bounded_regressions=100>`\n- `optimize_stripe_groups,<stripe_group_size_in_columns=8>,<bounded_regressions=100>`\n\n### Launch a test with interesting search strategies\n\nNow that kernels are built, you can use them to accelerate the search, which can be quite time-consuming without using the GPU.  Below, we report results on a number of interesting strategies for a 64-column, 128-row random matrix using a V100 accelerator.\n\n    $ python3 permutation_test.py --channels 64 --filters 128 channel_swap,0 channel_swap,100 channel_swap,1000 optimize_stripe_groups,8,0 optimize_stripe_groups,8,100 optimize_stripe_groups,8,1000 optimize_stripe_groups,12,0 random,1000 random,10000 random,100000\n    Found permutation search CUDA kernels for standalone testing\n    Found 2 gpus\n    strategy                           ,      magnitude,       efficacy,       duration\n    unpruned                           ,       4083.169,       -       ,       -       \n    unstructured                       ,       3060.238,       -       ,       -       \n    50% rows                           ,       3042.332,          100.0,       -       \n    default 2:4                        ,       2852.376,            0.0,          0.000\n    channel_swap,0                     ,       2913.352,           32.1,          0.214               \n    channel_swap,100                   ,       2914.174,           32.5,          2.249               \n    channel_swap,1000                  ,       2920.694,           36.0,         20.248               \n    optimize_stripe_groups,8,0         ,       2919.757,           35.5,          0.013               \n    optimize_stripe_groups,8,100       ,       2919.758,           35.5,          0.152               \n    optimize_stripe_groups,8,1000      ,       2919.935,           35.6,          1.387               \n    optimize_stripe_groups,12,0        ,       2921.947,           36.6,          0.860               \n    random,1000                        ,       2873.380,           11.1,          0.116               \n    random,10000                       ,       2873.603,           11.2,          1.149               \n    random,100000                      ,       2879.129,           14.1,         11.510   \n\nFor this particular input, the `channel_swap` strategy requires 1000 bounded regressions in order to surpass the efficacy of optimizing two stripe groups (8 columns) without any bounded regressions, but allowing 1000 bounded regressions when optimizing two stripe groups is slightly worse than swapping channels with 1000 bounded regressions.  Optimizing *three* stripe groups at a time outperforms all the other approaches by a wide margin.  Testing many random permutations is inefficient and ineffective.\n\nWithout GPU acceleration, these tests would be much slower (though they find the same final permutations):\n\n    $ python3 permutation_test.py --gpu 0 --channels 64 --filters 128 channel_swap,0 channel_swap,100 optimize_stripe_groups,8,0 optimize_stripe_groups,8,100 random,1000\n    strategy                           ,      magnitude,       efficacy,       duration\n    unpruned                           ,       4083.169,       -       ,       -       \n    unstructured                       ,       3060.238,       -       ,       -       \n    50% rows                           ,       3042.332,          100.0,       -       \n    default 2:4                        ,       2852.377,            0.0,          0.016\n    channel_swap,0                     ,       2913.351,           32.1,         55.972\n    channel_swap,100                   ,       2914.174,           32.5,        450.025\n    optimize_stripe_groups,8,0         ,       2919.759,           35.5,         60.653\n    optimize_stripe_groups,8,100       ,       2919.759,           35.5,        465.709\n    random,1000                        ,       2873.381,           11.1,         14.889\n\n\n### Perform the ablation study from Table 1\n\n`bash ablation_studies.sh` will generate the results for the ablation study, showing the relative importance of the bounded regressions and stripe group greedy phase.\n\n### Generate the runtime results from Table 3\n\n`bash runtime_table.sh` will generate the search strategies' efficacies and runtime shown in Table 3.\n\n### Traverse permutation space (as in Figure 3)\n\nWe developed a heuristic approach to interpolating between permutations which allows us to find permutations with efficacies that evenly divide some range.  The `--intermediate_steps <N>` argument can be used to emit such a sequence of permutations:\n\n    $ python3 permutation_test.py --channels 64 --filters 128 --intermediate_steps 7 --print_permutation 1 optimize_stripe_groups,8,0\n    Found permutation search CUDA kernels for standalone testing\n    Found 2 gpus\n    strategy                           ,      magnitude,       efficacy,       duration\n    unpruned                           ,       4083.169,       -       ,       -\n    unstructured                       ,       3060.238,       -       ,       -\n    50% rows                           ,       3042.332,          100.0,       -\n    default 2:4                        ,       2852.377,            0.0,          0.000\n    (2859.8855, [2, 8, 14, 24, 9, 12, 13, 15, 4, 5, 6, 7, 0, 1, 3, 46, 40, 41, 42, 43, 32, 33, 34, 35, 25, 26, 27, 55, 16, 17, 18, 58, 20, 21, 22, 23, 38, 60, 61, 63, 11, 44, 45, 47, 36, 37, 39, 62, 10, 28, 29, 30, 31, 52, 53, 54, 19, 56, 57, 59, 48, 49, 50, 51])\n    (2870.1387, [5, 6, 7, 41, 9, 12, 13, 35, 0, 1, 3, 46, 30, 40, 42, 43, 2, 32, 33, 34, 25, 26, 27, 55, 16, 17, 18, 58, 20, 21, 22, 23, 38, 60, 61, 63, 11, 44, 45, 47, 36, 37, 39, 62, 4, 10, 28, 29, 31, 52, 53, 54, 19, 56, 57, 59, 15, 48, 49, 50, 8, 14, 24, 51])\n    (2878.0679, [36, 37, 39, 62, 9, 12, 13, 35, 0, 3, 16, 46, 30, 40, 42, 43, 2, 5, 32, 33, 23, 26, 27, 55, 1, 20, 21, 22, 38, 60, 61, 63, 11, 44, 45, 47, 6, 7, 25, 41, 4, 10, 28, 29, 31, 52, 53, 54, 19, 56, 57, 59, 15, 48, 49, 50, 8, 14, 24, 51, 17, 18, 34, 58])\n    (2884.8323, [9, 12, 35, 54, 0, 3, 16, 46, 30, 40, 42, 43, 2, 5, 32, 33, 23, 26, 27, 55, 11, 44, 45, 47, 36, 37, 39, 62, 4, 10, 28, 29, 31, 52, 53, 60, 19, 21, 56, 57, 15, 48, 49, 50, 8, 14, 24, 51, 17, 18, 34, 58, 6, 7, 25, 41, 1, 13, 20, 22, 38, 59, 61, 63])\n    (2894.9697, [9, 12, 33, 35, 0, 3, 16, 46, 2, 5, 32, 52, 23, 26, 27, 55, 11, 44, 45, 47, 36, 37, 39, 62, 4, 10, 28, 29, 19, 21, 50, 56, 15, 43, 48, 49, 8, 14, 24, 51, 17, 18, 34, 58, 6, 7, 25, 41, 1, 13, 20, 22, 38, 59, 61, 63, 30, 40, 42, 54, 31, 53, 57, 60])\n    (2901.5115, [9, 12, 35, 56, 0, 3, 16, 46, 23, 26, 27, 55, 33, 36, 37, 39, 4, 10, 28, 29, 19, 21, 45, 50, 8, 14, 24, 51, 17, 18, 34, 58, 6, 7, 25, 41, 1, 13, 20, 22, 38, 59, 61, 63, 30, 40, 42, 54, 31, 53, 57, 60, 2, 5, 32, 52, 15, 43, 49, 62, 11, 44, 47, 48])\n    (2910.2043, [4, 10, 28, 37, 9, 12, 35, 56, 0, 3, 16, 46, 23, 33, 36, 39, 8, 14, 24, 51, 17, 18, 34, 58, 6, 7, 25, 41, 1, 13, 20, 22, 38, 59, 61, 63, 30, 40, 42, 54, 31, 53, 57, 60, 2, 5, 32, 52, 15, 43, 49, 62, 11, 44, 47, 48, 19, 21, 45, 50, 26, 27, 29, 55])\n    optimize_stripe_groups,8,0         ,       2919.757,           35.5,          0.015\n    [0, 9, 12, 35, 4, 10, 28, 37, 50, 19, 45, 21, 34, 17, 18, 58, 16, 46, 39, 3, 49, 43, 15, 62, 6, 7, 41, 25, 48, 11, 44, 47, 13, 20, 22, 1, 55, 29, 26, 27, 5, 2, 32, 52, 40, 30, 42, 54, 53, 57, 60, 31, 36, 56, 23, 33, 59, 38, 61, 63, 51, 24, 14, 8]\n\n### Transform unstructured sparsity to structured sparsity (as in Figure 4)\n\nIf you have a directory with .npy weight files for each layer of a network, `bash unstructured_study.sh <path_to_directory> <network_name>` will perform a binary search for each file to find the minimum unstructured sparsity required to transparently transform that layer with a number of permutation search techniques; this file was used to generate Figure 4, using weights dumped from a pre-trained ResNet50 in Torchvision.\n\n## References\n\nThe baseline algorithm which we adapated for use with 2:4 sparsity and upon which we improved is \"[TETRIS](https://papers.nips.cc/paper/2018/hash/89885ff2c83a10305ee08bd507c1049c-Abstract.html): TilE-matching the TRemendous Irregular Sparsity,\" Ji et al., NeurIPS 2018.\n\nIf you want to use this technique when generating a 2:4 sparse network for inference, we've packaged it into our [ASP](https://github.com/NVIDIA/apex/tree/master/apex/contrib/sparsity) library - this will perform the permutation searches for each layer as required, as well as fix up neighboring layers so there are no extra operations inserted at runtime.\n\n## Citation\n\nIf you use this idea or code in your own research, please cite the [paper](https://proceedings.neurips.cc/paper/2021/hash/6e8404c3b93a9527c8db241a1846599a-Abstract.html) that describes it:\n\n```\n@inproceedings{pool2021channel,\n  author    = {Pool, Jeff and Yu, Chong},\n  booktitle = {Advances in Neural Information Processing Systems ({NeurIPS})},\n  title     = {Channel Permutations for {N:M} Sparsity},\n  url       = {https://proceedings.neurips.cc/paper/2021/file/6e8404c3b93a9527c8db241a1846599a-Paper.pdf},\n  volume    = {34},\n  year      = {2021}\n}\n\n```\n\n"
  },
  {
    "path": "apex/contrib/sparsity/permutation_tests/ablation_studies.sh",
    "content": "#!/bin/bash\n\nOUTDIR=\"results/ablation_logs\"\nmkdir -p $OUTDIR\n\nR1000=random,1000\nCS=channel_swap,0\nCS_100=channel_swap,100\nCS_1000=channel_swap,1000\nOSG2=optimize_stripe_groups,8,0\nOSG2_100=optimize_stripe_groups,8,100\nOSG2_1000=optimize_stripe_groups,8,1000\nOSG3=optimize_stripe_groups,12,0\nOSG3_100=optimize_stripe_groups,12,100\nOSG3_1000=optimize_stripe_groups,12,1000\noptimal=optimize_stripe_groups,16,0\n\n# Table 1\nfor seed in {0..24}; do\n    echo $seed\n    python3 permutation_test.py --channels 16 --filters 32 --seed $seed --pretty_print=False $R1000 $CS $CS_100 $CS_1000 $OSG2 $OSG2_100 $OSG2_1000 $OSG3 $OSG3_100 $OSG3_1000 $optimal | tee \"${OUTDIR}/ablations_32x16_$seed.log\"\n    python3 permutation_test.py --channels 128 --filters 64 --seed $seed --pretty_print=False $R1000 $CS $CS_100 $CS_1000 $OSG2 $OSG2_100 $OSG2_1000 $OSG3 $OSG3_100 $OSG3_1000 | tee \"${OUTDIR}/ablations_64x128_$seed.log\"\ndone\n\necho \"Gathering results ...\"\n\n################# collect results into a .csv file\n# get mean and stddev of efficacy from all seeds for one strategy\nget_mean_stddev() {\n    local strategy=$1\n    local OUTFILE=$2\n\n    # get the strategy's line,                           pull out efficacy and time,              use sum-of-squares to compute stddev and mean in a single pass\n    grep \"$strategy,\" $OUTDIR/ablations_64x128_*.log | awk -F \",\" '{print $3,$4}' | awk '{sum += $1; sumsq += ($1)^2; timesum += $2} END {printf \"%.1f,%.1f,%.2f,\", sum/NR, sqrt((sumsq-sum^2/NR)/NR), timesum/NR}' >> $OUTFILE\n}\n\n# get the number of times some strategy matched the optimal solution\nget_num_optimal() {\n    local strategy=$1\n    local OUTFILE=$2\n\n    matches=0\n    for seed in {0..24}; do\n        # compare floats with epsilon: add one thousandth to the efficacy under test\n        this_eff=$(grep \"$strategy,\" \"${OUTDIR}/ablations_32x16_${seed}.log\" | awk -F \",\" '{print int($3 * 1000 + 1)}')\n        best_eff=$(grep \"optimize_stripe_groups_16_0,\" \"${OUTDIR}/ablations_32x16_${seed}.log\" | awk -F \",\" '{print int($3 * 1000)}')\n        if [ \"$this_eff\" -ge \"$best_eff\" ]; then\n            let \"matches = $matches + 1\"\n        fi\n    done\n\n    printf \"$matches,\" >> $OUTFILE\n}\n\n# populate a row of the ablation study table\npopulate_row() {\n    local greedy=$1\n    local escape=$2\n    local strategy=$(echo \"$3\" | sed 's/,/_/g')\n    local OUTFILE=$4\n\n    printf \"$greedy,$escape,\" >> $OUTFILE\n    get_mean_stddev \"$strategy\" \"$OUTFILE\"\n    printf \",\" >> $OUTFILE\n    get_num_optimal \"$strategy\" \"$OUTFILE\"\n    printf \"\\n\" >> $OUTFILE\n}\n\n# prepare output file header\nOUTFILE=\"results/ablation_studies.csv\"\nprintf \",,25x 64x128,,,,25x 32x16\\n\" > $OUTFILE\nprintf \",,Efficacy,,Runtime,,Optimal\\n\" >> $OUTFILE\nprintf \"Greedy Phase,Escape Phase,Mean,StdDev,Mean,,# Found\\n\" >> $OUTFILE\n\n# finally, gather the data for each strategy into a row of the table\npopulate_row \"Random 1000\" \"-\" \"$R1000\" \"$OUTFILE\"\npopulate_row \"Channel Swap\" \"-\" \"$CS\" \"$OUTFILE\"\npopulate_row \"Channel Swap\" \"BR(100)\" \"$CS_100\" \"$OUTFILE\"\npopulate_row \"Channel Swap\" \"BR(1000)\" \"$CS_1000\" \"$OUTFILE\"\npopulate_row \"OSG(2)\" \"-\" \"$OSG2\" \"$OUTFILE\"\npopulate_row \"OSG(2)\" \"BR(100)\" \"$OSG2_100\" \"$OUTFILE\"\npopulate_row \"OSG(2)\" \"BR(1000)\" \"$OSG2_1000\" \"$OUTFILE\"\npopulate_row \"OSG(3)\" \"-\" \"$OSG3\" \"$OUTFILE\"\npopulate_row \"OSG(3)\" \"BR(100)\" \"$OSG3_100\" \"$OUTFILE\"\npopulate_row \"OSG(3)\" \"BR(1000)\" \"$OSG3_1000\" \"$OUTFILE\"\n\necho \"Done! $OUTFILE\"\n"
  },
  {
    "path": "apex/contrib/sparsity/permutation_tests/permutation_test.py",
    "content": "import numpy as np\nimport time\nimport sys\n\n# permutation-specifics\nsys.path.append(\"../\")\nfrom permutation_search_kernels.permutation_utilities import *\nfrom permutation_search_kernels.exhaustive_search import Exhaustive_Search\nfrom permutation_search_kernels.channel_swap import Channel_Swap\n\n# Arguments\nimport argparse\n\n\ndef str2bool(v):\n    if isinstance(v, bool):\n        return v\n    if v.lower() in (\"yes\", \"true\", \"t\", \"y\", \"1\"):\n        return True\n    elif v.lower() in (\"no\", \"false\", \"f\", \"n\", \"0\"):\n        return False\n    else:\n        raise argparse.ArgumentTypeError(\"Boolean value expected.\")\n\n\nparser = argparse.ArgumentParser(description=\"Test channel permutations\")\nparser.add_argument(\"--infile\", default=\"random\", type=str, help='input file or \"random\"')\nparser.add_argument(\"--channels\", default=384, type=int, help=\"random input channel count (C)\")\nparser.add_argument(\"--filters\", default=96, type=int, help=\"random input filter count (K)\")\nparser.add_argument(\"--verbosity\", default=0, type=int, help=\"print status updates\")\nparser.add_argument(\"--seed\", default=1, type=int, help=\"random seed\")\nparser.add_argument(\n    \"--pretty_print\",\n    default=True,\n    type=str2bool,\n    help=\"print the table for pretty viewing (as opposed to strict .csv)\",\n)\nparser.add_argument(\n    \"--unstructured\",\n    default=0.0,\n    type=float,\n    help='perform unstructured pruning to a target sparsity before processing, emulate an unstructured sparse network. \"-1\" will find the minimum sparsity required to achieve a perfect permutation',\n)\nparser.add_argument(\n    \"--gpu\",\n    default=True,\n    type=str2bool,\n    help=\"uses a gpu to accelerate the search if possible\",\n)\nparser.add_argument(\n    \"--check_permutation\",\n    default=False,\n    type=str2bool,\n    help=\"check that the tracked permutation matches the recovered permutation\",\n)\nparser.add_argument(\n    \"--intermediate_steps\",\n    default=0,\n    type=int,\n    help=\"find roughly evenly-spaced permutations in efficacy\",\n)\nparser.add_argument(\n    \"--print_permutation\",\n    default=False,\n    type=str2bool,\n    help=\"print the final permutation found by each strategy\",\n)\nparser.add_argument(\"strategies\", metavar=\"strategy\", type=str, nargs=\"+\", help=\"strategies to try\")\n\n\n## binary search for the minimum sparsity necessary to achieve a perfect permutation with some strategy\ndef find_minimum_sparsity(matrix, search_function, **kwargs):\n    duration = 0\n    min_sparsity = 50\n    max_sparsity = 100\n    sparsity = 75\n    verbosity = 0\n    if \"verbosity\" in kwargs:\n        verbosity = kwargs[\"verbosity\"]\n\n    while min_sparsity < max_sparsity:\n        if verbosity > 5:\n            print(f\"\\tlooking now at {sparsity} (between {min_sparsity} and {max_sparsity})\")\n\n        # prepare unstructured sparse matrix, get row sparsity magnitude\n        tmp_result = unstructured_prune(result, sparsity / 100.0)\n        local_unpruned_magnitude = np.sum(np.abs(tmp_result))\n        local_unstructured_rows_magnitude = magnitude_after_pruning_rows(tmp_result, rate=0.5)\n\n        # quick check to see if this sparsity is trivially too low\n        if local_unstructured_rows_magnitude * 1.0001 < local_unpruned_magnitude:\n            if verbosity > 5:\n                print(\n                    f\"Skipping sparsity {sparsity} since there's no perfect permutation (unstructured mag {local_unpruned_magnitude} is larger than sparse rows {local_unstructured_rows_magnitude}).\"\n                )\n            min_sparsity = sparsity + 1\n            sparsity = int(min_sparsity + (max_sparsity - min_sparsity) / 2.0)\n            continue\n\n        tmp_result, tmp_duration, found_permutation = search_function(tmp_result, **kwargs)\n        duration += tmp_duration\n        nonzeros = np.count_nonzero(tmp_result)\n        tmp_result = apply_2_to_4(tmp_result)\n        nonzeros_after_2to4 = np.count_nonzero(tmp_result)\n        if nonzeros == nonzeros_after_2to4:  # found a winner, are we done?\n            if verbosity > 3:\n                print(f\"Found an unstructured sparsity that we can turn into 2:4: {sparsity}\")\n\n            max_sparsity = sparsity\n            if max_sparsity <= min_sparsity and verbosity > 0:\n                print(\n                    f\"Found the minimum unstructured sparsity that we can turn into 2:4: {sparsity}\"\n                )\n                break\n        else:\n            if verbosity > 5:\n                print(f\"Unstructured sparsity {sparsity} was insufficient to produce 2:4 sparsity\")\n            min_sparsity = sparsity + 1\n            if max_sparsity <= min_sparsity and verbosity > 0:\n                print(\n                    f\"Found the minimum unstructured sparsity that we can turn into 2:4: {max_sparsity}\"\n                )\n                sparsity = max_sparsity\n                break\n\n        sparsity = int(min_sparsity + (max_sparsity - min_sparsity) / 2.0)\n\n    return sparsity, duration\n\n\n# Entry point\nif __name__ == \"__main__\":\n    args = parser.parse_args()\n    verbosity = args.verbosity\n    np.random.seed(seed=args.seed)\n    use_gpu(initial_override=args.gpu)\n\n    # get or create the input matrix\n    input_vals = np.random.rand(args.filters, args.channels)\n    if args.infile != \"random\":\n        if \"npy\" in args.infile:\n            input_vals = np.load(args.infile, \"r\")\n        shp = input_vals.shape\n        shp_str = str(shp).replace(\",\", \"x\")\n        newshp_str = \"\"\n        if len(shp) == 4:  # K,C,R,S -> RSK,C\n            input_vals = (\n                np.transpose(input_vals, (2, 3, 0, 1))\n                .flatten()\n                .reshape((shp[2] * shp[3] * shp[0], shp[1]))\n            )\n            newshp_str = str(input_vals.shape).replace(\",\", \"x\")\n        print(f\"{args.infile},{shp_str},{newshp_str}\")\n        if input_vals.shape[1] % 4 != 0:\n            print(f\"Unfriendly shape {input_vals.shape}, not pruning.\")\n            sys.exit()\n\n    # unstructured prune if requested\n    if args.unstructured > 0.0:\n        args.unstructured = min(args.unstructured, 1.0)\n        input_vals = unstructured_prune(input_vals, args.unstructured)\n        print(\n            f\"{args.infile} pruned to {args.unstructured * 100.0:>.1f} sparsity, shape is {input_vals.shape}\"\n        )\n\n    # calculate some early metrics\n    sorted_magnitudes = np.sort(np.abs(input_vals), axis=None)\n    unpruned_magnitude = np.sum(sorted_magnitudes)\n    num_weights = sorted_magnitudes.size\n    unstructured_magnitude = np.sum(sorted_magnitudes[int(num_weights / 2) :])\n    unstructured_rows_magnitude = magnitude_after_pruning_rows(input_vals, rate=0.5)\n    simple_2to4 = apply_2_to_4(np.copy(input_vals))\n    simple_2to4_magnitude = sum_after_2_to_4(input_vals)\n    tmp_time = time.perf_counter()\n    simple_2to4_magnitude = sum_after_2_to_4(input_vals)\n    default_duration = time.perf_counter() - tmp_time\n    best_magnitude = unstructured_rows_magnitude\n\n    best_lost_magnitude = unpruned_magnitude - best_magnitude\n    base_lost_magnitude = unpruned_magnitude - simple_2to4_magnitude\n\n    # prep results table\n    final_metric = \"efficacy\"\n    if args.unstructured < 0.0:\n        final_metric = \"min_sparsity\"\n    if args.pretty_print:\n        print(f\"{'strategy':<35},{'magnitude':>15},{final_metric:>15},{'duration':>15}\")\n        print(f\"{'unpruned':<35},{unpruned_magnitude:>15.3f},{'-':^15},{'-':^15}\")\n        print(f\"{'unstructured':<35},{unstructured_magnitude:>15.3f},{'-':^15},{'-':^15}\")\n        print(f\"{'50% rows':<35},{unstructured_rows_magnitude:>15.3f},{'100.0':>15},{'-':^15}\")\n        print(\n            f\"{'default 2:4':<35},{simple_2to4_magnitude:>15.3f},{'0.0':>15},{default_duration:>15.3f}\"\n        )\n    else:\n        print(f\"strategy,magnitude,{final_metric},duration\")\n        print(f\"unpruned,{unpruned_magnitude},-,-\")\n        print(f\"unstructured,{unstructured_magnitude},-,-\")\n        print(f\"50%_rows,{unstructured_rows_magnitude},100.0,-\")\n        print(f\"2:4,{simple_2to4_magnitude},0.0,{default_duration}\")\n\n    # try the requested strategies\n    for i, strategy in enumerate(args.strategies):\n        result = np.copy(input_vals)\n        np.random.seed(seed=args.seed)\n\n        duration = 0.0\n        min_sparsity = 0.0\n        strat_split = strategy.split(\",\")\n        found_permutation = None\n\n        # optimize stripe groups\n        if strat_split[0] == \"optimize_stripe_groups\":\n            stripe_group_size_in_cols = 8\n            if len(strat_split) >= 2:\n                stripe_group_size_in_cols = int(strat_split[1])\n            escape_attempts = 100\n            if len(strat_split) >= 3:\n                escape_attempts = int(strat_split[2])\n\n            if args.unstructured >= 0.0:  # just perform the search on the current matrix\n                result, duration, found_permutation = Exhaustive_Search(\n                    result,\n                    stripe_group_size=stripe_group_size_in_cols,\n                    escape_attempts=escape_attempts,\n                )\n            else:  # find the minimum sparsity needed to transparently transform the input\n                min_sparsity, duration = find_minimum_sparsity(\n                    result,\n                    Exhaustive_Search,\n                    stripe_group_size=stripe_group_size_in_cols,\n                    escape_attempts=escape_attempts,\n                )\n                result = unstructured_prune(result, min_sparsity / 100.0)\n\n        # channel swaps\n        elif strat_split[0] == \"channel_swap\":\n            escape_attempts = 0\n            if len(strat_split) >= 2:\n                escape_attempts = int(strat_split[1])\n\n            if args.unstructured >= 0.0:  # just perform the search on the current matrix\n                result, duration, found_permutation = Channel_Swap(\n                    result, escape_attempts=escape_attempts, verbosity=verbosity\n                )\n            else:  # find the minimum sparsity needed to transparently transform the input\n                min_sparsity, duration = find_minimum_sparsity(\n                    result,\n                    Channel_Swap,\n                    escape_attempts=escape_attempts,\n                    verbosity=verbosity,\n                )\n                result = unstructured_prune(result, min_sparsity / 100.0)\n\n        # random permutations\n        elif strat_split[0] == \"random\":\n            if (\n                args.unstructured < 0.0\n            ):  # searching for minimum sparsity not supported for random permutations\n                continue\n\n            num_perms = 10\n            if len(strat_split) >= 2 and int(strat_split[1]) >= 1:\n                num_perms = int(strat_split[1])\n\n            # try the seeds/permutations\n            permutation = [c for c in range(result.shape[1])]\n            best_sum = sum_after_2_to_4(result)\n            best_perm = permutation.copy()\n            start_time = time.perf_counter()\n            for x in range(num_perms):\n                permutation = np.random.permutation(permutation)\n                cur_sum = sum_after_2_to_4(result[:, permutation])\n                if cur_sum > best_sum:\n                    best_sum = cur_sum\n                    best_perm = permutation.copy()\n                    if verbosity > 0:\n                        print(f\"\\tnew best permutation {x} found with magnitude {best_sum:>15.3f}\")\n                elif verbosity > 5:\n                    print(f\"\\tpermutation {x} magnitude too low: {cur_sum:>15.3f}\")\n            duration = time.perf_counter() - start_time\n            result = result[:, best_perm]\n            found_permutation = best_perm\n\n        else:\n            print(f\"Unknown strategy: {strategy}!\")\n            sys.exit()\n\n        # report stats for this strategy\n        cur_mag = sum_after_2_to_4(result)\n        cur_eff = (\n            efficacy(best_lost_magnitude, base_lost_magnitude, unpruned_magnitude - cur_mag) * 100.0\n        )\n        final_metric = cur_eff\n        if args.unstructured < 0.0:\n            final_metric = min_sparsity\n        perm_distance = \"\"\n\n        error = None\n        if args.check_permutation and found_permutation is not None:\n            recovered_perm = find_permutation(result, input_vals)\n\n            error = False\n            for c in range(len(recovered_perm)):\n                if recovered_perm[c] != found_permutation[c]:\n                    if verbosity > 0:\n                        print(\n                            f\"tracked permutation at index {c} was {found_permutation[c]}, but the recovered permutation thought it was {recovered_perm[c]}\"\n                        )\n                    error = True\n\n        # if requested, generate permutations that divide the efficacy space into equal steps\n        if args.intermediate_steps != 0:\n            magnitude_targets = None\n            if args.intermediate_steps != 0:\n                ratios = [\n                    step / float(args.intermediate_steps + 1)\n                    for step in range(1, args.intermediate_steps + 1)\n                ]\n                mag_diff = cur_mag - (unpruned_magnitude - base_lost_magnitude)\n                magnitude_targets = [\n                    (unpruned_magnitude - base_lost_magnitude) + mag_diff * ratio\n                    for ratio in ratios\n                ]\n            perm_distance, target_permutations = permutation_distance(\n                found_permutation,\n                [c for c in range(result.shape[1])],\n                matrix=input_vals,\n                magnitude_targets=magnitude_targets,\n                debug=False,\n                verbosity=verbosity,\n            )\n            if target_permutations is not None:\n                for target_permutation in target_permutations:\n                    print(target_permutation)\n\n        error_str = \"\"\n        if error is not None:\n            error_str = \",       correct\"\n            if error:\n                error_str = \",      mismatch\"\n\n        if args.pretty_print:\n            print(\n                f\"{strategy:35},{cur_mag:>15.3f},{final_metric:>15.1f},{duration:>15.3f}{error_str:>15}\"\n            )\n        else:\n            strat_string = strategy.replace(\",\", \"_\")\n            print(f\"{strat_string},{cur_mag},{final_metric},{duration}{error_str}\")\n\n        if args.print_permutation and found_permutation is not None:\n            print(found_permutation)\n"
  },
  {
    "path": "apex/contrib/sparsity/permutation_tests/runtime_table.sh",
    "content": "#!/bin/bash\n\nOUTDIR=\"results/runtime_logs\"\nmkdir -p $OUTDIR\n\nR1000=random,1000\nCS=channel_swap,0\nCS_100=channel_swap,100\nOSG2=optimize_stripe_groups,8,0\nOSG2_100=optimize_stripe_groups,8,100\nOSG2_1000=optimize_stripe_groups,8,1000\nOSG3=optimize_stripe_groups,12,0\nOSG3_100=optimize_stripe_groups,12,100\nOSG3_1000=optimize_stripe_groups,12,1000\n\nfor cols in \"32\" \"64\" \"128\" \"256\"; do\n    echo \"$cols x $cols\"\n    python3 permutation_test.py --channels $cols --filters $cols --pretty_print=False $R1000 $CS $CS_100 $OSG2 $OSG2_100 $OSG2_1000 $OSG3 $OSG3_100 $OSG3_1000 | tee \"${OUTDIR}/runtime_${cols}x${cols}.log\"\n    let \"rows = $cols * 2\"\n    echo \"$cols x $rows\"\n    python3 permutation_test.py --channels $cols --filters $rows --pretty_print=False $R1000 $CS $CS_100 $OSG2 $OSG2_100 $OSG2_1000 $OSG3 $OSG3_100 $OSG3_1000 | tee \"${OUTDIR}/runtime_${cols}x${rows}.log\"\ndone\n\n# 2048x2048 is too large for OSG3\necho \"2048 x 2048\"\npython3 permutation_test.py --channels 2048 --filters 2048 --pretty_print=False $R1000 $CS $CS_100 $OSG2 $OSG2_100 $OSG2_1000 | tee \"${OUTDIR}/runtime_2048x2048.log\"\n\n\n############### collect results into a .csv file\necho \"Gathering results ...\"\n\n# efficacy and runtime from one strategy and size\nget_results() {\n    local strategy=$1\n    local cols=$2\n    local rows=$3\n    local OUTFILE=$4\n\n    grep \"$strategy,\" \"$OUTDIR/runtime_${cols}x${rows}.log\" | awk -F \",\" '{printf \"%s,%s,\",$3,$4}' >> $OUTFILE\n}\n\n# prepare output file headers\nOUTFILE=\"results/runtimes.csv\"\nprintf \"Columns,\" > $OUTFILE\nfor cols in \"32\" \"64\" \"128\" \"256\"; do\n    printf \"$cols,$cols,$cols,$cols,\" >> $OUTFILE\ndone\nprintf \"2048,2048\\n\" >> $OUTFILE\n\nprintf \"Rows,\" >> $OUTFILE\nfor cols in \"32\" \"64\" \"128\" \"256\"; do\n    let \"rows = $cols * 2\"\n    printf \"$cols,$cols,$rows,$rows,\" >> $OUTFILE\ndone\nprintf \"2048,2048\\n\" >> $OUTFILE\n\nprintf \"Metric,\" >> $OUTFILE\nfor cols in \"32\" \"64\" \"128\" \"256\"; do\n    printf \"Efficacy,Runtime,Efficay,Runtime,\" >> $OUTFILE\ndone\nprintf \"Efficacy,Runtime\\n\" >> $OUTFILE\n\n# gather data in a reasonable order\nfor strategy in \"$R1000\" \"$CS\" \"$CS_100\" \"$OSG2\" \"$OSG2_100\" \"$OSG2_1000\" \"$OSG3\" \"$OSG3_100\" \"$OSG3_1000\"; do\n    strategy=$(echo \"$strategy\" | sed 's/,/_/g') # replace commas with underscores, as they'll appear in the results logs\n    printf \"$strategy,\" >> $OUTFILE\n    for cols in \"32\" \"64\" \"128\" \"256\"; do\n        get_results \"$strategy\" \"$cols\" \"$cols\" \"$OUTFILE\"\n        let \"rows = $cols * 2\"\n        get_results \"$strategy\" \"$cols\" \"$rows\" \"$OUTFILE\"\n    done\n\n    get_results \"$strategy\" \"2048\" \"2048\" \"$OUTFILE\"\n\n    printf \"\\n\" >> $OUTFILE\ndone\n\necho \"Done! $OUTFILE\"\n"
  },
  {
    "path": "apex/contrib/sparsity/permutation_tests/unstructured_study.sh",
    "content": "#!/bin/bash\n\nif [ \"$#\" -ne 2 ]; then\n  echo \"Please specify both the source directory and a run tag: bash unstructured_study.sh <directory> <tag>\"\n  exit\nfi\n\ndir=$1  # or set to the directory containing .npy files of interest\ntag=$2 # or set to an identifier, e.g. \"network_name\"\n\nresdir=\"results/unstructured_logs/${tag}\"\nmkdir -p $resdir\n\nCS=channel_swap,0\nOSG2=optimize_stripe_groups,8,0\nOSG2_100=optimize_stripe_groups,8,100\nOSG2_1000=optimize_stripe_groups,8,1000\nOSG3=optimize_stripe_groups,12,0\n\nCS_successes=()\nOSG2_successes=()\nOSG2_100_successes=()\nOSG2_1000_successes=()\nOSG3_successes=()\n\nfor sparsity in {50..100}; do\n    CS_successes+=(0)\n    OSG2_successes+=(0)\n    OSG2_100_successes+=(0)\n    OSG2_1000_successes+=(0)\n    OSG3_successes+=(0)\ndone\n\nupdate_successes () {\n    strategy=$1\n    local -n _successes=$2\n    logfile=$3\n\n    limit=$(grep \"${strategy},\" $logfile | awk -F \",\" '{print $3}')\n \n    echo $logfile, $strategy, $limit\n    for (( sparsity=$limit; sparsity<=100; sparsity++ )); do\n        let \"entry = $sparsity - 50\"\n        let \"value = ${_successes[$entry]} + 1\"\n        _successes[$entry]=$value\n    done\n}\n\n# Figure 4\nfor filename in $dir/*.npy; do\n    out=$(basename -- \"$filename\")\n    echo \"Searching for minimum sparsities for $out\"\n    out=$resdir/$out.unstructured\n    python3 permutation_test.py --infile=$filename --pretty_print=False --unstructured=-1 $CS $OSG2 $OSG2_100 $OSG2_1000 $OSG3 > $out\n\n    update_successes \"channel_swap_0\" CS_successes \"$out\"\n    update_successes \"optimize_stripe_groups_8_0\" OSG2_successes \"$out\"\n    update_successes \"optimize_stripe_groups_8_100\" OSG2_100_successes \"$out\"\n    update_successes \"optimize_stripe_groups_8_1000\" OSG2_1000_successes \"$out\"\n    update_successes \"optimize_stripe_groups_12_0\" OSG3_successes \"$out\"\ndone\n\n#################### save the table\n# log a single strategy in as a row in the table\nlog_success () {\n    strategy=$1\n    local -n _successes=$2\n    OUTFILE=$3\n\n    printf \"$strategy,\" >> $OUTFILE\n    for sparsity in {50..100}; do\n        let \"entry = $sparsity - 50\"\n        printf \"%d,\" ${_successes[$entry]} >> $OUTFILE\n    done\n    printf \"\\n\" >> $OUTFILE\n}\n\n# prepare the header\nOUTFILE=\"results/unstructured.csv\"\nprintf \"Sparsity,\" > $OUTFILE\nfor sparsity in {50..100}; do\n    printf \"%d,\" $sparsity >> $OUTFILE\ndone\nprintf \"\\n\" >> $OUTFILE\n\n# add data for each strategy\nlog_success \"channel_swap_0\" CS_successes \"$OUTFILE\"\nlog_success \"optimize_stripe_groups_8_0\" OSG2_successes \"$OUTFILE\"\nlog_success \"optimize_stripe_groups_8_100\" OSG2_100_successes \"$OUTFILE\"\nlog_success \"optimize_stripe_groups_8_1000\" OSG2_1000_successes \"$OUTFILE\"\nlog_success \"optimize_stripe_groups_12_0\" OSG3_successes \"$OUTFILE\"\n\necho \"Done! ${OUTFILE}\"\n"
  },
  {
    "path": "apex/contrib/sparsity/sparse_masklib.py",
    "content": "import sys\nimport torch\nimport numpy as np\nimport collections\nfrom itertools import permutations\n\n\n\"\"\" compute density (helper fn to compute % NNZs in a tensor) \"\"\"\n\n\ndef fill(x):\n    return float(x.nonzero().size(0)) / torch.numel(x)\n\n\n\"\"\" reshape matrix into m-dimensional vectors: (h,w) -> (hw/m, m) \"\"\"\n\n\ndef reshape_1d(matrix, m):\n    # If not a nice multiple of m, fill with zeroes.\n    if matrix.shape[1] % m > 0:\n        mat = torch.cuda.FloatTensor(\n            matrix.shape[0], matrix.shape[1] + (m - matrix.shape[1] % m)\n        ).fill_(0)\n        mat[:, : matrix.shape[1]] = matrix\n        shape = mat.shape\n        return mat.view(-1, m), shape\n    else:\n        return matrix.view(-1, m), matrix.shape\n\n\n\"\"\" return all possible m:n patterns in a 1d vector \"\"\"\nvalid_m4n2_1d_patterns = None\n\n\ndef compute_valid_1d_patterns(m, n):\n    # Early exit if patterns was already created.\n    global valid_m4n2_1d_patterns\n\n    if m == 4 and n == 2 and valid_m4n2_1d_patterns is not None:\n        return valid_m4n2_1d_patterns\n    patterns = torch.zeros(m)\n    patterns[:n] = 1\n    valid_patterns = torch.tensor(list(set(permutations(patterns.tolist()))))\n    if m == 4 and n == 2:\n        valid_m4n2_1d_patterns = valid_patterns\n    return valid_patterns\n\n\n\"\"\" m:n 1d structured best \"\"\"\n\n\ndef mn_1d_best(matrix, m, n):\n    # Find all possible patterns.\n    patterns = compute_valid_1d_patterns(m, n).cuda()\n\n    # Find the best m:n pattern (sum of non-masked weights).\n    mask = torch.cuda.IntTensor(matrix.shape).fill_(1).view(-1, m)\n    mat, shape = reshape_1d(matrix, m)\n    pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1)\n    mask[:] = patterns[pmax[:]]\n    mask = mask.view(matrix.shape)\n    return mask\n\n\ndef m4n2_1d(mat, density):\n    return mn_1d_best(mat, 4, 2)\n\n\n\"\"\"\n  Below 2d-masking related code is targeted more for training (from scratch).\n  2d-pruning of a weight tensor is done to accelerate DGRAD step during backprop\n  phase of training algorithm. Acceleration comes from using SpMMA instructions in\n  Tensor Cores of NVIDIA Ampere GPU Architecture \n  (note: this code does not do the acceleration, GPU kernels are required for this).\n  1d pruning of weight tensor helps speed up FPROP step by pruning in 2:4 pattern\n  along the horizontal (logical) direction.\n  During DGRAD step, weight tensor is transposed. 2d pruning functions below, mask\n  weight tensor such that their transposed versions are also 2:4 sparse along the\n  horizontal (logical) direction. Thus, with 2d pruning, weight tensors are \n  2:4 sparse along row and column directions.\n \"\"\"\n\n\"\"\" m:n 2d structured pruning: greedy method to select mask \"\"\"\n\n\ndef mn_2d_greedy(matrix, m, n):\n    # Convert to numpy\n    mat = matrix.cpu().detach().numpy()\n    mask = np.ones(mat.shape, dtype=int)\n\n    rowCount = int(mat.shape[0] / m) * m\n    colCount = int(mat.shape[1] / m) * m\n    for rowStartIdx in range(0, rowCount, m):\n        rowEndIdx = rowStartIdx + m\n        for colStartIdx in range(0, colCount, m):\n            colEndIdx = colStartIdx + m\n            matrixSub = np.absolute(np.squeeze(mat[rowStartIdx:rowEndIdx, colStartIdx:colEndIdx]))\n            maskSub = np.squeeze(mask[rowStartIdx:rowEndIdx, colStartIdx:colEndIdx])\n            maskSub.fill(0.0)\n            matrixVecView = matrixSub.reshape(-1)\n            maskVecView = maskSub.reshape(-1)\n            linearIdx = np.argsort(matrixVecView)\n            matrixIdx = [(int(x / m), x % m) for x in linearIdx]\n            rowCounter = collections.Counter()\n            colCounter = collections.Counter()\n            for currIdx in range(len(linearIdx) - 1, -1, -1):\n                currMatrixEntry = matrixIdx[currIdx]\n                if (rowCounter[currMatrixEntry[0]] == n) or (colCounter[currMatrixEntry[1]] == n):\n                    continue\n                # end if\n                maskSub[currMatrixEntry[0], currMatrixEntry[1]] = 1.0\n                rowCounter[currMatrixEntry[0]] += 1\n                colCounter[currMatrixEntry[1]] += 1\n\n    return torch.tensor(mask.cuda())\n\n\ndef m4n2_2d_greedy(mat, density):\n    return mn_2d_greedy(mat, 4, 2)\n\n\n\"\"\" return all possible m:n patterns in a mxn block. \"\"\"\nvalid_m4n2_2d_patterns = None\n\n\ndef compute_valid_2d_patterns(m, n):\n    # Early exit if patterns was already created.\n    global valid_m4n2_2d_patterns\n    if valid_m4n2_2d_patterns is not None:\n        return valid_m4n2_2d_patterns\n\n    patterns = torch.zeros(m)\n    patterns[:n] = 1\n    patterns = list(set(permutations(patterns.tolist())))\n    patterns = patterns + patterns\n    patterns = torch.empty(list(set(permutations(patterns, m))))\n\n    valid = ((patterns.sum(dim=1) <= n).sum(dim=1) == m).nonzero().view(-1)\n    valid_patterns = torch.empty(valid.shape[0], m, m)\n    valid_patterns[:] = patterns[valid[:]]\n\n    if m == 4 and n == 2:\n        valid_m4n2_2d_patterns = valid_patterns\n    return valid_patterns\n\n\n\"\"\" m:n 2d structured pruning: exhaustive method to select best mask \"\"\"\n\n\ndef mn_2d_best(matrix, m, n):\n    # Find all possible patterns.\n    patterns = compute_valid_2d_patterns(m, n).cuda()\n\n    # Find the best m:n pattern (sum of non-masked weights).\n    mask = torch.cuda.IntTensor(matrix.shape).fill_(1)\n    mat = reshape_2d(matrix, m, m).abs()\n    pmax = torch.argmax(torch.matmul(mat, patterns.view(patterns.shape[0], m * m).t()), dim=2)\n\n    # Copy best m:n patterns into mask.\n    mat = mat.view(mat.shape[0] * mat.shape[1], -1)\n    pmax = pmax.view(pmax.shape[0] * pmax.shape[1]).unsqueeze(1).expand(-1, mat.shape[1])\n    patterns = patterns.view(patterns.shape[0], patterns.shape[1] * patterns.shape[2])\n    mat = torch.gather(patterns, 0, pmax)\n    mat = reshape_2d_inv(mat.view(matrix.shape[0] // m, matrix.shape[1] // m, m, m))\n    mask.copy_(mat.type(mask.type()))\n    return mask\n\n\ndef m4n2_2d_best(mat, density):\n    return mn_2d_best(mat, 4, 2)\n\n\n\"\"\" returns a sparse mask \"\"\"\n\n\ndef create_mask(tensor, pattern=\"m4n2_1d\", density=0.5):\n    # Reshape tensor and mask.\n    shape = tensor.shape\n    ttype = tensor.type()\n    t = tensor.float().contiguous()\n\n    # 1d-tensor\n    if len(shape) == 1:\n        t = t.view(1, shape[0])\n        func = getattr(sys.modules[__name__], pattern, None)\n        mask = func(t, density)\n        return mask.view(shape).type(ttype)\n    # 2d-tensor (K, C)\n    elif len(shape) == 2:\n        # linear\n        t = t.view(shape[0], shape[1])\n        func = getattr(sys.modules[__name__], pattern, None)\n        mask = func(t, density)\n        return mask.view(shape).type(ttype)\n    # 3d-tensor (K, C, R)\n    elif len(shape) == 3:\n        # 1d convs\n        t = t.permute(0, 2, 1).contiguous().view(shape[0] * shape[2], shape[1])\n        func = getattr(sys.modules[__name__], pattern, None)\n        mask = func(t, density)\n        mask = mask.view(shape[0], shape[2], shape[1]).permute(0, 2, 1).contiguous()\n        return mask.view(shape).type(ttype)\n    # 4d-tensor (K, C, R, S)\n    elif len(shape) == 4:\n        \"\"\"\n        # transformers (bmm)\n        t = t.view(shape[0]*shape[1]*shape[2], shape[3])\n        func = getattr(sys.modules[__name__], pattern, None)\n        mask = func(t, density)\n        return mask.view(shape).type(ttype)\n        \"\"\"\n        # 2d convs\n        t = t.permute(2, 3, 0, 1).contiguous().view(shape[2] * shape[3] * shape[0], shape[1])\n        func = getattr(sys.modules[__name__], pattern, None)\n        mask = func(t, density)\n        mask = mask.view(shape[2], shape[3], shape[0], shape[1]).permute(2, 3, 0, 1).contiguous()\n        return mask.view(shape).type(ttype)\n"
  },
  {
    "path": "apex/contrib/sparsity/test/checkpointing_test_part1.py",
    "content": "from collections import OrderedDict\n\nimport torch\nfrom apex.optimizers import FusedAdam\nfrom apex.contrib.sparsity import ASP\n\n\ndef build_model(args):\n    od = OrderedDict()\n    for i in range(args.num_layers):\n        if i == 0:\n            od[\"linear_layer_%d\" % (i + 1)] = torch.nn.Linear(\n                args.input_features, args.hidden_features\n            )\n            od[\"layer_norm_%d\" % (i + 1)] = torch.nn.LayerNorm(\n                [args.batch_size, args.hidden_features]\n            )\n        elif i == args.num_layers - 1:\n            od[\"linear_layer_%d\" % (i + 1)] = torch.nn.Linear(\n                args.hidden_features, args.output_features\n            )\n            od[\"layer_norm_%d\" % (i + 1)] = torch.nn.LayerNorm(\n                [args.batch_size, args.output_features]\n            )\n        else:\n            od[\"linear_layer_%d\" % (i + 1)] = torch.nn.Linear(\n                args.hidden_features, args.hidden_features\n            )\n            od[\"layer_norm_%d\" % (i + 1)] = torch.nn.LayerNorm(\n                [args.batch_size, args.hidden_features]\n            )\n    return torch.nn.Sequential(od)\n\n\ndef train_step(args, model, optimizer, input_batch, target_batch, step):\n    predicted_target = model(input_batch)\n    loss = ((predicted_target - target_batch) ** 2).sum()\n    loss.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n    step = step + 1\n    # print(\"Step %d :: loss=%e\" % (step, loss.item()))\n    return step\n\n\ndef train_loop(args, model, optimizer, step, num_steps):\n    for i in range(num_steps):\n        input_batch = torch.randn([args.batch_size, args.input_features]).cuda()\n        target_batch = torch.randn([args.batch_size, args.output_features]).cuda()\n        step = train_step(args, model, optimizer, input_batch, target_batch, step)\n    return step\n\n\ndef main(args):\n    #\n    # PART1\n    #\n\n    torch.manual_seed(args.seed)\n\n    model = build_model(args).cuda()\n    one_ll = next(model.children()).weight\n    optimizer = FusedAdam(model.parameters())\n    ASP.init_model_for_pruning(\n        model,\n        args.pattern,\n        verbosity=args.verbosity,\n        whitelist=args.whitelist,\n        allow_recompute_mask=args.allow_recompute_mask,\n    )\n    ASP.init_optimizer_for_pruning(optimizer)\n\n    step = 0\n\n    # train for a few steps with dense weights\n    print(\"DENSE :: \", one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_dense_steps)\n\n    # simulate sparsity by inserting zeros into existing dense weights\n    ASP.compute_sparse_masks()\n\n    # train for a few steps with sparse weights\n    print(\"SPARSE :: \", one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_sparse_steps)\n\n    torch.save(\n        {\n            \"step\": step,\n            \"verbosity\": args.verbosity,\n            \"seed2\": args.seed2,\n            \"pattern\": args.pattern,\n            \"whitelist\": args.whitelist,\n            \"allow_recompute_mask\": args.allow_recompute_mask,\n            \"model_state_dict\": model.state_dict(),\n            \"optimizer_state_dict\": optimizer.state_dict(),\n        },\n        args.checkpoint_path,\n    )\n\n\nif __name__ == \"__main__\":\n\n    class Args:\n        verbosity = 3\n        seed = 4873\n        seed2 = 99875\n        pattern = \"m4n2_2d_best\"\n        whitelist = [torch.nn.Linear]\n        allow_recompute_mask = True\n        batch_size = 32\n        input_features = 8\n        output_features = 8\n        hidden_features = 32\n        num_layers = 4\n        num_dense_steps = 2000\n        num_sparse_steps = 3000\n        num_sparse_steps_2 = 1000\n        checkpoint_path = \"part1.chkp\"\n\n    args = Args()\n\n    main(args)\n"
  },
  {
    "path": "apex/contrib/sparsity/test/checkpointing_test_part2.py",
    "content": "from collections import OrderedDict\n\nimport torch\nfrom apex.optimizers import FusedAdam\nfrom apex.contrib.sparsity import ASP\n\n\ndef build_model(args):\n    od = OrderedDict()\n    for i in range(args.num_layers):\n        if i == 0:\n            od[\"linear_layer_%d\" % (i + 1)] = torch.nn.Linear(\n                args.input_features, args.hidden_features\n            )\n            od[\"layer_norm_%d\" % (i + 1)] = torch.nn.LayerNorm(\n                [args.batch_size, args.hidden_features]\n            )\n        elif i == args.num_layers - 1:\n            od[\"linear_layer_%d\" % (i + 1)] = torch.nn.Linear(\n                args.hidden_features, args.output_features\n            )\n            od[\"layer_norm_%d\" % (i + 1)] = torch.nn.LayerNorm(\n                [args.batch_size, args.output_features]\n            )\n        else:\n            od[\"linear_layer_%d\" % (i + 1)] = torch.nn.Linear(\n                args.hidden_features, args.hidden_features\n            )\n            od[\"layer_norm_%d\" % (i + 1)] = torch.nn.LayerNorm(\n                [args.batch_size, args.hidden_features]\n            )\n    return torch.nn.Sequential(od)\n\n\ndef train_step(args, model, optimizer, input_batch, target_batch, step):\n    predicted_target = model(input_batch)\n    loss = ((predicted_target - target_batch) ** 2).sum()\n    loss.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n    step = step + 1\n    # print(\"Step %d :: loss=%e\" % (step, loss.item()))\n    return step\n\n\ndef train_loop(args, model, optimizer, step, num_steps):\n    for i in range(num_steps):\n        input_batch = torch.randn([args.batch_size, args.input_features]).cuda()\n        target_batch = torch.randn([args.batch_size, args.output_features]).cuda()\n        step = train_step(args, model, optimizer, input_batch, target_batch, step)\n    return step\n\n\ndef main(step, args, model_state_dict, optimizer_state_dict):\n    #\n    # PART2\n    #\n\n    model = build_model(args).cuda()\n    one_ll = next(model.children()).weight\n    optimizer = FusedAdam(model.parameters())\n    ASP.init_model_for_pruning(\n        model,\n        args.pattern,\n        verbosity=args.verbosity,\n        whitelist=args.whitelist,\n        allow_recompute_mask=args.allow_recompute_mask,\n    )\n    ASP.init_optimizer_for_pruning(optimizer)\n\n    torch.manual_seed(args.seed2)\n    model.load_state_dict(model_state_dict)\n    optimizer.load_state_dict(optimizer_state_dict)\n\n    print(\"Model sparsity is %s\" % (\"enabled\" if ASP.is_sparsity_enabled() else \"disabled\"))\n\n    # train for a few steps with sparse weights\n    print(\"SPARSE :: \", one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2)\n\n\nif __name__ == \"__main__\":\n    checkpoint = torch.load(\"part1.chkp\")\n\n    class Args:\n        verbosity = checkpoint[\"verbosity\"]\n        seed = 4873\n        seed2 = checkpoint[\"seed2\"]\n        pattern = checkpoint[\"pattern\"]\n        whitelist = checkpoint[\"whitelist\"]\n        allow_recompute_mask = checkpoint[\"allow_recompute_mask\"]\n        batch_size = 32\n        input_features = 8\n        output_features = 8\n        hidden_features = 32\n        num_layers = 4\n        num_dense_steps = 2000\n        num_sparse_steps = 3000\n        num_sparse_steps_2 = 1000\n        checkpoint_path = \"part1.chkp\"\n\n    args = Args()\n\n    main(\n        checkpoint[\"step\"],\n        args,\n        checkpoint[\"model_state_dict\"],\n        checkpoint[\"optimizer_state_dict\"],\n    )\n"
  },
  {
    "path": "apex/contrib/sparsity/test/checkpointing_test_reference.py",
    "content": "from collections import OrderedDict\n\nimport torch\nfrom apex.optimizers import FusedAdam\nfrom apex.contrib.sparsity import ASP\n\n#\n# Reference run for checkpointing test (part1 + part2)\n#\n\n\ndef build_model(args):\n    od = OrderedDict()\n    for i in range(args.num_layers):\n        if i == 0:\n            od[\"linear_layer_%d\" % (i + 1)] = torch.nn.Linear(\n                args.input_features, args.hidden_features\n            )\n            od[\"layer_norm_%d\" % (i + 1)] = torch.nn.LayerNorm(\n                [args.batch_size, args.hidden_features]\n            )\n        elif i == args.num_layers - 1:\n            od[\"linear_layer_%d\" % (i + 1)] = torch.nn.Linear(\n                args.hidden_features, args.output_features\n            )\n            od[\"layer_norm_%d\" % (i + 1)] = torch.nn.LayerNorm(\n                [args.batch_size, args.output_features]\n            )\n        else:\n            od[\"linear_layer_%d\" % (i + 1)] = torch.nn.Linear(\n                args.hidden_features, args.hidden_features\n            )\n            od[\"layer_norm_%d\" % (i + 1)] = torch.nn.LayerNorm(\n                [args.batch_size, args.hidden_features]\n            )\n    return torch.nn.Sequential(od)\n\n\ndef train_step(args, model, optimizer, input_batch, target_batch, step):\n    predicted_target = model(input_batch)\n    loss = ((predicted_target - target_batch) ** 2).sum()\n    loss.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n    step = step + 1\n    # print(\"Step %d :: loss=%e\" % (step, loss.item()))\n    return step\n\n\ndef train_loop(args, model, optimizer, step, num_steps):\n    for i in range(num_steps):\n        input_batch = torch.randn([args.batch_size, args.input_features]).cuda()\n        target_batch = torch.randn([args.batch_size, args.output_features]).cuda()\n        step = train_step(args, model, optimizer, input_batch, target_batch, step)\n    return step\n\n\ndef main(args):\n    #\n    # PART1\n    #\n\n    torch.manual_seed(args.seed)\n\n    model = build_model(args).cuda()\n    one_ll = next(model.children()).weight\n    optimizer = FusedAdam(model.parameters())\n    ASP.init_model_for_pruning(\n        model,\n        args.pattern,\n        whitelist=args.whitelist,\n        allow_recompute_mask=args.allow_recompute_mask,\n    )\n    ASP.init_optimizer_for_pruning(optimizer)\n\n    step = 0\n\n    # train for a few steps with dense weights\n    print(\"DENSE :: \", one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_dense_steps)\n\n    # simulate sparsity by inserting zeros into existing dense weights\n    ASP.compute_sparse_masks()\n\n    # train for a few steps with sparse weights\n    print(\"SPARSE :: \", one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_sparse_steps)\n\n    #\n    # PART 2\n    #\n\n    torch.manual_seed(args.seed2)\n\n    # train for a few steps with sparse weights\n    print(\"SPARSE :: \", one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2)\n\n\nif __name__ == \"__main__\":\n\n    class Args:\n        seed = 4873\n        seed2 = 99875\n        pattern = \"m4n2_2d_best\"\n        whitelist = [torch.nn.Linear]\n        allow_recompute_mask = True\n        batch_size = 32\n        input_features = 8\n        output_features = 8\n        hidden_features = 32\n        num_layers = 4\n        num_dense_steps = 2000\n        num_sparse_steps = 3000\n        num_sparse_steps_2 = 1000\n        checkpoint_path = \"part1.chkp\"\n\n    args = Args()\n\n    main(args)\n"
  },
  {
    "path": "apex/contrib/sparsity/test/test_permutation_application.py",
    "content": "import torch\nimport torch.onnx\nfrom apex.contrib.sparsity.permutation_lib import Permutation\n\n\"\"\"\nFunctional and behavioral correctness checking for network permutations\nEach test class is a torch.nn.Module with three required members:\n- self.input_shape is used to populate a dummy input\n- self.expected_C_params indicates how many parameters are expected to be permuted in the C dimension\n- self.expected_K_params indicates how many parameters are expected to be permuted in the K dimension\n\nA test is successful if and only if:\n1. The output of the un-permuted module matches (within a tolerance) the ouput of the permuted module\n2. The number of parameters permuted in C, as reported by the Permutation class, matches the expected value in the test module\n3. The number of parameters permuted in K, as reported by the Permutation class, matches the expected value in the test module\n\nThis file has all the test modules defined first, followed by the common test routine to check each module's correctness, and finally the main/entry point.\n\"\"\"\n\n\nclass simple_convs(torch.nn.Module):\n    \"\"\"Stack of 2d convolutions with different normalization and activation functions\"\"\"\n\n    def __init__(\n        self,\n        num_convs: int,\n        channels: int,\n        normalization: str = \"none\",\n        activation: str = \"ReLU\",\n    ):\n        super().__init__()\n        self.num_convs = num_convs\n        self.channels = channels\n        self.normalization = normalization\n        self.activation = activation\n\n        self.input_shape = [4, channels, 7, 7]\n\n        # we'll permute all convs' weights along C except the first\n        self.expected_C_params = -1\n        self.expected_K_params = 0\n\n        self.conv_stack = torch.nn.Sequential()\n        for c in range(self.num_convs - 1):\n            self.conv_stack.add_module(\n                f\"conv_{c}\",\n                torch.nn.Conv2d(self.channels, self.channels, kernel_size=(3, 3), padding=1),\n            )\n            self.expected_C_params += 1\n            self.expected_K_params += 2\n\n            if self.normalization == \"BatchNorm2d\":\n                self.conv_stack.add_module(\n                    f\"norm_{c}\",\n                    torch.nn.BatchNorm2d(self.channels, track_running_stats=False),\n                )\n                self.expected_K_params += 2\n            elif self.normalization == \"LazyBatchNorm2d\":\n                self.conv_stack.add_module(\n                    f\"norm_{c}\", torch.nn.LazyBatchNorm2d(track_running_stats=False)\n                )\n                self.expected_K_params += 2\n            elif self.normalization == \"GroupNorm\":\n                self.conv_stack.add_module(\n                    f\"norm_{c}\", torch.nn.GroupNorm(4, self.channels, affine=True)\n                )\n                self.expected_C_params -= 1  # GN prevents permutations of the neighboring convs\n                self.expected_K_params -= 2\n            elif self.normalization == \"InstanceNorm2d\":\n                self.conv_stack.add_module(\n                    f\"norm_{c}\",\n                    torch.nn.InstanceNorm2d(self.channels, affine=True, track_running_stats=False),\n                )\n                self.expected_K_params += 2\n            elif self.normalization == \"LocalResponseNorm\":\n                self.conv_stack.add_module(f\"norm_{c}\", torch.nn.LocalResponseNorm(16))\n            elif self.normalization == \"LayerNorm1\":\n                self.conv_stack.add_module(f\"norm_{c}\", torch.nn.LayerNorm(7))\n            elif self.normalization == \"LayerNorm2\":\n                self.conv_stack.add_module(f\"norm_{c}\", torch.nn.LayerNorm([7, 7]))\n            elif self.normalization == \"LayerNorm3\":\n                self.conv_stack.add_module(f\"norm_{c}\", torch.nn.LayerNorm([self.channels, 7, 7]))\n                self.expected_K_params += 2\n            elif self.normalization == \"SyncBatchNorm\":\n                self.conv_stack.add_module(\n                    f\"norm_{c}\",\n                    torch.nn.SyncBatchNorm(self.channels, track_running_stats=False),\n                )\n                self.expected_K_params += 2\n\n            self.conv_stack.add_module(f\"act_{c}\", torch.nn.ReLU())\n\n        self.conv_stack.add_module(\n            \"conv_out\", torch.nn.Conv2d(self.channels, 8, kernel_size=(1, 1))\n        )\n        self.expected_C_params += 1\n\n    def forward(self, x: torch.Tensor):\n        x = self.conv_stack(x)\n\n        return x\n\n\nclass conv_1d(torch.nn.Module):\n    \"\"\"1D convolutions in isolation and with siblings\"\"\"\n\n    def __init__(\n        self,\n        with_2d=False,\n    ):\n        super().__init__()\n        self.input_shape = [4, 16, 7, 7]\n        self.expected_C_params = 0\n        self.expected_K_params = 0\n        self.with_2d = with_2d\n\n        self.input_conv = torch.nn.Conv2d(self.input_shape[1], 32, kernel_size=(3, 3), padding=1)\n        self.expected_K_params += 2\n\n        self.branch_a_1D = torch.nn.Conv1d(32, 32, kernel_size=3, padding=1)\n        self.expected_C_params += 1\n        self.expected_K_params += 2\n        if self.with_2d:\n            self.branch_b_2D = torch.nn.Conv2d(32, 32, kernel_size=(3, 3), padding=1)\n            self.expected_C_params += 1\n            self.expected_K_params += 2\n\n        self.out_conv = torch.nn.Conv2d(32, 8, kernel_size=(1, 1))\n        self.expected_C_params += 1\n\n    def forward(self, x: torch.Tensor):\n        step0 = self.input_conv(x)\n\n        s0shape = step0.shape\n        step1 = self.branch_a_1D(step0.view(s0shape[0], s0shape[1], s0shape[2] * s0shape[3])).view(\n            s0shape\n        )\n        if self.with_2d:\n            step1 = step1 + self.branch_b_2D(step0)\n\n        return self.out_conv(step1)\n\n\nclass grouped_convs(torch.nn.Module):\n    \"\"\"Stack of 2d convolutions with different types of grouped convolutions\"\"\"\n\n    def __init__(\n        self,\n    ):\n        super().__init__()\n        self.channels = 128\n        self.input_shape = [4, self.channels, 7, 7]\n        self.expected_C_params = 0\n        self.expected_K_params = 0\n\n        self.conv_stack = torch.nn.Sequential()\n        self.conv_stack.add_module(\n            \"conv_in\",\n            torch.nn.Conv2d(self.channels, self.channels, kernel_size=(3, 3), padding=1),\n        )\n\n        # dw conv will let previous and this layers' weights and biases permute along K\n        self.expected_K_params += 4\n        self.conv_stack.add_module(\n            \"conv_dw\",\n            torch.nn.Conv2d(\n                self.channels,\n                self.channels,\n                kernel_size=(3, 3),\n                padding=1,\n                groups=self.channels,\n            ),\n        )\n\n        # regular conv permutes both\n        self.expected_C_params += 1\n        self.expected_K_params += 2\n        self.conv_stack.add_module(\n            \"conv_0\",\n            torch.nn.Conv2d(self.channels, self.channels, kernel_size=(3, 3), padding=1, groups=1),\n        )  # explicit '1' groups for extra coverage\n\n        # only 2 groups should allow permutations only in C\n        self.expected_C_params += 1\n        self.conv_stack.add_module(\n            \"conv_gr2\",\n            torch.nn.Conv2d(self.channels, self.channels, kernel_size=(3, 3), padding=1, groups=2),\n        )\n\n        # another regular conv, this one can't do anything\n        self.conv_stack.add_module(\n            \"conv_1\",\n            torch.nn.Conv2d(self.channels, self.channels, kernel_size=(3, 3), padding=1),\n        )\n\n        # finally, grouped conv with small groups\n        self.conv_stack.add_module(\n            \"conv_gr64\",\n            torch.nn.Conv2d(\n                self.channels,\n                self.channels,\n                kernel_size=(3, 3),\n                padding=1,\n                groups=self.channels // 2,\n            ),\n        )\n\n    def forward(self, input: torch.Tensor):\n        return self.conv_stack(input)\n\n\nclass simple_forks_joins(torch.nn.Module):\n    \"\"\"Some simple residual connections to test collecting parameters into a single group.  Four sections: input, blocka + residual, blockb + blockc, output\"\"\"\n\n    def __init__(\n        self,\n    ):\n        super().__init__()\n        self.channels = 64\n        self.input_shape = [4, self.channels, 7, 7]\n        self.expected_C_params = 0\n        self.expected_K_params = 0\n\n        self.input_convs = torch.nn.Sequential()\n        # input conv can only permute along K\n        self.expected_K_params += 2\n        self.input_convs.add_module(\n            \"conv_in0\",\n            torch.nn.Conv2d(self.channels, self.channels, kernel_size=(3, 3), padding=1),\n        )\n        # the next conv can permute along both C and K\n        self.expected_C_params += 1\n        self.expected_K_params += 2\n        self.input_convs.add_module(\n            \"conv_in1\",\n            torch.nn.Conv2d(self.channels, self.channels, kernel_size=(3, 3), padding=1),\n        )\n        # BN will permute 2 more along K\n        self.expected_K_params += 2\n        self.input_convs.add_module(\n            \"bn_in1\", torch.nn.BatchNorm2d(self.channels, track_running_stats=False)\n        )\n\n        self.block_a = torch.nn.Sequential()\n        # cut channels in half, then back to full, two fully permutable convs\n        self.expected_C_params += 2\n        self.expected_K_params += 4\n        self.block_a.add_module(\n            \"conv_a0\",\n            torch.nn.Conv2d(self.channels, self.channels // 2, kernel_size=(3, 3), padding=1),\n        )\n        self.block_a.add_module(\n            \"conv_a1\",\n            torch.nn.Conv2d(self.channels // 2, self.channels, kernel_size=(3, 3), padding=1),\n        )\n\n        self.block_b = torch.nn.Sequential()\n        # cut channels in half, then back to full, two fully permutable convs\n        self.expected_C_params += 2\n        self.expected_K_params += 4\n        self.block_b.add_module(\n            \"conv_b0\",\n            torch.nn.Conv2d(self.channels, self.channels // 2, kernel_size=(3, 3), padding=1),\n        )\n        self.block_b.add_module(\n            \"conv_b1\",\n            torch.nn.Conv2d(self.channels // 2, self.channels, kernel_size=(3, 3), padding=1),\n        )\n\n        self.block_c = torch.nn.Sequential()\n        # cut channels in half, then back to full, two fully permutable convs\n        self.expected_C_params += 2\n        self.expected_K_params += 4\n        self.block_c.add_module(\n            \"conv_c0\",\n            torch.nn.Conv2d(self.channels, self.channels // 2, kernel_size=(3, 3), padding=1),\n        )\n        self.block_c.add_module(\n            \"conv_c1\",\n            torch.nn.Conv2d(self.channels // 2, self.channels, kernel_size=(3, 3), padding=1),\n        )\n\n        self.output_conv = torch.nn.Sequential()\n        self.expected_C_params += 1\n        self.output_conv.add_module(\n            \"conv_out\", torch.nn.Conv2d(self.channels, 8, kernel_size=(3, 3), padding=1)\n        )\n\n    def forward(self, input: torch.Tensor):\n        step0 = self.input_convs(input)\n        step1 = step0 + self.block_a(step0)\n        step2 = self.block_b(step1) + self.block_c(step1)\n        return self.output_conv(step2)\n\n\nclass different_grouped_convs(torch.nn.Module):\n    \"\"\"Convolutions with different group sizes need to use the GCD of the input channel counts if siblings\"\"\"\n\n    def __init__(\n        self,\n    ):\n        super().__init__()\n        self.channels = 16\n        self.input_shape = [4, self.channels, 7, 7]\n        self.expected_C_params = 0\n        self.expected_K_params = 0\n\n        self.input_conv = torch.nn.Sequential()\n        self.expected_K_params += 2\n        self.input_conv.add_module(\n            \"input_conv\",\n            torch.nn.Conv2d(self.channels, 128, kernel_size=(3, 3), padding=1),\n        )\n\n        self.expected_C_params += 4\n        # 4 parallel blocks with decreasing group size from \"left\" to \"right\"\n        self.block_a = torch.nn.Sequential()\n        self.block_a.add_module(\"conv_a\", torch.nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1))\n        self.block_b = torch.nn.Sequential()\n        self.block_b.add_module(\n            \"conv_b\", torch.nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1, groups=2)\n        )\n        self.block_c = torch.nn.Sequential()\n        self.block_c.add_module(\n            \"conv_c\", torch.nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1, groups=4)\n        )\n        self.block_d = torch.nn.Sequential()\n        self.block_d.add_module(\n            \"conv_d\", torch.nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1, groups=8)\n        )\n\n        # output can't permute along C, disallowed by parents\n        self.output_conv = torch.nn.Sequential()\n        self.output_conv.add_module(\n            \"output_conv\", torch.nn.Conv2d(128, 8, kernel_size=(3, 3), padding=1)\n        )\n\n    def forward(self, input: torch.Tensor):\n        step0 = self.input_conv(input)\n        step1 = (\n            self.block_a(step0) + self.block_b(step0) + self.block_c(step0) + self.block_d(step0)\n        )\n        return self.output_conv(step1)\n\n\nclass siblings_poison(torch.nn.Module):\n    \"\"\"A single sibling that cannot permute along C poisons all other siblings in its group\"\"\"\n\n    def __init__(\n        self,\n    ):\n        super().__init__()\n        self.input_shape = [4, 16, 7, 7]\n        self.expected_C_params = 0\n        self.expected_K_params = 0\n\n        self.input_conv = torch.nn.Sequential()\n        self.input_conv.add_module(\n            \"input_conv\",\n            torch.nn.Conv2d(self.input_shape[1], 128, kernel_size=(3, 3), padding=1),\n        )\n\n        # two parallel block: conv->flatten->linear | flatten->linear\n        self.expected_K_params += (\n            4  # two linears will have their output channels permuted for the output layer\n        )\n        self.block_a = torch.nn.Sequential()\n        self.block_a.add_module(\"conv_a\", torch.nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1))\n        self.block_a.add_module(\"flatten_a\", torch.nn.Flatten(1))\n        self.block_a.add_module(\"linear_a\", torch.nn.Linear(6272, 128))\n\n        self.block_b = torch.nn.Sequential()\n        self.block_b.add_module(\"flatten_b\", torch.nn.Flatten(1))\n        self.block_b.add_module(\"linear_b\", torch.nn.Linear(6272, 128))\n\n        self.output = torch.nn.Sequential()\n        self.expected_C_params += 1  # output layer will have its C dimension permuted\n        self.output.add_module(\"output\", torch.nn.Linear(128, 8))\n\n    def forward(self, input: torch.Tensor):\n        step0 = self.input_conv(input)\n        step1 = self.block_a(step0) + self.block_b(step0)\n        return self.output(step1)\n\n\nclass coparent_poison(torch.nn.Module):\n    \"\"\"A single coparent that cannot permute along K poisons all other coparents in its group\"\"\"\n\n    def __init__(\n        self,\n    ):\n        super().__init__()\n        self.input_shape = [4, 16, 7, 7]\n        self.expected_C_params = 0\n        self.expected_K_params = 0\n\n        self.input_conv = torch.nn.Sequential()\n        self.expected_K_params += 2\n        self.input_conv.add_module(\n            \"input_conv\",\n            torch.nn.Conv2d(self.input_shape[1], 128, kernel_size=(3, 3), padding=1),\n        )\n\n        # two parallel block: conv | conv-> grouped conv\n        self.expected_C_params += 3  # all convs permute along C\n        self.expected_K_params += 2  # only conv_b0 permutes along K\n        self.block_a = torch.nn.Sequential()\n        self.block_a.add_module(\"conv_a\", torch.nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1))\n\n        self.block_b = torch.nn.Sequential()\n        self.block_b.add_module(\"conv_b0\", torch.nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1))\n        self.block_b.add_module(\n            \"conv_b1\",\n            torch.nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1, groups=4),\n        )\n\n        self.output = torch.nn.Sequential()\n        self.output.add_module(\"output\", torch.nn.Conv2d(128, 8, kernel_size=(1, 1)))\n\n    def forward(self, input: torch.Tensor):\n        step0 = self.input_conv(input)\n        step1 = self.block_a(step0) + self.block_b(step0)\n        return self.output(step1)\n\n\nclass depthwise_child_is_sibling(torch.nn.Module):\n    \"\"\"The child of a depthwise convolution should act as a sibling\"\"\"\n\n    def __init__(\n        self,\n    ):\n        super().__init__()\n        self.input_shape = [4, 16, 7, 7]\n        self.expected_C_params = 0\n        self.expected_K_params = 0\n\n        self.input_conv = torch.nn.Sequential()\n        self.expected_K_params += 2\n        self.input_conv.add_module(\n            \"input_conv\",\n            torch.nn.Conv2d(self.input_shape[1], 128, kernel_size=(3, 3), padding=1),\n        )\n\n        # two parallel block: conv | depthwise->conv\n        self.expected_C_params += 2\n        self.expected_K_params += 4 + 2\n        self.block_a = torch.nn.Sequential()\n        self.block_a.add_module(\"conv_a\", torch.nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1))\n\n        self.block_b = torch.nn.Sequential()\n        self.block_b.add_module(\n            \"conv_b_dw\",\n            torch.nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1, groups=128),\n        )\n        self.block_b.add_module(\n            \"conv_b_1\", torch.nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1)\n        )\n\n        self.output_conv = torch.nn.Sequential()\n        self.expected_C_params += 1\n        self.output_conv.add_module(\"output_conv\", torch.nn.Conv2d(128, 8, kernel_size=(1, 1)))\n\n    def forward(self, input: torch.Tensor):\n        step0 = self.input_conv(input)\n        step1 = self.block_a(step0) + self.block_b(step0)\n        return self.output_conv(step1)\n\n\nclass module_attribute(torch.nn.Module):\n    \"\"\"Attributes of some module must be permuted if they feed some operation that is permuted\"\"\"\n\n    def __init__(\n        self,\n        complexity: int = 0,\n    ):\n        super().__init__()\n        self.input_shape = [4, 16, 7, 7]\n        self.expected_C_params = 0\n        self.expected_K_params = 0\n        self.complexity = complexity\n\n        self.input_conv = torch.nn.Sequential()\n        self.expected_K_params += (\n            3  # conv weight, conv bias, input_offset C (counts as K since it's acting as a parent)\n        )\n        self.input_offset = torch.nn.Parameter(torch.zeros(128, 7, 7))\n        torch.nn.init.normal_(self.input_offset.data, mean=0.0, std=2.0)\n        self.input_conv.add_module(\n            \"conv_input\",\n            torch.nn.Conv2d(self.input_shape[1], 128, kernel_size=(3, 3), padding=1),\n        )\n\n        # add a couple more layers, and let the same offset affect another layer, as well\n        if complexity == 1:\n            self.expected_C_params += 2\n            self.expected_K_params += 4\n            self.stack_a = torch.nn.Sequential()\n            self.stack_a.add_module(\n                \"conv_a\", torch.nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1)\n            )\n\n            self.stack_b = torch.nn.Sequential()\n            self.stack_b.add_module(\n                \"conv_b\", torch.nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1)\n            )\n\n        self.output_conv = torch.nn.Sequential()\n        self.expected_C_params += 1\n        self.output_conv.add_module(\"conv_output\", torch.nn.Conv2d(128, 8, kernel_size=(3, 3)))\n\n    def forward(self, input: torch.Tensor):\n        batch_input_offset = self.input_offset.expand(input.shape[0], -1, -1, -1)\n        x = self.input_conv(input) + batch_input_offset\n        if self.complexity == 1:\n            x = self.stack_a(x) + batch_input_offset\n            x = self.stack_b(x) + batch_input_offset\n        return self.output_conv(x)\n\n\nclass square_attribute(torch.nn.Module):\n    \"\"\"Attributes with multiple dimensions matching the permutation length should only be permuted along the correct dimension\"\"\"\n\n    # TODO: currently, such an attribute will disallow permutations around it, but with effort, it could be handled correctly.\n\n    def __init__(\n        self,\n    ):\n        super().__init__()\n        self.input_shape = [4, 16, 16]\n        self.expected_C_params = 0\n        self.expected_K_params = 0\n\n        self.input_linear = torch.nn.Sequential()\n        # self.expected_K_params += 2  # if handled correctly, the linear's K and the offset's K should both be permuted\n        self.input_linear.add_module(\"linear_input\", torch.nn.Linear(self.input_shape[1], 16))\n        self.input_offset = torch.nn.Parameter(torch.zeros(16, 16))\n        torch.nn.init.normal_(self.input_offset.data, mean=0.0, std=2.0)\n\n        self.output_linear = torch.nn.Sequential()\n        # self.expected_C_params += 1  # if handled correctly, this should be permuted\n        self.output_linear.add_module(\"linear_output\", torch.nn.Linear(16, 8))\n\n    def forward(self, input: torch.Tensor):\n        batch_input_offset = self.input_offset.expand(input.shape[0], -1, -1)\n        x = self.input_linear(input) + torch.permute(batch_input_offset, (0, 2, 1))\n        return self.output_linear(x)\n\n\nclass MHA_test(torch.nn.Module):\n    \"\"\"MultiheadAttention modules are unique, we need to check permutations for input and ouput projections\"\"\"\n\n    def __init__(self, hidden_dim: int = 256, seq_len: int = 64, num_heads: int = 16):\n        super().__init__()\n        self.hidden_dim = hidden_dim\n        self.seq_len = seq_len\n        self.num_heads = num_heads\n        self.input_shape = [4, self.seq_len, self.hidden_dim]\n\n        self.expected_C_params = 1\n        self.expected_K_params = 2\n\n        self.MHA0 = torch.nn.MultiheadAttention(\n            self.hidden_dim, self.num_heads, dropout=False, batch_first=True\n        )\n        self.MHA1 = torch.nn.MultiheadAttention(\n            self.hidden_dim, self.num_heads, dropout=False, batch_first=True\n        )\n\n    def forward(self, input: torch.Tensor):\n        step0, _ = self.MHA0(input, input, input)\n        step1, _ = self.MHA1(step0, step0, step0)\n        return step1\n\n\nclass one_sparse_sibling(torch.nn.Module):\n    \"\"\"If only one of two siblings is sparse, both need to be permuted\"\"\"\n\n    def __init__(\n        self,\n    ):\n        super().__init__()\n        self.input_shape = [4, 16, 7, 7]\n        self.expected_C_params = 0\n        self.expected_K_params = 0\n\n        self.in_conv = torch.nn.Sequential()\n        self.expected_K_params += 2\n        self.in_conv.add_module(\n            \"conv_in\",\n            torch.nn.Conv2d(self.input_shape[1], 128, kernel_size=(3, 3), padding=1),\n        )\n\n        self.block_a = torch.nn.Sequential()\n        self.expected_C_params += 1  # only conv_a0 will be permuted along C\n        self.expected_K_params += 2  # only conv_a1 will be permuted along K\n        self.block_a.add_module(\"conv_a0\", torch.nn.Conv2d(128, 3, kernel_size=(1, 1)))\n        self.block_a.add_module(\"conv_a1\", torch.nn.Conv2d(3, 128, kernel_size=(3, 3), padding=1))\n\n        self.block_b = torch.nn.Sequential()\n        self.expected_C_params += 2  # even though conv_a0 will not be sparse (only 3 output channels), conv_b0 can still be permuted along C\n        self.expected_K_params += 4\n        self.block_b.add_module(\"conv_b0\", torch.nn.Conv2d(128, 128, kernel_size=(3, 3), padding=1))\n        self.block_b.add_module(\"conv_b1\", torch.nn.Conv2d(128, 128, kernel_size=(1, 1)))\n\n        self.out_conv = torch.nn.Sequential()\n        self.expected_C_params += 1\n        self.out_conv.add_module(\"conv_out\", torch.nn.Conv2d(128, 8, kernel_size=(1, 1)))\n\n    def forward(self, input: torch.Tensor):\n        step0 = self.in_conv(input)\n        step1 = self.block_a(step0) + self.block_b(step0)\n        return self.out_conv(step1)\n\n\nclass test_concat(torch.nn.Module):\n    \"\"\"If concats are along the channel dimension (dim1 of NCHW), downstream layers can still be permuted despite C!=parentK\"\"\"\n\n    def __init__(\n        self,\n        ratio=1,  # ratio between # channels in either path to be concatenated\n        dim=1,  # dimension to concatenate, K by default\n        depth=1,  # number of concats to stack\n    ):\n        super().__init__()\n        assert dim == 1 or ratio == 1, (\n            \"can't concat along dimensions other than K if K's don't match\"\n        )\n        self.dim = dim\n        self.depth = depth\n        self.input_shape = [4, 16, 7, 7]\n        self.expected_C_params = 0\n        self.expected_K_params = 0\n\n        self.in_conv = torch.nn.Sequential()\n        self.expected_K_params += 2\n        self.in_conv.add_module(\n            \"conv_in\", torch.nn.Conv2d(self.input_shape[1], 64, kernel_size=(1, 1))\n        )\n\n        self.left_paths = torch.nn.ModuleList([torch.nn.Conv2d(64, 64, kernel_size=(1, 1))])\n        self.expected_C_params += 1\n        self.expected_K_params += 2\n\n        in_C = 64\n        out_C = 64\n        for d in range(1, depth, 1):\n            self.expected_C_params += 1\n            self.expected_K_params += 2\n            if dim == 1:\n                out_C += 64\n            self.left_paths.append(torch.nn.Conv2d(in_C + 64, out_C, kernel_size=(1, 1)))\n            if dim == 1:\n                in_C += 64\n\n        self.right_path = torch.nn.Sequential()\n        self.expected_C_params += 1\n        self.expected_K_params += 2\n        self.right_path.add_module(\"conv_b\", torch.nn.Conv2d(64, 64 * ratio, kernel_size=(1, 1)))\n\n        self.out_conv = torch.nn.Sequential()\n        self.expected_C_params += 1\n        if dim == 1:\n            out_C += 64 * ratio\n        self.out_conv.add_module(\"conv_out\", torch.nn.Conv2d(out_C, 16, kernel_size=(1, 1)))\n\n    def forward(self, input: torch.Tensor):\n        step0 = self.in_conv(input)\n        step1 = step0\n        for d, layer in enumerate(self.left_paths):\n            if d == 0:\n                step1 = layer(step1)\n            else:\n                step1 = layer(torch.cat([step1, step0], 1))\n\n        step2 = torch.cat([step1, self.right_path(step0)], self.dim)\n        return self.out_conv(step2)\n\n\nclass test_flatten_op(torch.nn.Module):\n    \"\"\"flatten ops may change the effective channel count, typically by collapsing N,C,H,W into N,C*H*W before a classifier\"\"\"\n\n    def __init__(\n        self,\n        change_dims=True,\n    ):\n        super().__init__()\n        self.change_dims = change_dims\n        self.input_shape = [4, 16, 3, 3]\n        self.expected_C_params = 0\n        self.expected_K_params = 0\n\n        if not self.change_dims:\n            self.input_shape = [4, 16, 1, 1]\n            self.expected_C_params = 1\n            self.expected_K_params = 2\n\n        self.flattened_C = self.input_shape[2] * self.input_shape[3] * 64\n\n        self.in_conv = torch.nn.Conv2d(self.input_shape[1], 64, kernel_size=(1, 1))\n        self.out_gemm = torch.nn.Linear(self.flattened_C, 16)\n\n    def forward(self, input: torch.Tensor):\n        step0 = self.in_conv(input)\n        step1 = torch.flatten(step0, start_dim=1)\n        return self.out_gemm(step1)\n\n\nclass test_flatten_module(torch.nn.Module):\n    \"\"\"flatten modules may change the effective channel count, typically by collapsing N,C,H,W into N,C*H*W before a classifier\"\"\"\n\n    def __init__(\n        self,\n        change_dims=True,\n    ):\n        super().__init__()\n        self.change_dims = change_dims\n        self.input_shape = [4, 16, 3, 3]\n        self.expected_C_params = 0\n        self.expected_K_params = 0\n\n        if not self.change_dims:\n            self.input_shape = [4, 16, 1, 1]\n            self.expected_C_params = 1\n            self.expected_K_params = 2\n\n        self.flattened_C = self.input_shape[2] * self.input_shape[3] * 64\n        self.stack = torch.nn.Sequential()\n        self.stack.add_module(\n            \"conv_in\", torch.nn.Conv2d(self.input_shape[1], 64, kernel_size=(1, 1))\n        )\n        self.stack.add_module(\"flatten\", torch.nn.Flatten(1))\n        self.stack.add_module(\"gemm_out\", torch.nn.Linear(self.flattened_C, 16))\n\n    def forward(self, input: torch.Tensor):\n        return self.stack(input)\n\n\nclass test_trace_failure(torch.nn.Module):\n    \"\"\"make sure tracing failures are handled gracefully\"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.input_shape = [4, 16, 1, 1]\n        self.expected_C_params = 0\n        self.expected_K_params = 0\n\n        self.in_conv = torch.nn.Conv2d(self.input_shape[1], 64, kernel_size=(1, 1))\n        self.out_conv = torch.nn.Conv2d(64, 16, kernel_size=(1, 1))\n\n    def forward(self, input: torch.Tensor):\n        step0 = self.in_conv(input)\n        # NCHW = 4,64,1,1\n        channels = step0.size(1)\n        channel_offset = torch.arange(channels, dtype=torch.long, device=step0.device)\n        channel_offset = channel_offset.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(step0)\n        step0.add_(channel_offset)\n        return self.out_conv(step0)\n\n\nclass already_sparse(torch.nn.Module):\n    \"\"\"if weights are already sparse, permutations should be skipped\"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.input_shape = [4, 16, 3, 3]\n        self.expected_C_params = 0\n        self.expected_K_params = 0\n\n        self.in_conv = torch.nn.Conv2d(self.input_shape[1], 64, kernel_size=(1, 1))\n        self.out_conv = torch.nn.Conv2d(64, 16, kernel_size=(1, 1))\n\n        # apply 2:4 to the output weights, it will not require a permutation\n        out_weights = torch.ones_like(self.out_conv.weight)\n        out_weights[:, 0::2, ...] = 0\n        assert torch.sum(out_weights) == torch.numel(out_weights) / 2\n        self.out_conv.weight.data.copy_(out_weights)\n\n    def forward(self, input: torch.Tensor):\n        step0 = self.in_conv(input)\n        return self.out_conv(step0)\n\n\ndef test_model(model, tag, verbosity=0, save_onnx=False):\n    Permutation.set_identical_seed()\n    x = torch.rand(model.input_shape)\n    if save_onnx:\n        torch.onnx.export(model, x, f\"{tag}.onnx\", verbose=False)\n\n    base_out = model(x)\n\n    sparse_parameters = []\n    all_parameters = []\n\n    module_to_params = {}\n    module_to_params[torch.nn.MultiheadAttention] = (\n        \"q_proj_weight\",\n        \"k_proj_weight\",\n        \"v_proj_weight\",\n        \"in_proj_weight\",\n    )\n\n    for module_name, module in model.named_modules():\n        module_type_str = str(type(module)).split(\"'\")[1]\n        if module_type_str == \"torch.nn.modules.container.Sequential\" or module_type_str.startswith(\n            \"torchvision.models\"\n        ):\n            # filter out the 'torch.nn.modules.container.Sequential' type and the whole model, like 'torchvision.models.vgg.VGG'\n            continue\n        for p_name, p in module.named_parameters():\n            all_parameters.append((module_name, module, p_name, p))\n\n            if isinstance(\n                module,\n                (\n                    torch.nn.Linear,\n                    torch.nn.Conv1d,\n                    torch.nn.Conv2d,\n                    torch.nn.Conv3d,\n                    torch.nn.MultiheadAttention,\n                    torch.nn.modules.linear.NonDynamicallyQuantizableLinear,\n                ),\n            ):\n                allowed_names = (\"weight\",)\n                if type(module) in module_to_params.keys():\n                    allowed_names = module_to_params[type(module)]\n\n                if p_name not in allowed_names:\n                    continue\n\n                if len(p.size()) >= 2 and (p.size()[0] % 8) == 0 and (p.size()[1] % 16) == 0:\n                    mask = torch.ones_like(p).bool()\n                    buffname = p_name.split(\".\")[-1]\n                    module.register_buffer(\"__%s_mma_mask\" % buffname, mask)\n                    sparse_parameters.append((module_name, module, p_name, p, mask, None))\n\n        if module_type_str == \"torch.nn.modules.batchnorm.BatchNorm2d\":\n            # need to get the running_mean and running_var from model.state_dict(), as they are not the learnable parameters\n            module_mean_name = module_name + \".running_mean\"\n            module_var_name = module_name + \".running_var\"\n            for param_key in model.state_dict():\n                if module_mean_name == param_key or module_var_name == param_key:\n                    all_parameters.append(\n                        (\n                            module_name,\n                            module,\n                            param_key.split(\".\")[-1],\n                            model.state_dict()[param_key],\n                        )\n                    )\n\n    if verbosity > 1:\n        sparse_param_names = [\n            module_name + \":\" + p_name\n            for (module_name, module, p_name, p, mask, pruned) in sparse_parameters\n        ]\n        all_param_names = [\n            module_name + \":\" + p_name for (module_name, module, p_name, p) in all_parameters\n        ]\n        print(\n            f\"\\tSparse parameter names: {sparse_param_names}\\n\\tAll parameter names: {all_param_names}\"\n        )\n\n    Permutation.set_permutation_params_from_asp(model, sparse_parameters, all_parameters, verbosity)\n    Permutation.permute_model(model)\n\n    C_params, K_params, missed_dims = Permutation.get_permutation_stats()\n\n    success = True\n    fail_str = \"\"\n    succ_str = \"\"\n    if len(C_params) != model.expected_C_params:\n        success = False\n        fail_str = (\n            fail_str + f\"\\n\\tC expected {model.expected_C_params}, got {len(C_params)} ({C_params})\"\n        )\n    elif verbosity > 0:\n        succ_str = (\n            succ_str + f\"\\n\\tC expected {model.expected_C_params}, got {len(C_params)} ({C_params})\"\n        )\n\n    if len(K_params) != model.expected_K_params:\n        success = False\n        fail_str = (\n            fail_str + f\"\\n\\tK expected {model.expected_K_params}, got {len(K_params)} ({K_params})\"\n        )\n    elif verbosity > 0:\n        succ_str = (\n            succ_str + f\"\\n\\tK expected {model.expected_K_params}, got {len(K_params)} ({K_params})\"\n        )\n\n    if len(missed_dims) != 0:\n        success = False\n        fail_str = (\n            fail_str\n            + f\"\\n\\tMissed permutations along {len(missed_dims)} dimensions ({missed_dims})\"\n        )\n\n    perm_out = model(x)\n\n    atol = 1e-5\n    rtol = 1e-4\n    outs_match = torch.allclose(base_out.data, perm_out.data, atol=atol, rtol=rtol)\n    if not outs_match:\n        fail_str = fail_str + f\"\\n\\tOutputs matched: {outs_match}\"\n        if success:\n            diffs = base_out - perm_out\n            diff_locs = (diffs >= atol).nonzero(as_tuple=True)\n            fail_str = fail_str + f\"\\n{diff_locs}\\n{diffs[diff_locs]}\"\n        success = False\n\n    if success:\n        print(f\"{tag}: Success\\t{succ_str}\")\n    else:\n        print(f\"{tag}: FAIL\\t{fail_str}\")\n\n    return success\n\n\ndef main():\n    global_success = True\n\n    global_success &= test_model(simple_convs(2, 16), \"smoke test\")\n    global_success &= test_model(simple_convs(5, 64), \"simple 5 64\")\n    global_success &= test_model(simple_convs(10, 32), \"simple 10 32\")\n    # normalization\n    for norm in [\n        \"BatchNorm2d\",\n        \"LazyBatchNorm2d\",\n        \"InstanceNorm2d\",\n        \"LazyInstanceNorm2d\",\n        \"LayerNorm3\",\n        \"LocalResponseNorm\",\n    ]:\n        global_success &= test_model(simple_convs(4, 128, norm), norm)\n    # disallowed normalization\n    for norm in [\"GroupNorm\"]:\n        global_success &= test_model(simple_convs(4, 128, norm), norm)\n\n    global_success &= test_model(conv_1d(), \"conv1d\")\n    global_success &= test_model(conv_1d(with_2d=True), \"conv1d and conv2d\")\n    global_success &= test_model(grouped_convs(), \"grouped convs\")\n    global_success &= test_model(simple_forks_joins(), \"forks and joins\")\n    global_success &= test_model(different_grouped_convs(), \"GCD\")\n    global_success &= test_model(siblings_poison(), \"sibling poison\")\n    global_success &= test_model(coparent_poison(), \"coparent poison\")\n    global_success &= test_model(depthwise_child_is_sibling(), \"dw child is sibling\")\n    global_success &= test_model(module_attribute(complexity=0), \"single attribute\")\n    global_success &= test_model(module_attribute(complexity=1), \"single attribute thrice\")\n    global_success &= test_model(MHA_test(hidden_dim=256, seq_len=64, num_heads=16), \"stacked MHA\")\n    global_success &= test_model(one_sparse_sibling(), \"one sparse sibling\")\n    global_success &= test_model(test_concat(), \"simple concat\")  # concat along K\n    global_success &= test_model(test_concat(dim=0), \"concat dim0\")  # concat along C\n    global_success &= test_model(\n        test_concat(ratio=2), \"concat ratio2\"\n    )  # concat along K with different K values\n    global_success &= test_model(\n        test_concat(depth=2), \"concat depth2\"\n    )  # concat along K multiple times\n    global_success &= test_model(test_concat(depth=3), \"concat depth3\")\n    global_success &= test_model(test_concat(ratio=3, depth=4), \"concat ratio3 depth4\")\n    global_success &= test_model(test_concat(dim=0, depth=3), \"concat dim0 depth3\")\n    global_success &= test_model(test_flatten_op(), \"flatten op\")\n    global_success &= test_model(test_flatten_op(change_dims=False), \"useless flatten op\")\n    global_success &= test_model(test_flatten_module(), \"flatten module\")\n    global_success &= test_model(test_flatten_module(change_dims=False), \"useless flatten module\")\n    global_success &= test_model(test_trace_failure(), \"trace failure\")\n    global_success &= test_model(already_sparse(), \"skip already sparse\")\n    global_success &= test_model(square_attribute(), \"square attributes\")\n\n    if global_success:\n        print(\"All tests completed successfully.\")\n    else:\n        print(\"There was at least one failure.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "apex/contrib/sparsity/test/toy_problem.py",
    "content": "from collections import OrderedDict\n\nimport torch\nfrom apex.optimizers import FusedAdam\nfrom apex.contrib.sparsity import ASP\n\n\ndef build_model(args):\n    od = OrderedDict()\n    for i in range(args.num_layers):\n        if i == 0:\n            od[\"linear_layer_%d\" % (i + 1)] = torch.nn.Linear(\n                args.input_features, args.hidden_features\n            )\n            od[\"layer_norm_%d\" % (i + 1)] = torch.nn.LayerNorm(\n                [args.batch_size, args.hidden_features]\n            )\n        elif i == args.num_layers - 1:\n            od[\"linear_layer_%d\" % (i + 1)] = torch.nn.Linear(\n                args.hidden_features, args.output_features\n            )\n            od[\"layer_norm_%d\" % (i + 1)] = torch.nn.LayerNorm(\n                [args.batch_size, args.output_features]\n            )\n        else:\n            od[\"linear_layer_%d\" % (i + 1)] = torch.nn.Linear(\n                args.hidden_features, args.hidden_features\n            )\n            od[\"layer_norm_%d\" % (i + 1)] = torch.nn.LayerNorm(\n                [args.batch_size, args.hidden_features]\n            )\n    return torch.nn.Sequential(od)\n\n\ndef train_step(args, model, optimizer, input_batch, target_batch, step):\n    predicted_target = model(input_batch)\n    loss = ((predicted_target - target_batch) ** 2).sum()\n    loss.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n    step = step + 1\n    # print(\"Step %d :: loss=%e\" % (step, loss.item()))\n    return step\n\n\ndef train_loop(args, model, optimizer, step, num_steps):\n    for i in range(num_steps):\n        input_batch = torch.randn([args.batch_size, args.input_features]).cuda()\n        target_batch = torch.randn([args.batch_size, args.output_features]).cuda()\n        step = train_step(args, model, optimizer, input_batch, target_batch, step)\n    return step\n\n\ndef main(args):\n    model = build_model(args).cuda()\n    one_ll = next(model.children()).weight\n    optimizer = FusedAdam(model.parameters())\n    # only prune linear layers, even though we also support conv1d, conv2d and conv3d\n    ASP.init_model_for_pruning(\n        model, \"m4n2_1d\", whitelist=[torch.nn.Linear], allow_recompute_mask=True\n    )\n    ASP.init_optimizer_for_pruning(optimizer)\n\n    step = 0\n\n    # train for a few steps with dense weights\n    print(\"DENSE :: \", one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_dense_steps)\n\n    # simulate sparsity by inserting zeros into existing dense weights\n    ASP.compute_sparse_masks()\n\n    # train for a few steps with sparse weights\n    print(\"SPARSE :: \", one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_sparse_steps)\n\n    # recompute sparse masks\n    ASP.compute_sparse_masks()\n\n    # train for a few steps with sparse weights\n    print(\"SPARSE :: \", one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2)\n\n    # turn off sparsity\n    print(\"SPARSE :: \", one_ll)\n    ASP.restore_pruned_weights()\n\n    # train for a few steps with dense weights\n    print(\"DENSE :: \", one_ll)\n    step = train_loop(args, model, optimizer, step, args.num_dense_steps_2)\n\n\nif __name__ == \"__main__\":\n\n    class Args:\n        batch_size = 32\n        input_features = 16\n        output_features = 8\n        hidden_features = 40\n        num_layers = 4\n        num_dense_steps = 2000\n        num_sparse_steps = 3000\n        num_sparse_steps_2 = 1000\n        num_dense_steps_2 = 1500\n\n    args = Args()\n\n    main(args)\n"
  },
  {
    "path": "apex/contrib/test/__init__.py",
    "content": ""
  },
  {
    "path": "apex/contrib/test/bottleneck/__init__.py",
    "content": ""
  },
  {
    "path": "apex/contrib/test/bottleneck/test_bottleneck_module.py",
    "content": "import unittest\n\nimport torch\nfrom torch.testing._internal import common_utils\n\nfrom apex.distributed_testing.distributed_test_base import NcclDistributedTestBase\n\nSKIP_TEST = None\ntry:\n    from apex.contrib.bottleneck import Bottleneck, SpatialBottleneck\n    from apex.contrib.bottleneck import HaloExchangerPeer\n    from apex.contrib.peer_memory import PeerMemoryPool\nexcept ImportError as e:\n    SKIP_TEST = e\n\n\ndef ground_truth_bottleneck(C, dtype, explicit_nhwc):\n    bottleneck = Bottleneck(C, C, C, use_cudnn=True, explicit_nhwc=explicit_nhwc)\n    bottleneck.to(dtype=dtype, device=\"cuda\")\n    for p in bottleneck.parameters():\n        torch.distributed.broadcast(p, 0)\n    for b in bottleneck.buffers():\n        torch.distributed.broadcast(b, 0)\n    return bottleneck\n\n\ndef print_bottleneck_p_and_b(bottleneck):\n    with torch.no_grad():\n        for n, p in bottleneck.named_parameters():\n            print(\"%s :: %s\" % (n, str(p.norm(p=2, dtype=torch.float32))))\n        for n, p in bottleneck.named_buffers():\n            print(\"%s :: %s\" % (n, str(p.norm(p=2, dtype=torch.float32))))\n\n\ndef has_nan(x):\n    if isinstance(x, list) or isinstance(x, tuple):\n        for xx in x:\n            if torch.any(torch.isnan(xx)):\n                return True\n        return False\n    elif isinstance(x, dict):\n        for k, v in x.items():\n            if torch.any(torch.isnan(v)):\n                return True\n    else:\n        return torch.any(torch.isnan(x))\n\n\ndef rel_diff_t(xx1, xx2):\n    return (\n        (xx1 - xx2).norm(p=2, dtype=torch.float32) / (xx1 + xx2).norm(p=2, dtype=torch.float32)\n    ).item()\n\n\ndef rel_diff(x1, x2):\n    if isinstance(x1, list) or isinstance(x1, tuple):\n        return [rel_diff_t(xx1, xx2) for xx1, xx2 in zip(x1, x2)]\n    elif isinstance(x1, dict):\n        return [rel_diff_t(xx1, xx2) for (k1, xx1), (k2, xx2) in zip(x1.items(), x2.items())]\n    else:\n        return rel_diff_t(x1, x2)\n\n\ndef graph_it(bottleneck, x):\n    print(\"Graphing\")\n    with torch.no_grad():\n        x = x.clone()\n        x.grad = None\n        x.requires_grad = True\n    return torch.cuda.make_graphed_callables(bottleneck, (x,))\n\n\ndef clone_inputs(bottleneck, x, dy=None):\n    with torch.no_grad():\n        x = x.clone()\n        x.grad = None\n        x.requires_grad = True\n        if dy is None:\n            y = bottleneck(x)\n            dy = torch.randn_like(y) / 1e2\n            torch.distributed.broadcast(dy, 0)\n    return x, dy\n\n\ndef fprop_and_bprop(bottleneck, x, dy):\n    y = bottleneck(x)\n    y.backward(dy)\n    dgrad = x.grad.detach()\n    wgrad = {}\n    for n, p in bottleneck.named_parameters():\n        wgrad[n] = p.grad.detach()\n    return x, y, dy, dgrad, wgrad\n\n\ndef ground_truth(N, C, H, W, dtype, memory_format, bottleneck):\n    if memory_format == 1:\n        # 1 -> explicit nhwc\n        explicit_nhwc = True\n        with torch.no_grad():\n            x = torch.randn([N, H, W, C], dtype=dtype, device=\"cuda\")\n            torch.distributed.broadcast(x, 0)\n            x, dy = clone_inputs(bottleneck, x)\n        return fprop_and_bprop(bottleneck, x, dy)\n    else:\n        # 2 -> native nhwc\n        # 3 -> nchw\n        explicit_nhwc = False\n        assert False, \"Not implemented yet\"\n\n\ndef print_ground_truth(gt):\n    x, y, dy, dgrad, wgrad = gt\n    if has_nan(y) or has_nan(dgrad) or has_nan(wgrad):\n        print(\"Error! Ground truth has NAN\")\n    else:\n        print(\"Ok! No NAN found in ground truth\")\n\n\ndef apply_to_different_bottleneck(gt, bottleneck):\n    with torch.no_grad():\n        x, _, dy, _, _ = gt\n        x, dy = clone_inputs(bottleneck, x, dy)\n    return fprop_and_bprop(bottleneck, x, dy)\n\n\ndef compare_single_field(results, f1, f2, l0, l1, l2):\n    if has_nan(f1) and has_nan(f2):\n        results[l0] = \"both NAN\"\n    elif has_nan(f1):\n        results[l0] = \"%s.%s NAN\" % (l1, l0)\n    elif has_nan(f2):\n        results[l0] = \"%s.%s NAN\" % (l2, l0)\n    else:\n        results[l0] = \"%s\" % (str(rel_diff(f1, f2)))\n\n\ndef compare(gt, bt):\n    x1, y1, dy1, dgrad1, wgrad1 = gt\n    x2, y2, dy2, dgrad2, wgrad2 = bt\n    results = {}\n    compare_single_field(results, y1, y2, \"y\", \"gt\", \"bt\")\n    compare_single_field(results, dy1, dy2, \"dy\", \"gt\", \"bt\")\n    compare_single_field(results, dgrad1, dgrad2, \"dgrad\", \"gt\", \"bt\")\n    compare_single_field(results, wgrad1, wgrad2, \"wgrad\", \"gt\", \"bt\")\n    for i in range(torch.distributed.get_world_size()):\n        if i == torch.distributed.get_rank():\n            print(i, results)\n        torch.distributed.barrier()\n\n\ndef spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args):\n    spatial_bottleneck = SpatialBottleneck(\n        C,\n        C,\n        C,\n        use_cudnn=True,\n        explicit_nhwc=explicit_nhwc,\n        spatial_parallel_args=spatial_parallel_args,\n    )\n    spatial_bottleneck.to(dtype=dtype, device=\"cuda\")\n    with torch.no_grad():\n        sp = {}\n        for n, p in spatial_bottleneck.named_parameters():\n            sp[n] = p\n        for n, p in gt_bottleneck.named_parameters():\n            sp[n].copy_(p)\n        sb = {}\n        for n, b in spatial_bottleneck.named_buffers():\n            sb[n] = b\n        for n, b in gt_bottleneck.named_buffers():\n            sb[n].copy_(b)\n    return spatial_bottleneck\n\n\ndef n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp32_reduce=False):\n    assert explicit_nhwc, \"Only tested for explicit nhwc\"\n\n    x, _, dy, _, _ = gt\n    N, H, W, C = list(x.shape)  # Tensor is already shaped properly for n-way parallel\n    dtype = x.dtype\n\n    spatial_group_size = world_size\n    spatial_group_rank = rank\n    spatial_communicator = None\n    spatial_halo_exchanger = halex\n    spatial_method = 1  # 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x\n    use_delay_kernel = False\n    spatial_parallel_args = (\n        spatial_group_size,\n        spatial_group_rank,\n        spatial_communicator,\n        spatial_halo_exchanger,\n        spatial_method,\n        use_delay_kernel,\n    )\n    spatial_bottleneck = spatial_parallel_bottleneck(\n        C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args\n    )\n\n    with torch.no_grad():\n        Hs = H // spatial_group_size\n        xs = x[:, spatial_group_rank * Hs : (spatial_group_rank + 1) * Hs, :, :].clone()\n        dys = dy[:, spatial_group_rank * Hs : (spatial_group_rank + 1) * Hs, :, :].clone()\n        xs.requires_grad = True\n\n    spatial_bottleneck = graph_it(spatial_bottleneck, xs)\n    _, y, _, dgrad, wgrad = fprop_and_bprop(spatial_bottleneck, xs, dys)\n\n    # gather output pieces\n    for n, p in wgrad.items():\n        if fp32_reduce:\n            p32 = p.float()\n            torch.distributed.all_reduce(p32)\n            p.copy_(p32.half())\n        else:\n            torch.distributed.all_reduce(p)\n    ys = [torch.empty_like(y) for _ in range(spatial_group_size)]\n    torch.distributed.all_gather(ys, y)\n    y = torch.cat(ys, dim=1)\n    dgrads = [torch.empty_like(dgrad) for _ in range(spatial_group_size)]\n    torch.distributed.all_gather(dgrads, dgrad)\n    dgrad = torch.cat(dgrads, dim=1)\n    return x, y, dy, dgrad, wgrad\n\n\ndef main():\n    torch.use_deterministic_algorithms(True)\n\n    torch.distributed.init_process_group(\"nccl\")\n    rank = torch.distributed.get_rank()\n    world_size = torch.distributed.get_world_size()\n    torch.cuda.set_device(rank)\n\n    explicit_nhwc = True\n\n    dtype = torch.float16\n    N, C, H, W = 1, 64, 200, 336\n    Hs = ((H + 8 * world_size - 1) // (8 * world_size)) * 8\n    H = Hs * world_size\n    gt_bottleneck = ground_truth_bottleneck(C, dtype, explicit_nhwc)\n    gt = ground_truth(N, C, H, W, dtype, 1, gt_bottleneck)\n\n    # verify that spatial bottleneck with group_size 1 produces same results as ground truth bottleneck\n    spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, None)\n    bt = apply_to_different_bottleneck(gt, spatial_bottleneck)\n    compare(gt, bt)\n    # print_bottleneck_p_and_b(gt_bottleneck)\n    # print_bottleneck_p_and_b(spatial_bottleneck)\n\n    group_size = world_size\n    group = rank // group_size\n    ranks = [group * group_size + i for i in range(group_size)]\n    rank_in_group = rank % group_size\n\n    spatial_group_size = world_size\n    spatial_communicator = None\n\n    peer_pool = PeerMemoryPool(0, 64 * 1024 * 1024, ranks)\n\n    # class HaloExchangerNoComm(HaloExchanger):\n    #    def __init__(self, ranks, rank_in_group):\n    # class HaloExchangerAllGather(HaloExchanger):\n    #    def __init__(self, ranks, rank_in_group, comm):\n    # class HaloExchangerSendRecv(HaloExchanger):\n    #    def __init__(self, ranks, rank_in_group):\n    # class HaloExchangerPeer(HaloExchanger):\n    #    def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=1):\n\n    # halex = HaloExchangerAllGather(ranks, rank_in_group)\n    # halex = HaloExchangerSendRecv(ranks, rank_in_group)\n\n    halex = HaloExchangerPeer(ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=0)\n    # print(\"halex.signals = %s\" % (str(halex.signals)))\n    # Make sure peer memory halo exchanger has finished initializing flags on all ranks before proceeding\n    # torch.cuda.synchronize()\n    # torch.distributed.barrier()\n\n    bt2 = n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp32_reduce=True)\n    compare(gt, bt2)\n\n\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\nclass TestBottleneck(NcclDistributedTestBase):\n    # PyTorch's float16 tolerance values, see https://pytorch.org/docs/stable/testing.html#torch.testing.assert_close\n    fp16_tolerance = {\"atol\": 1e-5, \"rtol\": 1e-3}\n\n    @property\n    def world_size(self) -> int:\n        return min(torch.cuda.device_count(), 2)\n\n    def test_bottleneck_without_peer_memory(self) -> None:\n        explicit_nhwc: bool = True\n        dtype: torch.dtype = torch.float16\n        N, C, H, W = 1, 64, 200, 336\n        Hs = ((H + 8 * self.world_size - 1) // (8 * self.world_size)) * 8\n        H = Hs * self.world_size\n\n        gt_bottleneck = ground_truth_bottleneck(C, dtype, explicit_nhwc)\n        gt = ground_truth(N, C, H, W, dtype, 1, gt_bottleneck)\n\n        spatial_bottleneck = spatial_parallel_bottleneck(\n            C, dtype, explicit_nhwc, gt_bottleneck, None\n        )\n        bt = apply_to_different_bottleneck(gt, spatial_bottleneck)\n        self.assertEqual(gt, bt, **self.fp16_tolerance)\n\n    @unittest.skipIf(\n        torch.cuda.device_count() < 2 or not torch.cuda.can_device_access_peer(0, 1),\n        \"peer memory access not supported\",\n    )\n    def test_bottleneck_with_peer_memory(self) -> None:\n        explicit_nhwc: bool = True\n        dtype: torch.dtype = torch.float16\n        N, C, H, W = 1, 64, 200, 336\n        Hs = ((H + 8 * self.world_size - 1) // (8 * self.world_size)) * 8\n        H = Hs * self.world_size\n\n        gt_bottleneck = ground_truth_bottleneck(C, dtype, explicit_nhwc)\n        gt = ground_truth(N, C, H, W, dtype, 1, gt_bottleneck)\n\n        group = self.rank // self.world_size\n        ranks = [group * self.world_size + i for i in range(self.world_size)]\n        rank_in_group = self.rank % self.world_size\n\n        spatial_group_size, spatial_communicator = self.world_size, None\n        peer_pool = PeerMemoryPool(0, 64 * 1024 * 1024, ranks)\n        halo_exchanger_peer = HaloExchangerPeer(\n            ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=0\n        )\n        bt2 = n_way_spatial(\n            halo_exchanger_peer,\n            gt_bottleneck,\n            gt,\n            explicit_nhwc,\n            self.world_size,\n            self.rank,\n            fp32_reduce=True,\n        )\n        # TODO(crcrpar): Investigate the implementation to mitigate the numerical errors.\n        # NOTE(crcrpar): This assert often fails due to numerical errors.\n        # self.assertEqual(gt, bt2, **self.fp16_tolerance)\n\n\nif __name__ == \"__main__\":\n    common_utils.run_tests()\n"
  },
  {
    "path": "apex/contrib/test/clip_grad/__init__.py",
    "content": ""
  },
  {
    "path": "apex/contrib/test/clip_grad/test_clip_grad.py",
    "content": "import random\nimport unittest\n\nimport torch\n\nSKIP_TEST = None\ntry:\n    from apex.contrib.clip_grad import clip_grad_norm_\nexcept ImportError as e:\n    SKIP_TEST = e\n\n\ndef make_params(\n    num_params,\n    sizes=[1, 2, 3, 4, 5],\n    num_dims=[1, 2, 3],\n    dtypes=[torch.float32],\n    devices=[\"cuda\"],\n    make_copy=False,\n):\n    \"\"\"Construct parameters with random configurations\"\"\"\n\n    # Construct parameters\n    params = []\n    for _ in range(num_params):\n        dims = [random.choice(sizes) for _ in range(random.choice(num_dims))]\n        dtype = random.choice(dtypes)\n        device = random.choice(devices)\n        p = torch.nn.Parameter(torch.randn(dims, dtype=dtype, device=device))\n        p.grad = torch.randn_like(p)\n        params.append(p)\n\n    # Copy parameters if needed\n    if make_copy:\n        params_copy = []\n        for p in params:\n            p_copy = p.clone().detach()\n            p_copy.grad = p.grad.clone().detach()\n            params_copy.append(p_copy)\n        return params, params_copy\n    else:\n        return params\n\n\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\nclass ClipGradNormTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        super().setUp()\n        random.seed(seed)\n        torch.manual_seed(seed)\n\n    def test_matches_pytorch(\n        self,\n        num_params=41,\n        dtypes=[torch.float32, torch.float16, torch.float64],\n        devices=[\"cuda\", \"cpu\"],\n        max_norm=0.54321,\n        norm_type=2.0,\n        rtol=1e-3,\n        atol=1e-20,\n    ):\n        \"\"\"Make sure PyTorch and Apex gradient clipping produce same results\"\"\"\n\n        # Construct identical sets of parameters\n        torch_params, apex_params = make_params(\n            num_params,\n            dtypes=dtypes,\n            devices=devices,\n            make_copy=True,\n        )\n\n        # Apply gradient clipping\n        torch_norm = torch.nn.utils.clip_grad_norm_(\n            torch_params,\n            max_norm,\n            norm_type=norm_type,\n        )\n        apex_norm = clip_grad_norm_(\n            apex_params,\n            max_norm,\n            norm_type=norm_type,\n        )\n\n        # Make sure PyTorch and Apex get same results\n        torch.testing.assert_close(\n            apex_norm,\n            torch_norm,\n            rtol=rtol,\n            atol=atol,\n            check_dtype=False,\n        )\n        for torch_p, apex_p in zip(torch_params, apex_params):\n            torch.testing.assert_close(\n                apex_p,\n                torch_p,\n                rtol=0,\n                atol=0,\n            )  # Params should be unaffected\n            torch.testing.assert_close(\n                apex_p.grad,\n                torch_p.grad,\n                rtol=rtol,\n                atol=atol,\n            )\n\n    def test_matches_pytorch_fp16(self):\n        self.test_matches_pytorch(num_params=11, dtypes=[torch.float16])\n\n    def test_matches_pytorch_fp32(self):\n        self.test_matches_pytorch(dtypes=[torch.float32], rtol=1e-6)\n\n    def test_matches_pytorch_fp64(self):\n        self.test_matches_pytorch(dtypes=[torch.float64], rtol=1e-15)\n\n    def test_matches_pytorch_cpu(self):\n        self.test_matches_pytorch(devices=[\"cpu\"])\n\n    def test_matches_pytorch_infnorm(self):\n        self.test_matches_pytorch(norm_type=float(\"inf\"))\n\n    def test_matches_pytorch_1norm(self):\n        self.test_matches_pytorch(norm_type=1.0)\n\n    def test_raises_on_mismatch(self):\n        # Construct different sets of parameters\n        torch_params, apex_params = make_params(7, make_copy=True)\n        with torch.no_grad():\n            torch_params[0].grad.view(-1)[0] = 1.23\n            apex_params[0].grad.view(-1)[0] = 3.21\n\n        # Apply gradient clipping\n        torch_norm = torch.nn.utils.clip_grad_norm_(\n            torch_params,\n            0.54321,\n        )\n        apex_norm = clip_grad_norm_(\n            apex_params,\n            0.54321,\n        )\n\n        # Make sure PyTorch and Apex get different results\n        self.assertRaises(\n            AssertionError,\n            torch.testing.assert_close,\n            apex_norm,\n            torch_norm,\n            rtol=1e-3,\n            atol=1e-20,\n            check_dtype=False,\n        )\n        for torch_p, apex_p in zip(torch_params, apex_params):\n            self.assertRaises(\n                AssertionError,\n                torch.testing.assert_close,\n                apex_p.grad,\n                torch_p.grad,\n                rtol=1e-3,\n                atol=1e-20,\n            )\n\n    def test_raises_on_nan(self):\n        params = make_params(5, num_dims=[1])\n        params[2].grad[-1] = float(\"NaN\")\n        self.assertRaises(RuntimeError, clip_grad_norm_, params, 1.0, error_if_nonfinite=True)\n\n    def test_raises_on_inf(self):\n        params = make_params(5, num_dims=[1])\n        params[2].grad[-1] = float(\"inf\")\n        self.assertRaises(RuntimeError, clip_grad_norm_, params, 1.0, error_if_nonfinite=True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "apex/contrib/test/conv_bias_relu/__init__.py",
    "content": ""
  },
  {
    "path": "apex/contrib/test/conv_bias_relu/test_conv_bias_relu.py",
    "content": "import copy\nimport math\nimport random\nimport unittest\n\nimport torch\nimport torch.nn.functional as F\n\nHAS_CONV_BIAS_RELU = None\ntry:\n    from apex.contrib.conv_bias_relu import (\n        ConvBiasReLU,\n        ConvBias,\n        ConvBiasMaskReLU,\n        ConvFrozenScaleBiasReLU,\n    )\nexcept ImportError:\n    HAS_CONV_BIAS_RELU = False\nelse:\n    HAS_CONV_BIAS_RELU = True\n\n\n@unittest.skipIf(not HAS_CONV_BIAS_RELU, \"`apex.contrib.conv_bias_relu` is not found.\")\nclass FusedDenseTest(unittest.TestCase):\n    def setUp(self, seed=0):\n        super().setUp()\n        torch.manual_seed(seed)\n\n        self.batch_size = random.randint(1, 64)\n        self.in_channels = random.randint(1, 64) * 8\n        self.out_channels = random.randint(1, 64) * 8\n        self.in_height = self.in_width = random.randint(5, 100)\n        self.conv_kernel_size = random.randint(1, 5)\n        self.conv_pad = random.randint(0, int(self.conv_kernel_size / 2))\n        self.conv_stride = random.randint(1, 5)\n        self.conv_dilation = 1\n        self.out_height = self.out_width = math.floor(\n            (\n                self.in_height\n                + 2 * self.conv_pad\n                - self.conv_dilation * (self.conv_kernel_size - 1)\n                - 1\n            )\n            / self.conv_stride\n            + 1\n        )\n\n        self.x = (\n            torch.randint(\n                low=-16,\n                high=16,\n                size=[self.batch_size, self.in_channels, self.in_height, self.in_width],\n            )\n            .cuda()\n            .to(memory_format=torch.channels_last)\n            .float()\n        )\n        self.x_ = self.x.clone()\n        self.x.requires_grad_()\n        self.x_.requires_grad_()\n\n        self.mask = (\n            torch.randn([self.batch_size, self.out_channels, self.out_height, self.out_width])\n            .cuda()\n            .to(memory_format=torch.channels_last)\n        )\n        self.mask = (self.mask > 0).to(torch.int8)\n        self.mask_ = self.mask.clone()\n\n        self.scale = torch.randn([1, self.out_channels, 1, 1]).half().cuda()\n        self.scale_ = self.scale.clone()\n        self.bias = torch.randn([1, self.out_channels, 1, 1]).half().cuda()\n        self.bias_ = self.bias.clone()\n\n        self.conv1 = (\n            torch.nn.Conv2d(\n                self.in_channels,\n                self.out_channels,\n                self.conv_kernel_size,\n                stride=self.conv_stride,\n                padding=self.conv_pad,\n            )\n            .cuda()\n            .to(memory_format=torch.channels_last)\n        )\n        self.conv1_ = copy.deepcopy(self.conv1)\n\n        self.conv2 = (\n            torch.nn.Conv2d(\n                self.in_channels,\n                self.out_channels,\n                self.conv_kernel_size,\n                stride=self.conv_stride,\n                padding=self.conv_pad,\n                bias=False,\n            )\n            .cuda()\n            .to(memory_format=torch.channels_last)\n        )\n        self.conv2_ = copy.deepcopy(self.conv2)\n\n        print()\n        print(\n            \"> input=[{}, {}, {}, {}]\".format(\n                self.batch_size, self.in_channels, self.in_height, self.in_width\n            )\n        )\n        print(\n            \"> kernel=[{}, {}, {}, {}], stride={}, pad={}\".format(\n                self.out_channels,\n                self.in_channels,\n                self.conv_kernel_size,\n                self.conv_kernel_size,\n                self.conv_stride,\n                self.conv_pad,\n            )\n        )\n\n    def test_conv_bias_relu(self):\n        with torch.amp.autocast(\"cuda\", dtype=torch.half):\n            out = ConvBiasReLU(\n                self.x,\n                self.conv1.weight,\n                self.conv1.bias.reshape(1, -1, 1, 1),\n                self.conv_pad,\n                self.conv_stride,\n            )\n            loss = (out.float() ** 2).sum() / out.numel()\n        loss.backward()\n        with torch.amp.autocast(\"cuda\", dtype=torch.half):\n            out_ = F.relu(self.conv1_(self.x_))\n            loss_ = (out_**2).sum() / out_.numel()\n        loss_.backward()\n\n        torch.testing.assert_close(out_, out, atol=1e-3, rtol=1e-3, equal_nan=True)\n        torch.testing.assert_close(\n            self.conv1_.bias.grad,\n            self.conv1.bias.grad,\n            atol=1e-3,\n            rtol=1e-3,\n            equal_nan=True,\n        )\n        torch.testing.assert_close(\n            self.conv1_.weight.grad,\n            self.conv1.weight.grad,\n            atol=1e-3,\n            rtol=1e-3,\n            equal_nan=True,\n        )\n        torch.testing.assert_close(self.x_.grad, self.x.grad, atol=1e-3, rtol=1e-3, equal_nan=True)\n\n    def test_conv_bias(self):\n        with torch.amp.autocast(\"cuda\", dtype=torch.half):\n            out = ConvBias(\n                self.x,\n                self.conv1.weight,\n                self.conv1.bias.reshape(1, -1, 1, 1),\n                self.conv_pad,\n                self.conv_stride,\n            )\n            loss = (out.float() ** 2).sum() / out.numel()\n        loss.backward()\n\n        with torch.amp.autocast(\"cuda\", dtype=torch.half):\n            out_ = self.conv1_(self.x_)\n            loss_ = (out_**2).sum() / out_.numel()\n        loss_.backward()\n\n        torch.testing.assert_close(out, out_, atol=1e-3, rtol=1e-3, equal_nan=True)\n        torch.testing.assert_close(\n            self.conv1_.bias.grad,\n            self.conv1.bias.grad,\n            atol=1e-3,\n            rtol=1e-3,\n            equal_nan=True,\n        )\n        torch.testing.assert_close(\n            self.conv1_.weight.grad,\n            self.conv1.weight.grad,\n            atol=1e-3,\n            rtol=1e-3,\n            equal_nan=True,\n        )\n        torch.testing.assert_close(self.x_.grad, self.x.grad, atol=1e-3, rtol=1e-3, equal_nan=True)\n\n    def test_conv_bias_mask_relu(self):\n        with torch.amp.autocast(\"cuda\", dtype=torch.half):\n            out = ConvBiasMaskReLU(\n                self.x,\n                self.conv1.weight,\n                self.conv1.bias.reshape(1, -1, 1, 1),\n                self.mask,\n                self.conv_pad,\n                self.conv_stride,\n            )\n            loss = (out.float() ** 2).sum() / out.numel()\n        loss.backward()\n        with torch.amp.autocast(\"cuda\", dtype=torch.half):\n            out_ = F.relu(self.conv1_(self.x_) * self.mask_)\n            loss_ = (out_**2).sum() / out_.numel()\n        loss_.backward()\n\n        torch.testing.assert_close(out, out_, atol=1e-3, rtol=1e-3, equal_nan=True)\n        torch.testing.assert_close(\n            self.conv1_.bias.grad,\n            self.conv1.bias.grad,\n            atol=1e-3,\n            rtol=1e-3,\n            equal_nan=True,\n        )\n        torch.testing.assert_close(\n            self.conv1_.weight.grad,\n            self.conv1.weight.grad,\n            atol=1e-3,\n            rtol=1e-3,\n            equal_nan=True,\n        )\n        torch.testing.assert_close(self.x_.grad, self.x.grad, atol=1e-3, rtol=1e-3, equal_nan=True)\n\n    def test_conv_frozen_scale_bias_relu(self):\n        with torch.amp.autocast(\"cuda\", dtype=torch.half):\n            out = ConvFrozenScaleBiasReLU(\n                self.x,\n                self.conv2.weight,\n                self.scale,\n                self.bias,\n                self.conv_pad,\n                self.conv_stride,\n            )\n            loss = (out.float() ** 2).sum() / out.numel()\n        loss.backward()\n        with torch.amp.autocast(\"cuda\", dtype=torch.half):\n            out_ = F.relu(self.conv2_(self.x_) * self.scale_ + self.bias_)\n            loss_ = (out_**2).sum() / out_.numel()\n        loss_.backward()\n\n        torch.testing.assert_close(out, out_, atol=2.5e-3, rtol=2.5e-3, equal_nan=True)\n        torch.testing.assert_close(\n            self.conv2_.weight.grad,\n            self.conv2.weight.grad,\n            atol=1e-3,\n            rtol=1e-3,\n            equal_nan=True,\n        )\n        torch.testing.assert_close(self.x_.grad, self.x.grad, atol=1e-3, rtol=1e-3, equal_nan=True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "apex/contrib/test/cudnn_gbn/__init__.py",
    "content": ""
  },
  {
    "path": "apex/contrib/test/cudnn_gbn/test_cudnn_gbn_with_two_gpus.py",
    "content": "import copy\nimport typing\nimport unittest\n\nimport torch\nimport torch.nn as nn\nfrom torch.testing._internal import common_utils\n\nSKIP_TEST = None\nfrom apex.distributed_testing.distributed_test_base import NcclDistributedTestBase\n\ntry:\n    from apex.contrib.cudnn_gbn import GroupBatchNorm2d as GBN\nexcept ImportError as e:\n    SKIP_TEST = e\n\n\n# Usage: python /path/to/cudnn_gbn/test_gbn_with_two_gpus.py\n\ninput_shapes = [\n    [1, 1024, 48, 72],\n    [1, 128, 192, 288],\n    [1, 128, 384, 576],\n    [1, 1536, 48, 72],\n    [1, 2048, 48, 72],\n    [1, 256, 1, 1],\n    [1, 256, 192, 288],\n    [1, 256, 384, 576],\n    [1, 256, 48, 72],\n    [1, 256, 96, 144],\n    [1, 32, 384, 576],\n    [1, 48, 192, 288],\n    [1, 64, 384, 576],\n    [1, 728, 48, 72],\n    [1, 728, 96, 144],\n]\n\n\nclass BNModelRef(nn.Module):\n    def __init__(self, num_features, num_layers=1000):\n        super().__init__()\n        self.fwd = nn.Sequential(\n            *[\n                nn.BatchNorm2d(\n                    num_features,\n                    eps=1e-05,\n                    momentum=0.1,\n                    affine=True,\n                    track_running_stats=True,\n                )\n                for _ in range(num_layers)\n            ]\n        )\n\n    def forward(self, x):\n        return self.fwd(x)\n\n\nclass BNModel(nn.Module):\n    def __init__(self, num_features, num_layers=1000):\n        super().__init__()\n        self.fwd = nn.Sequential(\n            *[\n                GBN(\n                    num_features,\n                    group_size=2,\n                    eps=1e-05,\n                    momentum=0.1,\n                    affine=True,\n                    track_running_stats=True,\n                )\n                for _ in range(num_layers)\n            ]\n        )\n\n    def forward(self, x):\n        return self.fwd(x)\n\n\ndef get_rand_tensors(global_shape, device):\n    inp_t = torch.rand(global_shape, dtype=torch.float32, device=device).to(\n        memory_format=torch.channels_last\n    )\n    weight = torch.rand(global_shape[1], dtype=torch.float32, device=device)\n    bias = torch.rand(global_shape[1], dtype=torch.float32, device=device)\n    _grad_out = torch.rand(global_shape, dtype=torch.float32, device=device).to(\n        memory_format=torch.channels_last\n    )\n    return inp_t, weight, bias, _grad_out\n\n\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\nclass TestCudnnGBN(NcclDistributedTestBase):\n    def _prep(self):\n        torch.cuda.manual_seed(333)\n        torch.manual_seed(333)\n\n    @property\n    def world_size(self) -> int:\n        return min(torch.cuda.device_count(), 2)\n\n    @torch.backends.cudnn.flags(enabled=True, benchmark=True)\n    def _test_cudnn_gbn(\n        self,\n        num_layers: int,\n        shape: typing.List[int],\n        *,\n        memory_format: torch.memory_format = torch.channels_last,\n    ) -> None:\n        global_shape = copy.deepcopy(shape)\n        global_shape[0] = self.world_size\n\n        device = torch.device(\"cuda\", self.rank)\n        cudnn_gbn_model = BNModel(\n            num_features=shape[1],\n            num_layers=num_layers,\n        ).to(device=device, memory_format=memory_format)\n        ref_model = BNModelRef(\n            num_features=shape[1],\n            num_layers=num_layers,\n        ).to(device=device, memory_format=memory_format)\n\n        input, weight, bias, grad_out = get_rand_tensors(global_shape, device)\n        with torch.no_grad():\n            ref_model.fwd[0].weight.copy_(weight)\n            ref_model.fwd[0].bias.copy_(bias)\n            cudnn_gbn_model.fwd[0].weight.copy_(weight)\n            cudnn_gbn_model.fwd[0].bias.copy_(bias)\n\n            ref_input = input.clone().detach().requires_grad_()\n            input = input[self.rank : self.rank + 1, ...].clone().detach().requires_grad_()\n\n            ref_grad_out = grad_out.half().clone().detach()\n            grad_out = grad_out[self.rank : self.rank + 1, ...].half().clone().detach()\n\n        with torch.amp.autocast(\"cuda\"):\n            out = cudnn_gbn_model(input)\n            ref_out = ref_model(ref_input.half())\n        out.backward(grad_out)\n        ref_out.backward(ref_grad_out)\n\n        kwargs = {\"rtol\": 3.5e-3, \"atol\": 3e-2, \"msg\": f\"shape: {shape}\"}\n\n        torch.testing.assert_close(ref_out[self.rank : self.rank + 1], out, **kwargs)\n        torch.testing.assert_close(ref_input.grad[self.rank : self.rank + 1], input.grad, **kwargs)\n        # compensating the averaging over processes done by DDP\n        # in order to produce mathematically equivalent result\n        # https://github.com/NVIDIA/apex/issues/134#issuecomment-458307368\n        torch.testing.assert_close(\n            ref_model.fwd[0].weight.grad / self.world_size,\n            cudnn_gbn_model.fwd[0].weight.grad,\n            **kwargs,\n        )\n        torch.testing.assert_close(\n            ref_model.fwd[0].bias.grad / self.world_size,\n            cudnn_gbn_model.fwd[0].bias.grad,\n            **kwargs,\n        )\n\n    def test_cudnngbn(self):\n        if self.world_size != 2:\n            self.skipTest(f\"This test is written for world_size of 2 but {self.world_size}\")\n        for shape in input_shapes:\n            self._prep()\n            self._test_cudnn_gbn(1, shape)\n\n\nif __name__ == \"__main__\":\n    common_utils.run_tests()\n"
  },
  {
    "path": "apex/contrib/test/fmha/__init__.py",
    "content": ""
  },
  {
    "path": "apex/contrib/test/fmha/test_fmha.py",
    "content": "###############################################################################\n# Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#     * Redistributions of source code must retain the above copyright\n#       notice, this list of conditions and the following disclaimer.\n#     * Redistributions in binary form must reproduce the above copyright\n#       notice, this list of conditions and the following disclaimer in the\n#       documentation and/or other materials provided with the distribution.\n#     * Neither the name of the NVIDIA CORPORATION nor the\n#       names of its contributors may be used to endorse or promote products\n#       derived from this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n#\n###############################################################################\n\nimport math\nimport unittest\n\nimport torch\nimport numpy as np\n\nSKIP_TEST = None\ntry:\n    import fmhalib as mha\nexcept ImportError as e:\n    SKIP_TEST = e\n\n\ndef _get_device_properties(device=torch.device(\"cuda\")):\n    # type: (str or torch.device) -> Tuple[int, int]\n    properties = torch.cuda.get_device_properties(device)\n    return properties.major, properties.minor\n\n\ndef py_mha(qkv, amask, b, s, h, d):\n    qkv = qkv.view(b, s, h, 3, d)\n    q = qkv[:, :, :, 0, :].permute(0, 2, 1, 3)\n    k = qkv[:, :, :, 1, :].permute(0, 2, 1, 3)\n    v = qkv[:, :, :, 2, :].permute(0, 2, 1, 3)\n    p = torch.matmul(q.float(), k.permute(0, 1, 3, 2).float())\n    p_masked = p / math.sqrt(d) + (1.0 - amask) * -10000.0\n    s = torch.softmax(p_masked, -1).to(qkv.dtype)\n    ctx = torch.matmul(s, v)\n    ctx = ctx.permute(0, 2, 1, 3).contiguous()\n\n    ctx.retain_grad()\n\n    return ctx\n\n\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\n@unittest.skipIf(\n    _get_device_properties() not in [(8, 0), (9, 0), (10, 0), (12, 0)],\n    \"FMHA only supports sm80\",\n)\nclass TestFMHA(unittest.TestCase):\n    def run_test(self, s: int, b: int, zero_tensors: bool):\n        print(f\"Test s={s} b={b}, zero_tensors={zero_tensors}\")\n\n        torch.manual_seed(1234)\n        torch.cuda.manual_seed(1234)\n\n        dtype = torch.float16\n        device = torch.device(\"cuda\")\n\n        h = 16\n        d = 64\n\n        slens = [s] * b\n        a = torch.tensor(np.array([0] + slens), dtype=torch.int32)\n        amask = torch.ones(b, h, s, s, dtype=dtype, device=device)\n        seqlens = torch.tensor(slens, dtype=torch.int32, device=device)\n        cu_seqlens = torch.cumsum(a, 0).to(dtype=torch.int32, device=device)\n        total = cu_seqlens[-1].item()\n\n        qkv = torch.randn((b, s, h, 3, d), device=device, dtype=dtype)\n\n        qkv_vs = qkv.permute(0, 1, 3, 2, 4).contiguous().view(b * s, 3, h, d)\n\n        qkv.requires_grad = True\n\n        if b < 4:\n            ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, True, zero_tensors, None)\n        else:\n            ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, False, zero_tensors, None)\n        ctx = ctx.view(b, s, h, d)\n\n        ctx_ref = py_mha(qkv, amask, b, s, h, d)\n        torch.testing.assert_close(ctx_ref.float(), ctx.float(), atol=1e-3, rtol=1e-5)\n\n        labels = torch.randn_like(ctx_ref)\n        diff = ctx_ref - labels\n        l = (diff * diff).sum() / b\n        l.backward()\n\n        dw = ctx_ref.grad.permute(0, 2, 1, 3)\n\n        dw2 = dw.permute(0, 2, 1, 3).clone().detach().contiguous()\n\n        if b < 4:\n            dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s, zero_tensors)\n        else:\n            dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s, zero_tensors)\n\n        dqkv2 = dqkv2.permute(0, 2, 1, 3).view(b, s, h, 3, d)\n\n        torch.testing.assert_close(qkv.grad.float(), dqkv2.float(), atol=1e-3, rtol=1e-5)\n\n    def test_128(self):\n        self.run_test(128, 32, False)\n        self.run_test(128, 32, True)\n        self.run_test(128, 56, False)\n        self.run_test(128, 56, True)\n\n    def test_256(self):\n        self.run_test(256, 32, False)\n        self.run_test(256, 32, True)\n        self.run_test(256, 56, False)\n        self.run_test(256, 56, True)\n\n    def test_384(self):\n        self.run_test(384, 32, False)\n        self.run_test(384, 32, True)\n        self.run_test(384, 56, False)\n        self.run_test(384, 56, True)\n\n    def test_512(self):\n        self.run_test(512, 32, False)\n        self.run_test(512, 32, True)\n        self.run_test(512, 56, False)\n        self.run_test(512, 56, True)\n        self.run_test(512, 2, False)\n        self.run_test(512, 2, True)\n        self.run_test(512, 3, False)\n        self.run_test(512, 3, True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "apex/contrib/test/focal_loss/__init__.py",
    "content": ""
  },
  {
    "path": "apex/contrib/test/focal_loss/test_focal_loss.py",
    "content": "import unittest\n\nimport torch\nimport torch.nn.functional as F\n\nreference_available = True\ntry:\n    from torchvision.ops.focal_loss import sigmoid_focal_loss\nexcept ImportError:\n    reference_available = False\n\nSKIP_TEST = None\ntry:\n    from apex.contrib.focal_loss import focal_loss\nexcept ImportError as e:\n    SKIP_TEST = e\n\n\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\n@unittest.skipIf(\n    not reference_available,\n    \"Reference implementation `torchvision.ops.focal_loss.sigmoid_focal_loss` is not available.\",\n)\nclass FocalLossTest(unittest.TestCase):\n    N_SAMPLES = 12\n    N_CLASSES = 8\n    ALPHA = 0.24\n    GAMMA = 2.0\n    REDUCTION = \"sum\"\n\n    def test_focal_loss(self) -> None:\n        if not reference_available:\n            self.skipTest(\n                \"This test needs `torchvision` for `torchvision.ops.focal_loss.sigmoid_focal_loss`.\"\n            )\n        else:\n            x = torch.randn(FocalLossTest.N_SAMPLES, FocalLossTest.N_CLASSES).cuda()\n            with torch.no_grad():\n                x_expected = x.clone()\n                x_actual = x.clone()\n            x_expected.requires_grad_()\n            x_actual.requires_grad_()\n\n            classes = torch.randint(0, FocalLossTest.N_CLASSES, (FocalLossTest.N_SAMPLES,)).cuda()\n            with torch.no_grad():\n                y = F.one_hot(classes, FocalLossTest.N_CLASSES).float()\n\n            expected = sigmoid_focal_loss(\n                x_expected,\n                y,\n                alpha=FocalLossTest.ALPHA,\n                gamma=FocalLossTest.GAMMA,\n                reduction=FocalLossTest.REDUCTION,\n            )\n\n            actual = sum(\n                [\n                    focal_loss.FocalLoss.apply(\n                        x_actual[i : i + 1],\n                        classes[i : i + 1].long(),\n                        torch.ones([], device=\"cuda\"),\n                        FocalLossTest.N_CLASSES,\n                        FocalLossTest.ALPHA,\n                        FocalLossTest.GAMMA,\n                        0.0,\n                    )\n                    for i in range(FocalLossTest.N_SAMPLES)\n                ]\n            )\n\n            # forward parity\n            torch.testing.assert_close(expected, actual)\n\n            expected.backward()\n            actual.backward()\n\n            # grad parity\n            torch.testing.assert_close(x_expected.grad, x_actual.grad)\n\n\nif __name__ == \"__main__\":\n    torch.manual_seed(42)\n    unittest.main()\n"
  },
  {
    "path": "apex/contrib/test/fused_dense/test_fused_dense.py",
    "content": "import unittest\nimport os\n\nimport torch\nfrom torch.testing._internal import common_utils\nfrom torch.testing._internal.common_device_type import instantiate_device_type_tests\n\nSKIP_TEST = None\ntry:\n    from apex import fused_dense\nexcept ImportError as e:\n    SKIP_TEST = e\n\n\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\nclass FusedDenseTest(common_utils.TestCase):\n    def _test_fused_dense(self, dtype, seed=0):\n        os.environ[\"TORCH_ALLOW_TF32_CUBLAS_OVERRIDE\"] = \"0\"\n        torch.manual_seed(seed)\n\n        seq_length = 512\n        sequences = 3\n        hidden_dim = 1024\n\n        ref_inputs = torch.randn(\n            sequences * seq_length, hidden_dim, dtype=dtype, device=torch.device(\"cuda\")\n        ).requires_grad_(True)\n\n        tst_inputs = ref_inputs.clone().detach().requires_grad_(True)\n        dense = fused_dense.FusedDense(1024, 3072)\n        dense.to(dtype=dtype)\n        dense.cuda()\n\n        y_tst = dense(tst_inputs)\n        y_ref = torch.matmul(ref_inputs, dense.weight.t()) + dense.bias\n        dy = torch.randn_like(y_tst).to(dtype=dtype)\n        y_tst.backward(dy)\n        dw_ref = torch.matmul(dy.t(), ref_inputs)\n        dx_ref = torch.matmul(dy, dense.weight.clone())\n        db_ref = dy.sum(0, False)\n\n        torch.testing.assert_close(ref_inputs, tst_inputs, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(y_ref, y_tst, atol=1e-3, rtol=1e-3, equal_nan=True)\n        torch.testing.assert_close(dw_ref, dense.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True)\n        torch.testing.assert_close(dx_ref, tst_inputs.grad, atol=1e-3, rtol=1e-3, equal_nan=True)\n        torch.testing.assert_close(db_ref, dense.bias.grad, atol=1e-3, rtol=1e-3, equal_nan=True)\n\n    @common_utils.parametrize(\"dtype\", [torch.half, torch.float, torch.bfloat16])\n    def test_fused_dense(self, dtype):\n        self._test_fused_dense(dtype)\n\n\ninstantiate_device_type_tests(FusedDenseTest, globals(), only_for=(\"cuda\",))\n\nif __name__ == \"__main__\":\n    common_utils.run_tests()\n"
  },
  {
    "path": "apex/contrib/test/group_norm/__init__.py",
    "content": ""
  },
  {
    "path": "apex/contrib/test/group_norm/test_group_norm.py",
    "content": "#!/usr/bin/env python\n# coding: utf-8\n\n#\n# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: BSD-3-Clause\n#\n\nimport functools\nimport importlib\nimport pathlib\nimport sys\nimport torch\nimport unittest\n\nSKIP_TEST = None\ntry:\n    from apex.contrib.group_norm.group_norm import cuda_group_norm_nhwc_one_pass\n    from apex.contrib.group_norm.group_norm import cuda_group_norm_nhwc_two_pass\n    from apex.contrib.group_norm.group_norm import cuda_group_norm_v2_nhwc\n    from apex.contrib.group_norm.group_norm import get_cc_and_sm_count\n    from apex.contrib.group_norm import GroupNorm\nexcept ImportError as e:\n    SKIP_TEST = e\n\n\ndef torch_group_norm_high_precision(x, g, w, b, eps, act=\"\", *, compute_type):\n    xdtype = x.dtype\n    y = torch.nn.functional.group_norm(\n        x.to(compute_type),\n        g,\n        w.to(compute_type),\n        b.to(compute_type),\n        eps,\n    )\n    if act in [\"silu\", \"swish\"]:\n        y = torch.nn.functional.silu(y)\n    y = y.to(dtype=xdtype)\n    return y\n\n\ntorch_group_norm_high_precision_fp64 = functools.partial(\n    torch_group_norm_high_precision,\n    compute_type=torch.float64,\n)\n\n\n@functools.cache\ndef relative_ulp(dtype, device):\n    # Unit in the Last Place\n    one = torch.tensor(1.0, dtype=dtype, device=device)\n    two = torch.tensor(2.0, dtype=dtype, device=device)\n    return (torch.nextafter(one, two) - one).item()\n\n\ndef _ref_compute_type(ref_func, xdtype: torch.dtype) -> torch.dtype:\n    # `torch_group_norm_high_precision_fp64` is a functools.partial with compute_type keyword.\n    if isinstance(ref_func, functools.partial):\n        compute_type = (ref_func.keywords or {}).get(\"compute_type\", None)\n        if compute_type is not None:\n            return compute_type\n    return xdtype\n\n\ndef _estimate_group_norm_test_bytes(\n    *,\n    N: int,\n    C: int,\n    H: int,\n    W: int,\n    xdtype: torch.dtype,\n    wdtype: torch.dtype,\n    ref_func,\n) -> int:\n    \"\"\"\n    Conservative VRAM estimate for `verify_group_norm`.\n\n    The reference path converts to a high-precision compute type (fp64 by default)\n    and runs both forward+backward while retaining graphs, which can roughly require\n    multiple full-size buffers at once. We intentionally over-estimate to avoid OOMs.\n    \"\"\"\n    numel = int(N) * int(C) * int(H) * int(W)\n    ref_dtype = _ref_compute_type(ref_func, xdtype)\n\n    x_bytes = numel * int(xdtype.itemsize)\n    ref_bytes = numel * int(ref_dtype.itemsize)\n\n    # Live tensors: x, dy, y_ref, y_tst, dx_ref/dx_tst + autograd saved buffers.\n    # Empirically, a ~10x multiplier on the reference compute buffers is a safer\n    # lower bound for fp64 reference on large tensors.\n    #\n    # Keep the estimate simple and intentionally conservative:\n    # - Base fp16/bf16 buffers: ~6x (x, dy, y, grads/temps)\n    # - Reference high-precision buffers: ~10x\n    estimate = (6 * x_bytes) + (10 * ref_bytes)\n\n    # Small extras: weights/bias/grads.\n    estimate += 6 * int(C) * int(wdtype.itemsize)\n    return int(estimate)\n\n\ndef _has_sufficient_cuda_memory(required_bytes: int, *, safety_factor: float = 0.90) -> bool:\n    if not torch.cuda.is_available():\n        return False\n    # `mem_get_info` reports free/total for the current device.\n    free_bytes, _total_bytes = torch.cuda.mem_get_info()\n    return required_bytes <= int(free_bytes * safety_factor)\n\n\n@unittest.skipIf(\n    torch.cuda.get_device_properties().multi_processor_count < 16,\n    \"GroupNorm is unsupported on low SM count devices\",\n)\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\nclass GroupNormTest(unittest.TestCase):\n    def setUp(self, seed=0):\n        super().setUp()\n        torch.manual_seed(seed)\n\n    def verify_group_norm(\n        self,\n        tst_func,\n        N=32,\n        C=128,\n        H=256,\n        W=256,\n        G=32,\n        ref_func=torch_group_norm_high_precision_fp64,\n        xdtype=torch.float16,\n        wdtype=torch.float32,\n        eps=1e-5,\n        memory_format=torch.channels_last,\n        device=\"cuda\",\n        act=\"\",\n    ):\n        # create data\n        x_shape = (N, C, H, W)\n        w_shape = (C,)\n        weight = torch.rand(w_shape, dtype=wdtype, device=\"cuda\", requires_grad=True)\n        bias = torch.rand(w_shape, dtype=wdtype, device=\"cuda\", requires_grad=True)\n        x = -2.3 + 0.5 * torch.randn(x_shape, dtype=xdtype, device=\"cuda\")\n        x = x.to(memory_format=memory_format)\n        dy = 0.1 * torch.randn_like(x)\n        x.requires_grad_(True)\n\n        # forward pass\n        y_ref = ref_func(x, G, weight, bias, eps, act)\n        if tst_func is GroupNorm:\n            gn = GroupNorm(G, C, eps, device=device, dtype=wdtype, act=act)\n            with torch.no_grad():\n                gn.weight = torch.nn.Parameter(weight)\n                gn.bias = torch.nn.Parameter(bias)\n            y_tst = gn(x)\n        else:\n            y_tst = tst_func(x, G, weight, bias, eps, act)\n\n        # backward pass\n        y_ref.backward(dy, retain_graph=True)\n        dx_ref, dw_ref, db_ref = [t.grad.clone() for t in [x, weight, bias]]\n        x.grad.zero_()\n        weight.grad.zero_()\n        bias.grad.zero_()\n        y_tst.backward(dy, retain_graph=True)\n        if tst_func is GroupNorm:\n            dx_tst, dw_tst, db_tst = x.grad, gn.weight.grad, gn.bias.grad\n        else:\n            dx_tst, dw_tst, db_tst = [t.grad.clone() for t in [x, weight, bias]]\n\n        # compare\n        torch.testing.assert_close(\n            y_tst, y_ref, atol=1e-2, rtol=relative_ulp(y_ref.dtype, y_ref.device)\n        )\n        torch.testing.assert_close(\n            dx_tst, dx_ref, atol=1e-2, rtol=relative_ulp(dx_ref.dtype, dx_ref.device)\n        )\n        torch.testing.assert_close(\n            dw_tst, dw_ref, atol=1e-2, rtol=relative_ulp(dw_ref.dtype, dw_ref.device)\n        )\n        torch.testing.assert_close(\n            db_tst, db_ref, atol=1e-2, rtol=relative_ulp(db_ref.dtype, db_ref.device)\n        )\n\n    def test_fp16_one_pass_algo(self):\n        self.verify_group_norm(cuda_group_norm_nhwc_one_pass, act=\"\")\n\n    def test_fp16_two_pass_algo(self):\n        self.verify_group_norm(cuda_group_norm_nhwc_two_pass, act=\"\")\n\n    def test_fp16_one_pass_algo_with_swish(self):\n        self.verify_group_norm(cuda_group_norm_nhwc_one_pass, act=\"swish\")\n\n    def test_fp16_two_pass_algo_with_swish(self):\n        self.verify_group_norm(cuda_group_norm_nhwc_two_pass, act=\"swish\")\n\n    def test_bf16_one_pass_algo(self):\n        self.verify_group_norm(cuda_group_norm_nhwc_one_pass, xdtype=torch.bfloat16, act=\"\")\n\n    def test_bf16_two_pass_algo(self):\n        self.verify_group_norm(cuda_group_norm_nhwc_two_pass, xdtype=torch.bfloat16, act=\"\")\n\n    def test_bf16_one_pass_algo_with_swish(self):\n        self.verify_group_norm(cuda_group_norm_nhwc_one_pass, xdtype=torch.bfloat16, act=\"swish\")\n\n    def test_bf16_two_pass_algo_with_swish(self):\n        self.verify_group_norm(cuda_group_norm_nhwc_two_pass, xdtype=torch.bfloat16, act=\"swish\")\n\n    def test_fp32_one_pass_algo(self):\n        self.verify_group_norm(cuda_group_norm_nhwc_one_pass, xdtype=torch.float32, act=\"\")\n\n    def test_fp32_two_pass_algo(self):\n        self.verify_group_norm(cuda_group_norm_nhwc_two_pass, xdtype=torch.float32, act=\"\")\n\n    def test_fp32_one_pass_algo_with_swish(self):\n        self.verify_group_norm(cuda_group_norm_nhwc_one_pass, xdtype=torch.float32, act=\"swish\")\n\n    def test_fp32_two_pass_algo_with_swish(self):\n        self.verify_group_norm(cuda_group_norm_nhwc_two_pass, xdtype=torch.float32, act=\"swish\")\n\n    def test_group_norm_module(self):\n        self.verify_group_norm(GroupNorm, G=16, act=\"swish\")\n\n    def test_group_norm_inductor(self):\n        N, C, H, W, G = 32, 320, 256, 256, 16\n\n        model = (\n            torch.nn.Sequential(\n                GroupNorm(G, C, act=\"silu\", dtype=torch.float16),\n                torch.nn.Conv2d(C, C, kernel_size=3, padding=\"same\"),\n            )\n            .cuda()\n            .half()\n        )\n        compiled = torch.compile(model)\n\n        x = -2.3 + 0.5 * torch.randn((N, C, H, W), dtype=torch.float16, device=\"cuda\")\n        x = x.to(memory_format=torch.channels_last)\n        dy = 0.1 * torch.randn_like(x)\n        x.requires_grad_(True)\n\n        for _ in range(4):\n            y = compiled(x)\n            y.backward(dy)\n\n        from torch._dynamo.utils import counters\n\n        # TODO: Remove this when 3.9 is no longer supported\n        if sys.version_info < (3, 10):\n            num_graph_breaks = sum(counters[\"graph_break\"].values())\n        else:\n            num_graph_breaks = counters[\"graph_break\"].total()\n        self.assertEqual(num_graph_breaks, 0, \"Shouldn't see any graph breaks.\")\n        self.assertEqual(counters[\"stats\"][\"unique_graphs\"], 1, \"Expect only one graph.\")\n\n    def test_16_groups(self):\n        sizes = [\n            [8, 2560, 16, 16],\n            [8, 1920, 32, 32],\n            [8, 1920, 16, 16],\n            [8, 2560, 8, 8],\n            [1, 128, 16128, 1200],\n        ]\n        for sz in sizes:\n            with self.subTest(size=sz):\n                n, c, h, w = sz\n                required = _estimate_group_norm_test_bytes(\n                    N=n,\n                    C=c,\n                    H=h,\n                    W=w,\n                    xdtype=torch.float16,\n                    wdtype=torch.float32,\n                    ref_func=torch_group_norm_high_precision_fp64,\n                )\n                if not _has_sufficient_cuda_memory(required):\n                    free_bytes, total_bytes = torch.cuda.mem_get_info()\n                    raise unittest.SkipTest(\n                        f\"Skipping large GroupNorm case {sz}: estimated {required / 1e9:.1f} GB \"\n                        f\"requires more than available free VRAM ({free_bytes / 1e9:.1f} GB free, \"\n                        f\"{total_bytes / 1e9:.1f} GB total).\"\n                    )\n                self.verify_group_norm(GroupNorm, N=n, C=c, H=h, W=w, G=16, act=\"swish\")\n\n    def test_large_batch_two_pass(self):\n        \"\"\"Regression test for divide-by-zero when batch size is large.\n\n        When batch_size >= 256 and c >= 640, blocks_per_act_slice = 256 / n\n        truncates to 0, causing div_up(hw, 0). Test all three heuristic branches.\n        \"\"\"\n        sizes = [\n            [256, 1280, 8, 8],\n            [512, 640, 16, 16],\n            [1024, 512, 8, 8],\n        ]\n        for sz in sizes:\n            with self.subTest(size=sz):\n                n, c, h, w = sz\n                required = _estimate_group_norm_test_bytes(\n                    N=n,\n                    C=c,\n                    H=h,\n                    W=w,\n                    xdtype=torch.float16,\n                    wdtype=torch.float32,\n                    ref_func=torch_group_norm_high_precision_fp64,\n                )\n                if not _has_sufficient_cuda_memory(required):\n                    free_bytes, total_bytes = torch.cuda.mem_get_info()\n                    raise unittest.SkipTest(\n                        f\"Skipping large-batch GroupNorm case {sz}: estimated \"\n                        f\"{required / 1e9:.1f} GB requires more than available \"\n                        f\"free VRAM ({free_bytes / 1e9:.1f} GB free, \"\n                        f\"{total_bytes / 1e9:.1f} GB total).\"\n                    )\n                self.verify_group_norm(\n                    cuda_group_norm_nhwc_two_pass, N=n, C=c, H=h, W=w, G=32, act=\"silu\"\n                )\n\n    def test_fp16_parameters(self):\n        n, c, h, w = 8, 2560, 16, 16\n        self.verify_group_norm(\n            GroupNorm,\n            N=n,\n            C=c,\n            H=h,\n            W=w,\n            G=16,\n            xdtype=torch.float16,\n            wdtype=torch.float16,\n            act=\"swish\",\n        )\n\n    @staticmethod\n    @functools.cache\n    def get_v2_hw_c_list():\n        srcpath = pathlib.Path(__file__).parent.absolute()\n        gen_module_path = (\n            srcpath / \"..\" / \"..\" / \"csrc\" / \"group_norm_v2\" / \"generate_gn_cuda_inst.py\"\n        )\n        spec = importlib.util.spec_from_file_location(\"generate_gn_cuda_inst\", gen_module_path)\n        generate_gn_cuda_inst = importlib.util.module_from_spec(spec)\n        spec.loader.exec_module(generate_gn_cuda_inst)\n        return generate_gn_cuda_inst.hw_c_list\n\n    def check_v2_cc_and_sm_count(self):\n        cc, sm_count = get_cc_and_sm_count(torch.cuda.current_device())\n        return (\n            cc in GroupNorm.GN_V2_SUPPORTED_LOWER_BOUND_SM_COUNT\n            and sm_count >= GroupNorm.GN_V2_SUPPORTED_LOWER_BOUND_SM_COUNT[cc]\n        )\n\n    def skip_if_v2_not_supported(self):\n        if not self.check_v2_cc_and_sm_count():\n            cc, sm_count = get_cc_and_sm_count(torch.cuda.current_device())\n            self.skipTest(\n                f\"SM count {sm_count} is not supported for compute capability {cc[0]}.{cc[1]}\"\n            )\n\n    def test_check_v2_legality(self):\n        gn = GroupNorm(\n            num_groups=16,\n            num_channels=640,\n            device=\"cuda\",\n            dtype=torch.float16,\n            act=\"swish\",\n        )\n        self.skip_if_v2_not_supported()\n        # Correct\n        x = torch.empty(\n            8,\n            640,\n            32,\n            32,\n            dtype=torch.float16,\n            device=\"cuda\",\n            memory_format=torch.channels_last,\n        )\n        self.assertTrue(gn._check_legality(x) and gn._check_v2_legality(x))\n        # Wrong layout\n        x = torch.empty(8, 640, 32, 32, dtype=torch.float16, device=\"cuda\")\n        self.assertFalse(gn._check_legality(x) and gn._check_v2_legality(x))\n        # Wrong shape\n        x = torch.empty(\n            8,\n            640,\n            32,\n            24,\n            dtype=torch.float16,\n            device=\"cuda\",\n            memory_format=torch.channels_last,\n        )\n        self.assertFalse(gn._check_legality(x) and gn._check_v2_legality(x))\n        # Wrong dtype\n        x = torch.empty(\n            8,\n            640,\n            32,\n            32,\n            dtype=torch.float32,\n            device=\"cuda\",\n            memory_format=torch.channels_last,\n        )\n        self.assertFalse(gn._check_legality(x) and gn._check_v2_legality(x))\n\n    def test_fp16_v2_32_groups(self):\n        self.skip_if_v2_not_supported()\n        for n in [1, 2, 4, 8, 16, 32]:\n            for hw, c in self.get_v2_hw_c_list():\n                h = w = int(hw**0.5)\n                assert hw == h * w\n                self.verify_group_norm(\n                    cuda_group_norm_v2_nhwc,\n                    N=n,\n                    C=c,\n                    H=h,\n                    W=w,\n                    G=32,\n                    xdtype=torch.float16,\n                    wdtype=torch.float16,\n                    act=\"\",\n                )\n\n    def test_fp16_v2_16_groups_with_swish(self):\n        self.skip_if_v2_not_supported()\n        for n in [1, 2, 4, 8, 16, 32]:\n            for hw, c in self.get_v2_hw_c_list():\n                h = w = int(hw**0.5)\n                assert hw == h * w\n                self.verify_group_norm(\n                    cuda_group_norm_v2_nhwc,\n                    N=n,\n                    C=c,\n                    H=h,\n                    W=w,\n                    G=16,\n                    xdtype=torch.float16,\n                    wdtype=torch.float16,\n                    act=\"swish\",\n                )\n\n    def test_bf16_v2_32_groups(self):\n        self.skip_if_v2_not_supported()\n        for n in [1, 2, 4, 8, 16, 32]:\n            for hw, c in self.get_v2_hw_c_list():\n                h = w = int(hw**0.5)\n                assert hw == h * w\n                self.verify_group_norm(\n                    cuda_group_norm_v2_nhwc,\n                    N=n,\n                    C=c,\n                    H=h,\n                    W=w,\n                    G=32,\n                    xdtype=torch.bfloat16,\n                    wdtype=torch.bfloat16,\n                    act=\"\",\n                )\n\n    def test_bf16_v2_16_groups_with_swish(self):\n        self.skip_if_v2_not_supported()\n        for n in [1, 2, 4, 8, 16, 32]:\n            for hw, c in self.get_v2_hw_c_list():\n                h = w = int(hw**0.5)\n                assert hw == h * w\n                self.verify_group_norm(\n                    cuda_group_norm_v2_nhwc,\n                    N=n,\n                    C=c,\n                    H=h,\n                    W=w,\n                    G=16,\n                    xdtype=torch.bfloat16,\n                    wdtype=torch.bfloat16,\n                    act=\"swish\",\n                )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "apex/contrib/test/index_mul_2d/__init__.py",
    "content": ""
  },
  {
    "path": "apex/contrib/test/index_mul_2d/test_index_mul_2d.py",
    "content": "import random\nimport unittest\n\nimport torch\n\nHAS_INDEX_MUL_2D_RELU = None\ntry:\n    from apex.contrib.index_mul_2d import index_mul_2d\nexcept ImportError:\n    HAS_INDEX_MUL_2D_RELU = False\nelse:\n    HAS_INDEX_MUL_2D_RELU = True\n\n\n@unittest.skipIf(not HAS_INDEX_MUL_2D_RELU, \"`apex.contrib.index_mul_2d` is not found.\")\nclass IndexMul2dTest(unittest.TestCase):\n    def setUp(self, seed=0):\n        torch.manual_seed(seed)\n\n        self.input1_size = random.randint(1, 1000)\n        self.input2_size = random.randint(1, 100000)\n        self.feature_size = random.randint(1, 256)\n\n        self.input1_float = torch.randn(\n            size=(self.input1_size, self.feature_size),\n        ).cuda()\n        self.input2_float = torch.randn(\n            size=(self.input2_size, self.feature_size),\n        ).cuda()\n        self.index1 = torch.randint(low=0, high=self.input1_size, size=(self.input2_size,)).cuda()\n\n        self.input1_float_ = self.input1_float.clone()\n        self.input2_float_ = self.input2_float.clone()\n\n        self.input1_float.requires_grad_()\n        self.input1_float_.requires_grad_()\n        self.input2_float.requires_grad_()\n        self.input2_float_.requires_grad_()\n\n        self.input1_half = (\n            torch.randn(\n                size=(self.input1_size, self.feature_size),\n            )\n            .cuda()\n            .half()\n        )\n        self.input2_half = (\n            torch.randn(\n                size=(self.input2_size, self.feature_size),\n            )\n            .cuda()\n            .half()\n        )\n\n        self.input1_half_ = self.input1_half.clone()\n        self.input2_half_ = self.input2_half.clone()\n\n        self.input1_half.requires_grad_()\n        self.input2_half.requires_grad_()\n        self.input1_half_.requires_grad_()\n        self.input2_half_.requires_grad_()\n\n    def test_index_mul_float(self):\n        out = index_mul_2d(self.input1_float, self.input2_float, self.index1)\n        energy = (out.float() ** 2).sum() / out.numel()\n        force = torch.autograd.grad(\n            energy,\n            self.input1_float,\n            grad_outputs=torch.ones_like(energy),\n            create_graph=True,\n        )[0]\n        loss = (out.float() ** 2).sum() / out.numel() + (force.float() ** 2).sum()\n        loss.backward()\n\n        out_ = self.input1_float_[self.index1] * self.input2_float_\n        energy_ = (out_.float() ** 2).sum() / out.numel()\n        force_ = torch.autograd.grad(\n            energy_,\n            self.input1_float_,\n            grad_outputs=torch.ones_like(energy),\n            create_graph=True,\n        )[0]\n        loss = (out_.float() ** 2).sum() / out_.numel() + (force_.float() ** 2).sum()\n        loss.backward()\n\n        torch.testing.assert_close(\n            self.input1_float, self.input1_float_, atol=1e-3, rtol=1e-3, equal_nan=True\n        )\n        torch.testing.assert_close(\n            self.input2_float, self.input2_float_, atol=1e-3, rtol=1e-3, equal_nan=True\n        )\n        torch.testing.assert_close(\n            self.input1_float.grad,\n            self.input1_float_.grad,\n            atol=1e-3,\n            rtol=1e-3,\n            equal_nan=True,\n        )\n        torch.testing.assert_close(\n            self.input2_float.grad,\n            self.input2_float_.grad,\n            atol=1e-3,\n            rtol=1e-3,\n            equal_nan=True,\n        )\n\n    def test_index_mul_half(self):\n        out = index_mul_2d(self.input1_half, self.input2_half, self.index1)\n        energy = (out.float() ** 2).sum() / out.numel()\n        force = torch.autograd.grad(\n            energy,\n            self.input1_half,\n            grad_outputs=torch.ones_like(energy),\n            create_graph=True,\n        )[0]\n        loss = (out.float() ** 2).sum() / out.numel() + (force.float() ** 2).sum()\n        loss.backward()\n\n        out_ = self.input1_half_[self.index1] * self.input2_half_\n        energy_ = (out_.float() ** 2).sum() / out.numel()\n        force_ = torch.autograd.grad(\n            energy_,\n            self.input1_half_,\n            grad_outputs=torch.ones_like(energy),\n            create_graph=True,\n        )[0]\n        loss = (out_.float() ** 2).sum() / out_.numel() + (force_.float() ** 2).sum()\n        loss.backward()\n\n        torch.testing.assert_close(\n            self.input1_half, self.input1_half_, atol=1e-3, rtol=1e-3, equal_nan=True\n        )\n        torch.testing.assert_close(\n            self.input2_half, self.input2_half_, atol=1e-3, rtol=1e-3, equal_nan=True\n        )\n        torch.testing.assert_close(\n            self.input1_half.grad,\n            self.input1_half_.grad,\n            atol=2e-3,\n            rtol=5e-2,\n            equal_nan=True,\n        )\n        torch.testing.assert_close(\n            self.input2_half.grad,\n            self.input2_half_.grad,\n            atol=1e-3,\n            rtol=1e-3,\n            equal_nan=True,\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "apex/contrib/test/layer_norm/__init__.py",
    "content": ""
  },
  {
    "path": "apex/contrib/test/layer_norm/test_fast_layer_norm.py",
    "content": "import itertools\nimport unittest\n\nimport torch\n\nSKIP_TEST = None\ntry:\n    from apex.contrib.layer_norm.layer_norm import FastLayerNorm\n    import fast_layer_norm as fln\nexcept ImportError as e:\n    SKIP_TEST = e\n\n\nclass GPUTimer:\n    def __init__(self, stream):\n        self.start_ = torch.cuda.Event(enable_timing=True)\n        self.stop_ = torch.cuda.Event(enable_timing=True)\n        self.stream_ = stream\n\n    def start(self):\n        self.stream_.record_event(self.start_)\n\n    def stop(self):\n        self.stream_.record_event(self.stop_)\n\n    def sync(self):\n        self.stream_.synchronize()\n\n    def millis(self):\n        return self.start_.elapsed_time(self.stop_)\n\n\ndef size_in_bytes(t):\n    return torch.numel(t) * t.element_size()\n\n\ndef metrics(y_ref, y, epsilon=1e-6):\n    y_ref = y_ref.float()\n    y = y.float()\n    relerr, mse = (\n        (y_ref - y).abs().sum() / (y_ref.abs().sum() + epsilon),\n        (y_ref - y).square().mean(),\n    )\n    return relerr.item(), mse.item()\n\n\ndevice = torch.device(\"cuda\")\nfp32 = torch.float32\nfp16 = torch.float16\nbf16 = torch.bfloat16\n\n\ndef backward_(dz, x, mu, rs, gamma):\n    wtype = gamma.dtype\n    itype = x.dtype\n    otype = dz.dtype\n    ctype = mu.dtype\n    mu = mu.unsqueeze(1)\n    rs = rs.unsqueeze(1)\n\n    hidden_size = gamma.numel()\n    y = rs * (x.to(ctype) - mu)\n    dbeta = dz.view(-1, hidden_size).sum(0, dtype=ctype)\n    dgamma = (dz * y).view(-1, hidden_size).sum(0, dtype=ctype)\n    dy = dz.view(-1, hidden_size).to(ctype) * gamma.unsqueeze(0).to(ctype)\n    mdy = dy.mean(1, keepdim=True, dtype=ctype)\n\n    mdyy = (dy * y).mean(1, keepdim=True, dtype=ctype)\n    dx = rs * (dy - mdyy * y - mdy)\n\n    return dx.to(itype), dgamma.to(wtype), dbeta.to(wtype)\n\n\ndef benchmark_(S, B, hidden_size, itype, wtype, runs=100):\n    epsilon = 1e-5\n\n    x = torch.randn((S * B, hidden_size), dtype=itype, device=device)\n    beta = torch.randn(hidden_size, dtype=wtype, device=device)\n    gamma = torch.randn(hidden_size, dtype=wtype, device=device)\n    dz = torch.randn(x.shape, dtype=wtype, device=device)\n\n    stream = torch.cuda.Stream()\n    with torch.cuda.stream(stream):\n        timer = GPUTimer(stream)\n\n        # warmup\n        for r in range(runs):\n            z, mu, rsigma = fln.ln_fwd(x, gamma, beta, epsilon)\n\n        timer.start()\n        for r in range(runs):\n            z, mu, rsigma = fln.ln_fwd(x, gamma, beta, epsilon)\n        timer.stop()\n        timer.sync()\n\n        total_bytes_fwd = sum([size_in_bytes(t) for t in [x, z, gamma, beta, mu, rsigma]])\n\n        ms_fwd = timer.millis() / runs\n\n        print(\n            \"[FWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec\".format(\n                ms_fwd, total_bytes_fwd * 1e-6 / ms_fwd\n            )\n        )\n\n        timer.start()\n        for r in range(runs):\n            dx, dgamma, dbeta, dbp, dgp = fln.ln_bwd(dz, z, mu, rsigma, gamma, beta, True)\n        timer.stop()\n        timer.sync()\n\n        total_bytes_bwd = sum(\n            [\n                size_in_bytes(t)\n                for t in [\n                    dz,\n                    x,\n                    mu,\n                    rsigma,\n                    gamma,\n                    dx,\n                    dgamma,\n                    dbeta,\n                    dbp,\n                    dbp,\n                    dgp,\n                    dgp,\n                ]\n            ]\n        )\n\n        ms_bwd = timer.millis() / runs\n\n        print(\n            \"[BWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec\".format(\n                ms_bwd, total_bytes_bwd * 1e-6 / ms_bwd\n            )\n        )\n\n\ndef _test_impl(S, B, hidden_size, itype, wtype, ctype=fp32, mem_eff=False):\n    seed = 1243\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n\n    otype = wtype\n    print(\"========================================================\")\n    print(f\"S={S} B={B} Hidden={hidden_size} {itype} {wtype} Mem_Eff={mem_eff}\")\n    print(\"--------------------------------------------------------\")\n\n    x = torch.randn(S * B, hidden_size, dtype=itype, device=device)\n    gamma = torch.randn(hidden_size, dtype=wtype, device=device) * 0.2\n    beta = torch.randn(hidden_size, dtype=wtype, device=device) * 0.2\n    epsilon = 1e-5\n\n    x.requires_grad = True\n    gamma.requires_grad = True\n    beta.requires_grad = True\n\n    mu_ref = x.mean(1, dtype=ctype, keepdim=True)\n    v = torch.square(x - mu_ref).mean(1, dtype=ctype, keepdim=True)\n    rs_ref = torch.rsqrt(v + epsilon)\n    y_ref = rs_ref * (x.to(ctype) - mu_ref)\n    z_ref = (gamma.unsqueeze(0) * (y_ref).to(otype) + beta.unsqueeze(0)).to(otype)\n\n    mu_ref = mu_ref.flatten()\n    rs_ref = rs_ref.flatten()\n\n    dz = torch.randn_like(z_ref)\n\n    # z_ref.backward(dz)\n    # dx_ref = x.grad\n    # dgamma_ref = gamma.grad\n    # dbeta_ref = beta.grad\n\n    dx_ref, dg_ref, db_ref = backward_(dz, x, mu_ref, rs_ref, gamma)\n\n    z, mu, rs = fln.ln_fwd(x, gamma, beta, epsilon)\n    if mem_eff:\n        dx, dg, db, dg_part, db_part = fln.ln_bwd(dz, z, mu, rs, gamma, beta, True)\n    else:\n        dx, dg, db, dg_part, db_part = fln.ln_bwd(dz, x, mu, rs, gamma, beta, False)\n\n    re_z, mse_z = metrics(z_ref, z)\n    re_mu, mse_mu = metrics(mu_ref, mu)\n    re_rs, mse_rs = metrics(rs_ref, rs)\n\n    re_dx, mse_dx = metrics(dx_ref, dx)\n    re_dg, mse_dg = metrics(dg_ref, dg)\n    re_db, mse_db = metrics(db_ref, db)\n\n    print(f\" z: relerr={re_z:.4e} mse={mse_z:.4e}\")\n    print(f\"mu: relerr={re_mu:.4e} mse={mse_mu:.4e}\")\n    print(f\"rs: relerr={re_mu:.4e} mse={mse_mu:.4e}\")\n\n    print(f\"dx: relerr={re_dx:.4e} mse={mse_dx:.4e}\")\n    print(f\"dg: relerr={re_dg:.4e} mse={mse_dg:.4e}\")\n    print(f\"db: relerr={re_db:.4e} mse={mse_db:.4e}\")\n\n    def check_err(x, relerr):\n        tol = 2e-2 if x.dtype in (torch.float16, torch.bfloat16) else 1e-5\n        return relerr < tol\n\n    return [\n        check_err(x, re)\n        for x, re in zip([z, mu, rs, dx, dg, db], [re_z, re_mu, re_rs, re_dx, re_dg, re_db])\n    ]\n\n\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\nclass TestFastLayerNorm(unittest.TestCase):\n    # TODO(crcrpar): Try `torch.testing.assert_close` instead and migrate to it if it's working.\n    def assertAll(self, l):\n        if not all(l):\n            print(l)\n        for x in l:\n            self.assertTrue(x)\n\n    def test_all_configs(self):\n        hidden_sizes = [\n            768,\n            1024,\n            1536,\n            2048,\n            2304,\n            3072,\n            3840,\n            4096,\n            5120,\n            6144,\n            8192,\n            10240,\n            12288,\n            12800,\n            14336,\n            15360,\n            16384,\n            18432,\n            20480,\n            24576,\n            25600,\n            30720,\n            32768,\n            40960,\n            49152,\n            65536,\n        ]\n\n        for h, mem_eff in itertools.product(hidden_sizes, (True, False)):\n            with self.subTest(f\"hidden_size={h}\"):\n                self.assertAll(_test_impl(256, 2, h, fp32, fp32, mem_eff=mem_eff))\n                self.assertAll(_test_impl(256, 2, h, fp16, fp16, mem_eff=mem_eff))\n                self.assertAll(_test_impl(256, 2, h, fp32, fp16, mem_eff=mem_eff))\n                self.assertAll(_test_impl(256, 2, h, bf16, bf16, mem_eff=mem_eff))\n                self.assertAll(_test_impl(256, 2, h, fp32, bf16, mem_eff=mem_eff))\n\n    def test_run_benchmark(self):\n        for S, B, hidden_size, runs in (\n            (512, 32, 768, 1000),\n            (512, 32, 1024, 1000),\n            (512, 8, 4096, 1000),\n            (512, 8, 5120, 1000),\n            (512, 8, 6144, 1000),\n            (256, 2, 20480, 500),\n            (256, 2, 25600, 500),\n            (256, 2, 40960, 250),\n            (256, 2, 65536, 250),\n        ):\n            with self.subTest(f\"(S, B, hidden_size)=({S}, {B}, {hidden_size})\"):\n                benchmark_(S, B, hidden_size, fp16, fp16, runs)\n\n    def test_compat_with_autocast(self):\n        autocast_dtypes = (\n            (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)\n        )\n        input_shape = (512, 32, 768)\n        layer_norm = FastLayerNorm(input_shape[-1]).cuda()\n        input = torch.randn(input_shape).cuda()\n\n        for dtype in autocast_dtypes:\n            layer_norm.zero_grad(set_to_none=True)\n            with self.subTest(f\"autocast_dtype={dtype}\"):\n                with torch.amp.autocast(\"cuda\", enabled=True, dtype=dtype):\n                    out = layer_norm(input)\n                    self.assertEqual(dtype, out.dtype)\n                grad = torch.randn_like(out)\n                out.backward(grad)\n                self.assertEqual(torch.float32, layer_norm.weight.grad.dtype)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "apex/contrib/test/multihead_attn/__init__.py",
    "content": ""
  },
  {
    "path": "apex/contrib/test/multihead_attn/test_encdec_multihead_attn.py",
    "content": "import unittest\n\nimport torch\n\nSKIP_TEST = None\ntry:\n    from apex.contrib.multihead_attn import EncdecMultiheadAttn\nexcept ImportError as e:\n    SKIP_TEST = e\n\n\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\nclass EncdecMultiheadAttnTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n\n        self.seq_length = 80\n        self.sequences = 10\n        self.hidden_dim = 1024\n        self.heads = 16\n        self.dropout_prob = 0.0\n\n        self.ref_layer = EncdecMultiheadAttn(\n            self.hidden_dim,\n            self.heads,\n            dropout=self.dropout_prob,\n            bias=False,\n            include_norm_add=False,\n            impl=\"default\",\n        )\n        self.ref_layer.cuda().half()\n        self.ref_layer.reset_parameters()\n        self.ref_inputs_q = torch.randn(\n            self.seq_length,\n            self.sequences,\n            self.hidden_dim,\n            dtype=torch.float16,\n            device=torch.device(\"cuda\"),\n        ).requires_grad_(True)\n        self.ref_inputs_k = torch.randn(\n            self.seq_length,\n            self.sequences,\n            self.hidden_dim,\n            dtype=torch.float16,\n            device=torch.device(\"cuda\"),\n        ).requires_grad_(True)\n\n        # Reset seed so parameters are identical\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.tst_layer = EncdecMultiheadAttn(\n            self.hidden_dim,\n            self.heads,\n            dropout=self.dropout_prob,\n            bias=False,\n            include_norm_add=False,\n            impl=\"fast\",\n        )\n        self.tst_layer.cuda().half()\n        self.tst_layer.reset_parameters()\n\n        self.tst_inputs_q = torch.randn(\n            self.seq_length,\n            self.sequences,\n            self.hidden_dim,\n            dtype=torch.float16,\n            device=torch.device(\"cuda\"),\n        ).requires_grad_(True)\n        self.tst_inputs_k = torch.randn(\n            self.seq_length,\n            self.sequences,\n            self.hidden_dim,\n            dtype=torch.float16,\n            device=torch.device(\"cuda\"),\n        ).requires_grad_(True)\n\n    def test_encdec_multihead_attn(self):\n        ref_outputs, _ = self.ref_layer.forward(\n            self.ref_inputs_q,\n            self.ref_inputs_k,\n            self.ref_inputs_k,\n            key_padding_mask=None,\n            need_weights=False,\n            attn_mask=None,\n            is_training=True,\n        )\n\n        tst_outputs, _ = self.tst_layer.forward(\n            self.tst_inputs_q,\n            self.tst_inputs_k,\n            self.tst_inputs_k,\n            key_padding_mask=None,\n            need_weights=False,\n            attn_mask=None,\n            is_training=True,\n        )\n        torch.testing.assert_close(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)\n\n        with torch.no_grad():\n            ref_grads = torch.randn_like(ref_outputs)\n            tst_grads = ref_grads.clone()\n        ref_outputs.backward(ref_grads)\n        tst_outputs.backward(tst_grads)\n        torch.testing.assert_close(\n            self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3\n        )\n\n    def test_encdec_multihead_attn_time_mask(self):\n        grads = torch.randn_like(self.tst_inputs_q)\n        time_mask_byte = torch.triu(\n            torch.ones(\n                self.tst_inputs_q.size(0),\n                self.tst_inputs_k.size(0),\n                device=torch.device(\"cuda\"),\n                dtype=torch.uint8,\n            ),\n            1,\n        )\n        time_mask_bool = time_mask_byte.to(torch.bool)\n\n        ref_outputs, _ = self.ref_layer.forward(\n            self.ref_inputs_q,\n            self.ref_inputs_k,\n            self.ref_inputs_k,\n            key_padding_mask=None,\n            need_weights=False,\n            attn_mask=time_mask_bool,\n            is_training=True,\n        )\n\n        tst_outputs, _ = self.tst_layer.forward(\n            self.tst_inputs_q,\n            self.tst_inputs_k,\n            self.tst_inputs_k,\n            key_padding_mask=None,\n            need_weights=False,\n            attn_mask=time_mask_byte,\n            is_training=True,\n        )\n\n        self.ref_inputs_q.backward(grads)\n        self.tst_inputs_q.backward(grads)\n\n        torch.testing.assert_close(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)\n        torch.testing.assert_close(\n            self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3\n        )\n\n    def test_encdec_multihead_attn_pad_mask(self):\n        grads = torch.randn_like(self.tst_inputs_q)\n        pad_mask_byte = torch.tril(\n            torch.ones(\n                self.tst_inputs_k.size(1),\n                self.tst_inputs_k.size(0),\n                device=torch.device(\"cuda\"),\n                dtype=torch.uint8,\n            ),\n            1,\n        )\n        pad_mask_bool = pad_mask_byte.to(torch.bool)\n\n        ref_outputs, _ = self.ref_layer.forward(\n            self.ref_inputs_q,\n            self.ref_inputs_k,\n            self.ref_inputs_k,\n            key_padding_mask=pad_mask_bool,\n            need_weights=False,\n            attn_mask=None,\n            is_training=True,\n        )\n\n        tst_outputs, _ = self.tst_layer.forward(\n            self.tst_inputs_q,\n            self.tst_inputs_k,\n            self.tst_inputs_k,\n            key_padding_mask=pad_mask_byte,\n            need_weights=False,\n            attn_mask=None,\n            is_training=True,\n        )\n\n        self.ref_inputs_q.backward(grads)\n        self.tst_inputs_q.backward(grads)\n\n        torch.testing.assert_close(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)\n        torch.testing.assert_close(\n            self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "apex/contrib/test/multihead_attn/test_encdec_multihead_attn_norm_add.py",
    "content": "import unittest\n\nimport torch\n\nSKIP_TEST = None\ntry:\n    from apex.contrib.multihead_attn import EncdecMultiheadAttn\nexcept ImportError as e:\n    SKIP_TEST = e\n\n\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\nclass EncdecMultiheadAttnNormAddTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.seq_length = 80\n        self.sequences = 10\n        self.hidden_dim = 1024\n        self.heads = 16\n        self.dropout_prob = 0.0\n\n        self.ref_layer = EncdecMultiheadAttn(\n            self.hidden_dim,\n            self.heads,\n            dropout=self.dropout_prob,\n            bias=False,\n            include_norm_add=True,\n            impl=\"default\",\n        )\n        self.ref_layer.cuda().half()\n        self.ref_layer.reset_parameters()\n        self.ref_inputs_q = torch.randn(\n            self.seq_length,\n            self.sequences,\n            self.hidden_dim,\n            dtype=torch.float16,\n            device=torch.device(\"cuda\"),\n        ).requires_grad_(True)\n        self.ref_inputs_k = torch.randn(\n            self.seq_length,\n            self.sequences,\n            self.hidden_dim,\n            dtype=torch.float16,\n            device=torch.device(\"cuda\"),\n        ).requires_grad_(True)\n\n        # Reset seed so parameters are identical\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.tst_layer = EncdecMultiheadAttn(\n            self.hidden_dim,\n            self.heads,\n            dropout=self.dropout_prob,\n            bias=False,\n            include_norm_add=True,\n            impl=\"fast\",\n        )\n        self.tst_layer.cuda().half()\n        self.tst_layer.reset_parameters()\n\n        self.tst_inputs_q = torch.randn(\n            self.seq_length,\n            self.sequences,\n            self.hidden_dim,\n            dtype=torch.float16,\n            device=torch.device(\"cuda\"),\n        ).requires_grad_(True)\n        self.tst_inputs_k = torch.randn(\n            self.seq_length,\n            self.sequences,\n            self.hidden_dim,\n            dtype=torch.float16,\n            device=torch.device(\"cuda\"),\n        ).requires_grad_(True)\n\n    def test_encdec_multihead_attn_norm_add(self):\n        grads = torch.randn_like(self.tst_inputs_q)\n\n        for _ in range(5):\n            ref_outputs, _ = self.ref_layer.forward(\n                self.ref_inputs_q,\n                self.ref_inputs_k,\n                self.ref_inputs_k,\n                key_padding_mask=None,\n                need_weights=False,\n                attn_mask=None,\n                is_training=True,\n            )\n\n            tst_outputs, _ = self.tst_layer.forward(\n                self.tst_inputs_q,\n                self.tst_inputs_k,\n                self.tst_inputs_k,\n                key_padding_mask=None,\n                need_weights=False,\n                attn_mask=None,\n                is_training=True,\n            )\n\n            self.ref_inputs_q.backward(grads)\n            self.tst_inputs_q.backward(grads)\n\n        torch.testing.assert_close(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)\n        torch.testing.assert_close(\n            self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "apex/contrib/test/multihead_attn/test_fast_self_multihead_attn_bias.py",
    "content": "import unittest\n\nimport torch\n\nSKIP_TEST = None\ntry:\n    from apex.contrib.multihead_attn import SelfMultiheadAttn\nexcept ImportError as e:\n    SKIP_TEST = e\n\n\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\nclass SelfMultiheadAttnTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.seq_length = 80\n        self.sequences = 10\n        self.hidden_dim = 1024\n        self.heads = 16\n        self.dropout_prob = 0.0\n\n        self.ref_layer = SelfMultiheadAttn(\n            self.hidden_dim,\n            self.heads,\n            dropout=self.dropout_prob,\n            bias=True,\n            include_norm_add=False,\n            separate_qkv_params=True,\n            mask_additive=True,\n            impl=\"default\",\n        )\n        self.ref_layer.cuda().half()\n        self.ref_layer.reset_parameters()\n        self.ref_inputs = torch.randn(\n            self.seq_length,\n            self.sequences,\n            self.hidden_dim,\n            dtype=torch.float16,\n            device=torch.device(\"cuda\"),\n        ).requires_grad_(True)\n        # Reset seed so parameters are identical\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.tst_layer = SelfMultiheadAttn(\n            self.hidden_dim,\n            self.heads,\n            dropout=self.dropout_prob,\n            bias=True,\n            include_norm_add=False,\n            separate_qkv_params=True,\n            mask_additive=True,\n            impl=\"fast\",\n        )\n        self.tst_layer.cuda().half()\n        self.tst_layer.reset_parameters()\n\n        self.tst_inputs = torch.randn(\n            self.seq_length,\n            self.sequences,\n            self.hidden_dim,\n            dtype=torch.float16,\n            device=torch.device(\"cuda\"),\n        ).requires_grad_(True)\n\n    def test_self_multihead_attn_additive_mask(self):\n        grads = torch.randn_like(self.tst_inputs)\n        mask = ((torch.randn(self.sequences, self.seq_length) > 0) * -10000.0).half().cuda()\n\n        ref_outputs, _ = self.ref_layer.forward(\n            self.ref_inputs,\n            self.ref_inputs,\n            self.ref_inputs,\n            key_padding_mask=mask,\n            need_weights=False,\n            attn_mask=None,\n            is_training=True,\n        )\n\n        tst_outputs, _ = self.tst_layer.forward(\n            self.tst_inputs,\n            self.tst_inputs,\n            self.tst_inputs,\n            key_padding_mask=mask,\n            need_weights=False,\n            attn_mask=None,\n            is_training=True,\n        )\n\n        self.ref_inputs.backward(grads)\n        self.tst_inputs.backward(grads)\n\n        torch.testing.assert_close(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)\n        torch.testing.assert_close(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "apex/contrib/test/multihead_attn/test_mha_fused_softmax.py",
    "content": "import unittest\n\nimport torch\nimport torch.nn.functional as F\n\nSKIP_TEST = None\ntry:\n    from apex.contrib.multihead_attn import fast_mask_softmax_dropout_func\nexcept ImportError as e:\n    SKIP_TEST = e\n\n\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\nclass FusedSoftmaxTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.seq_length = 80\n        self.sequences = 10\n        self.hidden_dim = 1024\n        self.heads = 16\n        self.dropout_prob = 0.0\n\n        self.mask = (torch.randn(self.sequences, self.seq_length) > 0).cuda()\n        self.mask = self.mask.half() * -10000\n        self.ref_inputs = torch.randn(\n            self.heads * self.sequences,\n            self.seq_length,\n            self.seq_length,\n            dtype=torch.float16,\n            device=torch.device(\"cuda\"),\n        ).requires_grad_(True)\n\n        self.tst_inputs = self.ref_inputs.clone().detach().requires_grad_(True)\n\n    def test_fused_softmax(self):\n        grads = torch.randn_like(self.tst_inputs)\n        y_ref = self.ref_inputs.view(self.sequences, self.heads, self.seq_length, self.seq_length)\n        y_ref = y_ref + self.mask.unsqueeze(1).unsqueeze(2)\n        y_ref = y_ref.view(self.sequences * self.heads, self.seq_length, self.seq_length)\n        y_ref = F.softmax(y_ref, dim=-1)\n        y_ref = torch._fused_dropout(y_ref, 1.0)\n\n        y_tst = fast_mask_softmax_dropout_func(\n            True, self.heads, self.tst_inputs, self.mask, True, 0.0\n        )\n        y_ref[0].backward(grads)\n        y_tst.backward(grads)\n\n        torch.testing.assert_close(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(y_ref[0], y_tst, atol=1e-3, rtol=1e-3)\n        torch.testing.assert_close(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "apex/contrib/test/multihead_attn/test_self_multihead_attn.py",
    "content": "import unittest\n\nimport torch\n\nSKIP_TEST = None\ntry:\n    from apex.contrib.multihead_attn import SelfMultiheadAttn\nexcept ImportError as e:\n    SKIP_TEST = e\n\n\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\nclass SelfMultiheadAttnTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.seq_length = 80\n        self.sequences = 10\n        self.hidden_dim = 1024\n        self.heads = 16\n        self.dropout_prob = 0.0\n\n        self.ref_layer = SelfMultiheadAttn(\n            self.hidden_dim,\n            self.heads,\n            dropout=self.dropout_prob,\n            bias=False,\n            include_norm_add=False,\n            impl=\"default\",\n        )\n        self.ref_layer.cuda().half()\n        self.ref_layer.reset_parameters()\n        self.ref_inputs = torch.randn(\n            self.seq_length,\n            self.sequences,\n            self.hidden_dim,\n            dtype=torch.float16,\n            device=torch.device(\"cuda\"),\n        ).requires_grad_(True)\n\n        # Reset seed so parameters are identical\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.tst_layer = SelfMultiheadAttn(\n            self.hidden_dim,\n            self.heads,\n            dropout=self.dropout_prob,\n            bias=False,\n            include_norm_add=False,\n            impl=\"fast\",\n        )\n        self.tst_layer.cuda().half()\n        self.tst_layer.reset_parameters()\n\n        self.tst_inputs = torch.randn(\n            self.seq_length,\n            self.sequences,\n            self.hidden_dim,\n            dtype=torch.float16,\n            device=torch.device(\"cuda\"),\n        ).requires_grad_(True)\n\n    def test_self_multihead_attn(self):\n        ref_outputs, _ = self.ref_layer.forward(\n            self.ref_inputs,\n            self.ref_inputs,\n            self.ref_inputs,\n            key_padding_mask=None,\n            need_weights=False,\n            attn_mask=None,\n            is_training=True,\n        )\n\n        tst_outputs, _ = self.tst_layer.forward(\n            self.tst_inputs,\n            self.tst_inputs,\n            self.tst_inputs,\n            key_padding_mask=None,\n            need_weights=False,\n            attn_mask=None,\n            is_training=True,\n        )\n\n        torch.testing.assert_close(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)\n\n        with torch.no_grad():\n            ref_grads = torch.randn_like(self.tst_inputs)\n            tst_grads = ref_grads.clone()\n\n        ref_outputs.backward(ref_grads)\n        tst_outputs.backward(tst_grads)\n        torch.testing.assert_close(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3)\n\n    def test_self_multihead_attn_time_mask(self):\n        grads = torch.randn_like(self.tst_inputs)\n        time_mask_byte = torch.triu(\n            torch.ones(\n                self.tst_inputs.size(0),\n                self.tst_inputs.size(0),\n                device=torch.device(\"cuda\"),\n                dtype=torch.uint8,\n            ),\n            1,\n        )\n        time_mask_bool = time_mask_byte.to(torch.bool)\n\n        ref_outputs, _ = self.ref_layer.forward(\n            self.ref_inputs,\n            self.ref_inputs,\n            self.ref_inputs,\n            key_padding_mask=None,\n            need_weights=False,\n            attn_mask=time_mask_bool,\n            is_training=True,\n        )\n\n        tst_outputs, _ = self.tst_layer.forward(\n            self.tst_inputs,\n            self.tst_inputs,\n            self.tst_inputs,\n            key_padding_mask=None,\n            need_weights=False,\n            attn_mask=time_mask_byte,\n            is_training=True,\n        )\n\n        self.ref_inputs.backward(grads)\n        self.tst_inputs.backward(grads)\n\n        torch.testing.assert_close(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(ref_outputs, tst_outputs, atol=5e-3, rtol=1e-3)\n        torch.testing.assert_close(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3)\n\n    def test_self_multihead_attn_pad_mask(self):\n        grads = torch.randn_like(self.tst_inputs)\n        pad_mask_byte = torch.tril(\n            torch.ones(\n                self.tst_inputs.size(1),\n                self.tst_inputs.size(0),\n                device=torch.device(\"cuda\"),\n                dtype=torch.uint8,\n            ),\n            1,\n        )\n        pad_mask_bool = pad_mask_byte.to(torch.bool)\n\n        ref_outputs, _ = self.ref_layer.forward(\n            self.ref_inputs,\n            self.ref_inputs,\n            self.ref_inputs,\n            key_padding_mask=pad_mask_bool,\n            need_weights=False,\n            attn_mask=None,\n            is_training=True,\n        )\n\n        tst_outputs, _ = self.tst_layer.forward(\n            self.tst_inputs,\n            self.tst_inputs,\n            self.tst_inputs,\n            key_padding_mask=pad_mask_byte,\n            need_weights=False,\n            attn_mask=None,\n            is_training=True,\n        )\n\n        self.ref_inputs.backward(grads)\n        self.tst_inputs.backward(grads)\n\n        torch.testing.assert_close(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)\n        torch.testing.assert_close(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "apex/contrib/test/multihead_attn/test_self_multihead_attn_norm_add.py",
    "content": "import unittest\n\nimport torch\n\nSKIP_TEST = None\ntry:\n    from apex.contrib.multihead_attn import SelfMultiheadAttn\nexcept ImportError as e:\n    SKIP_TEST = e\n\n\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\nclass SelfMultiheadAttnNormAddTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n\n        self.seq_length = 80\n        self.sequences = 10\n        self.hidden_dim = 1024\n        self.heads = 16\n        self.dropout_prob = 0.0\n\n        self.ref_layer = SelfMultiheadAttn(\n            self.hidden_dim,\n            self.heads,\n            dropout=self.dropout_prob,\n            bias=False,\n            include_norm_add=True,\n            impl=\"default\",\n        )\n        self.ref_layer.cuda().half()\n        self.ref_layer.reset_parameters()\n        self.ref_inputs = torch.randn(\n            self.seq_length,\n            self.sequences,\n            self.hidden_dim,\n            dtype=torch.float16,\n            device=torch.device(\"cuda\"),\n        ).requires_grad_(True)\n\n        # Reset seed so parameters are identical\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n        self.tst_layer = SelfMultiheadAttn(\n            self.hidden_dim,\n            self.heads,\n            dropout=self.dropout_prob,\n            bias=False,\n            include_norm_add=True,\n            impl=\"fast\",\n        )\n        self.tst_layer.cuda().half()\n        self.tst_layer.reset_parameters()\n\n        self.tst_inputs = torch.randn(\n            self.seq_length,\n            self.sequences,\n            self.hidden_dim,\n            dtype=torch.float16,\n            device=torch.device(\"cuda\"),\n        ).requires_grad_(True)\n\n    def test_self_multihead_attn_norm_add(self):\n        grads = torch.randn_like(self.tst_inputs)\n\n        for _ in range(0, 5):\n            ref_outputs, _ = self.ref_layer.forward(\n                self.ref_inputs,\n                self.ref_inputs,\n                self.ref_inputs,\n                key_padding_mask=None,\n                need_weights=False,\n                attn_mask=None,\n                is_training=True,\n            )\n\n            tst_outputs, _ = self.tst_layer.forward(\n                self.tst_inputs,\n                self.tst_inputs,\n                self.tst_inputs,\n                key_padding_mask=None,\n                need_weights=False,\n                attn_mask=None,\n                is_training=True,\n            )\n\n            self.ref_inputs.backward(grads)\n            self.tst_inputs.backward(grads)\n\n        torch.testing.assert_close(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)\n        torch.testing.assert_close(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "apex/contrib/test/openfold_triton/test_fused_adam_swa.py",
    "content": "# Copyright 2023 NVIDIA CORPORATION\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\nfrom itertools import chain\nimport random\nimport unittest\n\nimport torch\nimport torch.nn as nn\n\nSKIP_TEST = None\ntry:\n    from apex.contrib.openfold_triton.fused_adam_swa import AdamMathType, FusedAdamSWA\nexcept ImportError as e:\n    SKIP_TEST = e\n\n\n# Stochastic weight average (SWA) reference code from\n# https://github.com/mlcommons/hpc_results_v3.0/blob/350e46f7/NVIDIA/benchmarks/openfold/implementations/pytorch/openfold/swa.py#L21-L70\nclass AlphaFoldSWA(nn.Module):\n    \"\"\"AlphaFold SWA (Stochastic Weight Averaging) module wrapper.\"\"\"\n\n    def __init__(self, alphafold: nn.Module, enabled: bool, decay_rate: float) -> None:\n        super(AlphaFoldSWA, self).__init__()\n        if enabled:\n            self.averaged_model = torch.optim.swa_utils.AveragedModel(\n                model=alphafold,\n                avg_fn=swa_avg_fn(decay_rate=decay_rate),\n            )\n            self.enabled = True\n        else:\n            self.averaged_model = None\n            self.enabled = False\n\n    def update(self, alphafold: nn.Module) -> None:\n        if self.enabled:\n            self.averaged_model.update_parameters(model=alphafold)\n\n    def forward(self, batch):\n        if not self.enabled:\n            raise RuntimeError(\"AlphaFoldSWA is not enabled\")\n        return self.averaged_model(batch)\n\n\nclass swa_avg_fn:\n    \"\"\"Averaging function for EMA with configurable decay rate\n    (Supplementary '1.11.7 Evaluator setup').\"\"\"\n\n    def __init__(self, decay_rate: float) -> None:\n        self._decay_rate = decay_rate\n\n    def __call__(\n        self,\n        averaged_model_parameter: torch.Tensor,\n        model_parameter: torch.Tensor,\n        num_averaged: torch.Tensor,\n    ) -> torch.Tensor:\n        # for decay_rate = 0.999:\n        # return averaged_model_parameter * 0.999 + model_parameter * 0.001\n        # avg * 0.999 + m * 0.001\n        # 999*avg/1000 + m/1000\n        # (999*avg + avg - avg)/1000 + m/1000\n        # (1000*avg - avg)/1000 + m/1000\n        # 1000*avg/1000 - avg/1000 + m/1000\n        # avg + (m - avg)/1000\n        # avg + (m - avg)*0.001\n        return averaged_model_parameter + (model_parameter - averaged_model_parameter) * (\n            1.0 - self._decay_rate\n        )\n\n\n@unittest.skipIf(SKIP_TEST, f\"Skip testing FusedAdamSWA: {SKIP_TEST}\")\nclass FusedAdamSWATestCase(unittest.TestCase):\n    def setUp(self):\n        super().setUp()\n        self._seed = 19260817\n        random.seed(self._seed)\n        torch.manual_seed(self._seed)\n        # FIXME: correctly fix: \"\"\"NameError(\"Cannot access global variable _DTYPE2TRITON from within @jit'ed function.\n        # Triton kernels can only access global variables that are instanstiated as constexpr (`x = triton.language.constexpr(42)`).\n        # Note that this is different from annotating a variable as constexpr (`x: triton.language.constexpr = 42`), which is not supported.\n        # Alternatively, set the envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not promise to support this forever.\")\"\"\"\n        os.environ[\"TRITON_ALLOW_NON_CONSTEXPR_GLOBALS\"] = \"1\"\n\n    def tearDown(self):\n        os.environ.pop(\"TRITON_ALLOW_NON_CONSTEXPR_GLOBALS\", None)\n\n    def test_fused_update_on_random_data(self):\n        with torch.backends.cudnn.flags(deterministic=True):\n            self._run_fused_update_on_random_data()\n\n    def _run_fused_update_on_random_data(self):\n        device = torch.device(\"cuda:0\")\n        compute_dtype = torch.float32\n        state_dtype = torch.float64\n        atol = 1e-5  # Default: 1e-8, raise error at 1e-6 for FP32 compute and FP64 state.\n        rtol = 1e-4  # Default: 1e-5\n        lr = 1e-1\n        bias_correction = True\n        beta1, beta2 = 0.9, 0.999\n        eps = 1e-6\n        adam_math_mode = AdamMathType.PyTorchAdam\n        weight_decay = 1e-3  # PyTorchAdam impl will fail non-zero weight decay.\n        amsgrad = False\n        adam_step = 1900\n        swa_decay_rate = 0.9\n        swa_n_averaged = 1\n\n        state_params = [\n            torch.empty(random.randint(128, 2048), device=device, dtype=state_dtype).uniform_(-5, 5)\n            for _ in range(32)\n        ]\n        compute_dtypes = [\n            compute_dtype if random.uniform(0.0, 1.0) <= 0.5 else state_dtype for _ in range(32)\n        ]\n        grads = [\n            torch.empty_like(p, dtype=d).uniform_(-5, 5)\n            for d, p in zip(compute_dtypes, state_params)\n        ]\n        moments = [torch.empty_like(p).uniform_(-5, 5) for p in state_params]\n        velocities = [torch.empty_like(p).uniform_(0, 10) for p in state_params]\n\n        # Ground truth: Apex FusedAdam, optimized-hpc SWA.\n        compute_params_gt = [p.clone().to(d) for d, p in zip(compute_dtypes, state_params)]\n        dummy_model = torch.nn.Module()\n        for i, p in enumerate(state_params):\n            dummy_model.register_parameter(f\"param_{i}\", torch.nn.Parameter(p.clone()))\n        state_params_gt = list(dummy_model.parameters())\n        swa_model = AlphaFoldSWA(dummy_model, enabled=True, decay_rate=swa_decay_rate)\n        swa_params_gt = list(swa_model.parameters())\n        optimizer = torch.optim.Adam(\n            state_params_gt,\n            lr=lr,\n            betas=(beta1, beta2),\n            eps=eps,\n            weight_decay=weight_decay,\n            amsgrad=amsgrad,\n        )\n        moments_gt, velocities_gt = [], []\n        for i, p in enumerate(optimizer.param_groups[0][\"params\"]):\n            s = optimizer.state[p]\n            self.assertTrue(moments[i].shape == velocities[i].shape == p.shape)\n            s[\"step\"] = torch.tensor(adam_step, dtype=state_dtype, device=device)\n            s[\"exp_avg\"] = moments[i].clone()\n            s[\"exp_avg_sq\"] = velocities[i].clone()\n            moments_gt.append(s[\"exp_avg\"])\n            velocities_gt.append(s[\"exp_avg_sq\"])\n        for p, g in zip(state_params_gt, grads):\n            p.grad = g.clone().to(state_dtype)\n        optimizer.step()\n        swa_model.averaged_model.n_averaged.copy_(swa_n_averaged)\n        swa_model.update(dummy_model)\n        for c, s in zip(compute_params_gt, state_params_gt):\n            c.detach().copy_(s.detach().to(c.dtype))\n\n        # Fused AdamSWA, all at once.\n        state_params_test = [torch.nn.Parameter(p.clone()) for p in state_params]\n        compute_params_test = [p.clone().to(d) for d, p in zip(compute_dtypes, state_params)]\n        swa_params_test = [p.clone() for p in state_params]\n        fused_optimizer = FusedAdamSWA(\n            params=state_params_test,\n            compute_params=compute_params_test,\n            swa_params=swa_params_test,\n            swa_decay_rate=swa_decay_rate,\n            lr=lr,\n            bias_correction=bias_correction,\n            betas=(beta1, beta2),\n            eps=eps,\n            adam_math_mode=adam_math_mode,\n            weight_decay=weight_decay,\n            amsgrad=amsgrad,\n        )\n        moments_test, velocities_test = [], []\n        for i, p in enumerate(fused_optimizer.param_groups[0][\"params\"]):\n            s = fused_optimizer.state[p]\n            self.assertTrue(moments[i].shape == velocities[i].shape == p.shape)\n            s[\"exp_avg\"] = moments[i].clone()\n            s[\"exp_avg_sq\"] = velocities[i].clone()\n            moments_test.append(s[\"exp_avg\"])\n            velocities_test.append(s[\"exp_avg_sq\"])\n        for c, g in zip(compute_params_test, grads):\n            c.grad = g.clone()\n        fused_optimizer.param_groups[0][\"step\"] = adam_step\n        fused_optimizer.swa_param_groups[0][\"n_averaged\"] = swa_n_averaged\n        fused_optimizer.step()\n\n        # Ensure parameters are actually updated.\n        for i, (p_gt, p_test, p_origin) in enumerate(\n            zip(state_params_gt, state_params_test, state_params)\n        ):\n            self.assertFalse(torch.allclose(p_gt, p_origin, rtol=rtol, atol=atol))\n            self.assertFalse(torch.allclose(p_test, p_origin, rtol=rtol, atol=atol))\n        # Ensure FusedAdamSWA correctness.\n        self.assertEqual(\n            swa_model.averaged_model.n_averaged.item(),\n            fused_optimizer.swa_param_groups[0][\"n_averaged\"],\n        )\n        for i, (p_test, p_gt) in enumerate(\n            zip(\n                chain(state_params_test, compute_params_test, swa_params_test),\n                chain(state_params_gt, compute_params_gt, swa_params_gt),\n            )\n        ):\n            self.assertTrue(torch.allclose(p_test, p_gt, rtol=rtol, atol=atol))\n        # Ensure moments are updated correctly.\n        for i, (m, m_gt) in enumerate(\n            zip(\n                chain(moments_test, velocities_test),\n                chain(moments_gt, velocities_gt),\n            )\n        ):\n            self.assertTrue(torch.allclose(m, m_gt, rtol=rtol, atol=atol))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "apex/contrib/test/openfold_triton/test_openfold_mha.py",
    "content": "import math\nimport random\nfrom typing import Optional\nimport torch\nimport unittest\n\nSKIP_TEST = None\ntry:\n    from apex.contrib.openfold_triton import AttnTri as openfold_attention_triton\nexcept ImportError as e:\n    SKIP_TEST = e\n\n\ndef openfold_attention_eager(\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    mask: torch.Tensor,\n    bias: Optional[torch.Tensor],\n    inf: float,\n) -> torch.Tensor:\n    # query:  [*, num_heads, Q, c_hidden]\n    # key:    [*, num_heads, K, c_hidden]\n    # value:  [*, num_heads, V, c_hidden]\n    # mask:   Logit mask tensor broadcastable to [*, num_heads, Q, K]\n    # bias:   Optional logit bias tensor broadcastable to [*, num_heads, Q, K]\n    # inf:    Safe infinity value.\n    # assuming K == V\n\n    key = torch.swapdims(key, -2, -1)\n    # key: [*, num_heads, c_hidden, K]\n\n    scaling = 1.0 / math.sqrt(query.size(-1))\n    a = torch.matmul(query * scaling, key)\n    # a: [*, num_heads, Q, K]\n\n    a += (mask - 1.0) * inf\n    # a: [*, num_heads, Q, K]\n\n    if bias is not None:\n        a += bias\n    # a: [*, num_heads, Q, K]\n\n    a = torch.softmax(a, dim=-1)\n    # a: [*, num_heads, Q, K]\n\n    a = torch.matmul(a, value)\n    # a: [*, num_heads, Q, c_hidden]\n\n    return a\n\n\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\nclass OpenfoldMhaTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        super().setUp()\n        random.seed(seed)\n        torch.manual_seed(seed)\n\n    # representative workload in openfold\n    def test_openfold_triton_mha(self, Z=256, H=4, N_CTX=256, D_HEAD=32, dtype=torch.float16):\n        One = 1\n        q = (\n            torch.empty((One, Z, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\")\n            .normal_(mean=0.1, std=0.2)\n            .requires_grad_()\n        )\n        k = (\n            torch.empty((One, Z, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\")\n            .normal_(mean=0.4, std=0.2)\n            .requires_grad_()\n        )\n        v = (\n            torch.empty((One, Z, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\")\n            .normal_(mean=0.3, std=0.2)\n            .requires_grad_()\n        )\n        bias = (\n            torch.empty((One, One, H, N_CTX, N_CTX), dtype=dtype, device=\"cuda\")\n            .normal_(mean=0.2, std=0.2)\n            .requires_grad_()\n        )\n        mask = (\n            torch.empty((One, N_CTX, One, One, N_CTX), device=\"cuda\").normal_(mean=0, std=0.5) > 0\n        )\n        mask = mask.to(device=torch.device(\"cuda\"), dtype=dtype).requires_grad_(False)\n\n        dout = torch.randn_like(q)\n        inf = 1e9\n\n        # reference implementation\n        ref_out = openfold_attention_eager(q, k, v, mask, bias, inf)\n        ref_out.backward(dout)\n\n        ref_dv, v.grad = v.grad.clone(), None\n        ref_dk, k.grad = k.grad.clone(), None\n        ref_dq, q.grad = q.grad.clone(), None\n        ref_dbias, bias.grad = bias.grad.clone(), None\n\n        # triton implementation\n        tri_out = openfold_attention_triton(q, k, v, mask, bias, inf, torch.is_grad_enabled())\n        tri_out.backward(dout)\n\n        tri_dv, v.grad = v.grad.clone(), None\n        tri_dk, k.grad = k.grad.clone(), None\n        tri_dq, q.grad = q.grad.clone(), None\n        tri_dbias, bias.grad = bias.grad.clone(), None\n\n        # check results\n        torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0)\n        torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=0)\n        torch.testing.assert_close(ref_dk, tri_dk, atol=1e-2, rtol=0)\n        torch.testing.assert_close(ref_dq, tri_dq, atol=1e-2, rtol=0)\n        torch.testing.assert_close(ref_dbias, tri_dbias, atol=1e-2, rtol=0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "apex/contrib/test/openfold_triton/test_sync_triton_auto_tune_cache_across_gpus.py",
    "content": "import os\n\nimport torch\nimport torch.distributed as dist\nfrom torch.testing._internal.common_utils import run_tests\nfrom torch.testing._internal.common_distributed import (\n    MultiProcessTestCase,\n    requires_nccl,\n    skip_if_lt_x_gpu,\n)\n\nfrom apex.contrib.openfold_triton import (\n    LayerNormSmallShapeOptImpl,\n    sync_triton_auto_tune_cache_across_gpus,\n    _tuneable_triton_kernels,\n)\n\n\nclass SyncTritonAutoTuneCacheTest(MultiProcessTestCase):\n    device_type = \"cuda\"\n\n    def __init__(self, *args, **kwargs) -> None:\n        super().__init__(*args, **kwargs)\n\n    def setUp(self) -> None:\n        super().setUp()\n        self._spawn_processes()\n\n    def tearDown(self) -> None:\n        torch.cuda.synchronize()\n        torch.cuda.empty_cache()\n        super().tearDown()\n\n    @property\n    def world_size(self) -> int:\n        return min(torch.cuda.device_count(), 2)\n\n    @property\n    def init_method(self):\n        return f\"{common_utils.FILE_SCHEMA}{self.file_name}\"\n\n    @property\n    def destroy_pg_upon_exit(self) -> bool:\n        return True\n\n    def _create_process_group_nccl(self):\n        def maybe_export(env, val):\n            if not type(env) == str:\n                raise ValueError(f\"Type of type of env is expected to be str, but got {type(env)}\")\n            if not type(val) == str:\n                raise ValueError(f\"Type of type of val is expected to be str, but got {type(val)}\")\n            if os.getenv(env) is None:\n                os.environ[env] = val\n\n        maybe_export(\"MASTER_PORT\", \"29500\")\n        maybe_export(\"MASTER_ADDR\", \"localhost\")\n\n        # create nccl processgroup for two ranks\n        dist.init_process_group(\n            \"nccl\",\n            world_size=self.world_size,\n            rank=self.rank,\n        )\n        pg = dist.distributed_c10d._get_default_group()\n        return pg\n\n    @requires_nccl()\n    @skip_if_lt_x_gpu(1)\n    def test_sync_triton_auto_tune_cache_across_gpus(self):\n        pg = self._create_process_group_nccl()\n        device = torch.device(f\"cuda:{self.rank % torch.cuda.device_count()}\")\n        torch.cuda.set_device(device)\n\n        if self.rank == 0:\n            eps = 1e-5\n            normalized_shape = (\n                128,\n                64,\n            )\n\n            weight = torch.ones(normalized_shape, device=device, requires_grad=True)\n            bias = torch.zeros(normalized_shape, device=device, requires_grad=True)\n\n            x = torch.randn(\n                (\n                    2,\n                    2,\n                )\n                + normalized_shape,\n                device=device,\n            )\n            y = LayerNormSmallShapeOptImpl.apply(x, normalized_shape, weight, bias, eps)\n            l = torch.sum(y)\n            l.backward()\n\n        sync_triton_auto_tune_cache_across_gpus(strict=False, verbose=True)\n\n        caches_synced = 0\n        for func_name, func in _tuneable_triton_kernels.items():\n            if len(func.cache) > 0:\n                caches_synced = caches_synced + 1\n                print(\n                    f\"caches were synchronized for {func_name} at rank = {self.rank}:\",\n                    func.cache,\n                )\n\n        self.assertTrue(caches_synced > 0)\n\n\nif __name__ == \"__main__\":\n    run_tests()\n"
  },
  {
    "path": "apex/contrib/test/optimizers/__init__.py",
    "content": ""
  },
  {
    "path": "apex/contrib/test/optimizers/test_dist_adam.py",
    "content": "from contextlib import contextmanager\nimport io\nfrom typing import Callable, Optional\nimport unittest\nimport warnings\nfrom contextlib import nullcontext\n\nimport torch\nfrom torch.testing._internal import common_utils\n\nSKIP_TEST = None\ntry:\n    from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam\nexcept ImportError as e:\n    SKIP_TEST = e\nfrom apex.distributed_testing.distributed_test_base import NcclDistributedTestBase\n\n\nclass SimpleModel(torch.nn.Module):\n    def __init__(self, num_layers, size):\n        super().__init__()\n        self.params = torch.nn.ParameterList(\n            [torch.nn.Parameter(torch.rand(1, size) + 1) for _ in range(num_layers)]\n        )\n\n    def forward(self, x):\n        y = 0\n        for i, param in enumerate(self.params):\n            y += (i + 1) * param * x\n        return y\n\n\ndef make_models(\n    num_layers: int,\n    size: int,\n    *,\n    lr: float = 0.1,\n    adam_w_mode: bool = True,\n    model_dtype: torch.dtype = torch.float32,\n    optim_dtype: Optional[torch.dtype] = None,\n    grad_sync_dtype: Optional[torch.dtype] = None,\n    param_sync_dtype: Optional[torch.dtype] = None,\n    device: torch.device = \"cuda\",\n    process_group: Optional[torch.distributed.ProcessGroup] = None,\n    average_grad_sync: bool = True,\n    overlap_communication: bool = True,\n    bucket_cap_mb: float = 71 / (4 * 1024 * 1024),\n    contiguous_buffers: bool = False,\n    store_params: bool = False,\n    store_param_remainders: bool = False,\n    with_scaled_states: bool = False,\n    nccl_ub: bool = False,\n    with_cuda_graph: bool = False,\n):\n    # Construct models with same parameters\n    ref_model = SimpleModel(num_layers, size).to(dtype=model_dtype, device=device)\n    dist_model = SimpleModel(num_layers, size).to(dtype=model_dtype, device=device)\n    with torch.no_grad():\n        for ref_param, dist_param in zip(dist_model.parameters(), ref_model.parameters()):\n            dist_param.copy_(ref_param)\n\n    # Initialize reference model with data-parallelism\n    rank = torch.distributed.get_rank()\n    ref_model = torch.nn.parallel.DistributedDataParallel(\n        ref_model,\n        device_ids=[rank] if device == \"cuda\" else None,\n        output_device=rank if device == \"cuda\" else None,\n        process_group=process_group,\n    )\n\n    # Construct optimizers with same hyperparameters\n    if optim_dtype is None:\n        optim_dtype = model_dtype\n    optim_args = dict(lr=lr, betas=(0.1, 0.2), eps=0.25, weight_decay=0.1)\n    ref_optim_class = torch.optim.AdamW if adam_w_mode else torch.optim.Adam\n    ref_optim = ref_optim_class(\n        [\n            {\"params\": list(ref_model.parameters())[1::2], \"lr\": lr * 2},\n            {\"params\": list(ref_model.parameters())[0::2]},\n        ],\n        **optim_args,\n    )\n    dist_optim = DistributedFusedAdam(\n        [\n            {\"params\": list(dist_model.parameters())[1::2], \"lr\": lr * 2},\n            {\"params\": list(dist_model.parameters())[0::2]},\n        ],\n        adam_w_mode=adam_w_mode,\n        overlap_grad_sync=overlap_communication,\n        overlap_param_sync=overlap_communication,\n        bucket_cap_mb=bucket_cap_mb,\n        dtype=optim_dtype,\n        grad_sync_dtype=grad_sync_dtype,\n        param_sync_dtype=param_sync_dtype,\n        process_group=process_group,\n        average_grad_sync=average_grad_sync,\n        contiguous_param_buffer=contiguous_buffers,\n        contiguous_grad_buffer=contiguous_buffers,\n        store_params=store_params,\n        store_param_remainders=store_param_remainders,\n        with_scaled_states=with_scaled_states,\n        nccl_ub=nccl_ub,\n        capturable=with_cuda_graph,\n        **optim_args,\n    )\n\n    return ref_model, ref_optim, dist_model, dist_optim\n\n\n@contextmanager\ndef dummy_context():\n    try:\n        yield\n    finally:\n        pass\n\n\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\nclass TestDistributedFusedAdam(NcclDistributedTestBase):\n    seed = 1234\n\n    def test_matches_pytorch(\n        self,\n        rtol: Optional[float] = None,\n        atol: Optional[float] = None,\n        num_layers: int = 11,\n        layer_size: int = 7,\n        batch_size: int = 3,\n        num_steps: int = 3,\n        micro_batch_steps: int = 3,\n        adam_w_mode: bool = True,\n        overlap_communication: bool = True,\n        use_nosync: bool = True,\n        model_dtype: torch.dtype = torch.float32,\n        optim_dtype: Optional[torch.dtype] = None,\n        grad_sync_dtype: Optional[torch.dtype] = None,\n        param_sync_dtype: Optional[torch.dtype] = None,\n        device: torch.device = \"cuda\",\n        bucket_cap_mb: float = 71 / (4 * 1024 * 1024),\n        contiguous_buffers: bool = False,\n        store_params: bool = False,\n        store_param_remainders: bool = False,\n        with_scaled_states: bool = False,\n        nccl_ub: bool = False,\n        init_optim_func: Optional[Callable[[DistributedFusedAdam], None]] = None,\n        with_cuda_graph: bool = False,\n    ):\n        torch.manual_seed(self.seed + self.rank)\n\n        # Identical models with data-parallel and ZeRO\n        stream = torch.cuda.Stream()\n        with torch.cuda.stream(stream):\n            ref_model, ref_optim, dist_model, dist_optim = make_models(\n                num_layers,\n                layer_size,\n                adam_w_mode=adam_w_mode,\n                model_dtype=model_dtype,\n                optim_dtype=optim_dtype,\n                grad_sync_dtype=grad_sync_dtype,\n                param_sync_dtype=param_sync_dtype,\n                device=device,\n                overlap_communication=overlap_communication,\n                bucket_cap_mb=bucket_cap_mb,\n                contiguous_buffers=contiguous_buffers,\n                store_params=store_params,\n                store_param_remainders=store_param_remainders,\n                with_scaled_states=with_scaled_states,\n                nccl_ub=nccl_ub,\n                with_cuda_graph=with_cuda_graph,\n            )\n\n        # Initialize distributed optimizer\n        if init_optim_func is not None:\n            with torch.cuda.stream(stream):\n                init_optim_func(dist_optim)\n\n        # Static data\n        static_xs, static_dys = [], []\n        ys_ref, grad_xs_ref = [], []\n        ys_dist, grad_xs_dist = [], []\n\n        graph = torch.cuda.CUDAGraph() if with_cuda_graph else None\n        CAPTURE_ITERATION = 11\n        if with_cuda_graph:\n            assert num_steps > CAPTURE_ITERATION + 3, \"Not enough iterations for CUDA graph test.\"\n\n        # Training loop\n        with torch.cuda.stream(stream):\n            for step in range(num_steps):\n                # Synthetic data\n                for micro_step in range(micro_batch_steps):\n                    x = torch.rand(batch_size, layer_size) - 0.5\n                    dy = torch.rand_like(x) - 0.5\n                    x = x.to(dtype=model_dtype, device=device)\n                    dy = dy.to(dtype=model_dtype, device=device)\n                    if step == 0:\n                        static_xs.append(x)\n                        static_dys.append(dy)\n                    else:\n                        static_xs[micro_step].copy_(x)\n                        static_dys[micro_step].copy_(dy)\n\n                # Reference implementation\n                ref_optim.zero_grad()\n                for micro_step in range(micro_batch_steps):\n                    x, dy = static_xs[micro_step], static_dys[micro_step]\n\n                    x_ref = x.detach().clone().requires_grad_(True)\n                    y_ref = ref_model(x_ref)\n                    y_ref.backward(dy)\n\n                    if step == 0:\n                        ys_ref.append(y_ref)\n                        grad_xs_ref.append(x_ref.grad)\n                    else:\n                        with torch.no_grad():\n                            ys_ref[micro_step].copy_(y_ref)\n                            grad_xs_ref[micro_step].copy_(x_ref.grad)\n                ref_optim.step()\n\n                # Distributed implementation\n                if not with_cuda_graph or step <= CAPTURE_ITERATION:\n                    if with_cuda_graph and step == CAPTURE_ITERATION:\n                        ctx = torch.cuda.graph(graph)\n                        torch.cuda.synchronize()\n                    else:\n                        ctx = nullcontext()\n\n                    with ctx:\n                        dist_optim.zero_grad()\n                        for micro_step in range(micro_batch_steps):\n                            x, dy = static_xs[micro_step], static_dys[micro_step]\n\n                            x_dist = x.detach().clone().requires_grad_(True)\n                            y_dist = dist_model(x_dist)\n                            backward_context = dummy_context\n                            if use_nosync and micro_step < micro_batch_steps - 1:\n                                backward_context = dist_optim.no_sync\n                            with backward_context():\n                                y_dist.backward(dy)\n\n                            if step == 0:\n                                ys_dist.append(y_dist)\n                                grad_xs_dist.append(x_dist.grad)\n                            else:\n                                with torch.no_grad():\n                                    ys_dist[micro_step].copy_(y_dist)\n                                    grad_xs_dist[micro_step].copy_(x_dist.grad)\n                        dist_optim.step()\n\n                    if with_cuda_graph and step == CAPTURE_ITERATION:\n                        graph.replay()\n                else:\n                    graph.replay()\n\n                # Check that data tensors match\n                for mbs in range(micro_batch_steps):\n                    torch.testing.assert_close(ys_dist[mbs], ys_ref[mbs], rtol=rtol, atol=atol)\n                    torch.testing.assert_close(\n                        grad_xs_dist[mbs], grad_xs_ref[mbs], rtol=rtol, atol=atol\n                    )\n\n                # Check that parameters match\n                for ref_param, dist_param in zip(ref_model.parameters(), dist_model.parameters()):\n                    torch.testing.assert_close(dist_param, ref_param, rtol=rtol, atol=atol)\n\n    def test_matches_pytorch_l2_reg(self):\n        self.test_matches_pytorch(adam_w_mode=False)\n\n    def test_matches_pytorch_no_overlap(self):\n        self.test_matches_pytorch(\n            overlap_communication=False,\n            use_nosync=False,\n        )\n\n    def test_matches_pytorch_sync_every_step(self):\n        self.test_matches_pytorch(use_nosync=False)\n\n    def test_matches_pytorch_contiguous_buffers(self):\n        self.test_matches_pytorch(contiguous_buffers=True)\n\n    def test_matches_pytorch_fp64(self):\n        self.test_matches_pytorch(\n            rtol=1.3e-6,\n            atol=1e-5,\n            model_dtype=torch.float64,\n            optim_dtype=torch.float32,\n        )\n\n    def test_matches_pytorch_fp16(self):\n        self.test_matches_pytorch(\n            rtol=5e-3,\n            atol=1e-5,\n            micro_batch_steps=1,\n            model_dtype=torch.float16,\n            optim_dtype=torch.float16,\n        )\n\n    def test_matches_pytorch_bf16(self):\n        self.test_matches_pytorch(\n            rtol=5e-2,\n            atol=1e-5,\n            micro_batch_steps=1,\n            model_dtype=torch.bfloat16,\n            optim_dtype=torch.bfloat16,\n        )\n\n    def test_matches_pytorch_fp16_params(self):\n        self.test_matches_pytorch(\n            rtol=5e-3,\n            atol=1e-5,\n            micro_batch_steps=1,\n            model_dtype=torch.float16,\n            optim_dtype=torch.float32,\n            param_sync_dtype=torch.float16,\n            store_params=True,\n        )\n\n    def test_matches_pytorch_bf16_grads(self):\n        self.test_matches_pytorch(\n            rtol=5e-2,\n            atol=1e-5,\n            micro_batch_steps=1,\n            model_dtype=torch.float32,\n            optim_dtype=torch.float32,\n            grad_sync_dtype=torch.bfloat16,\n        )\n\n    def test_matches_pytorch_bf16_param_remainders(self):\n        self.test_matches_pytorch(\n            rtol=5e-2,\n            atol=1e-5,\n            micro_batch_steps=1,\n            model_dtype=torch.bfloat16,\n            optim_dtype=torch.float32,\n            param_sync_dtype=torch.bfloat16,\n            store_params=False,\n            store_param_remainders=True,\n        )\n\n    def test_matches_pytorch_multi_dtypes(self):\n        def init_optim(optim: DistributedFusedAdam):\n            params = list(optim.parameters())\n            optim.init_params(params[0::3], grad_sync_dtype=torch.bfloat16)\n            optim.init_params(params[1::3], param_sync_dtype=torch.bfloat16)\n\n        self.test_matches_pytorch(\n            rtol=5e-2,\n            atol=1e-5,\n            init_optim_func=init_optim,\n        )\n\n    def test_matches_pytorch_int64_param_sync(self):\n        self.test_matches_pytorch(\n            param_sync_dtype=torch.int64,\n        )\n\n    def test_matches_pytorch_int32_param_sync_contiguous_buffers(self):\n        self.test_matches_pytorch(\n            param_sync_dtype=torch.int32,\n            contiguous_buffers=True,\n        )\n\n    def test_matches_pytorch_uint8_param_sync(self):\n        self.test_matches_pytorch(\n            rtol=0.5,\n            atol=0.05,\n            model_dtype=torch.float16,\n            optim_dtype=torch.float16,\n            micro_batch_steps=1,\n            param_sync_dtype=torch.uint8,\n        )\n\n    def test_matches_pytorch_scaled_state(self):\n        self.test_matches_pytorch(\n            rtol=5e-2,\n            atol=1e-5,\n            micro_batch_steps=1,\n            model_dtype=torch.bfloat16,\n            optim_dtype=torch.float16,\n            param_sync_dtype=torch.int,\n            store_params=True,\n            with_scaled_states=True,\n        )\n\n    def test_matches_pytorch_nccl_ub(self):\n        self.test_matches_pytorch(\n            contiguous_buffers=True,\n            nccl_ub=True,\n        )\n\n    def test_raises_on_mismatch(self):\n        torch.manual_seed(self.seed + self.rank)\n\n        # Identical models with data-parallel and ZeRO\n        num_layers = 11\n        layer_size = 7\n        ref_model, ref_optim, dist_model, dist_optim = make_models(\n            num_layers,\n            layer_size,\n        )\n\n        # Only perform training step with distributed model\n        dist_optim.zero_grad()\n        x = torch.rand(3, layer_size) - 0.5\n        x = x.to(dtype=torch.float32, device=\"cuda\")\n        dy = torch.rand_like(x) - 0.5\n        y = dist_model(x)\n        y.backward(dy)\n        dist_optim.step()\n\n        # Check that parameters do not match\n        for ref_param, dist_param in zip(ref_model.parameters(), dist_model.parameters()):\n            self.assertRaises(\n                AssertionError,\n                torch.testing.assert_close,\n                dist_param,\n                ref_param,\n            )\n\n    def test_clip_grad_norm(self):\n        torch.manual_seed(self.seed + self.rank)\n\n        # Identical models with data-parallel and ZeRO\n        ref_model, ref_optim, dist_model, dist_optim = make_models(1, 1)\n\n        # Training steps with pre-determined gradients\n        xs = [3, 1, 4, 1, 5, 9]\n        dys = [1, -1, 1, -1, 1, -1]\n        for x, dy in zip(xs, dys):\n            x = torch.tensor([[x]], dtype=torch.float32, device=\"cuda\")\n            dy = torch.tensor([[dy]], dtype=torch.float32, device=\"cuda\")\n\n            # Reference implementation\n            ref_optim.zero_grad()\n            y_ref = ref_model(x.detach())\n            y_ref.backward(dy.detach())\n            ref_grad_norm = torch.nn.utils.clip_grad_norm_(ref_model.parameters(), 3.5)\n            ref_optim.step()\n\n            # Distributed implementation\n            dist_optim.zero_grad()\n            y_dist = dist_model(x.detach())\n            y_dist.backward(dy.detach())\n            dist_grad_norm = dist_optim.clip_grad_norm(3.5)\n            dist_optim.step()\n\n            # Check that parameters match\n            torch.testing.assert_close(dist_grad_norm, ref_grad_norm)\n            for ref_param, dist_param in zip(ref_model.parameters(), dist_model.parameters()):\n                torch.testing.assert_close(dist_param, ref_param)\n\n    def test_grad_scaler(self):\n        torch.manual_seed(self.seed + self.rank)\n\n        # Identical models with data-parallel and ZeRO\n        ref_model, ref_optim, dist_model, dist_optim = make_models(1, 1)\n        grad_scaler_args = dict(\n            init_scale=3.21,\n            growth_factor=1.23,\n            backoff_factor=0.876,\n            growth_interval=1,\n        )\n        ref_scaler = torch.amp.GradScaler(\"cuda\", **grad_scaler_args)\n        dist_scaler = torch.amp.GradScaler(\"cuda\", **grad_scaler_args)\n\n        # Training steps with pre-determined gradients\n        xs = [3, 1, 4, 1, 5, 9]\n        dys = [1, float(\"inf\"), 1, 1, float(\"nan\"), -1]\n        for x, dy in zip(xs, dys):\n            x = torch.tensor([[x]], dtype=torch.float32, device=\"cuda\")\n            dy = torch.tensor([[dy]], dtype=torch.float32, device=\"cuda\")\n\n            # Reference implementation\n            ref_optim.zero_grad()\n            y_ref = ref_model(x.detach())\n            ref_scaler.scale(y_ref).backward(dy.detach())\n            ref_scaler.step(ref_optim)\n            ref_scaler.update()\n\n            # Distributed implementation\n            dist_optim.zero_grad()\n            y_dist = dist_model(x.detach())\n            dist_scaler.scale(y_dist).backward(dy.detach())\n            dist_scaler.step(dist_optim)\n            dist_scaler.update()\n\n            # Check that parameters match\n            for ref_param, dist_param in zip(ref_model.parameters(), dist_model.parameters()):\n                torch.testing.assert_close(dist_param, ref_param)\n\n    def test_checkpoint(\n        self,\n        rtol: Optional[float] = None,\n        atol: Optional[float] = None,\n        num_layers: int = 2,\n        layer_size: int = 2,\n        num_steps: int = 3,\n        save_group_size: Optional[int] = None,\n        load_group_size: Optional[int] = None,\n        save_model_kwargs: Optional[dict] = None,\n        load_model_kwargs: Optional[dict] = None,\n    ):\n        \"\"\"Test state_dict and load_state_dict functions\n\n        Two models are constructed, possibly on different process\n        groups. One of the models is trained for a few steps, a\n        checkpoint is saved, and the checkpoint is loaded on the other\n        model. Both models are then trained for a few steps and\n        checked to make sure that they produce identical results.\n\n        Arguments:\n            rtol (float): Relative tolerance for numerical checks (see\n                torch.allclose).\n            atol (float): Absolute tolerance for numerical checks (see\n                torch.allclose).\n            num_layers (int): Number of layers in test model.\n            layer_size (int): Number of features in model layers.\n            num_steps (int): Number of training steps to perform\n                before and after checkpointing.\n            save_group_size (int): Process group size for model that\n                saves the checkpoint. Uses the default process group\n                by default.\n            load_group_size (int): Process group size for model that\n                loads the checkpoint. Uses the default process group\n                by default.\n            save_model_kwargs (dict): keyword arguments passed to\n                make_models when constructing the model that saves the\n                checkpoint.\n            load_model_kwargs (dict): keyword arguments passed to\n                make_models when constructing the model that loads the\n                checkpoint.\n\n        \"\"\"\n\n        # Initialize process groups\n        world_size = torch.distributed.get_world_size()\n        if save_group_size is None:\n            save_group_size = world_size\n            save_group = None\n        else:\n            if save_group_size > world_size:\n                self.skipTest(f\"Requires {save_group_size} ranks, found {world_size}\")\n            save_ranks = list(range(save_group_size))\n            save_group = torch.distributed.new_group(ranks=save_ranks)\n        if load_group_size is None:\n            load_group_size = world_size\n            load_group = None\n        else:\n            if load_group_size > world_size:\n                self.skipTest(f\"Requires {load_group_size} ranks, found {world_size}\")\n            load_ranks = list(range(load_group_size))\n            load_group = torch.distributed.new_group(ranks=load_ranks)\n\n        # Construct two models with same config and different params\n        torch.manual_seed(self.seed)\n        if self.rank < save_group_size:\n            if not save_model_kwargs:\n                save_model_kwargs = {}\n            _, _, model_save, optim_save = make_models(\n                num_layers,\n                layer_size,\n                lr=0.1,\n                process_group=save_group,\n                average_grad_sync=False,\n                overlap_communication=False,\n                **save_model_kwargs,\n            )\n            optim_save.init_params(reversed(list(model_save.parameters())))\n        torch.manual_seed(self.seed + 1)\n        if self.rank < load_group_size:\n            if not load_model_kwargs:\n                load_model_kwargs = {}\n            _, _, model_load, optim_load = make_models(\n                num_layers,\n                layer_size,\n                lr=1234.0,\n                process_group=load_group,\n                average_grad_sync=False,\n                overlap_communication=False,\n                **load_model_kwargs,\n            )\n            optim_load.init_params(list(model_load.parameters()))\n\n        batch_size = 2 * save_group_size * load_group_size\n\n        def make_global_batch() -> torch.Tensor:\n            \"\"\"Generate random tensor on root rank and broadcast\"\"\"\n            x = torch.empty(batch_size, layer_size, device=\"cuda\")\n            if self.rank == 0:\n                torch.rand(x.size(), out=x)\n                x -= 0.5\n            torch.distributed.broadcast(x, src=0)\n            return x\n\n        def to_local_batch(\n            global_batch: torch.Tensor,\n            group: Optional[torch.distributed.ProcessGroup],\n        ) -> Optional[torch.Tensor]:\n            \"\"\"Get local portion of tensor that is replicated across all ranks\"\"\"\n            group_size = torch.distributed.get_world_size(group)\n            if group_size < 0:\n                return None\n            local_batch_size = batch_size // group_size\n            batch_start = self.rank * local_batch_size\n            batch_end = (self.rank + 1) * local_batch_size\n            return global_batch[batch_start:batch_end, ...]\n\n        def to_global_batch(\n            local_batch: torch.Tensor,\n            group: Optional[torch.distributed.ProcessGroup],\n        ) -> torch.Tensor:\n            \"\"\"Gather distributed tensor and broadcast to all ranks\"\"\"\n\n            # Allocate buffer\n            global_batch = torch.empty(batch_size, layer_size, device=\"cuda\")\n\n            # Gather data on root rank\n            group_size = torch.distributed.get_world_size(group)\n            if group_size > 0:\n                local_batches = None\n                if self.rank == 0:\n                    local_batch_size = batch_size // group_size\n                    local_batches = [\n                        global_batch[rank * local_batch_size : (rank + 1) * local_batch_size, ...]\n                        for rank in range(group_size)\n                    ]\n                torch.distributed.gather(\n                    local_batch,\n                    local_batches,\n                    dst=0,\n                    group=group,\n                )\n\n            # Broadcast data to all ranks\n            torch.distributed.broadcast(global_batch, src=0)\n            return global_batch\n\n        # Train one of the models\n        torch.manual_seed(self.seed + 2)\n        for step in range(num_steps):\n            if self.rank < save_group_size:\n                optim_save.zero_grad()\n            x = make_global_batch()\n            dy = make_global_batch()\n            if self.rank < save_group_size:\n                x = to_local_batch(x, save_group)\n                dy = to_local_batch(dy, save_group)\n                y = model_save(x)\n                y.backward(dy)\n                optim_save.step()\n\n        # Make sure models are different\n        if self.rank < min(save_group_size, load_group_size):\n            for param_save, param_load in zip(model_save.parameters(), model_load.parameters()):\n                self.assertRaises(\n                    AssertionError,\n                    torch.testing.assert_close,\n                    param_load,\n                    param_save,\n                    rtol=rtol,\n                    atol=atol,\n                )\n\n        # Save state\n        state_bytes = None\n        if self.rank < save_group_size:\n            state_dict = {\n                \"model\": model_save.state_dict(),\n                \"optim\": optim_save.state_dict(),\n            }\n            byte_stream = io.BytesIO()\n            torch.save(state_dict, byte_stream)\n            state_bytes = byte_stream.getvalue()\n\n        # Broadcast state from root rank and load\n        if self.rank < load_group_size:\n            if load_group_size != save_group_size:\n                if self.rank != 0:\n                    state_bytes = None\n                state_bytes = [state_bytes]\n                torch.distributed.broadcast_object_list(\n                    state_bytes,\n                    src=0,\n                    group=load_group,\n                )\n                state_bytes = state_bytes[0]\n            state_dict = torch.load(io.BytesIO(state_bytes))\n            model_load.load_state_dict(state_dict[\"model\"])\n            optim_load.load_state_dict(state_dict[\"optim\"])\n\n        # Make sure models are identical\n        if self.rank < min(save_group_size, load_group_size):\n            for param_save, param_load in zip(model_save.parameters(), model_load.parameters()):\n                torch.testing.assert_close(param_load, param_save, rtol=rtol, atol=atol)\n\n        # Train both models\n        torch.manual_seed(self.seed + 3)\n        for step in range(num_steps):\n            # Reset grads\n            if self.rank < save_group_size:\n                optim_save.zero_grad()\n            if self.rank < load_group_size:\n                optim_load.zero_grad()\n\n            # Synthetic data\n            x = make_global_batch()\n            dy = make_global_batch()\n\n            # Training step for model that saved checkpoint\n            y_save = None\n            dx_save = None\n            if self.rank < save_group_size:\n                x_save = to_local_batch(x, save_group)\n                x_save = x_save.detach().clone().requires_grad_(True)\n                dy_save = to_local_batch(dy, save_group)\n                y_save = model_save(x_save)\n                y_save.backward(dy_save)\n                dx_save = x_save.grad\n            y_save = to_global_batch(y_save, save_group)\n            dx_save = to_global_batch(dx_save, save_group)\n\n            # Training step for model that loaded checkpoint\n            y_load = None\n            dx_load = None\n            if self.rank < load_group_size:\n                x_load = to_local_batch(x, load_group)\n                x_load = x_load.detach().clone().requires_grad_(True)\n                dy_load = to_local_batch(dy, load_group)\n                y_load = model_load(x_load)\n                y_load.backward(dy_load)\n                dx_load = x_load.grad\n            y_load = to_global_batch(y_load, load_group)\n            dx_load = to_global_batch(dx_load, load_group)\n\n            # Check that data tensors match\n            torch.testing.assert_close(y_load, y_save, rtol=rtol, atol=atol)\n            torch.testing.assert_close(dx_load, dx_save, rtol=rtol, atol=atol)\n\n            # Optimizer step\n            if self.rank < save_group_size:\n                optim_save.step()\n            if self.rank < load_group_size:\n                optim_load.step()\n\n            # Check that parameters match\n            if self.rank < min(save_group_size, load_group_size):\n                for param_save, param_load in zip(model_save.parameters(), model_load.parameters()):\n                    torch.testing.assert_close(\n                        param_load,\n                        param_save,\n                        rtol=rtol,\n                        atol=atol,\n                    )\n\n    def test_checkpoint_save_1gpu(self):\n        \"\"\"Test loading checkpoint with one GPU\"\"\"\n        self.test_checkpoint(save_group_size=1)\n\n    def test_checkpoint_load_1gpu(self):\n        \"\"\"Test saving checkpoint with one GPU\"\"\"\n        self.test_checkpoint(load_group_size=1)\n\n    def test_checkpoint_bf16(self):\n        \"\"\"Test checkpoint with BF16 model\"\"\"\n        self.test_checkpoint(\n            rtol=5e-2,\n            atol=1e-5,\n            save_model_kwargs=dict(\n                model_dtype=torch.bfloat16,\n                optim_dtype=torch.float32,\n                param_sync_dtype=torch.bfloat16,\n                store_params=False,\n                store_param_remainders=True,\n            ),\n            load_model_kwargs=dict(\n                model_dtype=torch.bfloat16,\n                optim_dtype=torch.float32,\n                param_sync_dtype=torch.bfloat16,\n                store_params=False,\n                store_param_remainders=True,\n            ),\n        )\n\n    def test_checkpoint_scaled_state(self):\n        \"\"\"Test checkpoint with scaled FP16 state\"\"\"\n        self.test_checkpoint(\n            rtol=5e-2,\n            atol=1e-5,\n            save_model_kwargs=dict(\n                model_dtype=torch.bfloat16,\n                optim_dtype=torch.float16,\n                param_sync_dtype=torch.int,\n                store_params=True,\n                with_scaled_states=True,\n            ),\n            load_model_kwargs=dict(\n                model_dtype=torch.bfloat16,\n                optim_dtype=torch.float16,\n                param_sync_dtype=torch.int,\n                store_params=True,\n                with_scaled_states=True,\n            ),\n        )\n\n    def test_bucket_low_utilization_warning(self):\n        \"\"\"Test warning when bucket utilization is low\"\"\"\n        layer_size = 2 * 1024 * 1024\n        num_layers = 4\n        fairish_bucket_cap_mb = 4 * num_layers * layer_size / (1024 * 1024)\n\n        # Check that warning is raised when bucket utilization is low\n        with self.assertWarnsRegex(Warning, \".*Consider decreasing the bucket_cap_mb argument.\"):\n            self.test_matches_pytorch(\n                num_layers=num_layers,\n                layer_size=layer_size,\n                overlap_communication=False,\n                bucket_cap_mb=fairish_bucket_cap_mb * 2,\n            )\n\n        # Check that warning is not raised when bucket utilization is high\n        with warnings.catch_warnings(record=True) as warns:\n            self.test_matches_pytorch(\n                num_layers=num_layers,\n                layer_size=layer_size,\n                overlap_communication=False,\n                bucket_cap_mb=fairish_bucket_cap_mb,\n            )\n            for w in warns:\n                self.assertNotRegex(\n                    str(w.message), \".*Consider decreasing the bucket_cap_mb argument.\"\n                )\n\n    def test_cuda_graph(self):\n        \"\"\"Test distributed adam with CUDA graph\"\"\"\n        if self.world_size <= 8:\n            self.skipTest(f\"{self.world_size=} is expected to be >= 8\")\n        self.test_matches_pytorch(\n            rtol=5e-3,\n            atol=1e-5,\n            num_steps=15,\n            micro_batch_steps=1,\n            model_dtype=torch.float16,\n            optim_dtype=torch.float16,\n            contiguous_buffers=True,\n            with_cuda_graph=True,\n        )\n\n\nif __name__ == \"__main__\":\n    # Assume script has been run with torchrun\n    common_utils.run_tests()\n"
  },
  {
    "path": "apex/contrib/test/optimizers/test_distributed_fused_lamb.py",
    "content": "import inspect\n\nimport torch\nfrom torch.cuda.amp import GradScaler\nfrom torch.testing._internal import common_utils\nfrom torch.distributed.distributed_c10d import _coalescing_manager\n\nfrom apex.contrib.optimizers.distributed_fused_lamb import DistributedFusedLAMB\nfrom apex.distributed_testing.distributed_test_base import NcclDistributedTestBase\n\n\ndef flat_dist_call(param_list: list[torch.Tensor], op, args):\n    with _coalescing_manager(async_ops=True) as cm:\n        for p in param_list:\n            op(p, *args)\n\n    cm.wait()\n\n\ndef get_init_weights_func():\n    @torch.no_grad()\n    def init_weights(m):\n        if isinstance(m, torch.nn.Linear):\n            m.weight.fill_(1.0)\n\n    return init_weights\n\n\nclass ModelFoo(torch.nn.Module):\n    def __init__(self):\n        super(ModelFoo, self).__init__()\n        self.linear = torch.nn.Linear(128, 128, bias=False)\n        self.loss = torch.nn.MSELoss()\n\n    def forward(self, input_tensor, gt):\n        y = self.linear(input_tensor)\n        loss = self.loss(y, gt)\n        return loss\n\n\n# A test for distributed fused Lamb optimizer: run several iterations and see if loss decreases\n# There are two instances of the same test because based on `world_size` the optimizer decides what collectives operation to use.\n# If torch.distributed.get_world_size() == torch.cuda.device_count() it uses only `all_gather`.\n# If torch.distributed.get_world_size() < torch.cuda.device_count() it uses both `all_gather` and `reduce_scatter`.\nclass NcclDistributedFusedLAMB(NcclDistributedTestBase):\n    @property\n    def world_size(self) -> int:\n        return torch.cuda.device_count()\n\n    @common_utils.parametrize(\"no_copy\", [False, True])\n    @common_utils.parametrize(\n        \"opt_kwargs\",\n        [\n            dict(\n                overlap_reductions=True,\n                dwu_num_blocks=2,\n                dwu_num_chunks=2,\n                fused_norm=False,\n                fuse_scale=False,\n                clip_after_ar=True,\n                full_ar=False,\n            ),\n            dict(\n                overlap_reductions=False,\n                dwu_num_blocks=1,\n                dwu_num_chunks=1,\n                fused_norm=True,\n                fuse_scale=True,\n                clip_after_ar=False,\n            ),\n        ],\n    )\n    def test_distributed_fused_lamb(self, no_copy, opt_kwargs):\n        if (\n            no_copy\n            and \"no_copy\" not in inspect.getfullargspec(torch.distributed.reduce_scatter).args\n        ):\n            self.skipTest(\"does not support no_copy\")\n        if no_copy and \"no_copy\" not in inspect.getfullargspec(torch.distributed.all_gather).args:\n            self.skipTest(\"does not support no_copy\")\n\n        assert torch.distributed.is_initialized()\n        gpu_count = torch.distributed.get_world_size()\n\n        init_scale = 100\n        lr = torch.tensor(0.1).cuda()\n        grad_scaler = GradScaler(init_scale=init_scale, growth_interval=1000)\n\n        model = ModelFoo()\n        model = model.cuda().half()\n        model.apply(get_init_weights_func())\n\n        param_optimizer = list(model.named_parameters())\n        no_decay = [\"bias\", \"gamma\", \"beta\", \"LayerNorm\"]\n        optimizer_grouped_parameters = [\n            {\n                \"params\": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n                \"weight_decay\": 0.01,\n            },\n            {\n                \"params\": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],\n                \"weight_decay\": 0.0,\n            },\n        ]\n\n        if \"full_ar\" not in opt_kwargs:\n            opt_kwargs[\"full_ar\"] = gpu_count == torch.cuda.device_count()\n\n        # Aidyn-A: not sure what parameters are the best for testing purposes,\n        # setting up whatever I think appropriate.\n        optimizer = DistributedFusedLAMB(\n            optimizer_grouped_parameters,\n            lr=0.1,\n            betas=(0.9, 0.9),\n            eps=1e-6,\n            max_grad_norm=1.0,\n            dwu_group_size=gpu_count,\n            dwu_num_rs_pg=1,\n            dwu_num_ar_pg=1,\n            dwu_num_ag_pg=1,\n            use_nvlamb=False,\n            set_param_views_to_flat_buffer=False,\n            e5m2_allgather=False,\n            **opt_kwargs,\n        )\n        optimizer.set_global_scale(init_scale)\n\n        optimizer._reduce_scatter_no_copy = no_copy\n        optimizer._all_gather_no_copy = no_copy\n\n        flat_dist_call(\n            [param.data for param in model.parameters()],\n            torch.distributed.broadcast,\n            (0,),\n        )\n\n        x = torch.randn(4096, 128, dtype=torch.float16).cuda()\n        y = torch.randn(4096, 128, dtype=torch.float16).cuda()\n\n        losses = []\n        for _ in range(10):\n            loss = model(x, y)\n            optimizer._lazy_init_stage1()\n            grad_scaler.scale(loss).backward()\n            optimizer._lazy_init_stage2()\n            optimizer._lr = lr\n            optimizer.complete_reductions()\n            optimizer.set_global_scale(grad_scaler._get_scale_async())\n            grad_scaler.step(optimizer)\n            grad_scaler.update()\n            optimizer.zero_grad(set_to_none=True)\n\n            losses.append(loss.item())\n\n        self.assertTrue(losses == sorted(losses, reverse=True))\n\n\ncommon_utils.instantiate_parametrized_tests(NcclDistributedFusedLAMB)\n\n\nclass NcclDistributedFusedLAMB_partial_ar(NcclDistributedFusedLAMB):\n    @property\n    def world_size(self) -> int:\n        return max(torch.cuda.device_count() - 1, 1)\n\n\nif __name__ == \"__main__\":\n    common_utils.run_tests()\n"
  },
  {
    "path": "apex/contrib/test/peer_memory/__init__.py",
    "content": ""
  },
  {
    "path": "apex/contrib/test/peer_memory/test_peer_halo_exchange_module.py",
    "content": "import unittest\n\nimport torch\nfrom torch.testing._internal import common_utils\n\nSKIP_TEST = None\nfrom apex.distributed_testing.distributed_test_base import NcclDistributedTestBase\n\ntry:\n    from apex.contrib.peer_memory import PeerMemoryPool, PeerHaloExchanger1d\nexcept ImportError as e:\n    SKIP_TEST = e\n\n# How to run:\n# python /path/to/test_peer_halo_exchange_module.py\n\n\n# Output of this function is used as ground truth in module tests.\ndef nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_split):\n    if explicit_nhwc:\n        if H_split:\n            _, Hp, _, _ = list(y.shape)\n            H = Hp - 2 * half_halo\n            top_out_halo = y[:, half_halo : 2 * half_halo, :, :]\n            top_inp_halo = y[:, :half_halo, :, :]\n            btm_out_halo = y[:, H : H + half_halo, :, :]\n            btm_inp_halo = y[:, H + half_halo : H + 2 * half_halo, :, :]\n        else:\n            _, _, Wp, _ = list(y.shape)\n            W = Wp - 2 * half_halo\n            top_out_halo = y[:, :, half_halo : 2 * half_halo, :]\n            top_inp_halo = y[:, :, :half_halo, :]\n            btm_out_halo = y[:, :, W : W + half_halo, :]\n            btm_inp_halo = y[:, :, W + half_halo : W + 2 * half_halo, :]\n    else:\n        if H_split:\n            _, _, Hp, _ = list(y.shape)\n            H = Hp - 2 * half_halo\n            top_out_halo = y[:, :, half_halo : 2 * half_halo, :]\n            top_inp_halo = y[:, :, :half_halo, :]\n            btm_out_halo = y[:, :, H : H + half_halo, :]\n            btm_inp_halo = y[:, :, H + half_halo : H + 2 * half_halo, :]\n        else:\n            _, _, _, Wp = list(y.shape)\n            W = Wp - 2 * half_halo\n            top_out_halo = y[:, :, :, half_halo : 2 * half_halo]\n            top_inp_halo = y[:, :, :, :half_halo]\n            btm_out_halo = y[:, :, :, W : W + half_halo]\n            btm_inp_halo = y[:, :, :, W + half_halo : W + 2 * half_halo]\n\n    mf = (\n        torch.channels_last\n        if y.is_contiguous(memory_format=torch.channels_last)\n        else torch.contiguous_format\n    )\n    top_out_halo = top_out_halo.contiguous()\n    btm_out_halo = btm_out_halo.contiguous()\n\n    top_inp_halos = [torch.empty_like(top_out_halo) for _ in range(peer_group_size)]\n    torch.distributed.all_gather(top_inp_halos, top_out_halo)\n    btm_inp_halos = [torch.empty_like(btm_out_halo) for _ in range(peer_group_size)]\n    torch.distributed.all_gather(btm_inp_halos, btm_out_halo)\n    top_rank = (peer_rank + peer_group_size - 1) % peer_group_size\n    btm_rank = (peer_rank + 1) % peer_group_size\n    if peer_rank == 0:\n        top_inp_halo.zero_()\n    else:\n        top_inp_halo.copy_(btm_inp_halos[top_rank].to(memory_format=mf))\n    if peer_rank == peer_group_size - 1:\n        btm_inp_halo.zero_()\n    else:\n        btm_inp_halo.copy_(top_inp_halos[btm_rank].to(memory_format=mf))\n\n\ndef single_test(\n    peer_rank,\n    peer_group_size,\n    halo_ex,\n    C,\n    H,\n    W,\n    half_halo,\n    dtype,\n    memory_format,\n    H_split,\n    num_steps,\n    numSM=1,\n):\n    if memory_format == 1:\n        # 1 -> explicit nhwc\n        explicit_nhwc = True\n        if H_split:\n            y = torch.randn([1, H + 2 * half_halo, W, C], dtype=dtype, device=\"cuda\")\n            ym = y[:, half_halo : H + half_halo, :, :]\n        else:\n            y = torch.randn([1, H, W + 2 * half_halo, C], dtype=dtype, device=\"cuda\")\n            ym = y[:, :, half_halo : W + half_halo, :]\n    else:\n        # 2 -> native nhwc\n        # 3 -> nchw\n        explicit_nhwc = False\n        if H_split:\n            y = torch.randn([1, C, H + 2 * half_halo, W], dtype=dtype, device=\"cuda\")\n            if memory_format == 2:\n                y = y.to(memory_format=torch.channels_last)\n            ym = y[:, :, half_halo : H + half_halo, :]\n        else:\n            y = torch.randn([1, C, H, W + 2 * half_halo], dtype=dtype, device=\"cuda\")\n            if memory_format == 2:\n                y = y.to(memory_format=torch.channels_last)\n            ym = y[:, :, :, half_halo : W + half_halo]\n    y3 = y.clone()\n    list_y = []\n    for step in range(num_steps):\n        halo_ex(y, H_split, explicit_nhwc, numSM)\n        list_y.append(y.clone())\n        y.copy_(y3)\n        halo_ex.peer_pool.reset()\n        torch.distributed.barrier()\n    y2 = y3.clone()\n    list_y2 = []\n    for step in range(num_steps):\n        nccl_halo_ex(peer_rank, peer_group_size, y2, half_halo, explicit_nhwc, H_split)\n        list_y2.append(y2.clone())\n        y2.copy_(y3)\n    if memory_format == 1:\n        memory_format_str = \"explicit_nhwc\"\n    elif memory_format == 2:\n        memory_format_str = \"native nhwc\"\n    elif memory_format == 3:\n        memory_format_str = \"nchw\"\n    else:\n        memory_format_str = \"???\"\n    torch.testing.assert_close(list_y, list_y2, msg=memory_format_str)\n    # is_equal = [torch.all(torch.eq(yy, yy2)) for yy, yy2 in zip(list_y, list_y2)]\n    # is_equal = torch.tensor(is_equal, dtype=torch.bool)\n    # is_equal = torch.all(is_equal)\n    # if peer_rank == 0:\n    #     if is_equal:\n    #         print(\n    #             \"SUCCESS : N,C,H,W = 1,%d,%d,%d, half_halo=%d, %s, %s, %s\"\n    #             % (\n    #                 C,\n    #                 H,\n    #                 W,\n    #                 half_halo,\n    #                 str(dtype),\n    #                 memory_format_str,\n    #                 \"H-split\" if H_split else \"W-split\",\n    #             )\n    #         )\n    #     else:\n    #         print(\n    #             \"FAILURE : N,C,H,W = 1,%d,%d,%d, half_halo=%d, %s, %s, %s\"\n    #             % (\n    #                 C,\n    #                 H,\n    #                 W,\n    #                 half_halo,\n    #                 str(dtype),\n    #                 memory_format_str,\n    #                 \"H-split\" if H_split else \"W-split\",\n    #             )\n    #         )\n    #\n    # peer memory flag sync relies on there being at least one barrier per step\n    # torch.distributed.barrier()\n\n\ndef H_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_steps):\n    Hr = 8 * world_size\n    Hp = ((H + Hr - 1) // Hr) * 8\n\n    for i in range(4):\n        div = int(pow(2, i))\n        single_test(\n            rank,\n            world_size,\n            halo_ex,\n            C * div,\n            Hp // div,\n            W // div,\n            half_halo,\n            torch.float16,\n            1,\n            True,\n            num_steps,\n        )\n        single_test(\n            rank,\n            world_size,\n            halo_ex,\n            C * div,\n            Hp // div,\n            W // div,\n            half_halo,\n            torch.float16,\n            2,\n            True,\n            num_steps,\n        )\n        single_test(\n            rank,\n            world_size,\n            halo_ex,\n            C * div,\n            Hp // div,\n            W // div,\n            half_halo,\n            torch.float16,\n            3,\n            True,\n            num_steps,\n        )\n\n\ndef W_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_steps):\n    Wr = 8 * world_size\n    Wp = ((W + Wr - 1) // Wr) * 8\n\n    for i in range(4):\n        div = int(pow(2, i))\n        single_test(\n            rank,\n            world_size,\n            halo_ex,\n            C * div,\n            H // div,\n            Wp // div,\n            half_halo,\n            torch.float16,\n            1,\n            False,\n            num_steps,\n        )\n        single_test(\n            rank,\n            world_size,\n            halo_ex,\n            C * div,\n            H // div,\n            Wp // div,\n            half_halo,\n            torch.float16,\n            2,\n            False,\n            num_steps,\n        )\n        single_test(\n            rank,\n            world_size,\n            halo_ex,\n            C * div,\n            H // div,\n            Wp // div,\n            half_halo,\n            torch.float16,\n            3,\n            False,\n            num_steps,\n        )\n\n\ndef main():\n    # for this trivial example peer_rank == rank and peer_group_size == world_size\n\n    torch.distributed.init_process_group(\"nccl\")\n    rank = torch.distributed.get_rank()\n    world_size = torch.distributed.get_world_size()\n    torch.cuda.set_device(rank)\n    peer_ranks = [i for i in range(world_size)]\n    pool = PeerMemoryPool(0, 2 * 1024 * 1024, peer_ranks)\n\n    num_steps = 100\n\n    half_halo = 1\n    halo_ex = PeerHaloExchanger1d(peer_ranks, rank, pool, half_halo)\n\n    H_split_tests(1, 64, 336, 200, half_halo, rank, world_size, halo_ex, num_steps)\n    W_split_tests(1, 64, 200, 336, half_halo, rank, world_size, halo_ex, num_steps)\n\n\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\nclass TestPeerMemory(NcclDistributedTestBase):\n    HALF_HALO = 1\n    NUM_STEPS = 100\n\n    @property\n    def world_size(self) -> int:\n        return min(torch.cuda.device_count(), 2)\n\n    # TODO(crcrpar): Check if `world_size` being multiple of 2 is must.\n    def _check_world_size_and_may_skip(self) -> None:\n        if not (self.world_size >= 2 and self.world_size % 2 == 0):\n            self.skipTest(f\"world_size is expected to be a multiple of 2 but, {self.world_size}\")\n\n    def get_halo_excnahger_1d(self):\n        peer_ranks = [i for i in range(self.world_size)]\n        pool = PeerMemoryPool(64 * 1024, 2 * 1024 * 1024, peer_ranks)\n        halo_exchanger_1d = PeerHaloExchanger1d(\n            peer_ranks, self.rank, pool, TestPeerMemory.HALF_HALO\n        )\n        return halo_exchanger_1d\n\n    def test_height_split(self):\n        self._check_world_size_and_may_skip()\n        H_split_tests(\n            1,\n            64,\n            336,\n            200,\n            TestPeerMemory.HALF_HALO,\n            self.rank,\n            self.world_size,\n            self.get_halo_excnahger_1d(),\n            TestPeerMemory.NUM_STEPS,\n        )\n\n    def test_width_split(self):\n        self._check_world_size_and_may_skip()\n        W_split_tests(\n            1,\n            64,\n            200,\n            336,\n            TestPeerMemory.HALF_HALO,\n            self.rank,\n            self.world_size,\n            self.get_halo_excnahger_1d(),\n            TestPeerMemory.NUM_STEPS,\n        )\n\n\nif __name__ == \"__main__\":\n    common_utils.run_tests()\n"
  },
  {
    "path": "apex/contrib/test/transducer/__init__.py",
    "content": ""
  },
  {
    "path": "apex/contrib/test/transducer/test_transducer_joint.py",
    "content": "import unittest\n\nimport torch\n\nSKIP_TEST = None\ntry:\n    from apex.contrib.transducer import TransducerJoint\n    from apex.contrib.transducer import _transducer_ref as transducer_ref\nexcept ImportError as e:\n    SKIP_TEST = e\n\n\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\nclass TransducerJointTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n\n    def gen_input(self, for_vector_kernel):\n        self.B = 4\n        T_min = 51\n        T_max = 101\n        U_min = 12\n        U_max = 25\n        if for_vector_kernel:\n            H = 512\n        else:\n            H = 509\n        dtype = torch.float16\n        device = \"cuda\"\n\n        self.f_tst = torch.randn((self.B, T_max, H), dtype=dtype, requires_grad=True, device=device)\n        self.g_tst = torch.randn((self.B, U_max, H), dtype=dtype, requires_grad=True, device=device)\n        self.h_grad = torch.randn(self.B, T_max, U_max, H, dtype=dtype, device=device)\n        self.f_len = torch.randint(T_min, T_max + 1, (self.B,), dtype=torch.int, device=device)\n        self.g_len = torch.randint(U_min, U_max + 1, (self.B,), dtype=torch.int, device=device)\n        self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max\n        self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max\n        self.dropout_prob = 0.5\n\n        # Make sure gradients from out-of-bound locations are zero. This should be guaranteed by\n        # the loss function\n        for b in range(self.B):\n            self.h_grad[b, self.f_len[b] :, :, :] = 0\n            self.h_grad[b, :, self.g_len[b] :, :] = 0\n        self.h_grad_packed = self._pack(self.h_grad, self.f_len, self.g_len)\n\n    def _pack(self, x, f_len, g_len):\n        B = x.size(0)\n        list_x = []\n        for b in range(B):\n            list_x_row = [x[b, t, : g_len[b]] for t in range(f_len[b])]\n            x_row = torch.cat(list_x_row)\n            list_x.append(x_row)\n        x_packed = torch.cat(list_x).data.clone()\n        x_packed.requires_grad = True\n        batch_offset = torch.cumsum(f_len * g_len, dim=0)\n        return x_packed\n\n    def _unpack(self, x, f_len, g_len):\n        batch_offset = torch.cumsum(f_len * g_len, dim=0)\n        x_unpacked = torch.zeros_like(self.h_grad, dtype=torch.uint8)\n        B = self.h_grad.size(0)\n        H = self.h_grad.size(-1)\n        for b in range(B):\n            my_batch_offset = 0 if b == 0 else batch_offset[b - 1]\n            my_f_len = f_len[b]\n            my_g_len = g_len[b]\n            for t in range(my_f_len):\n                x_unpacked[b, t, :my_g_len] = x[\n                    my_batch_offset + t * my_g_len : my_batch_offset + t * my_g_len + my_g_len\n                ]\n        return x_unpacked\n\n    def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout):\n        self.gen_input(for_vector_kernel=for_vector_kernel)\n        # Generate reference\n        f_ref = self.f_tst.data.clone()\n        g_ref = self.g_tst.data.clone()\n        f_ref.requires_grad = True\n        g_ref.requires_grad = True\n\n        my_joint = TransducerJoint(\n            pack_output=pack_output,\n            relu=relu,\n            dropout=dropout,\n            dropout_prob=self.dropout_prob,\n            probe_mask=True,\n        )\n        if not pack_output:\n            h_tst = my_joint(f=self.f_tst, g=self.g_tst, f_len=self.f_len, g_len=self.g_len)\n            h_tst.backward(self.h_grad)\n            if dropout:\n                mask = my_joint.mask_probe[0]\n        else:\n            batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0)\n            h_tst = my_joint(\n                f=self.f_tst,\n                g=self.g_tst,\n                f_len=self.f_len,\n                g_len=self.g_len,\n                batch_offset=batch_offset,\n                packed_batch=batch_offset[-1],\n            )\n            h_tst.backward(self.h_grad_packed)\n            if dropout:\n                mask_packed = my_joint.mask_probe[0]\n                mask = self._unpack(mask_packed, self.f_len, self.g_len)\n\n        # reference\n        h_ref, f_grad_ref, g_grad_ref = transducer_ref.transducer_joint_reference(\n            f=f_ref,\n            g=g_ref,\n            h_grad=self.h_grad,\n            f_len=self.f_len,\n            g_len=self.g_len,\n            pack_output=pack_output,\n            relu=relu,\n            dropout=dropout,\n            dropout_prob=self.dropout_prob,\n            mask=mask if dropout else None,\n        )\n\n        f_grad_tst = self.f_tst.grad\n        g_grad_tst = self.g_tst.grad\n\n        torch.testing.assert_close(h_ref, h_tst, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(f_grad_ref, f_grad_tst, atol=5e-5, rtol=1e-3)\n        torch.testing.assert_close(g_grad_ref, g_grad_tst, atol=1e-3, rtol=1e-3)\n\n    def test_transducer_joint(self):\n        self.run_transducer_joint(\n            for_vector_kernel=True, pack_output=True, relu=False, dropout=False\n        )\n\n    def test_transducer_joint_vec(self):\n        self.run_transducer_joint(\n            for_vector_kernel=True, pack_output=False, relu=False, dropout=False\n        )\n\n    def test_transducer_joint_pack(self):\n        self.run_transducer_joint(\n            for_vector_kernel=False, pack_output=True, relu=False, dropout=False\n        )\n\n    def test_transducer_joint_vec_pack(self):\n        self.run_transducer_joint(\n            for_vector_kernel=True, pack_output=True, relu=False, dropout=False\n        )\n\n    def test_transducer_joint_relu(self):\n        self.run_transducer_joint(\n            for_vector_kernel=True, pack_output=True, relu=True, dropout=False\n        )\n\n    def test_transducer_joint_vec_relu(self):\n        self.run_transducer_joint(\n            for_vector_kernel=True, pack_output=False, relu=True, dropout=False\n        )\n\n    def test_transducer_joint_pack_relu(self):\n        self.run_transducer_joint(\n            for_vector_kernel=False, pack_output=True, relu=True, dropout=False\n        )\n\n    def test_transducer_joint_vec_pack_relu(self):\n        self.run_transducer_joint(\n            for_vector_kernel=True, pack_output=True, relu=True, dropout=False\n        )\n\n    @unittest.expectedFailure\n    def test_transducer_joint_relu_dropout(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)\n\n    @unittest.expectedFailure\n    def test_transducer_joint_vec_relu_dropout(self):\n        self.run_transducer_joint(\n            for_vector_kernel=True, pack_output=False, relu=True, dropout=True\n        )\n\n    @unittest.expectedFailure\n    def test_transducer_joint_pack_relu_dropout(self):\n        self.run_transducer_joint(\n            for_vector_kernel=False, pack_output=True, relu=True, dropout=True\n        )\n\n    @unittest.expectedFailure\n    def test_transducer_joint_vec_pack_relu_dropout(self):\n        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "apex/contrib/test/transducer/test_transducer_loss.py",
    "content": "import unittest\n\nimport torch\n\nSKIP_TEST = None\ntry:\n    from apex.contrib.transducer import TransducerLoss\n    from apex.contrib.transducer import _transducer_ref as transducer_ref\nexcept ImportError as e:\n    SKIP_TEST = e\n\n\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\nclass TransducerLossTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        torch.manual_seed(seed)\n\n    def gen_input(self, scalar_t, for_vector_kernel):\n        self.B = 5\n        T_min = 23\n        T_max = 51\n        U_min = 12\n        U_max = 25\n        V = 16 if for_vector_kernel else 14\n        self.blank_idx = V - 1\n        device = \"cuda\"\n\n        self.x_tst = torch.randn(\n            (self.B, T_max, U_max, V), dtype=scalar_t, requires_grad=True, device=device\n        )\n        self.y = torch.randint(\n            0, self.blank_idx, (self.B, U_max - 1), dtype=torch.int, device=device\n        )\n        self.f_len = torch.randint(T_min, T_max + 1, (self.B,), dtype=torch.int, device=device)\n        self.y_len = torch.randint(U_min - 1, U_max, (self.B,), dtype=torch.int, device=device)\n        self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max\n        self.y_len[torch.randint(0, self.B, (1,)).item()] = U_max - 1\n        self.x_tst_packed, self.batch_offset = self._pack(self.x_tst)\n        # Generate reference\n        x_ref = self.x_tst.data.clone()\n        x_ref.requires_grad = True\n        loss_grad = torch.ones(x_ref.size(0), dtype=x_ref.dtype, device=x_ref.device) / x_ref.size(\n            0\n        )\n        _, _, self.grad_ref, self.loss_ref = transducer_ref.transducer_loss_reference(\n            x=x_ref,\n            label=self.y,\n            f_len=self.f_len,\n            y_len=self.y_len,\n            blank_idx=self.blank_idx,\n            loss_grad=loss_grad,\n        )\n\n    def _pack(self, x):\n        list_x = []\n        for b in range(self.B):\n            list_x_row = [x[b, t, : self.y_len[b] + 1] for t in range(self.f_len[b])]\n            x_row = torch.cat(list_x_row)\n            list_x.append(x_row)\n        x_packed = torch.cat(list_x).data.clone()\n        x_packed.requires_grad = True\n        batch_offset = torch.cumsum(self.f_len * (self.y_len + 1), dim=0)\n        return x_packed, batch_offset\n\n    def _unpack(self, x):\n        x_unpacked = torch.zeros(\n            self.B,\n            self.f_len.max(),\n            self.y_len.max() + 1,\n            x.size(-1),\n            dtype=x.dtype,\n            device=x.device,\n        )\n        for b in range(self.B):\n            my_batch_offset = 0 if b == 0 else self.batch_offset[b - 1]\n            my_f_len = self.f_len[b]\n            my_g_len = self.y_len[b] + 1\n            for t in range(my_f_len):\n                for u in range(my_g_len):\n                    x_unpacked[b, t, u] = x[my_batch_offset + t * my_g_len + u]\n        return x_unpacked\n\n    def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input, for_vector_kernel):\n        self.gen_input(scalar_t, for_vector_kernel)\n        my_loss = TransducerLoss(\n            fuse_softmax_backward=fuse_softmax_backward, packed_input=packed_input\n        )\n        if not packed_input:\n            loss_tst = my_loss(\n                x=self.x_tst,\n                label=self.y,\n                f_len=self.f_len,\n                y_len=self.y_len,\n                blank_idx=self.blank_idx,\n            )\n            loss_tst.mean().backward()\n            grad_tst = self.x_tst.grad\n        else:\n            loss_tst = my_loss(\n                x=self.x_tst_packed,\n                label=self.y,\n                f_len=self.f_len,\n                y_len=self.y_len,\n                blank_idx=self.blank_idx,\n                batch_offset=self.batch_offset,\n                max_f_len=max(self.f_len),\n            )\n            loss_tst.mean().backward()\n            grad_tst_packed = self.x_tst_packed.grad\n            grad_tst = self._unpack(grad_tst_packed)\n\n        return loss_tst, grad_tst\n\n    def test_transducer_loss_fp32(self):\n        loss_tst, grad_tst = self.run_transducer_loss(\n            scalar_t=torch.float32,\n            fuse_softmax_backward=False,\n            packed_input=False,\n            for_vector_kernel=False,\n        )\n        torch.testing.assert_close(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(self.grad_ref, grad_tst, atol=1e-5, rtol=1e-5)\n\n    def test_transducer_loss_fp16(self):\n        loss_tst, grad_tst = self.run_transducer_loss(\n            scalar_t=torch.float16,\n            fuse_softmax_backward=False,\n            packed_input=False,\n            for_vector_kernel=False,\n        )\n        torch.testing.assert_close(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)\n\n    def test_transducer_loss_fp16_backward_fusion(self):\n        loss_tst, grad_tst = self.run_transducer_loss(\n            scalar_t=torch.float16,\n            fuse_softmax_backward=True,\n            packed_input=False,\n            for_vector_kernel=False,\n        )\n        torch.testing.assert_close(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)\n\n    def test_transducer_loss_fp16_backward_fusion_packed(self):\n        loss_tst, grad_tst = self.run_transducer_loss(\n            scalar_t=torch.float16,\n            fuse_softmax_backward=True,\n            packed_input=True,\n            for_vector_kernel=False,\n        )\n        torch.testing.assert_close(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)\n\n    def test_transducer_loss_fp16_backward_fusion_packed_vec(self):\n        loss_tst, grad_tst = self.run_transducer_loss(\n            scalar_t=torch.float16,\n            fuse_softmax_backward=True,\n            packed_input=True,\n            for_vector_kernel=True,\n        )\n        torch.testing.assert_close(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)\n        torch.testing.assert_close(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "apex/contrib/test/xentropy/__init__.py",
    "content": ""
  },
  {
    "path": "apex/contrib/test/xentropy/test_label_smoothing.py",
    "content": "import unittest\nimport random\nimport time\n\nimport numpy as np\n\nimport torch\n\nSKIP_TEST = None\ntry:\n    from apex.contrib import xentropy as label_smoothing\nexcept ImportError as e:\n    SKIP_TEST = e\n\n\ndef label_smoothing_raw(x, target, padding_idx, smoothing):\n    logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32)\n\n    non_pad_mask = target != padding_idx\n    nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))\n    nll_loss = nll_loss.squeeze(1)[non_pad_mask]\n    smooth_loss = -logprobs.mean(dim=-1)[non_pad_mask]\n    loss = (1.0 - smoothing) * nll_loss + smoothing * smooth_loss\n    return loss\n\n\ndef label_smoothing_opt_1(x, target, padding_idx, smoothing):\n    logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32)\n\n    pad_mask = target == padding_idx\n    ll_loss = logprobs.gather(dim=-1, index=target.unsqueeze(1)).squeeze(1)\n    smooth_loss = logprobs.mean(dim=-1)\n    loss = (smoothing - 1.0) * ll_loss - smoothing * smooth_loss\n    loss.masked_fill_(pad_mask, 0)\n    return loss\n\n\n@unittest.skipIf(SKIP_TEST, f\"{SKIP_TEST}\")\nclass LabelSmoothingTest(unittest.TestCase):\n    def setUp(self, seed=1234):\n        super().setUp()\n        random.seed(seed)\n        np.random.seed(seed)\n        torch.manual_seed(seed)\n\n        # Set pytorch print precision\n        torch.set_printoptions(precision=10)\n\n    def gen_test_inputs(self, N, T, H, smoothing, padding_idx, dtype=torch.half):\n        logits = torch.randn((N * T, H), dtype=dtype, device=\"cuda\", requires_grad=True)\n        labels = torch.randint(0, H, [N * T], device=\"cuda\")\n        for i in random.sample(range(N * T), N * T // 6):\n            labels[i] = padding_idx\n        half_to_float = logits.dtype == torch.half\n\n        return logits, labels, half_to_float\n\n    def print_max_diff_elem(self, ref, tst):\n        ref, tst = ref.flatten(), tst.flatten()\n        diff = (ref - tst).abs().max()\n        idx = (ref - tst).abs().argmax()\n        print(\n            \"Max atol idx: {}, diff: {:.6f}, ref: {:.6f}, tst: {:.6f}\".format(\n                idx, diff, ref[idx], tst[idx]\n            )\n        )\n\n    def _test_label_smoothing_function(self, dtype):\n        # Set label smoothing configuration\n        smoothing, padding_idx = 0.1, 0\n        N, T, H = 128, 74, 32320\n        iters = 10\n        loss_func = label_smoothing.SoftmaxCrossEntropyLoss.apply\n\n        for i in range(iters):\n            logits, labels, half_to_float = self.gen_test_inputs(N, T, H, smoothing, padding_idx)\n\n            # Run original softmax cross entropy with label smoothing\n            logits.grad = None\n            losses = label_smoothing_raw(logits, labels, padding_idx, smoothing)\n            loss = losses.sum()\n            loss.backward()\n\n            ref_loss = loss.clone().detach()\n            ref_grad = logits.grad.clone().detach()\n\n            # Run optimized softmax cross entropy with label smoothing\n            logits.grad = None\n            losses = loss_func(logits, labels, smoothing, padding_idx, half_to_float)\n            loss = losses.sum()\n            loss.backward()\n\n            val_loss = loss.clone().detach()\n            val_grad = logits.grad.clone().detach()\n\n            # Validate\n            self.print_max_diff_elem(ref_grad, val_grad)\n            torch.testing.assert_close(val_loss, ref_loss)\n            torch.testing.assert_close(val_grad, ref_grad)\n\n    def test_label_smoothing_function_fp16(self):\n        self._test_label_smoothing_function(torch.half)\n\n    def test_label_smoothing_function_bf16(self):\n        self._test_label_smoothing_function(torch.bfloat16)\n\n    def test_label_smoothing_perf(self):\n        # Set label smoothing configuration\n        smoothing, padding_idx = 0.1, 0\n        N, T, H = 128, 74, 32320\n        iters = 1000\n        loss_func = label_smoothing.SoftmaxCrossEntropyLoss.apply\n        print()\n\n        logits, labels, half_to_float = self.gen_test_inputs(N, T, H, smoothing, padding_idx)\n\n        # Run original softmax cross entropy with label smoothing\n        torch.cuda.synchronize()\n        ts = time.time()\n        for i in range(iters):\n            logits.grad = None\n            losses = label_smoothing_raw(logits, labels, padding_idx, smoothing)\n            loss = losses.sum() / N\n            loss.backward()\n        torch.cuda.synchronize()\n        print(\n            \"Raw time {:.2f} s elapsed for {} iterations, norm {:.4f}\".format(\n                time.time() - ts, iters, logits.grad.norm()\n            )\n        )\n\n        # Run optimized softmax cross entropy with label smoothing\n        torch.cuda.synchronize()\n        ts = time.time()\n        for i in range(iters):\n            logits.grad = None\n            losses = loss_func(logits, labels, smoothing, padding_idx, half_to_float)\n            loss = losses.sum() / N\n            loss.backward()\n        torch.cuda.synchronize()\n        print(\n            \"Opt time {:.2f} s elapsed for {} iterations, norm {:.4f}\".format(\n                time.time() - ts, iters, logits.grad.norm()\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "apex/contrib/torchsched/__init__.py",
    "content": "\"\"\"Graph scheduler package.\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nimport torch\nimport torch._inductor\nfrom torch._dynamo import list_backends\nfrom torch._dynamo import register_backend\nfrom torch._inductor.compile_fx import compile_fx_inner\n\nfrom .backend import get_backend\n\nif TYPE_CHECKING:\n    from collections.abc import Callable\n    from typing import Any\n\n    from torch._ops import OpOverload\n\n__all__ = [\"get_backend\", \"set_default_backend\"]\n\n# Register custom operators\ntorch.ops.import_module(\"apex.contrib.torchsched.ops\")\n\n\n# Register torch-sched backend\n# Same API as torch._inductor.compile_fx\n@register_backend\ndef torchsched(\n    model_: torch.fx.GraphModule,\n    example_inputs_: list[torch.Tensor],\n    inner_compile: Callable[..., Any] = compile_fx_inner,\n    config_patches: dict[str, Any] | None = None,\n    decompositions: dict[OpOverload, Callable[..., Any]] | None = None,\n) -> Callable:\n    backend = get_backend(backend=\"torchsched\", scheme=\"dwb\")\n    return backend(model_, example_inputs_, inner_compile, config_patches, decompositions)\n\n\n_SUPPORTED_BACKENDS = list_backends()\n_DEFAULT_BACKEND = \"inductor\"\n__torch_compile__ = torch.compile\n\n\ndef set_default_backend(backend: str) -> None:\n    \"\"\"\n    Set the default backend for torch.compile.\n\n    Parameters:\n        backend (str): The backend to use as the default for torch.compile.\n    \"\"\"\n    global _DEFAULT_BACKEND\n    assert backend in _SUPPORTED_BACKENDS, f\"Unknown backend {backend}\"\n    _DEFAULT_BACKEND = backend\n\n\ndef torchsched_compile(\n    *args: object,\n    backend: str | Callable | None = None,\n    **kwargs: object,\n) -> object:\n    \"\"\"\n    Wrap around the original torch.compile to support default backend.\n\n    Parameters:\n        *args (object): Positional arguments for torch.compile.\n        backend (Union[str, Callable, None]): The backend to use.\n            If None, the default backend is used.\n        **kwargs (object): Additional keyword arguments for torch.compile.\n\n    Returns:\n        object: Compiler or compiled model.\n    \"\"\"\n    if backend is None:\n        backend = _DEFAULT_BACKEND\n    return __torch_compile__(*args, backend=backend, **kwargs)\n\n\n# Monkey patch torch.compile to set default backend\ntorch.compile = torchsched_compile\n"
  },
  {
    "path": "apex/contrib/torchsched/backend.py",
    "content": "\"\"\"Graph scheduler backend.\"\"\"\n\nfrom __future__ import annotations\n\nimport functools\nfrom copy import copy\nfrom typing import TYPE_CHECKING\nfrom typing import ParamSpec\nfrom typing import TypeVar\n\nif TYPE_CHECKING:\n    from collections.abc import Callable\n    from types import NotImplementedType\n\nimport torch\nfrom torch import Tensor\nfrom torch import _TorchCompileInductorWrapper\nfrom torch._dynamo import lookup_backend\nfrom torch._inductor.compile_fx import compile_fx\nfrom torch._inductor.compile_fx import compile_fx_inner\nfrom torch._inductor.decomposition import select_decomp_table\n\nimport apex.contrib.torchsched.config as config\nfrom apex.contrib.torchsched.inductor import patch_graph_lowering\nfrom apex.contrib.torchsched.passes import pre_grad_custom_pass\n\naten = torch.ops.aten\nprims = torch.ops.prims\n\n__all__ = [\"get_backend\"]\n\n\nP = ParamSpec(\"P\")\nR = TypeVar(\"R\")\n\n\ndef enable_multi_stream_scheduling(compile_fn: Callable[P, R]) -> Callable[P, R]:\n    assert callable(compile_fn)\n\n    @functools.wraps(compile_fn)\n    def _compile_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:\n        patch_graph_lowering(patch=True)\n        compile_results = compile_fn(*args, **kwargs)\n        patch_graph_lowering(patch=False)\n        return compile_results\n\n    return _compile_wrapper\n\n\n# Refer: https://github.com/pytorch/pytorch/blob/v2.6.0/torch/_inductor/decomposition.py#L213\ndef convolution_backward_decomp_dwb(\n    grad_output: Tensor,\n    input: Tensor,\n    weight: Tensor,\n    bias_sizes: tuple[int, ...],\n    stride: tuple[int, ...],\n    padding: tuple[int, ...],\n    dilation: tuple[int, ...],\n    transposed: bool,\n    output_padding: tuple[int, ...],\n    groups: int,\n    output_mask: tuple[bool, bool, bool],\n) -> tuple[Tensor, Tensor, Tensor] | NotImplementedType:\n    \"\"\"Decomposite convolution bprop using the dgrad/wgrad/bgrad scheme.\n\n    Args:\n        grad_output (Tensor): The gradient w.r.t output.\n        input (Tensor): The input tensor.\n        weight (Tensor): The weight tensor.\n        bias_sizes (Tuple[int, ...]): The sizes of the bias tensor.\n        stride (Tuple[int, ...]): The stride of the convolution.\n        padding (Tuple[int, ...]): The padding of the convolution.\n        dilation (Tuple[int, ...]): The dilation of the convolution.\n        transposed (bool): Whether the convolution is transposed.\n        output_padding (Tuple[int, ...]): The output padding for the transposed convolution.\n        groups (int): The number of groups for the convolution.\n        output_mask (Tuple[bool, bool, bool]): A mask indicating which gradients to compute.\n\n    Returns:\n        Union[Tuple[Tensor, Tensor, Tensor], NotImplemented]: A tuple containing the\n            gradients of the input, weight, and bias, or NotImplemented if the\n            conditions are not met.\n    \"\"\"\n    if not output_mask[2] or grad_output.device.type != \"cuda\":\n        return NotImplemented\n    grad_inp, _, _ = aten.convolution_backward(\n        grad_output,\n        input,\n        weight,\n        bias_sizes,\n        stride,\n        padding,\n        dilation,\n        transposed,\n        output_padding,\n        groups,\n        [output_mask[0], False, False],\n    )\n    _, grad_weight, _ = aten.convolution_backward(\n        grad_output,\n        input,\n        weight,\n        bias_sizes,\n        stride,\n        padding,\n        dilation,\n        transposed,\n        output_padding,\n        groups,\n        [False, output_mask[1], False],\n    )\n    grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim())))\n    return (grad_inp, grad_weight, grad_bias)\n\n\ndef convolution_backward_decomp_wbd(\n    grad_output: Tensor,\n    input: Tensor,\n    weight: Tensor,\n    bias_sizes: tuple[int, ...],\n    stride: tuple[int, ...],\n    padding: tuple[int, ...],\n    dilation: tuple[int, ...],\n    transposed: bool,\n    output_padding: tuple[int, ...],\n    groups: int,\n    output_mask: tuple[bool, bool, bool],\n) -> tuple[Tensor, Tensor, Tensor] | NotImplementedType:\n    \"\"\"Decomposite convolution bprop using the wgrad/bgrad/dgrad scheme.\n\n    Args:\n        grad_output (Tensor): The gradient w.r.t output.\n        input (Tensor): The input tensor.\n        weight (Tensor): The weight tensor.\n        bias_sizes (Tuple[int, ...]): The sizes of the bias tensor.\n        stride (Tuple[int, ...]): The stride of the convolution.\n        padding (Tuple[int, ...]): The padding of the convolution.\n        dilation (Tuple[int, ...]): The dilation of the convolution.\n        transposed (bool): Whether the convolution is transposed.\n        output_padding (Tuple[int, ...]): The output padding for the transposed convolution.\n        groups (int): The number of groups for the convolution.\n        output_mask (Tuple[bool, bool, bool]): A mask indicating which gradients to compute.\n\n    Returns:\n        Union[Tuple[Tensor, Tensor, Tensor], NotImplemented]: A tuple containing the\n            gradients of the input, weight, and bias, or NotImplemented if the\n            conditions are not met.\n    \"\"\"\n    if not output_mask[2] or grad_output.device.type != \"cuda\":\n        return NotImplemented\n    _, grad_weight, _ = aten.convolution_backward(\n        grad_output,\n        input,\n        weight,\n        bias_sizes,\n        stride,\n        padding,\n        dilation,\n        transposed,\n        output_padding,\n        groups,\n        [False, output_mask[1], False],\n    )\n    grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim())))\n    grad_inp, _, _ = aten.convolution_backward(\n        grad_output,\n        input,\n        weight,\n        bias_sizes,\n        stride,\n        padding,\n        dilation,\n        transposed,\n        output_padding,\n        groups,\n        [output_mask[0], False, False],\n    )\n    return (grad_inp, grad_weight, grad_bias)\n\n\nclass DecompositionsWrapper(_TorchCompileInductorWrapper):\n    \"\"\"A wrapper class for handling decompositions in model compilation.\n\n    This class extends the `_TorchCompileInductorWrapper` to include additional\n    decompositions for model compilation.\n\n    Args:\n        mode (str): The mode for the wrapper.\n        options (Optional[Dict]): Additional options for the wrapper.\n        dynamic (bool): Whether the wrapper is dynamic.\n        decompositions (Dict): A dictionary of decompositions to use.\n\n    Attributes:\n        decompositions (Dict): The decompositions used by the wrapper.\n    \"\"\"\n\n    def __init__(\n        self,\n        mode: str,\n        options: dict | None,\n        dynamic: bool,\n        decompositions: dict,\n    ) -> None:\n        \"\"\"Initialize the DecompositionsWrapper.\"\"\"\n        super().__init__(mode, options, dynamic)\n        self.decompositions = decompositions\n        # Force skip the type checking in self.apply_options() since default values are None type.\n        self.config.update(\n            {\n                \"pre_grad_custom_pass\": (\n                    pre_grad_custom_pass if config.enable_pre_grad_pass else None\n                ),\n            },\n        )\n\n    def __eq__(self, rhs: object) -> bool:\n        \"\"\"Check equality with another DecompositionsWrapper.\n\n        Args:\n            rhs (object): The other object to compare with.\n\n        Returns:\n            bool: True if the wrappers are equal, False otherwise.\n        \"\"\"\n        eq = (\n            isinstance(rhs, DecompositionsWrapper)\n            and super().__eq__(rhs)\n            and rhs.decompositions == self.decompositions\n        )\n        return eq\n\n    def __call__(\n        self,\n        model_: torch.nn.Module,\n        inputs_: list,\n        *args: object,\n        **kwargs: object,\n    ) -> Callable:\n        \"\"\"Compiles the model with the given inputs and decompositions.\n\n        Args:\n            model_ (torch.nn.Module): The model to compile.\n            inputs_ (list): The inputs to the model.\n            args (object): Positional argument.\n            kwargs (object): Keyword argument.\n\n        Returns:\n            Callable: The compiled model.\n        \"\"\"\n        # Modifications to compilation process should be isolated between each compilations.\n        decompositions = copy(select_decomp_table())\n        decompositions.update(self.decompositions)\n        return compile_fx(\n            model_,\n            inputs_,\n            inner_compile=enable_multi_stream_scheduling(compile_fx_inner),\n            config_patches=self.config,\n            decompositions=decompositions,\n        )\n\n\ndef get_backend(\n    backend: str = \"torch\",\n    scheme: str = \"dwb\",\n) -> Callable | DecompositionsWrapper:\n    \"\"\"Get the graph scheduler backend for model compilation.\n\n    This function returns the appropriate backend for model compilation based on\n    the specified parameters.\n\n    Args:\n        backend (str, optional): The backend to use. Defaults to \"torch\".\n        scheme (str, optional): The decomposition scheme to use. Defaults to \"dwb\".\n\n    Returns:\n        Union[Callable, DecompositionsWrapper]: The backend for model compilation.\n\n    Raises:\n        Exception: If an unknown scheme is specified.\n    \"\"\"\n    if backend not in (\"torch\", \"torchsched\"):\n        raise ValueError(f\"Unknown compilation {backend=}\")\n    if scheme not in (\"dwb\", \"wbd\"):\n        raise ValueError(f\"Invalid {scheme=}, use scheme=dwb or wbd instead\")\n\n    if backend == \"torch\":\n        return lookup_backend(\"inductor\")\n\n    # [NOTE] Disable buffer reuse and inplace buffers to avoid inter-stream conflicts.\n    #\n    # In PyTorch Inductor, the safety of buffer reuse and in-place buffer update is ensured by the\n    # program's single-stream, serial execution. That is, if op2 is launched only after op1 has\n    # completed execution, then these cases are safe:\n    #\n    #   Case 1: Safe to reuse buffer `workspace1` as `op2`'s workspace.\n    #\n    #         op1   ->   op2              op1   ->   op2\n    #          ↕          ↕       ⇒        ↕          ↑\n    #     workspace1 workspace2       workspace1 ←----┘\n    #\n    #   Case 2: Safe to inpalace `op1`'s output to `buf1` then send to `op2` as input.\n    #\n    #     buf1 -> op1 -> buf2 -> op2  ⇒  buf1 ↔\top1\n    #                                     └-------> op2\n    #\n    # However, if operators are dispatched to distinct CUDA Streams and execute in parallel, above\n    # cases are not safe any more:\n    #\n    #   Counter example 1: Case 1 is not safe if op1 and op2 are in parallel.\n    #\n    #        op1\n    #         ↕\n    #     workspace1 (Buffer modified concurrently by op1 and op2.)\n    #         ↕\n    #        op2\n    #\n    #   Counter example 2: Case 2 is not safe if op1 and op2 are in parallel.\n    #\n    #     buf1 <-->\top1\n    #      └------> op2 (Op2 could read op1's input data.)\n    #\n    # Thus currently we disable both buffer reuse and inplace buffer update to ensure multi-stream\n    # correctness.\n    #\n    # TODO(@davidli): Add cross-stream dependency to Inductor scheduling's dependency system so we\n    # can safely reuse and inplace update buffers even in multi-stream scenario.\n\n    if scheme == \"dwb\":\n        return DecompositionsWrapper(\n            mode=\"default\",\n            options={\"allow_buffer_reuse\": False, \"inplace_buffers\": False},\n            dynamic=False,\n            decompositions={\n                aten.convolution_backward.default: convolution_backward_decomp_dwb,\n            },\n        )\n    elif scheme == \"wbd\":\n        return DecompositionsWrapper(\n            mode=\"default\",\n            options={\"allow_buffer_reuse\": False, \"inplace_buffers\": False},\n            dynamic=False,\n            decompositions={\n                aten.convolution_backward.default: convolution_backward_decomp_wbd,\n            },\n        )\n    else:\n        # To please mypy\n        raise ValueError(f\"Invalid {scheme=}, use scheme=dwb or wbd instead\")\n"
  },
  {
    "path": "apex/contrib/torchsched/config.py",
    "content": "\"\"\"Configurations for graph scheduler.\"\"\"\n\nimport functools\nimport os\nimport re\nimport sys\n\n# Debug info and dump grpahs\ndebug = os.getenv(\"TORCH_SCHED_DEBUG\", \"0\") == \"1\"\n\n# Toggle pre_grad_pass for various pattern matches\nenable_pre_grad_pass = False\n\n# Pre grad pass patterns\npre_grad_pass_options: list[str] = [\"cudnn_layer_norm\"]\n\n# Number of CUDA streams used for multi-stream scheduling.\n# The first stream will be critical path stream, operators on non-critical path will be\n# scheduled to other streams in a round-robin way.\nnum_streams = int(os.getenv(\"TORCH_SCHED_NUM_STREAMS\", \"8\"))\n\n\ndef _get_skip_post_grad_graph_ids() -> set[int]:\n    if ids := os.environ.get(\"TORCH_SCHED_SKIP_GRAPH_IDS\"):\n        result: set[int] = set()\n        for part in ids.split(\",\"):\n            if \"-\" in part:\n                start, end = map(int, part.split(\"-\"))\n                result.update(range(start, end + 1))\n            else:\n                result.add(int(part))\n        return result\n    else:\n        return set()\n\n\n# IDs of post AOT-autograd graphs that should be skipped for multi-stream scheduling. Can be\n# specified via TORCH_SCHED_SKIP_GRAPH_IDS environment variable in a SLURM-like scheme, e.g.,\n# TORCH_SCHED_SKIP_GRAPH_IDS=1,2,3-5,7-10\nskip_post_grad_graph_ids: set[int] = _get_skip_post_grad_graph_ids()\n\n# Reduce the number of allocated CUDA Events in the generated program by:\n# 1. Track reference count of each CUDA Event in the scheduling phase. Skip generating CUDA Events\n#    that have no reference counts, i.e., have not been waited by other streams;\n# 2. Reuse allocated CUDA Events when feasible.\n# This option is enable by default.\nreuse_cuda_event: bool = os.getenv(\"TORCH_SCHED_REUSE_CUDA_EVENT\", \"1\") == \"1\"\n\n\n@functools.lru_cache\ndef __get_dump_code_backends_and_dir(\n    dump_code: str | None,\n) -> tuple[list[str], str | None]:\n    pattern = r\"(?:\\+(?P<backend>\\w+),)?(?P<dir>[\\w\\/\\.\\-\\s@#~]+)\"\n    backends, dir = [\"torchsched\"], None\n    if dump_code and (match := re.match(pattern, dump_code)):\n        if backend := match.group(\"backend\"):\n            backends.append(backend)\n        dir = os.path.abspath(match.group(\"dir\"))\n    return backends, dir\n\n\n# Specify dump code backend types and output directory by::\n#\n#   TORCH_SCHED_DUMP_CODE='+inductor,/dir/to/save/code'\n#\n# Where `+inductor` enables dump both Inductor and torchsched code. If omitted, only dump\n# torchsched code. `/dir/to/save/code` specifies a directory to dump code to.\n(\n    dump_code_backends,\n    dump_code_dir,\n) = __get_dump_code_backends_and_dir(os.getenv(\"TORCH_SCHED_DUMP_CODE\"))\n\nfrom torch.utils._config_module import install_config_module  # noqa: E402\n\n# adds patch, save_config, etc\ninstall_config_module(sys.modules[__name__])\n"
  },
  {
    "path": "apex/contrib/torchsched/inductor/__init__.py",
    "content": "\"\"\"Scheduling abstractions on PyTorch Inductor level.\"\"\"\n\nfrom apex.contrib.torchsched.inductor.graph import patch_graph_lowering\n\n__all__ = [\"patch_graph_lowering\"]\n"
  },
  {
    "path": "apex/contrib/torchsched/inductor/_utils.py",
    "content": "from __future__ import annotations\n\nimport functools\nimport queue\nfrom typing import TYPE_CHECKING\n\nif TYPE_CHECKING:\n    from types import TracebackType\n\nimport torch\n\n__all__ = [\n    \"DEFAULT_STREAM\",\n    \"DEFAULT_STREAM_IDX\",\n    \"ENTRANCE_EVENT\",\n    \"EVENT_NAME_TEMPLATE\",\n    \"STREAM_NAME_TEMPLATE\",\n    \"CUDAStreamPool\",\n    \"get_cuda_stream_pool\",\n]\n\nDEFAULT_STREAM: str = \"default_stream\"\nDEFAULT_STREAM_IDX: int = 0\nENTRANCE_EVENT: str = \"event0\"\nEVENT_NAME_TEMPLATE: str = \"event{event_idx:d}\"\nSTREAM_NAME_TEMPLATE: str = \"stream{stream_idx:d}\"\n\n\n@functools.lru_cache\ndef get_stream_name(stream_idx: int) -> str:\n    \"\"\"Generate CUDA Stream name from stream index number.\n\n    Args:\n        stream_idx: Non-negative index number. 0 refers to the default stream, others refer to side\n            streams.\n    \"\"\"\n    if stream_idx == 0:\n        return DEFAULT_STREAM\n    else:\n        return STREAM_NAME_TEMPLATE.format(stream_idx=stream_idx)\n\n\nclass CUDAStreamPool:\n    \"\"\"A pool managing reusable CUDA streams to optimize GPU operations.\n\n    Attributes:\n        pool_size (int): The maximum number of CUDA streams managed by the pool.\n        stream_queue (queue.Queue): Queue holding the available CUDA streams.\n    \"\"\"\n\n    def __init__(self, device: int | None = None, pool_size: int = 8) -> None:\n        \"\"\"Initializesthe CUDAStreamPool instance.\n\n        Args:\n            device (Optional[int], optional): The CUDA device ID.\n                Defaults to None (current device).\n            pool_size (int, optional): The maximum number of CUDA streams in the pool.\n                Defaults to 8.\n        \"\"\"\n        self.pool_size: int = pool_size\n        self.stream_queue: queue.Queue[torch.cuda.Stream] = queue.Queue(maxsize=pool_size)\n\n        for _ in range(pool_size):\n            stream = torch.cuda.Stream(device=device)\n            self.stream_queue.put(stream)\n\n    def acquire(self) -> torch.cuda.Stream:\n        \"\"\"Acquire a CUDA stream from the pool.\n\n        Returns:\n            torch.cuda.Stream: A CUDA stream object from the pool.\n        \"\"\"\n        return self.stream_queue.get()\n\n    def release(self, stream: torch.cuda.Stream | None) -> None:\n        \"\"\"Return a CUDA stream back to the pool.\n\n        Args:\n            stream (Optional[torch.cuda.Stream]): The CUDA stream to return to the pool.\n        \"\"\"\n        if stream is not None:\n            self.stream_queue.put(stream)\n\n    def __enter__(self) -> torch.cuda.Stream:\n        \"\"\"Enters the runtime context and acquires a CUDA stream.\n\n        Returns:\n            torch.cuda.Stream: The acquired CUDA stream.\n        \"\"\"\n        self.stream = self.acquire()\n        self.stream.__enter__()\n        return self.stream\n\n    def __exit__(\n        self,\n        exc_type: type[BaseException] | None,\n        exc_val: BaseException | None,\n        exc_tb: TracebackType | None,\n    ) -> None:\n        \"\"\"Exit the runtime context and releases the acquired CUDA stream.\n\n        Args:\n            exc_type (type[BaseException] | None): Exception type, if raised.\n            exc_val (BaseException | None): Exception instance, if raised.\n            exc_tb (TracebackType | None): Traceback object, if raised.\n        \"\"\"\n        self.stream.__exit__(exc_type, exc_val, exc_tb)\n        self.release(self.stream)\n\n\n_cuda_stream_pool: CUDAStreamPool | None = None\n\n\ndef get_cuda_stream_pool(device: int | None = None, pool_size: int = 32) -> CUDAStreamPool:\n    \"\"\"Retrieve a global CUDA stream pool, creating it if necessary.\n\n    This function ensures that only one CUDAStreamPool instance exists globally.\n\n    Args:\n        device (Optional[int], optional): The CUDA device ID to initialize the pool on.\n            Defaults to None (current device).\n        pool_size (int, optional): The number of streams in the pool. Defaults to 32.\n\n    Returns:\n        CUDAStreamPool: The global CUDA stream pool instance.\n    \"\"\"\n    global _cuda_stream_pool\n    if _cuda_stream_pool is None:\n        _cuda_stream_pool = CUDAStreamPool(device=device, pool_size=pool_size)\n    return _cuda_stream_pool\n"
  },
  {
    "path": "apex/contrib/torchsched/inductor/event.py",
    "content": "\"\"\"CUDA Event abstractions used in Inductor multi-stream scheduling.\n\nAttributes:\n    ENTRANCE_EVENT: Name of the first event on the default CUDA Stream that got recorded before all\n        kernels.\n    EVENT_NAME_TEMPLATE: Python string template to generate event names. Can be used as:\n\n            idx: int = ...\n            event = EVENT_NAME_TEMPLATE.format(event_idx=idx)\n\"\"\"\n\nfrom __future__ import annotations\n\nimport dataclasses\nimport functools\nimport itertools\n\nfrom torch._inductor.codegen.wrapper import IndentedBuffer\nfrom torch._inductor.codegen.wrapper import WrapperLine\n\nimport apex.contrib.torchsched.config as torchsched_config\nfrom apex.contrib.torchsched.inductor._utils import DEFAULT_STREAM_IDX\nfrom apex.contrib.torchsched.inductor._utils import ENTRANCE_EVENT\nfrom apex.contrib.torchsched.inductor._utils import EVENT_NAME_TEMPLATE\nfrom apex.contrib.torchsched.inductor._utils import get_stream_name\n\n\n@functools.total_ordering\n@dataclasses.dataclass\nclass CudaEventSym:\n    \"\"\"Symbolic representation of CUDA Events in the Inductor scheduling phase.\n\n    Args:\n        factory: The CUDAEventFactory that generate this event.\n        idx: Indexing number assigned in chronological order during scheduling.\n        originate_stream_idx: The index of the CUDA stream that this event originated from.\n        ref_count: Reference count of this event instance.\n        materialized_event: The actual CUDA Event name that will be used in the final PyTorch\n            program. Only symbolic event with reference count larger than one will be materialized.\n\n    Note:\n        In most cases this class should not be used standalone. Use\n        `CUDAEventFactory.get_sym_event()` to instantiate one.\n    \"\"\"\n\n    factory: CudaEventFactory\n    idx: int\n    originate_stream_idx: int\n    ref_count: int = 0\n    materialized_event: str | None = None\n\n    def __lt__(self, rhs: CudaEventSym) -> bool:\n        \"\"\"Whether the current event is generated before the rhs event.\"\"\"\n        if self.factory is not rhs.factory:\n            return NotImplemented\n        return (self.idx, self.originate_stream_idx) < (\n            rhs.idx,\n            rhs.originate_stream_idx,\n        )\n\n    def __eq__(self, rhs: object) -> bool:\n        \"\"\"Whether the current event is identical to the rhs event.\"\"\"\n        if not isinstance(rhs, CudaEventSym):\n            return NotImplemented\n        return (\n            self.idx == rhs.idx\n            and self.originate_stream_idx == rhs.originate_stream_idx\n            and self.factory is rhs.factory\n        )\n\n    def __str__(self) -> str:\n        \"\"\"Represent this symbolic event in string.\"\"\"\n        ret = f\"{self.__class__.__name__} (idx={self.idx}\"\n        ret += f\", originate_stream_idx={self.originate_stream_idx}\"\n        if self.ref_count:\n            ret += f\", ref_count={self.ref_count}\"\n        if self.materialized_event:\n            ret += f\", materialized to `{self.materialized_event}`\"\n        ret += \")\"\n        return ret\n\n    def __hash__(self) -> int:\n        \"\"\"Hash this symbolic event.\"\"\"\n        return hash((id(self.factory), self.idx, self.originate_stream_idx))\n\n    def record(self, stream_idx: int) -> _CudaEventRecordLine:\n        \"\"\"Record this event on a given stream.\n\n        Args:\n            stream_idx: The index of the stream that this event will record on.\n\n        Returns:\n            An internal data structure that depicts stream <-> event dependency.\n\n        Note:\n            This method doesn't necessarily generate a event recording in the final program.\n            Instead it records the dependence between the stream and the current event. Whether\n            or not this event recording show up in the final program depends on the reference\n            count of the current event. I.e., if this event is never waited for by the later\n            code, this event recording will not be code-generated.\n        \"\"\"\n        stream = get_stream_name(stream_idx)\n        return _CudaEventRecordLine(self, stream)\n\n    def wait(self, stream_idx: int) -> _CudaEventWaitLine:\n        \"\"\"Wait for this event to complete by a given stream.\n\n        Args:\n            stream_idx: The index of the stream that will be waiting for this event to complete.\n\n        Returns:\n            An internal data structure that depicts stream <-> event dependency.\n\n        Note:\n            This method doesn't necessarily generate a event waiting in the final program. Instead\n            it records the dependence between the stream and the current event and also increase\n            the reference count of this event. If an event object has called this method, it is\n            guaranteed to be generated in the final program.\n        \"\"\"\n        assert stream_idx != self.originate_stream_idx\n        self.ref_count += 1\n        stream = get_stream_name(stream_idx)\n        return _CudaEventWaitLine(self, stream)\n\n\n@dataclasses.dataclass\nclass _CudaEventRecordLine(WrapperLine):\n    event: CudaEventSym\n    stream: str\n    _reuse_cuda_event: bool = torchsched_config.reuse_cuda_event\n\n    def codegen(self, code: IndentedBuffer) -> None:\n        assert 0 <= self.event.ref_count\n        assert self.event.materialized_event is None\n        if self.event.ref_count or not self._reuse_cuda_event:\n            self.event.materialized_event = self.event.factory.get_materialized_event(code)\n            code.writeline(f\"{self.event.materialized_event}.record({self.stream})\")\n\n\n@dataclasses.dataclass\nclass _CudaEventWaitLine(WrapperLine):\n    event: CudaEventSym\n    stream: str\n\n    def codegen(self, code: IndentedBuffer) -> None:\n        assert 0 < self.event.ref_count\n        assert self.event.materialized_event is not None\n        code_line = f\"{self.event.materialized_event}.wait({self.stream})\"\n        self.event.ref_count -= 1\n        if self.event.ref_count == 0:\n            self.event.factory.deposit_materialized_event(self.event.materialized_event)\n            self.event.materialized_event = None\n            code_line += f\"  # End lifecycle of {self.event}\"\n        code.writeline(code_line)\n\n\nclass CudaEventFactory:\n    \"\"\"A factory that managements CUDA event creations and materializations.\n\n    This factory maintains internal states to ensure that created cuda events get monotonically\n    increasing indices as compilation goes along. It also maintains a pool of materialized cuda\n    events that symbolic events can reuse.\n    \"\"\"\n\n    def __init__(self) -> None:\n        \"\"\"Initialize a event factory.\"\"\"\n        self.symbolic_event_idx: itertools.count = itertools.count(start=1)\n        self.materialized_event_idx: itertools.count = itertools.count(start=1)\n        self.available_materialized_events: set[str] = set()\n        self._entrance_event: CudaEventSym | None = None\n        self._reuse_cuda_event: bool = torchsched_config.reuse_cuda_event\n\n    def get_entrance_event(self) -> CudaEventSym:\n        \"\"\"Return the cuda event that corresponding to compute graph entering.\"\"\"\n        if self._entrance_event is None:\n            self._entrance_event = CudaEventSym(\n                factory=self,\n                idx=0,\n                originate_stream_idx=DEFAULT_STREAM_IDX,\n            )\n            # Code-gen for entrance event is almost hard-coded in device guard enter so the\n            # materialization is slightly different here.\n            self._entrance_event.materialized_event = ENTRANCE_EVENT\n        return self._entrance_event\n\n    def get_sym_event(self, originate_stream_idx: int) -> CudaEventSym:\n        \"\"\"Allocate a symbolic cuda event.\"\"\"\n        return CudaEventSym(\n            factory=self,\n            idx=next(self.symbolic_event_idx),\n            originate_stream_idx=originate_stream_idx,\n        )\n\n    def get_materialized_event(self, code: IndentedBuffer) -> str:\n        \"\"\"Allocate or reuse a materialized cuda event.\"\"\"\n        if self._reuse_cuda_event and self.available_materialized_events:\n            return self.available_materialized_events.pop()\n        else:\n            event = EVENT_NAME_TEMPLATE.format(event_idx=next(self.materialized_event_idx))\n            code.writeline(f\"{event} = torch.cuda.Event()\")\n            return event\n\n    def deposit_materialized_event(self, event: str) -> None:\n        \"\"\"Give back a materialized cuda event when the corresponding sym event ends lifecycle.\"\"\"\n        assert event not in self.available_materialized_events\n        self.available_materialized_events.add(event)\n"
  },
  {
    "path": "apex/contrib/torchsched/inductor/graph.py",
    "content": "\"\"\"Scheduling abstractions on PyTorch Inductor GraphLowering level.\"\"\"\n\nfrom __future__ import annotations\n\nimport functools\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\n\nimport torch\nfrom torch._inductor.codegen.common import get_scheduling_for_device\nfrom torch._inductor.codegen.common import get_wrapper_codegen_for_device\nfrom torch._inductor.codegen.common import register_backend_for_device\nfrom torch._inductor.codegen.wrapper import PythonWrapperCodegen\nfrom torch._inductor.graph import GraphLowering\nfrom torch._inductor.scheduler import Scheduler\nfrom torch._inductor.virtualized import V\n\nif TYPE_CHECKING:\n    from torch._inductor.utils import ValueWithLineMap\n\nfrom apex.contrib.torchsched import config as torchsched_config\nfrom apex.contrib.torchsched.inductor.scheduler import MultiCudaStreamScheduler\nfrom apex.contrib.torchsched.inductor.wrapper import MultiStreamWrapperCodegen\n\n_inductor_codegen = GraphLowering.codegen\npatching_device_type = \"cuda\"\nschedule_log = torch._logging.getArtifactLogger(__name__, \"schedule\")\n\n\n@functools.wraps(GraphLowering.codegen)\ndef _torchsched_codegen(\n    graph: GraphLowering,\n) -> tuple[ValueWithLineMap, ValueWithLineMap]:\n    # Move patching logic here as post_grad_graph_id was not available until now.\n    cpp_wrapper_cls = get_wrapper_codegen_for_device(patching_device_type, cpp_wrapper=True)\n    only_cpu = len(graph.device_types - {\"cpu\", \"meta\"}) == 0\n    scheduling_cls = get_scheduling_for_device(patching_device_type)\n    wrapper_cls = get_wrapper_codegen_for_device(patching_device_type)\n    write_get_raw_stream = PythonWrapperCodegen.write_get_raw_stream\n    if not only_cpu and graph.post_grad_graph_id not in torchsched_config.skip_post_grad_graph_ids:\n        patched_scheduler_cls = MultiCudaStreamScheduler\n        patched_wrapper_cls = MultiStreamWrapperCodegen\n        # torch.compile explicitly calls `write_get_raw_stream` via wrapper's class method in its\n        # lowering process to walk around the wrapper-stream LRU cache mechanism. To be compatible\n        # with this, we got to patch wrapper's class method as well.\n        PythonWrapperCodegen.write_get_raw_stream = MultiStreamWrapperCodegen._write_get_raw_stream\n    else:\n        patched_scheduler_cls = Scheduler\n        patched_wrapper_cls = PythonWrapperCodegen\n    register_backend_for_device(\n        device=patching_device_type,\n        device_scheduling=scheduling_cls,\n        device_wrapper_codegen=patched_wrapper_cls,\n        device_cpp_wrapper_codegen=cpp_wrapper_cls,\n    )\n\n    graph.init_wrapper_code()\n    graph.scheduler = patched_scheduler_cls(graph.operations)\n    V.debug.draw_orig_fx_graph(graph.orig_gm, graph.scheduler.nodes)\n    graph.wrapper_code.push_codegened_graph(graph)\n    graph.scheduler.codegen()\n    result = graph.wrapper_code.generate(graph.is_inference)\n    graph.wrapper_code.pop_codegened_graph()\n\n    PythonWrapperCodegen.write_get_raw_stream = write_get_raw_stream\n    register_backend_for_device(\n        device=patching_device_type,\n        device_scheduling=scheduling_cls,\n        device_wrapper_codegen=wrapper_cls,\n        device_cpp_wrapper_codegen=cpp_wrapper_cls,\n    )\n\n    return result\n\n\n@functools.wraps(GraphLowering.codegen)\ndef _mixed_codegen(graph: GraphLowering) -> tuple[ValueWithLineMap, ValueWithLineMap]:\n    assert torchsched_config.dump_code_dir\n    output_code_per_backend: dict[str, tuple[ValueWithLineMap, ValueWithLineMap]] = {}\n\n    for backend in torchsched_config.dump_code_backends:\n        if backend == \"torchsched\":\n            codegen = _torchsched_codegen\n        elif backend == \"inductor\":\n            codegen = _inductor_codegen\n        else:\n            raise ValueError(f\"Unknown {backend=} from {torchsched_config.dump_code_backends=}\")\n        wrapper_code, kernel_code = codegen(graph)\n        output_code_per_backend[backend] = (wrapper_code, kernel_code)\n\n    for backend, (wrapper_code, kernel_code) in output_code_per_backend.items():\n        backend_dir = Path(torchsched_config.dump_code_dir) / backend\n        backend_dir.mkdir(parents=True, exist_ok=True)\n        graph_id = graph.post_grad_graph_id\n        (backend_dir / f\"graph_{graph_id}_wrapper_code.py\").write_text(wrapper_code.value)\n        if kernel_code.value.strip():\n            # Kernel_code is only available in AOTInductor mode.\n            (backend_dir / f\"graph_{graph_id}_kernel_code.py\").write_text(kernel_code.value)\n\n    return output_code_per_backend[\"torchsched\"]\n\n\ndef patch_graph_lowering(patch: bool = True) -> None:\n    \"\"\"Patch PyTorch Inductor lowerings with multi-stream scheduling.\n\n    This function patches the `torch.compile` stack on the GraphLowering level,\n    i.e., the compute graph has been captured by Dynamo and it has undergone\n    post-auto-gradient passes, including pattern-matching optimizations and\n    preliminary operator fusions. At that point, most nodes in the graph are\n    either fused Triton templates, or function calls to external libraries. The\n    multi-stream scheduler then finds the longest critical path in this graph,\n    and schedule other nodes to side streams to exploit the inherent parallelism\n    of the given compute graph.\n\n    Args:\n        patch: Whether to patch Inductor `GraphLowering` with multi-stream\n            scheduler. Set to `False` to restore the default `torch.compile`\n            behavior. (default: `True`)\n    \"\"\"\n    if patch and torchsched_config.dump_code_dir:\n        GraphLowering.codegen = _mixed_codegen\n    elif patch:\n        GraphLowering.codegen = _torchsched_codegen\n    else:\n        GraphLowering.codegen = _inductor_codegen\n"
  },
  {
    "path": "apex/contrib/torchsched/inductor/scheduler.py",
    "content": "\"\"\"Scheduling abstractions on PyTorch Inductor Scheduler level.\"\"\"\n\nfrom __future__ import annotations\n\nimport collections\nimport itertools\nimport re\nfrom typing import TYPE_CHECKING\nfrom typing import cast\n\nimport torch\nimport torch._inductor.config as inductor_config\nfrom torch._inductor import ir\nfrom torch._inductor.dependencies import WeakDep\nfrom torch._inductor.scheduler import BaseSchedulerNode\nfrom torch._inductor.scheduler import ExternKernelSchedulerNode\nfrom torch._inductor.scheduler import ForeachKernelSchedulerNode\nfrom torch._inductor.scheduler import FusedSchedulerNode\nfrom torch._inductor.scheduler import NopKernelSchedulerNode\nfrom torch._inductor.scheduler import Scheduler\nfrom torch._inductor.scheduler import SchedulerNode\nfrom torch._inductor.utils import device_need_guard\nfrom torch._inductor.virtualized import V\n\nfrom apex.contrib.torchsched import config\nfrom apex.contrib.torchsched.inductor._utils import DEFAULT_STREAM_IDX\nfrom apex.contrib.torchsched.inductor._utils import get_stream_name\nfrom apex.contrib.torchsched.inductor.event import CudaEventFactory\nfrom apex.contrib.torchsched.inductor.event import CudaEventSym\nfrom apex.contrib.torchsched.inductor.wrapper import EnterCudaStreamContextLine\n\nif TYPE_CHECKING:\n    from apex.contrib.torchsched.inductor.wrapper import MultiStreamWrapperCodegen\n\n\nschedule_log = torch._logging.getArtifactLogger(__name__, \"schedule\")\n\n\nclass MultiCudaStreamScheduler(Scheduler):\n    \"\"\"Scheduling post-fusion graph with multi-stream awareness.\n\n    This class introduced a new optimization pass on top of the Inductor :class:`Scheduler`. I.e.,\n    it firstly searches for the longest critical path in the given compute graph, currently using\n    the path depth as a proxy of execution cost. Then it executes the non-critical computations in\n    parallel with the critical path computations by launching them to side CUDA Streams, with the\n    goal of scheduling critical path computations back-to-back while saturating GPU resources at\n    runtime.\n\n    Args:\n        operations: A list of Inductor IR nodes representing fused computations.\n    \"\"\"\n\n    def __init__(self, operations: list[ir.Operation]) -> None:\n        \"\"\"Construct a scheduler object from a list of Inductor IR nodes.\n\n        Refer to :class:`MultiCudaStreamScheduler` doc for argument specification.\n        \"\"\"\n        super().__init__(operations)\n        self.event_factory = CudaEventFactory()\n        self.buff_to_event: dict[str, CudaEventSym] = collections.defaultdict(\n            lambda: self.event_factory.get_sym_event(\n                originate_stream_idx=self.current_stream_idx,  # type: ignore[arg-type]\n            ),\n        )\n        self.unjoined_events: dict[int, set[CudaEventSym]] = collections.defaultdict(set)\n        self.buffers_requiring_device_check: set[str] = set()\n        # The only source of which stream context are we currently in at the scheduling phase.\n        self._current_stream_ctx: EnterCudaStreamContextLine | None = None\n        self.schedule_multi_cuda_streams()\n\n    @property\n    def current_stream_idx(self) -> int | None:\n        \"\"\"CUDA Stream index that current scheduler node assigned to.\"\"\"\n        if self._current_stream_ctx is not None:\n            return self._current_stream_ctx.stream_idx\n        else:\n            return None\n\n    @property\n    def current_stream_name(self) -> str | None:\n        \"\"\"CUDA Stream name that current scheduler node assigned to.\"\"\"\n        if (stream_idx := self.current_stream_idx) is not None:\n            return get_stream_name(stream_idx)\n        else:\n            return None\n\n    @property\n    def buffers_recorded_on_current_stream(self) -> set[str]:\n        \"\"\"Buffer names that have been recorded on the current stream context.\"\"\"\n        assert self._current_stream_ctx is not None\n        return self._current_stream_ctx.buffers_recorded_on_this_stream\n\n    @buffers_recorded_on_current_stream.setter\n    def buffers_recorded_on_current_stream(self, buffs: set[str]) -> None:\n        \"\"\"Set buffer names that have been recorded on the current stream context.\n\n        Note:\n            The name of buffers recorded on the current stream context should be a superset of the\n            buffers recorded on the previous stream context.\n        \"\"\"\n        assert self._current_stream_ctx is not None\n        assert buffs.issuperset(self._current_stream_ctx.buffers_recorded_on_this_stream)\n        self._current_stream_ctx.buffers_recorded_on_this_stream = buffs\n\n    def debug_str_short(self, node: BaseSchedulerNode) -> str:\n        \"\"\"Generate short string representing scheduler node's calling function or indices.\"\"\"\n        if node.is_extern() and isinstance(node.node, ir.MultiOutput):\n            kernel_str = node.node.codegen_list_tuple_access(\n                basename=\"getitem\",\n                indices=node.node.indices,\n            )\n            return f\"{node.get_name()} ({kernel_str})\"\n        elif node.is_extern():\n            kernel_name = node.node.get_kernel_name() or str(node.node.op_overload)\n            return f\"{node.get_name()} ({kernel_name})\"\n        else:\n            return node.get_name()\n\n    def get_last_event(self, events: set[CudaEventSym]) -> CudaEventSym:\n        \"\"\"Identify the latest generated CUDA event among all given events.\"\"\"\n        return sorted(events, reverse=True)[0]  # CudaEventSym is total-ordering.\n\n    def schedule_multi_cuda_streams(self) -> None:\n        \"\"\"Assign each fused Inductor IR nodes with the CUDA Stream to be launched to.\"\"\"\n        if not self.nodes:\n            # Empty graphs are sent to compiler in very rare circumstances. Just Skip scheduling.\n            return\n\n        buf_originate: dict[str, BaseSchedulerNode] = {}\n        node_users: dict[BaseSchedulerNode, set[BaseSchedulerNode]] = collections.defaultdict(set)\n        for node in self.nodes:\n            for n in node.get_buffer_names():\n                buf_originate[n] = node\n        for node in self.nodes:\n            for d in node.unmet_dependencies:\n                assert d.name in buf_originate\n                node_users[buf_originate[d.name]].add(node)\n\n        critical_path_per_depth: dict[int, set[BaseSchedulerNode]] = collections.defaultdict(set)\n        node_depth: dict[BaseSchedulerNode, int] = collections.defaultdict(lambda: -1)\n\n        def visit(node: BaseSchedulerNode, depth: int, prev: set[BaseSchedulerNode]) -> None:\n            if node_depth[node] < depth:\n                node_depth[node] = depth\n                path = prev | {node}\n                if len(critical_path_per_depth[depth]) < len(path):\n                    critical_path_per_depth[depth] = path\n                for user in node_users[node]:\n                    visit(user, depth + 1, path)\n\n        graph_entries = [n for n in self.nodes if not n.unmet_dependencies]\n        for entry in graph_entries:\n            visit(entry, depth=1, prev=set())\n\n        max_depth, longest_critical_path = sorted(critical_path_per_depth.items(), reverse=True)[0]\n\n        # Allocate CUDA Streams for each fused node:\n        # - Critical path nodes go to the default stream\n        # - Nodes without GPU operations (currently only covered getitem nodes) go to their\n        #   producer's stream\n        # - Other nodes go to a set of pre-defined number of side-streams in a round-robin manner\n        num_streams = config.num_streams\n        if num_streams == 1:\n            node_to_stream = {node: DEFAULT_STREAM_IDX for node in self.nodes}\n        else:\n            node_to_stream = {}\n            side_stream_indices = itertools.cycle(range(1, num_streams))\n            for node in self.nodes:\n                if node in longest_critical_path:\n                    node_to_stream[node] = DEFAULT_STREAM_IDX\n                elif node.is_extern() and isinstance(node.node, ir.MultiOutput):\n                    assert len(node.unmet_dependencies) == 1\n                    producer = buf_originate[next(iter(node.unmet_dependencies)).name]\n                    node_to_stream[node] = node_to_stream[producer]\n                else:\n                    node_to_stream[node] = next(side_stream_indices)\n        self.node_to_stream = node_to_stream\n\n        # Also remember buffer originate streams.\n        buff_to_stream = {}\n        for node, stream_idx in node_to_stream.items():\n            for buf_name in node.get_buffer_names():\n                buff_to_stream[buf_name] = stream_idx\n        self.buff_to_stream = buff_to_stream\n\n        schedule_log.debug(f\"{' Multi-CUDA-Stream scheduling results ':=^79}\")\n        schedule_log.debug(\"Post-fusion graph depth: %d\", max_depth)\n        schedule_log.debug(\"Total number of allocated CUDA Streams: %d\", num_streams)\n        schedule_log.debug(f\"{' Critical path ':-^79}\")\n        for node in self.nodes:\n            if node in longest_critical_path:\n                schedule_log.debug(\"- %s\", self.debug_str_short(node))\n        schedule_log.debug(f\"{' Stream assignments of other nodes ':-^79}\")\n        for node, stream_idx in node_to_stream.items():\n            if node not in longest_critical_path:\n                schedule_log.debug(\"- %s -> Stream %d\", self.debug_str_short(node), stream_idx)\n\n    def get_final_events_to_sync(self) -> set[CudaEventSym]:\n        \"\"\"Return the CUDA Events that need to be synced at the end of the program.\n\n        Raises:\n            ValueError: If there is hanging event on the default stream. This usually means the\n                user didn't properly use :meth:`add_unjointed_event` to register hanging events.\n        \"\"\"\n        if self.unjoined_events.get(DEFAULT_STREAM_IDX):\n            raise ValueError(\n                f\"Unexpected {self.unjoined_events[DEFAULT_STREAM_IDX]=} on default stream\",\n            )\n        events_to_sync = set()\n        for stream, events in self.unjoined_events.items():\n            if len(events) == 0:\n                schedule_log.debug(f\"All events on stream{stream} have been consumed\")\n                continue\n            last_event = self.get_last_event(events)\n            if 1 < len(events):\n                schedule_log.debug(\n                    f\"Seeing multiple hanging {events=} on stream{stream}, scheduling the \"\n                    f\"{last_event=} to sync\",\n                )\n            else:\n                schedule_log.debug(\n                    f\"Scheduling the {last_event=} on stream{stream} to sync\",\n                )\n            events_to_sync.add(last_event)\n        return events_to_sync\n\n    def clear_unjoined_events(self) -> None:\n        \"\"\"Clear handing event syncs registered by :meth:`add_unjointed_event`.\"\"\"\n        self.unjoined_events.clear()\n\n    def register_downstream_event(\n        self,\n        node: BaseSchedulerNode,\n    ) -> CudaEventSym:\n        \"\"\"Register one CUDA event indicating node execution complete.\n\n        For ordinary Inductor IR nodes, the completion event is newly created using an internal\n        event counter. For Inductor no-op nodes, the last event corresponding to its inputs will be\n        used instead.\n\n        Args:\n            node: The Inductor IR node to generate completion event for.\n\n        Returns:\n            The name of the completion event.\n\n        Raises:\n            ValueError: If this function is called out side of any stream context.\n        \"\"\"\n        if isinstance(node, NopKernelSchedulerNode) and node.unmet_dependencies:\n            upstream_events = set()\n            for dep in node.unmet_dependencies:\n                assert dep.name in self.buff_to_event\n                upstream_events.add(self.buff_to_event[dep.name])\n            assert 1 <= len(upstream_events)\n            downstream_event = self.get_last_event(upstream_events)\n            for buff in node.get_buffer_names():\n                self.buff_to_event[buff] = downstream_event\n        else:\n            for i, buff in enumerate(sorted(node.get_buffer_names())):\n                if i == 0:\n                    downstream_event = self.buff_to_event[buff]\n                    assert downstream_event.originate_stream_idx == self.current_stream_idx\n                else:\n                    self.buff_to_event[buff] = downstream_event\n            if (node_stream := self.node_to_stream[node]) != DEFAULT_STREAM_IDX:\n                self.unjoined_events[node_stream].add(downstream_event)\n            V.graph.wrapper_code.writeline(downstream_event.record(node_stream))\n        return downstream_event\n\n    def get_cross_stream_dependencies(\n        self,\n        node: BaseSchedulerNode,\n    ) -> tuple[set[CudaEventSym], set[str]]:\n        \"\"\"Get CUDA Event and buffer dependencies of an IR node.\n\n        Args:\n            node: The Inductor IR node to generate code for.\n\n        Returns:\n            upstream_events: A set of CUDA Event symbols, these events need to be synced before\n                executing `node`'s code.\n            buffer_from_other_streams: A set of buffer names, these buffers need to be recorded on\n                the CUDA Stream that `node` is running on.\n        \"\"\"\n        assert node in self.node_to_stream\n\n        # Process cross-cuda-stream dependencies.\n        node_stream = self.node_to_stream[node]\n        events_on_stream: dict[int, set[CudaEventSym]] = collections.defaultdict(set)\n        buffers_from_other_streams = set()\n        if not node.unmet_dependencies and node_stream != DEFAULT_STREAM_IDX:\n            # Graph entries on side streams should wait upon the main stream entrance.\n            entrance_event = self.event_factory.get_entrance_event()\n            events_on_stream[DEFAULT_STREAM_IDX].add(entrance_event)\n        for dep in node.read_writes.reads:\n            buff = dep.name  # To track stream number and cuda events.\n            buff_real = self.mutation_real_name.get(buff, buff)  # The real name in code.\n            if dep not in node.unmet_dependencies and not isinstance(dep, WeakDep):\n                # Materialized dependencies should be recorded on this stream.\n                buffers_from_other_streams.add(buff_real)\n                # The scalar tensor argument `dropout_p` of SDPA backward kernels might be on CUDA\n                # or CPU devices depending on execution scenario. To ensure program correctness we\n                # add a runtime check for it.\n                #\n                # TODO (@davidli): Remove this ad-hoc checking once PyTorch fix SDPA and\n                # MultiOutputLayout issues.\n                if node.is_extern() and re.match(\n                    r\"aten._scaled_dot_product_.*_attention_backward\",\n                    str(node.node.op_overload),\n                ):\n                    self.buffers_requiring_device_check.add(buff_real)\n                continue\n            elif isinstance(dep, WeakDep):\n                # Skip unmaterialized dependencies.\n                continue\n            assert buff in self.buff_to_event\n            assert buff in self.buff_to_stream\n            buff_event = self.buff_to_event[buff]\n            buff_stream = self.buff_to_stream[buff]\n            events_on_stream[buff_stream].add(buff_event)\n            if buff_stream != node_stream:\n                if node.is_extern() and isinstance(node.node, ir.MultiOutput):\n                    assert len(node.read_writes.reads) == 1\n                    buff_real = node.node.codegen_list_tuple_access(\n                        basename=buff_real,\n                        indices=node.node.indices,\n                    )\n                    self.buffers_requiring_device_check |= {\n                        buff_real,\n                        node.node.get_name(),\n                    }\n                buffers_from_other_streams.add(buff_real)\n\n        # Should only wait for the latest event from each stream.\n        upstream_events = set()\n        for stream, events in events_on_stream.items():\n            if stream != node_stream:\n                last_event = self.get_last_event(events)\n                upstream_events.add(last_event)\n\n        return upstream_events, buffers_from_other_streams\n\n    def generate_stream_ctx_enter(self, node: BaseSchedulerNode) -> None:\n        \"\"\"Code-gen to enter the Stream context assigned to node.\"\"\"\n        assert not isinstance(node, NopKernelSchedulerNode)\n        wrapper_code = cast(\"MultiStreamWrapperCodegen\", V.graph.wrapper_code)\n        upstream_events, buffers_from_other_streams = self.get_cross_stream_dependencies(node)\n        node_stream = self.node_to_stream[node]\n        self._current_stream_ctx = wrapper_code.codegen_cuda_stream_enter(\n            stream_idx=node_stream,\n            upstream_events=upstream_events,\n            buffers_from_other_streams=buffers_from_other_streams,\n            buffers_requiring_device_check=self.buffers_requiring_device_check,\n        )\n\n    def generate_stream_ctx_exit(self) -> None:\n        \"\"\"Code-gen to exit from the current Stream context.\"\"\"\n        assert self._current_stream_ctx is not None\n        wrapper_code = cast(\"MultiStreamWrapperCodegen\", V.graph.wrapper_code)\n        wrapper_code.codegen_cuda_stream_exit()\n        self._current_stream_ctx = None\n\n    def propagate_cross_stream_dependencies(self, node: BaseSchedulerNode) -> None:\n        \"\"\"Move input node's dependencies to the entrance of current CUDA Stream context.\n\n        If node is scheduled in the middle of a stream context, its dependencies should be properly\n        synced before entering this context. This function extracts `node`'s dependencies and move\n        them to the data structure that represents the entrance of current stream context.\n\n        Args:\n            node: The Inductor IR node to generate code for. This node must have an assigned stream\n                in :meth:`schedule_multi_cuda_streams`.\n        \"\"\"\n        assert self.current_stream_idx is not None\n        wrapper_code = cast(\"MultiStreamWrapperCodegen\", V.graph.wrapper_code)\n        upstream_events, buffers_from_other_streams = self.get_cross_stream_dependencies(node)\n        buffers_from_other_streams -= self.buffers_recorded_on_current_stream\n        wrapper_code.codegen_buffers_record_stream(\n            buffers=buffers_from_other_streams,\n            stream_idx=self.current_stream_idx,\n            buffers_requiring_device_check=self.buffers_requiring_device_check,\n        )\n        wrapper_code.codegen_events_wait_stream(\n            events=upstream_events,\n            stream_idx=self.current_stream_idx,\n        )\n        self.buffers_recorded_on_current_stream |= buffers_from_other_streams\n\n    def generate_stream_ctx_switching(self, node: BaseSchedulerNode) -> None:\n        \"\"\"Generate stream entering and exiting to properly run node in a multi-stream scenario.\n\n        Stream context switching is only generated if `node`'s assigned stream is different from\n        the previous node's stream. If the node is a no-op, its code will be generated in the same\n        context of previous node.\n        \"\"\"\n        assert node in self.node_to_stream\n        stream = None if isinstance(node, NopKernelSchedulerNode) else self.node_to_stream[node]\n        if self.current_stream_idx == stream:\n            if stream is not None:\n                self.propagate_cross_stream_dependencies(node)\n            return\n        elif self.current_stream_idx is not None and stream is None:\n            # Don't generate ctx switching. Memory plaining code (e.g., delete buffers) on current\n            # node goes to previous stream ctx.\n            return\n        elif self.current_stream_idx is None and stream is not None:\n            # Enter new ctx, update current stream status.\n            self.generate_stream_ctx_enter(node)\n        else:\n            # Switching from previous stream ctx to the new stream ctx.\n            self.generate_stream_ctx_exit()\n            self.generate_stream_ctx_enter(node)\n\n    def codegen(self) -> None:\n        \"\"\"Generate Python code for each of the Scheduler IR nodes.\n\n        Note:\n            The overall `torch.compile` code-gen is a multi-pass process, which means that this\n            method doesn't necessarily generate final program strings for every IR nodes. For\n            certain types of IRs, e.g., those involve memory allocation/deletion and CUDA Stream\n            switching, this method only generates respective data structures, and the final\n            code-gen is delegated to :meth:`WrapperCodeGen.codegen` using information form these\n            data structures.\n\n        Raises:\n            AssertionError: If any of the conditions met\n                * A node need to switch device context but it didn't include device information;\n                * A node contains at least one non-weak dependence that was not seen in the\n                  :meth:`schedule_multi_cuda_streams` pass;\n                * A node contains at least one non-weak cross-stream dependence that the\n                  corresponding event was not generated before that point;\n                * The fused compute graph contains :class:`ForeachKernelSchedulerNode` but the\n                  target backend doesn't support SIMD scheme.\n        \"\"\"\n        wrapper_code = cast(\"MultiStreamWrapperCodegen\", V.graph.wrapper_code)\n        wrapper_code.codegen_graph_nvtx_range_push(V.graph.post_grad_graph_id)\n        for node in self.nodes:\n            try:\n                schedule_log.debug(\n                    \"Generating code for node %s with estimated runtime %f\",\n                    node.get_name(),\n                    node.get_estimated_runtime(),\n                )\n            except Exception:\n                schedule_log.debug(\n                    \"Generating code for node %s with estimated runtime 0.0\",\n                    node.get_name(),\n                )\n\n            self.enter_context(node)\n\n            if not isinstance(node, NopKernelSchedulerNode) and (device := node.get_device()):\n                if device != self.current_device or node.is_extern() or node.is_template():\n                    self.flush()\n                if device != self.current_device:\n                    if self.current_device and device_need_guard(\n                        self.current_device.type,\n                    ):\n                        wrapper_code.codegen_device_guard_exit()\n                    if device_need_guard(device.type):\n                        assert device.index is not None, \"device should have an index\"\n                        wrapper_code.codegen_device_guard_enter(device.index)\n                    self.current_device: torch.device | None = device\n\n            self.generate_stream_ctx_switching(node)\n            self.buffer_names_to_free.update(node.last_usage)\n\n            if node.is_template():\n                node, *epilogue = node.get_nodes()\n                self.get_backend(device).codegen_template(node, epilogue)\n            elif node.is_extern():\n                node = cast(\"ExternKernelSchedulerNode\", node)\n                self.codegen_extern_call(node)\n            elif node.is_foreach():\n                node = cast(\"ForeachKernelSchedulerNode\", node)\n                backend_ = self.get_backend(device)\n                from torch._inductor.codegen.cuda_combined_scheduling import (\n                    CUDACombinedScheduling,\n                )\n                from torch._inductor.codegen.simd import SIMDScheduling\n\n                if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)):\n                    backend = backend_\n                else:\n                    raise AssertionError(f\"{type(self)=}\")\n                backend.codegen_combo_kernel(node)\n            elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):\n                self.get_backend(device).codegen_node(node)\n            else:\n                assert isinstance(node, NopKernelSchedulerNode)\n                node.mark_run()\n\n            if inductor_config.triton.debug_sync_kernel:\n                self.get_backend(device).codegen_sync()\n\n            self.available_buffer_names.update(node.get_buffer_names())\n            self.completed_operations.update(node.get_operation_names())\n            self.register_downstream_event(node)\n\n            if not isinstance(node, NopKernelSchedulerNode):\n                device = node.get_device()\n                if device is not None and self.get_backend(device).ready_to_flush():\n                    self.flush()\n\n        if self.current_device and device_need_guard(self.current_device.type):\n            # Exit the last stream context.\n            if self._current_stream_ctx is not None:\n                self.generate_stream_ctx_exit()\n            # Record the default stream on buffers from other streams.\n            side_stream_buffers = set()\n            for output in V.graph.get_output_names():\n                if self.buff_to_stream.get(output, DEFAULT_STREAM_IDX) != DEFAULT_STREAM_IDX:\n                    side_stream_buffers.add(output)\n            wrapper_code.codegen_buffers_record_stream(\n                buffers=side_stream_buffers,\n                stream_idx=DEFAULT_STREAM_IDX,\n                buffers_requiring_device_check=self.buffers_requiring_device_check,\n            )\n            # Sync hanging events from other streams.\n            if events_to_sync := self.get_final_events_to_sync():\n                wrapper_code.codegen_events_wait_stream(\n                    events=events_to_sync,\n                    stream_idx=DEFAULT_STREAM_IDX,\n                )\n            # exit the outermost CUDA device guard. this is\n            # important for nested indentation codegen-ing.\n            wrapper_code.codegen_device_guard_exit()\n\n        wrapper_code.codegen_graph_nvtx_range_pop()\n        self.flush()\n"
  },
  {
    "path": "apex/contrib/torchsched/inductor/wrapper.py",
    "content": "\"\"\"Scheduling abstractions on PyTorch Inductor WrapperCodeGen level.\n\nAttributes:\n    DEFAULT_STREAM: Name of the default CUDA Stream on the final generated Python code.\n    DEFAULT_STREAM_IDX: Index number of the default CUDA Stream in `torchsched` internal passes.\n    STREAM_NAME_TEMPLATE: Python string template to generate stream names. Can be used as:\n\n            idx: int = ...\n            stream = STREAM_NAME_TEMPLATE.format(stream_idx=idx)\n\"\"\"\n\nfrom __future__ import annotations\n\nimport dataclasses\nfrom typing import TYPE_CHECKING\n\nfrom torch._inductor.codegen.wrapper import EnterDeviceContextManagerLine\nfrom torch._inductor.codegen.wrapper import ExitDeviceContextManagerLine\nfrom torch._inductor.codegen.wrapper import IndentedBuffer\nfrom torch._inductor.codegen.wrapper import PythonWrapperCodegen\nfrom torch._inductor.codegen.wrapper import SubgraphPythonWrapperCodegen\nfrom torch._inductor.codegen.wrapper import WrapperLine\nfrom torch._inductor.virtualized import V\n\nimport apex.contrib.torchsched.config as config\nfrom apex.contrib.torchsched.inductor._utils import DEFAULT_STREAM\nfrom apex.contrib.torchsched.inductor._utils import ENTRANCE_EVENT\nfrom apex.contrib.torchsched.inductor._utils import STREAM_NAME_TEMPLATE\nfrom apex.contrib.torchsched.inductor._utils import get_stream_name\n\nif TYPE_CHECKING:\n    from torch._inductor.graph import GraphLowering\n    from torch._inductor.ir import GraphPartitionSignature\n\n    from apex.contrib.torchsched.inductor.event import CudaEventSym\n\n\n@dataclasses.dataclass\nclass EnterDeviceContextManagerWithStreamInfoLine(EnterDeviceContextManagerLine):\n    \"\"\"Enter a CUDA device context and allocate required side streams.\n\n    Note:\n        - The number of allocated streams is controlled by :attr:`torchsched.config.num_streams`;\n    \"\"\"\n\n    def codegen(self, code: IndentedBuffer) -> None:\n        \"\"\"Generate context switching and stream allocation code.\"\"\"\n        if V.graph.cpp_wrapper:\n            super().codegen(code)\n        else:\n            super().codegen(code)\n            code.writeline(f\"{DEFAULT_STREAM} = torch.cuda.current_stream()\")\n            code.writeline(f\"{ENTRANCE_EVENT} = {DEFAULT_STREAM}.record_event()\")\n\n            code.writeline(\n                \"from apex.contrib.torchsched.inductor._utils import get_cuda_stream_pool\"\n            )\n            code.writeline(\n                f\"cuda_stream_pool = get_cuda_stream_pool(device={self.device_idx}, \"\n                f\"pool_size={config.num_streams})\",\n            )\n\n            for i in range(1, config.num_streams):\n                code.writeline(\n                    f\"{STREAM_NAME_TEMPLATE.format(stream_idx=i)} = cuda_stream_pool.acquire()\",\n                )\n\n\n@dataclasses.dataclass\nclass ExitDeviceContextManagerWithStreamInfoLine(ExitDeviceContextManagerLine):\n    \"\"\"Exit a CUDA device context and release allocated streams.\"\"\"\n\n    def codegen(self, code: IndentedBuffer) -> None:\n        \"\"\"Generate context switching and stream release code.\"\"\"\n        for i in range(1, config.num_streams):\n            code.writeline(\n                f\"cuda_stream_pool.release({STREAM_NAME_TEMPLATE.format(stream_idx=i)})\",\n            )\n        if not V.graph.cpp_wrapper:\n            code.do_unindent()\n\n\n@dataclasses.dataclass\nclass EnterCudaStreamContextLine(WrapperLine):\n    \"\"\"Enter a context executed by respective CUDA Stream and insert necessary syncs.\n\n    Attributes:\n        wrapper: The code-gen wrapper of the current compilation phase.\n        stream_idx: The index number corresponds to the entering CUDA Stream context.\n        upstream_events: Names of CUDA Events that the current stream should be waiting for before\n            the stream switching.\n        buffers_from_other_streams: Name of buffers produced by other CUDA Streams. Those buffers\n            should be recorded to the current stream to avoid accidental memory free.\n        buffers_requiring_device_check: Name of buffers that might not be on CUDA devices and\n            require runtime device checking before recording stream to them.\n    \"\"\"\n\n    stream_idx: int\n\n    def __post_init__(self) -> None:\n        \"\"\"Track buffers have been recorded on this stream to reduce duplicate recording.\"\"\"\n        self.buffers_recorded_on_this_stream: set[str] = set()\n\n    def codegen(self, code: IndentedBuffer) -> None:\n        \"\"\"Generate stream switching and buffer recording code.\"\"\"\n        code.writeline(f\"with torch.cuda.stream({get_stream_name(self.stream_idx)}):\")\n        code.do_indent()\n\n        # [NOTE] The 3-indent-level assertion\n        #\n        #     Indent level 1: Inductor wrapper call indent\n        #         Indent level 2: Device guard context indent\n        #             Indent level 3: CUDA Stream context indent\n        #\n        # Over or under indenting usually means that :meth:`MultiCudaStreamScheduler.codegen`\n        # introduced bugs on stream context switching. This check also applies to stream context\n        # exiting, as in :meth:`ExitCudaStreamContextLine.codegen`.\n        assert code._indent == 3\n\n\n@dataclasses.dataclass\nclass ExitCudaStreamContextLine(WrapperLine):\n    \"\"\"Generate code to exit the current stream context.\n\n    Note:\n        Most attributes and checking logics of this class have been moved to\n        :meth:`MultiStreamWrapperCodeGen.codegen_cuda_stream_exit`. We preserve this data structure\n        because the checking and unindent should be generated in the latter phase of code-gen.\n    \"\"\"\n\n    def codegen(self, code: IndentedBuffer) -> None:\n        \"\"\"Check indentation level and exit the current stream context.\"\"\"\n        assert code._indent == 3  # See :note:`The 3-indent-level assertion` above.\n        code.do_unindent()\n\n\nclass MultiStreamWrapperCodegen(PythonWrapperCodegen):\n    \"\"\"Wrapper code generator for graph scheduling.\"\"\"\n\n    def __init__(self) -> None:\n        \"\"\"Construct a code-gen wrapper and disable raw stream caching.\n\n        Note:\n            The :meth:`write_get_raw_stream` method processed in this constructor is invoked from\n            literally everywhere throughout the Inductor stack, but the current\n            :meth:`PythonWrapperCodegen.write_get_raw_stream` is LRU-cached and always returns a\n            const raw stream name. This is not what we wanted in a multi-stream environment. Thus\n            we need to re-patch this function in instance initialization.\n        \"\"\"\n        super().__init__()\n        self.write_get_raw_stream = self._write_get_raw_stream\n\n    @staticmethod\n    def create(\n        is_subgraph: bool,\n        subgraph_name: str,\n        parent_wrapper: MultiStreamWrapperCodegen,\n        partition_signatures: GraphPartitionSignature | None = None,\n    ) -> MultiStreamWrapperCodegen | SubgraphPythonWrapperCodegen:\n        \"\"\"Instantiate a wrapper codegen for an Inductor graph or a subgraph.\"\"\"\n        if is_subgraph:\n            assert subgraph_name is not None\n            assert parent_wrapper is not None\n            return SubgraphPythonWrapperCodegen(\n                subgraph_name,\n                parent_wrapper,\n                partition_signatures,\n            )\n        return MultiStreamWrapperCodegen()\n\n    def _write_get_raw_stream(self, device_idx: int, graph: GraphLowering | None = None) -> str:\n        self.write_triton_header_once()\n        if (current_stream_name := V.graph.scheduler.current_stream_name) is not None:\n            name = f\"{current_stream_name}_raw\"\n            self.writeline(f\"{name} = {current_stream_name}.cuda_stream\")\n        else:\n            name = f\"stream{device_idx}\"\n            self.writeline(f\"{name} = get_raw_stream({device_idx})\")\n        return name\n\n    def codegen_graph_nvtx_range_push(self, post_grad_graph_id: int) -> None:\n        \"\"\"Generate NVTX range push for graph.\"\"\"\n        self.writeline(f\"torch.cuda.nvtx.range_push('graph {post_grad_graph_id}')\")\n\n    def codegen_graph_nvtx_range_pop(self) -> None:\n        \"\"\"Generate NVTX range pop for graph.\"\"\"\n        self.writeline(\"torch.cuda.nvtx.range_pop()\")\n\n    def codegen_device_guard_enter(self, device_idx: int) -> None:\n        \"\"\"Generate data structure for device guard context.\n\n        Note:\n            Refer to :class:`EnterDeviceContextManagerWithStreamInfoLine` doc for more details.\n        \"\"\"\n        self.writeline(\n            EnterDeviceContextManagerWithStreamInfoLine(\n                device_idx,\n                self.last_seen_device_guard_index,\n            ),\n        )\n        self.last_seen_device_guard_index: int = device_idx\n\n    def codegen_device_guard_exit(self) -> None:\n        \"\"\"Generate data structure for exiting device guard context.\"\"\"\n        self.writeline(ExitDeviceContextManagerWithStreamInfoLine())\n\n    def codegen_cuda_stream_enter(\n        self,\n        stream_idx: int,\n        upstream_events: set[CudaEventSym],\n        buffers_from_other_streams: set[str],\n        buffers_requiring_device_check: set[str] | None = None,\n    ) -> EnterCudaStreamContextLine:\n        \"\"\"Generate data structure for entering a CUDA Stream context.\n\n        Args:\n            stream_idx: The index number of the entering CUDA Stream context.\n            upstream_events: Names of CUDA Events that the current stream should be waiting for\n                before the stream switching. This is usually the events that are generated by the\n                previous stream context.\n            buffers_from_other_streams: Name of buffers produced by other CUDA Streams. Those\n                buffers should be recorded to the current stream to avoid accidental memory free.\n            buffers_requiring_device_check: Name of buffers that might not be on CUDA devices and\n                require runtime device checking before recording stream to them.\n\n        Note:\n            - Refer to :class:`EnterCudaStreamContextLine` for argument specifications;\n            - Once entered a context, the stream associated with this context will also be recorded\n              such that kernels in subsequent code-gen can get the correct stream index.\n\n        Raises:\n            ValueError: If this function is called while the previous stream context isn't exited.\n        \"\"\"\n        if (current_stream_name := V.graph.scheduler.current_stream_name) is not None:\n            raise ValueError(\n                f\"Nested stream context switching: {current_stream_name} -> \"\n                f\"{get_stream_name(stream_idx)}\",\n            )\n        ctx_entrance = EnterCudaStreamContextLine(stream_idx=stream_idx)\n        self.writeline(ctx_entrance)\n        self.codegen_buffers_record_stream(\n            buffers=buffers_from_other_streams,\n            stream_idx=stream_idx,\n            buffers_requiring_device_check=buffers_requiring_device_check,\n        )\n        ctx_entrance.buffers_recorded_on_this_stream |= buffers_from_other_streams\n        self.codegen_events_wait_stream(\n            events=upstream_events,\n            stream_idx=stream_idx,\n        )\n        return ctx_entrance\n\n    def codegen_cuda_stream_exit(self) -> None:\n        \"\"\"Generate data structure for exiting a CUDA Stream context.\"\"\"\n        self.writeline(ExitCudaStreamContextLine())\n\n    def codegen_events_wait_stream(self, events: set[CudaEventSym], stream_idx: int) -> None:\n        \"\"\"Generate data structure for syncing hanging CUDA Events with certain stream.\n\n        Args:\n            events: Symbols of the events that need to be synchronized with the given stream.\n            stream_idx: Index of the CUDA stream to synchronize the events with.\n        \"\"\"\n        for event in events:\n            self.writeline(event.wait(stream_idx))\n\n    def codegen_buffers_record_stream(\n        self,\n        buffers: set[str],\n        stream_idx: int,\n        buffers_requiring_device_check: set[str] | None = None,\n    ) -> None:\n        \"\"\"Generate data structure for recording steam on return tensors before program exit.\n\n        Args:\n            buffers: Names of buffers that need to be recorded on the given stream.\n            stream_idx: Index of the CUDA stream to record the buffers to.\n            buffers_requiring_device_check: Name of buffers that might not be on CUDA devices and\n                require runtime device checking before recording stream to them. If not provided,\n                buffers will be recorded to the given stream without runtime device checking.\n        \"\"\"\n        for buff in buffers:\n            prefix = (\n                f\"if {buff}.is_cuda: \"\n                if buffers_requiring_device_check and buff in buffers_requiring_device_check\n                else \"\"\n            )\n            self.writeline(f\"{prefix}{buff}.record_stream({get_stream_name(stream_idx)})\")\n"
  },
  {
    "path": "apex/contrib/torchsched/ops/__init__.py",
    "content": "\"\"\"Custom PyTorch operators.\"\"\"\n\nimport torch\n\n__all__: list[str] = []\n\n# Register custom operators\ntorch.ops.import_module(\"apex.contrib.torchsched.ops.layer_norm\")\n"
  },
  {
    "path": "apex/contrib/torchsched/ops/layer_norm.py",
    "content": "\"\"\"Customized CuDNN frontend layer norm.\n\nPlease refer to:\n\n* https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/python/20_layernorm.ipynb\n\"\"\"\n\nfrom __future__ import annotations\n\nimport math\n\nimport cudnn\nimport torch\n\n__all__ = [\"get_cudnn_manager\"]\n\n\nclass CuDNNManager:\n    \"\"\"CuDNN fronted context manager.\n\n    Notice: CuDNN handle must be created after distributed process group initialization.\n    \"\"\"\n\n    def __init__(self) -> None:\n        self._handle = cudnn.create_handle()\n        self._cudnn_stream = torch.cuda.Stream()\n        self.reset_stream()\n\n    def __del__(self) -> None:\n        if cudnn is not None and hasattr(cudnn, \"destroy_handle\"):\n            cudnn.destroy_handle(self._handle)\n\n    def __enter__(self) -> CuDNNManager:\n        self._torch_stream = torch.cuda.current_stream()\n        self._cudnn_stream.wait_stream(self._torch_stream)\n        torch.cuda.set_stream(self._cudnn_stream)\n        return self\n\n    def __exit__(\n        self,\n        exc_type: type | None,\n        exc_val: Exception | None,\n        exc_tb: object | None,\n    ) -> None:\n        self._torch_stream.wait_stream(self._cudnn_stream)\n        torch.cuda.set_stream(self._torch_stream)\n        del self._torch_stream\n\n    def set_stream(self, stream: torch.cuda.Stream) -> None:\n        cudnn.set_stream(stream=stream.cuda_stream, handle=self._handle)\n\n    def reset_stream(self) -> None:\n        cudnn.set_stream(stream=self._cudnn_stream.cuda_stream, handle=self._handle)\n\n    @property\n    def handle(self) -> int:\n        return self._handle\n\n    @property\n    def stream(self) -> torch.cuda.Stream:\n        return self._cudnn_stream\n\n\n_global_cudnn_manager: CuDNNManager | None = None\n\n\ndef get_cudnn_manager() -> CuDNNManager:\n    \"\"\"Get the CuDNN front-end context manager.\n\n    Returns:\n        CuDNNManager: Global CuDNN manager.\n    \"\"\"\n    global _global_cudnn_manager\n    if _global_cudnn_manager is None:\n        _global_cudnn_manager = CuDNNManager()\n    return _global_cudnn_manager\n\n\nclass LayerNormGraphFactory:\n    \"\"\"cuDNN front-end layer norm graph factory.\n\n    cuDNN layer norm constraints:\n\n    * All tensors are 4-dimensional;\n    * `x` and `y` have the same layout in the graph;\n    \"\"\"\n\n    _graphs: dict = {}\n    _symbols: dict = {}\n    _TORCH2CUDNN: dict = {\n        torch.bool: cudnn.data_type.BOOLEAN,\n        torch.bfloat16: cudnn.data_type.BFLOAT16,\n        torch.float16: cudnn.data_type.HALF,\n        torch.float32: cudnn.data_type.FLOAT,\n        torch.uint8: cudnn.data_type.UINT8,\n    }\n\n    @classmethod\n    def get_forward_graph(\n        cls: type[LayerNormGraphFactory],\n        m: int,\n        n: int,\n        xdtype: torch.dtype,\n        wdtype: torch.dtype,\n    ) -> tuple[cudnn._compiled_module.pygraph, tuple]:\n        key = m, n, xdtype, wdtype, \"FORWARD\"\n        if key in cls._graphs:\n            if key not in cls._symbols:\n                raise RuntimeError(\n                    f\"Symbolic tensor was not constructed for layer-norm forward graph with input \"\n                    f\"shape {(m, n)} and data type {(xdtype, wdtype)}\",\n                )\n            return cls._graphs[key], cls._symbols[key]\n\n        cudnn_manager: CuDNNManager = get_cudnn_manager()\n        graph = cudnn.pygraph(\n            intermediate_data_type=cudnn.data_type.FLOAT,\n            compute_data_type=cudnn.data_type.FLOAT,\n            handle=cudnn_manager.handle,\n        )\n        x_sym = graph.tensor(\n            name=\"x_sym\",\n            dim=(m, n, 1, 1),\n            stride=(n, 1, n, n),  # Simulate the channel-last format.\n            data_type=cls._TORCH2CUDNN[xdtype],\n        )\n        scale_sym = graph.tensor(\n            name=\"scale_sym\",\n            dim=(1, n, 1, 1),\n            stride=(n, 1, n, n),\n            data_type=cls._TORCH2CUDNN[wdtype],\n        )\n        bias_sym = graph.tensor(\n            name=\"bias_sym\",\n            dim=(1, n, 1, 1),\n            stride=(n, 1, n, n),\n            data_type=cls._TORCH2CUDNN[wdtype],\n        )\n        eps_sym = graph.tensor(\n            name=\"eps_sym\",\n            dim=(1, 1, 1, 1),\n            stride=(1, 1, 1, 1),\n            is_pass_by_value=True,\n            data_type=cudnn.data_type.FLOAT,\n        )\n\n        y_sym, x_mean_sym, x_invstd_sym = graph.layernorm(\n            name=f\"layer-norm-forward-{key}\",\n            norm_forward_phase=cudnn.norm_forward_phase.TRAINING,\n            input=x_sym,\n            scale=scale_sym,\n            bias=bias_sym,\n            epsilon=eps_sym,\n        )\n\n        y_sym.set_output(True).set_data_type(cls._TORCH2CUDNN[xdtype])\n        x_mean_sym.set_output(True).set_data_type(cls._TORCH2CUDNN[torch.float32])\n        x_invstd_sym.set_output(True).set_data_type(cls._TORCH2CUDNN[torch.float32])\n\n        graph.validate()\n        graph.build_operation_graph()\n        graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n        graph.check_support()\n        graph.build_plans(cudnn.build_plan_policy.HEURISTICS_CHOICE)  # ALL\n        symbols = (\n            x_sym,\n            scale_sym,\n            bias_sym,\n            eps_sym,\n            y_sym,\n            x_mean_sym,\n            x_invstd_sym,\n        )\n\n        cls._graphs[key] = graph\n        cls._symbols[key] = symbols\n\n        return graph, symbols\n\n    @classmethod\n    def get_backward_graph(\n        cls: type[LayerNormGraphFactory],\n        m: int,\n        n: int,\n        xdtype: torch.dtype,\n        wdtype: torch.dtype,\n    ) -> tuple[cudnn._compiled_module.pygraph, tuple]:\n        key = m, n, xdtype, wdtype, \"BACKWARD\"\n        if key in cls._graphs:\n            if key not in cls._symbols:\n                raise RuntimeError(\n                    f\"Symbolic tensor was not constructed for layer-norm backward \"\n                    f\"graph with input shape {(m, n)} and data type {(xdtype, wdtype)}\",\n                )\n            return cls._graphs[key], cls._symbols[key]\n\n        cudnn_manager: CuDNNManager = get_cudnn_manager()\n        graph = cudnn.pygraph(\n            intermediate_data_type=cudnn.data_type.FLOAT,\n            compute_data_type=cudnn.data_type.FLOAT,\n            handle=cudnn_manager.handle,\n        )\n        x_sym = graph.tensor(\n            name=\"x_sym\",\n            dim=(m, n, 1, 1),\n            stride=(n, 1, n, n),  # Simulate the channel-last format.\n            data_type=cls._TORCH2CUDNN[xdtype],\n        )\n        d_y_sym = graph.tensor(\n            name=\"d_y_sym\",\n            dim=(m, n, 1, 1),\n            stride=(n, 1, n, n),  # Simulate the channel-last format.\n            data_type=cls._TORCH2CUDNN[xdtype],\n        )\n        scale_sym = graph.tensor(\n            name=\"scale_sym\",\n            dim=(1, n, 1, 1),\n            stride=(n, 1, n, n),\n            data_type=cls._TORCH2CUDNN[wdtype],\n        )\n        x_mean_sym = graph.tensor(\n            name=\"x_mean_sym\",\n            dim=(m, 1, 1, 1),\n            stride=(1, 1, 1, 1),\n            data_type=cudnn.data_type.FLOAT,\n        )\n        x_invstd_sym = graph.tensor(\n            name=\"x_invstd_sym\",\n            dim=(m, 1, 1, 1),\n            stride=(1, 1, 1, 1),\n            data_type=cudnn.data_type.FLOAT,\n        )\n        d_x_sym, d_scale_sym, d_bias_sym = graph.layernorm_backward(\n            name=f\"layer-norm-backward-{key}\",\n            grad=d_y_sym,\n            input=x_sym,\n            scale=scale_sym,\n            mean=x_mean_sym,\n            inv_variance=x_invstd_sym,\n        )\n\n        d_x_sym.set_output(True).set_data_type(cls._TORCH2CUDNN[xdtype])\n        d_scale_sym.set_output(True).set_data_type(cls._TORCH2CUDNN[wdtype])\n        d_bias_sym.set_output(True).set_data_type(cls._TORCH2CUDNN[wdtype])\n\n        graph.validate()\n        graph.build_operation_graph()\n        graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n        graph.check_support()\n        graph.build_plans(cudnn.build_plan_policy.HEURISTICS_CHOICE)  # ALL\n        symbols = (\n            x_sym,\n            d_y_sym,\n            scale_sym,\n            x_mean_sym,\n            x_invstd_sym,\n            d_x_sym,\n            d_scale_sym,\n            d_bias_sym,\n        )\n\n        cls._graphs[key] = graph\n        cls._symbols[key] = symbols\n\n        return graph, symbols\n\n\n@torch.library.custom_op(\"cudnn::layer_norm\", mutates_args=(), device_types=\"cuda\")\ndef layer_norm(\n    x: torch.Tensor,\n    normalized_shape: list[int],\n    weight: torch.Tensor,\n    bias: torch.Tensor,\n    eps: float = 1e-05,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    # PyTorch LayerNorm:\n    #   * Shape (N, S, H), normalized_shape (H,);\n    #   * Shape (N, C, H, W), normalized_shape (C, H, W);\n    # cuDNN LayerNorm expects shape (M, N, 1, 1) and normalized_shape (1, N, 1, 1)\n    if tuple(x.shape[-len(normalized_shape) :]) != tuple(normalized_shape):  # noqa: E203\n        raise ValueError(\n            f\"CuDNN LayerNorm expects `x.shape[{-len(normalized_shape)}:]` equals to \"\n            f\"`normalized_shape`, but got:\\n    {x.shape=}, {normalized_shape=}\",\n        )\n    assert weight.dtype == bias.dtype\n    assert x.is_contiguous()\n\n    stream = torch.cuda.current_stream()\n    cudnn_manager: CuDNNManager = get_cudnn_manager()\n    cudnn_manager.set_stream(stream)\n\n    xdtype, wdtype, device = x.dtype, weight.dtype, x.device\n    m, n = math.prod(x.shape[: -len(normalized_shape)]), math.prod(normalized_shape)\n    (\n        forward_graph,\n        (\n            x_sym,\n            scale_sym,\n            bias_sym,\n            eps_sym,\n            y_sym,\n            x_mean_sym,\n            x_invstd_sym,\n        ),\n    ) = LayerNormGraphFactory.get_forward_graph(m, n, xdtype, wdtype)\n\n    x_contiguous = x.reshape(m, n, 1, 1)  # NOTE: x could be noncontiguous.\n    weight = weight.view(1, n, 1, 1)\n    bias = bias.view(1, n, 1, 1)\n    eps_cpu = torch.full((1, 1, 1, 1), eps, dtype=torch.float32, device=\"cpu\")\n\n    y = torch.empty_like(x_contiguous)\n    x_mean = torch.empty(m, 1, 1, 1, dtype=torch.float32, device=device)\n    x_invstd = torch.empty(m, 1, 1, 1, dtype=torch.float32, device=device)\n    workspace = torch.empty(\n        forward_graph.get_workspace_size(),\n        dtype=torch.uint8,\n        device=device,\n    )\n\n    forward_graph.execute(\n        {\n            x_sym: x_contiguous.detach(),\n            scale_sym: weight.detach(),\n            bias_sym: bias.detach(),\n            eps_sym: eps_cpu.detach(),\n            y_sym: y.detach(),\n            x_mean_sym: x_mean.detach(),\n            x_invstd_sym: x_invstd.detach(),\n        },\n        workspace,\n    )\n    y = y.view(x.shape)\n\n    return y, x_mean, x_invstd\n\n\n@layer_norm.register_fake\ndef layer_norm_fake(\n    x: torch.Tensor,\n    normalized_shape: list[int],\n    weight: torch.Tensor,\n    bias: torch.Tensor,\n    eps: float = 1e-05,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    m = math.prod(x.shape[: -len(normalized_shape)])\n\n    y = torch.empty_like(x)\n    x_mean = torch.empty(m, 1, 1, 1, dtype=torch.float32, device=x.device)\n    x_invstd = torch.empty(m, 1, 1, 1, dtype=torch.float32, device=x.device)\n    return y, x_mean, x_invstd\n\n\n@torch.library.custom_op(\n    \"cudnn::layer_norm_backward\",\n    mutates_args=(),\n    device_types=\"cuda\",\n)\ndef layer_norm_backward(\n    d_y: torch.Tensor,\n    x_mean: torch.Tensor,\n    x_invstd: torch.Tensor,\n    x: torch.Tensor,\n    normalized_shape: list[int],\n    weight: torch.Tensor,\n    bias: torch.Tensor,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    xdtype, wdtype, device = d_y.dtype, weight.dtype, d_y.device\n    m, n = math.prod(x.shape[: -len(normalized_shape)]), math.prod(normalized_shape)\n\n    stream = torch.cuda.current_stream()\n    cudnn_manager: CuDNNManager = get_cudnn_manager()\n    cudnn_manager.set_stream(stream)\n\n    (\n        backward_graph,\n        (\n            x_sym,\n            d_y_sym,\n            scale_sym,\n            x_mean_sym,\n            x_invstd_sym,\n            d_x_sym,\n            d_scale_sym,\n            d_bias_sym,\n        ),\n    ) = LayerNormGraphFactory.get_backward_graph(m, n, xdtype, wdtype)\n\n    d_y_contiguous = d_y.reshape(m, n, 1, 1)  # NOTE: d_y could also be noncontiguous.\n    d_x = torch.empty_like(x)\n    d_weight = torch.empty_like(weight)\n    d_bias = torch.empty_like(bias)\n    workspace = torch.empty(\n        backward_graph.get_workspace_size(),\n        dtype=torch.uint8,\n        device=device,\n    )\n\n    backward_graph.execute(\n        {\n            x_sym: x.detach(),\n            d_y_sym: d_y_contiguous.detach(),\n            scale_sym: weight.detach(),\n            x_mean_sym: x_mean.detach(),\n            x_invstd_sym: x_invstd.detach(),\n            d_x_sym: d_x.detach(),\n            d_scale_sym: d_weight.detach(),\n            d_bias_sym: d_bias.detach(),\n        },\n        workspace,\n    )\n\n    return d_x, d_weight, d_bias\n\n\n@layer_norm_backward.register_fake\ndef layer_norm_backward_fake(\n    d_y: torch.Tensor,\n    x_mean: torch.Tensor,\n    x_invstd: torch.Tensor,\n    x: torch.Tensor,\n    normalized_shape: list[int],\n    weight: torch.Tensor,\n    bias: torch.Tensor,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    d_x = torch.empty_like(x)\n    d_weight = torch.empty_like(weight)\n    d_bias = torch.empty_like(bias)\n    return d_x, d_weight, d_bias\n\n\ndef layer_norm_setup_context(\n    ctx: torch.autograd.FunctionCtx,\n    inputs: tuple,\n    output: tuple,\n) -> torch.Tensor:\n    x, normalized_shape, weight, bias, eps = inputs\n    y, x_mean, x_invstd = output\n\n    ctx.save_for_backward(x, weight, bias, x_mean, x_invstd)\n    ctx.normalized_shape = normalized_shape\n\n    return y\n\n\ndef layer_norm_backward_wrapper(\n    ctx: torch.autograd.FunctionCtx,\n    d_y: torch.Tensor,\n    d_x_mean: torch.Tensor,\n    d_x_invstd: torch.Tensor,\n) -> tuple[torch.Tensor, None, torch.Tensor, torch.Tensor, None]:\n    x, weight, bias, x_mean, x_invstd = ctx.saved_tensors\n    normalized_shape = ctx.normalized_shape\n\n    d_x, d_weight, d_bias = layer_norm_backward(\n        d_y,\n        x_mean,\n        x_invstd,\n        x,\n        normalized_shape,\n        weight,\n        bias,\n    )\n\n    return d_x, None, d_weight, d_bias, None\n\n\ntorch.library.register_autograd(\n    \"cudnn::layer_norm\",\n    layer_norm_backward_wrapper,\n    setup_context=layer_norm_setup_context,\n)\n"
  },
  {
    "path": "apex/contrib/torchsched/passes/__init__.py",
    "content": "\"\"\"Customized compiler passes.\"\"\"\n\nfrom __future__ import annotations\n\nfrom apex.contrib.torchsched.passes.pre_grad_passes import pre_grad_custom_pass\n\n__all__ = [\"pre_grad_custom_pass\"]\n"
  },
  {
    "path": "apex/contrib/torchsched/passes/pre_grad_passes.py",
    "content": "\"\"\"Customized Inductor passes.\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nfrom typing import TYPE_CHECKING\n\nimport torch\nfrom torch._dynamo.utils import counters\nfrom torch.fx import replace_pattern\n\nif TYPE_CHECKING:\n    from collections.abc import Callable\n    from collections.abc import Sequence\n\nfrom apex.contrib.torchsched import config\n\n__all__ = [\"pre_grad_custom_pass\"]\n\n# pass name to (pattern replacement) mapping\nPRE_GRAD_PASS_PATTERNS: dict[str, tuple[Callable, Callable]] = {}\n\n\ndef register_pattern(name: str, pattern: Callable, replacement: Callable) -> None:\n    assert name not in PRE_GRAD_PASS_PATTERNS\n    PRE_GRAD_PASS_PATTERNS[name] = pattern, replacement\n\n\ndef replace_layer_norm(\n    x: torch.Tensor,\n    normalized_shape: Sequence[int],\n    weight: torch.Tensor,\n    bias: torch.Tensor,\n    eps: float,\n) -> torch.Tensor:\n    y, x_mean, x_invstd = torch.ops.cudnn.layer_norm(\n        x,\n        normalized_shape,\n        weight,\n        bias,\n        eps,\n    )\n    return y\n\n\nregister_pattern(\n    \"cudnn_layer_norm\",\n    torch.nn.functional.layer_norm,\n    replace_layer_norm,\n)\n\n\ndef run_pre_grad_pass(\n    name: str,\n    graph: torch.fx.Graph,\n    pattern: Callable,\n    replacement: Callable,\n) -> int:\n    \"\"\"Run a pre-gradient pass on the given graph.\n\n    Args:\n        name (str): A string identifier for the pass.\n        graph (torch.fx.Graph): The graph to be transformed.\n        pattern (Callable): A callable that defines the pattern to match in the graph.\n        replacement (Callable): A callable that defines the replacement for matched patterns.\n\n    Returns:\n        An integer representing the number of transformations applied.\n\n    Note:\n        These two doesn't match because of kwargs (Inductor vs. torch.fx.symbolic_trace):\n\n            %layer_norm : [num_users=1] = call_function[target=torch.nn.functional.layer_norm](\n                args = (%l_args_0_, (320,), %l_fn_parameters_weight_, %l_fn_parameters_bias_,\n                1e-05), kwargs = {})\n            %layer_norm : [num_users=1] = call_function[target=torch.nn.functional.layer_norm](\n                args = (%input_1, %normalized_shape), kwargs = {weight: %weight, bias: %bias,\n                eps: %eps})\n    \"\"\"\n    # Manually trace the graph and move kwargs to args\n    pattern_graph = torch.fx.symbolic_trace(pattern).graph\n    for node in pattern_graph.nodes:\n        if node.op == \"call_function\" and node.target == pattern:\n            node.args = node.args + tuple(node.kwargs.values())\n            node.kwargs = {}\n    pattern_graph.owning_module.recompile()\n\n    matched = replace_pattern(graph.owning_module, pattern_graph, replacement)\n    graph.owning_module.recompile()\n    graph.lint()\n\n    return len(matched)\n\n\ndef pre_grad_custom_pass(graph: torch.fx.Graph) -> None:\n    \"\"\"Run customized pre-grad passes.\n\n    Args:\n        graph (torch.fx.Graph): The FX graph to be optimized.\n    \"\"\"\n    passes = config.pre_grad_pass_options\n    for pass_name in passes:\n        assert pass_name in PRE_GRAD_PASS_PATTERNS, f\"Unknown pre_grad pass: {pass_name}\"\n        pattern, replacement = PRE_GRAD_PASS_PATTERNS[pass_name]\n        replaced = run_pre_grad_pass(pass_name, graph, pattern, replacement)\n        counters[\"torchsched\"][f\"pre_grad_{pass_name}\"] += replaced\n        logging.debug(\"Pre grad pass %s replaced %d sub-graphs\", pass_name, replaced)\n"
  },
  {
    "path": "apex/contrib/transducer/__init__.py",
    "content": "from .transducer import TransducerJoint\r\nfrom .transducer import TransducerLoss\r\nfrom . import _transducer_ref\r\n"
  },
  {
    "path": "apex/contrib/transducer/_transducer_ref.py",
    "content": "import torch\r\n\r\n\r\ndef transducer_loss_reference(x, label, f_len, y_len, blank_idx, loss_grad):\r\n    def log_sum_exp(a, b):\r\n        if a >= b:\r\n            return a + torch.log(1 + torch.exp(b - a))\r\n        else:\r\n            return b + torch.log(1 + torch.exp(a - b))\r\n\r\n    def forward_alpha(x, label, f_len, y_len, blank_idx):\r\n        B, T, U, V = x.size()\r\n        acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype\r\n        alpha = torch.zeros((B, T, U), dtype=acc_t, device=x.device)\r\n        for b in range(B):\r\n            alpha[b, 0, 0] = 0\r\n            for t in range(1, f_len[b]):\r\n                alpha[b, t, 0] = alpha[b, t - 1, 0] + x[b, t - 1, 0, blank_idx]\r\n            for u in range(1, y_len[b] + 1):\r\n                alpha[b, 0, u] = alpha[b, 0, u - 1] + x[b, 0, u - 1, label[b, u - 1]]\r\n            for t in range(1, f_len[b]):\r\n                for u in range(1, y_len[b] + 1):\r\n                    curr_ = alpha[b, t - 1, u] + x[b, t - 1, u, blank_idx]\r\n                    next_ = alpha[b, t, u - 1] + x[b, t, u - 1, label[b, u - 1]]\r\n                    alpha[b, t, u] = log_sum_exp(curr_, next_)\r\n        return alpha\r\n\r\n    def forward_beta(x, label, f_len, y_len, blank_idx):\r\n        B, T, U, V = x.shape\r\n        acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype\r\n        beta = torch.zeros((B, T, U), dtype=acc_t, device=x.device)\r\n        for b in range(B):\r\n            beta[b, f_len[b] - 1, y_len[b]] = x[b, f_len[b] - 1, y_len[b], blank_idx]\r\n            for t in range(f_len[b] - 2, -1, -1):\r\n                beta[b, t, y_len[b]] = beta[b, t + 1, y_len[b]] + x[b, t, y_len[b], blank_idx]\r\n            for u in range(y_len[b] - 1, -1, -1):\r\n                beta[b, f_len[b] - 1, u] = (\r\n                    beta[b, f_len[b] - 1, u + 1] + x[b, f_len[b] - 1, u, label[b, u]]\r\n                )\r\n            for t in range(f_len[b] - 2, -1, -1):\r\n                for u in range(y_len[b] - 1, -1, -1):\r\n                    curr_ = beta[b, t + 1, u] + x[b, t, u, blank_idx]\r\n                    next_ = beta[b, t, u + 1] + x[b, t, u, label[b, u]]\r\n                    beta[b, t, u] = log_sum_exp(curr_, next_)\r\n        return beta\r\n\r\n    def backward(x, label, f_len, y_len, alpha, beta, loss_grad, blank_idx):\r\n        grad = torch.zeros_like(x)\r\n        B, T, U, V = x.size()\r\n        for b in range(B):\r\n            common_factor = torch.log(loss_grad[b]) + alpha - beta[b, 0, 0]\r\n            # next\r\n            for u in range(y_len[b]):\r\n                grad[b, : f_len[b], u, label[b, u]] = -torch.exp(\r\n                    common_factor[b, : f_len[b], u]\r\n                    + beta[b, : f_len[b], u + 1]\r\n                    + x[b, : f_len[b], u, label[b, u]]\r\n                )\r\n\r\n            # current\r\n            grad[b, : f_len[b] - 1, : y_len[b] + 1, blank_idx] = -torch.exp(\r\n                common_factor[b, : f_len[b] - 1, : y_len[b] + 1]\r\n                + beta[b, 1 : f_len[b], : y_len[b] + 1]\r\n                + x[b, : f_len[b] - 1, : y_len[b] + 1, blank_idx]\r\n            )\r\n\r\n            grad[b, f_len[b] - 1, y_len[b], blank_idx] = -torch.exp(\r\n                common_factor[b, f_len[b] - 1, y_len[b]] + x[b, f_len[b] - 1, y_len[b], blank_idx]\r\n            )\r\n\r\n        return grad\r\n\r\n    x_log = torch.nn.functional.log_softmax(x, dim=-1)\r\n    alpha = forward_alpha(x_log, label, f_len, y_len, blank_idx)\r\n    beta = forward_beta(x_log, label, f_len, y_len, blank_idx)\r\n    grad = backward(x_log, label, f_len, y_len, alpha, beta, loss_grad, blank_idx)\r\n    x_log.backward(grad)\r\n    loss = -beta[:, 0, 0]\r\n    loss = loss.to(x.dtype)\r\n    return alpha, beta, x.grad, loss\r\n\r\n\r\ndef transducer_joint_reference(\r\n    f, g, h_grad, f_len, g_len, pack_output, relu, dropout, dropout_prob=0, mask=None\r\n):\r\n    if dropout and mask == None:\r\n        raise NotImplementedError(\"mask needs to supplied to test dropout.\")\r\n    B, T, H = f.size()\r\n    U = g.size(1)\r\n    f_expand = f.unsqueeze(dim=2)\r\n    g_expand = g.unsqueeze(dim=1)\r\n    h = f_expand + g_expand\r\n    if relu:\r\n        h = torch.nn.functional.relu(h)\r\n    if dropout:\r\n        h *= mask\r\n        scale = 1 / (1 - dropout_prob)\r\n        h *= scale\r\n    h.backward(h_grad)\r\n\r\n    if pack_output == False:\r\n        # intentionally set don't-care region to -1 to test if transducer joint\r\n        # write these regions to avoid NaN and inf\r\n        for b in range(B):\r\n            h[b, f_len[b] :] = -1\r\n            h[b, :, g_len[b] :] = -1\r\n\r\n        return h, f.grad, g.grad\r\n\r\n    # packing\r\n    list_to_pack = []\r\n    for b in range(B):\r\n        list_to_pack.append(h[b, : f_len[b], : g_len[b], :].reshape(-1, H))\r\n    h_packed = torch.cat(list_to_pack)\r\n    return h_packed, f.grad, g.grad\r\n"
  },
  {
    "path": "apex/contrib/transducer/transducer.py",
    "content": "import torch\r\nimport transducer_loss_cuda\r\nimport transducer_joint_cuda\r\n\r\n\r\nclass TransducerJoint(torch.nn.Module):\r\n    \"\"\"Transducer joint\r\n    Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural\r\n    Networks\r\n\r\n    Arguments:\r\n        pack_output (bool, optional): whether to pack the output in a compact form with don't-care\r\n        data being removed. (default: False)\r\n        relu (bool, optional): apply ReLU to the output of the joint operation. Requires opt=1\r\n        (default: False)\r\n        dropout (bool, optional): apply dropout to the output of the joint operation. Requires opt=1\r\n        (default: False)\r\n        opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a tiled algorithm.\r\n            (default: 1)\r\n        fwd_tile_size (int, optional): tile size used in forward operation. This argument will be\r\n        ignored if opt != 1. (default: 4)\r\n        dropout_prob (float, optional): dropout probability. (default: 0.0)\r\n        probe_mask (bool, optional): a flag used to probe the mask generated by ReLU and/or dropout\r\n        operation. When this argument is set to True, the mask can be accessed through\r\n        self.mask_probe. (default: false)\r\n    \"\"\"\r\n\r\n    def __init__(\r\n        self,\r\n        pack_output=False,\r\n        relu=False,\r\n        dropout=False,\r\n        opt=1,\r\n        fwd_tile_size=4,\r\n        dropout_prob=0,\r\n        probe_mask=False,\r\n    ):\r\n        super(TransducerJoint, self).__init__()\r\n        self.pack_output = pack_output\r\n        self.relu = relu\r\n        self.dropout = dropout\r\n        self.dropout_prob = dropout_prob\r\n        self.opt = opt\r\n        self.fwd_tile_size = fwd_tile_size\r\n        self.dummy_batch_offset = torch.empty(0)\r\n        masked = self.relu or self.dropout\r\n        self.mask_probe = [] if masked and probe_mask else None\r\n        if masked and opt != 1:\r\n            raise NotImplementedError(\"ReLU and dropout fusion is only supported with opt=1\")\r\n\r\n    def forward(self, f, g, f_len, g_len, batch_offset=None, packed_batch=0):\r\n        \"\"\"Forward operation of transducer joint\r\n\r\n        Arguments:\r\n            f (tensor): transcription vector from encode block of shape (B, T, H).\r\n            g (tensor): prediction vector form predict block of shape (B, U, H).\r\n            f_len (tensor): length of transcription vector for each batch.\r\n            g_len (tensor): length of prediction vector minus 1 for each batch.\r\n            batch_offset (tensor, optional): tensor containing the offset of each batch\r\n                in the results. For example, batch offset can be obtained from:\r\n                batch_offset = torch.cumsum(f_len*g_len, dim=0)\r\n                This argument is required if pack_output == True, and is ignored if\r\n                pack_output == False. (default: None)\r\n            packed_batch (int, optional): the batch size after packing. This argument is\r\n                ignored if pack_output == False. (default: 0)\r\n        \"\"\"\r\n        my_batch_offset = batch_offset if self.pack_output else self.dummy_batch_offset\r\n        if self.pack_output and (batch_offset is None or packed_batch == 0):\r\n            raise Exception(\"Please specify batch_offset and packed_batch when packing is enabled\")\r\n        dropout = self.dropout and self.training  # only dropout for training\r\n        return TransducerJointFunc.apply(\r\n            f,\r\n            g,\r\n            f_len,\r\n            g_len,\r\n            self.pack_output,\r\n            self.relu,\r\n            dropout,\r\n            my_batch_offset,\r\n            packed_batch,\r\n            self.opt,\r\n            self.fwd_tile_size,\r\n            self.dropout_prob,\r\n            self.mask_probe,\r\n        )\r\n\r\n\r\nclass TransducerLoss(torch.nn.Module):\r\n    \"\"\"Transducer loss\r\n    Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural\r\n    Networks\r\n\r\n    Arguments:\r\n        fuse_softmax_backward (bool, optional) whether to fuse the backward of transducer loss with\r\n            softmax. (default: True)\r\n        opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a more optimized\r\n            algorithm. In some cases, opt=1 might fall back to opt=0. (default: 1)\r\n        packed_input (bool, optional): whether to pack the output in a compact form with don't-care\r\n        data being removed. (default: False)\r\n    \"\"\"\r\n\r\n    def __init__(self, fuse_softmax_backward=True, opt=1, packed_input=False):\r\n        super(TransducerLoss, self).__init__()\r\n        self.fuse_softmax_backward = fuse_softmax_backward\r\n        self.opt = opt\r\n        self.packed_input = packed_input\r\n        self.dummy_batch_offset = torch.empty(0)\r\n\r\n    def forward(\r\n        self,\r\n        x,\r\n        label,\r\n        f_len,\r\n        y_len,\r\n        blank_idx,\r\n        batch_offset=None,\r\n        max_f_len=None,\r\n        debug_list=None,\r\n    ):\r\n        \"\"\"Forward operation of transducer joint\r\n\r\n        Arguments:\r\n            x (tensor): input tensor to the loss function with a shape of (B, T, U, H).\r\n            label (tensor): labels for the input data.\r\n            f_len (tensor): lengths of the inputs in the time dimension for each batch.\r\n            y_len (tensor): lengths of the labels for each batch.\r\n            blank_idx (int): index for the null symbol.\r\n            batch_offset (tensor, optional): tensor containing the offset of each batch\r\n                in the input. For example, batch offset can be obtained from:\r\n                batch_offset = torch.cumsum(f_len*(y_len+1), dim=0)\r\n                This argument is required if packed_input == True, and is ignored if\r\n                packed_input == False. (default: None)\r\n            max_f_len (int, optional): maximum length of the input in the time dimension.\r\n                For example, it can be obtained as\r\n                max_f_len = max(f_len)\r\n                This argument is required if packed_input == True, and is ignored if\r\n                packed_input == False. (default: None)\r\n                (default: None)\r\n            debug_list (list, optional): when an empty list is supplied, Alpha and Beta generated\r\n                in the forward operation will be attached to this list for debug purpose.\r\n                (default: None)\r\n        \"\"\"\r\n        if self.packed_input:\r\n            if batch_offset is None or max_f_len is None:\r\n                raise Exception(\r\n                    \"Please specify batch_offset and max_f_len when packing is \\\r\n                                    enabled\"\r\n                )\r\n            my_batch_offset = batch_offset\r\n            my_max_f_len = max_f_len\r\n        else:\r\n            my_batch_offset = self.dummy_batch_offset\r\n            my_max_f_len = x.size(1)\r\n        return TransducerLossFunc.apply(\r\n            x,\r\n            label,\r\n            f_len,\r\n            y_len,\r\n            my_batch_offset,\r\n            my_max_f_len,\r\n            blank_idx,\r\n            self.fuse_softmax_backward,\r\n            debug_list,\r\n            self.opt,\r\n            self.packed_input,\r\n        )\r\n\r\n\r\nclass TransducerLossFunc(torch.autograd.Function):\r\n    @staticmethod\r\n    def forward(\r\n        ctx,\r\n        x,\r\n        label,\r\n        f_len,\r\n        y_len,\r\n        batch_offset,\r\n        max_f_len,\r\n        blank_idx,\r\n        fuse_softmax_backward,\r\n        debug_list,\r\n        opt,\r\n        packed_input,\r\n    ):\r\n        if fuse_softmax_backward == False:\r\n            with torch.enable_grad():\r\n                x = torch.nn.functional.log_softmax(x, dim=-1)\r\n        else:\r\n            x = torch.nn.functional.log_softmax(x, dim=-1)\r\n        alpha, beta, loss = transducer_loss_cuda.forward(\r\n            x,\r\n            label,\r\n            f_len,\r\n            y_len,\r\n            batch_offset,\r\n            max_f_len,\r\n            blank_idx,\r\n            opt,\r\n            packed_input,\r\n        )\r\n        if debug_list == []:\r\n            debug_list += [alpha, beta]\r\n        ctx.save_for_backward(x, alpha, beta, f_len, y_len, label, batch_offset)\r\n        ctx.blank_idx = blank_idx\r\n        ctx.fuse_softmax_backward = fuse_softmax_backward\r\n        ctx.opt = opt\r\n        ctx.packed_input = packed_input\r\n        ctx.max_f_len = max_f_len\r\n        return loss\r\n\r\n    @staticmethod\r\n    def backward(ctx, loss_grad):\r\n        x, alpha, beta, f_len, y_len, label, batch_offset = ctx.saved_tensors\r\n        x_grad = transducer_loss_cuda.backward(\r\n            x,\r\n            loss_grad,\r\n            alpha,\r\n            beta,\r\n            f_len,\r\n            y_len,\r\n            label,\r\n            batch_offset,\r\n            ctx.max_f_len,\r\n            ctx.blank_idx,\r\n            ctx.opt,\r\n            ctx.fuse_softmax_backward,\r\n            ctx.packed_input,\r\n        )\r\n        if ctx.fuse_softmax_backward == False:\r\n            x_grad = x.backward(x_grad)\r\n        return x_grad, None, None, None, None, None, None, None, None, None, None\r\n\r\n\r\nclass TransducerJointFunc(torch.autograd.Function):\r\n    @staticmethod\r\n    def forward(\r\n        ctx,\r\n        f,\r\n        g,\r\n        f_len,\r\n        g_len,\r\n        pack_output,\r\n        relu,\r\n        dropout,\r\n        batch_offset,\r\n        packed_batch,\r\n        opt,\r\n        fwd_tile_size,\r\n        dropout_prob,\r\n        mask_probe,\r\n    ):\r\n        h = transducer_joint_cuda.forward(\r\n            f,\r\n            g,\r\n            f_len,\r\n            g_len,\r\n            batch_offset,\r\n            packed_batch,\r\n            opt,\r\n            pack_output,\r\n            relu,\r\n            dropout,\r\n            dropout_prob,\r\n            fwd_tile_size,\r\n        )\r\n        masked = relu or dropout\r\n        if masked:\r\n            ctx.save_for_backward(h[1], f_len, g_len, batch_offset)\r\n            if mask_probe is not None:\r\n                mask_probe.append(h[1])\r\n        else:\r\n            ctx.save_for_backward(f_len, g_len, batch_offset)\r\n\r\n        ctx.pack_output = pack_output\r\n        ctx.masked = relu or dropout\r\n        ctx.max_f_len = f.size(1)\r\n        ctx.max_g_len = g.size(1)\r\n        ctx.scale = 1 / (1 - dropout_prob) if dropout and dropout_prob != 1 else 1\r\n        return h[0]\r\n\r\n    @staticmethod\r\n    def backward(ctx, loss_grad):\r\n        if ctx.masked:\r\n            mask, f_len, g_len, batch_offset = ctx.saved_tensors\r\n            inp = [loss_grad, mask]\r\n        else:\r\n            f_len, g_len, batch_offset = ctx.saved_tensors\r\n            inp = [loss_grad]\r\n\r\n        f_grad, g_grad = transducer_joint_cuda.backward(\r\n            inp,\r\n            f_len,\r\n            g_len,\r\n            batch_offset,\r\n            ctx.max_f_len,\r\n            ctx.max_g_len,\r\n            ctx.pack_output,\r\n            ctx.scale,\r\n        )\r\n\r\n        return (\r\n            f_grad,\r\n            g_grad,\r\n            None,\r\n            None,\r\n            None,\r\n            None,\r\n            None,\r\n            None,\r\n            None,\r\n            None,\r\n            None,\r\n            None,\r\n            None,\r\n            None,\r\n            None,\r\n            None,\r\n        )\r\n"
  },
  {
    "path": "apex/contrib/xentropy/__init__.py",
    "content": "from .softmax_xentropy import SoftmaxCrossEntropyLoss\n\n\n__all__ = [\n    \"SoftmaxCrossEntropyLoss\",\n]\n"
  },
  {
    "path": "apex/contrib/xentropy/softmax_xentropy.py",
    "content": "import torch\n\nimport xentropy_cuda\n\n\nclass SoftmaxCrossEntropyLoss(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, half_to_float=False):\n        losses, max_log_sum_exp = xentropy_cuda.forward(logits, labels, smoothing, half_to_float)\n        losses.masked_fill_(labels == padding_idx, 0)\n\n        ctx.save_for_backward(\n            logits,\n            max_log_sum_exp,\n            labels,\n            torch.FloatTensor([smoothing]),\n            torch.LongTensor([padding_idx]),\n        )\n\n        return losses\n\n    @staticmethod\n    def backward(ctx, grad_loss):\n        logits, max_log_sum_exp, labels, smoothing, padding_idx = ctx.saved_tensors\n\n        if not grad_loss.is_contiguous():\n            grad_loss = grad_loss.contiguous()\n        grad_loss.masked_fill_(labels == padding_idx.item(), 0)\n        grad_logits = xentropy_cuda.backward(\n            grad_loss.contiguous(), logits, max_log_sum_exp, labels, smoothing.item()\n        )\n\n        return grad_logits, None, None, None, None\n"
  },
  {
    "path": "apex/distributed_testing/__init__.py",
    "content": "\"\"\"Distributed testing utilities.\"\"\"\n\nfrom apex.distributed_testing.distributed_test_base import (\n    DistributedTestBase,\n    NcclDistributedTestBase,\n    UccDistributedTestBase,\n)\n\n__all__ = [\n    \"DistributedTestBase\",\n    \"NcclDistributedTestBase\",\n    \"UccDistributedTestBase\",\n]\n"
  },
  {
    "path": "apex/distributed_testing/_ucc_util.py",
    "content": "from torch import distributed as dist\n\nHAS_UCC = hasattr(dist, \"is_ucc_available\") and dist.is_ucc_available()\nif not HAS_UCC:\n    try:\n        import torch_ucc\n\n        HAS_UCC = True\n    except ImportError:\n        HAS_UCC = False\n"
  },
  {
    "path": "apex/distributed_testing/distributed_test_base.py",
    "content": "import os\nimport sys\nimport unittest\nfrom packaging.version import Version, parse\n\nimport torch\nfrom torch import distributed as dist\nfrom torch.utils import collect_env\nfrom torch.testing._internal import common_utils\nfrom torch.testing._internal import common_distributed\n\nfrom apex.distributed_testing._ucc_util import HAS_UCC\n\n# NOTE(mkozuki): Version guard for ucc. ref: https://github.com/openucx/ucc/issues/496\n_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION = Version(\"470.42.01\")\n_driver_version = None\nif torch.cuda.is_available():\n    _driver_version = parse(collect_env.get_nvidia_driver_version(collect_env.run))\nHAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER = (\n    _driver_version is not None and _driver_version >= _TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION\n)\n\n\nclass DistributedTestBase(common_distributed.MultiProcessTestCase):\n    def __init__(self, *args, **kwargs) -> None:\n        super().__init__(*args, **kwargs)\n\n    def setUp(self) -> None:\n        super().setUp()\n        self._setup_pre_spawn()\n        self._spawn_processes()\n\n    def tearDown(self) -> None:\n        torch.cuda.empty_cache()\n        super().tearDown()\n\n    @property\n    def world_size(self) -> int:\n        return min(torch.cuda.device_count(), 4)\n\n    @property\n    def init_method(self):\n        return f\"{common_utils.FILE_SCHEMA}{self.file_name}\"\n\n    @property\n    def destroy_pg_upon_exit(self) -> bool:\n        # Overriding base test class: do not auto destroy PG upon exit.\n        return False\n\n    @classmethod\n    def _run(cls, rank, test_name, file_name, pipe, **kwargs):\n        self = cls(test_name)\n        self.assertTrue(torch.cuda.is_available())\n        self.assertTrue(hasattr(self, \"DISTRIBUTED_BACKEND\"))\n        self.rank = rank\n        self.file_name = file_name\n\n        print(f\"[dist init] rank = {self.rank}, world_size = {self.world_size}\")\n\n        try:\n            dist.init_process_group(\n                init_method=self.init_method,\n                backend=self.DISTRIBUTED_BACKEND,\n                world_size=int(self.world_size),\n                rank=self.rank,\n            )\n        except RuntimeError as e:\n            if \"recompile\" in e.args[0]:\n                print(f\"Backend of {self.DISTRIBUTED_BACKEND} not available\")\n                sys.exit(0)\n            raise\n\n        torch.cuda.set_device(self.rank % torch.cuda.device_count())\n\n        dist.barrier()\n        self.run_test(test_name, pipe)\n        dist.barrier()\n\n        dist.destroy_process_group()\n        sys.exit(0)\n\n    def _setup_pre_spawn(self):\n        pass\n\n\nclass NcclDistributedTestBase(DistributedTestBase):\n    DISTRIBUTED_BACKEND = \"nccl\"\n\n\n@unittest.skipUnless(\n    HAS_UCC,\n    \"Requires either torch ucc or pytorch build from source with native ucc installed and enabled\",\n)\n@unittest.skipUnless(\n    HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER,\n    f\"`torch_ucc` requires NVIDIA driver >= {_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION} but {_driver_version} found. \"\n    \"See https://github.com/openucx/ucc/issues/496\",\n)\nclass UccDistributedTestBase(DistributedTestBase):\n    DISTRIBUTED_BACKEND = \"ucc\"\n\n    def _setup_pre_spawn(self) -> None:\n        self.master_addr = \"localhost\"\n        os.environ[\"MASTER_ADDR\"] = \"localhost\"\n        self._has_master_port = \"MASTER_PORT\" in os.environ\n        if self._has_master_port:\n            self.master_port = os.environ[\"MASTER_PORT\"]\n        else:\n            try:\n                from caffe2.torch.fb.common.utils import get_free_port\n\n                self.master_port = str(get_free_port())\n            except ImportError:\n                self.master_port = \"12375\"\n            os.environ[\"MASTER_PORT\"] = self.master_port\n\n        self._has_ucx_tls = \"UCX_TLS\" in os.environ\n        if not self._has_ucx_tls:\n            os.environ[\"UCX_TLS\"] = \"tcp,cuda\"\n        print('os.environ[\"UCX_TLS\"] = {}'.format(os.environ[\"UCX_TLS\"]))\n\n    def tearDown(self) -> None:\n        super().tearDown()\n        if not self._has_master_port:\n            del os.environ[\"MASTER_PORT\"]\n        if not self._has_ucx_tls:\n            del os.environ[\"UCX_TLS\"]\n\n    @property\n    def init_method(self):\n        return \"tcp://localhost:\" + os.environ[\"MASTER_PORT\"]\n"
  },
  {
    "path": "apex/fused_dense/__init__.py",
    "content": "from .fused_dense import *\n"
  },
  {
    "path": "apex/fused_dense/fused_dense.py",
    "content": "import torch\nfrom torch import nn\nimport fused_dense_cuda\nfrom apex._autocast_utils import _cast_if_autocast_enabled\n\n\n# implements fused GEMM+bias in forward pass using mlp_cuda from apex\nclass FusedDenseFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input, weight, bias):\n        ctx.save_for_backward(input, weight)\n        output = fused_dense_cuda.linear_bias_forward(input, weight, bias)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        input, weight = ctx.saved_tensors\n        grad_input, grad_weight, grad_bias = fused_dense_cuda.linear_bias_backward(\n            input, weight, grad_output\n        )\n        return grad_input, grad_weight, grad_bias\n\n\nclass DenseNoBiasFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input, weight):\n        ctx.save_for_backward(input, weight)\n        output = torch.matmul(input, weight.t())\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        input, weight = ctx.saved_tensors\n        grad_input = grad_output.mm(weight)\n        grad_weight = grad_output.t().mm(input)\n        return grad_input, grad_weight\n\n\nclass FusedDenseGeluDenseFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input, weight1, bias1, weight2, bias2):\n        ctx.save_for_backward(input, weight1, weight2)\n        output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(\n            input, weight1, bias1, weight2, bias2\n        )\n        ctx.save_for_backward(input, weight1, weight2, gelu_in, output1)\n        return output2\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        input, weight1, weight2, gelu_in, output1 = ctx.saved_tensors\n        grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = (\n            fused_dense_cuda.linear_gelu_linear_backward(\n                input, gelu_in, output1, weight1, weight2, grad_output\n            )\n        )\n        return grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2\n\n\ndef _fused_dense(input, weight, bias):\n    args = _cast_if_autocast_enabled(input, weight, bias)\n    with torch.amp.autocast(\"cuda\", enabled=False):\n        return FusedDenseFunc.apply(*args)\n\n\ndef _dense_no_bias(input, weight):\n    args = _cast_if_autocast_enabled(input, weight)\n    with torch.amp.autocast(\"cuda\", enabled=False):\n        return DenseNoBiasFunc.apply(*args)\n\n\ndef _fused_dense_gelu_dense(input, weight1, bias1, weight2, bias2):\n    args = _cast_if_autocast_enabled(input, weight1, bias1, weight2, bias2)\n    with torch.amp.autocast(\"cuda\", enabled=False):\n        return FusedDenseGeluDenseFunc.apply(*args)\n\n\nclass FusedDense(nn.Module):\n    def __init__(self, in_features, out_features, bias=True):\n        super(FusedDense, self).__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.weight = nn.Parameter(torch.empty(out_features, in_features))\n        if bias:\n            self.bias = nn.Parameter(torch.empty(out_features))\n        else:\n            # assert False, \"no-bias option not added yet\"\n            self.register_parameter(\"bias\", None)\n\n    def forward(self, input):\n        if self.bias is not None:\n            return _fused_dense(input, self.weight, self.bias)\n        else:\n            return _dense_no_bias(input, self.weight)\n\n\nclass FusedDenseGeluDense(nn.Module):\n    def __init__(self, in_features, intermediate_features, out_features, bias=True):\n        super(FusedDenseGeluDense, self).__init__()\n        assert bias == True, \"DenseGeluDense module without bias is currently not supported\"\n        self.in_features = in_features\n        self.intermediate_features = intermediate_features\n        self.out_features = out_features\n        self.weight1 = nn.Parameter(torch.empty(intermediate_features, in_features))\n        self.bias1 = nn.Parameter(torch.empty(intermediate_features))\n        self.weight2 = nn.Parameter(torch.empty(out_features, intermediate_features))\n        self.bias2 = nn.Parameter(torch.empty(out_features))\n\n    def forward(self, input):\n        return _fused_dense_gelu_dense(input, self.weight1, self.bias1, self.weight2, self.bias2)\n"
  },
  {
    "path": "apex/mlp/__init__.py",
    "content": "from .mlp import *\n"
  },
  {
    "path": "apex/mlp/mlp.py",
    "content": "from copy import copy\nimport math\n\nimport torch\nfrom torch import nn\n\nfrom apex._autocast_utils import _cast_if_autocast_enabled\nimport mlp_cuda\n\n\nclass MlpFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, bias, activation, *args):\n        output = mlp_cuda.forward(bias, activation, args)\n        ctx.save_for_backward(*args)\n        ctx.outputs = output\n        ctx.bias = bias\n        ctx.activation = activation\n        return output[0]\n\n    @staticmethod\n    def backward(ctx, grad_o):\n        grads = mlp_cuda.backward(ctx.bias, ctx.activation, grad_o, ctx.outputs, ctx.saved_tensors)\n        del ctx.outputs\n        return (None, None, *grads)\n\n\ndef mlp_function(bias, activation, *args):\n    autocast_args = _cast_if_autocast_enabled(bias, activation, *args)\n    return MlpFunction.apply(*autocast_args)\n\n\nclass MLP(torch.nn.Module):\n    \"\"\"Launch MLP in C++\n\n    Args:\n        mlp_sizes (list of int): MLP sizes. Example: [1024,1024,1024] will create 2 MLP layers with shape 1024x1024\n        bias (bool): Default True:\n        relu (bool): Default True\n    \"\"\"\n\n    def __init__(self, mlp_sizes, bias=True, activation=\"relu\"):\n        super().__init__()\n        self.num_layers = len(mlp_sizes) - 1\n        self.mlp_sizes = copy(mlp_sizes)\n        self.bias = 1 if bias else 0\n\n        if activation == \"none\":\n            self.activation = 0\n        elif activation == \"relu\":\n            self.activation = 1\n        elif activation == \"sigmoid\":\n            self.activation = 2\n        else:\n            raise TypeError(\"activation must be relu or none.\")\n\n        self.weights = []\n        self.biases = []\n        for i in range(self.num_layers):\n            w = torch.nn.Parameter(torch.empty(mlp_sizes[i + 1], mlp_sizes[i]))\n            self.weights.append(w)\n            name = \"weight_{}\".format(i)\n            setattr(self, name, w)\n            if self.bias:\n                b = torch.nn.Parameter(torch.empty(mlp_sizes[i + 1]))\n                self.biases.append(b)\n                name = \"bias_{}\".format(i)\n                setattr(self, name, b)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for weight in self.weights:\n            dimsum = weight.size(0) + weight.size(1)\n            std = math.sqrt(2.0 / float(dimsum))\n            nn.init.normal_(weight, 0.0, std)\n        if self.bias:\n            for bias in self.biases:\n                std = math.sqrt(1.0 / float(bias.size(0)))\n                nn.init.normal_(bias, 0.0, std)\n\n    def forward(self, input):\n        return mlp_function(self.bias, self.activation, input, *self.weights, *self.biases)\n\n    def extra_repr(self):\n        s = f\"MLP sizes: {self.mlp_sizes}, Bias={self.bias}, activation={self.activation}\"\n        return s\n"
  },
  {
    "path": "apex/multi_tensor_apply/__init__.py",
    "content": "from .multi_tensor_apply import MultiTensorApply\n\nmulti_tensor_applier = MultiTensorApply(2048 * 32)\n"
  },
  {
    "path": "apex/multi_tensor_apply/multi_tensor_apply.py",
    "content": "class MultiTensorApply(object):\n    available = False\n    warned = False\n\n    def __init__(self, chunk_size):\n        try:\n            import amp_C\n\n            MultiTensorApply.available = True\n            self.chunk_size = chunk_size\n        except ImportError as err:\n            MultiTensorApply.available = False\n            MultiTensorApply.import_err = err\n\n    def check_avail(self):\n        if MultiTensorApply.available == False:\n            raise RuntimeError(\n                \"Attempted to call MultiTensorApply method, but MultiTensorApply \"\n                \"is not available, possibly because Apex was installed without \"\n                \"--cpp_ext --cuda_ext.  Original import error message:\",\n                MultiTensorApply.import_err,\n            )\n\n    def __call__(self, op, noop_flag_buffer, tensor_lists, *args):\n        self.check_avail()\n\n        return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)\n"
  },
  {
    "path": "apex/normalization/__init__.py",
    "content": "from .fused_layer_norm import (\n    FusedLayerNorm,\n    MixedFusedLayerNorm,\n    FusedRMSNorm,\n    MixedFusedRMSNorm,\n)\n"
  },
  {
    "path": "apex/normalization/fused_layer_norm.py",
    "content": "import importlib\nimport numbers\n\nimport torch\nfrom torch.nn.parameter import Parameter\nfrom torch.nn import init\nfrom torch.nn import functional as F\nfrom typing import List, Tuple\n\nfrom apex._autocast_utils import _cast_if_autocast_enabled\n\nglobal fused_layer_norm_cuda\nfused_layer_norm_cuda = None\n\n\n# PyTorch supports `torch.library.custom_op` since 2.4.0.\ndef supports_custom_op() -> bool:\n    return hasattr(torch.library, \"custom_op\")\n\n\n# Reference implementation from Huggingface\ndef manual_rms_norm(input, normalized_shape, weight, eps):\n    # layer norm should always be calculated in float32\n    dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))\n    variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True)\n    input = input * torch.rsqrt(variance + eps)\n\n    if weight is None:\n        return input\n\n    # convert into half-precision if necessary\n    if weight.dtype in [torch.float16, torch.bfloat16]:\n        input = input.to(weight.dtype)\n\n    return weight * input\n\n\nclass FusedLayerNormAffineFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input, weight, bias, normalized_shape, eps, memory_efficient=False):\n        global fused_layer_norm_cuda\n        if fused_layer_norm_cuda is None:\n            fused_layer_norm_cuda = importlib.import_module(\"fused_layer_norm_cuda\")\n        ctx.normalized_shape = normalized_shape\n        ctx.eps = eps\n        ctx.memory_efficient = memory_efficient\n        input_ = input.contiguous()\n        weight_ = weight.contiguous()\n        bias_ = bias.contiguous()\n        output, mean, invvar = fused_layer_norm_cuda.forward_affine(\n            input_, ctx.normalized_shape, weight_, bias_, ctx.eps\n        )\n        if ctx.memory_efficient:\n            ctx.save_for_backward(output, weight_, bias_, None, invvar)\n        else:\n            ctx.save_for_backward(input_, weight_, bias_, mean, invvar)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        input_or_output, weight_, bias_, mean, invvar = ctx.saved_tensors\n        grad_input = grad_weight = grad_bias = None\n        grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(\n            grad_output.contiguous(),\n            mean,\n            invvar,\n            input_or_output,\n            ctx.normalized_shape,\n            weight_,\n            bias_,\n            ctx.eps,\n            ctx.memory_efficient,\n        )\n        return grad_input, grad_weight, grad_bias, None, None, None\n\n\nif supports_custom_op():\n\n    @torch.library.custom_op(\"apex::fused_layer_norm_affine_fwd\", mutates_args=())\n    def fused_layer_norm_affine_fwd(\n        input: torch.Tensor,\n        weight: torch.Tensor,\n        bias: torch.Tensor,\n        normalized_shape: List[int],\n        eps: float,\n        memory_efficient: bool = False,\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        global fused_layer_norm_cuda\n        if fused_layer_norm_cuda is None:\n            fused_layer_norm_cuda = importlib.import_module(\"fused_layer_norm_cuda\")\n\n        input_ = input.contiguous()\n        weight_ = weight.contiguous()\n        bias_ = bias.contiguous()\n        output, mean, invvar = fused_layer_norm_cuda.forward_affine(\n            input_, normalized_shape, weight_, bias_, eps\n        )\n        return output, mean, invvar\n\n    @fused_layer_norm_affine_fwd.register_fake\n    def fused_layer_norm_affine_fwd_fake(\n        input: torch.Tensor,\n        weight: torch.Tensor,\n        bias: torch.Tensor,\n        normalized_shape: List[int],\n        eps: float,\n        memory_efficient: bool = False,\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        input = input.contiguous()\n        weight = weight.contiguous()\n        bias = bias.contiguous()\n        idiff = input.ndim - len(normalized_shape)\n        n = 1\n        for i in range(idiff):\n            n *= input.shape[i]\n        if input.dtype in [torch.float16, torch.bfloat16]:\n            dtype = torch.float32\n        else:\n            dtype = input.dtype\n        mean = torch.empty([n], dtype=dtype, device=input.device)\n        invvar = torch.empty_like(mean)\n        return torch.empty_like(input), mean, invvar\n\n    @torch.library.custom_op(\"apex::fused_layer_norm_affine_bwd\", mutates_args=())\n    def fused_layer_norm_affine_bwd(\n        grad_output: torch.Tensor,\n        mean: torch.Tensor,\n        invvar: torch.Tensor,\n        input_or_output: torch.Tensor,\n        normalized_shape: List[int],\n        weight: torch.Tensor,\n        bias: torch.Tensor,\n        eps: float,\n        memory_efficient: bool = False,\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(\n            grad_output.contiguous(),\n            mean,\n            invvar,\n            input_or_output,\n            normalized_shape,\n            weight,\n            bias,\n            eps,\n            memory_efficient,\n        )\n        return grad_input, grad_weight, grad_bias\n\n    @fused_layer_norm_affine_bwd.register_fake\n    def fused_layer_norm_affine_bwd_fake(\n        grad_output: torch.Tensor,\n        mean: torch.Tensor,\n        invvar: torch.Tensor,\n        input_or_output: torch.Tensor,\n        normalized_shape: List[int],\n        weight: torch.Tensor,\n        bias: torch.Tensor,\n        eps: float,\n        memory_efficient: bool = False,\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        grad_input = torch.empty_like(input_or_output)\n        grad_weight = torch.empty_like(weight)\n        grad_bias = torch.empty_like(bias)\n        return grad_input, grad_weight, grad_bias\n\n    def _fused_layer_norm_affine_backward(ctx, grad_output, grad_mean, grad_invvar):\n        input_or_output, weight_, bias_, mean, invvar = ctx.saved_tensors\n        grad_input = grad_weight = grad_bias = None\n        grad_input, grad_weight, grad_bias = fused_layer_norm_affine_bwd(\n            grad_output,\n            mean,\n            invvar,\n            input_or_output,\n            ctx.normalized_shape,\n            weight_,\n            bias_,\n            ctx.eps,\n            ctx.memory_efficient,\n        )\n        return grad_input, grad_weight, grad_bias, None, None, None\n\n    def _fused_layer_norm_affine_setup_context(ctx, inputs, output):\n        input, weight, bias, normalized_shape, eps, memory_efficient = inputs\n        output, mean, invvar = output\n        input_ = input.contiguous()\n        weight_ = weight.contiguous()\n        bias_ = bias.contiguous()\n        if memory_efficient:\n            ctx.save_for_backward(output, weight_, bias_, None, invvar)\n        else:\n            ctx.save_for_backward(input_, weight_, bias_, mean, invvar)\n        ctx.normalized_shape = normalized_shape\n        ctx.eps = eps\n        ctx.memory_efficient = memory_efficient\n\n    fused_layer_norm_affine_fwd.register_autograd(\n        _fused_layer_norm_affine_backward,\n        setup_context=_fused_layer_norm_affine_setup_context,\n    )\n\n\nclass FusedRMSNormAffineFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input, weight, normalized_shape, eps, memory_efficient=False):\n        global fused_layer_norm_cuda\n        if fused_layer_norm_cuda is None:\n            fused_layer_norm_cuda = importlib.import_module(\"fused_layer_norm_cuda\")\n        ctx.normalized_shape = normalized_shape\n        ctx.eps = eps\n        ctx.memory_efficient = memory_efficient\n        input_ = input.contiguous()\n        weight_ = weight.contiguous()\n        output, invvar = fused_layer_norm_cuda.rms_forward_affine(\n            input_, ctx.normalized_shape, weight_, ctx.eps\n        )\n        if ctx.memory_efficient:\n            ctx.save_for_backward(output, weight_, invvar)\n        else:\n            ctx.save_for_backward(input_, weight_, invvar)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        input_or_output, weight_, invvar = ctx.saved_tensors\n        grad_input = grad_weight = None\n        grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine(\n            grad_output.contiguous(),\n            invvar,\n            input_or_output,\n            ctx.normalized_shape,\n            weight_,\n            ctx.eps,\n            ctx.memory_efficient,\n        )\n        return grad_input, grad_weight, None, None, None\n\n\nif supports_custom_op():\n\n    @torch.library.custom_op(\"apex::fused_rms_norm_affine_fwd\", mutates_args=())\n    def fused_rms_norm_affine_fwd(\n        input: torch.Tensor,\n        weight: torch.Tensor,\n        normalized_shape: List[int],\n        eps: float,\n        memory_efficient: bool = False,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        global fused_layer_norm_cuda\n        if fused_layer_norm_cuda is None:\n            fused_layer_norm_cuda = importlib.import_module(\"fused_layer_norm_cuda\")\n\n        input_ = input.contiguous()\n        weight_ = weight.contiguous()\n        output, invvar = fused_layer_norm_cuda.rms_forward_affine(\n            input_, normalized_shape, weight_, eps\n        )\n        return output, invvar\n\n    @fused_rms_norm_affine_fwd.register_fake\n    def fused_rms_norm_affine_fwd_fake(\n        input: torch.Tensor,\n        weight: torch.Tensor,\n        normalized_shape: List[int],\n        eps: float,\n        memory_efficient: bool = False,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        input = input.contiguous()\n        weight = weight.contiguous()\n        idiff = input.ndim - len(normalized_shape)\n        n = 1\n        for i in range(idiff):\n            n *= input.shape[i]\n        if input.dtype in [torch.float16, torch.bfloat16]:\n            dtype = torch.float32\n        else:\n            dtype = input.dtype\n        return (\n            torch.empty_like(input),\n            torch.empty(\n                [n],\n                dtype=dtype,\n                device=input.device,\n                requires_grad=input.requires_grad,\n                memory_format=torch.contiguous_format,\n            ),\n        )\n\n    @torch.library.custom_op(\"apex::fused_rms_norm_affine_bwd\", mutates_args=())\n    def fused_rms_norm_affine_bwd(\n        grad_output: torch.Tensor,\n        invvar: torch.Tensor,\n        input_or_output: torch.Tensor,\n        normalized_shape: List[int],\n        weight: torch.Tensor,\n        eps: float,\n        memory_efficient: bool = False,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine(\n            grad_output.contiguous(),\n            invvar,\n            input_or_output,\n            normalized_shape,\n            weight,\n            eps,\n            memory_efficient,\n        )\n        return grad_input, grad_weight\n\n    @fused_rms_norm_affine_bwd.register_fake\n    def fused_rms_norm_affine_bwd_fake(\n        grad_output: torch.Tensor,\n        invvar: torch.Tensor,\n        input_or_output: torch.Tensor,\n        normalized_shape: List[int],\n        weight: torch.Tensor,\n        eps: float,\n        memory_efficient: bool = False,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        grad_input = torch.empty_like(input_or_output)\n        grad_weight = torch.empty_like(weight)\n        return grad_input, grad_weight\n\n    def _fused_rms_norm_affine_backward(ctx, grad_output, grad_invvar):\n        input_or_output, weight_, invvar = ctx.saved_tensors\n        grad_input = grad_weight = None\n        grad_input, grad_weight = fused_rms_norm_affine_bwd(\n            grad_output,\n            invvar,\n            input_or_output,\n            ctx.normalized_shape,\n            weight_,\n            ctx.eps,\n            ctx.memory_efficient,\n        )\n        return grad_input, grad_weight, None, None, None\n\n    def _fused_rms_norm_affine_setup_context(ctx, inputs, output):\n        input_, weight_, normalized_shape, eps, memory_efficient = inputs\n        output_, invvar = output\n        input_ = input_.contiguous()\n        weight_ = weight_.contiguous()\n        if memory_efficient:\n            ctx.save_for_backward(output_, weight_, invvar)\n        else:\n            ctx.save_for_backward(input_, weight_, invvar)\n        ctx.normalized_shape = normalized_shape\n        ctx.eps = eps\n        ctx.memory_efficient = memory_efficient\n\n    fused_rms_norm_affine_fwd.register_autograd(\n        _fused_rms_norm_affine_backward,\n        setup_context=_fused_rms_norm_affine_setup_context,\n    )\n\n\nclass FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction):\n    @staticmethod\n    def forward(ctx, input, weight, bias, normalized_shape, eps, memory_efficient=False):\n        global fused_layer_norm_cuda\n        if fused_layer_norm_cuda is None:\n            fused_layer_norm_cuda = importlib.import_module(\"fused_layer_norm_cuda\")\n        ctx.normalized_shape = normalized_shape\n        ctx.eps = eps\n        ctx.memory_efficient = memory_efficient\n        input_ = input.contiguous()\n        weight_ = weight.contiguous()\n        bias_ = bias.contiguous()\n        output, mean, invvar = fused_layer_norm_cuda.forward_affine_mixed_dtypes(\n            input_, ctx.normalized_shape, weight_, bias_, ctx.eps\n        )\n        if ctx.memory_efficient:\n            ctx.save_for_backward(output, weight_, bias_, None, invvar)\n        else:\n            ctx.save_for_backward(input_, weight_, bias_, mean, invvar)\n        return output\n\n\nclass FusedRMSNormAffineMixedDtypesFunction(FusedRMSNormAffineFunction):\n    @staticmethod\n    def forward(ctx, input, weight, normalized_shape, eps, memory_efficient=False):\n        global fused_layer_norm_cuda\n        if fused_layer_norm_cuda is None:\n            fused_layer_norm_cuda = importlib.import_module(\"fused_layer_norm_cuda\")\n        ctx.normalized_shape = normalized_shape\n        ctx.eps = eps\n        ctx.memory_efficient = memory_efficient\n        input_ = input.contiguous()\n        weight_ = weight.contiguous()\n        output, invvar = fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes(\n            input_, ctx.normalized_shape, weight_, ctx.eps\n        )\n        if ctx.memory_efficient:\n            ctx.save_for_backward(output, weight_, invvar)\n        else:\n            ctx.save_for_backward(input_, weight_, invvar)\n        return output\n\n\nclass FusedLayerNormFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input, normalized_shape, eps, memory_efficient=False):\n        global fused_layer_norm_cuda\n        if fused_layer_norm_cuda is None:\n            fused_layer_norm_cuda = importlib.import_module(\"fused_layer_norm_cuda\")\n        ctx.normalized_shape = normalized_shape\n        ctx.eps = eps\n        ctx.memory_efficient = memory_efficient\n        input_ = input.contiguous()\n        output, mean, invvar = fused_layer_norm_cuda.forward(input_, ctx.normalized_shape, ctx.eps)\n        if ctx.memory_efficient:\n            ctx.save_for_backward(output, None, invvar)\n        else:\n            ctx.save_for_backward(input_, mean, invvar)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        input_or_output, mean, invvar = ctx.saved_tensors\n        grad_input = fused_layer_norm_cuda.backward(\n            grad_output.contiguous(),\n            mean,\n            invvar,\n            input_or_output,\n            ctx.normalized_shape,\n            ctx.eps,\n            ctx.memory_efficient,\n        )\n        return grad_input, None, None, None\n\n\nif supports_custom_op():\n\n    @torch.library.custom_op(\"apex::fused_layer_norm_fwd\", mutates_args=())\n    def fused_layer_norm_fwd(\n        input: torch.Tensor,\n        normalized_shape: List[int],\n        eps: float,\n        memory_efficient: bool = False,\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        global fused_layer_norm_cuda\n        if fused_layer_norm_cuda is None:\n            fused_layer_norm_cuda = importlib.import_module(\"fused_layer_norm_cuda\")\n\n        input_ = input.contiguous()\n        output, mean, invvar = fused_layer_norm_cuda.forward(input_, normalized_shape, eps)\n        return output, mean, invvar\n\n    @fused_layer_norm_fwd.register_fake\n    def fused_layer_norm_fwd_fake(\n        input: torch.Tensor,\n        normalized_shape: List[int],\n        eps: float,\n        memory_efficient: bool = False,\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        input = input.contiguous()\n        idiff = input.ndim - len(normalized_shape)\n        n = 1\n        for i in range(idiff):\n            n *= input.shape[i]\n        if input.dtype in [torch.float16, torch.bfloat16]:\n            dtype = torch.float32\n        else:\n            dtype = input.dtype\n        mean = torch.empty([n], dtype=dtype, device=input.device)\n        invvar = torch.empty_like(mean)\n        return torch.empty_like(input), mean, invvar\n\n    @torch.library.custom_op(\"apex::fused_layer_norm_bwd\", mutates_args=())\n    def fused_layer_norm_bwd(\n        grad_output: torch.Tensor,\n        mean: torch.Tensor,\n        invvar: torch.Tensor,\n        input_or_output: torch.Tensor,\n        normalized_shape: List[int],\n        eps: float,\n        memory_efficient: bool = False,\n    ) -> torch.Tensor:\n        grad_input = fused_layer_norm_cuda.backward(\n            grad_output.contiguous(),\n            mean,\n            invvar,\n            input_or_output,\n            normalized_shape,\n            eps,\n            memory_efficient,\n        )\n        return grad_input\n\n    @fused_layer_norm_bwd.register_fake\n    def fused_layer_norm_bwd_fake(\n        grad_output: torch.Tensor,\n        mean: torch.Tensor,\n        invvar: torch.Tensor,\n        input_or_output: torch.Tensor,\n        normalized_shape: List[int],\n        eps: float,\n        memory_efficient: bool = False,\n    ) -> torch.Tensor:\n        grad_input = torch.empty_like(input_or_output)\n        return grad_input\n\n    def _fused_layer_norm_backward(ctx, grad_output, grad_mean, grad_invvar):\n        input_or_output, mean, invvar = ctx.saved_tensors\n        grad_input = fused_layer_norm_bwd(\n            grad_output,\n            mean,\n            invvar,\n            input_or_output,\n            ctx.normalized_shape,\n            ctx.eps,\n            ctx.memory_efficient,\n        )\n        return grad_input, None, None, None\n\n    def _fused_layer_norm_setup_context(ctx, inputs, output):\n        input, normalized_shape, eps, memory_efficient = inputs\n        output, mean, invvar = output\n        input_ = input.contiguous()\n        if memory_efficient:\n            ctx.save_for_backward(output, None, invvar)\n        else:\n            ctx.save_for_backward(input_, mean, invvar)\n        ctx.normalized_shape = normalized_shape\n        ctx.eps = eps\n        ctx.memory_efficient = memory_efficient\n\n    fused_layer_norm_fwd.register_autograd(\n        _fused_layer_norm_backward,\n        setup_context=_fused_layer_norm_setup_context,\n    )\n\n\nclass FusedRMSNormFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input, normalized_shape, eps, memory_efficient=False):\n        global fused_layer_norm_cuda\n        if fused_layer_norm_cuda is None:\n            fused_layer_norm_cuda = importlib.import_module(\"fused_layer_norm_cuda\")\n        ctx.normalized_shape = normalized_shape\n        ctx.eps = eps\n        ctx.memory_efficient = memory_efficient\n        input_ = input.contiguous()\n        output, invvar = fused_layer_norm_cuda.rms_forward(input_, ctx.normalized_shape, ctx.eps)\n        if ctx.memory_efficient:\n            ctx.save_for_backward(output, invvar)\n        else:\n            ctx.save_for_backward(input_, invvar)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        input_or_output, invvar = ctx.saved_tensors\n        grad_input = None\n        grad_input = fused_layer_norm_cuda.rms_backward(\n            grad_output.contiguous(),\n            invvar,\n            input_or_output,\n            ctx.normalized_shape,\n            ctx.eps,\n            ctx.memory_efficient,\n        )\n        return grad_input, None, None, None\n\n\nif supports_custom_op():\n\n    @torch.library.custom_op(\"apex::fused_rms_norm_fwd\", mutates_args=())\n    def fused_rms_norm_fwd(\n        input: torch.Tensor,\n        normalized_shape: List[int],\n        eps: float,\n        memory_efficient: bool = False,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        global fused_layer_norm_cuda\n        if fused_layer_norm_cuda is None:\n            fused_layer_norm_cuda = importlib.import_module(\"fused_layer_norm_cuda\")\n\n        input_ = input.contiguous()\n        output, invvar = fused_layer_norm_cuda.rms_forward(input_, normalized_shape, eps)\n        return output, invvar\n\n    @fused_rms_norm_fwd.register_fake\n    def fused_rms_norm_fwd_fake(\n        input: torch.Tensor,\n        normalized_shape: List[int],\n        eps: float,\n        memory_efficient: bool = False,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        input = input.contiguous()\n        idiff = input.ndim - len(normalized_shape)\n        n = 1\n        for i in range(idiff):\n            n *= input.shape[i]\n        if input.dtype in [torch.float16, torch.bfloat16]:\n            dtype = torch.float32\n        else:\n            dtype = input.dtype\n        return (\n            torch.empty_like(input),\n            torch.empty(\n                [n],\n                dtype=dtype,\n                device=input.device,\n                requires_grad=input.requires_grad,\n                memory_format=torch.contiguous_format,\n            ),\n        )\n\n    @torch.library.custom_op(\"apex::fused_rms_norm_bwd\", mutates_args=())\n    def fused_rms_norm_bwd(\n        grad_output: torch.Tensor,\n        invvar: torch.Tensor,\n        input_or_output: torch.Tensor,\n        normalized_shape: List[int],\n        eps: float,\n        memory_efficient: bool = False,\n    ) -> torch.Tensor:\n        grad_input = fused_layer_norm_cuda.rms_backward(\n            grad_output.contiguous(),\n            invvar,\n            input_or_output,\n            normalized_shape,\n            eps,\n            memory_efficient,\n        )\n        return grad_input\n\n    @fused_rms_norm_bwd.register_fake\n    def fused_rms_norm_bwd_fake(\n        grad_output: torch.Tensor,\n        invvar: torch.Tensor,\n        input_or_output: torch.Tensor,\n        normalized_shape: List[int],\n        eps: float,\n        memory_efficient: bool = False,\n    ) -> torch.Tensor:\n        grad_input = torch.empty_like(input_or_output)\n        return grad_input\n\n    def _fused_rms_norm_backward(ctx, grad_output, grad_invvar):\n        input_or_output, invvar = ctx.saved_tensors\n        grad_input = None\n        grad_input = fused_rms_norm_bwd(\n            grad_output,\n            invvar,\n            input_or_output,\n            ctx.normalized_shape,\n            ctx.eps,\n            ctx.memory_efficient,\n        )\n        return grad_input, None, None, None\n\n    def _fused_rms_norm_setup_context(ctx, inputs, output):\n        input_, normalized_shape, eps, memory_efficient = inputs\n        output_, invvar = output\n        input_ = input_.contiguous()\n        if memory_efficient:\n            ctx.save_for_backward(output_, invvar)\n        else:\n            ctx.save_for_backward(input_, invvar)\n        ctx.normalized_shape = normalized_shape\n        ctx.eps = eps\n        ctx.memory_efficient = memory_efficient\n\n    fused_rms_norm_fwd.register_autograd(\n        _fused_rms_norm_backward, setup_context=_fused_rms_norm_setup_context\n    )\n\n\ndef fused_layer_norm_affine(\n    input, weight, bias, normalized_shape, eps=1e-6, memory_efficient=False\n):\n    args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps, memory_efficient)\n    with torch.amp.autocast(\"cuda\", enabled=False):\n        if supports_custom_op():\n            return fused_layer_norm_affine_fwd(*args)[0]\n        else:\n            return FusedLayerNormAffineFunction.apply(*args)\n\n\ndef fused_layer_norm(input, normalized_shape, eps=1e-6, memory_efficient=False):\n    args = _cast_if_autocast_enabled(input, normalized_shape, eps, memory_efficient)\n    with torch.amp.autocast(\"cuda\", enabled=False):\n        if supports_custom_op():\n            return fused_layer_norm_fwd(*args)[0]\n        else:\n            return FusedLayerNormFunction.apply(*args)\n\n\ndef mixed_dtype_fused_layer_norm_affine(\n    input, weight, bias, normalized_shape, eps=1e-6, memory_efficient=False\n):\n    args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps, memory_efficient)\n    with torch.amp.autocast(\"cuda\", enabled=False):\n        return FusedLayerNormAffineMixedDtypesFunction.apply(*args)\n\n\ndef fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6, memory_efficient=False):\n    args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps, memory_efficient)\n    with torch.amp.autocast(\"cuda\", enabled=False):\n        if supports_custom_op():\n            return fused_rms_norm_affine_fwd(*args)[0]\n        else:\n            return FusedRMSNormAffineFunction.apply(*args)\n\n\ndef fused_rms_norm(input, normalized_shape, eps=1e-6, memory_efficient=False):\n    args = _cast_if_autocast_enabled(input, normalized_shape, eps, memory_efficient)\n    with torch.amp.autocast(\"cuda\", enabled=False):\n        if supports_custom_op():\n            return fused_rms_norm_fwd(*args)[0]\n        else:\n            return FusedRMSNormFunction.apply(*args)\n\n\ndef mixed_dtype_fused_rms_norm_affine(\n    input, weight, normalized_shape, eps=1e-6, memory_efficient=False\n):\n    args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps, memory_efficient)\n    with torch.amp.autocast(\"cuda\", enabled=False):\n        return FusedRMSNormAffineMixedDtypesFunction.apply(*args)\n\n\nclass FusedLayerNorm(torch.nn.Module):\n    r\"\"\"Applies Layer Normalization over a mini-batch of inputs as described in\n    the paper `Layer Normalization`_ .\n\n    Currently only runs on cuda() tensors.\n\n    .. math::\n        y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta\n\n    The mean and standard-deviation are calculated separately over the last\n    certain number dimensions which have to be of the shape specified by\n    :attr:`normalized_shape`.\n    :math:`\\gamma` and :math:`\\beta` are learnable affine transform parameters of\n    :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.\n\n    .. note::\n        Unlike Batch Normalization and Instance Normalization, which applies\n        scalar scale and bias for each entire channel/plane with the\n        :attr:`affine` option, Layer Normalization applies per-element scale and\n        bias with :attr:`elementwise_affine`.\n\n    This layer uses statistics computed from input data in both training and\n    evaluation modes.\n\n    Args:\n        normalized_shape (int or list or torch.Size): input shape from an expected input\n            of size\n\n            .. math::\n                [* \\times \\text{normalized}\\_\\text{shape}[0] \\times \\text{normalized}\\_\\text{shape}[1]\n                    \\times \\ldots \\times \\text{normalized}\\_\\text{shape}[-1]]\n\n            If a single integer is used, it is treated as a singleton list, and this module will\n            normalize over the last dimension which is expected to be of that specific size.\n        eps: a value added to the denominator for numerical stability. Default: 1e-5\n        elementwise_affine: a boolean value that when set to ``True``, this module\n            has learnable per-element affine parameters initialized to ones (for weights)\n            and zeros (for biases). Default: ``True``.\n\n    Shape:\n        - Input: :math:`(N, *)`\n        - Output: :math:`(N, *)` (same shape as input)\n\n    Examples::\n\n        >>> input = torch.randn(20, 5, 10, 10)\n        >>> # With Learnable Parameters\n        >>> m = apex.normalization.FusedLayerNorm(input.size()[1:])\n        >>> # Without Learnable Parameters\n        >>> m = apex.normalization.FusedLayerNorm(input.size()[1:], elementwise_affine=False)\n        >>> # Normalize over last two dimensions\n        >>> m = apex.normalization.FusedLayerNorm([10, 10])\n        >>> # Normalize over last dimension of size 10\n        >>> m = apex.normalization.FusedLayerNorm(10)\n        >>> # Activating the module\n        >>> output = m(input)\n\n    .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450\n    \"\"\"\n\n    def __init__(\n        self,\n        normalized_shape,\n        eps=1e-5,\n        elementwise_affine=True,\n        memory_efficient=False,\n    ):\n        super().__init__()\n\n        global fused_layer_norm_cuda\n        fused_layer_norm_cuda = importlib.import_module(\"fused_layer_norm_cuda\")\n\n        if isinstance(normalized_shape, numbers.Integral):\n            normalized_shape = (normalized_shape,)\n        self.normalized_shape = torch.Size(normalized_shape)\n        self.eps = eps\n        self.elementwise_affine = elementwise_affine\n        self.memory_efficient = memory_efficient\n        if self.elementwise_affine:\n            self.weight = Parameter(torch.empty(*normalized_shape))\n            self.bias = Parameter(torch.empty(*normalized_shape))\n        else:\n            self.register_parameter(\"weight\", None)\n            self.register_parameter(\"bias\", None)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        if self.elementwise_affine:\n            init.ones_(self.weight)\n            init.zeros_(self.bias)\n\n    def forward(self, input):\n        if (\n            torch.jit.is_tracing()\n            or torch.jit.is_scripting()\n            or torch.compiler.is_compiling()\n            or not input.is_cuda\n        ):\n            return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)\n        if self.elementwise_affine:\n            return fused_layer_norm_affine(\n                input,\n                self.weight,\n                self.bias,\n                self.normalized_shape,\n                self.eps,\n                self.memory_efficient,\n            )\n        else:\n            return fused_layer_norm(input, self.normalized_shape, self.eps, self.memory_efficient)\n\n    def extra_repr(self):\n        return \"{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}\".format(\n            **self.__dict__\n        )\n\n\nclass FusedRMSNorm(torch.nn.Module):\n    r\"\"\"Applies RMS Normalization over a mini-batch of inputs\n\n    Currently only runs on cuda() tensors.\n\n    .. math::\n        y = \\frac{x}{\\mathrm{RMS}[x]} * \\gamma\n\n    The root-mean-square is calculated separately over the last\n    certain number dimensions which have to be of the shape specified by\n    :attr:`normalized_shape`.\n    :math:`\\gamma` is a learnable affine transform parameter of\n    :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.\n    `epsilon` is added to the mean-square, then the root of the sum is taken.\n\n    .. note::\n        Unlike Batch Normalization and Instance Normalization, which applies\n        scalar scale and bias for each entire channel/plane with the\n        :attr:`affine` option, RMS Normalization applies per-element scale\n        with :attr:`elementwise_affine`.\n\n    This layer uses statistics computed from input data in both training and\n    evaluation modes.\n\n    Args:\n        normalized_shape (int or list or torch.Size): input shape from an expected input\n            of size\n\n            .. math::\n                [* \\times \\text{normalized}\\_\\text{shape}[0] \\times \\text{normalized}\\_\\text{shape}[1]\n                    \\times \\ldots \\times \\text{normalized}\\_\\text{shape}[-1]]\n\n            If a single integer is used, it is treated as a singleton list, and this module will\n            normalize over the last dimension which is expected to be of that specific size.\n        eps: a value added to the denominator for numerical stability. Default: 1e-5\n        elementwise_affine: a boolean value that when set to ``True``, this module\n            has learnable per-element affine parameters initialized to ones (for weights)\n            and zeros (for biases). Default: ``True``.\n\n    Shape:\n        - Input: :math:`(N, *)`\n        - Output: :math:`(N, *)` (same shape as input)\n\n    Examples::\n\n        >>> input = torch.randn(20, 5, 10, 10)\n        >>> # With Learnable Parameters\n        >>> m = apex.normalization.FusedRMSNorm(input.size()[1:])\n        >>> # Without Learnable Parameters\n        >>> m = apex.normalization.FusedRMSNorm(input.size()[1:], elementwise_affine=False)\n        >>> # Normalize over last two dimensions\n        >>> m = apex.normalization.FusedRMSNorm([10, 10])\n        >>> # Normalize over last dimension of size 10\n        >>> m = apex.normalization.FusedRMSNorm(10)\n        >>> # Activating the module\n        >>> output = m(input)\n\n    .. _`Root Mean Square Layer Normalization`: https://arxiv.org/pdf/1910.07467.pdf\n    \"\"\"\n\n    def __init__(\n        self,\n        normalized_shape,\n        eps=1e-5,\n        elementwise_affine=True,\n        memory_efficient=False,\n    ):\n        super().__init__()\n\n        global fused_layer_norm_cuda\n        fused_layer_norm_cuda = importlib.import_module(\"fused_layer_norm_cuda\")\n\n        if isinstance(normalized_shape, numbers.Integral):\n            normalized_shape = (normalized_shape,)\n        self.normalized_shape = torch.Size(normalized_shape)\n        self.eps = eps\n        self.elementwise_affine = elementwise_affine\n        self.memory_efficient = memory_efficient\n        if self.elementwise_affine:\n            self.weight = Parameter(torch.empty(*normalized_shape))\n        else:\n            self.register_parameter(\"weight\", None)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        if self.elementwise_affine:\n            init.ones_(self.weight)\n\n    def forward(self, input):\n        if (\n            torch.jit.is_tracing()\n            or torch.jit.is_scripting()\n            or torch.compiler.is_compiling()\n            or not input.is_cuda\n        ):\n            return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)\n\n        if self.elementwise_affine:\n            return fused_rms_norm_affine(\n                input,\n                self.weight,\n                self.normalized_shape,\n                self.eps,\n                self.memory_efficient,\n            )\n        else:\n            return fused_rms_norm(input, self.normalized_shape, self.eps, self.memory_efficient)\n\n    def extra_repr(self):\n        return \"{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}\".format(\n            **self.__dict__\n        )\n\n\n# NOTE (mkozuki): Why \"mixed\"?\n# MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype\n# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype.\n# See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in \"csrc/layer_norm_cuda.cpp\"\nclass MixedFusedLayerNorm(FusedLayerNorm):\n    def __init__(self, normalized_shape, eps=1e-5, *, memory_efficient=False, **kwargs):\n        if \"elementwise_affine\" in kwargs:\n            import warnings\n\n            warnings.warn(\"MixedFusedLayerNorm does not support `elementwise_affine` argument\")\n            elementwise_affine = kwargs.pop(\"elementwise_affine\")\n            if not elementwise_affine:\n                raise RuntimeError(\n                    \"MixedFusedLayerNorm does not support `elementwise_affine = False`\"\n                )\n\n        super().__init__(\n            normalized_shape=normalized_shape,\n            eps=eps,\n            elementwise_affine=True,\n            memory_efficient=memory_efficient,\n        )\n\n    def forward(self, input: torch.Tensor):\n        # NOTE (mkozuki): CPU path is here mainly for unittest sake.\n        if (\n            torch.jit.is_tracing()\n            or torch.jit.is_scripting()\n            or torch.compiler.is_compiling()\n            or not input.is_cuda\n        ):\n            return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)\n        return mixed_dtype_fused_layer_norm_affine(\n            input,\n            self.weight,\n            self.bias,\n            self.normalized_shape,\n            self.eps,\n            self.memory_efficient,\n        )\n\n\n# MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype\n# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype.\n# See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in \"csrc/layer_norm_cuda.cpp\"\nclass MixedFusedRMSNorm(FusedRMSNorm):\n    def __init__(self, normalized_shape, eps=1e-5, *, memory_efficient=False, **kwargs):\n        if \"elementwise_affine\" in kwargs:\n            import warnings\n\n            warnings.warn(\"MixedFusedRMSNorm does not support `elementwise_affine` argument\")\n            elementwise_affine = kwargs.pop(\"elementwise_affine\")\n            if not elementwise_affine:\n                raise RuntimeError(\n                    \"MixedFusedRMSNorm does not support `elementwise_affine = False`\"\n                )\n\n        super().__init__(\n            normalized_shape=normalized_shape,\n            eps=eps,\n            elementwise_affine=True,\n            memory_efficient=memory_efficient,\n        )\n\n    def forward(self, input: torch.Tensor):\n        # NOTE (mkozuki): CPU path is here mainly for unittest sake.\n        # TODO Manual RMS Norm Implementation Here\n        if (\n            torch.jit.is_tracing()\n            or torch.jit.is_scripting()\n            or torch.compiler.is_compiling()\n            or not input.is_cuda\n        ):\n            return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)\n        return mixed_dtype_fused_rms_norm_affine(\n            input, self.weight, self.normalized_shape, self.eps, self.memory_efficient\n        )\n"
  },
  {
    "path": "apex/optimizers/__init__.py",
    "content": "from .fused_sgd import FusedSGD\nfrom .fused_adam import FusedAdam\nfrom .fused_novograd import FusedNovoGrad\nfrom .fused_lamb import FusedLAMB\nfrom .fused_adagrad import FusedAdagrad\nfrom .fused_mixed_precision_lamb import FusedMixedPrecisionLamb\n"
  },
  {
    "path": "apex/optimizers/fused_adagrad.py",
    "content": "import torch\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\n\nclass FusedAdagrad(torch.optim.Optimizer):\n    \"\"\"Implements Adagrad algorithm.\n\n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./``.\n\n    This version of fused Adagrad implements 2 fusions.\n      * Fusion of the Adagrad update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.optimizers.FusedAdagrad`'s usage is identical to any ordinary Pytorch optimizer::\n        opt = apex.optimizers.FusedAdagrad(model.parameters(), lr = ....)\n        ...\n        opt.step()\n\n    :class:`apex.optimizers.FusedAdagrad` may be used with or without Amp.  If you wish to use :class:`FusedAdagrad` with Amp,\n    you may choose any ``opt_level``::\n        opt = apex.optimizers.FusedAdagrad(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n    In general, ``opt_level=\"O1\"`` is recommended.\n\n    It has been proposed in `Adaptive Subgradient Methods for Online Learning\n    and Stochastic Optimization`_.\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float, optional): learning rate (default: 1e-2)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        eps (float, optional): term added to the denominator to improve\n            numerical stability (default: 1e-10)\n        adagrad_w_mode (boolean, optional): Apply L2 regularization or weight decay\n            True for decoupled weight decay (also known as AdamW) (default: False)\n\n    .. _Adaptive Subgradient Methods for Online Learning and Stochastic\n        Optimization: http://jmlr.org/papers/v12/duchi11a.html\n    \"\"\"\n\n    def __init__(\n        self,\n        params,\n        lr=1e-2,\n        eps=1e-10,\n        weight_decay=0.0,\n        set_grad_none=True,\n        adagrad_w_mode=False,\n    ):\n        defaults = dict(lr=lr, eps=eps, weight_decay=weight_decay)\n        super(FusedAdagrad, self).__init__(params, defaults)\n        self.adagrad_w_mode = 1 if adagrad_w_mode else 0\n        self.set_grad_none = set_grad_none\n\n        if multi_tensor_applier.available:\n            import amp_C\n\n            # Skip buffer\n            self._dummy_overflow_buf = torch.cuda.IntTensor([0])\n            self.multi_tensor_adagrad = amp_C.multi_tensor_adagrad\n        else:\n            raise RuntimeError(\"apex.optimizers.FusedAdagrad requires cuda extensions\")\n\n    def zero_grad(self):\n        if self.set_grad_none:\n            for group in self.param_groups:\n                for p in group[\"params\"]:\n                    p.grad = None\n        else:\n            super(FusedAdagrad, self).zero_grad()\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n            # create lists for multi-tensor apply\n            g_16, p_16, h_16 = [], [], []\n            g_32, p_32, h_32 = [], [], []\n\n            for p in group[\"params\"]:\n                if p.grad is None:\n                    continue\n                if p.grad.data.is_sparse:\n                    raise RuntimeError(\"FusedAdagrad does not support sparse gradients\")\n\n                state = self.state[p]\n                # State initialization\n                if len(state) == 0:\n                    # Exponential moving average of gradient values\n                    state[\"sum\"] = torch.zeros_like(p.data)\n                if p.dtype == torch.float16:\n                    g_16.append(p.grad.data)\n                    p_16.append(p.data)\n                    h_16.append(state[\"sum\"])\n                elif p.dtype == torch.float32:\n                    g_32.append(p.grad.data)\n                    p_32.append(p.data)\n                    h_32.append(state[\"sum\"])\n                else:\n                    raise RuntimeError(\"FusedAdagrad only support fp16 and fp32.\")\n\n            if len(g_16) > 0:\n                multi_tensor_applier(\n                    self.multi_tensor_adagrad,\n                    self._dummy_overflow_buf,\n                    [g_16, p_16, h_16],\n                    group[\"lr\"],\n                    group[\"eps\"],\n                    self.adagrad_w_mode,\n                    group[\"weight_decay\"],\n                )\n            if len(g_32) > 0:\n                multi_tensor_applier(\n                    self.multi_tensor_adagrad,\n                    self._dummy_overflow_buf,\n                    [g_32, p_32, h_32],\n                    group[\"lr\"],\n                    group[\"eps\"],\n                    self.adagrad_w_mode,\n                    group[\"weight_decay\"],\n                )\n\n        return loss\n"
  },
  {
    "path": "apex/optimizers/fused_adam.py",
    "content": "import torch\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\n\nclass FusedAdam(torch.optim.Optimizer):\n    \"\"\"Implements Adam algorithm.\n\n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./``.\n\n    This version of fused Adam implements 2 fusions.\n\n      * Fusion of the Adam update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.optimizers.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,\n    or ``torch.optim.Adam`` with ``adam_w_mode=False``::\n\n        opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....)\n        ...\n        opt.step()\n\n    :class:`apex.optimizers.FusedAdam` may be used with or without Amp.  If you wish to use :class:`FusedAdam` with Amp,\n    you may choose any ``opt_level``::\n\n        opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n\n    In general, ``opt_level=\"O1\"`` is recommended.\n\n\n    .. warning::\n        A previous version of :class:`FusedAdam` allowed a number of additional arguments to ``step``.  These additional arguments\n        are now deprecated and unnecessary.\n\n    Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            (default: False) NOT SUPPORTED in FusedAdam!\n        adam_w_mode (boolean, optional): Apply L2 regularization or weight decay\n            True for decoupled weight decay(also known as AdamW) (default: True)\n        set_grad_none (bool, optional): whether set grad to None when zero_grad()\n            method is called. (default: True)\n        capturable (bool, optional): whether to use the version of the optimizer\n            that can be used with CUDA Graphs. (default: False)\n        master_weights (bool, optional): whether to maintain FP32 master weights\n           in the optimizer with FP16 mixed precision training, currently can\n           only be used with capturable set to True. (default: False)\n\n    .. _Adam - A Method for Stochastic Optimization:\n        https://arxiv.org/abs/1412.6980\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(\n        self,\n        params,\n        lr=1e-3,\n        bias_correction=True,\n        betas=(0.9, 0.999),\n        eps=1e-8,\n        adam_w_mode=True,\n        weight_decay=0.0,\n        amsgrad=False,\n        set_grad_none=True,\n        capturable=False,\n        master_weights=False,\n    ):\n        if amsgrad:\n            raise RuntimeError(\"FusedAdam does not support the AMSGrad variant.\")\n        if master_weights and not capturable:\n            raise RuntimeError(\n                \"Master weights is currently only supported with the capturable version.\"\n            )\n        # If the optimizer is capturable then LR should be a tensor (on GPU)\n        lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr\n        defaults = dict(\n            lr=lr,\n            bias_correction=bias_correction,\n            betas=betas,\n            eps=eps,\n            weight_decay=weight_decay,\n        )\n        super(FusedAdam, self).__init__(params, defaults)\n        self.adam_w_mode = 1 if adam_w_mode else 0\n        self.set_grad_none = set_grad_none\n\n        self.capturable = capturable\n        self.master_weights = master_weights\n\n        # Create full precision master weights\n        self.param_groups_master = []\n        for i, pg in enumerate(self.param_groups):\n            param_list = pg[\"params\"]\n            self.param_groups_master.append(\n                {\n                    \"params\": [\n                        p.clone().detach().float() if self.master_weights else None\n                        for p in param_list\n                    ],\n                }\n            )\n\n        if capturable:\n            for idx, group in enumerate(self.param_groups):\n                if len(group[\"params\"]) == 0:\n                    continue\n                device = group[\"params\"][0].device\n                for item in [\"lr\"]:\n                    self.param_groups[idx][item] = group[item].to(device=device)\n\n            self._step_supports_amp_scaling = True\n\n        if multi_tensor_applier.available:\n            import amp_C\n\n            # Skip buffer\n            self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=\"cuda\")\n            self.multi_tensor_adam = amp_C.multi_tensor_adam\n            self.multi_tensor_adam_capturable = amp_C.multi_tensor_adam_capturable\n            self.multi_tensor_adam_capturable_master = amp_C.multi_tensor_adam_capturable_master\n        else:\n            raise RuntimeError(\"apex.optimizers.FusedAdam requires cuda extensions\")\n\n    def zero_grad(self):\n        if self.set_grad_none:\n            for group in self.param_groups:\n                for p in group[\"params\"]:\n                    p.grad = None\n        else:\n            super(FusedAdam, self).zero_grad()\n\n    def step(\n        self,\n        closure=None,\n        grads=None,\n        output_params=None,\n        scale=None,\n        grad_norms=None,\n        grad_scaler=None,\n    ):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n\n        The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.\n        \"\"\"\n        if any(p is not None for p in [grads, output_params, scale, grad_norms]):\n            raise RuntimeError(\n                \"FusedAdam has been updated.  Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.\"\n            )\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group, group_master in zip(self.param_groups, self.param_groups_master):\n            if len(group[\"params\"]) == 0:\n                continue\n            device = group[\"params\"][0].device\n            bias_correction = 1 if group[\"bias_correction\"] else 0\n            beta1, beta2 = group[\"betas\"]\n\n            # assume same step across group now to simplify things\n            # per parameter step can be easily support by making it tensor, or pass list into kernel\n            if \"step\" in group:\n                group[\"step\"] += (\n                    1 if not self.capturable else (self._dummy_overflow_buf != 1).to(torch.int)\n                )\n            else:\n                group[\"step\"] = (\n                    1 if not self.capturable else torch.tensor([1], dtype=torch.int, device=device)\n                )\n\n            # create lists for multi-tensor apply\n            g_16, p_16, m_16, v_16 = [], [], [], []\n            g_bf, p_bf, m_bf, v_bf = [], [], [], []\n            g_32, p_32, m_32, v_32 = [], [], [], []\n            p_16_master = []\n            p_32_master = []\n\n            for p, p_master in zip(group[\"params\"], group_master[\"params\"]):\n                if p.grad is None:\n                    continue\n                if p.grad.data.is_sparse:\n                    raise RuntimeError(\n                        \"FusedAdam does not support sparse gradients, please consider SparseAdam instead\"\n                    )\n\n                state = self.state[p]\n                # State initialization\n                if len(state) == 0:\n                    # Exponential moving average of gradient values\n                    state[\"exp_avg\"] = torch.zeros_like(p.data).float()\n                    # Exponential moving average of squared gradient values\n                    state[\"exp_avg_sq\"] = torch.zeros_like(p.data).float()\n\n                if p.dtype == torch.float16:\n                    if self.master_weights:\n                        p_16_master.append(p_master.data)\n                    g_16.append(p.grad.data)\n                    p_16.append(p.data)\n                    m_16.append(state[\"exp_avg\"])\n                    v_16.append(state[\"exp_avg_sq\"])\n                elif p.dtype == torch.bfloat16:\n                    g_bf.append(p.grad)\n                    p_bf.append(p)\n                    m_bf.append(state[\"exp_avg\"])\n                    v_bf.append(state[\"exp_avg_sq\"])\n                elif p.dtype == torch.float32:\n                    if self.master_weights:\n                        p_32_master.append(p_master.data)\n                    g_32.append(p.grad.data)\n                    p_32.append(p.data)\n                    m_32.append(state[\"exp_avg\"])\n                    v_32.append(state[\"exp_avg_sq\"])\n                else:\n                    raise RuntimeError(\"FusedAdam only support fp16 and fp32.\")\n\n            # If the optimizer is capturable, then if there's a grad scaler it works\n            # on the GPU + a different multi_tensor_applier should be called\n            if self.capturable:\n                # overflow check of gradients\n                found_inf = (\n                    grad_scaler._check_inf_per_device(self)[device]\n                    if grad_scaler is not None\n                    else torch.zeros((1,), device=device)\n                )\n                self._dummy_overflow_buf.copy_(found_inf)\n\n                # get unscale scale factor\n                scale, inv_scale = None, None\n                if grad_scaler:\n                    scale = grad_scaler._get_scale_async()\n                    inv_scale = scale.double().reciprocal().float()\n                else:\n                    scale = torch.ones((1,), device=device)\n                    inv_scale = torch.ones((1,), device=device)\n\n                if len(g_16) > 0:\n                    multi_tensor_applier(\n                        self.multi_tensor_adam_capturable_master\n                        if self.master_weights\n                        else self.multi_tensor_adam_capturable,\n                        self._dummy_overflow_buf,\n                        [g_16, p_16, m_16, v_16, p_16_master]\n                        if self.master_weights\n                        else [g_16, p_16, m_16, v_16],\n                        group[\"lr\"],\n                        beta1,\n                        beta2,\n                        group[\"eps\"],\n                        group[\"step\"],\n                        self.adam_w_mode,\n                        bias_correction,\n                        group[\"weight_decay\"],\n                        inv_scale,\n                    )\n\n                if len(g_bf) > 0:\n                    multi_tensor_applier(\n                        self.multi_tensor_adam_capturable,\n                        self._dummy_overflow_buf,\n                        [g_bf, p_bf, m_bf, v_bf],\n                        group[\"lr\"],\n                        beta1,\n                        beta2,\n                        group[\"eps\"],\n                        group[\"step\"],\n                        self.adam_w_mode,\n                        bias_correction,\n                        group[\"weight_decay\"],\n                        inv_scale,\n                    )\n\n                if len(g_32) > 0:\n                    multi_tensor_applier(\n                        self.multi_tensor_adam_capturable_master\n                        if self.master_weights\n                        else self.multi_tensor_adam_capturable,\n                        self._dummy_overflow_buf,\n                        [g_32, p_32, m_32, v_32, p_32_master]\n                        if self.master_weights\n                        else [g_32, p_32, m_32, v_32],\n                        group[\"lr\"],\n                        beta1,\n                        beta2,\n                        group[\"eps\"],\n                        group[\"step\"],\n                        self.adam_w_mode,\n                        bias_correction,\n                        group[\"weight_decay\"],\n                        inv_scale,\n                    )\n            else:\n                if len(g_16) > 0:\n                    multi_tensor_applier(\n                        self.multi_tensor_adam,\n                        self._dummy_overflow_buf,\n                        [g_16, p_16, m_16, v_16],\n                        group[\"lr\"],\n                        beta1,\n                        beta2,\n                        group[\"eps\"],\n                        group[\"step\"],\n                        self.adam_w_mode,\n                        bias_correction,\n                        group[\"weight_decay\"],\n                    )\n\n                if len(g_bf) > 0:\n                    multi_tensor_applier(\n                        self.multi_tensor_adam,\n                        self._dummy_overflow_buf,\n                        [g_bf, p_bf, m_bf, v_bf],\n                        group[\"lr\"],\n                        beta1,\n                        beta2,\n                        group[\"eps\"],\n                        group[\"step\"],\n                        self.adam_w_mode,\n                        bias_correction,\n                        group[\"weight_decay\"],\n                    )\n\n                if len(g_32) > 0:\n                    multi_tensor_applier(\n                        self.multi_tensor_adam,\n                        self._dummy_overflow_buf,\n                        [g_32, p_32, m_32, v_32],\n                        group[\"lr\"],\n                        beta1,\n                        beta2,\n                        group[\"eps\"],\n                        group[\"step\"],\n                        self.adam_w_mode,\n                        bias_correction,\n                        group[\"weight_decay\"],\n                    )\n\n        return loss\n"
  },
  {
    "path": "apex/optimizers/fused_lamb.py",
    "content": "import torch\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\n\nclass FusedLAMB(torch.optim.Optimizer):\n    \"\"\"Implements LAMB algorithm.\n\n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./``.\n\n    This version of fused LAMB implements 2 fusions.\n\n      * Fusion of the LAMB update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::\n\n        opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)\n        ...\n        opt.step()\n\n    :class:`apex.optimizers.FusedLAMB` may be used with or without Amp.  If you wish to use :class:`FusedLAMB` with Amp,\n    you may choose any ``opt_level``::\n\n        opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n\n    In general, ``opt_level=\"O1\"`` is recommended.\n\n    LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its norm. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            NOT SUPPORTED now! (default: False)\n        adam_w_mode (boolean, optional): Apply L2 regularization or weight decay\n            True for decoupled weight decay(also known as AdamW) (default: True)\n        grad_averaging (bool, optional): whether apply (1-beta2) to grad when\n            calculating running averages of gradient. (default: True)\n        set_grad_none (bool, optional): whether set grad to None when zero_grad()\n            method is called. (default: True)\n        max_grad_norm (float, optional): value used to clip global grad norm\n            (default: 1.0)\n        use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0\n            weight decay parameter (default: False)\n\n    .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:\n        https://arxiv.org/abs/1904.00962\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(\n        self,\n        params,\n        lr=1e-3,\n        bias_correction=True,\n        betas=(0.9, 0.999),\n        eps=1e-6,\n        weight_decay=0.01,\n        amsgrad=False,\n        adam_w_mode=True,\n        grad_averaging=True,\n        set_grad_none=True,\n        max_grad_norm=1.0,\n        use_nvlamb=False,\n    ):\n        if amsgrad:\n            raise RuntimeError(\"FusedLAMB does not support the AMSGrad variant.\")\n        defaults = dict(\n            lr=lr,\n            bias_correction=bias_correction,\n            betas=betas,\n            eps=eps,\n            weight_decay=weight_decay,\n            grad_averaging=grad_averaging,\n            max_grad_norm=max_grad_norm,\n        )\n        super(FusedLAMB, self).__init__(params, defaults)\n        if multi_tensor_applier.available:\n            import amp_C\n\n            self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm\n            # Skip buffer\n            self._dummy_overflow_buf = torch.tensor(\n                [0], dtype=torch.int, device=self.param_groups[0][\"params\"][0].device\n            )\n            self.multi_tensor_lamb = amp_C.multi_tensor_lamb\n        else:\n            raise RuntimeError(\"apex.optimizers.FusedLAMB requires cuda extensions\")\n\n        self.adam_w_mode = 1 if adam_w_mode else 0\n        self.set_grad_none = set_grad_none\n        self.use_nvlamb = use_nvlamb\n\n    def zero_grad(self):\n        if self.set_grad_none:\n            for group in self.param_groups:\n                for p in group[\"params\"]:\n                    p.grad = None\n        else:\n            super(FusedLAMB, self).zero_grad()\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        # create separate grad lists for fp32 and fp16 params\n        g_all_32, g_all_16 = [], []\n        for group in self.param_groups:\n            for p in group[\"params\"]:\n                if p.grad is None:\n                    continue\n                if p.dtype == torch.float32:\n                    g_all_32.append(p.grad.data)\n                elif p.dtype == torch.float16:\n                    g_all_16.append(p.grad.data)\n                else:\n                    raise RuntimeError(\"FusedLAMB only support fp16 and fp32.\")\n\n        device = self.param_groups[0][\"params\"][0].device\n        g_norm_32, g_norm_16 = (\n            torch.zeros(1, device=device),\n            torch.zeros(1, device=device),\n        )\n        # compute grad norm for two lists\n        if len(g_all_32) > 0:\n            g_norm_32 = multi_tensor_applier(\n                self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_32], False\n            )[0]\n        if len(g_all_16) > 0:\n            g_norm_16 = multi_tensor_applier(\n                self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_16], False\n            )[0]\n\n        # blend two grad norms to get global grad norm\n        global_grad_norm = multi_tensor_applier(\n            self.multi_tensor_l2norm,\n            self._dummy_overflow_buf,\n            [[g_norm_32, g_norm_16]],\n            False,\n        )[0]\n        max_grad_norm = self.defaults[\"max_grad_norm\"]\n\n        for group in self.param_groups:\n            bias_correction = 1 if group[\"bias_correction\"] else 0\n            beta1, beta2 = group[\"betas\"]\n            grad_averaging = 1 if group[\"grad_averaging\"] else 0\n\n            # assume same step across group now to simplify things\n            # per parameter step can be easily support by making it tensor, or pass list into kernel\n            if \"step\" in group:\n                group[\"step\"] += 1\n            else:\n                group[\"step\"] = 1\n\n            # create lists for multi-tensor apply\n            g_16, p_16, m_16, v_16 = [], [], [], []\n            g_32, p_32, m_32, v_32 = [], [], [], []\n\n            for p in group[\"params\"]:\n                if p.grad is None:\n                    continue\n                if p.grad.data.is_sparse:\n                    raise RuntimeError(\n                        \"FusedLAMB does not support sparse gradients, please consider SparseAdam instead\"\n                    )\n\n                state = self.state[p]\n                # State initialization\n                if len(state) == 0:\n                    # Exponential moving average of gradient values\n                    state[\"exp_avg\"] = torch.zeros_like(p.data)\n                    # Exponential moving average of gradient values\n                    state[\"exp_avg_sq\"] = torch.zeros_like(p.data)\n\n                if p.dtype == torch.float16:\n                    g_16.append(p.grad.data)\n                    p_16.append(p.data)\n                    m_16.append(state[\"exp_avg\"])\n                    v_16.append(state[\"exp_avg_sq\"])\n                elif p.dtype == torch.float32:\n                    g_32.append(p.grad.data)\n                    p_32.append(p.data)\n                    m_32.append(state[\"exp_avg\"])\n                    v_32.append(state[\"exp_avg_sq\"])\n                else:\n                    raise RuntimeError(\"FusedLAMB only support fp16 and fp32.\")\n\n            if len(g_16) > 0:\n                multi_tensor_applier(\n                    self.multi_tensor_lamb,\n                    self._dummy_overflow_buf,\n                    [g_16, p_16, m_16, v_16],\n                    group[\"lr\"],\n                    beta1,\n                    beta2,\n                    group[\"eps\"],\n                    group[\"step\"],\n                    bias_correction,\n                    group[\"weight_decay\"],\n                    grad_averaging,\n                    self.adam_w_mode,\n                    global_grad_norm,\n                    max_grad_norm,\n                    self.use_nvlamb,\n                )\n            if len(g_32) > 0:\n                multi_tensor_applier(\n                    self.multi_tensor_lamb,\n                    self._dummy_overflow_buf,\n                    [g_32, p_32, m_32, v_32],\n                    group[\"lr\"],\n                    beta1,\n                    beta2,\n                    group[\"eps\"],\n                    group[\"step\"],\n                    bias_correction,\n                    group[\"weight_decay\"],\n                    grad_averaging,\n                    self.adam_w_mode,\n                    global_grad_norm,\n                    max_grad_norm,\n                    self.use_nvlamb,\n                )\n\n        return loss\n"
  },
  {
    "path": "apex/optimizers/fused_mixed_precision_lamb.py",
    "content": "import torch\nfrom copy import deepcopy\nfrom itertools import chain\nfrom collections import defaultdict, abc as container_abcs\n\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\n\nclass FusedMixedPrecisionLamb(torch.optim.Optimizer):\n    def __init__(\n        self,\n        params,\n        lr=1e-3,\n        step=0,\n        bias_correction=True,\n        betas=(0.9, 0.999),\n        eps=1e-6,\n        weight_decay=0.01,\n        amsgrad=False,\n        adam_w_mode=True,\n        grad_averaging=True,\n        max_grad_norm=1.0,\n        use_nvlamb=False,\n        reduced_precision_dtype=None,\n    ):\n        if amsgrad:\n            raise RuntimeError(\"FusedLAMB does not support the AMSGrad variant.\")\n\n        # init defaults\n        defaults = dict(\n            lr=torch.tensor(lr, dtype=torch.float32),\n            step=torch.tensor([step], dtype=torch.int),\n            bias_correction=bias_correction,\n            betas=betas,\n            eps=eps,\n            weight_decay=weight_decay,\n            grad_averaging=grad_averaging,\n            max_grad_norm=max_grad_norm,\n        )\n\n        # init base module\n        super(FusedMixedPrecisionLamb, self).__init__(params, defaults)\n\n        # The learning rate (lr) and optimizer step (step) should be located on device\n        # in order to faciliated device sync free execution\n        device = self.param_groups[0][\"params\"][0].device\n        tensor_state = [\"lr\", \"step\"]\n        for idx, group in enumerate(self.param_groups):\n            for item in tensor_state:\n                self.param_groups[idx][item] = group[item].to(device=device)\n\n        if multi_tensor_applier.available:\n            import amp_C\n\n            self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm_mp\n            # Skip buffer\n            self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=device)\n            self.multi_tensor_lamb = amp_C.multi_tensor_lamb_mp\n        else:\n            raise RuntimeError(\"apex.optimizers.FusedLAMB requires cuda extensions\")\n\n        # Mixed Precision support\n        self.reduced_precision_dtype = reduced_precision_dtype\n        self.param_groups_full_precision = []\n\n        self._step_supports_amp_scaling = True\n        self.adam_w_mode = 1 if adam_w_mode else 0\n        self.use_nvlamb = use_nvlamb\n\n    # This method is overridden from the parent class because there is not a way to override\n    # the nested function cast() that copies a saved piece of state to the device without\n    # redundantly doing the copy.\n    def load_state_dict(self, state_dict):\n        r\"\"\"Loads the optimizer state.\n\n        Args:\n            state_dict (dict): optimizer state. Should be an object returned\n                from a call to :meth:`state_dict`.\n        \"\"\"\n        # deepcopy, to be consistent with module API\n        state_dict = deepcopy(state_dict)\n        # Validate the state_dict\n        groups = self.param_groups\n        saved_groups = state_dict[\"param_groups\"]\n\n        if len(groups) != len(saved_groups):\n            raise ValueError(\"loaded state dict has a different number of parameter groups\")\n        param_lens = (len(g[\"params\"]) for g in groups)\n        saved_lens = (len(g[\"params\"]) for g in saved_groups)\n        if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):\n            raise ValueError(\n                \"loaded state dict contains a parameter group \"\n                \"that doesn't match the size of optimizer's group\"\n            )\n\n        # Update the state\n        id_map = {\n            old_id: p\n            for old_id, p in zip(\n                chain.from_iterable((g[\"params\"] for g in saved_groups)),\n                chain.from_iterable((g[\"params\"] for g in groups)),\n            )\n        }\n\n        def cast(param, value):\n            r\"\"\"Make a deep copy of value, casting all tensors to device of param.\"\"\"\n            if isinstance(value, torch.Tensor):\n                # The original version casted the saved value to the params dtype\n                # This doesn't work for mixed precision Lamb where the momentum and\n                # velocity are expected to be in full precision while the params are\n                # in reduced precision\n                value = value.to(value.device)\n                return value\n            elif isinstance(value, dict):\n                return {k: cast(param, v) for k, v in value.items()}\n            elif isinstance(value, container_abcs.Iterable):\n                return type(value)(cast(param, v) for v in value)\n            else:\n                return value\n\n        # Copy state assigned to params (and cast tensors to appropriate types).\n        # State that is not assigned to params is copied as is (needed for\n        # backward compatibility).\n        state = defaultdict(dict)\n        for k, v in state_dict[\"state\"].items():\n            if k in id_map:\n                param = id_map[k]\n                state[param] = cast(param, v)\n            else:\n                state[k] = v\n\n        # Update parameter groups, setting their 'params' value\n        def update_group(group, new_group):\n            new_group[\"params\"] = group[\"params\"]\n            return new_group\n\n        param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]\n        self.__setstate__({\"state\": state, \"param_groups\": param_groups})\n\n    def _setup_full_precision_params(self):\n        for i, pg in enumerate(self.param_groups):\n            param_list = pg[\"params\"]\n            self.param_groups_full_precision.append(\n                {\n                    \"params\": [\n                        p.clone().detach().to(dtype=torch.float32)\n                        if (self.reduced_precision_dtype is not None)\n                        and (p.dtype == self.reduced_precision_dtype)\n                        else None\n                        for p in param_list\n                    ],\n                }\n            )\n\n    # add_param_groups() is overridden because default items can be tensors. The\n    # parent version does not clone the default item, so two param groups can\n    # accidentally point to the same default item value where they can differ\n    # given they are in separate groups.\n    def add_param_group(self, param_group):\n        super().add_param_group(param_group)\n        for name, default in self.defaults.items():\n            if isinstance(default, torch.Tensor):\n                self.param_groups[len(self.param_groups) - 1][name] = default.clone()\n\n    @torch.no_grad()\n    def step(self, closure=None, grad_scaler=None):\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        # The full precision params are set up in the first step of the optimizer\n        # instead of in the constructor because the full precision params will get out\n        # out of sync with the model params if DDP syncs the model params across devices\n        # after the optimizer is constructed.\n        if len(self.param_groups_full_precision) == 0:\n            self._setup_full_precision_params()\n\n        # create separate grad lists for params\n        grad_list = []\n        for gid, group in enumerate(self.param_groups):\n            for pid, p in enumerate(group[\"params\"]):\n                assert group[\"params\"][0].dtype == p.dtype, (\n                    \"Error: Parameters are not of the identical type: {} != {}\".format(\n                        group[\"params\"][0].dtype, p.dtype\n                    )\n                )\n                if p.grad is None:\n                    continue\n                grad_list.append(p.grad)\n\n        # Overflow check of gradients\n        device = self.param_groups[0][\"params\"][0].device\n        found_inf = (\n            grad_scaler._check_inf_per_device(self)[device]\n            if grad_scaler is not None\n            else torch.zeros((1,), device=device)\n        )\n        self._dummy_overflow_buf.copy_(found_inf)\n\n        # Get unscale scale factor\n        scale, inv_scale = None, None\n        if grad_scaler:\n            scale = grad_scaler._get_scale_async()\n            inv_scale = scale.double().reciprocal().float()\n        else:\n            scale = torch.ones((1,), device=device)\n            inv_scale = torch.ones((1,), device=device)\n\n        # grad_norm is of scaled gradients.\n        # So, multiply `max_grad_norm` by scale.\n        max_grad_norm = self.defaults[\"max_grad_norm\"] * scale\n        grad_norm = multi_tensor_applier(\n            self.multi_tensor_l2norm,\n            self._dummy_overflow_buf,\n            [grad_list],\n            False,\n        )[0]\n\n        # Run LAMB optimization math\n        for gid, (group, group_full) in enumerate(\n            zip(self.param_groups, self.param_groups_full_precision)\n        ):\n            bias_correction = 1 if group[\"bias_correction\"] else 0\n            beta1, beta2 = group[\"betas\"]\n            grad_averaging = 1 if group[\"grad_averaging\"] else 0\n\n            # assume same step across group now to simplify things\n            # per parameter step can be easily support by making it tensor, or pass list into kernel\n            group[\"step\"] += (self._dummy_overflow_buf != 1).to(torch.int)\n\n            state_lists = [\n                [],  # (0) grads\n                [],  # (1) params\n                [],  # (2) momentum state\n                [],  # (3) velocity state\n            ]\n            if self.reduced_precision_dtype is not None:\n                state_lists.append([])  # (4) params reduced_dtype\n\n            for p, p_full in zip(group[\"params\"], group_full[\"params\"]):\n                if p.grad is None:\n                    continue\n                assert not p.grad.is_sparse\n\n                state = self.state[p]\n                # State initialization\n                if len(state) == 0:\n                    dtype = p.dtype\n                    if (\n                        self.reduced_precision_dtype is not None\n                        and p.dtype == self.reduced_precision_dtype\n                    ):\n                        dtype = torch.float32\n                    # Exponential moving average of gradient values\n                    state[\"exp_avg\"] = torch.zeros_like(p.data, dtype=dtype)\n                    # Exponential moving average of gradient values\n                    state[\"exp_avg_sq\"] = torch.zeros_like(p.data, dtype=dtype)\n\n                if self.reduced_precision_dtype is not None:\n                    state_lists[0].append(p.grad.data)\n                    state_lists[1].append(p_full.data)\n                    state_lists[2].append(state[\"exp_avg\"])\n                    state_lists[3].append(state[\"exp_avg_sq\"])\n                    state_lists[4].append(p.data)\n                else:\n                    state_lists[0].append(p.grad.data)\n                    state_lists[1].append(p.data)\n                    state_lists[2].append(state[\"exp_avg\"])\n                    state_lists[3].append(state[\"exp_avg_sq\"])\n\n            multi_tensor_applier(\n                self.multi_tensor_lamb,\n                self._dummy_overflow_buf,\n                state_lists,\n                group[\"lr\"],\n                beta1,\n                beta2,\n                group[\"eps\"],\n                group[\"step\"],\n                bias_correction,\n                group[\"weight_decay\"],\n                grad_averaging,\n                self.adam_w_mode,\n                grad_norm,\n                max_grad_norm,\n                self.use_nvlamb,\n                found_inf,\n                inv_scale,\n            )\n\n        return loss\n"
  },
  {
    "path": "apex/optimizers/fused_novograd.py",
    "content": "import torch\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\n\nclass FusedNovoGrad(torch.optim.Optimizer):\n    \"\"\"Implements NovoGrad algorithm.\n\n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./``.\n\n    This version of fused NovoGrad implements 2 fusions.\n\n      * Fusion of the NovoGrad update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.optimizers.FusedNovoGrad`'s usage is identical to any Pytorch optimizer::\n\n        opt = apex.optimizers.FusedNovoGrad(model.parameters(), lr = ....)\n        ...\n        opt.step()\n\n    :class:`apex.optimizers.FusedNovoGrad` may be used with or without Amp.  If you wish to use :class:`FusedNovoGrad` with Amp,\n    you may choose any ``opt_level``::\n\n        opt = apex.optimizers.FusedNovoGrad(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n\n    In general, ``opt_level=\"O1\"`` is recommended.\n\n    It has been proposed in `Jasper: An End-to-End Convolutional Neural Acoustic Model`_.\n    More info: https://nvidia.github.io/OpenSeq2Seq/html/optimizers.html#novograd\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its norm. (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            NOT SUPPORTED now! (default: False)\n        reg_inside_moment (bool, optional): whether do regularization (norm and L2)\n            in momentum calculation. True for include, False for not include and\n            only do it on update term. (default: False)\n        grad_averaging (bool, optional): whether apply (1-beta1) to grad when\n            calculating running averages of gradient. (default: True)\n        norm_type (int, optional): which norm to calculate for each layer.\n            2 for L2 norm, and 0 for infinite norm. These 2 are only supported\n            type now. (default: 2)\n        init_zero (bool, optional): whether init norm with 0 (start averaging on\n            1st step) or first step norm (start averaging on 2nd step). True for\n            init with 0. (default: False)\n        set_grad_none (bool, optional): whether set grad to None when zero_grad()\n            method is called. (default: True)\n\n    .. _Jasper - An End-to-End Convolutional Neural Acoustic Model:\n        https://arxiv.org/abs/1904.03288\n    .. _On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(\n        self,\n        params,\n        lr=1e-3,\n        bias_correction=True,\n        betas=(0.9, 0.999),\n        eps=1e-8,\n        weight_decay=0.0,\n        amsgrad=False,\n        reg_inside_moment=False,\n        grad_averaging=True,\n        norm_type=2,\n        init_zero=False,\n        set_grad_none=True,\n    ):\n        if amsgrad:\n            raise RuntimeError(\"FusedNovoGrad does not support the AMSGrad variant.\")\n        defaults = dict(\n            lr=lr,\n            bias_correction=bias_correction,\n            betas=betas,\n            eps=eps,\n            weight_decay=weight_decay,\n            grad_averaging=grad_averaging,\n            norm_type=norm_type,\n            init_zero=init_zero,\n        )\n        super(FusedNovoGrad, self).__init__(params, defaults)\n        if multi_tensor_applier.available:\n            import amp_C\n            # Skip buffer\n\n            # Creating the overflow buffer on the same device as the params tensors.\n            self._dummy_overflow_buf = torch.tensor(\n                [0], dtype=torch.int, device=self.param_groups[0][\"params\"][0].device\n            )\n            self.multi_tensor_novograd = amp_C.multi_tensor_novograd\n        else:\n            raise RuntimeError(\"apex.optimizers.FusedNovoGrad requires cuda extensions\")\n\n        self.moment_mode = 0 if reg_inside_moment else 1\n        self.set_grad_none = set_grad_none\n\n    def zero_grad(self):\n        if self.set_grad_none:\n            for group in self.param_groups:\n                for p in group[\"params\"]:\n                    p.grad = None\n        else:\n            super(FusedNovoGrad, self).zero_grad()\n\n    def load_state_dict(self, state_dict):\n        super(FusedNovoGrad, self).load_state_dict(state_dict)\n        # in case exp_avg_sq is not on the same device as params, move it there\n        for group in self.param_groups:\n            if len(group[\"params\"]) > 0:\n                group[\"exp_avg_sq\"][0] = group[\"exp_avg_sq\"][0].to(group[\"params\"][0].device)\n                group[\"exp_avg_sq\"][1] = group[\"exp_avg_sq\"][1].to(group[\"params\"][0].device)\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n            bias_correction = 1 if group[\"bias_correction\"] else 0\n            beta1, beta2 = group[\"betas\"]\n            grad_averaging = 1 if group[\"grad_averaging\"] else 0\n\n            # assume same step across group now to simplify things\n            # per parameter step can be easily support by making it tensor, or pass list into kernel\n            if \"step\" in group:\n                group[\"step\"] += 1\n            else:\n                group[\"step\"] = 1\n\n            # create lists for multi-tensor apply\n            g_16, p_16, m_16 = [], [], []\n            g_32, p_32, m_32 = [], [], []\n\n            for p in group[\"params\"]:\n                if p.grad is None:\n                    continue\n                if p.grad.data.is_sparse:\n                    raise RuntimeError(\n                        \"FusedNovoGrad does not support sparse gradients, please consider SparseAdam instead\"\n                    )\n\n                state = self.state[p]\n                # State initialization\n                if len(state) == 0:\n                    # Exponential moving average of gradient values\n                    state[\"exp_avg\"] = torch.zeros_like(p.data)\n\n                if p.dtype == torch.float16:\n                    g_16.append(p.grad.data)\n                    p_16.append(p.data)\n                    m_16.append(state[\"exp_avg\"])\n                elif p.dtype == torch.float32:\n                    g_32.append(p.grad.data)\n                    p_32.append(p.data)\n                    m_32.append(state[\"exp_avg\"])\n                else:\n                    raise RuntimeError(\"FusedNovoGrad only support fp16 and fp32.\")\n\n            # we store per weight norm as one tensor for one group/precision combination\n            # different from optim.Adam, we store norm here(not ^2) so we can unify calculation for norm types\n            if \"exp_avg_sq\" not in group:\n                group[\"exp_avg_sq\"] = [None, None]\n                if group[\"init_zero\"]:\n                    # Creating the following parameters on the same device as the params tensors.\n                    group[\"exp_avg_sq\"][0] = (\n                        torch.cuda.FloatTensor(\n                            len(g_16), device=self.param_groups[0][\"params\"][0].device\n                        )\n                        .contiguous()\n                        .fill_(0)\n                    )\n                    group[\"exp_avg_sq\"][1] = (\n                        torch.cuda.FloatTensor(\n                            len(g_32), device=self.param_groups[0][\"params\"][0].device\n                        )\n                        .contiguous()\n                        .fill_(0)\n                    )\n                else:  # init with first step norm, so first blend have no effect\n                    if group[\"norm_type\"] == 0:\n                        v_16 = [torch.max(torch.abs(g.to(torch.float32))).item() for g in g_16]\n                        v_32 = [torch.max(torch.abs(g)).item() for g in g_32]\n                    elif group[\"norm_type\"] == 2:\n                        v_16 = [\n                            torch.sum(torch.pow(g.to(torch.float32), 2)).sqrt().item() for g in g_16\n                        ]\n                        v_32 = [torch.sum(torch.pow(g, 2)).sqrt().item() for g in g_32]\n                    else:\n                        raise RuntimeError(\"FusedNovoGrad only support l2/inf norm now.\")\n                    # Creating the following parameters on the same device as the params tensors.\n                    group[\"exp_avg_sq\"][0] = torch.cuda.FloatTensor(\n                        v_16, device=self.param_groups[0][\"params\"][0].device\n                    )\n                    group[\"exp_avg_sq\"][1] = torch.cuda.FloatTensor(\n                        v_32, device=self.param_groups[0][\"params\"][0].device\n                    )\n            else:\n                assert len(g_16) == group[\"exp_avg_sq\"][0].numel()\n                assert len(g_32) == group[\"exp_avg_sq\"][1].numel()\n\n            if len(g_16) > 0:\n                multi_tensor_applier(\n                    self.multi_tensor_novograd,\n                    self._dummy_overflow_buf,\n                    [g_16, p_16, m_16],\n                    group[\"exp_avg_sq\"][0],\n                    group[\"lr\"],\n                    beta1,\n                    beta2,\n                    group[\"eps\"],\n                    group[\"step\"],\n                    bias_correction,\n                    group[\"weight_decay\"],\n                    grad_averaging,\n                    self.moment_mode,\n                    group[\"norm_type\"],\n                )\n            if len(g_32) > 0:\n                multi_tensor_applier(\n                    self.multi_tensor_novograd,\n                    self._dummy_overflow_buf,\n                    [g_32, p_32, m_32],\n                    group[\"exp_avg_sq\"][1],\n                    group[\"lr\"],\n                    beta1,\n                    beta2,\n                    group[\"eps\"],\n                    group[\"step\"],\n                    bias_correction,\n                    group[\"weight_decay\"],\n                    grad_averaging,\n                    self.moment_mode,\n                    group[\"norm_type\"],\n                )\n\n        return loss\n"
  },
  {
    "path": "apex/optimizers/fused_sgd.py",
    "content": "import torch\nfrom torch.optim.optimizer import Optimizer, required\n\nfrom apex.multi_tensor_apply import multi_tensor_applier\n\n\nclass FusedSGD(Optimizer):\n    r\"\"\"Implements stochastic gradient descent (optionally with momentum).\n\n    Currently GPU-only.  Requires Apex to be installed via\n    ``pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./``.\n\n    This version of fused SGD implements 2 fusions.\n\n      * Fusion of the SGD update's elementwise operations\n      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.\n\n    :class:`apex.optimizers.FusedSGD` may be used as a drop-in replacement for ``torch.optim.SGD``::\n\n        opt = apex.optimizers.FusedSGD(model.parameters(), lr = ....)\n        ...\n        opt.step()\n\n    :class:`apex.optimizers.FusedSGD` may be used with or without Amp.  If you wish to use :class:`FusedSGD` with Amp,\n    you may choose any ``opt_level``::\n\n        opt = apex.optimizers.FusedSGD(model.parameters(), lr = ....)\n        model, opt = amp.initialize(model, opt, opt_level=\"O0\" or \"O1 or \"O2\")\n        ...\n        opt.step()\n\n    In general, ``opt_level=\"O1\"`` is recommended.\n\n    Nesterov momentum is based on the formula from\n    `On the importance of initialization and momentum in deep learning`__.\n\n    Args:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float): learning rate\n        momentum (float, optional): momentum factor (default: 0)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        dampening (float, optional): dampening for momentum (default: 0)\n        nesterov (bool, optional): enables Nesterov momentum (default: False)\n\n    Example:\n        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)\n        >>> optimizer.zero_grad()\n        >>> loss_fn(model(input), target).backward()\n        >>> optimizer.step()\n\n    __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf\n\n    .. note::\n        The implementation of SGD with Momentum/Nesterov subtly differs from\n        Sutskever et. al. and implementations in some other frameworks.\n\n        Considering the specific case of Momentum, the update can be written as\n\n        .. math::\n                  v = \\rho * v + g \\\\\n                  p = p - lr * v\n\n        where p, g, v and :math:`\\rho` denote the parameters, gradient,\n        velocity, and momentum respectively.\n\n        This is in contrast to Sutskever et. al. and\n        other frameworks which employ an update of the form\n\n        .. math::\n             v = \\rho * v + lr * g \\\\\n             p = p - v\n\n        The Nesterov version is analogously modified.\n    \"\"\"\n\n    def __init__(\n        self,\n        params,\n        lr=required,\n        momentum=0,\n        dampening=0,\n        weight_decay=0,\n        nesterov=False,\n        wd_after_momentum=False,\n        materialize_master_grads=True,\n        set_grad_none=False,\n    ):\n        if lr is not required and lr < 0.0:\n            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n        if momentum < 0.0:\n            raise ValueError(\"Invalid momentum value: {}\".format(momentum))\n        if weight_decay < 0.0:\n            raise ValueError(\"Invalid weight_decay value: {}\".format(weight_decay))\n\n        defaults = dict(\n            lr=lr,\n            momentum=momentum,\n            dampening=dampening,\n            weight_decay=weight_decay,\n            nesterov=nesterov,\n        )\n        if nesterov and (momentum <= 0 or dampening != 0):\n            raise ValueError(\"Nesterov momentum requires a momentum and zero dampening\")\n        super(FusedSGD, self).__init__(params, defaults)\n\n        self.wd_after_momentum = wd_after_momentum\n        self.materialize_master_grads = materialize_master_grads\n        self.most_recent_scale = 1.0\n        self.scale_set_by_backward = False\n        self.set_grad_none = set_grad_none\n\n        if multi_tensor_applier.available:\n            import amp_C\n\n            # Skip buffer\n            self._dummy_overflow_buf = torch.tensor(\n                [0], dtype=torch.int, device=self.param_groups[0][\"params\"][0].device\n            )\n            self.multi_tensor_sgd = amp_C.multi_tensor_sgd\n        else:\n            raise RuntimeError(\"apex.optimizers.FusedSGD requires cuda extensions\")\n\n    def __setstate__(self, state):\n        super(FusedSGD, self).__setstate__(state)\n        for group in self.param_groups:\n            group.setdefault(\"nesterov\", False)\n\n    def zero_grad(self):\n        if self.set_grad_none:\n            for group in self.param_groups:\n                for p in group[\"params\"]:\n                    p.grad = None\n        else:\n            super(FusedSGD, self).zero_grad()\n\n    def get_momentums(self, params):\n        momentums = []\n        first_run = True\n        for p in params:\n            param_state = self.state[p]\n            # torch.optim.SGD initializes momentum in the main loop, we have\n            # to do it here, and track whether or not we've done so, so that\n            # momentum application can be skipped in the main kernel.\n            if \"momentum_buffer\" not in param_state:\n                first_run = True\n                buf = param_state[\"momentum_buffer\"] = torch.zeros_like(p.data)\n                momentums.append(buf)\n            else:\n                first_run = False\n                momentums.append(param_state[\"momentum_buffer\"])\n        return momentums, first_run\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        explicit_master_params = hasattr(self, \"_amp_stash\") and hasattr(\n            self._amp_stash, \"fp32_from_fp16_groups\"\n        )\n\n        for gid, group in enumerate(self.param_groups):\n            weight_decay = group[\"weight_decay\"]\n            momentum = group[\"momentum\"]\n            dampening = group[\"dampening\"]\n            nesterov = group[\"nesterov\"]\n\n            # For each group, there are 3 possible combinations we need to consider:\n            # grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy\n            # 1. fp16, fp16, fp16, No\n            # 2. fp32, fp32, fp32, No\n            # 3. fp16, fp32, fp32, Yes\n\n            first_runs = [True, True]\n\n            # I think a bit of code divergence in exchange for naming clarity is worthwhile\n            if explicit_master_params:\n                stash = self._amp_stash\n\n                fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]\n                fp32_grads = [\n                    p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None\n                ]\n                fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)\n\n                if self.materialize_master_grads:\n                    fp16_model_params = [\n                        p\n                        for i, p in enumerate(stash.fp16_groups[gid])\n                        if stash.fp32_from_fp16_groups[gid][i].grad is not None\n                    ]\n                    fp32_from_fp16_grads = [\n                        p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None\n                    ]\n                    fp32_from_fp16_params = [\n                        p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None\n                    ]\n                    fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(\n                        fp32_from_fp16_params\n                    )\n\n                    fp16_set = [\n                        fp32_from_fp16_grads,\n                        fp32_from_fp16_params,\n                        fp32_from_fp16_momentums,\n                        fp16_model_params,\n                    ]\n                else:\n                    fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None]\n                    fp16_model_grads = [\n                        p.grad for p in stash.fp16_groups[gid] if p.grad is not None\n                    ]\n                    fp32_from_fp16_params = [\n                        p\n                        for i, p in enumerate(stash.fp32_from_fp16_groups[gid])\n                        if stash.fp16_groups[gid][i].grad is not None\n                    ]\n                    fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(\n                        fp32_from_fp16_params\n                    )\n\n                    fp16_set = [\n                        fp16_model_grads,\n                        fp32_from_fp16_params,\n                        fp32_from_fp16_momentums,\n                        fp16_model_params,\n                    ]\n\n                launch_sets = [fp16_set, [fp32_grads, fp32_params, fp32_momentums]]\n            else:\n                fp16_params = [\n                    p for p in group[\"params\"] if (p.dtype == torch.float16 and p.grad is not None)\n                ]\n                fp16_grads = [\n                    p.grad\n                    for p in group[\"params\"]\n                    if (p.dtype == torch.float16 and p.grad is not None)\n                ]\n                fp16_momentums, first_runs[0] = self.get_momentums(fp16_params)\n\n                fp32_params = [\n                    p for p in group[\"params\"] if (p.dtype == torch.float32 and p.grad is not None)\n                ]\n                fp32_grads = [\n                    p.grad\n                    for p in group[\"params\"]\n                    if (p.dtype == torch.float32 and p.grad is not None)\n                ]\n                fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)\n\n                launch_sets = [\n                    [fp16_grads, fp16_params, fp16_momentums],\n                    [fp32_grads, fp32_params, fp32_momentums],\n                ]\n\n            for s, (launch_set, first_run) in enumerate(zip(launch_sets, first_runs)):\n                assert len(launch_set[0]) == len(launch_set[1])\n                assert len(launch_set[0]) == len(launch_set[2])\n                if len(launch_set[0]) > 0:\n                    multi_tensor_applier(\n                        self.multi_tensor_sgd,\n                        self._dummy_overflow_buf,\n                        launch_set,\n                        weight_decay,\n                        momentum,\n                        dampening,\n                        group[\"lr\"],\n                        nesterov,\n                        first_run,\n                        self.wd_after_momentum,\n                        1.0 / self.most_recent_scale,\n                    )\n\n        self.most_recent_scale = 1.0\n        self.scale_set_by_backward = False\n\n        return loss\n"
  },
  {
    "path": "csrc/amp_C_frontend.cpp",
    "content": "#include <torch/extension.h>\n\nvoid multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                             float scale);\n\nvoid multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                           float wd, float momentum, float dampening, float lr, bool nesterov, bool first_run,\n                           bool wd_after_momentum, float scale);\n\nvoid multi_tensor_axpby_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                             float a, float b, int arg_to_check);\n\nstd::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,\n                                                            std::vector<std::vector<at::Tensor>> tensor_lists,\n                                                            at::optional<bool> per_tensor_python);\n\nstd::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_mp_cuda(int chunk_size, at::Tensor noop_flag,\n                                                               std::vector<std::vector<at::Tensor>> tensor_lists,\n                                                               at::optional<bool> per_tensor_python);\n\nstd::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_scale_cuda(int chunk_size, at::Tensor noop_flag,\n                                                                  std::vector<std::vector<at::Tensor>> tensor_lists,\n                                                                  float scale, at::optional<bool> per_tensor_python);\n\nstd::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(int chunk_size, at::Tensor noop_flag,\n                                                                    std::vector<std::vector<at::Tensor>> tensor_lists,\n                                                                    at::Tensor inv_scale,\n                                                                    at::optional<bool> per_tensor_python);\n\nvoid multi_tensor_lamb_stage1_cuda(int chunk_size, at::Tensor noop_flag,\n                                   std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor per_tensor_decay,\n                                   const int step, const float beta1, const float beta2, const float epsilon,\n                                   at::Tensor global_grad_norm, const float max_global_grad_norm);\n\nvoid multi_tensor_lamb_stage2_cuda(int chunk_size, at::Tensor noop_flag,\n                                   std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor per_tensor_param_norm,\n                                   at::Tensor per_tensor_update_norm, const float lr, const float weight_decay,\n                                   at::optional<bool> use_nvlamb_python);\n\nvoid multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                            const float lr, const float beta1, const float beta2, const float epsilon, const int step,\n                            const int mode, const int bias_correction, const float weight_decay);\n\nvoid multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,\n                                       std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor lr,\n                                       const float beta1, const float beta2, const float epsilon, at::Tensor step,\n                                       const int mode, const int bias_correction, const float weight_decay,\n                                       at::Tensor inv_scale);\n\nvoid multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag,\n                                              std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor lr,\n                                              const float beta1, const float beta2, const float epsilon,\n                                              at::Tensor step, const int mode, const int bias_correction,\n                                              const float weight_decay, at::Tensor inv_scale);\n\nvoid multi_tensor_adagrad_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                               const float lr, const float epsilon, const int mode, const float weight_decay);\n\nvoid multi_tensor_novograd_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                                at::Tensor grad_norms, const float lr, const float beta1, const float beta2,\n                                const float epsilon, const int step, const int bias_correction,\n                                const float weight_decay, const int grad_averaging, const int mode,\n                                const int norm_type);\n\nvoid multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                            const float lr, const float beta1, const float beta2, const float epsilon, const int step,\n                            const int bias_correction, const float weight_decay, const int grad_averaging,\n                            const int mode, at::Tensor global_grad_norm, const float max_grad_norm,\n                            at::optional<bool> use_nvlamb_python);\n\nvoid multi_tensor_lamb_mp_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                               at::Tensor lr, const float beta1, const float beta2, const float epsilon,\n                               at::Tensor step, const int bias_correction, const float weight_decay,\n                               const int grad_averaging, const int mode, at::Tensor global_grad_norm,\n                               at::Tensor max_grad_norm, at::optional<bool> use_nvlamb_python, at::Tensor found_inf,\n                               at::Tensor inv_scale);\n\nat::Tensor update_scale_hysteresis_cuda(at::Tensor current_scale, at::Tensor growth_tracker,\n                                        at::Tensor hysteresis_tracker, at::Tensor found_inf, const double growth_factor,\n                                        const double backoff_factor, const int64_t growth_interval,\n                                        const int hysteresis);\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"multi_tensor_scale\", &multi_tensor_scale_cuda, \"Fused overflow check + scale for a list of contiguous tensors\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"multi_tensor_sgd\", &multi_tensor_sgd_cuda, \"Fused SGD optimizer for list of contiguous tensors\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"multi_tensor_axpby\", &multi_tensor_axpby_cuda, \"out = a*x + b*y for a list of contiguous tensors\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"multi_tensor_l2norm\", &multi_tensor_l2norm_cuda, \"Computes L2 norm for a list of contiguous tensors\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"multi_tensor_l2norm_mp\", &multi_tensor_l2norm_mp_cuda, \"Computes L2 norm for a list of contiguous tensors\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"multi_tensor_l2norm_scale\", &multi_tensor_l2norm_scale_cuda,\n        \"Computes L2 norm for a list of contiguous tensors and does scaling\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"multi_tensor_unscale_l2norm\", &multi_tensor_unscale_l2norm_cuda,\n        \"Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only performed for L2 norm \"\n        \"computation, and tensors are not updated)\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"multi_tensor_lamb_stage1_cuda\", &multi_tensor_lamb_stage1_cuda, \"Computes update part of LAMB optimizer\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"multi_tensor_lamb_stage2_cuda\", &multi_tensor_lamb_stage2_cuda,\n        \"Completes application of gradient to parameters for LAMB optimizer\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"multi_tensor_adam\", &multi_tensor_adam_cuda,\n        \"Compute and apply gradient update to parameters for Adam optimizer\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"multi_tensor_adam_capturable\", &multi_tensor_adam_capturable_cuda,\n        \"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph support and LR scheduling\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"multi_tensor_adam_capturable_master\", &multi_tensor_adam_capturable_master_cuda,\n        \"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph support, LR scheduling and \"\n        \"FP32 master weights\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"multi_tensor_adagrad\", &multi_tensor_adagrad_cuda,\n        \"Compute and apply gradient update to parameters for Adam optimizer\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"multi_tensor_novograd\", &multi_tensor_novograd_cuda,\n        \"Compute and apply gradient update to parameters for Adam optimizer\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"multi_tensor_lamb\", &multi_tensor_lamb_cuda, \"Computes and apply update for LAMB optimizer\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"multi_tensor_lamb_mp\", &multi_tensor_lamb_mp_cuda, \"Computes and apply update for LAMB optimizer\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"update_scale_hysteresis\", &update_scale_hysteresis_cuda, \"Updates scale while accounting for hysteresis\",\n        py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "csrc/flatten_unflatten.cpp",
    "content": "#include <torch/csrc/utils/tensor_flatten.h>\n#include <torch/extension.h>\n// https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h\n\nat::Tensor flatten(std::vector<at::Tensor> tensors) { return torch::utils::flatten_dense_tensors(tensors); }\n\nstd::vector<at::Tensor> unflatten(at::Tensor flat, std::vector<at::Tensor> tensors) {\n  return torch::utils::unflatten_dense_tensors(flat, tensors);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"flatten\", &flatten, \"Flatten dense tensors\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"unflatten\", &unflatten, \"Unflatten dense tensors\", py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "csrc/fused_dense.cpp",
    "content": "#include <stdio.h>\n#include <torch/extension.h>\n#include <torch/torch.h>\n\n#include <vector>\n\ntemplate <typename T>\nint linear_bias_forward_cuda(at::Tensor input, T* weight, at::Tensor bias, int in_features, int batch_size,\n                             int out_features, at::Tensor output, void* lt_workspace);\n\ntemplate <typename T>\nint linear_bias_backward_cuda(T* input, T* weight, T* d_output, int in_features, int batch_size, int out_features,\n                              T* d_weight, T* d_bias, T* d_input, void* lt_workspace);\n\ntemplate <typename T>\nint linear_gelu_linear_forward_cuda(T* input, T* weight1, T* bias1, T* weight2, T* bias2, int in_features,\n                                    int hidden_features, int batch_size, int out_features, T* output1, T* output2,\n                                    T* gelu_in, void* lt_workspace);\n\ntemplate <typename T>\nint linear_gelu_linear_backward_cuda(T* input, T* gelu_in, T* output1, T* weight1, T* weight2, T* d_output1,\n                                     T* d_output2, int in_features, int batch_size, int hidden_features,\n                                     int out_features, T* d_weight1, T* d_weight2, T* d_bias1, T* d_bias2, T* d_input,\n                                     void* lt_workspace);\n\nat::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor bias) {\n  auto batch_size = input.size(0);\n  auto in_features = input.size(1);\n\n  int out_features = weight.size(0);\n\n  // auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());\n\n  // create output/workspace tensor\n  auto out = at::empty({batch_size, out_features}, input.options());\n  // auto reserved_space = at::empty({reserved_size}, inputs[0].options());\n  //  allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB\n  auto lt_workspace = at::empty({1 << 22}, input.options());\n\n  AT_DISPATCH_FLOATING_TYPES_AND2(\n      at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), \"linear_bias_forward\", [&] {\n        scalar_t* w_ptr = weight.data_ptr<scalar_t>();\n        scalar_t* b_ptr = bias.data_ptr<scalar_t>();\n        [[maybe_unused]] auto result =\n            linear_bias_forward_cuda<scalar_t>(input, w_ptr, bias, in_features, batch_size, out_features, out,\n                                               // out.data_ptr<scalar_t>(),\n                                               // reserved_space.data_ptr<scalar_t>(),\n                                               (void*)(lt_workspace.data_ptr<scalar_t>()));\n      });\n\n  return {out};\n}\n\nstd::vector<at::Tensor> linear_bias_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output) {\n  auto batch_size = input.size(0);\n  auto in_features = input.size(1);\n\n  int out_features = weight.size(0);\n\n  // auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());\n\n  // create output/workspace tensor\n  auto d_weight = at::empty({out_features, in_features}, input.options());\n#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600\n  auto d_bias = d_output.view({-1, out_features}).sum(0, false);\n#else\n  auto d_bias = at::empty({out_features}, input.options());\n#endif\n  auto d_input = at::empty({batch_size, in_features}, input.options());\n  // auto reserved_space = at::empty({reserved_size}, inputs[0].options());\n  //  allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB\n  auto lt_workspace = at::empty({1 << 22}, input.options());\n\n  AT_DISPATCH_FLOATING_TYPES_AND2(\n      at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), \"linear_bias_backward\", [&] {\n        scalar_t* w_ptr = weight.data_ptr<scalar_t>();\n        scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();\n        [[maybe_unused]] auto result = linear_bias_backward_cuda<scalar_t>(\n            input.data_ptr<scalar_t>(), w_ptr, d_output.data_ptr<scalar_t>(), in_features, batch_size, out_features,\n            d_weight.data_ptr<scalar_t>(), d_bias.data_ptr<scalar_t>(), d_input.data_ptr<scalar_t>(),\n            // reserved_space.data_ptr<scalar_t>(),\n            (void*)(lt_workspace.data_ptr<scalar_t>()));\n      });\n\n  return {d_input, d_weight, d_bias};\n}\n\nstd::vector<at::Tensor> linear_gelu_linear_forward(at::Tensor input, at::Tensor weight1, at::Tensor bias1,\n                                                   at::Tensor weight2, at::Tensor bias2) {\n  auto batch_size = input.size(0);\n  auto in_features = input.size(1);\n\n  int hidden_features = weight1.size(0);\n  int out_features = weight2.size(0);\n\n  // auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());\n\n  // create output/workspace tensor\n  auto output1 = at::empty({batch_size, hidden_features}, input.options());\n  auto gelu_in = at::empty({batch_size, hidden_features}, input.options());\n  auto output2 = at::empty({batch_size, out_features}, input.options());\n  // auto reserved_space = at::empty({reserved_size}, inputs[0].options());\n  //  allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB\n  auto lt_workspace = at::empty({1 << 22}, input.options());\n\n  AT_DISPATCH_FLOATING_TYPES_AND2(\n      at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), \"linear_gelu_linear_forward\", [&] {\n        scalar_t* w1_ptr = weight1.data_ptr<scalar_t>();\n        scalar_t* b1_ptr = bias1.data_ptr<scalar_t>();\n        scalar_t* w2_ptr = weight2.data_ptr<scalar_t>();\n        scalar_t* b2_ptr = bias2.data_ptr<scalar_t>();\n        [[maybe_unused]] auto result = linear_gelu_linear_forward_cuda<scalar_t>(\n            input.data_ptr<scalar_t>(), w1_ptr, b1_ptr, w2_ptr, b2_ptr, in_features, hidden_features, batch_size,\n            out_features, output1.data_ptr<scalar_t>(), output2.data_ptr<scalar_t>(), gelu_in.data_ptr<scalar_t>(),\n            // reserved_space.data_ptr<scalar_t>(),\n            (void*)(lt_workspace.data_ptr<scalar_t>()));\n      });\n\n  return {output1, output2, gelu_in};\n}\n\nstd::vector<at::Tensor> linear_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1,\n                                                    at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2) {\n  auto batch_size = input.size(0);\n  auto in_features = input.size(1);\n\n  int hidden_features = weight1.size(0);\n  int out_features = weight2.size(0);\n\n  // auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());\n\n  // create output/workspace tensor\n  auto d_weight1 = at::empty({hidden_features, in_features}, input.options());\n  auto d_weight2 = at::empty({out_features, hidden_features}, input.options());\n  auto d_bias1 = at::empty({hidden_features}, input.options());\n  auto d_bias2 = at::empty({out_features}, input.options());\n  auto d_input = at::empty({batch_size, in_features}, input.options());\n  auto d_output1 = at::empty({batch_size, hidden_features}, input.options());\n  // auto reserved_space = at::empty({reserved_size}, inputs[0].options());\n  //  allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB\n  auto lt_workspace = at::empty({1 << 22}, input.options());\n\n  AT_DISPATCH_FLOATING_TYPES_AND2(\n      at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), \"linear_bias_backward\", [&] {\n        // scalar_t* w_ptr = weight.data_ptr<scalar_t>();\n        // scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();\n        [[maybe_unused]] auto result = linear_gelu_linear_backward_cuda<scalar_t>(\n            input.data_ptr<scalar_t>(), gelu_in.data_ptr<scalar_t>(), output1.data_ptr<scalar_t>(),\n            weight1.data_ptr<scalar_t>(), weight2.data_ptr<scalar_t>(), d_output1.data_ptr<scalar_t>(),\n            d_output2.data_ptr<scalar_t>(), in_features, batch_size, hidden_features, out_features,\n            d_weight1.data_ptr<scalar_t>(), d_weight2.data_ptr<scalar_t>(), d_bias1.data_ptr<scalar_t>(),\n            d_bias2.data_ptr<scalar_t>(), d_input.data_ptr<scalar_t>(),\n            // reserved_space.data_ptr<scalar_t>(),\n            (void*)(lt_workspace.data_ptr<scalar_t>()));\n      });\n\n  return {d_input, d_weight1, d_bias1, d_weight2, d_bias2};\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"linear_bias_forward\", &linear_bias_forward, \"linear bias forward\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"linear_bias_backward\", &linear_bias_backward, \"linear bias backward\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"linear_gelu_linear_forward\", &linear_gelu_linear_forward, \"linear gelu linear forward\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"linear_gelu_linear_backward\", &linear_gelu_linear_backward, \"linear gelu linear backward\",\n        py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "csrc/fused_dense_cuda.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <assert.h>\n#include <stdio.h>\n#include <stdlib.h>\n#include <string.h>\n#include <torch/torch.h>\n\n/* Includes, cuda */\n#include <cublas_v2.h>\n#include <cuda_runtime.h>\n\n#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000\n// includes cublaslt\n#include <cublasLt.h>\n#endif\n// FP64 Wrapper around cublas GEMMEx\ncublasStatus_t gemm_bias(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                         const float* alpha, double* A, int lda, double* B, int ldb, const float* beta, double* C,\n                         int ldc) {\n  return cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_R_64F, lda, B, CUDA_R_64F, ldb, beta, C,\n                      CUDA_R_64F, ldc, CUDA_R_64F, CUBLAS_GEMM_DEFAULT);\n}\n\n// FP32 Wrapper around cublas GEMMEx\ncublasStatus_t gemm_bias(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                         const float* alpha, float* A, int lda, float* B, int ldb, const float* beta, float* C,\n                         int ldc) {\n  return cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_R_32F, lda, B, CUDA_R_32F, ldb, beta, C,\n                      CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT);\n}\n\n// FP16 Tensor core wrapper around cublas GEMMEx\ncublasStatus_t gemm_bias(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                         const float* alpha, at::Half* A, int lda, at::Half* B, int ldb, const float* beta, at::Half* C,\n                         int ldc) {\n  return cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_R_16F, lda, B, CUDA_R_16F, ldb, beta, C,\n                      CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);\n}\n\n// BF16 Tensor core wrapper around cublas GEMMEx\ncublasStatus_t gemm_bias(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                         const float* alpha, at::BFloat16* A, int lda, at::BFloat16* B, int ldb, const float* beta,\n                         at::BFloat16* C, int ldc) {\n  return cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_R_16BF, lda, B, CUDA_R_16BF, ldb, beta, C,\n                      CUDA_R_16BF, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);\n}\n\n#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600\n\nint gemm_bias_lt(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                 const float* alpha,                                            /* host pointer */\n                 at::Half* A, int lda, at::Half* B, int ldb, const float* beta, /* host pointer */\n                 at::Half* C, int ldc, void* workspace, size_t workspaceSize, cudaStream_t stream, bool use_bias,\n                 const void* bias) {\n  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;\n\n  cublasLtMatmulDescOpaque_t operationDesc = {};\n  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};\n  cublasLtMatmulPreferenceOpaque_t preference = {};\n\n  int returnedResults = 0;\n  cublasLtMatmulHeuristicResult_t heuristicResult = {};\n  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;\n\n  // Create operation descriptor; see cublasLtMatmulDescAttributes_t\n  // for details about defaults; here we just set the transforms for\n  // A and B.\n  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (use_bias) {\n    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));\n    if (status != CUBLAS_STATUS_SUCCESS) {\n      goto CLEANUP;\n    }\n    epilogue = CUBLASLT_EPILOGUE_BIAS;\n  }\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n\n  // Create matrix descriptors. Not setting any extra attributes.\n  status =\n      cublasLtMatrixLayoutInit(&Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status =\n      cublasLtMatrixLayoutInit(&Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // Create preference handle; In general, extra attributes can be\n  // used here to disable tensor ops or to make sure algo selected\n  // will work with badly aligned A, B, C. However, for simplicity\n  // here we assume A,B,C are always well aligned (e.g., directly\n  // come from cudaMalloc)\n  status = cublasLtMatmulPreferenceInit(&preference);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize,\n                                                sizeof(workspaceSize));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // We just need the best available heuristic to try and run matmul.\n  // There is no guarantee that this will work. For example, if A is\n  // badly aligned, you can request more (e.g. 32) algos and try to\n  // run them one by one until something works.\n  status = cublasLtMatmulAlgoGetHeuristic(ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1,\n                                          &heuristicResult, &returnedResults);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (returnedResults == 0) {\n    status = CUBLAS_STATUS_NOT_SUPPORTED;\n    goto CLEANUP;\n  }\n  status = cublasLtMatmul(ltHandle, &operationDesc, alpha, A, &Adesc, B, &Bdesc, beta, C, &Cdesc, C, &Cdesc,\n                          //&heuristicResult.algo,\n                          NULL, workspace, workspaceSize, stream);\n\nCLEANUP:\n  // Descriptors are no longer needed as all GPU work was already\n  // enqueued.\n  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;\n}\n\nint gemm_bias_lt(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                 const float* alpha,                                                    /* host pointer */\n                 at::BFloat16* A, int lda, at::BFloat16* B, int ldb, const float* beta, /* host pointer */\n                 at::BFloat16* C, int ldc, void* workspace, size_t workspaceSize, cudaStream_t stream, bool use_bias,\n                 const void* bias) {\n  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;\n\n  cublasLtMatmulDescOpaque_t operationDesc = {};\n  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};\n  cublasLtMatmulPreferenceOpaque_t preference = {};\n\n  int returnedResults = 0;\n  cublasLtMatmulHeuristicResult_t heuristicResult = {};\n  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;\n\n  // Create operation descriptor; see cublasLtMatmulDescAttributes_t\n  // for details about defaults; here we just set the transforms for\n  // A and B.\n  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (use_bias) {\n    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));\n    if (status != CUBLAS_STATUS_SUCCESS) {\n      goto CLEANUP;\n    }\n    epilogue = CUBLASLT_EPILOGUE_BIAS;\n  }\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n\n  // Create matrix descriptors. Not setting any extra attributes.\n  status =\n      cublasLtMatrixLayoutInit(&Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status =\n      cublasLtMatrixLayoutInit(&Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // Create preference handle; In general, extra attributes can be\n  // used here to disable tensor ops or to make sure algo selected\n  // will work with badly aligned A, B, C. However, for simplicity\n  // here we assume A,B,C are always well aligned (e.g., directly\n  // come from cudaMalloc)\n  status = cublasLtMatmulPreferenceInit(&preference);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize,\n                                                sizeof(workspaceSize));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // We just need the best available heuristic to try and run matmul.\n  // There is no guarantee that this will work. For example, if A is\n  // badly aligned, you can request more (e.g. 32) algos and try to\n  // run them one by one until something works.\n  status = cublasLtMatmulAlgoGetHeuristic(ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1,\n                                          &heuristicResult, &returnedResults);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (returnedResults == 0) {\n    status = CUBLAS_STATUS_NOT_SUPPORTED;\n    goto CLEANUP;\n  }\n  status = cublasLtMatmul(ltHandle, &operationDesc, alpha, A, &Adesc, B, &Bdesc, beta, C, &Cdesc, C, &Cdesc,\n                          //&heuristicResult.algo,\n                          NULL, workspace, workspaceSize, stream);\n\nCLEANUP:\n  // Descriptors are no longer needed as all GPU work was already\n  // enqueued.\n  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;\n}\n\nint gemm_bias_lt(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                 const float* alpha,                                        /* host pointer */\n                 double* A, int lda, double* B, int ldb, const float* beta, /* host pointer */\n                 double* C, int ldc, void* workspace, size_t workspaceSize, cudaStream_t stream, bool use_bias,\n                 const void* bias) {\n  return 1;\n}\n\nint gemm_bias_lt(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                 const float* alpha,                                      /* host pointer */\n                 float* A, int lda, float* B, int ldb, const float* beta, /* host pointer */\n                 float* C, int ldc, void* workspace, size_t workspaceSize, cudaStream_t stream, bool use_bias,\n                 const void* bias) {\n  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;\n\n  cublasLtMatmulDescOpaque_t operationDesc = {};\n  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};\n  cublasLtMatmulPreferenceOpaque_t preference = {};\n\n  int returnedResults = 0;\n  cublasLtMatmulHeuristicResult_t heuristicResult = {};\n  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;\n\n  // Create operation descriptor; see cublasLtMatmulDescAttributes_t\n  // for details about defaults; here we just set the transforms for\n  // A and B.\n  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (use_bias) {\n    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));\n    if (status != CUBLAS_STATUS_SUCCESS) {\n      goto CLEANUP;\n    }\n    epilogue = CUBLASLT_EPILOGUE_BIAS;\n  }\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n\n  // Create matrix descriptors. Not setting any extra attributes.\n  status =\n      cublasLtMatrixLayoutInit(&Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status =\n      cublasLtMatrixLayoutInit(&Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // Create preference handle; In general, extra attributes can be\n  // used here to disable tensor ops or to make sure algo selected\n  // will work with badly aligned A, B, C. However, for simplicity\n  // here we assume A,B,C are always well aligned (e.g., directly\n  // come from cudaMalloc)\n  status = cublasLtMatmulPreferenceInit(&preference);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize,\n                                                sizeof(workspaceSize));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // We just need the best available heuristic to try and run matmul.\n  // There is no guarantee that this will work. For example, if A is\n  // badly aligned, you can request more (e.g. 32) algos and try to\n  // run them one by one until something works.\n  status = cublasLtMatmulAlgoGetHeuristic(ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1,\n                                          &heuristicResult, &returnedResults);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (returnedResults == 0) {\n    status = CUBLAS_STATUS_NOT_SUPPORTED;\n    goto CLEANUP;\n  }\n\n  status = cublasLtMatmul(ltHandle, &operationDesc, alpha, A, &Adesc, B, &Bdesc, beta, C, &Cdesc, C, &Cdesc,\n                          &heuristicResult.algo, workspace, workspaceSize, stream);\n\nCLEANUP:\n  // Descriptors are no longer needed as all GPU work was already\n  // enqueued.\n  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;\n}\n\nint gemm_bias_gelu_lt(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, int m, int n,\n                      int k, const float* alpha,                                     /* host pointer */\n                      at::Half* A, int lda, at::Half* B, int ldb, const float* beta, /* host pointer */\n                      at::Half* C, int64_t ldc, void* workspace, size_t workspaceSize, cudaStream_t stream,\n                      bool use_bias, const void* gelu_in, const void* bias) {\n  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;\n\n  cublasLtMatmulDescOpaque_t operationDesc = {};\n  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};\n  cublasLtMatmulPreferenceOpaque_t preference = {};\n\n  int returnedResults = 0;\n  cublasLtMatmulHeuristicResult_t heuristicResult = {};\n  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_GELU_AUX;\n\n  // Create operation descriptor; see cublasLtMatmulDescAttributes_t\n  // for details about defaults; here we just set the transforms for\n  // A and B.\n  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in,\n                                          sizeof(gelu_in));\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));\n\n  if (use_bias) {\n    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));\n    if (status != CUBLAS_STATUS_SUCCESS) {\n      goto CLEANUP;\n    }\n    epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;\n  }\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n\n  // Create matrix descriptors. Not setting any extra attributes.\n  status =\n      cublasLtMatrixLayoutInit(&Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status =\n      cublasLtMatrixLayoutInit(&Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // Create preference handle; In general, extra attributes can be\n  // used here to disable tensor ops or to make sure algo selected\n  // will work with badly aligned A, B, C. However, for simplicity\n  // here we assume A,B,C are always well aligned (e.g., directly\n  // come from cudaMalloc)\n  status = cublasLtMatmulPreferenceInit(&preference);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize,\n                                                sizeof(workspaceSize));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // We just need the best available heuristic to try and run matmul.\n  // There is no guarantee that this will work. For example, if A is\n  // badly aligned, you can request more (e.g. 32) algos and try to\n  // run them one by one until something works.\n  status = cublasLtMatmulAlgoGetHeuristic(ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1,\n                                          &heuristicResult, &returnedResults);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (returnedResults == 0) {\n    status = CUBLAS_STATUS_NOT_SUPPORTED;\n    goto CLEANUP;\n  }\n  status = cublasLtMatmul(ltHandle, &operationDesc, alpha, A, &Adesc, B, &Bdesc, beta, C, &Cdesc, C, &Cdesc,\n                          //&heuristicResult.algo,\n                          NULL, workspace, workspaceSize, stream);\n\nCLEANUP:\n  // Descriptors are no longer needed as all GPU work was already\n  // enqueued.\n  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;\n}\n\nint gemm_bias_gelu_lt(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, int m, int n,\n                      int k, const float* alpha,                                             /* host pointer */\n                      at::BFloat16* A, int lda, at::BFloat16* B, int ldb, const float* beta, /* host pointer */\n                      at::BFloat16* C, int64_t ldc, void* workspace, size_t workspaceSize, cudaStream_t stream,\n                      bool use_bias, const void* gelu_in, const void* bias) {\n  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;\n\n  cublasLtMatmulDescOpaque_t operationDesc = {};\n  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};\n  cublasLtMatmulPreferenceOpaque_t preference = {};\n\n  int returnedResults = 0;\n  cublasLtMatmulHeuristicResult_t heuristicResult = {};\n  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_GELU_AUX;\n\n  // Create operation descriptor; see cublasLtMatmulDescAttributes_t\n  // for details about defaults; here we just set the transforms for\n  // A and B.\n  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in,\n                                          sizeof(gelu_in));\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));\n\n  if (use_bias) {\n    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));\n    if (status != CUBLAS_STATUS_SUCCESS) {\n      goto CLEANUP;\n    }\n    epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;\n  }\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n\n  // Create matrix descriptors. Not setting any extra attributes.\n  status =\n      cublasLtMatrixLayoutInit(&Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status =\n      cublasLtMatrixLayoutInit(&Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // Create preference handle; In general, extra attributes can be\n  // used here to disable tensor ops or to make sure algo selected\n  // will work with badly aligned A, B, C. However, for simplicity\n  // here we assume A,B,C are always well aligned (e.g., directly\n  // come from cudaMalloc)\n  status = cublasLtMatmulPreferenceInit(&preference);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize,\n                                                sizeof(workspaceSize));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // We just need the best available heuristic to try and run matmul.\n  // There is no guarantee that this will work. For example, if A is\n  // badly aligned, you can request more (e.g. 32) algos and try to\n  // run them one by one until something works.\n  status = cublasLtMatmulAlgoGetHeuristic(ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1,\n                                          &heuristicResult, &returnedResults);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (returnedResults == 0) {\n    status = CUBLAS_STATUS_NOT_SUPPORTED;\n    goto CLEANUP;\n  }\n  status = cublasLtMatmul(ltHandle, &operationDesc, alpha, A, &Adesc, B, &Bdesc, beta, C, &Cdesc, C, &Cdesc,\n                          //&heuristicResult.algo,\n                          NULL, workspace, workspaceSize, stream);\n\nCLEANUP:\n  // Descriptors are no longer needed as all GPU work was already\n  // enqueued.\n  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;\n}\n\nint gemm_bias_gelu_lt(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, int m, int n,\n                      int k, const float* alpha,                                 /* host pointer */\n                      double* A, int lda, double* B, int ldb, const float* beta, /* host pointer */\n                      double* C, int ldc, void* workspace, size_t workspaceSize, cudaStream_t stream, bool use_bias,\n                      const void* gelu_in, const void* bias) {\n  return 1;\n}\n\nint gemm_bias_gelu_lt(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, int m, int n,\n                      int k, const float* alpha,                               /* host pointer */\n                      float* A, int lda, float* B, int ldb, const float* beta, /* host pointer */\n                      float* C, int64_t ldc, void* workspace, size_t workspaceSize, cudaStream_t stream, bool use_bias,\n                      const void* gelu_in, const void* bias) {\n  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;\n\n  cublasLtMatmulDescOpaque_t operationDesc = {};\n  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};\n  cublasLtMatmulPreferenceOpaque_t preference = {};\n\n  int returnedResults = 0;\n  cublasLtMatmulHeuristicResult_t heuristicResult = {};\n  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_GELU_AUX;\n\n  // Create operation descriptor; see cublasLtMatmulDescAttributes_t\n  // for details about defaults; here we just set the transforms for\n  // A and B.\n  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in,\n                                          sizeof(gelu_in));\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));\n\n  if (use_bias) {\n    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));\n    if (status != CUBLAS_STATUS_SUCCESS) {\n      goto CLEANUP;\n    }\n    epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;\n  }\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n\n  // Create matrix descriptors. Not setting any extra attributes.\n  status =\n      cublasLtMatrixLayoutInit(&Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status =\n      cublasLtMatrixLayoutInit(&Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // Create preference handle; In general, extra attributes can be\n  // used here to disable tensor ops or to make sure algo selected\n  // will work with badly aligned A, B, C. However, for simplicity\n  // here we assume A,B,C are always well aligned (e.g., directly\n  // come from cudaMalloc)\n  status = cublasLtMatmulPreferenceInit(&preference);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize,\n                                                sizeof(workspaceSize));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // We just need the best available heuristic to try and run matmul.\n  // There is no guarantee that this will work. For example, if A is\n  // badly aligned, you can request more (e.g. 32) algos and try to\n  // run them one by one until something works.\n  status = cublasLtMatmulAlgoGetHeuristic(ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1,\n                                          &heuristicResult, &returnedResults);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (returnedResults == 0) {\n    status = CUBLAS_STATUS_NOT_SUPPORTED;\n    goto CLEANUP;\n  }\n  status = cublasLtMatmul(ltHandle, &operationDesc, alpha, A, &Adesc, B, &Bdesc, beta, C, &Cdesc, C, &Cdesc,\n                          //&heuristicResult.algo,\n                          NULL, workspace, workspaceSize, stream);\n\nCLEANUP:\n  // Descriptors are no longer needed as all GPU work was already\n  // enqueued.\n  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;\n}\n\nint gemm_bgradb_lt(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                   const float* alpha,                                            /* host pointer */\n                   at::Half* A, int lda, at::Half* B, int ldb, const float* beta, /* host pointer */\n                   at::Half* C, int ldc, void* workspace, size_t workspaceSize, cudaStream_t stream, bool use_bias,\n                   const void* bgrad) {\n  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;\n\n  cublasLtMatmulDescOpaque_t operationDesc = {};\n  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};\n  cublasLtMatmulPreferenceOpaque_t preference = {};\n\n  int returnedResults = 0;\n  cublasLtMatmulHeuristicResult_t heuristicResult = {};\n  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;\n\n  // Create operation descriptor; see cublasLtMatmulDescAttributes_t\n  // for details about defaults; here we just set the transforms for\n  // A and B.\n  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (use_bias) {\n    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));\n    if (status != CUBLAS_STATUS_SUCCESS) {\n      goto CLEANUP;\n    }\n    epilogue = CUBLASLT_EPILOGUE_BGRADB;\n  }\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n\n  // Create matrix descriptors. Not setting any extra attributes.\n  status =\n      cublasLtMatrixLayoutInit(&Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status =\n      cublasLtMatrixLayoutInit(&Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // Create preference handle; In general, extra attributes can be\n  // used here to disable tensor ops or to make sure algo selected\n  // will work with badly aligned A, B, C. However, for simplicity\n  // here we assume A,B,C are always well aligned (e.g., directly\n  // come from cudaMalloc)\n  status = cublasLtMatmulPreferenceInit(&preference);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize,\n                                                sizeof(workspaceSize));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // We just need the best available heuristic to try and run matmul.\n  // There is no guarantee that this will work. For example, if A is\n  // badly aligned, you can request more (e.g. 32) algos and try to\n  // run them one by one until something works.\n  status = cublasLtMatmulAlgoGetHeuristic(ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1,\n                                          &heuristicResult, &returnedResults);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (returnedResults == 0) {\n    status = CUBLAS_STATUS_NOT_SUPPORTED;\n    goto CLEANUP;\n  }\n  status = cublasLtMatmul(ltHandle, &operationDesc, alpha, A, &Adesc, B, &Bdesc, beta, C, &Cdesc, C, &Cdesc,\n                          //&heuristicResult.algo,\n                          NULL, workspace, workspaceSize, stream);\n\nCLEANUP:\n  // Descriptors are no longer needed as all GPU work was already\n  // enqueued.\n  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;\n}\n\nint gemm_bgradb_lt(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                   const float* alpha,                                                    /* host pointer */\n                   at::BFloat16* A, int lda, at::BFloat16* B, int ldb, const float* beta, /* host pointer */\n                   at::BFloat16* C, int ldc, void* workspace, size_t workspaceSize, cudaStream_t stream, bool use_bias,\n                   const void* bgrad) {\n  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;\n\n  cublasLtMatmulDescOpaque_t operationDesc = {};\n  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};\n  cublasLtMatmulPreferenceOpaque_t preference = {};\n\n  int returnedResults = 0;\n  cublasLtMatmulHeuristicResult_t heuristicResult = {};\n  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;\n\n  // Create operation descriptor; see cublasLtMatmulDescAttributes_t\n  // for details about defaults; here we just set the transforms for\n  // A and B.\n  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (use_bias) {\n    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));\n    if (status != CUBLAS_STATUS_SUCCESS) {\n      goto CLEANUP;\n    }\n    epilogue = CUBLASLT_EPILOGUE_BGRADB;\n  }\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n\n  // Create matrix descriptors. Not setting any extra attributes.\n  status =\n      cublasLtMatrixLayoutInit(&Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status =\n      cublasLtMatrixLayoutInit(&Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // Create preference handle; In general, extra attributes can be\n  // used here to disable tensor ops or to make sure algo selected\n  // will work with badly aligned A, B, C. However, for simplicity\n  // here we assume A,B,C are always well aligned (e.g., directly\n  // come from cudaMalloc)\n  status = cublasLtMatmulPreferenceInit(&preference);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize,\n                                                sizeof(workspaceSize));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // We just need the best available heuristic to try and run matmul.\n  // There is no guarantee that this will work. For example, if A is\n  // badly aligned, you can request more (e.g. 32) algos and try to\n  // run them one by one until something works.\n  status = cublasLtMatmulAlgoGetHeuristic(ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1,\n                                          &heuristicResult, &returnedResults);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (returnedResults == 0) {\n    status = CUBLAS_STATUS_NOT_SUPPORTED;\n    goto CLEANUP;\n  }\n  status = cublasLtMatmul(ltHandle, &operationDesc, alpha, A, &Adesc, B, &Bdesc, beta, C, &Cdesc, C, &Cdesc,\n                          //&heuristicResult.algo,\n                          NULL, workspace, workspaceSize, stream);\n\nCLEANUP:\n  // Descriptors are no longer needed as all GPU work was already\n  // enqueued.\n  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;\n}\n\nint gemm_bgradb_lt(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                   const float* alpha,                                        /* host pointer */\n                   double* A, int lda, double* B, int ldb, const float* beta, /* host pointer */\n                   double* C, int ldc, void* workspace, size_t workspaceSize, cudaStream_t stream, bool use_bias,\n                   const void* bgrad) {\n  return 1;\n}\n\nint gemm_bgradb_lt(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                   const float* alpha,                                      /* host pointer */\n                   float* A, int lda, float* B, int ldb, const float* beta, /* host pointer */\n                   float* C, int ldc, void* workspace, size_t workspaceSize, cudaStream_t stream, bool use_bias,\n                   const void* bgrad) {\n  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;\n\n  cublasLtMatmulDescOpaque_t operationDesc = {};\n  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};\n  cublasLtMatmulPreferenceOpaque_t preference = {};\n\n  int returnedResults = 0;\n  cublasLtMatmulHeuristicResult_t heuristicResult = {};\n  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;\n\n  // Create operation descriptor; see cublasLtMatmulDescAttributes_t\n  // for details about defaults; here we just set the transforms for\n  // A and B.\n  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (use_bias) {\n    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));\n    if (status != CUBLAS_STATUS_SUCCESS) {\n      goto CLEANUP;\n    }\n    epilogue = CUBLASLT_EPILOGUE_BGRADB;\n  }\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n\n  // Create matrix descriptors. Not setting any extra attributes.\n  status =\n      cublasLtMatrixLayoutInit(&Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status =\n      cublasLtMatrixLayoutInit(&Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // Create preference handle; In general, extra attributes can be\n  // used here to disable tensor ops or to make sure algo selected\n  // will work with badly aligned A, B, C. However, for simplicity\n  // here we assume A,B,C are always well aligned (e.g., directly\n  // come from cudaMalloc)\n  status = cublasLtMatmulPreferenceInit(&preference);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize,\n                                                sizeof(workspaceSize));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // We just need the best available heuristic to try and run matmul.\n  // There is no guarantee that this will work. For example, if A is\n  // badly aligned, you can request more (e.g. 32) algos and try to\n  // run them one by one until something works.\n  status = cublasLtMatmulAlgoGetHeuristic(ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1,\n                                          &heuristicResult, &returnedResults);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (returnedResults == 0) {\n    status = CUBLAS_STATUS_NOT_SUPPORTED;\n    goto CLEANUP;\n  }\n\n  status = cublasLtMatmul(ltHandle, &operationDesc, alpha, A, &Adesc, B, &Bdesc, beta, C, &Cdesc, C, &Cdesc,\n                          &heuristicResult.algo, workspace, workspaceSize, stream);\n\nCLEANUP:\n  // Descriptors are no longer needed as all GPU work was already\n  // enqueued.\n  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;\n}\n\nint gemm_dgelu_bgradb_lt(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, int m, int n,\n                         int k, const float* alpha,                                     /* host pointer */\n                         at::Half* A, int lda, at::Half* B, int ldb, const float* beta, /* host pointer */\n                         at::Half* C, int64_t ldc, void* workspace, size_t workspaceSize, cudaStream_t stream,\n                         const void* gelu_in, const void* bgrad) {\n  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;\n\n  cublasLtMatmulDescOpaque_t operationDesc = {};\n  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};\n  cublasLtMatmulPreferenceOpaque_t preference = {};\n\n  int returnedResults = 0;\n  cublasLtMatmulHeuristicResult_t heuristicResult = {};\n  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;\n\n  // Create operation descriptor; see cublasLtMatmulDescAttributes_t\n  // for details about defaults; here we just set the transforms for\n  // A and B.\n  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in,\n                                          sizeof(gelu_in));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n\n  // Create matrix descriptors. Not setting any extra attributes.\n  status =\n      cublasLtMatrixLayoutInit(&Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status =\n      cublasLtMatrixLayoutInit(&Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // Create preference handle; In general, extra attributes can be\n  // used here to disable tensor ops or to make sure algo selected\n  // will work with badly aligned A, B, C. However, for simplicity\n  // here we assume A,B,C are always well aligned (e.g., directly\n  // come from cudaMalloc)\n  status = cublasLtMatmulPreferenceInit(&preference);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize,\n                                                sizeof(workspaceSize));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // We just need the best available heuristic to try and run matmul.\n  // There is no guarantee that this will work. For example, if A is\n  // badly aligned, you can request more (e.g. 32) algos and try to\n  // run them one by one until something works.\n  status = cublasLtMatmulAlgoGetHeuristic(ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1,\n                                          &heuristicResult, &returnedResults);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (returnedResults == 0) {\n    status = CUBLAS_STATUS_NOT_SUPPORTED;\n    goto CLEANUP;\n  }\n  status = cublasLtMatmul(ltHandle, &operationDesc, alpha, A, &Adesc, B, &Bdesc, beta, C, &Cdesc, C, &Cdesc,\n                          //&heuristicResult.algo,\n                          NULL, workspace, workspaceSize, stream);\n\nCLEANUP:\n  // Descriptors are no longer needed as all GPU work was already\n  // enqueued.\n  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;\n}\n\nint gemm_dgelu_bgradb_lt(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, int m, int n,\n                         int k, const float* alpha,                                             /* host pointer */\n                         at::BFloat16* A, int lda, at::BFloat16* B, int ldb, const float* beta, /* host pointer */\n                         at::BFloat16* C, int64_t ldc, void* workspace, size_t workspaceSize, cudaStream_t stream,\n                         const void* gelu_in, const void* bgrad) {\n  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;\n\n  cublasLtMatmulDescOpaque_t operationDesc = {};\n  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};\n  cublasLtMatmulPreferenceOpaque_t preference = {};\n\n  int returnedResults = 0;\n  cublasLtMatmulHeuristicResult_t heuristicResult = {};\n  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;\n\n  // Create operation descriptor; see cublasLtMatmulDescAttributes_t\n  // for details about defaults; here we just set the transforms for\n  // A and B.\n  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in,\n                                          sizeof(gelu_in));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n\n  // Create matrix descriptors. Not setting any extra attributes.\n  status =\n      cublasLtMatrixLayoutInit(&Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status =\n      cublasLtMatrixLayoutInit(&Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // Create preference handle; In general, extra attributes can be\n  // used here to disable tensor ops or to make sure algo selected\n  // will work with badly aligned A, B, C. However, for simplicity\n  // here we assume A,B,C are always well aligned (e.g., directly\n  // come from cudaMalloc)\n  status = cublasLtMatmulPreferenceInit(&preference);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize,\n                                                sizeof(workspaceSize));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // We just need the best available heuristic to try and run matmul.\n  // There is no guarantee that this will work. For example, if A is\n  // badly aligned, you can request more (e.g. 32) algos and try to\n  // run them one by one until something works.\n  status = cublasLtMatmulAlgoGetHeuristic(ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1,\n                                          &heuristicResult, &returnedResults);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (returnedResults == 0) {\n    status = CUBLAS_STATUS_NOT_SUPPORTED;\n    goto CLEANUP;\n  }\n  status = cublasLtMatmul(ltHandle, &operationDesc, alpha, A, &Adesc, B, &Bdesc, beta, C, &Cdesc, C, &Cdesc,\n                          //&heuristicResult.algo,\n                          NULL, workspace, workspaceSize, stream);\n\nCLEANUP:\n  // Descriptors are no longer needed as all GPU work was already\n  // enqueued.\n  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;\n}\n\nint gemm_dgelu_bgradb_lt(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, int m, int n,\n                         int k, const float* alpha,                                 /* host pointer */\n                         double* A, int lda, double* B, int ldb, const float* beta, /* host pointer */\n                         double* C, int ldc, void* workspace, size_t workspaceSize, cudaStream_t stream,\n                         const void* gelu_in, const void* bgrad) {\n  return 1;\n}\n\nint gemm_dgelu_bgradb_lt(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, int m, int n,\n                         int k, const float* alpha,                               /* host pointer */\n                         float* A, int lda, float* B, int ldb, const float* beta, /* host pointer */\n                         float* C, int64_t ldc, void* workspace, size_t workspaceSize, cudaStream_t stream,\n                         const void* gelu_in, const void* bgrad) {\n  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;\n\n  cublasLtMatmulDescOpaque_t operationDesc = {};\n  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};\n  cublasLtMatmulPreferenceOpaque_t preference = {};\n\n  int returnedResults = 0;\n  cublasLtMatmulHeuristicResult_t heuristicResult = {};\n  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;\n\n  // Create operation descriptor; see cublasLtMatmulDescAttributes_t\n  // for details about defaults; here we just set the transforms for\n  // A and B.\n  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in,\n                                          sizeof(gelu_in));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n\n  // Create matrix descriptors. Not setting any extra attributes.\n  status =\n      cublasLtMatrixLayoutInit(&Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status =\n      cublasLtMatrixLayoutInit(&Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // Create preference handle; In general, extra attributes can be\n  // used here to disable tensor ops or to make sure algo selected\n  // will work with badly aligned A, B, C. However, for simplicity\n  // here we assume A,B,C are always well aligned (e.g., directly\n  // come from cudaMalloc)\n  status = cublasLtMatmulPreferenceInit(&preference);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize,\n                                                sizeof(workspaceSize));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // We just need the best available heuristic to try and run matmul.\n  // There is no guarantee that this will work. For example, if A is\n  // badly aligned, you can request more (e.g. 32) algos and try to\n  // run them one by one until something works.\n  status = cublasLtMatmulAlgoGetHeuristic(ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1,\n                                          &heuristicResult, &returnedResults);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (returnedResults == 0) {\n    status = CUBLAS_STATUS_NOT_SUPPORTED;\n    goto CLEANUP;\n  }\n  status = cublasLtMatmul(ltHandle, &operationDesc, alpha, A, &Adesc, B, &Bdesc, beta, C, &Cdesc, C, &Cdesc,\n                          //&heuristicResult.algo,\n                          NULL, workspace, workspaceSize, stream);\n\nCLEANUP:\n  // Descriptors are no longer needed as all GPU work was already\n  // enqueued.\n  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;\n}\n\n#endif\n\ntemplate <typename T>\nint linear_bias_forward_cuda(at::Tensor input, T* weight, at::Tensor bias, int in_features, int batch_size,\n                             int out_features, at::Tensor output, void* lt_workspace) {\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  // Get the stream from cublas handle to reuse for biasReLU kernel.\n  cudaStream_t stream;\n  cublasGetStream(handle, &stream);\n  const float alpha = 1.0;\n  const float beta_zero = 0.0;\n  const float beta_one = 1.0;\n  int status = 1;\n#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600\n  status = gemm_bias_lt((cublasLtHandle_t)handle, CUBLAS_OP_T, CUBLAS_OP_N, out_features, batch_size, in_features,\n                        &alpha,                                                            /* host pointer */\n                        weight, in_features, input.data_ptr<T>(), in_features, &beta_zero, /* host pointer */\n                        output.data_ptr<T>(), out_features, lt_workspace, 1 << 22, stream, true,\n                        static_cast<const void*>(bias.data_ptr<T>()));\n#endif\n  if (status != 0) {\n    output.copy_(bias);\n    status = gemm_bias(handle, CUBLAS_OP_T, CUBLAS_OP_N, out_features, batch_size, in_features, &alpha, weight,\n                       in_features, input.data_ptr<T>(), in_features, &beta_one, output.data_ptr<T>(), out_features);\n  }\n  return status;\n}\n\ntemplate <typename T>\nint linear_bias_backward_cuda(T* input, T* weight, T* d_output, int in_features, int batch_size, int out_features,\n                              T* d_weight, T* d_bias, T* d_input, void* lt_workspace) {\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  // Get the stream from cublas handle to reuse for biasReLU kernel.\n  cudaStream_t stream;\n  cublasGetStream(handle, &stream);\n  const float alpha = 1.0;\n  const float beta_zero = 0.0;\n  int status = 1;\n#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600\n  status = gemm_bgradb_lt((cublasLtHandle_t)handle, CUBLAS_OP_N, CUBLAS_OP_T, in_features, out_features, batch_size,\n                          &alpha,                                                 /* host pointer */\n                          input, in_features, d_output, out_features, &beta_zero, /* host pointer */\n                          d_weight, in_features, lt_workspace, 1 << 22, stream, true, static_cast<const void*>(d_bias));\n#endif\n\n  if (status != 0) {\n    status = gemm_bias(handle, CUBLAS_OP_N, CUBLAS_OP_T, in_features, out_features, batch_size, &alpha, input,\n                       in_features, d_output, out_features, &beta_zero, d_weight, in_features);\n  }\n\n  status = gemm_bias(handle, CUBLAS_OP_N, CUBLAS_OP_N, in_features, batch_size, out_features, &alpha, weight,\n                     in_features, d_output, out_features, &beta_zero, d_input, in_features);\n  return status;\n}\n\ntemplate <typename T>\nint linear_gelu_linear_forward_cuda(T* input, T* weight1, T* bias1, T* weight2, T* bias2, int in_features,\n                                    int hidden_features, int batch_size, int out_features, T* output1, T* output2,\n                                    T* gelu_in, void* lt_workspace) {\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  // Get the stream from cublas handle to reuse for biasReLU kernel.\n  cudaStream_t stream;\n  cublasGetStream(handle, &stream);\n  const float alpha = 1.0;\n  const float beta_zero = 0.0;\n  int status = 1;\n#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600\n  status = gemm_bias_gelu_lt((cublasLtHandle_t)handle, CUBLAS_OP_T, CUBLAS_OP_N, hidden_features, batch_size,\n                             in_features, &alpha,                                  /* host pointer */\n                             weight1, in_features, input, in_features, &beta_zero, /* host pointer */\n                             output1, hidden_features, lt_workspace, 1 << 22, stream, true,\n                             static_cast<const void*>(gelu_in), static_cast<const void*>(bias1));\n  status = gemm_bias_lt((cublasLtHandle_t)handle, CUBLAS_OP_T, CUBLAS_OP_N, out_features, batch_size, hidden_features,\n                        &alpha,                                                         /* host pointer */\n                        weight2, hidden_features, output1, hidden_features, &beta_zero, /* host pointer */\n                        output2, out_features, lt_workspace, 1 << 22, stream, true, static_cast<const void*>(bias2));\n  return status;\n#else\n  return 1;\n#endif\n}\n\ntemplate <typename T>\nint linear_gelu_linear_backward_cuda(T* input, T* gelu_in, T* output1, T* weight1, T* weight2, T* d_output1,\n                                     T* d_output2, int in_features, int batch_size, int hidden_features,\n                                     int out_features, T* d_weight1, T* d_weight2, T* d_bias1, T* d_bias2, T* d_input,\n                                     void* lt_workspace) {\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  // Get the stream from cublas handle to reuse for biasReLU kernel.\n  cudaStream_t stream;\n  cublasGetStream(handle, &stream);\n  const float alpha = 1.0;\n  const float beta_zero = 0.0;\n  int status = 1;\n#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600\n  // wgrad for first gemm\n  status = gemm_bgradb_lt((cublasLtHandle_t)handle, CUBLAS_OP_N, CUBLAS_OP_T, hidden_features, out_features, batch_size,\n                          &alpha,                                                        /* host pointer */\n                          output1, hidden_features, d_output2, out_features, &beta_zero, /* host pointer */\n                          d_weight2, hidden_features, lt_workspace, 1 << 22, stream, true,\n                          static_cast<const void*>(d_bias2));\n  // dgrad for second GEMM\n  status = gemm_dgelu_bgradb_lt((cublasLtHandle_t)handle, CUBLAS_OP_N, CUBLAS_OP_N, hidden_features, batch_size,\n                                out_features, &alpha,                                          /* host pointer */\n                                weight2, hidden_features, d_output2, out_features, &beta_zero, /* host pointer */\n                                d_output1, hidden_features, lt_workspace, 1 << 22, stream,\n                                static_cast<const void*>(gelu_in), static_cast<const void*>(d_bias1));\n  // wgrad for the first GEMM\n  status = gemm_bias(handle, CUBLAS_OP_N, CUBLAS_OP_T, in_features, hidden_features, batch_size, &alpha, input,\n                     in_features, d_output1, hidden_features, &beta_zero, d_weight1, in_features);\n\n  // dgrad for the first GEMM\n  status = gemm_bias(handle, CUBLAS_OP_N, CUBLAS_OP_N, in_features, batch_size, hidden_features, &alpha, weight1,\n                     in_features, d_output1, hidden_features, &beta_zero, d_input, in_features);\n#endif\n  return status;\n}\n\ntemplate int linear_bias_forward_cuda<at::Half>(at::Tensor input, at::Half* weight, at::Tensor bias, int in_features,\n                                                int batch_size, int out_features, at::Tensor output,\n                                                void* lt_workspace);\n\ntemplate int linear_bias_forward_cuda<float>(at::Tensor input, float* weight, at::Tensor bias, int in_features,\n                                             int batch_size, int out_features, at::Tensor output, void* lt_workspace);\n\ntemplate int linear_bias_forward_cuda<double>(at::Tensor input, double* weight, at::Tensor bias, int in_features,\n                                              int batch_size, int out_features, at::Tensor output, void* lt_workspace);\n\ntemplate int linear_bias_backward_cuda<at::Half>(at::Half* input, at::Half* weight, at::Half* d_output, int in_features,\n                                                 int batch_size, int out_features, at::Half* d_weight, at::Half* d_bias,\n                                                 at::Half* d_input, void* lt_workspace);\n\ntemplate int linear_bias_backward_cuda<float>(float* input, float* weight, float* d_output, int in_features,\n                                              int batch_size, int out_features, float* d_weight, float* d_bias,\n                                              float* d_input, void* lt_workspace);\n\ntemplate int linear_bias_backward_cuda<double>(double* input, double* weight, double* d_output, int in_features,\n                                               int batch_size, int out_features, double* d_weight, double* d_bias,\n                                               double* d_input, void* lt_workspace);\n\ntemplate int linear_gelu_linear_forward_cuda<at::Half>(at::Half* input, at::Half* weight1, at::Half* bias1,\n                                                       at::Half* weight2, at::Half* bias2, int in_features,\n                                                       int hidden_features, int batch_size, int out_features,\n                                                       at::Half* output1, at::Half* output2, at::Half* gelu_in,\n                                                       void* lt_workspace);\n\ntemplate int linear_gelu_linear_forward_cuda<float>(float* input, float* weight1, float* bias1, float* weight2,\n                                                    float* bias2, int in_features, int hidden_features, int batch_size,\n                                                    int out_features, float* output1, float* output2, float* gelu_in,\n                                                    void* lt_workspace);\n\ntemplate int linear_gelu_linear_forward_cuda<double>(double* input, double* weight1, double* bias1, double* weight2,\n                                                     double* bias2, int in_features, int hidden_features,\n                                                     int batch_size, int out_features, double* output1, double* output2,\n                                                     double* gelu_in, void* lt_workspace);\n\ntemplate int linear_gelu_linear_backward_cuda<at::Half>(at::Half* input, at::Half* gelu_in, at::Half* output1,\n                                                        at::Half* weight1, at::Half* weight2, at::Half* d_output1,\n                                                        at::Half* d_output2, int in_features, int batch_size,\n                                                        int hidden_features, int out_features, at::Half* d_weight1,\n                                                        at::Half* d_weight2, at::Half* d_bias1, at::Half* d_bias2,\n                                                        at::Half* d_input, void* lt_workspace);\n\ntemplate int linear_gelu_linear_backward_cuda<float>(float* input, float* gelu_in, float* output1, float* weight1,\n                                                     float* weight2, float* d_output1, float* d_output2,\n                                                     int in_features, int batch_size, int hidden_features,\n                                                     int out_features, float* d_weight1, float* d_weight2,\n                                                     float* d_bias1, float* d_bias2, float* d_input,\n                                                     void* lt_workspace);\n\ntemplate int linear_gelu_linear_backward_cuda<double>(double* input, double* gelu_in, double* output1, double* weight1,\n                                                      double* weight2, double* d_output1, double* d_output2,\n                                                      int in_features, int batch_size, int hidden_features,\n                                                      int out_features, double* d_weight1, double* d_weight2,\n                                                      double* d_bias1, double* d_bias2, double* d_input,\n                                                      void* lt_workspace);\n\ntemplate int linear_bias_forward_cuda<at::BFloat16>(at::Tensor input, at::BFloat16* weight, at::Tensor bias,\n                                                    int in_features, int batch_size, int out_features,\n                                                    at::Tensor output, void* lt_workspace);\n\ntemplate int linear_bias_backward_cuda<at::BFloat16>(at::BFloat16* input, at::BFloat16* weight, at::BFloat16* d_output,\n                                                     int in_features, int batch_size, int out_features,\n                                                     at::BFloat16* d_weight, at::BFloat16* d_bias,\n                                                     at::BFloat16* d_input, void* lt_workspace);\n\ntemplate int linear_gelu_linear_forward_cuda<at::BFloat16>(at::BFloat16* input, at::BFloat16* weight1,\n                                                           at::BFloat16* bias1, at::BFloat16* weight2,\n                                                           at::BFloat16* bias2, int in_features, int hidden_features,\n                                                           int batch_size, int out_features, at::BFloat16* output1,\n                                                           at::BFloat16* output2, at::BFloat16* gelu_in,\n                                                           void* lt_workspace);\n\ntemplate int linear_gelu_linear_backward_cuda<at::BFloat16>(\n    at::BFloat16* input, at::BFloat16* gelu_in, at::BFloat16* output1, at::BFloat16* weight1, at::BFloat16* weight2,\n    at::BFloat16* d_output1, at::BFloat16* d_output2, int in_features, int batch_size, int hidden_features,\n    int out_features, at::BFloat16* d_weight1, at::BFloat16* d_weight2, at::BFloat16* d_bias1, at::BFloat16* d_bias2,\n    at::BFloat16* d_input, void* lt_workspace);\n"
  },
  {
    "path": "csrc/layer_norm_cuda.cpp",
    "content": "#include <torch/extension.h>\n\n#include <cassert>\n#include <optional>\n#include <vector>\n\nnamespace {\nvoid compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int& n1, int& n2) {\n  int idiff = input.ndimension() - normalized_shape.size();\n  n2 = 1;\n  for (int i = 0; i < (int)normalized_shape.size(); ++i) {\n    assert(input.sizes()[i + idiff] == normalized_shape[i]);\n    n2 *= normalized_shape[i];\n  }\n  n1 = 1;\n  for (int i = 0; i < idiff; ++i) {\n    n1 *= input.sizes()[i];\n  }\n}\n\nvoid check_args(at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta) {\n  TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));\n  TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));\n}\n\nvoid check_args(at::IntArrayRef normalized_shape, at::Tensor gamma) {\n  TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));\n}\n\nvoid check_args(at::Tensor input, at::IntArrayRef normalized_shape, int& n1, int& n2) {\n  int64_t normalized_ndim = normalized_shape.size();\n\n  if (normalized_ndim < 1) {\n    std::stringstream ss;\n    ss << \"Expected normalized_shape to be at least 1-dimensional, i.e., \"\n       << \"containing at least one element, but got normalized_shape=\" << normalized_shape;\n    throw std::runtime_error(ss.str());\n  }\n\n  auto input_shape = input.sizes();\n  auto input_ndim = input.dim();\n\n  if (input_ndim < normalized_ndim || !input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {\n    std::stringstream ss;\n    ss << \"Given normalized_shape=\" << normalized_shape << \", expected input with shape [*\";\n    for (auto size : normalized_shape) {\n      ss << \", \" << size;\n    }\n    ss << \"], but got input of size\" << input_shape;\n    throw std::runtime_error(ss.str());\n  }\n\n  compute_n1_n2(input, normalized_shape, n1, n2);\n}\n\nvoid check_args(at::Tensor input, at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta, int& n1,\n                int& n2) {\n  check_args(input, normalized_shape, n1, n2);\n  check_args(normalized_shape, gamma, beta);\n}\n\nvoid check_args(at::Tensor input, at::IntArrayRef normalized_shape, at::Tensor gamma, int& n1, int& n2) {\n  check_args(input, normalized_shape, n1, n2);\n  check_args(normalized_shape, gamma);\n}\n}  // namespace\n\nvoid cuda_layer_norm(at::Tensor& output, at::Tensor& mean, at::Tensor& invvar, const at::Tensor& input, int n1, int n2,\n                     at::IntArrayRef normalized_shape, const std::optional<at::Tensor>& gamma,\n                     const std::optional<at::Tensor>& beta, double epsilon);\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) \\\n  CHECK_CUDA(x);       \\\n  CHECK_CONTIGUOUS(x)\n\nstd::vector<at::Tensor> layer_norm(const at::Tensor& input, at::IntArrayRef normalized_shape, double epsilon) {\n  CHECK_INPUT(input);\n  int n1, n2;\n  check_args(input, normalized_shape, n1, n2);\n  at::Tensor output = at::empty_like(input);\n  at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half ||\n                                                                  input.scalar_type() == at::ScalarType::BFloat16\n                                                              ? at::ScalarType::Float\n                                                              : input.scalar_type()));\n  at::Tensor invvar = at::empty_like(mean);\n  cuda_layer_norm(output, mean, invvar, input, n1, n2, normalized_shape, std::nullopt, std::nullopt, epsilon);\n  return {output, mean, invvar};\n}\n\nstd::vector<at::Tensor> layer_norm_affine(const at::Tensor& input, at::IntArrayRef normalized_shape,\n                                          const at::Tensor& gamma, const at::Tensor& beta, double epsilon) {\n  CHECK_INPUT(input);\n  CHECK_INPUT(gamma);\n  CHECK_INPUT(beta);\n  int n1, n2;\n  check_args(input, normalized_shape, gamma, beta, n1, n2);\n  at::Tensor output = at::empty_like(input);\n  const auto stats_dtype =\n      (input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16)\n          ? at::ScalarType::Float\n          : input.scalar_type();\n  at::Tensor mean = at::empty({n1}, input.options().dtype(stats_dtype));\n  at::Tensor invvar = at::empty_like(mean);\n  cuda_layer_norm(output, mean, invvar, input, n1, n2, normalized_shape, gamma, beta, epsilon);\n  return {output, mean, invvar};\n}\n\nstd::vector<at::Tensor> layer_norm_affine_mixed_dtypes(const at::Tensor& input, at::IntArrayRef normalized_shape,\n                                                       const at::Tensor& gamma, const at::Tensor& beta,\n                                                       double epsilon) {\n  CHECK_INPUT(input);\n  int n1, n2;\n  check_args(input, normalized_shape, n1, n2);\n  at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));\n  at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half ||\n                                                                  input.scalar_type() == at::ScalarType::BFloat16\n                                                              ? at::ScalarType::Float\n                                                              : input.scalar_type()));\n  at::Tensor invvar = at::empty_like(mean);\n  cuda_layer_norm(output, mean, invvar, input, n1, n2, normalized_shape, gamma, beta, epsilon);\n  return {output, mean, invvar};\n}\n\nvoid cuda_layer_norm_gradient(at::Tensor& dout, const std::optional<at::Tensor>& mean, at::Tensor& invvar,\n                              at::Tensor& input_or_output, int n1, int n2, at::IntArrayRef normalized_shape,\n                              const std::optional<at::Tensor>& gamma, const std::optional<at::Tensor>& beta,\n                              double epsilon, at::Tensor& grad_input, const std::optional<at::Tensor>& grad_gamma,\n                              const std::optional<at::Tensor>& grad_beta, bool memory_efficient);\n\nat::Tensor layer_norm_gradient(at::Tensor& dout, const std::optional<at::Tensor>& mean_, at::Tensor& invvar,\n                               at::Tensor& input_or_output, at::IntArrayRef normalized_shape, double epsilon,\n                               bool memory_efficient) {\n  CHECK_INPUT(dout);\n  CHECK_INPUT(invvar);\n  CHECK_INPUT(input_or_output);\n  int n1, n2;\n  check_args(input_or_output, normalized_shape, n1, n2);\n  at::Tensor grad_input = at::empty_like(input_or_output);\n\n  cuda_layer_norm_gradient(dout, mean_, invvar, input_or_output, n1, n2, normalized_shape, std::nullopt, std::nullopt,\n                           epsilon, grad_input, std::nullopt, std::nullopt, memory_efficient);\n  return grad_input;\n}\n\nstd::vector<at::Tensor> layer_norm_gradient_affine(at::Tensor& dout, const std::optional<at::Tensor>& mean_,\n                                                   at::Tensor& invvar, at::Tensor& input_or_output,\n                                                   at::IntArrayRef normalized_shape, at::Tensor& gamma,\n                                                   at::Tensor& beta, double epsilon, bool memory_efficient) {\n  CHECK_INPUT(dout);\n  CHECK_INPUT(invvar);\n  CHECK_INPUT(input_or_output);\n  CHECK_INPUT(gamma);\n  CHECK_INPUT(beta);\n  int n1, n2;\n  check_args(input_or_output, normalized_shape, gamma, beta, n1, n2);\n  at::Tensor grad_input = at::empty_like(input_or_output);\n  at::Tensor grad_gamma = at::empty_like(gamma);\n  at::Tensor grad_beta = at::empty_like(beta);\n  cuda_layer_norm_gradient(dout, mean_, invvar, input_or_output, n1, n2, normalized_shape, gamma, beta, epsilon,\n                           grad_input, grad_gamma, grad_beta, memory_efficient);\n  return {grad_input, grad_gamma, grad_beta};\n}\n\nvoid cuda_rms_norm(at::Tensor& output, at::Tensor& invvar, const at::Tensor& input, int n1, int n2,\n                   at::IntArrayRef normalized_shape, const std::optional<at::Tensor>& gamma, double epsilon);\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) \\\n  CHECK_CUDA(x);       \\\n  CHECK_CONTIGUOUS(x)\n\nstd::vector<at::Tensor> rms_norm(const at::Tensor& input, at::IntArrayRef normalized_shape, double epsilon) {\n  CHECK_INPUT(input);\n  int n1, n2;\n  check_args(input, normalized_shape, n1, n2);\n  at::Tensor output = at::empty_like(input);\n  at::Tensor invvar = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half ||\n                                                                    input.scalar_type() == at::ScalarType::BFloat16\n                                                                ? at::ScalarType::Float\n                                                                : input.scalar_type()));\n  cuda_rms_norm(output, invvar, input, n1, n2, normalized_shape, std::nullopt, epsilon);\n  return {output, invvar};\n}\n\nstd::vector<at::Tensor> rms_norm_affine(const at::Tensor& input, at::IntArrayRef normalized_shape,\n                                        const at::Tensor& gamma, double epsilon) {\n  CHECK_INPUT(input);\n  CHECK_INPUT(gamma);\n  int n1, n2;\n  check_args(input, normalized_shape, gamma, n1, n2);\n  at::Tensor output = at::empty_like(input);\n  const auto stats_dtype =\n      (input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16)\n          ? at::ScalarType::Float\n          : input.scalar_type();\n  at::Tensor invvar = at::empty({n1}, input.options().dtype(stats_dtype));\n  cuda_rms_norm(output, invvar, input, n1, n2, normalized_shape, gamma, epsilon);\n  return {output, invvar};\n}\n\nstd::vector<at::Tensor> rms_norm_affine_mixed_dtypes(const at::Tensor& input, at::IntArrayRef normalized_shape,\n                                                     const at::Tensor& gamma, double epsilon) {\n  CHECK_INPUT(input);\n  int n1, n2;\n  check_args(input, normalized_shape, n1, n2);\n  at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));\n  at::Tensor invvar = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half ||\n                                                                    input.scalar_type() == at::ScalarType::BFloat16\n                                                                ? at::ScalarType::Float\n                                                                : input.scalar_type()));\n\n  cuda_rms_norm(output, invvar, input, n1, n2, normalized_shape, gamma, epsilon);\n  return {output, invvar};\n}\n\nvoid cuda_rms_norm_gradient(at::Tensor& dout, at::Tensor& invvar, at::Tensor& input_or_output, int n1, int n2,\n                            at::IntArrayRef normalized_shape, const std::optional<at::Tensor>& gamma, double epsilon,\n                            at::Tensor& grad_input, const std::optional<at::Tensor>& grad_gamma, bool memory_efficient);\n\nat::Tensor rms_norm_gradient(at::Tensor& dout, at::Tensor& invvar, at::Tensor& input_or_output,\n                             at::IntArrayRef normalized_shape, double epsilon, bool memory_efficient) {\n  CHECK_INPUT(dout);\n  CHECK_INPUT(invvar);\n  CHECK_INPUT(input_or_output);\n  int n1, n2;\n  check_args(input_or_output, normalized_shape, n1, n2);\n  at::Tensor grad_input = at::empty_like(input_or_output);\n  cuda_rms_norm_gradient(dout, invvar, input_or_output, n1, n2, normalized_shape, std::nullopt, epsilon, grad_input,\n                         std::nullopt, memory_efficient);\n  return grad_input;\n}\n\nstd::vector<at::Tensor> rms_norm_gradient_affine(at::Tensor& dout, at::Tensor& invvar, at::Tensor& input_or_output,\n                                                 at::IntArrayRef normalized_shape, at::Tensor& gamma, double epsilon,\n                                                 bool memory_efficient) {\n  CHECK_INPUT(dout);\n  CHECK_INPUT(invvar);\n  CHECK_INPUT(input_or_output);\n  CHECK_INPUT(gamma);\n  int n1, n2;\n  check_args(input_or_output, normalized_shape, gamma, n1, n2);\n  at::Tensor grad_input = at::empty_like(input_or_output);\n  at::Tensor grad_gamma = at::empty_like(gamma);\n  cuda_rms_norm_gradient(dout, invvar, input_or_output, n1, n2, normalized_shape, gamma, epsilon, grad_input,\n                         grad_gamma, memory_efficient);\n  return {grad_input, grad_gamma};\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward_affine\", &layer_norm_affine, \"LayerNorm forward (CUDA)\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"forward\", &layer_norm, \"LayerNorm forward (CUDA)\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward_affine\", &layer_norm_gradient_affine, \"LayerNorm backward (CUDA)\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward\", &layer_norm_gradient, \"LayerNorm backward (CUDA)\", py::call_guard<py::gil_scoped_release>());\n\n  m.def(\"forward_affine_mixed_dtypes\", &layer_norm_affine_mixed_dtypes,\n        \"LayerNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation\",\n        py::call_guard<py::gil_scoped_release>());\n\n  m.def(\"rms_forward_affine\", &rms_norm_affine, \"RMSNorm forward (CUDA)\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"rms_forward\", &rms_norm, \"RMSNorm forward (CUDA)\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"rms_backward_affine\", &rms_norm_gradient_affine, \"RMSNorm backward (CUDA)\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"rms_backward\", &rms_norm_gradient, \"RMSNorm backward (CUDA)\", py::call_guard<py::gil_scoped_release>());\n\n  m.def(\"rms_forward_affine_mixed_dtypes\", &rms_norm_affine_mixed_dtypes,\n        \"RMSNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation\",\n        py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "csrc/layer_norm_cuda_kernel.cu",
    "content": "#include <cuda.h>\n#include <cuda_runtime.h>\n\n#include <optional>\n\n#include \"ATen/ATen.h\"\n#include \"ATen/AccumulateType.h\"\n#include \"ATen/cuda/CUDAContext.h\"\n#include \"ATen/cuda/DeviceUtils.cuh\"\n#include \"static_switch.h\"\n#include \"type_shim.h\"\n\ntemplate <typename U>\n__device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) {\n  count = count + U(1);\n  U delta = curr - mu;\n  U lmean = mu + delta / count;\n  mu = lmean;\n  U delta2 = curr - lmean;\n  sigma2 = sigma2 + delta * delta2;\n}\n\ntemplate <typename U>\n__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB, U& mu, U& sigma2, U& count) {\n  U delta = muB - mu;\n  U nA = count;\n  U nB = countB;\n  count = count + countB;\n  U nX = count;\n  if (nX > U(0)) {\n    nA = nA / nX;\n    nB = nB / nX;\n    mu = nA * mu + nB * muB;\n    sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;\n  } else {\n    mu = U(0);\n    sigma2 = U(0);\n  }\n}\n\ntemplate <typename U>\n__device__ void cuRMSOnlineSum(const U curr, U& sigma2) {\n  sigma2 = sigma2 + curr * curr;\n}\n\ntemplate <typename U>\n__device__ void cuChanRMSOnlineSum(const U sigma2B, U& sigma2) {\n  sigma2 = sigma2 + sigma2B;\n}\n\ntemplate <typename T, typename U>\n__device__ void cuWelfordMuSigma2(const T* __restrict__ vals, const int n1, const int n2, const int i1, U& mu,\n                                  U& sigma2, U* buf, bool rms_only) {\n  // Assumptions:\n  // 1) blockDim.x == warpSize\n  // 2) Tensor is contiguous\n  // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.\n  //\n  // compute variance and mean over n2\n  U count = U(0);\n  mu = U(0);\n  sigma2 = U(0);\n  if (i1 < n1) {\n    // one warp normalizes one n1 index,\n    // synchronization is implicit\n    // initialize with standard Welford algorithm\n    const int numx = blockDim.x * blockDim.y;\n    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;\n    const T* lvals = vals + i1 * n2;\n    int l = 4 * thrx;\n    for (; l + 3 < n2; l += 4 * numx) {\n      for (int k = 0; k < 4; ++k) {\n        U curr = static_cast<U>(lvals[l + k]);\n        if (!rms_only) {\n          cuWelfordOnlineSum<U>(curr, mu, sigma2, count);\n        } else {\n          cuRMSOnlineSum<U>(curr, sigma2);\n        }\n      }\n    }\n    for (; l < n2; ++l) {\n      U curr = static_cast<U>(lvals[l]);\n      if (!rms_only) {\n        cuWelfordOnlineSum<U>(curr, mu, sigma2, count);\n      } else {\n        cuRMSOnlineSum<U>(curr, sigma2);\n      }\n    }\n    // intra-warp reductions\n    for (int l = 0; l <= 4; ++l) {\n      int srcLaneB = (threadIdx.x + (1 << l)) & 31;\n      U sigma2B = WARP_SHFL(sigma2, srcLaneB);\n      if (!rms_only) {\n        U muB = WARP_SHFL(mu, srcLaneB);\n        U countB = WARP_SHFL(count, srcLaneB);\n        cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);\n      } else {\n        cuChanRMSOnlineSum<U>(sigma2B, sigma2);\n      }\n    }\n    // threadIdx.x == 0 has correct values for each warp\n    // inter-warp reductions\n    if (blockDim.y > 1) {\n      U* ubuf = (U*)buf;\n      U* ibuf = (U*)(ubuf + blockDim.y);\n      for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {\n        // upper half of warps write to shared\n        if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2 * offset) {\n          const int wrt_y = threadIdx.y - offset;\n          if (!rms_only) {\n            ubuf[2 * wrt_y] = mu;\n            ibuf[wrt_y] = count;\n          }\n          ubuf[2 * wrt_y + 1] = sigma2;\n        }\n        __syncthreads();\n        // lower half merges\n        if (threadIdx.x == 0 && threadIdx.y < offset) {\n          U sigma2B = ubuf[2 * threadIdx.y + 1];\n          if (!rms_only) {\n            U muB = ubuf[2 * threadIdx.y];\n            U countB = ibuf[threadIdx.y];\n            cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);\n          } else {\n            cuChanRMSOnlineSum<U>(sigma2B, sigma2);\n          }\n        }\n        __syncthreads();\n      }\n      // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values\n      if (threadIdx.x == 0 && threadIdx.y == 0) {\n        if (!rms_only) {\n          ubuf[0] = mu;\n        }\n        ubuf[1] = sigma2;\n      }\n      __syncthreads();\n      if (!rms_only) {\n        mu = ubuf[0];\n      }\n      sigma2 = ubuf[1] / U(n2);\n      // don't care about final value of count, we know count == n2\n    } else {\n      if (!rms_only) {\n        mu = WARP_SHFL(mu, 0);\n      }\n      sigma2 = WARP_SHFL(sigma2 / U(n2), 0);\n    }\n  }\n}\n\ntemplate <>\n__device__ void cuWelfordMuSigma2(const at::Half* __restrict__ vals, const int n1, const int n2, const int i1,\n                                  float& mu, float& sigma2, float* buf, bool rms_only) {\n  // Assumptions:\n  // 1) blockDim.x == warpSize\n  // 2) Tensor is contiguous\n  // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.\n  //\n  // compute variance and mean over n2\n  float count = 0.0f;\n  mu = float(0);\n  sigma2 = float(0);\n  if (i1 < n1) {\n    // one warp normalizes one n1 index,\n    // synchronization is implicit\n    // initialize with standard Welford algorithm\n    const int numx = blockDim.x * blockDim.y;\n    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;\n    const at::Half* lvals = vals + i1 * n2;\n    int l = 8 * thrx;\n    if ((((size_t)lvals) & 3) != 0) {\n      // 16 bit alignment\n      // first thread consumes first point\n      if (thrx == 0) {\n        float curr = static_cast<float>(lvals[0]);\n        if (!rms_only) {\n          cuWelfordOnlineSum(curr, mu, sigma2, count);\n        } else {\n          cuRMSOnlineSum(curr, sigma2);\n        }\n      }\n      ++l;\n    }\n    // at this point, lvals[l] are 32 bit aligned for all threads.\n    for (; l + 7 < n2; l += 8 * numx) {\n      for (int k = 0; k < 8; k += 2) {\n        float2 curr = __half22float2(*((__half2*)(lvals + l + k)));\n        if (!rms_only) {\n          cuWelfordOnlineSum(curr.x, mu, sigma2, count);\n          cuWelfordOnlineSum(curr.y, mu, sigma2, count);\n        } else {\n          cuRMSOnlineSum(curr.x, sigma2);\n          cuRMSOnlineSum(curr.y, sigma2);\n        }\n      }\n    }\n    for (; l < n2; ++l) {\n      float curr = static_cast<float>(lvals[l]);\n      if (!rms_only) {\n        cuWelfordOnlineSum(curr, mu, sigma2, count);\n      } else {\n        cuRMSOnlineSum(curr, sigma2);\n      }\n    }\n    // intra-warp reductions\n    for (int l = 0; l <= 4; ++l) {\n      int srcLaneB = (threadIdx.x + (1 << l)) & 31;\n      float sigma2B = WARP_SHFL(sigma2, srcLaneB);\n      if (!rms_only) {\n        float muB = WARP_SHFL(mu, srcLaneB);\n        float countB = WARP_SHFL(count, srcLaneB);\n        cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);\n      } else {\n        cuChanRMSOnlineSum(sigma2B, sigma2);\n      }\n    }\n    // threadIdx.x == 0 has correct values for each warp\n    // inter-warp reductions\n    if (blockDim.y > 1) {\n      float* ubuf = (float*)buf;\n      float* ibuf = (float*)(ubuf + blockDim.y);\n      for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {\n        // upper half of warps write to shared\n        if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2 * offset) {\n          const int wrt_y = threadIdx.y - offset;\n          ubuf[2 * wrt_y + 1] = sigma2;\n          if (!rms_only) {\n            ubuf[2 * wrt_y] = mu;\n            ibuf[wrt_y] = count;\n          }\n        }\n        __syncthreads();\n        // lower half merges\n        if (threadIdx.x == 0 && threadIdx.y < offset) {\n          float sigma2B = ubuf[2 * threadIdx.y + 1];\n          if (!rms_only) {\n            float muB = ubuf[2 * threadIdx.y];\n            float countB = ibuf[threadIdx.y];\n            cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);\n          } else {\n            cuChanRMSOnlineSum(sigma2B, sigma2);\n          }\n        }\n        __syncthreads();\n      }\n      // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values\n      if (threadIdx.x == 0 && threadIdx.y == 0) {\n        if (!rms_only) {\n          ubuf[0] = mu;\n        }\n        ubuf[1] = sigma2;\n      }\n      __syncthreads();\n      if (!rms_only) {\n        mu = ubuf[0];\n      }\n      sigma2 = ubuf[1] / float(n2);\n      // don't care about final value of count, we know count == n2\n    } else {\n      if (!rms_only) {\n        mu = WARP_SHFL(mu, 0);\n      }\n      sigma2 = WARP_SHFL(sigma2 / float(n2), 0);\n    }\n  }\n}\n\ntemplate <typename U>\nU rsqrt(U v) {\n  return U(1) / sqrt(v);\n}\ntemplate <>\nfloat rsqrt(float v) {\n  return rsqrtf(v);\n}\ntemplate <>\ndouble rsqrt(double v) {\n  return rsqrt(v);\n}\n\nnamespace {\n// This is the un-specialized struct.  Note that we prevent instantiation of this\n// struct by putting an undefined symbol in the function body so it won't compile.\n//  template <typename T>\n//  struct SharedMemory\n//  {\n//      // Ensure that we won't compile any un-specialized types\n//      __device__ T *getPointer()\n//      {\n//          extern __device__ void error(void);\n//          error();\n//          return nullptr;\n//      }\n//  };\n// https://github.com/NVIDIA/apex/issues/246\ntemplate <typename T>\nstruct SharedMemory;\n\ntemplate <>\nstruct SharedMemory<float> {\n  __device__ float* getPointer() {\n    extern __shared__ float s_float[];\n    return s_float;\n  }\n};\n\ntemplate <>\nstruct SharedMemory<double> {\n  __device__ double* getPointer() {\n    extern __shared__ double s_double[];\n    return s_double;\n  }\n};\n}  // namespace\n\ntemplate <typename T, typename U, typename V>\n__device__ void cuApplyLayerNorm_(V* __restrict__ output_vals, U* __restrict__ mean, U* __restrict__ invvar,\n                                  const T* __restrict__ vals, const int n1, const int n2, const U epsilon,\n                                  const V* __restrict__ gamma, const V* __restrict__ beta, bool rms_only) {\n  // Assumptions:\n  // 1) blockDim.x == warpSize\n  // 2) Tensors are contiguous\n  //\n  for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {\n    SharedMemory<U> shared;\n    U* buf = shared.getPointer();\n    U mu, sigma2;\n    cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf, rms_only);\n\n    const T* lvals = vals + i1 * n2;\n    V* ovals = output_vals + i1 * n2;\n    U c_invvar = rsqrt(sigma2 + epsilon);\n    const int numx = blockDim.x * blockDim.y;\n    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;\n    if (gamma != nullptr && (beta != nullptr || rms_only)) {\n      for (int i = thrx; i < n2; i += numx) {\n        U curr = static_cast<U>(lvals[i]);\n        if (!rms_only) {\n          ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];\n        } else {\n          ovals[i] = gamma[i] * static_cast<V>(c_invvar * curr);\n        }\n      }\n    } else {\n      for (int i = thrx; i < n2; i += numx) {\n        U curr = static_cast<U>(lvals[i]);\n        if (!rms_only) {\n          ovals[i] = static_cast<V>(c_invvar * (curr - mu));\n        } else {\n          ovals[i] = static_cast<V>(c_invvar * curr);\n        }\n      }\n    }\n    if (threadIdx.x == 0 && threadIdx.y == 0) {\n      if (!rms_only) {\n        mean[i1] = mu;\n      }\n      invvar[i1] = c_invvar;\n    }\n    __syncthreads();\n  }\n}\n\ntemplate <typename T, typename U, typename V = T>\n__global__ void cuApplyLayerNorm(V* __restrict__ output_vals, U* __restrict__ mean, U* __restrict__ invvar,\n                                 const T* __restrict__ vals, const int n1, const int n2, const U epsilon,\n                                 const V* __restrict__ gamma, const V* __restrict__ beta) {\n  cuApplyLayerNorm_<T, U, V>(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, false);\n}\n\ntemplate <typename T, typename U, typename V = T>\n__global__ void cuApplyRMSNorm(V* __restrict__ output_vals, U* __restrict__ invvar, const T* __restrict__ vals,\n                               const int n1, const int n2, const U epsilon, const V* __restrict__ gamma) {\n  cuApplyLayerNorm_<T, U, V>(output_vals, nullptr, invvar, vals, n1, n2, epsilon, gamma, nullptr, true);\n}\n\ntemplate <typename V>\n__device__ V clamp_by_magnitude(V curr_gamma, double eps) {\n  const V kMinGamma = V(eps);\n  if (curr_gamma >= 0) {\n    if (curr_gamma < kMinGamma) {\n      return kMinGamma;\n    } else {\n      return curr_gamma;\n    }\n  } else {\n    if (curr_gamma > -kMinGamma) {\n      return -kMinGamma;\n    } else {\n      return curr_gamma;\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename V, bool MemoryEfficient>\n__device__ void cuLoadWriteStridedInputs(const int i1_block, const int thr_load_row_off, const int thr_load_col_off,\n                                         const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2,\n                                         const T* input_or_output, const V* dout, const int i1_end, const int n2,\n                                         const U* __restrict__ mean, const U* __restrict__ invvar,\n                                         const V* __restrict__ gamma, const V* __restrict__ beta, const double eps,\n                                         bool rms_only) {\n  int i1 = i1_block + thr_load_row_off;\n  if (i1 < i1_end) {\n    for (int k = 0; k < blockDim.y; ++k) {\n      int i2 = i2_off + k;\n      int load_idx = i1 * n2 + i2;\n      int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;\n      if (i2 < n2) {\n        U c_h = static_cast<U>(input_or_output[load_idx]);\n        U curr_dout = static_cast<U>(dout[load_idx]);\n        if (!rms_only) {\n          warp_buf1[write_idx] = curr_dout;\n          if (MemoryEfficient) {\n            U curr_beta = static_cast<U>(beta[i2]);\n            warp_buf2[write_idx] = curr_dout * (c_h - curr_beta) / static_cast<U>(clamp_by_magnitude(gamma[i2], eps));\n          } else {\n            warp_buf2[write_idx] = curr_dout * (c_h - mean[i1]) * invvar[i1];\n          }\n        } else {\n          if (MemoryEfficient) {\n            warp_buf2[write_idx] = curr_dout * (c_h) / static_cast<U>(clamp_by_magnitude(gamma[i2], eps));\n          } else {\n            warp_buf2[write_idx] = curr_dout * (c_h)*invvar[i1];\n          }\n        }\n      } else {\n        if (!rms_only) {\n          warp_buf1[write_idx] = U(0);\n        }\n        warp_buf2[write_idx] = U(0);\n      }\n    }\n  } else {\n    for (int k = 0; k < blockDim.y; ++k) {\n      int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;\n      if (!rms_only) {\n        warp_buf1[write_idx] = U(0);\n      }\n      warp_buf2[write_idx] = U(0);\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename V, bool MemoryEfficient>\n__device__ void cuLoadAddStridedInputs(const int i1_block, const int thr_load_row_off, const int thr_load_col_off,\n                                       const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2,\n                                       const T* input_or_output, const V* dout, const int i1_end, const int n2,\n                                       const U* __restrict__ mean, const U* __restrict__ invvar,\n                                       const V* __restrict__ gamma, const V* __restrict__ beta, const double eps,\n                                       bool rms_only) {\n  int i1 = i1_block + thr_load_row_off;\n  if (i1 < i1_end) {\n    for (int k = 0; k < blockDim.y; ++k) {\n      int i2 = i2_off + k;\n      int load_idx = i1 * n2 + i2;\n      int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;\n      if (i2 < n2) {\n        U c_h = static_cast<U>(input_or_output[load_idx]);\n        U curr_dout = static_cast<U>(dout[load_idx]);\n        if (!rms_only) {\n          U curr_beta = static_cast<U>(beta[i2]);\n          warp_buf1[write_idx] += curr_dout;\n          if (MemoryEfficient) {\n            warp_buf2[write_idx] += curr_dout * (c_h - curr_beta) / static_cast<U>(clamp_by_magnitude(gamma[i2], eps));\n          } else {\n            warp_buf2[write_idx] += curr_dout * (c_h - mean[i1]) * invvar[i1];\n          }\n        } else {\n          if (MemoryEfficient) {\n            warp_buf2[write_idx] += curr_dout * (c_h) / static_cast<U>(clamp_by_magnitude(gamma[i2], eps));\n          } else {\n            warp_buf2[write_idx] += curr_dout * (c_h)*invvar[i1];\n          }\n        }\n      }\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename V, bool MemoryEfficient>\n__global__ void cuComputePartGradGammaBeta(const V* __restrict__ dout, const T* __restrict__ input_or_output,\n                                           const int n1, const int n2, const U* __restrict__ mean,\n                                           const U* __restrict__ invvar, U epsilon, const V* __restrict__ gamma,\n                                           const V* __restrict__ beta, U* part_grad_gamma, U* part_grad_beta,\n                                           const double eps, bool rms_only) {\n  const int numsegs_n1 = (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y);\n  const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;\n  const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y;\n  const int i1_beg_plus_one = (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y;\n  const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;\n  const int row_stride = blockDim.x + 1;\n  const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1);\n  const int thr_load_row_off = (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y;\n  const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;\n  SharedMemory<U> shared;\n  U* buf = shared.getPointer();  // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y -\n                                 // 1)*(blockDim.x/blockDim.y) elements\n  U* warp_buf1 = (U*)buf;\n  U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;\n  // compute partial sums from strided inputs\n  // do this to increase number of loads in flight\n  cuLoadWriteStridedInputs<T, U, V, MemoryEfficient>(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, row_stride,\n                                                     warp_buf1, warp_buf2, input_or_output, dout, i1_end, n2, mean,\n                                                     invvar, gamma, beta, eps, rms_only);\n  for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; i1_block += blockDim.y * blockDim.y) {\n    cuLoadAddStridedInputs<T, U, V, MemoryEfficient>(i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride,\n                                                     warp_buf1, warp_buf2, input_or_output, dout, i1_end, n2, mean,\n                                                     invvar, gamma, beta, eps, rms_only);\n  }\n  __syncthreads();\n  // inter-warp reductions\n  // sum within each warp\n  U acc1 = U(0);\n  U acc2 = U(0);\n  for (int k = 0; k < blockDim.y; ++k) {\n    int row1 = threadIdx.y + k * blockDim.y;\n    int idx1 = row1 * row_stride + threadIdx.x;\n    if (!rms_only) {\n      acc1 += warp_buf1[idx1];\n    }\n    acc2 += warp_buf2[idx1];\n  }\n  if (!rms_only) {\n    warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;\n  }\n  warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;\n  __syncthreads();\n  // sum all warps\n  for (int offset = blockDim.y / 2; offset > 1; offset /= 2) {\n    if (threadIdx.y < offset) {\n      int row1 = threadIdx.y;\n      int row2 = threadIdx.y + offset;\n      int idx1 = row1 * row_stride + threadIdx.x;\n      int idx2 = row2 * row_stride + threadIdx.x;\n      if (!rms_only) {\n        warp_buf1[idx1] += warp_buf1[idx2];\n      }\n      warp_buf2[idx1] += warp_buf2[idx2];\n    }\n    __syncthreads();\n  }\n  int i2 = blockIdx.x * blockDim.x + threadIdx.x;\n  if (threadIdx.y == 0 && i2 < n2) {\n    int row1 = threadIdx.y;\n    int row2 = threadIdx.y + 1;\n    int idx1 = row1 * row_stride + threadIdx.x;\n    int idx2 = row2 * row_stride + threadIdx.x;\n    if (!rms_only) {\n      part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];\n    }\n    part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];\n  }\n}\n\ntemplate <typename U, typename V>\n__global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, const U* part_grad_beta, const int part_size,\n                                       const int n1, const int n2, V* grad_gamma, V* grad_beta, bool rms_only) {\n  // sum partial gradients for gamma and beta\n  SharedMemory<U> shared;\n  U* buf = shared.getPointer();\n  int i2 = blockIdx.x * blockDim.x + threadIdx.x;\n  if (i2 < n2) {\n    // each warp does sequential reductions until reduced part_size is num_warps\n    int num_warp_reductions = part_size / blockDim.y;\n    U sum_gamma = U(0);\n    U sum_beta = U(0);\n    const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;\n    const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;\n    for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {\n      sum_gamma += part_grad_gamma_ptr[warp_offset * n2];\n      if (!rms_only) {\n        sum_beta += part_grad_beta_ptr[warp_offset * n2];\n      }\n    }\n    // inter-warp reductions\n    const int nbsize3 = blockDim.x * blockDim.y / 2;\n    for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {\n      // top half write to shared memory\n      if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {\n        const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;\n        buf[write_idx] = sum_gamma;\n        if (!rms_only) {\n          buf[write_idx + nbsize3] = sum_beta;\n        }\n      }\n      __syncthreads();\n      // bottom half sums\n      if (threadIdx.y < offset) {\n        const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;\n        sum_gamma += buf[read_idx];\n        if (!rms_only) {\n          sum_beta += buf[read_idx + nbsize3];\n        }\n      }\n      __syncthreads();\n    }\n    // write out fully summed gradients\n    if (threadIdx.y == 0) {\n      grad_gamma[i2] = sum_gamma;\n      if (!rms_only) {\n        grad_beta[i2] = sum_beta;\n      }\n    }\n  }\n}\n\ntemplate <typename T, typename U, typename V, bool MemoryEfficient>\n__global__ void cuComputeGradInput(const V* __restrict__ dout, const T* __restrict__ input_or_output, const int n1,\n                                   const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, U epsilon,\n                                   const V* gamma, const V* beta, T* grad_input, const double eps, bool rms_only) {\n  for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {\n    U sum_loss1 = U(0);\n    U sum_loss2 = U(0);\n    const T* k_h = input_or_output + i1 * n2;\n    const V* k_dout = dout + i1 * n2;\n    const U c_invvar = invvar[i1];\n    const U c_mean = !MemoryEfficient ? mean[i1] : 0.;\n    const int numx = blockDim.x * blockDim.y;\n    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;\n    if (gamma != nullptr) {\n      int l = 4 * thrx;\n      for (; l + 3 < n2; l += 4 * numx) {\n        for (int k = 0; k < 4; ++k) {\n          const U c_h = static_cast<U>(k_h[l + k]);\n          const U c_loss = static_cast<U>(k_dout[l + k]);\n          if (!rms_only) {\n            sum_loss1 += c_loss * gamma[l + k];\n            if (MemoryEfficient) {\n              sum_loss2 += c_loss * (c_h - beta[l + k]);\n            } else {\n              sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar;\n            }\n          } else {\n            if (MemoryEfficient) {\n              sum_loss2 += c_loss * c_h;\n            } else {\n              sum_loss2 += c_loss * gamma[l + k] * (c_h)*c_invvar;\n            }\n          }\n        }\n      }\n      for (; l < n2; ++l) {\n        const U c_h = static_cast<U>(k_h[l]);\n        const U c_loss = static_cast<U>(k_dout[l]);\n        if (!rms_only) {\n          sum_loss1 += c_loss * gamma[l];\n          if (MemoryEfficient) {\n            sum_loss2 += c_loss * (c_h - beta[l]);\n          } else {\n            sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;\n          }\n        } else {\n          if (MemoryEfficient) {\n            sum_loss2 += c_loss * c_h;\n          } else {\n            sum_loss2 += c_loss * gamma[l] * (c_h)*c_invvar;\n          }\n        }\n      }\n    } else {\n      int l = 4 * thrx;\n      for (; l + 3 < n2; l += 4 * numx) {\n        for (int k = 0; k < 4; ++k) {\n          const U c_h = static_cast<U>(k_h[l + k]);\n          const U c_loss = static_cast<U>(k_dout[l + k]);\n          if (!rms_only) {\n            sum_loss1 += c_loss;\n            if (MemoryEfficient) {\n              sum_loss2 += c_loss * c_h;\n            } else {\n              sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;\n            }\n          } else {\n            if (MemoryEfficient) {\n              sum_loss2 += c_loss * c_h;\n            } else {\n              sum_loss2 += c_loss * (c_h)*c_invvar;\n            }\n          }\n        }\n      }\n      for (; l < n2; ++l) {\n        const U c_h = static_cast<U>(k_h[l]);\n        const U c_loss = static_cast<U>(k_dout[l]);\n        if (!rms_only) {\n          sum_loss1 += c_loss;\n          if (MemoryEfficient) {\n            sum_loss2 += c_loss * c_h;\n          } else {\n            sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;\n          }\n        } else {\n          if (MemoryEfficient) {\n            sum_loss2 += c_loss * c_h;\n          } else {\n            sum_loss2 += c_loss * (c_h)*c_invvar;\n          }\n        }\n      }\n    }\n    // intra-warp reductions\n    for (int mask = blockDim.x / 2; mask > 0; mask /= 2) {\n      if (!rms_only) {\n        sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);\n      }\n      sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);\n    }\n    // inter-warp reductions\n    if (blockDim.y > 1) {\n      SharedMemory<U> shared;\n      U* buf = shared.getPointer();\n      for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {\n        // upper half of warps write to shared\n        if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {\n          const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;\n          if (!rms_only) {\n            buf[2 * wrt_i] = sum_loss1;\n          }\n          buf[2 * wrt_i + 1] = sum_loss2;\n        }\n        __syncthreads();\n        // lower half merges\n        if (threadIdx.y < offset) {\n          const int read_i = threadIdx.y * blockDim.x + threadIdx.x;\n          if (!rms_only) {\n            sum_loss1 += buf[2 * read_i];\n          }\n          sum_loss2 += buf[2 * read_i + 1];\n        }\n        __syncthreads();\n      }\n      if (threadIdx.y == 0) {\n        if (!rms_only) {\n          buf[2 * threadIdx.x] = sum_loss1;\n        }\n        buf[2 * threadIdx.x + 1] = sum_loss2;\n      }\n      __syncthreads();\n      if (threadIdx.y != 0) {\n        if (!rms_only) {\n          sum_loss1 = buf[2 * threadIdx.x];\n        }\n        sum_loss2 = buf[2 * threadIdx.x + 1];\n      }\n    }\n    // all threads now have the two sums over l\n    U fH = (U)n2;\n    U term1 = (U(1) / fH) * c_invvar;\n    T* k_grad_input = grad_input + i1 * n2;\n    if (gamma != nullptr) {\n      for (int l = thrx; l < n2; l += numx) {\n        const U c_h = static_cast<U>(k_h[l]);\n        const U c_loss = static_cast<U>(k_dout[l]);\n        const U k_gamma = static_cast<U>(clamp_by_magnitude(gamma[l], eps));\n        U f_grad_input = fH * c_loss * k_gamma;\n        if (!rms_only) {\n          const U k_beta = beta[l];\n          f_grad_input -= sum_loss1;\n          if (MemoryEfficient) {\n            f_grad_input -= (c_h - k_beta) / k_gamma * sum_loss2;\n          } else {\n            f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;\n          }\n        } else {\n          if (MemoryEfficient) {\n            f_grad_input -= c_h / k_gamma * sum_loss2;\n          } else {\n            f_grad_input -= c_h * c_invvar * sum_loss2;\n          }\n        }\n        f_grad_input *= term1;\n        k_grad_input[l] = static_cast<T>(f_grad_input);\n      }\n    } else {\n      for (int l = thrx; l < n2; l += numx) {\n        const U c_h = static_cast<U>(k_h[l]);\n        const U c_loss = static_cast<U>(k_dout[l]);\n        U f_grad_input = fH * c_loss;\n        if (!rms_only) {\n          f_grad_input -= sum_loss1;\n          if (MemoryEfficient) {\n            f_grad_input -= c_h * sum_loss2;\n          } else {\n            f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;\n          }\n        } else {\n          if (MemoryEfficient) {\n            f_grad_input -= c_h * sum_loss2;\n          } else {\n            f_grad_input -= c_h * c_invvar * sum_loss2;\n          }\n        }\n        f_grad_input *= term1;\n        k_grad_input[l] = static_cast<T>(f_grad_input);\n      }\n    }\n    // prevent race where buf is written again before reads are done\n    __syncthreads();\n  }\n}\n\ntemplate <typename T, typename U, typename V = T>\nvoid HostApplyLayerNorm(V* output, U* mean, U* invvar, const T* input, int n1, int n2, double epsilon, const V* gamma,\n                        const V* beta) {\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const dim3 threads(32, 4, 1);\n  const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];\n  const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);\n  int nshared = threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0;\n  cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta);\n}\n\ntemplate <typename T, typename U, typename V = T>\nvoid HostApplyRMSNorm(V* output, U* invvar, const T* input, int n1, int n2, double epsilon, const V* gamma) {\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n  const dim3 threads(32, 4, 1);\n  const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];\n  const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);\n  int nshared = threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0;\n  cuApplyRMSNorm<<<blocks, threads, nshared, stream>>>(output, invvar, input, n1, n2, U(epsilon), gamma);\n}\n\nvoid cuda_layer_norm(at::Tensor& output, at::Tensor& mean, at::Tensor& invvar, const at::Tensor& input, int n1, int n2,\n                     at::IntArrayRef normalized_shape, const std::optional<at::Tensor>& gamma,\n                     const std::optional<at::Tensor>& beta, double epsilon) {\n  using namespace at;\n  DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(\n      input.scalar_type(), output.scalar_type(), \"layer_norm_cuda_kernel\",\n      using accscalar_t = at::acc_type<scalar_t_in, true>;\n      HostApplyLayerNorm<scalar_t_in, accscalar_t, scalar_t_out>(\n          output.data_ptr<scalar_t_out>(), mean.data_ptr<accscalar_t>(), invvar.data_ptr<accscalar_t>(),\n          input.data_ptr<scalar_t_in>(), n1, n2, epsilon, gamma.has_value() ? gamma->data_ptr<scalar_t_out>() : nullptr,\n          beta.has_value() ? beta->data_ptr<scalar_t_out>() : nullptr);)\n}\n\nvoid cuda_rms_norm(at::Tensor& output, at::Tensor& invvar, const at::Tensor& input, int n1, int n2,\n                   at::IntArrayRef normalized_shape, const std::optional<at::Tensor>& gamma, double epsilon) {\n  using namespace at;\n  DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(\n      input.scalar_type(), output.scalar_type(), \"rms_norm_cuda_kernel\",\n      using accscalar_t = at::acc_type<scalar_t_in, true>;\n      HostApplyRMSNorm<scalar_t_in, accscalar_t, scalar_t_out>(\n          output.data_ptr<scalar_t_out>(), invvar.data_ptr<accscalar_t>(), input.data_ptr<scalar_t_in>(), n1, n2,\n          epsilon, gamma.has_value() ? gamma->data_ptr<scalar_t_out>() : nullptr);)\n}\n\ntemplate <typename T, typename U = float, typename V = T>\nvoid HostLayerNormGradient(const V* dout, const U* mean, const U* invvar, at::Tensor& input_or_output, int n1, int n2,\n                           const V* gamma, const V* beta, double epsilon, T* grad_input, V* grad_gamma, V* grad_beta,\n                           bool memory_efficient) {\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n  if (gamma != nullptr && beta != nullptr) {\n    // compute grad_gamma(j) and grad_beta(j)\n    const int part_size = 16;\n    const dim3 threads2(32, 4, 1);\n    const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1);\n    const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);\n    const int nshared2_b = threads2.x * threads2.y * sizeof(U);\n    const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;\n    // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that\n    // the `cuda_layer_norm_gradient` doesn't support double.\n    const auto part_grad_dtype = (input_or_output.scalar_type() == at::ScalarType::Half ||\n                                  input_or_output.scalar_type() == at::ScalarType::BFloat16)\n                                     ? at::ScalarType::Float\n                                     : input_or_output.scalar_type();\n    at::Tensor part_grad_gamma = at::empty({part_size, n2}, input_or_output.options().dtype(part_grad_dtype));\n    at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);\n    BOOL_SWITCH(memory_efficient, MemoryEfficient, [&] {\n      auto kernel = &cuComputePartGradGammaBeta<T, U, V, MemoryEfficient>;\n      kernel<<<blocks2, threads2, nshared2, stream>>>(dout, input_or_output.data_ptr<T>(), n1, n2, mean, invvar,\n                                                      U(epsilon), gamma, beta, part_grad_gamma.data_ptr<U>(),\n                                                      part_grad_beta.data_ptr<U>(), epsilon, false);\n    });\n\n    const dim3 threads3(32, 8, 1);\n    const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1);\n    const int nshared3 = threads3.x * threads3.y * sizeof(U);\n    cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(\n        part_grad_gamma.data_ptr<U>(), part_grad_beta.data_ptr<U>(), part_size, n1, n2, grad_gamma, grad_beta, false);\n  }\n\n  // compute grad_input\n  const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];\n  const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);\n  const dim3 threads1(32, 4, 1);\n  int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0;\n  BOOL_SWITCH(memory_efficient, MemoryEfficient, [&] {\n    auto kernel = cuComputeGradInput<T, U, V, MemoryEfficient>;\n    kernel<<<blocks1, threads1, nshared, stream>>>(dout, input_or_output.data_ptr<T>(), n1, n2, mean, invvar,\n                                                   U(epsilon), gamma, beta, grad_input, epsilon, false);\n  });\n}\n\ntemplate <typename T, typename U = float, typename V = T>\nvoid HostRMSNormGradient(const V* dout, const U* invvar, at::Tensor& input_or_output, int n1, int n2, const V* gamma,\n                         double epsilon, T* grad_input, V* grad_gamma, bool memory_efficient) {\n  auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n  if (gamma != nullptr) {\n    const int part_size = 16;\n    const dim3 threads2(32, 4, 1);\n    const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1);\n    const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);\n    const int nshared2_b = threads2.x * threads2.y * sizeof(U);\n    const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;\n    // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that\n    // the `cuda_layer_norm_gradient` doesn't support double.\n    const auto part_grad_dtype = (input_or_output.scalar_type() == at::ScalarType::Half ||\n                                  input_or_output.scalar_type() == at::ScalarType::BFloat16)\n                                     ? at::ScalarType::Float\n                                     : input_or_output.scalar_type();\n    at::Tensor part_grad_gamma = at::empty({part_size, n2}, input_or_output.options().dtype(part_grad_dtype));\n    BOOL_SWITCH(memory_efficient, MemoryEfficient, [&] {\n      auto kernel = &cuComputePartGradGammaBeta<T, U, V, MemoryEfficient>;\n      kernel<<<blocks2, threads2, nshared2, stream>>>(dout, input_or_output.data_ptr<T>(), n1, n2, invvar, /* unused */\n                                                      invvar, U(epsilon), gamma, gamma,                    /* unused */\n                                                      part_grad_gamma.data_ptr<U>(),\n                                                      part_grad_gamma.data_ptr<U>(), /* unused */\n                                                      epsilon, true);\n    });\n\n    const dim3 threads3(32, 8, 1);\n    const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1);\n    const int nshared3 = threads3.x * threads3.y * sizeof(U);\n    cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(\n        part_grad_gamma.data_ptr<U>(), part_grad_gamma.data_ptr<U>(), /* unused */\n        part_size, n1, n2, grad_gamma, grad_gamma,                    /* unused */\n        true);\n  }\n\n  // compute grad_input\n  const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];\n  const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);\n  const dim3 threads1(32, 4, 1);\n  int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0;\n  BOOL_SWITCH(memory_efficient, MemoryEfficient, [&] {\n    auto kernel = cuComputeGradInput<T, U, V, MemoryEfficient>;\n    kernel<<<blocks1, threads1, nshared, stream>>>(dout, input_or_output.data_ptr<T>(), n1, n2, invvar, /* unused */\n                                                   invvar, U(epsilon), gamma, gamma,                    /* unused */\n                                                   grad_input, epsilon, true);\n  });\n}\n\nvoid cuda_layer_norm_gradient(at::Tensor& dout, const std::optional<at::Tensor>& mean, at::Tensor& invvar,\n                              at::Tensor& input_or_output, int n1, int n2, at::IntArrayRef normalized_shape,\n                              const std::optional<at::Tensor>& gamma, const std::optional<at::Tensor>& beta,\n                              double epsilon, at::Tensor& grad_input, const std::optional<at::Tensor>& grad_gamma,\n                              const std::optional<at::Tensor>& grad_beta, bool memory_efficient) {\n  using namespace at;\n  // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16\n  DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(\n      input_or_output.scalar_type(), gamma.has_value() ? gamma->scalar_type() : input_or_output.scalar_type(),\n      \"cuComputeGradInput\", using accscalar_t = at::acc_type<scalar_t_in, true>;\n      HostLayerNormGradient(dout.data_ptr<scalar_t_out>(), mean.has_value() ? mean->data_ptr<accscalar_t>() : nullptr,\n                            invvar.data_ptr<accscalar_t>(), input_or_output, n1, n2,\n                            // TMJ pass nullptr argument for gamma, beta, grad_gamma and grad_beta\n                            // if gamma Tensor is nullptr on input.\n                            gamma.has_value() ? gamma->data_ptr<scalar_t_out>() : nullptr,\n                            gamma.has_value() ? beta->data_ptr<scalar_t_out>() : nullptr, epsilon,\n                            grad_input.data_ptr<scalar_t_in>(),\n                            gamma.has_value() ? grad_gamma->data_ptr<scalar_t_out>() : nullptr,\n                            gamma.has_value() ? grad_beta->data_ptr<scalar_t_out>() : nullptr, memory_efficient);)\n}\n\nvoid cuda_rms_norm_gradient(at::Tensor& dout, at::Tensor& invvar, at::Tensor& input_or_output, int n1, int n2,\n                            at::IntArrayRef normalized_shape, const std::optional<at::Tensor>& gamma, double epsilon,\n                            at::Tensor& grad_input, const std::optional<at::Tensor>& grad_gamma,\n                            bool memory_efficient) {\n  using namespace at;\n  // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16\n  // DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(\n  DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(\n      input_or_output.scalar_type(), gamma.has_value() ? gamma->scalar_type() : input_or_output.scalar_type(),\n      \"cuComputeGradInputRMS\", using accscalar_t = at::acc_type<scalar_t_in, true>;\n      HostRMSNormGradient(dout.data_ptr<scalar_t_out>(), invvar.data_ptr<accscalar_t>(), input_or_output, n1, n2,\n                          // TMJ pass nullptr argument for gamma, beta, grad_gamma and grad_beta\n                          // if gamma Tensor is nullptr on input.\n                          gamma.has_value() ? gamma->data_ptr<scalar_t_out>() : nullptr, epsilon,\n                          grad_input.data_ptr<scalar_t_in>(),\n                          gamma.has_value() ? grad_gamma->data_ptr<scalar_t_out>() : nullptr, memory_efficient);)\n}\n"
  },
  {
    "path": "csrc/megatron/fused_rotary_positional_embedding.cpp",
    "content": "/* coding=utf-8\n * Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <torch/extension.h>\n\nnamespace fused_rope {\n\ntorch::Tensor fwd_cuda(const torch::Tensor& input, const torch::Tensor& freqs, const bool transpose_output);\n\ntorch::Tensor bwd_cuda(const torch::Tensor& output_grads, const torch::Tensor& freqs, const bool transpose_output);\n\ntorch::Tensor fwd_cached_cuda(const torch::Tensor& input, const torch::Tensor& cos, const torch::Tensor& sin,\n                              const bool transpose_output);\n\ntorch::Tensor bwd_cached_cuda(const torch::Tensor& output_grads, const torch::Tensor& cos, const torch::Tensor& sin,\n                              const bool transpose_output);\n\ntorch::Tensor fwd_thd_cuda(const torch::Tensor& input, const torch::Tensor& cu_seqlens, const torch::Tensor& freqs);\n\ntorch::Tensor bwd_thd_cuda(const torch::Tensor& output_grads, const torch::Tensor& cu_seqlens,\n                           const torch::Tensor& freqs);\n\ntorch::Tensor fwd_2d_cuda(const torch::Tensor& input, const torch::Tensor& cos_h, const torch::Tensor& sin_h,\n                          const torch::Tensor& cos_w, const torch::Tensor& sin_w);\n\ntorch::Tensor bwd_2d_cuda(const torch::Tensor& output_grads, const torch::Tensor& cos_h, const torch::Tensor& sin_h,\n                          const torch::Tensor& cos_w, const torch::Tensor& sin_w);\n\ntorch::Tensor fwd(const at::Tensor& input, const at::Tensor& freqs, const bool transpose_output) {\n  TORCH_CHECK(input.dim() == 4, \"expected 4D tensor\");\n  TORCH_CHECK(freqs.dim() == 4, \"expected 4D tensor\");\n  TORCH_CHECK(input.size(0) == freqs.size(0), \"expected input and freqs tensor have the same sequence length\");\n  TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,\n              \"expected the second and third dims of the freqs tensor equal 1\");\n  TORCH_CHECK(input.size(3) >= freqs.size(3),\n              \"expected the last dim of the input tensor equals or is \"\n              \"greater than the freqs tensor\");\n  TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, \"Dtype of the freqs tensor must be float\");\n\n  return fwd_cuda(input, freqs, transpose_output);\n}\n\ntorch::Tensor bwd(const torch::Tensor& output_grads, const at::Tensor& freqs, const bool transpose_output) {\n  TORCH_CHECK(output_grads.dim() == 4, \"expected 4D tensor\");\n  TORCH_CHECK(freqs.dim() == 4, \"expected 4D tensor\");\n  TORCH_CHECK(output_grads.size(0) == freqs.size(0),\n              \"expected output_grads and freqs tensor have the same sequence length\");\n  TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,\n              \"expected the second and third dims of the freqs tensor equal 1\");\n  TORCH_CHECK(output_grads.size(3) >= freqs.size(3),\n              \"expected the last dim of the output_grads tensor equals or is \"\n              \"greater than the freqs tensor\");\n  TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, \"Dtype of the freqs tensor must be float\");\n\n  return bwd_cuda(output_grads, freqs, transpose_output);\n}\n\ntorch::Tensor fwd_cached(const at::Tensor& input, const at::Tensor& cos, const at::Tensor& sin,\n                         const bool transpose_output) {\n  TORCH_CHECK(input.dim() == 4, \"expected 4D tensor\");\n  TORCH_CHECK(cos.dim() == 4, \"expected 4D tensor\");\n  TORCH_CHECK(sin.dim() == 4, \"expected 4D tensor\");\n  TORCH_CHECK(input.size(0) == cos.size(0), \"expected input and cos tensor have the same sequence length\");\n  TORCH_CHECK(input.size(0) == sin.size(0), \"expected input and sin tensor have the same sequence length\");\n  TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1, \"expected the second and third dims of the cos tensor equal 1\");\n  TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1, \"expected the second and third dims of the sin tensor equal 1\");\n  TORCH_CHECK(cos.size(3) == sin.size(3), \"expected cos and sin tensor have the same last dim\");\n  TORCH_CHECK(input.size(3) >= cos.size(3),\n              \"expected the last dim of the input tensor equals or is \"\n              \"greater than the cos tensor\");\n  TORCH_CHECK(cos.scalar_type() == sin.scalar_type(), \"expected cos and sin tensor have the same dtype\");\n\n  return fwd_cached_cuda(input, cos, sin, transpose_output);\n}\n\ntorch::Tensor bwd_cached(const torch::Tensor& output_grads, const at::Tensor& cos, const at::Tensor& sin,\n                         const bool transpose_output) {\n  TORCH_CHECK(output_grads.dim() == 4, \"expected 4D tensor\");\n  TORCH_CHECK(cos.dim() == 4, \"expected 4D tensor\");\n  TORCH_CHECK(sin.dim() == 4, \"expected 4D tensor\");\n  TORCH_CHECK(output_grads.size(0) == cos.size(0),\n              \"expected output_grads and cos tensor have the same sequence length\");\n  TORCH_CHECK(output_grads.size(0) == sin.size(0),\n              \"expected output_grads and sin tensor have the same sequence length\");\n  TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1, \"expected the second and third dims of the cos tensor equal 1\");\n  TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1, \"expected the second and third dims of the sin tensor equal 1\");\n  TORCH_CHECK(cos.size(3) == sin.size(3), \"expected cos and sin tensor have the same last dim\");\n  TORCH_CHECK(output_grads.size(3) >= cos.size(3),\n              \"expected the last dim of the output_grads tensor equals or is \"\n              \"greater than the cos tensor\");\n  TORCH_CHECK(cos.scalar_type() == sin.scalar_type(), \"expected cos and sin tensor have the same dtype\");\n\n  return bwd_cached_cuda(output_grads, cos, sin, transpose_output);\n}\n\ntorch::Tensor fwd_thd(const torch::Tensor& input, const torch::Tensor& cu_seqlens, const torch::Tensor& freqs) {\n  TORCH_CHECK(input.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(cu_seqlens.dim() == 1, \"expected 1D tensor\");\n  TORCH_CHECK(freqs.dim() == 4, \"expected 4D tensor\");\n  TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,\n              \"expected the second and third dims of the freqs tensor equal 1\");\n  TORCH_CHECK(input.size(2) >= freqs.size(3),\n              \"expected the last dim of the input tensor equals or is \"\n              \"greater than the freqs tensor\");\n  TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, \"Dtype of the freqs tensor must be float\");\n\n  return fwd_thd_cuda(input, cu_seqlens, freqs);\n}\n\ntorch::Tensor bwd_thd(const torch::Tensor& output_grads, const torch::Tensor& cu_seqlens, const torch::Tensor& freqs) {\n  TORCH_CHECK(output_grads.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(cu_seqlens.dim() == 1, \"expected 1D tensor\");\n  TORCH_CHECK(freqs.dim() == 4, \"expected 4D tensor\");\n  TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,\n              \"expected the second and third dims of the freqs tensor equal 1\");\n  TORCH_CHECK(output_grads.size(2) >= freqs.size(3),\n              \"expected the last dim of the output_grads tensor equals or is \"\n              \"greater than the freqs tensor\");\n  TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, \"Dtype of the freqs tensor must be float\");\n\n  return bwd_thd_cuda(output_grads, cu_seqlens, freqs);\n}\n\ntorch::Tensor fwd_2d(const torch::Tensor& input, const torch::Tensor& cos_h, const torch::Tensor& sin_h,\n                     const torch::Tensor& cos_w, const torch::Tensor& sin_w) {\n  TORCH_CHECK(input.dim() == 5, \"expected input to be 5D tensor\");\n  TORCH_CHECK(cos_h.dim() == 4, \"expected cos_h to be 4D tensor\");\n  TORCH_CHECK(sin_h.dim() == 4, \"expected sin_h to be 4D tensor\");\n  TORCH_CHECK(cos_w.dim() == 4, \"expected cos_w to be 4D tensor\");\n  TORCH_CHECK(sin_w.dim() == 4, \"expected sin_w to be 4D tensor\");\n  TORCH_CHECK(cos_h.size(2) == 1, \"expected third dim of cos_h/sin_h equals 1\");\n  TORCH_CHECK(input.size(1) <= cos_h.size(1), \"expected input's height <= cos_h/sin_h's\");\n  TORCH_CHECK(input.size(4) / 2 == cos_h.size(3), \"expected cos_h/sin_h's head dim equals input's head dim / 2\");\n  TORCH_CHECK(cos_w.size(2) == 1, \"expected third dim of cos_w/sin_w equals 1\");\n  TORCH_CHECK(input.size(2) <= cos_w.size(1), \"expected input's width <= cos_w/sin_w's\");\n  TORCH_CHECK(input.size(4) / 2 == cos_w.size(3), \"expected cos_w/sin_w's head dim equals input's head dim / 2\");\n\n  return fwd_2d_cuda(input, cos_h, sin_h, cos_w, sin_w);\n}\n\ntorch::Tensor bwd_2d(const torch::Tensor& output_grads, const torch::Tensor& cos_h, const torch::Tensor& sin_h,\n                     const torch::Tensor& cos_w, const torch::Tensor& sin_w) {\n  TORCH_CHECK(output_grads.dim() == 5, \"expected output_grads to be 5D tensor\");\n  TORCH_CHECK(cos_h.dim() == 4, \"expected cos_h to be 4D tensor\");\n  TORCH_CHECK(sin_h.dim() == 4, \"expected sin_h to be 4D tensor\");\n  TORCH_CHECK(cos_w.dim() == 4, \"expected cos_w to be 4D tensor\");\n  TORCH_CHECK(sin_w.dim() == 4, \"expected sin_w to be 4D tensor\");\n  TORCH_CHECK(cos_h.size(2) == 1, \"expected third dim of cos_h/sin_h equals 1\");\n  TORCH_CHECK(output_grads.size(1) <= cos_h.size(1), \"expected output_grads' height <= cos_h/sin_h's\");\n  TORCH_CHECK(output_grads.size(4) / 2 == cos_h.size(3),\n              \"expected cos_h/sin_h's head dim equals output_grads' head dim / 2\");\n  TORCH_CHECK(cos_w.size(2) == 1, \"expected third dim of cos_w/sin_w equals 1\");\n  TORCH_CHECK(output_grads.size(2) <= cos_w.size(1), \"expected output_grads' width <= cos_w/sin_w's\");\n  TORCH_CHECK(output_grads.size(4) / 2 == cos_w.size(3),\n              \"expected cos_w/sin_w's head dim equals output_grads' head dim / 2\");\n\n  return bwd_2d_cuda(output_grads, cos_h, sin_h, cos_w, sin_w);\n}\n\n}  // end namespace fused_rope\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &fused_rope::fwd, \"Fused Rotary Positional Embedding -- Forward.\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward\", &fused_rope::bwd, \"Fused Rotary Positional Embedding -- Backward.\",\n        py::call_guard<py::gil_scoped_release>());\n  // cache sin/cos\n  m.def(\"forward_cached\", &fused_rope::fwd_cached, \"Fused Rotary Positional Embedding Cached -- Forward.\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward_cached\", &fused_rope::bwd_cached, \"Fused Rotary Positional Embedding Cached -- Backward.\",\n        py::call_guard<py::gil_scoped_release>());\n  // thd\n  m.def(\"forward_thd\", &fused_rope::fwd_thd, \"Fused Rotary Positional Embedding for thd layout -- Forward.\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward_thd\", &fused_rope::bwd_thd, \"Fused Rotary Positional Embedding for thd layout -- Backward.\",\n        py::call_guard<py::gil_scoped_release>());\n  // 2d\n  m.def(\"forward_2d\", &fused_rope::fwd_2d, \"2D Fused Rotary Positional Embedding -- Forward.\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward_2d\", &fused_rope::bwd_2d, \"2D Fused Rotary Positional Embedding -- Backward.\",\n        py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "csrc/megatron/fused_rotary_positional_embedding.h",
    "content": "/* coding=utf-8\n * Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/macros/Macros.h>\n#include <cuda_runtime.h>\n#include <torch/extension.h>\n\nnamespace {\n\ntemplate <typename scalar_t>\n__device__ void fused_rope_block_forward(const scalar_t* src, const float* freqs, scalar_t* dst, const int offset_block,\n                                         const int offset_block_dst, const int h, const int d, const int d2,\n                                         const int stride_h, const int stride_d, const int o_stride_h,\n                                         const int o_stride_d) {\n  int s_id = blockIdx.x;\n#pragma unroll\n  for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {\n    float v_cos, v_sin;\n    sincosf(freqs[s_id * d2 + d_id], &v_sin, &v_cos);\n#pragma unroll\n    for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {\n      int offset_src = offset_block + h_id * stride_h + d_id * stride_d;\n      int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;\n      scalar_t v_src = src[offset_src];\n      scalar_t v_src_rotate =\n          (d_id + d2 / 2 < d2) ? -src[offset_src + (d2 / 2) * stride_d] : src[offset_src + (d2 / 2 - d2) * stride_d];\n      dst[offset_dst] = v_src * (scalar_t)v_cos + v_src_rotate * (scalar_t)v_sin;\n    }\n  }\n\n  // copy the rest\n  if (d > d2) {\n#pragma unroll\n    for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {\n      int offset_head = offset_block + h_id * stride_h;\n      int offset_head_dst = offset_block_dst + h_id * o_stride_h;\n#pragma unroll\n      for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {\n        dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d];\n      }\n    }\n  }\n}\n\ntemplate <typename scalar_t>\n__device__ void fused_rope_block_backward(const scalar_t* src, const float* freqs, scalar_t* dst,\n                                          const int offset_block, const int offset_block_dst, const int h, const int d,\n                                          const int d2, const int stride_h, const int stride_d, const int o_stride_h,\n                                          const int o_stride_d) {\n  int s_id = blockIdx.x;\n#pragma unroll\n  for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {\n    scalar_t v_cos = cosf(freqs[s_id * d2 + d_id]);\n    scalar_t v_sin =\n        (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2]) : -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]);\n#pragma unroll\n    for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {\n      int offset_src = offset_block + h_id * stride_h + d_id * stride_d;\n      int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;\n      scalar_t v_src = src[offset_src];\n      scalar_t v_src_rotate =\n          (d_id + d2 / 2 < d2) ? src[offset_src + (d2 / 2) * stride_d] : src[offset_src + (d2 / 2 - d2) * stride_d];\n      dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;\n    }\n  }\n\n  // handle the tail\n  if (d > d2) {\n#pragma unroll\n    for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {\n      int offset_head = offset_block + h_id * stride_h;\n      int offset_head_dst = offset_block_dst + h_id * o_stride_h;\n#pragma unroll\n      for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {\n        dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d];\n      }\n    }\n  }\n}\n\ntemplate <typename scalar_t>\n__global__ void fused_rope_forward(const int h, const int d, const int d2, const int stride_s, const int stride_b,\n                                   const int stride_h, const int stride_d, const int o_stride_s, const int o_stride_b,\n                                   const int o_stride_h, const int o_stride_d, const scalar_t* src, const float* freqs,\n                                   scalar_t* dst) {\n  int s_id = blockIdx.x, b_id = blockIdx.y;\n  int offset_block = s_id * stride_s + b_id * stride_b;\n  int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;\n  fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h,\n                           o_stride_d);\n}\n\ntemplate <typename scalar_t>\n__global__ void fused_rope_backward(const int h, const int d, const int d2, const int stride_s, const int stride_b,\n                                    const int stride_h, const int stride_d, const int o_stride_s, const int o_stride_b,\n                                    const int o_stride_h, const int o_stride_d, const scalar_t* src, const float* freqs,\n                                    scalar_t* dst) {\n  int s_id = blockIdx.x, b_id = blockIdx.y;\n  int offset_block = s_id * stride_s + b_id * stride_b;\n  int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;\n  fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h,\n                            o_stride_d);\n}\n\ntemplate <typename scalar_t_0, typename scalar_t_1>\n__device__ void fused_rope_cached_block_forward(const scalar_t_0* src, const scalar_t_1* cos, const scalar_t_1* sin,\n                                                scalar_t_0* dst, const int s_id, const int offset_block,\n                                                const int offset_block_dst, const int h, const int d, const int d2,\n                                                const int stride_h, const int stride_d, const int o_stride_h,\n                                                const int o_stride_d) {\n#pragma unroll\n  for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {\n    scalar_t_0 v_cos = cos[s_id * d2 + d_id];\n    scalar_t_0 v_sin = sin[s_id * d2 + d_id];\n#pragma unroll\n    for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {\n      int offset_src = offset_block + h_id * stride_h + d_id * stride_d;\n      int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;\n      scalar_t_0 v_src = src[offset_src];\n      scalar_t_0 v_src_rotate =\n          (d_id + d2 / 2 < d2) ? -src[offset_src + (d2 / 2) * stride_d] : src[offset_src + (d2 / 2 - d2) * stride_d];\n      dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;\n    }\n  }\n\n  // copy the rest\n  if (d > d2) {\n#pragma unroll\n    for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {\n      int offset_head = offset_block + h_id * stride_h;\n      int offset_head_dst = offset_block_dst + h_id * o_stride_h;\n#pragma unroll\n      for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {\n        dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d];\n      }\n    }\n  }\n}\n\ntemplate <typename scalar_t_0, typename scalar_t_1>\n__device__ void fused_rope_cached_block_backward(const scalar_t_0* src, const scalar_t_1* cos, const scalar_t_1* sin,\n                                                 scalar_t_0* dst, const int s_id, const int offset_block,\n                                                 const int offset_block_dst, const int h, const int d, const int d2,\n                                                 const int stride_h, const int stride_d, const int o_stride_h,\n                                                 const int o_stride_d) {\n#pragma unroll\n  for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {\n    scalar_t_0 v_cos = cos[s_id * d2 + d_id];\n    scalar_t_0 v_sin = (d_id + d2 / 2 < d2) ? sin[s_id * d2 + d_id + d2 / 2] : -sin[s_id * d2 + d_id + d2 / 2 - d2];\n#pragma unroll\n    for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {\n      int offset_src = offset_block + h_id * stride_h + d_id * stride_d;\n      int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;\n      scalar_t_0 v_src = src[offset_src];\n      scalar_t_0 v_src_rotate =\n          (d_id + d2 / 2 < d2) ? src[offset_src + (d2 / 2) * stride_d] : src[offset_src + (d2 / 2 - d2) * stride_d];\n      dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;\n    }\n  }\n\n  // handle the tail\n  if (d > d2) {\n#pragma unroll\n    for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {\n      int offset_head = offset_block + h_id * stride_h;\n      int offset_head_dst = offset_block_dst + h_id * o_stride_h;\n#pragma unroll\n      for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {\n        dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d];\n      }\n    }\n  }\n}\n\ntemplate <typename scalar_t_0, typename scalar_t_1>\n__global__ void fused_rope_cached_forward(const int h, const int d, const int d2, const int stride_s,\n                                          const int stride_b, const int stride_h, const int stride_d,\n                                          const int o_stride_s, const int o_stride_b, const int o_stride_h,\n                                          const int o_stride_d, const scalar_t_0* src, const scalar_t_1* cos,\n                                          const scalar_t_1* sin, scalar_t_0* dst) {\n  int s_id = blockIdx.x, b_id = blockIdx.y;\n  int offset_block = s_id * stride_s + b_id * stride_b;\n  int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;\n  fused_rope_cached_block_forward(src, cos, sin, dst, s_id, offset_block, offset_block_dst, h, d, d2, stride_h,\n                                  stride_d, o_stride_h, o_stride_d);\n}\n\ntemplate <typename scalar_t_0, typename scalar_t_1>\n__global__ void fused_rope_cached_backward(const int h, const int d, const int d2, const int stride_s,\n                                           const int stride_b, const int stride_h, const int stride_d,\n                                           const int o_stride_s, const int o_stride_b, const int o_stride_h,\n                                           const int o_stride_d, const scalar_t_0* src, const scalar_t_1* cos,\n                                           const scalar_t_1* sin, scalar_t_0* dst) {\n  int s_id = blockIdx.x, b_id = blockIdx.y;\n  int offset_block = s_id * stride_s + b_id * stride_b;\n  int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;\n  fused_rope_cached_block_backward(src, cos, sin, dst, s_id, offset_block, offset_block_dst, h, d, d2, stride_h,\n                                   stride_d, o_stride_h, o_stride_d);\n}\n\ntemplate <typename scalar_t>\n__global__ void fused_rope_thd_forward(const int h, const int d, const int d2, const int stride_t, const int stride_h,\n                                       const int stride_d, const int o_stride_t, const int o_stride_h,\n                                       const int o_stride_d, const scalar_t* src, const int* cu_seqlens,\n                                       const float* freqs, scalar_t* dst) {\n  int s_id = blockIdx.x, b_id = blockIdx.y;\n  int t_id = s_id + cu_seqlens[b_id];\n  if (t_id >= cu_seqlens[b_id + 1]) return;\n  int offset_block = t_id * stride_t;\n  int offset_block_dst = t_id * o_stride_t;\n  fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h,\n                           o_stride_d);\n}\n\ntemplate <typename scalar_t>\n__global__ void fused_rope_thd_backward(const int h, const int d, const int d2, const int stride_t, const int stride_h,\n                                        const int stride_d, const int o_stride_t, const int o_stride_h,\n                                        const int o_stride_d, const scalar_t* src, const int* cu_seqlens,\n                                        const float* freqs, scalar_t* dst) {\n  int s_id = blockIdx.x, b_id = blockIdx.y;\n  int t_id = s_id + cu_seqlens[b_id];\n  if (t_id >= cu_seqlens[b_id + 1]) return;\n  int offset_block = t_id * stride_t;\n  int offset_block_dst = t_id * o_stride_t;\n  fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h,\n                            o_stride_d);\n}\n\ntemplate <typename scalar_t_0, typename scalar_t_1>\n__global__ void fused_rope_2d_forward(const int ih, const int iw, const int h, const int d, const int stride_b,\n                                      const int stride_ih, const int stride_iw, const int stride_h, const int stride_d,\n                                      const int o_stride_b, const int o_stride_s, const int o_stride_h,\n                                      const int o_stride_d, const scalar_t_0* src, const scalar_t_1* cos_h,\n                                      const scalar_t_1* sin_h, const scalar_t_1* cos_w, const scalar_t_1* sin_w,\n                                      scalar_t_0* dst) {\n  int ih_id = blockIdx.x, iw_id = blockIdx.y, b_id = blockIdx.z;\n  // apply to height\n  int offset_block = b_id * stride_b + ih_id * stride_ih + iw_id * stride_iw;\n  int offset_block_dst = b_id * o_stride_b + (ih_id * iw + iw_id) * o_stride_s;\n  int s_id = ih_id;  // for cos_h and sin_h\n  fused_rope_cached_block_forward(src, cos_h, sin_h, dst, s_id, offset_block, offset_block_dst, h, d / 2, d / 2,\n                                  stride_h, stride_d, o_stride_h, o_stride_d);\n  // apply to width\n  offset_block += d / 2 * stride_d;\n  offset_block_dst += d / 2 * o_stride_d;\n  s_id = iw_id;  // for cos_w and sin_w\n  fused_rope_cached_block_forward(src, cos_w, sin_w, dst, s_id, offset_block, offset_block_dst, h, d / 2, d / 2,\n                                  stride_h, stride_d, o_stride_h, o_stride_d);\n}\n\ntemplate <typename scalar_t_0, typename scalar_t_1>\n__global__ void fused_rope_2d_backward(const int ih, const int iw, const int h, const int d, const int stride_b,\n                                       const int stride_ih, const int stride_iw, const int stride_h, const int stride_d,\n                                       const int o_stride_b, const int o_stride_s, const int o_stride_h,\n                                       const int o_stride_d, const scalar_t_0* src, const scalar_t_1* cos_h,\n                                       const scalar_t_1* sin_h, const scalar_t_1* cos_w, const scalar_t_1* sin_w,\n                                       scalar_t_0* dst) {\n  int ih_id = blockIdx.x, iw_id = blockIdx.y, b_id = blockIdx.z;\n  // apply to height\n  int offset_block = b_id * stride_b + ih_id * stride_ih + iw_id * stride_iw;\n  int offset_block_dst = b_id * o_stride_b + (ih_id * iw + iw_id) * o_stride_s;\n  int s_id = ih_id;  // for cos_h and sin_h\n  fused_rope_cached_block_backward(src, cos_h, sin_h, dst, s_id, offset_block, offset_block_dst, h, d / 2, d / 2,\n                                   stride_h, stride_d, o_stride_h, o_stride_d);\n  // apply to width\n  offset_block += d / 2 * stride_d;\n  offset_block_dst += d / 2 * o_stride_d;\n  s_id = iw_id;  // for cos_w and sin_w\n  fused_rope_cached_block_backward(src, cos_w, sin_w, dst, s_id, offset_block, offset_block_dst, h, d / 2, d / 2,\n                                   stride_h, stride_d, o_stride_h, o_stride_d);\n}\n\n}  // end of anonymous namespace\n\ntemplate <typename scalar_t>\nvoid dispatch_fused_rope_forward(const int s, const int b, const int h, const int d, const int d2, const int stride_s,\n                                 const int stride_b, const int stride_h, const int stride_d, const int o_stride_s,\n                                 const int o_stride_b, const int o_stride_h, const int o_stride_d,\n                                 const scalar_t* input, const float* freqs, scalar_t* output) {\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  int warps_per_block = h < 16 ? 4 : 8;\n  dim3 blocks(s, b);\n  dim3 threads(C10_WARP_SIZE, warps_per_block);\n\n  fused_rope_forward<<<blocks, threads, 0, stream>>>(h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,\n                                                     o_stride_b, o_stride_h, o_stride_d, input, freqs, output);\n  C10_CUDA_KERNEL_LAUNCH_CHECK();\n}\n\ntemplate <typename scalar_t>\nvoid dispatch_fused_rope_backward(const int s, const int b, const int h, const int d, const int d2, const int stride_s,\n                                  const int stride_b, const int stride_h, const int stride_d, const int o_stride_s,\n                                  const int o_stride_b, const int o_stride_h, const int o_stride_d,\n                                  const scalar_t* output_grads, const float* freqs, scalar_t* input_grads) {\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  int warps_per_block = h < 16 ? 4 : 8;\n  dim3 blocks(s, b);\n  dim3 threads(C10_WARP_SIZE, warps_per_block);\n\n  fused_rope_backward<<<blocks, threads, 0, stream>>>(h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,\n                                                      o_stride_b, o_stride_h, o_stride_d, output_grads, freqs,\n                                                      input_grads);\n  C10_CUDA_KERNEL_LAUNCH_CHECK();\n}\n\ntemplate <typename scalar_t_0, typename scalar_t_1>\nvoid dispatch_fused_rope_cached_forward(const int s, const int b, const int h, const int d, const int d2,\n                                        const int stride_s, const int stride_b, const int stride_h, const int stride_d,\n                                        const int o_stride_s, const int o_stride_b, const int o_stride_h,\n                                        const int o_stride_d, const scalar_t_0* input, const scalar_t_1* cos,\n                                        const scalar_t_1* sin, scalar_t_0* output) {\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  int warps_per_block = h < 16 ? 4 : 8;\n  dim3 blocks(s, b);\n  dim3 threads(C10_WARP_SIZE, warps_per_block);\n\n  fused_rope_cached_forward<<<blocks, threads, 0, stream>>>(h, d, d2, stride_s, stride_b, stride_h, stride_d,\n                                                            o_stride_s, o_stride_b, o_stride_h, o_stride_d, input, cos,\n                                                            sin, output);\n  C10_CUDA_KERNEL_LAUNCH_CHECK();\n}\n\ntemplate <typename scalar_t_0, typename scalar_t_1>\nvoid dispatch_fused_rope_cached_backward(const int s, const int b, const int h, const int d, const int d2,\n                                         const int stride_s, const int stride_b, const int stride_h, const int stride_d,\n                                         const int o_stride_s, const int o_stride_b, const int o_stride_h,\n                                         const int o_stride_d, const scalar_t_0* output_grads, const scalar_t_1* cos,\n                                         const scalar_t_1* sin, scalar_t_0* input_grads) {\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  int warps_per_block = h < 16 ? 4 : 8;\n  dim3 blocks(s, b);\n  dim3 threads(C10_WARP_SIZE, warps_per_block);\n\n  fused_rope_cached_backward<<<blocks, threads, 0, stream>>>(h, d, d2, stride_s, stride_b, stride_h, stride_d,\n                                                             o_stride_s, o_stride_b, o_stride_h, o_stride_d,\n                                                             output_grads, cos, sin, input_grads);\n  C10_CUDA_KERNEL_LAUNCH_CHECK();\n}\n\ntemplate <typename scalar_t>\nvoid dispatch_fused_rope_thd_forward(const int max_s, const int b, const int h, const int d, const int d2,\n                                     const int stride_t, const int stride_h, const int stride_d, const int o_stride_t,\n                                     const int o_stride_h, const int o_stride_d, const scalar_t* input,\n                                     const int* cu_seqlens, const float* freqs, scalar_t* output) {\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  int warps_per_block = h < 16 ? 4 : 8;\n  dim3 blocks(max_s, b);\n  dim3 threads(C10_WARP_SIZE, warps_per_block);\n\n  fused_rope_thd_forward<<<blocks, threads, 0, stream>>>(h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h,\n                                                         o_stride_d, input, cu_seqlens, freqs, output);\n  C10_CUDA_KERNEL_LAUNCH_CHECK();\n}\n\ntemplate <typename scalar_t>\nvoid dispatch_fused_rope_thd_backward(const int max_s, const int b, const int h, const int d, const int d2,\n                                      const int stride_t, const int stride_h, const int stride_d, const int o_stride_t,\n                                      const int o_stride_h, const int o_stride_d, const scalar_t* output_grads,\n                                      const int* cu_seqlens, const float* freqs, scalar_t* input_grads) {\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  int warps_per_block = h < 16 ? 4 : 8;\n  dim3 blocks(max_s, b);\n  dim3 threads(C10_WARP_SIZE, warps_per_block);\n\n  fused_rope_thd_backward<<<blocks, threads, 0, stream>>>(h, d, d2, stride_t, stride_h, stride_d, o_stride_t,\n                                                          o_stride_h, o_stride_d, output_grads, cu_seqlens, freqs,\n                                                          input_grads);\n  C10_CUDA_KERNEL_LAUNCH_CHECK();\n}\n\ntemplate <typename scalar_t_0, typename scalar_t_1>\nvoid dispatch_fused_rope_2d_forward(const int b, const int ih, const int iw, const int h, const int d,\n                                    const int stride_b, const int stride_ih, const int stride_iw, const int stride_h,\n                                    const int stride_d, const int o_stride_b, const int o_stride_s,\n                                    const int o_stride_h, const int o_stride_d, const scalar_t_0* input,\n                                    const scalar_t_1* cos_h, const scalar_t_1* sin_h, const scalar_t_1* cos_w,\n                                    const scalar_t_1* sin_w, scalar_t_0* output) {\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  int warps_per_block = h < 16 ? 4 : 8;\n  dim3 blocks(ih, iw, b);\n  dim3 threads(C10_WARP_SIZE, warps_per_block);\n\n  fused_rope_2d_forward<<<blocks, threads, 0, stream>>>(ih, iw, h, d, stride_b, stride_ih, stride_iw, stride_h,\n                                                        stride_d, o_stride_b, o_stride_s, o_stride_h, o_stride_d, input,\n                                                        cos_h, sin_h, cos_w, sin_w, output);\n  C10_CUDA_KERNEL_LAUNCH_CHECK();\n}\n\ntemplate <typename scalar_t_0, typename scalar_t_1>\nvoid dispatch_fused_rope_2d_backward(const int b, const int ih, const int iw, const int h, const int d,\n                                     const int stride_b, const int stride_ih, const int stride_iw, const int stride_h,\n                                     const int stride_d, const int o_stride_b, const int o_stride_s,\n                                     const int o_stride_h, const int o_stride_d, const scalar_t_0* output_grads,\n                                     const scalar_t_1* cos_h, const scalar_t_1* sin_h, const scalar_t_1* cos_w,\n                                     const scalar_t_1* sin_w, scalar_t_0* input_grads) {\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  int warps_per_block = h < 16 ? 4 : 8;\n  dim3 blocks(ih, iw, b);\n  dim3 threads(C10_WARP_SIZE, warps_per_block);\n\n  fused_rope_2d_backward<<<blocks, threads, 0, stream>>>(ih, iw, h, d, stride_b, stride_ih, stride_iw, stride_h,\n                                                         stride_d, o_stride_b, o_stride_s, o_stride_h, o_stride_d,\n                                                         output_grads, cos_h, sin_h, cos_w, sin_w, input_grads);\n  C10_CUDA_KERNEL_LAUNCH_CHECK();\n}\n"
  },
  {
    "path": "csrc/megatron/fused_rotary_positional_embedding_cuda.cu",
    "content": "/* coding=utf-8\n * Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <ATen/ATen.h>\n\n#include \"fused_rotary_positional_embedding.h\"\n#include \"type_shim.h\"\n\nnamespace fused_rope {\n\ntorch::Tensor fwd_cuda(const torch::Tensor& input, const torch::Tensor& freqs, const bool transpose_output) {\n  // input sizes: (s, b, h, d)\n  // s: sequence length\n  // b: batch size\n  // h: head num\n  // d: dim of each head\n  const int s = input.size(0);\n  const int b = input.size(1);\n  const int h = input.size(2);\n  const int d = input.size(3);\n  // input strides\n  const int stride_s = input.stride(0);\n  const int stride_b = input.stride(1);\n  const int stride_h = input.stride(2);\n  const int stride_d = input.stride(3);\n  // freqs' shape is always (s, 1, 1, d2), so the strides are same under\n  // different memory formats\n  const int d2 = freqs.size(3);\n\n  // output\n  auto act_options = input.options().requires_grad(false);\n  torch::Tensor output;\n  if (transpose_output) {\n    output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);\n  } else {\n    output = torch::empty({s, b, h, d}, act_options);\n  }\n  // output strides\n  const int o_stride_s = output.stride(0);\n  const int o_stride_b = output.stride(1);\n  const int o_stride_h = output.stride(2);\n  const int o_stride_d = output.stride(3);\n\n  DISPATCH_FLOAT_HALF_AND_BFLOAT(\n      input.scalar_type(), 0, \"dispatch_fused_rope_forward\",\n      dispatch_fused_rope_forward(s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,\n                                  o_stride_h, o_stride_d, input.data_ptr<scalar_t_0>(), freqs.data_ptr<float>(),\n                                  output.data_ptr<scalar_t_0>()););\n  return output;\n}\n\ntorch::Tensor bwd_cuda(const torch::Tensor& output_grads, const torch::Tensor& freqs, const bool transpose_output) {\n  // output_grads sizes: (s, b, h, d)\n  // s: sequence length\n  // b: batch size\n  // h: head num\n  // d: dim of each head\n  const int s = output_grads.size(0);\n  const int b = output_grads.size(1);\n  const int h = output_grads.size(2);\n  const int d = output_grads.size(3);\n  // output_grads strides\n  const int stride_s = output_grads.stride(0);\n  const int stride_b = output_grads.stride(1);\n  const int stride_h = output_grads.stride(2);\n  const int stride_d = output_grads.stride(3);\n  // freqs' shape is always (s, 1, 1, d2), so the strides are same under\n  // different memory formats\n  const int d2 = freqs.size(3);\n\n  auto act_options = output_grads.options().requires_grad(false);\n  torch::Tensor input_grads;\n  if (transpose_output) {\n    input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);\n  } else {\n    input_grads = torch::empty({s, b, h, d}, act_options);\n  }\n  const int o_stride_s = input_grads.stride(0);\n  const int o_stride_b = input_grads.stride(1);\n  const int o_stride_h = input_grads.stride(2);\n  const int o_stride_d = input_grads.stride(3);\n\n  DISPATCH_FLOAT_HALF_AND_BFLOAT(\n      output_grads.scalar_type(), 0, \"dispatch_fused_rope_backward\",\n      dispatch_fused_rope_backward(s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,\n                                   o_stride_h, o_stride_d, output_grads.data_ptr<scalar_t_0>(), freqs.data_ptr<float>(),\n                                   input_grads.data_ptr<scalar_t_0>()););\n  return input_grads;\n}\n\n#define DISPATCH_FUSED_ROPE_TYPES(TYPE1, TYPE2, NAME, ...)                                                      \\\n  switch (TYPE1) {                                                                                              \\\n    case at::ScalarType::Float: {                                                                               \\\n      using scalar_t_0 = float;                                                                                 \\\n      switch (TYPE2) {                                                                                          \\\n        case at::ScalarType::Float: {                                                                           \\\n          using scalar_t_1 = float;                                                                             \\\n          __VA_ARGS__;                                                                                          \\\n          break;                                                                                                \\\n        }                                                                                                       \\\n        default:                                                                                                \\\n          TORCH_CHECK(false, #NAME, \" not supported for '\", toString(TYPE1), \"' with '\", toString(TYPE2), \"'\"); \\\n      }                                                                                                         \\\n      break;                                                                                                    \\\n    }                                                                                                           \\\n    case at::ScalarType::Half: {                                                                                \\\n      using scalar_t_0 = at::Half;                                                                              \\\n      switch (TYPE2) {                                                                                          \\\n        case at::ScalarType::Float: {                                                                           \\\n          using scalar_t_1 = float;                                                                             \\\n          __VA_ARGS__;                                                                                          \\\n          break;                                                                                                \\\n        }                                                                                                       \\\n        case at::ScalarType::Half: {                                                                            \\\n          using scalar_t_1 = at::Half;                                                                          \\\n          __VA_ARGS__;                                                                                          \\\n          break;                                                                                                \\\n        }                                                                                                       \\\n        default:                                                                                                \\\n          TORCH_CHECK(false, #NAME, \" not supported for '\", toString(TYPE1), \"' with '\", toString(TYPE2), \"'\"); \\\n      }                                                                                                         \\\n      break;                                                                                                    \\\n    }                                                                                                           \\\n    case at::ScalarType::BFloat16: {                                                                            \\\n      using scalar_t_0 = at::BFloat16;                                                                          \\\n      switch (TYPE2) {                                                                                          \\\n        case at::ScalarType::Float: {                                                                           \\\n          using scalar_t_1 = float;                                                                             \\\n          __VA_ARGS__;                                                                                          \\\n          break;                                                                                                \\\n        }                                                                                                       \\\n        case at::ScalarType::BFloat16: {                                                                        \\\n          using scalar_t_1 = at::BFloat16;                                                                      \\\n          __VA_ARGS__;                                                                                          \\\n          break;                                                                                                \\\n        }                                                                                                       \\\n        default:                                                                                                \\\n          TORCH_CHECK(false, #NAME, \" not supported for '\", toString(TYPE1), \"' with '\", toString(TYPE2), \"'\"); \\\n      }                                                                                                         \\\n      break;                                                                                                    \\\n    }                                                                                                           \\\n    default:                                                                                                    \\\n      TORCH_CHECK(false, #NAME, \" not supported for '\", toString(TYPE1), \"' with '\", toString(TYPE2), \"'\");     \\\n  }\n\ntorch::Tensor fwd_cached_cuda(const torch::Tensor& input, const torch::Tensor& cos, const torch::Tensor& sin,\n                              const bool transpose_output) {\n  // input sizes: (s, b, h, d)\n  // s: sequence length\n  // b: batch size\n  // h: head num\n  // d: dim of each head\n  const int s = input.size(0);\n  const int b = input.size(1);\n  const int h = input.size(2);\n  const int d = input.size(3);\n  // input strides\n  const int stride_s = input.stride(0);\n  const int stride_b = input.stride(1);\n  const int stride_h = input.stride(2);\n  const int stride_d = input.stride(3);\n  // cos/sin's shape is always (s, 1, 1, d2), so the strides are same under\n  // different memory formats\n  const int d2 = cos.size(3);\n\n  // output\n  auto act_options = input.options().requires_grad(false);\n  torch::Tensor output;\n  if (transpose_output) {\n    output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);\n  } else {\n    output = torch::empty({s, b, h, d}, act_options);\n  }\n  // output strides\n  const int o_stride_s = output.stride(0);\n  const int o_stride_b = output.stride(1);\n  const int o_stride_h = output.stride(2);\n  const int o_stride_d = output.stride(3);\n\n  DISPATCH_FUSED_ROPE_TYPES(input.scalar_type(), cos.scalar_type(), \"dispatch_fused_rope_cached_forward\",\n                            dispatch_fused_rope_cached_forward(\n                                s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,\n                                o_stride_h, o_stride_d, input.data_ptr<scalar_t_0>(), cos.data_ptr<scalar_t_1>(),\n                                sin.data_ptr<scalar_t_1>(), output.data_ptr<scalar_t_0>()););\n  return output;\n}\n\ntorch::Tensor bwd_cached_cuda(const torch::Tensor& output_grads, const torch::Tensor& cos, const torch::Tensor& sin,\n                              const bool transpose_output) {\n  // output_grads sizes: (s, b, h, d)\n  // s: sequence length\n  // b: batch size\n  // h: head num\n  // d: dim of each head\n  const int s = output_grads.size(0);\n  const int b = output_grads.size(1);\n  const int h = output_grads.size(2);\n  const int d = output_grads.size(3);\n  // output_grads strides\n  const int stride_s = output_grads.stride(0);\n  const int stride_b = output_grads.stride(1);\n  const int stride_h = output_grads.stride(2);\n  const int stride_d = output_grads.stride(3);\n  // cos/sin's shape is always (s, 1, 1, d2), so the strides are same under\n  // different memory formats\n  const int d2 = cos.size(3);\n\n  auto act_options = output_grads.options().requires_grad(false);\n  torch::Tensor input_grads;\n  if (transpose_output) {\n    input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);\n  } else {\n    input_grads = torch::empty({s, b, h, d}, act_options);\n  }\n  const int o_stride_s = input_grads.stride(0);\n  const int o_stride_b = input_grads.stride(1);\n  const int o_stride_h = input_grads.stride(2);\n  const int o_stride_d = input_grads.stride(3);\n\n  DISPATCH_FUSED_ROPE_TYPES(output_grads.scalar_type(), cos.scalar_type(), \"dispatch_fused_rope_cached_backward\",\n                            dispatch_fused_rope_cached_backward(\n                                s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,\n                                o_stride_h, o_stride_d, output_grads.data_ptr<scalar_t_0>(), cos.data_ptr<scalar_t_1>(),\n                                sin.data_ptr<scalar_t_1>(), input_grads.data_ptr<scalar_t_0>()););\n  return input_grads;\n}\n\ntorch::Tensor fwd_thd_cuda(const torch::Tensor& input, const torch::Tensor& cu_seqlens, const torch::Tensor& freqs) {\n  // input sizes: (t, h, d)\n  // t: cumulative sum of sequence lengths\n  // h: head num\n  // d: dim of each head\n  const int t = input.size(0);\n  const int h = input.size(1);\n  const int d = input.size(2);\n  // input strides\n  const int stride_t = input.stride(0);\n  const int stride_h = input.stride(1);\n  const int stride_d = input.stride(2);\n  // batch size\n  const int b = cu_seqlens.size(0) - 1;\n  // freqs' shape is (max_s, 1, 1, d2)\n  const int max_s = freqs.size(0);\n  const int d2 = freqs.size(3);\n\n  // output\n  auto act_options = input.options().requires_grad(false);\n  auto output = torch::empty({t, h, d}, act_options);\n  // output strides\n  const int o_stride_t = output.stride(0);\n  const int o_stride_h = output.stride(1);\n  const int o_stride_d = output.stride(2);\n\n  DISPATCH_FLOAT_HALF_AND_BFLOAT(\n      input.scalar_type(), 0, \"dispatch_fused_rope_thd_forward\",\n      dispatch_fused_rope_thd_forward(max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h,\n                                      o_stride_d, input.data_ptr<scalar_t_0>(), cu_seqlens.data_ptr<int>(),\n                                      freqs.data_ptr<float>(), output.data_ptr<scalar_t_0>()););\n  return output;\n}\n\ntorch::Tensor bwd_thd_cuda(const torch::Tensor& output_grads, const torch::Tensor& cu_seqlens,\n                           const torch::Tensor& freqs) {\n  // output_grads sizes: (t, h, d)\n  // t: cumulative sum of sequence lengths\n  // h: head num\n  // d: dim of each head\n  const int t = output_grads.size(0);\n  const int h = output_grads.size(1);\n  const int d = output_grads.size(2);\n  // output_grads strides\n  const int stride_t = output_grads.stride(0);\n  const int stride_h = output_grads.stride(1);\n  const int stride_d = output_grads.stride(2);\n  // batch size\n  const int b = cu_seqlens.size(0) - 1;\n  // freqs' shape is (max_s, 1, 1, d2)\n  const int max_s = freqs.size(0);\n  const int d2 = freqs.size(3);\n\n  auto act_options = output_grads.options().requires_grad(false);\n  auto input_grads = torch::empty({t, h, d}, act_options);\n  const int o_stride_t = input_grads.stride(0);\n  const int o_stride_h = input_grads.stride(1);\n  const int o_stride_d = input_grads.stride(2);\n\n  DISPATCH_FLOAT_HALF_AND_BFLOAT(\n      output_grads.scalar_type(), 0, \"dispatch_fused_rope_thd_backward\",\n      dispatch_fused_rope_thd_backward(max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h,\n                                       o_stride_d, output_grads.data_ptr<scalar_t_0>(), cu_seqlens.data_ptr<int>(),\n                                       freqs.data_ptr<float>(), input_grads.data_ptr<scalar_t_0>()););\n  return input_grads;\n}\n\ntorch::Tensor fwd_2d_cuda(const torch::Tensor& input, const torch::Tensor& cos_h, const torch::Tensor& sin_h,\n                          const torch::Tensor& cos_w, const torch::Tensor& sin_w) {\n  // input sizes: (b, ih, iw, h, d)\n  // b: batch size\n  // ih: image height\n  // iw: image width\n  // h: head num\n  // d: dim of each head\n  const int b = input.size(0);\n  const int ih = input.size(1);\n  const int iw = input.size(2);\n  const int h = input.size(3);\n  const int d = input.size(4);\n  // input strides\n  const int stride_b = input.stride(0);\n  const int stride_ih = input.stride(1);\n  const int stride_iw = input.stride(2);\n  const int stride_h = input.stride(3);\n  const int stride_d = input.stride(4);\n\n  // output\n  auto act_options = input.options().requires_grad(false);\n  auto output = torch::empty({b, ih * iw, h, d}, act_options);\n  // output strides\n  const int o_stride_b = output.stride(0);\n  const int o_stride_s = output.stride(1);\n  const int o_stride_h = output.stride(2);\n  const int o_stride_d = output.stride(3);\n\n  DISPATCH_FUSED_ROPE_TYPES(\n      input.scalar_type(), cos_h.scalar_type(), \"dispatch_fused_rope_2d_forward\",\n      dispatch_fused_rope_2d_forward(\n          b, ih, iw, h, d, stride_b, stride_ih, stride_iw, stride_h, stride_d, o_stride_b, o_stride_s, o_stride_h,\n          o_stride_d, input.data_ptr<scalar_t_0>(), cos_h.data_ptr<scalar_t_1>(), sin_h.data_ptr<scalar_t_1>(),\n          cos_w.data_ptr<scalar_t_1>(), sin_w.data_ptr<scalar_t_1>(), output.data_ptr<scalar_t_0>()););\n  return output;\n}\n\ntorch::Tensor bwd_2d_cuda(const torch::Tensor& output_grads, const torch::Tensor& cos_h, const torch::Tensor& sin_h,\n                          const torch::Tensor& cos_w, const torch::Tensor& sin_w) {\n  // output_grads sizes: (b, ih, iw, h, d)\n  // b: batch size\n  // ih: image height\n  // iw: image width\n  // h: head num\n  // d: dim of each head\n  const int b = output_grads.size(0);\n  const int ih = output_grads.size(1);\n  const int iw = output_grads.size(2);\n  const int h = output_grads.size(3);\n  const int d = output_grads.size(4);\n  // output_grads strides\n  const int stride_b = output_grads.stride(0);\n  const int stride_ih = output_grads.stride(1);\n  const int stride_iw = output_grads.stride(2);\n  const int stride_h = output_grads.stride(3);\n  const int stride_d = output_grads.stride(4);\n\n  auto act_options = output_grads.options().requires_grad(false);\n  auto input_grads = torch::empty({b, ih * iw, h, d}, act_options);\n  const int o_stride_b = input_grads.stride(0);\n  const int o_stride_s = input_grads.stride(1);\n  const int o_stride_h = input_grads.stride(2);\n  const int o_stride_d = input_grads.stride(3);\n\n  DISPATCH_FUSED_ROPE_TYPES(\n      output_grads.scalar_type(), cos_h.scalar_type(), \"dispatch_fused_rope_2d_backward\",\n      dispatch_fused_rope_2d_backward(\n          b, ih, iw, h, d, stride_b, stride_ih, stride_iw, stride_h, stride_d, o_stride_b, o_stride_s, o_stride_h,\n          o_stride_d, output_grads.data_ptr<scalar_t_0>(), cos_h.data_ptr<scalar_t_1>(), sin_h.data_ptr<scalar_t_1>(),\n          cos_w.data_ptr<scalar_t_1>(), sin_w.data_ptr<scalar_t_1>(), input_grads.data_ptr<scalar_t_0>()););\n  return input_grads;\n}\n\n}  // end namespace fused_rope\n"
  },
  {
    "path": "csrc/megatron/fused_weight_gradient_dense.cpp",
    "content": "#include <torch/extension.h>\n\n#include <cstdio>\n#include <vector>\n\nvoid wgrad_gemm_accum_fp32_cuda_stub(at::Tensor& input_2d, at::Tensor& d_output_2d, at::Tensor& d_weight);\n\nvoid wgrad_gemm_accum_fp16_cuda_stub(at::Tensor& input_2d, at::Tensor& d_output_2d, at::Tensor& d_weight);\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"wgrad_gemm_accum_fp32\", &wgrad_gemm_accum_fp32_cuda_stub, \"wgrad gemm accum in fp32\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"wgrad_gemm_accum_fp16\", &wgrad_gemm_accum_fp16_cuda_stub, \"wgrad gemm accum in fp16\",\n        py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n\n#include <cassert>\n#include <cstdio>\n#include <cstdlib>\n#include <cstring>\n\n/* Includes, cuda */\n#include <cublas_v2.h>\n#include <cuda_runtime.h>\n\n#include \"type_shim.h\"\n\n// BF16 inputs and BF16 accumulation\nvoid gemmex_wrapper_fp16(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                         const float* alpha, at::BFloat16* A, int lda, at::BFloat16* B, int ldb, const float* beta,\n                         at::BFloat16* C, int ldc) {\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_R_16BF, lda, B, CUDA_R_16BF, ldb,\n                                    beta, C, CUDA_R_16BF, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n}\n\n// FP16 inputs and FP16 accumulation\nvoid gemmex_wrapper_fp16(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                         const float* alpha, at::Half* A, int lda, at::Half* B, int ldb, const float* beta, at::Half* C,\n                         int ldc) {\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_R_16F, lda, B, CUDA_R_16F, ldb,\n                                    beta, C, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n}\n\ntemplate <typename T>\nvoid wgrad_gemm_accum_fp16_cuda(T* input, T* d_output, T* d_weight, int in_dim, int hidden_dim, int out_dim) {\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t stream;\n  cublasGetStream(handle, &stream);\n  const float alpha = 1.0;\n  const float beta = 1.0;\n\n  gemmex_wrapper_fp16(handle, CUBLAS_OP_N, CUBLAS_OP_T, in_dim, out_dim, hidden_dim, &alpha, input, in_dim, d_output,\n                      out_dim, &beta, d_weight, in_dim);\n}\n\ntemplate void wgrad_gemm_accum_fp16_cuda<at::Half>(at::Half* input, at::Half* d_output, at::Half* d_weight, int in_dim,\n                                                   int hidden_dim, int out_dim);\ntemplate void wgrad_gemm_accum_fp16_cuda<at::BFloat16>(at::BFloat16* input, at::BFloat16* d_output,\n                                                       at::BFloat16* d_weight, int in_dim, int hidden_dim, int out_dim);\n\nvoid wgrad_gemm_accum_fp16_cuda_stub(at::Tensor& input, at::Tensor& d_output, at::Tensor& d_weight) {\n  at::Tensor input_2d, d_output_2d;\n  // input tensor: collapse to the first dim\n  auto in_sizes = input.sizes();\n  if (input.dim() > 2) {\n    input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});\n  } else {\n    input_2d = input;\n  }\n  // d_output tensor: collapse to the first dim\n  auto d_out_sizes = d_output.sizes();\n  if (d_output.dim() > 2) {\n    d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});\n  } else {\n    d_output_2d = d_output;\n  }\n\n  const int hidden_dim = input_2d.size(0);\n  const int in_dim = input_2d.size(1);\n  const int out_dim = d_weight.size(0);\n\n  DISPATCH_HALF_AND_BFLOAT(\n      input_2d.scalar_type(), \"wgrad_gemm_accum_fp16\",\n      wgrad_gemm_accum_fp16_cuda<scalar_t>(input_2d.data_ptr<scalar_t>(), d_output_2d.data_ptr<scalar_t>(),\n                                           d_weight.data_ptr<scalar_t>(), in_dim, hidden_dim, out_dim););\n}\n"
  },
  {
    "path": "csrc/megatron/fused_weight_gradient_dense_cuda.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n\n#include <cassert>\n#include <cstdio>\n#include <cstdlib>\n#include <cstring>\n\n/* Includes, cuda */\n#include <cublas_v2.h>\n#include <cuda_runtime.h>\n\n#include \"type_shim.h\"\n\n// BF16 Tensor core wrapper around cublas GEMMEx\nvoid gemmex_wrapper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                    const float* alpha, at::BFloat16* A, int lda, at::BFloat16* B, int ldb, const float* beta, float* C,\n                    int ldc) {\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_R_16BF, lda, B, CUDA_R_16BF, ldb,\n                                    beta, C, CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n}\n\n// FP16 Tensor core wrapper around cublas GEMMEx\nvoid gemmex_wrapper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                    const float* alpha, at::Half* A, int lda, at::Half* B, int ldb, const float* beta, float* C,\n                    int ldc) {\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_R_16F, lda, B, CUDA_R_16F, ldb,\n                                    beta, C, CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n}\n\n// FP32 wrapper around cublas GEMMEx\nvoid gemmex_wrapper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                    const float* alpha, float* A, int lda, float* B, int ldb, const float* beta, float* C, int ldc) {\n  TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_R_32F, lda, B, CUDA_R_32F, ldb,\n                                    beta, C, CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));\n}\n\ntemplate <typename T>\nvoid wgrad_gemm_accum_fp32_cuda(T* input, T* d_output, float* d_weight, int in_dim, int hidden_dim, int out_dim) {\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  cudaStream_t stream;\n  cublasGetStream(handle, &stream);\n  const float alpha = 1.0;\n  const float beta = 1.0;\n\n  gemmex_wrapper(handle, CUBLAS_OP_N, CUBLAS_OP_T, in_dim, out_dim, hidden_dim, &alpha, input, in_dim, d_output,\n                 out_dim, &beta, d_weight, in_dim);\n}\n\ntemplate void wgrad_gemm_accum_fp32_cuda<at::Half>(at::Half* input, at::Half* d_output, float* d_weight, int in_dim,\n                                                   int hidden_dim, int out_dim);\ntemplate void wgrad_gemm_accum_fp32_cuda<at::BFloat16>(at::BFloat16* input, at::BFloat16* d_output, float* d_weight,\n                                                       int in_dim, int hidden_dim, int out_dim);\ntemplate void wgrad_gemm_accum_fp32_cuda<float>(float* input, float* d_output, float* d_weight, int in_dim,\n                                                int hidden_dim, int out_dim);\n\nvoid wgrad_gemm_accum_fp32_cuda_stub(at::Tensor& input, at::Tensor& d_output, at::Tensor& d_weight) {\n  at::Tensor input_2d, d_output_2d;\n  // input tensor: collapse to the first dim\n  auto in_sizes = input.sizes();\n  if (input.dim() > 2) {\n    input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});\n  } else {\n    input_2d = input;\n  }\n  // d_output tensor: collapse to the first dim\n  auto d_out_sizes = d_output.sizes();\n  if (d_output.dim() > 2) {\n    d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});\n  } else {\n    d_output_2d = d_output;\n  }\n\n  const int hidden_dim = input_2d.size(0);\n  const int in_dim = input_2d.size(1);\n  const int out_dim = d_weight.size(0);\n\n  DISPATCH_FLOAT_HALF_AND_BFLOAT(\n      input_2d.scalar_type(), 0, \"wgrad_gemm_accum_fp32\",\n      wgrad_gemm_accum_fp32_cuda<scalar_t_0>(input_2d.data_ptr<scalar_t_0>(), d_output_2d.data_ptr<scalar_t_0>(),\n                                             d_weight.data_ptr<float>(), in_dim, hidden_dim, out_dim););\n}\n"
  },
  {
    "path": "csrc/megatron/generic_scaled_masked_softmax.cpp",
    "content": "/* coding=utf-8\n * Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <cuda_fp16.h>\n#include <torch/extension.h>\n\n#include <vector>\n\nnamespace multihead_attn {\nnamespace fused_softmax {\nnamespace generic_scaled_masked_softmax {\n\ntorch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor);\n\ntorch::Tensor bwd_cuda(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor);\n\ntorch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor) {\n  TORCH_CHECK(input.dim() == 4, \"expected 4D tensor\");\n  TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16),\n              \"Only fp16 and bf16 are supported\");\n  TORCH_CHECK(mask.dim() == 4, \"expected 4D tensor\");\n\n  return fwd_cuda(input, mask, scale_factor);\n}\n\ntorch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor) {\n  TORCH_CHECK(output_grads.dim() == 4, \"expected 3D tensor\");\n  TORCH_CHECK(softmax_results.dim() == 4, \"expected 3D tensor\");\n\n  TORCH_CHECK(\n      (output_grads.scalar_type() == at::ScalarType::Half) || (output_grads.scalar_type() == at::ScalarType::BFloat16),\n      \"Only fp16 and bf16 are supported\");\n  TORCH_CHECK((softmax_results.scalar_type() == at::ScalarType::Half) ||\n                  (softmax_results.scalar_type() == at::ScalarType::BFloat16),\n              \"Only fp16 and bf16 are supported\");\n\n  return bwd_cuda(output_grads, softmax_results, scale_factor);\n}\n\n}  // end namespace generic_scaled_masked_softmax\n}  // end namespace fused_softmax\n}  // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::fused_softmax::generic_scaled_masked_softmax::fwd,\n        \"Self Multihead Attention scaled, time masked softmax -- Forward.\", py::call_guard<py::gil_scoped_release>());\n\n  m.def(\"backward\", &multihead_attn::fused_softmax::generic_scaled_masked_softmax::bwd,\n        \"Self Multihead Attention scaled, time masked softmax -- Backward.\", py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "csrc/megatron/generic_scaled_masked_softmax.h",
    "content": "/* coding=utf-8\n * Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include <assert.h>\n#include <c10/macros/Macros.h>\n#include <cuda_fp16.h>\n#include <stdint.h>\n\n#include <cfloat>\n#include <limits>\n\nnamespace {\n\ntemplate <typename T>\nstruct Add {\n  __device__ __forceinline__ T operator()(T a, T b) const { return a + b; }\n};\n\ntemplate <typename T>\nstruct Max {\n  __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; }\n};\n\ntemplate <typename T>\n__device__ __forceinline__ T WARP_SHFL_DOWN_NATIVE(T value, int laneMask, int width = warpSize,\n                                                   unsigned int mask = 0xffffffff) {\n#if CUDA_VERSION >= 9000\n  return __shfl_down_sync(mask, value, laneMask, width);\n#else\n  return __shfl_down(value, laneMask, width);\n#endif\n}\n\ntemplate <typename acc_t, int WARP_SIZE, template <typename> class ReduceOp>\n__device__ __forceinline__ acc_t warp_reduce_new(acc_t val) {\n  ReduceOp<acc_t> r;\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n    val = r(val, WARP_SHFL_DOWN_NATIVE(val, offset, WARP_SIZE));\n  }\n  return val;\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t, int log2_elements>\n__global__ void scaled_masked_softmax_warp_backward_new(output_t* gradInput,  //[batches, attn_heads, q_len, k_len]\n                                                        input_t* grad,\n                                                        const input_t* output,  //[batches, attn_heads, q_len, k_len]\n                                                        acc_t scale, int element_count) {\n  int threads_per_block = blockDim.x;\n  // the first element_count*2 elements are used for cache, the last 128 is used for reduction\n  extern __shared__ acc_t shared_data[];\n  input_t* local_data = (input_t*)shared_data;\n  input_t* output_data = &local_data[element_count];\n  // maximum shared cached 128, enough for 4096 elements reduction into 4096/32= 128 elements\n  acc_t* shared = (acc_t*)(&(local_data[element_count * 2]));\n\n  int num_reductions = (element_count - 1) / threads_per_block + 1;\n\n  int offset = blockIdx.x * element_count;\n\n  int local_idx = threadIdx.x;\n  int lane = threadIdx.x % C10_WARP_SIZE;\n  int wid = threadIdx.x / C10_WARP_SIZE;\n  int warps_per_thread_block = threads_per_block / C10_WARP_SIZE;\n\n  // load the data to local data\n  acc_t val = 0.0;\n  for (int i = 0; i < num_reductions; i++) {\n    if (i * threads_per_block + local_idx < element_count) {\n      val = output[offset + i * threads_per_block + local_idx];\n      output_data[i * threads_per_block + local_idx] = val;\n      local_data[i * threads_per_block + local_idx] = val * grad[offset + i * threads_per_block + local_idx];\n    }\n    __syncthreads();\n  }\n\n  // find the sum\n  for (int i = local_idx; i < (element_count - 1) / C10_WARP_SIZE + 1; i += threads_per_block) {\n    shared[i] = 0.0;\n  }\n  __syncthreads();\n\n#pragma unroll\n  for (int i = 0; i < num_reductions; i++) {\n    if (i * threads_per_block + local_idx < element_count) {\n      val = local_data[i * threads_per_block + local_idx];\n    } else {\n      val = 0.0;\n    }\n    __syncthreads();\n    val = warp_reduce_new<acc_t, C10_WARP_SIZE, Add>(val);\n    if (lane == 0 && wid + warps_per_thread_block * i < (element_count - 1) / C10_WARP_SIZE + 1) {\n      shared[wid + warps_per_thread_block * i] = val;\n    }\n    __syncthreads();\n  }\n\n  // final shared reduction\n\n  int shared_mem_len = (element_count - 1) / C10_WARP_SIZE + 1;\n  int num_warps = (shared_mem_len - 1) / C10_WARP_SIZE + 1;\n  while (shared_mem_len > 1) {\n#pragma unroll\n    for (int i = 0; i < num_reductions; i++) {\n      if (i * threads_per_block + local_idx < shared_mem_len) {\n        val = shared[i * threads_per_block + local_idx];\n      } else {\n        val = 0.0;\n      }\n      __syncthreads();\n      val = warp_reduce_new<acc_t, C10_WARP_SIZE, Add>(val);\n      if (lane == 0) {\n        shared[wid + warps_per_thread_block * i] = val;\n      }\n      __syncthreads();\n    }\n    shared_mem_len = num_warps;\n    num_warps = (shared_mem_len - 1) / C10_WARP_SIZE + 1;\n  }\n  val = shared[0];\n#pragma unroll\n  for (int i = local_idx; i < element_count; i += threads_per_block) {\n    gradInput[offset + i] = (output_t)(scale * (local_data[i] - output_data[i] * val));\n  }\n}\n\n}  // end of anonymous namespace\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nvoid dispatch_scaled_masked_softmax_backward_new(output_t* grad_input, input_t* grad, const input_t* output,\n                                                 const acc_t scale, int query_seq_len, int key_seq_len, int batches,\n                                                 int attn_heads) {\n  if (key_seq_len == 0) {\n    return;\n  } else {\n    int batch_count = batches * attn_heads * query_seq_len;\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n    int num_warps = (key_seq_len - 1) / C10_WARP_SIZE + 1;\n    dim3 blocks(batch_count, 1, 1);\n    dim3 threads(threads_per_block, 1, 1);\n\n    scaled_masked_softmax_warp_backward_new<input_t, output_t, acc_t, 12>\n        <<<blocks, threads, sizeof(input_t) * key_seq_len * 2 + sizeof(acc_t) * num_warps,\n           at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, key_seq_len);\n  }\n}\n\n/*\n * Extended softmax (from native aten pytorch) with following additional features\n * 1) input scaling\n * 2) Explicit masking\n */\ntemplate <typename input_t, typename output_t, typename acc_t>\n__global__ void scaled_masked_softmax_warp_forward_new(output_t* dst, const input_t* src, const uint8_t* mask,\n                                                       const acc_t scale,\n                                                       int query_len,  // query_len\n                                                       int attn_heads,\n                                                       int element_count,  // key_len\n                                                       int pad_batches)    // mask batch size\n{\n  // min threawds_per_block has to be bigger than 128\n  int threads_per_block = blockDim.x;\n  //  the first element_count is used for cache, the last 128 is used for reduction\n  extern __shared__ acc_t local_data[];\n  // maximum shared cached 128, enough for 4096 elements reduction into 4096/32= 128 elements\n  acc_t* shared = &(local_data[element_count]);\n  // number of 1024 threads reductions\n  int num_reductions = (element_count - 1) / threads_per_block + 1;\n\n  int offset = blockIdx.x * element_count;\n  int mask_offset;\n  int query_id = blockIdx.x % query_len;\n  if (pad_batches == 1) {\n    // broadcaste the mask tensor\n    mask_offset = query_id * element_count;\n  } else {\n    int mask_batch_id = blockIdx.x / attn_heads / query_len;\n    mask_offset = (mask_batch_id * query_len + query_id) * element_count;\n  }\n\n  int local_idx = threadIdx.x;\n  int lane = threadIdx.x % C10_WARP_SIZE;\n  int wid = threadIdx.x / C10_WARP_SIZE;\n  int warps_per_thread_block = threads_per_block / C10_WARP_SIZE;\n\n  // load the data to local data\n  for (int i = local_idx; i < element_count; i += threads_per_block) {\n    // TODO, use the copy vector method\n    if (mask[mask_offset + i] == 1) {\n      local_data[i] = -10000.0;\n    } else {\n      local_data[i] = src[offset + i] * scale;\n    }\n  }\n\n  // first find the max value\n  for (int i = local_idx; i < (element_count - 1) / C10_WARP_SIZE + 1; i += threads_per_block) {\n    shared[i] = -10000.0;\n  }\n  __syncthreads();\n  acc_t val = -10000.0;\n#pragma unroll\n  for (int i = 0; i < num_reductions; i++) {\n    if (i * threads_per_block + local_idx < element_count) {\n      val = local_data[i * threads_per_block + local_idx];\n    } else {\n      val = -10000.0;\n    }\n    __syncthreads();\n    val = warp_reduce_new<acc_t, C10_WARP_SIZE, Max>(val);\n\n    if (lane == 0 && wid + warps_per_thread_block * i < (element_count - 1) / C10_WARP_SIZE + 1) {\n      shared[wid + warps_per_thread_block * i] = val;\n    }\n    __syncthreads();\n  }\n\n  // final shared reduction\n  int shared_mem_len = (element_count - 1) / C10_WARP_SIZE + 1;\n  int num_warps = (shared_mem_len - 1) / C10_WARP_SIZE + 1;\n  while (shared_mem_len > 1) {\n#pragma unroll\n    for (int i = 0; i < num_reductions; i++) {\n      if (i * threads_per_block + local_idx < shared_mem_len) {\n        val = shared[i * threads_per_block + local_idx];\n      } else {\n        val = -10000.0;\n      }\n      __syncthreads();\n      val = warp_reduce_new<acc_t, C10_WARP_SIZE, Max>(val);\n      if (lane == 0) {\n        shared[wid + warps_per_thread_block * i] = val;\n      }\n      __syncthreads();\n    }\n    shared_mem_len = num_warps;\n    num_warps = (shared_mem_len - 1) / C10_WARP_SIZE + 1;\n  }\n\n  acc_t reduced_val = shared[0];\n  if (reduced_val < -10000.0 + 0.1) {\n// if everything is masked, pay attention to nothing\n#pragma unroll\n    for (int i = local_idx; i < element_count; i += threads_per_block) {\n      dst[offset + i] = 0.0;\n    }\n    return;\n  }\n\n// update the values\n#pragma unroll\n  for (int i = local_idx; i < element_count; i += threads_per_block) {\n    local_data[i] = std::exp(local_data[i] - reduced_val);\n  }\n\n  // find the sum\n  for (int i = local_idx; i < (element_count - 1) / C10_WARP_SIZE + 1; i += threads_per_block) {\n    shared[i] = 0.0;\n  }\n  __syncthreads();\n\n#pragma unroll\n  for (int i = 0; i < num_reductions; i++) {\n    if (i * threads_per_block + local_idx < element_count) {\n      val = local_data[i * threads_per_block + local_idx];\n    } else {\n      val = 0.0;\n    }\n    __syncthreads();\n\n    val = warp_reduce_new<acc_t, C10_WARP_SIZE, Add>(val);\n    if (lane == 0 && wid + warps_per_thread_block * i < (element_count - 1) / C10_WARP_SIZE + 1) {\n      shared[wid + warps_per_thread_block * i] = val;\n    }\n    __syncthreads();\n  }\n\n  shared_mem_len = (element_count - 1) / C10_WARP_SIZE + 1;\n  num_warps = (shared_mem_len - 1) / C10_WARP_SIZE + 1;\n  while (shared_mem_len > 1) {\n#pragma unroll\n    for (int i = 0; i < num_reductions; i++) {\n      if (i * threads_per_block + local_idx < shared_mem_len) {\n        val = shared[i * threads_per_block + local_idx];\n      } else {\n        val = 0.0;\n      }\n      __syncthreads();\n      val = warp_reduce_new<acc_t, C10_WARP_SIZE, Add>(val);\n      if (lane == 0) {\n        shared[wid + warps_per_thread_block * i] = val;\n      }\n      __syncthreads();\n    }\n    shared_mem_len = num_warps;\n    num_warps = (shared_mem_len - 1) / C10_WARP_SIZE + 1;\n  }\n\n  reduced_val = shared[0];\n\n#pragma unroll\n  for (int i = local_idx; i < element_count; i += threads_per_block) {\n    dst[offset + i] = local_data[i] / reduced_val;\n  }\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nvoid dispatch_scaled_masked_softmax_forward_new(output_t* dst, const input_t* src, const uint8_t* mask,\n                                                const input_t scale, int query_seq_len, int key_seq_len, int batches,\n                                                int attn_heads, int pad_batches) {\n  if (key_seq_len == 0) {\n    return;\n  } else {\n    int batch_count = batches * attn_heads * query_seq_len;\n\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n\n    // calculate the needed shared memory\n    int num_warps = (key_seq_len - 1) / C10_WARP_SIZE + 1;\n\n    dim3 blocks(batch_count, 1, 1);\n    dim3 threads(threads_per_block, 1, 1);\n    scaled_masked_softmax_warp_forward_new<input_t, output_t, acc_t>\n        <<<blocks, threads, sizeof(acc_t) * (key_seq_len + num_warps), at::cuda::getCurrentCUDAStream()>>>(\n            dst, src, mask, scale, query_seq_len, attn_heads, key_seq_len, pad_batches);\n  }\n}\n"
  },
  {
    "path": "csrc/megatron/generic_scaled_masked_softmax_cuda.cu",
    "content": "/* coding=utf-8\n * Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include <cuda_runtime.h>\n#include <torch/extension.h>\n\n#include \"generic_scaled_masked_softmax.h\"\n#include \"type_shim.h\"\n\nnamespace multihead_attn {\nnamespace fused_softmax {\nnamespace generic_scaled_masked_softmax {\n\ntorch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor) {\n  // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]\n  const int batches = input.size(0);\n  const int pad_batches = mask.size(0);\n  const int attn_heads = input.size(1);\n  const int query_seq_len = input.size(2);\n  const int key_seq_len = input.size(3);\n  TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);\n  TORCH_INTERNAL_ASSERT(mask.size(1) == 1);\n  TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);\n  TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);\n\n  // Output\n  auto act_options = input.options().requires_grad(false);\n  torch::Tensor softmax_results = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);\n\n  // Softmax Intermediate Result Ptr\n  void* input_ptr = static_cast<void*>(input.data_ptr());\n  void* mask_ptr = static_cast<void*>(mask.data_ptr());\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n\n  DISPATCH_HALF_AND_BFLOAT(input.scalar_type(), \"dispatch_scaled_masked_softmax_forward\",\n                           dispatch_scaled_masked_softmax_forward_new<scalar_t, scalar_t, float>(\n                               reinterpret_cast<scalar_t*>(softmax_results_ptr),\n                               reinterpret_cast<const scalar_t*>(input_ptr), reinterpret_cast<const uint8_t*>(mask_ptr),\n                               scale_factor, query_seq_len, key_seq_len, batches, attn_heads, pad_batches););\n  return softmax_results;\n}\n\ntorch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor const& softmax_results_, float scale_factor) {\n  auto output_grads = output_grads_.contiguous();\n  auto softmax_results = softmax_results_.contiguous();\n\n  // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]\n  const int batches = output_grads.size(0);\n  const int attn_heads = output_grads.size(1);\n  const int query_seq_len = output_grads.size(2);\n  const int key_seq_len = output_grads.size(3);\n\n  auto act_options = output_grads.options();\n  torch::Tensor input_grad = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);\n\n  void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());\n\n  // Softmax Grad\n  DISPATCH_HALF_AND_BFLOAT(\n      output_grads_.scalar_type(), \"dispatch_scaled_masked_softmax_backward\",\n      dispatch_scaled_masked_softmax_backward_new<scalar_t, scalar_t, float>(\n          reinterpret_cast<scalar_t*>(static_cast<void*>(input_grad.data_ptr())),\n          reinterpret_cast<scalar_t*>(output_grads_ptr), reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),\n          scale_factor, query_seq_len, key_seq_len, batches, attn_heads););\n\n  // backward pass is completely in-place\n  return input_grad;\n}\n}  // namespace generic_scaled_masked_softmax\n}  // namespace fused_softmax\n}  // namespace multihead_attn\n"
  },
  {
    "path": "csrc/megatron/scaled_masked_softmax.cpp",
    "content": "/* coding=utf-8\n * Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <cuda_fp16.h>\n#include <torch/extension.h>\n\n#include <vector>\n\nnamespace multihead_attn {\nnamespace fused_softmax {\nnamespace scaled_masked_softmax {\n\ntorch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor);\n\ntorch::Tensor bwd_cuda(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor);\n\nint get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads);\n\ntorch::Tensor fwd(torch::Tensor& input, torch::Tensor& mask, float scale_factor) {\n  TORCH_CHECK(input.dim() == 4, \"expected 4D tensor\");\n  TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16),\n              \"Only fp16 and bf16 are supported\");\n  TORCH_CHECK(mask.dim() == 4, \"expected 4D tensor\");\n  if (!input.is_contiguous()) input = input.contiguous();\n  if (!mask.is_contiguous()) mask = mask.contiguous();\n\n  return fwd_cuda(input, mask, scale_factor);\n}\n\ntorch::Tensor bwd(torch::Tensor& output_grads, torch::Tensor& softmax_results, float scale_factor) {\n  TORCH_CHECK(output_grads.dim() == 4, \"expected 3D tensor\");\n  TORCH_CHECK(softmax_results.dim() == 4, \"expected 3D tensor\");\n\n  TORCH_CHECK(\n      (output_grads.scalar_type() == at::ScalarType::Half) || (output_grads.scalar_type() == at::ScalarType::BFloat16),\n      \"Only fp16 and bf16 are supported\");\n  TORCH_CHECK((softmax_results.scalar_type() == at::ScalarType::Half) ||\n                  (softmax_results.scalar_type() == at::ScalarType::BFloat16),\n              \"Only fp16 and bf16 are supported\");\n  if (!output_grads.is_contiguous()) output_grads = output_grads.contiguous();\n  if (!softmax_results.is_contiguous()) softmax_results = softmax_results.contiguous();\n\n  return bwd_cuda(output_grads, softmax_results, scale_factor);\n}\n\nint get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads) {\n  return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);\n}\n\n}  // end namespace scaled_masked_softmax\n}  // end namespace fused_softmax\n}  // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd,\n        \"Self Multihead Attention scaled, time masked softmax -- Forward.\", py::call_guard<py::gil_scoped_release>());\n\n  m.def(\"backward\", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd,\n        \"Self Multihead Attention scaled, time masked softmax -- Backward.\", py::call_guard<py::gil_scoped_release>());\n\n  m.def(\"get_batch_per_block\", &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,\n        \"Return Batch per block size.\", py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "csrc/megatron/scaled_masked_softmax.h",
    "content": "/* coding=utf-8\n * Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include <assert.h>\n#include <c10/macros/Macros.h>\n#include <cuda_fp16.h>\n#include <stdint.h>\n\n#include <cfloat>\n#include <limits>\n\nnamespace {\n\ntemplate <typename Datatype, int ELEMENTS_PER_LDG>\n__device__ __inline__ void copy_vector(Datatype* dst, const Datatype* src);\n\ntemplate <>\n__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16* dst, const c10::BFloat16* src) {\n  *dst = *src;\n}\n\ntemplate <>\n__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16* dst, const c10::BFloat16* src) {\n  *((float2*)dst) = *((float2*)src);\n}\n\ntemplate <>\n__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half* dst, const c10::Half* src) {\n  *dst = *src;\n}\n\ntemplate <>\n__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half* dst, const c10::Half* src) {\n  *((float2*)dst) = *((float2*)src);\n}\n\ntemplate <>\n__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t* dst, const uint8_t* src) {\n  *dst = *src;\n}\n\ntemplate <>\n__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t* dst, const uint8_t* src) {\n  *((half2*)dst) = *((half2*)src);\n}\n\nint log2_ceil(int value) {\n  int log2_value = 0;\n  while ((1 << log2_value) < value) ++log2_value;\n  return log2_value;\n}\n\ntemplate <typename T>\nstruct Add {\n  __device__ __forceinline__ T operator()(T a, T b) const { return a + b; }\n};\n\ntemplate <typename T>\nstruct Max {\n  __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; }\n};\n\ntemplate <typename T>\n__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,\n                                                  unsigned int mask = 0xffffffff) {\n#if CUDA_VERSION >= 9000\n  return __shfl_xor_sync(mask, value, laneMask, width);\n#else\n  return __shfl_xor(value, laneMask, width);\n#endif\n}\n\ntemplate <typename acc_t, int WARP_BATCH, int WARP_SIZE, template <typename> class ReduceOp>\n__device__ __forceinline__ void warp_reduce(acc_t* sum) {\n  ReduceOp<acc_t> r;\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);\n      sum[i] = r(sum[i], b);\n    }\n  }\n}\n\n/*\n * Extended softmax (from native aten pytorch) with following additional features\n * 1) input scaling\n */\ntemplate <typename input_t, typename output_t, typename acc_t, int log2_elements>\n__global__ void scaled_softmax_warp_forward(output_t* dst, const input_t* src, const acc_t scale, int micro_batch_size,\n                                            int element_count) {\n  // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and\n  // warp_size of method warp_softmax_forward_kernel.\n  constexpr int next_power_of_two = 1 << log2_elements;\n  constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n  constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;\n  constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;\n  constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;\n\n  // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )\n  // gridDim/blockIdx = (seq_len, attn_heads, batches)\n  long int first_batch =\n      (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + threadIdx.y) * WARP_BATCH;\n\n  // micro_batch_size might not be a multiple of WARP_BATCH. Check how\n  // many batches have to computed within this WARP.\n  int local_batches = micro_batch_size - first_batch;\n  if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;\n\n  // there might be multiple batches per warp. compute the index within the batch\n  int local_idx = threadIdx.x;\n\n  long int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;\n  src += thread_offset;\n  dst += thread_offset;\n\n  // load data from global memory\n  acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n  input_t temp_data[ELEMENTS_PER_LDG_STG];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    int batch_element_count = (i >= local_batches) ? 0 : element_count;\n\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n\n      if (element_index < batch_element_count) {\n        int itr_idx = i * element_count + it * WARP_SIZE;\n        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);\n\n#pragma unroll\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          elements[i][it + element] = (acc_t)temp_data[element] * scale;\n        }\n      } else {\n#pragma unroll\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();\n        }\n      }\n    }\n  }\n\n  // compute max_value\n  acc_t max_value[WARP_BATCH];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    max_value[i] = elements[i][0];\n#pragma unroll\n    for (int it = 1; it < WARP_ITERATIONS; ++it) {\n      max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n    }\n  }\n  warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);\n\n  acc_t sum[WARP_BATCH]{0.0f};\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      elements[i][it] = std::exp((elements[i][it] - max_value[i]));\n      sum[i] += elements[i][it];\n    }\n  }\n  warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);\n\n  // store result\n  output_t out[ELEMENTS_PER_LDG_STG];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    if (i >= local_batches) break;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n      if (element_index < element_count) {\n#pragma unroll\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          out[element] = elements[i][it + element] / sum[i];\n        }\n        copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);\n      } else {\n        break;\n      }\n    }\n  }\n}\n\n/*\n * Extended softmax (from native aten pytorch) with following additional features\n * 1) input scaling\n * 2) Explicit masking\n */\ntemplate <typename input_t, typename output_t, typename acc_t, int log2_elements>\n__global__ void scaled_masked_softmax_warp_forward(output_t* dst, const input_t* src, const uint8_t* mask,\n                                                   const acc_t scale, int micro_batch_size, int element_count,\n                                                   int pad_batches) {\n  // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and\n  // warp_size of method warp_softmax_forward_kernel.\n  constexpr int next_power_of_two = 1 << log2_elements;\n  constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n  constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;\n  constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;\n  constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;\n\n  // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )\n  // gridDim/blockIdx = (seq_len, attn_heads, batches)\n  long int first_batch =\n      (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + threadIdx.y) * WARP_BATCH;\n  long int pad_first_batch = 0;\n  if (pad_batches != 1) {  // bert style\n    pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;\n  } else {  // gpt2 style\n    pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n  }\n\n  // micro_batch_size might not be a multiple of WARP_BATCH. Check how\n  // many batches have to computed within this WARP.\n  int local_batches = micro_batch_size - first_batch;\n  if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;\n\n  // there might be multiple batches per warp. compute the index within the batch\n  int local_idx = threadIdx.x;\n\n  long int thread_offset_src_dst = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;\n  long int thread_offset_mask = pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;\n  src += thread_offset_src_dst;\n  dst += thread_offset_src_dst;\n  mask += thread_offset_mask;\n\n  // load data from global memory\n  acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n  input_t temp_data[ELEMENTS_PER_LDG_STG];\n  uint8_t temp_mask[ELEMENTS_PER_LDG_STG];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    int batch_element_count = (i >= local_batches) ? 0 : element_count;\n\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n\n      if (element_index < batch_element_count) {\n        int itr_idx = i * element_count + it * WARP_SIZE;\n        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);\n        copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);\n\n#pragma unroll\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          if (temp_mask[element] != 1) {\n            elements[i][it + element] = (acc_t)temp_data[element] * scale;\n          } else {\n            elements[i][it + element] = -10000.0;\n          }\n        }\n      } else {\n#pragma unroll\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();\n        }\n      }\n    }\n  }\n\n  // compute max_value\n  acc_t max_value[WARP_BATCH];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    max_value[i] = elements[i][0];\n#pragma unroll\n    for (int it = 1; it < WARP_ITERATIONS; ++it) {\n      max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n    }\n  }\n  warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);\n\n  // compute scale value to account for full mask\n  acc_t scale_value[WARP_BATCH];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0;\n  }\n\n  acc_t sum[WARP_BATCH]{0.0f};\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      elements[i][it] = std::exp((elements[i][it] - max_value[i]));\n      sum[i] += elements[i][it];\n    }\n  }\n  warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);\n\n  // store result\n  output_t out[ELEMENTS_PER_LDG_STG];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    if (i >= local_batches) break;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n      if (element_index < element_count) {\n#pragma unroll\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          out[element] = elements[i][it + element] * scale_value[i] / sum[i];\n        }\n        copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);\n      } else {\n        break;\n      }\n    }\n  }\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t, int log2_elements>\n__global__ void scaled_masked_softmax_warp_backward(output_t* gradInput, input_t* grad, const input_t* output,\n                                                    acc_t scale, int micro_batch_size, int element_count) {\n  // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and\n  // warp_size of method warp_softmax_backward_kernel.\n  constexpr int next_power_of_two = 1 << log2_elements;\n  constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n  constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;\n  constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;\n  constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;\n\n  // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )\n  // gridDim/blockIdx = (seq_len, attn_heads, batches)\n  long int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;\n\n  // micro_batch_size might not be a multiple of WARP_BATCH. Check how\n  // many batches have to computed within this WARP.\n  int local_batches = micro_batch_size - first_batch;\n  if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;\n\n  // there might be multiple batches per warp. compute the index within the batch\n  int local_idx = threadIdx.x;\n\n  // the first element to process by the current thread\n  long int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;\n  grad += thread_offset;\n  output += thread_offset;\n  gradInput += thread_offset;\n\n  // load data from global memory\n  acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};\n  acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};\n  input_t temp_grad[ELEMENTS_PER_LDG_STG];\n  input_t temp_output[ELEMENTS_PER_LDG_STG];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    int batch_element_count = (i >= local_batches) ? 0 : element_count;\n\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n      if (element_index < batch_element_count) {\n        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE);\n        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE);\n\n#pragma unroll\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          output_reg[i][it + element] = (acc_t)temp_output[element];\n        }\n#pragma unroll\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];\n        }\n      }\n    }\n  }\n\n  acc_t sum[WARP_BATCH];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    sum[i] = grad_reg[i][0];\n#pragma unroll\n    for (int it = 1; it < WARP_ITERATIONS; ++it) {\n      sum[i] += grad_reg[i][it];\n    }\n  }\n  warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);\n\n// store result\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    if (i >= local_batches) break;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n      if (element_index < element_count) {\n        // compute gradients\n        output_t out[ELEMENTS_PER_LDG_STG];\n#pragma unroll\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));\n        }\n        copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);\n      }\n    }\n  }\n}\n}  // end of anonymous namespace\n\nint get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads) {\n  int log2_elements = log2_ceil(key_seq_len);\n  const int next_power_of_two = 1 << log2_elements;\n\n  int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n  int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n\n  constexpr int threads_per_block = 128;\n  int warps_per_block = (threads_per_block / warp_size);\n  int batches_per_block = warps_per_block * batches_per_warp;\n\n  return batches_per_block;\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nvoid dispatch_scaled_softmax_forward(output_t* dst, const input_t* src, const input_t scale, int query_seq_len,\n                                     int key_seq_len, int batches, int attn_heads) {\n  TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 16384);\n  if (key_seq_len == 0) {\n    return;\n  } else {\n    int log2_elements = log2_ceil(key_seq_len);\n    const int next_power_of_two = 1 << log2_elements;\n    int batch_count = batches * attn_heads * query_seq_len;\n\n    // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.\n    int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n\n    // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.\n    int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n\n    int warps_per_block = (threads_per_block / warp_size);\n    int batches_per_block = warps_per_block * batches_per_warp;\n    TORCH_INTERNAL_ASSERT(query_seq_len % batches_per_block == 0);\n    dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches);\n    dim3 threads(warp_size, warps_per_block, 1);\n    // Launch code would be more elegant if C++ supported FOR CONSTEXPR\n    switch (log2_elements) {\n      case 0:  // 1\n        scaled_softmax_warp_forward<input_t, output_t, acc_t, 0>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);\n        break;\n      case 1:  // 2\n        scaled_softmax_warp_forward<input_t, output_t, acc_t, 1>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);\n        break;\n      case 2:  // 4\n        scaled_softmax_warp_forward<input_t, output_t, acc_t, 2>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);\n        break;\n      case 3:  // 8\n        scaled_softmax_warp_forward<input_t, output_t, acc_t, 3>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);\n        break;\n      case 4:  // 16\n        scaled_softmax_warp_forward<input_t, output_t, acc_t, 4>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);\n        break;\n      case 5:  // 32\n        scaled_softmax_warp_forward<input_t, output_t, acc_t, 5>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);\n        break;\n      case 6:  // 64\n        scaled_softmax_warp_forward<input_t, output_t, acc_t, 6>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);\n        break;\n      case 7:  // 128\n        scaled_softmax_warp_forward<input_t, output_t, acc_t, 7>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);\n        break;\n      case 8:  // 256\n        scaled_softmax_warp_forward<input_t, output_t, acc_t, 8>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);\n        break;\n      case 9:  // 512\n        scaled_softmax_warp_forward<input_t, output_t, acc_t, 9>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);\n        break;\n      case 10:  // 1024\n        scaled_softmax_warp_forward<input_t, output_t, acc_t, 10>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);\n        break;\n      case 11:  // 2048\n        scaled_softmax_warp_forward<input_t, output_t, acc_t, 11>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);\n        break;\n      case 12:  // 4096\n        scaled_softmax_warp_forward<input_t, output_t, acc_t, 12>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);\n        break;\n      case 13:  // 8192\n        scaled_softmax_warp_forward<input_t, output_t, acc_t, 13>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);\n        break;\n      case 14:  // 16384\n        scaled_softmax_warp_forward<input_t, output_t, acc_t, 14>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);\n        break;\n      default:\n        break;\n    }\n  }\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nvoid dispatch_scaled_masked_softmax_forward(output_t* dst, const input_t* src, const uint8_t* mask, const input_t scale,\n                                            int query_seq_len, int key_seq_len, int batches, int attn_heads,\n                                            int pad_batches) {\n  TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096);\n  if (key_seq_len == 0) {\n    return;\n  } else {\n    int log2_elements = log2_ceil(key_seq_len);\n    const int next_power_of_two = 1 << log2_elements;\n    int batch_count = batches * attn_heads * query_seq_len;\n\n    // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.\n    int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n\n    // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.\n    int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n\n    int warps_per_block = (threads_per_block / warp_size);\n    int batches_per_block = warps_per_block * batches_per_warp;\n    TORCH_INTERNAL_ASSERT(query_seq_len % batches_per_block == 0);\n    dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches);\n    dim3 threads(warp_size, warps_per_block, 1);\n    // Launch code would be more elegant if C++ supported FOR CONSTEXPR\n    switch (log2_elements) {\n      case 0:  // 1\n        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len,\n                                                                       pad_batches);\n        break;\n      case 1:  // 2\n        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len,\n                                                                       pad_batches);\n        break;\n      case 2:  // 4\n        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len,\n                                                                       pad_batches);\n        break;\n      case 3:  // 8\n        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len,\n                                                                       pad_batches);\n        break;\n      case 4:  // 16\n        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len,\n                                                                       pad_batches);\n        break;\n      case 5:  // 32\n        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len,\n                                                                       pad_batches);\n        break;\n      case 6:  // 64\n        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len,\n                                                                       pad_batches);\n        break;\n      case 7:  // 128\n        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len,\n                                                                       pad_batches);\n        break;\n      case 8:  // 256\n        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len,\n                                                                       pad_batches);\n        break;\n      case 9:  // 512\n        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len,\n                                                                       pad_batches);\n        break;\n      case 10:  // 1024\n        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len,\n                                                                       pad_batches);\n        break;\n      case 11:  // 2048\n        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len,\n                                                                       pad_batches);\n        break;\n      case 12:  // 4096\n        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len,\n                                                                       pad_batches);\n        break;\n      default:\n        break;\n    }\n  }\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nvoid dispatch_scaled_masked_softmax_backward(output_t* grad_input, input_t* grad, const input_t* output,\n                                             const acc_t scale, int query_seq_len, int key_seq_len, int batches,\n                                             int attn_heads) {\n  TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096);\n  if (key_seq_len == 0) {\n    return;\n  } else {\n    int log2_elements = log2_ceil(key_seq_len);\n    const int next_power_of_two = 1 << log2_elements;\n    int batch_count = batches * attn_heads * query_seq_len;\n\n    // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.\n    int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n\n    // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.\n    int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n\n    int warps_per_block = (threads_per_block / warp_size);\n    int batches_per_block = warps_per_block * batches_per_warp;\n    int blocks = batch_count / batches_per_block;\n    dim3 threads(warp_size, warps_per_block, 1);\n    // Launch code would be more elegant if C++ supported FOR CONSTEXPR\n    switch (log2_elements) {\n      case 0:  // 1\n        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       key_seq_len);\n        break;\n      case 1:  // 2\n        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       key_seq_len);\n        break;\n      case 2:  // 4\n        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       key_seq_len);\n        break;\n      case 3:  // 8\n        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       key_seq_len);\n        break;\n      case 4:  // 16\n        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       key_seq_len);\n        break;\n      case 5:  // 32\n        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       key_seq_len);\n        break;\n      case 6:  // 64\n        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       key_seq_len);\n        break;\n      case 7:  // 128\n        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       key_seq_len);\n        break;\n      case 8:  // 256\n        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       key_seq_len);\n        break;\n      case 9:  // 512\n        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       key_seq_len);\n        break;\n      case 10:  // 1024\n        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       key_seq_len);\n        break;\n      case 11:  // 2048\n        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       key_seq_len);\n        break;\n      case 12:  // 4096\n        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       key_seq_len);\n        break;\n\n      default:\n        break;\n    }\n  }\n}\n"
  },
  {
    "path": "csrc/megatron/scaled_masked_softmax_cuda.cu",
    "content": "/* coding=utf-8\n * Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include <cuda_runtime.h>\n#include <torch/extension.h>\n\n#include \"scaled_masked_softmax.h\"\n#include \"type_shim.h\"\n\nnamespace multihead_attn {\nnamespace fused_softmax {\nnamespace scaled_masked_softmax {\n\nint get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads) {\n  return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);\n}\n\ntorch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor) {\n  // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]\n  const int batches = input.size(0);\n  const int pad_batches = mask.size(0);\n  const int attn_heads = input.size(1);\n  const int query_seq_len = input.size(2);\n  const int key_seq_len = input.size(3);\n  TORCH_INTERNAL_ASSERT(key_seq_len <= 16384);\n  TORCH_INTERNAL_ASSERT(query_seq_len > 1);\n  TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);\n  TORCH_INTERNAL_ASSERT(mask.size(1) == 1);\n  TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);\n  TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);\n\n  // Output\n  auto act_options = input.options().requires_grad(false);\n  torch::Tensor softmax_results = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);\n\n  // Softmax Intermediate Result Ptr\n  void* input_ptr = static_cast<void*>(input.data_ptr());\n  void* mask_ptr = static_cast<void*>(mask.data_ptr());\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n\n  DISPATCH_HALF_AND_BFLOAT(input.scalar_type(), \"dispatch_scaled_masked_softmax_forward\",\n                           dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(\n                               reinterpret_cast<scalar_t*>(softmax_results_ptr),\n                               reinterpret_cast<const scalar_t*>(input_ptr), reinterpret_cast<const uint8_t*>(mask_ptr),\n                               scale_factor, query_seq_len, key_seq_len, batches, attn_heads, pad_batches););\n  return softmax_results;\n}\n\ntorch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor const& softmax_results_, float scale_factor) {\n  auto output_grads = output_grads_.contiguous();\n  auto softmax_results = softmax_results_.contiguous();\n\n  // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]\n  const int batches = output_grads.size(0);\n  const int attn_heads = output_grads.size(1);\n  const int query_seq_len = output_grads.size(2);\n  const int key_seq_len = output_grads.size(3);\n\n  auto act_options = output_grads.options().requires_grad(false);\n  torch::Tensor input_grads = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);\n  void* input_grads_ptr = static_cast<void*>(input_grads.data_ptr());\n  void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());\n\n  // Softmax Grad\n  DISPATCH_HALF_AND_BFLOAT(\n      output_grads_.scalar_type(), \"dispatch_scaled_masked_softmax_backward\",\n      dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(\n          reinterpret_cast<scalar_t*>(input_grads_ptr), reinterpret_cast<scalar_t*>(output_grads_ptr),\n          reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()), scale_factor, query_seq_len, key_seq_len,\n          batches, attn_heads););\n  return input_grads;\n}\n}  // namespace scaled_masked_softmax\n}  // namespace fused_softmax\n}  // namespace multihead_attn\n"
  },
  {
    "path": "csrc/megatron/scaled_softmax.cpp",
    "content": "/* coding=utf-8\n * Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <cuda_fp16.h>\n#include <torch/extension.h>\n\n#include <vector>\n\nnamespace multihead_attn {\nnamespace fused_softmax {\nnamespace scaled_softmax {\n\ntorch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor);\n\ntorch::Tensor bwd_cuda(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor);\n\ntorch::Tensor fwd(torch::Tensor const& input, float scale_factor) {\n  TORCH_CHECK(input.dim() == 4, \"expected 4D tensor\");\n  TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16),\n              \"Only fp16 and bf16 are supported\");\n\n  return fwd_cuda(input, scale_factor);\n}\n\ntorch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor) {\n  TORCH_CHECK(output_grads.dim() == 4, \"expected 3D tensor\");\n  TORCH_CHECK(softmax_results.dim() == 4, \"expected 3D tensor\");\n\n  TORCH_CHECK(\n      (output_grads.scalar_type() == at::ScalarType::Half) || (output_grads.scalar_type() == at::ScalarType::BFloat16),\n      \"Only fp16 and bf16 are supported\");\n  TORCH_CHECK((softmax_results.scalar_type() == at::ScalarType::Half) ||\n                  (softmax_results.scalar_type() == at::ScalarType::BFloat16),\n              \"Only fp16 and bf16 are supported\");\n\n  return bwd_cuda(output_grads, softmax_results, scale_factor);\n}\n\n}  // end namespace scaled_softmax\n}  // end namespace fused_softmax\n}  // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::fused_softmax::scaled_softmax::fwd,\n        \"Self Multihead Attention scaled, softmax -- Forward.\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward\", &multihead_attn::fused_softmax::scaled_softmax::bwd,\n        \"Self Multihead Attention scaled, softmax -- Backward.\", py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "csrc/megatron/scaled_softmax_cuda.cu",
    "content": "/* coding=utf-8\n * Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include <cuda_runtime.h>\n#include <torch/extension.h>\n\n#include \"scaled_masked_softmax.h\"\n#include \"type_shim.h\"\n\nnamespace multihead_attn {\nnamespace fused_softmax {\nnamespace scaled_softmax {\n\ntorch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) {\n  // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]\n  const int batches = input.size(0);\n  const int attn_heads = input.size(1);\n  const int query_seq_len = input.size(2);\n  const int key_seq_len = input.size(3);\n  TORCH_INTERNAL_ASSERT(key_seq_len <= 16384);\n  TORCH_INTERNAL_ASSERT(query_seq_len > 1);\n\n  // Output\n  auto act_options = input.options().requires_grad(false);\n  torch::Tensor softmax_results = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);\n\n  // Softmax Intermediate Result Ptr\n  void* input_ptr = static_cast<void*>(input.data_ptr());\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n\n  DISPATCH_HALF_AND_BFLOAT(\n      input.scalar_type(), \"dispatch_scaled_softmax_forward\",\n      dispatch_scaled_softmax_forward<scalar_t, scalar_t, float>(\n          reinterpret_cast<scalar_t*>(softmax_results_ptr), reinterpret_cast<const scalar_t*>(input_ptr), scale_factor,\n          query_seq_len, key_seq_len, batches, attn_heads););\n  return softmax_results;\n}\n\ntorch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor const& softmax_results_, float scale_factor) {\n  auto output_grads = output_grads_.contiguous();\n  auto softmax_results = softmax_results_.contiguous();\n\n  // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]\n  const int batches = output_grads.size(0);\n  const int attn_heads = output_grads.size(1);\n  const int query_seq_len = output_grads.size(2);\n  const int key_seq_len = output_grads.size(3);\n\n  void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());\n\n  // Softmax Grad\n  DISPATCH_HALF_AND_BFLOAT(\n      output_grads_.scalar_type(), \"dispatch_scaled_masked_softmax_backward\",\n      dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(\n          reinterpret_cast<scalar_t*>(output_grads_ptr), reinterpret_cast<scalar_t*>(output_grads_ptr),\n          reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()), scale_factor, query_seq_len, key_seq_len,\n          batches, attn_heads););\n\n  // backward pass is completely in-place\n  return output_grads;\n}\n}  // namespace scaled_softmax\n}  // namespace fused_softmax\n}  // namespace multihead_attn\n"
  },
  {
    "path": "csrc/megatron/scaled_upper_triang_masked_softmax.cpp",
    "content": "/* coding=utf-8\n * Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <cuda_fp16.h>\n#include <torch/extension.h>\n\n#include <vector>\n\nnamespace multihead_attn {\nnamespace fused_softmax {\nnamespace scaled_upper_triang_masked_softmax {\n\ntorch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor);\n\ntorch::Tensor bwd_cuda(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor);\n\ntorch::Tensor fwd(torch::Tensor const& input, float scale_factor) {\n  TORCH_CHECK(input.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16),\n              \"Only fp16 and bf16 are supported\");\n\n  return fwd_cuda(input, scale_factor);\n}\n\ntorch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor) {\n  TORCH_CHECK(output_grads.dim() == 3, \"expected 3D tensor\");\n  TORCH_CHECK(softmax_results.dim() == 3, \"expected 3D tensor\");\n\n  TORCH_CHECK(\n      (output_grads.scalar_type() == at::ScalarType::Half) || (output_grads.scalar_type() == at::ScalarType::BFloat16),\n      \"Only fp16 and bf16 are supported\");\n  TORCH_CHECK((softmax_results.scalar_type() == at::ScalarType::Half) ||\n                  (softmax_results.scalar_type() == at::ScalarType::BFloat16),\n              \"Only fp16 and bf16 are supported\");\n\n  return bwd_cuda(output_grads, softmax_results, scale_factor);\n}\n\n}  // end namespace scaled_upper_triang_masked_softmax\n}  // end namespace fused_softmax\n}  // end namespace multihead_attn\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,\n        \"Self Multihead Attention scaled, time masked softmax -- Forward.\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward\", &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,\n        \"Self Multihead Attention scaled, time masked softmax -- Backward.\", py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "csrc/megatron/scaled_upper_triang_masked_softmax.h",
    "content": "/* coding=utf-8\n * Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include <assert.h>\n#include <c10/macros/Macros.h>\n#include <cuda_fp16.h>\n#include <stdint.h>\n\n#include <cfloat>\n#include <limits>\n\nnamespace {\n\ntemplate <typename Datatype, int ELEMENTS_PER_LDG>\n__device__ __inline__ void copy_vector(Datatype* dst, const Datatype* src);\n\ntemplate <>\n__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16* dst, const c10::BFloat16* src) {\n  *dst = *src;\n}\n\ntemplate <>\n__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16* dst, const c10::BFloat16* src) {\n  *((float2*)dst) = *((float2*)src);\n}\n\ntemplate <>\n__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half* dst, const c10::Half* src) {\n  *dst = *src;\n}\n\ntemplate <>\n__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half* dst, const c10::Half* src) {\n  *((float2*)dst) = *((float2*)src);\n}\n\ntemplate <>\n__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t* dst, const uint8_t* src) {\n  *dst = *src;\n}\n\ntemplate <>\n__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t* dst, const uint8_t* src) {\n  *((half2*)dst) = *((half2*)src);\n}\n\ntemplate <typename Datatype, int ELEMENTS_PER_LDG>\n__device__ __inline__ void copy_zero_vector(Datatype* dst);\n\ntemplate <>\n__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16* dst) {\n  *dst = 0.0;\n}\n\ntemplate <>\n__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16* dst) {\n  *((float2*)dst) = make_float2(0.0f, 0.0f);\n}\n\ntemplate <>\n__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half* dst) {\n  *dst = 0.0;\n}\n\ntemplate <>\n__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half* dst) {\n  *((float2*)dst) = make_float2(0.0f, 0.0f);\n}\n\nint log2_ceil(int value) {\n  int log2_value = 0;\n  while ((1 << log2_value) < value) ++log2_value;\n  return log2_value;\n}\n\ntemplate <typename T>\nstruct Add {\n  __device__ __forceinline__ T operator()(T a, T b) const { return a + b; }\n};\n\ntemplate <typename T>\nstruct Max {\n  __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; }\n};\n\ntemplate <typename T>\n__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,\n                                                  unsigned int mask = 0xffffffff) {\n#if CUDA_VERSION >= 9000\n  return __shfl_xor_sync(mask, value, laneMask, width);\n#else\n  return __shfl_xor(value, laneMask, width);\n#endif\n}\n\ntemplate <typename acc_t, int WARP_BATCH, int WARP_SIZE, template <typename> class ReduceOp>\n__device__ __forceinline__ void warp_reduce(acc_t* sum) {\n  ReduceOp<acc_t> r;\n#pragma unroll\n  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {\n#pragma unroll\n    for (int i = 0; i < WARP_BATCH; ++i) {\n      acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);\n      sum[i] = r(sum[i], b);\n    }\n  }\n}\n\n/*\n * Extended softmax (from native aten pytorch) with following additional features\n * 1) input scaling\n * 2) Implicit time (diagonal masking)\n */\ntemplate <typename input_t, typename output_t, typename acc_t, int log2_elements>\n__global__ void scaled_upper_triang_masked_softmax_warp_forward(output_t* dst, const input_t* src, const acc_t scale,\n                                                                int micro_batch_size, int stride, int element_count) {\n  // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and\n  // warp_size of method warp_softmax_forward_kernel.\n  constexpr int next_power_of_two = 1 << log2_elements;\n  constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n  constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;\n  constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;\n  constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;\n\n  long int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;\n  int local_seq = blockIdx.x + 1;\n  int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE;\n\n  // micro_batch_size might not be a multiple of WARP_BATCH. Check how\n  // many batches have to computed within this WARP.\n  int local_batches = micro_batch_size - first_batch;\n  if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;\n\n  // there might be multiple batches per warp. compute the index within the batch\n  int local_idx = threadIdx.x;\n\n  long int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n  src += thread_offset;\n  dst += thread_offset;\n\n  // load data from global memory\n  acc_t elements[WARP_BATCH][WARP_ITERATIONS];\n  input_t temp_data[ELEMENTS_PER_LDG_STG];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    int batch_element_count = (i >= local_batches) ? 0 : local_seq;\n\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n\n      if (element_index < batch_element_count) {\n        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + i * element_count * stride + it * WARP_SIZE);\n\n#pragma unroll\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          if ((element_index + element) < batch_element_count) {\n            elements[i][it + element] = (acc_t)temp_data[element] * scale;\n          } else {\n            elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();\n          }\n        }\n      } else {\n#pragma unroll\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();\n        }\n      }\n    }\n  }\n\n  // compute max_value\n  acc_t max_value[WARP_BATCH];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    max_value[i] = elements[i][0];\n#pragma unroll\n    for (int it = 1; it < WARP_ITERATIONS; ++it) {\n      max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];\n    }\n  }\n  warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);\n\n  acc_t sum[WARP_BATCH]{0.0f};\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; ++it) {\n      if (it < warp_iteration_limit) {\n        elements[i][it] = std::exp((elements[i][it] - max_value[i]));\n        sum[i] += elements[i][it];\n      }\n    }\n  }\n  warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);\n\n  // store result\n  output_t out[ELEMENTS_PER_LDG_STG];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    if (i >= local_batches) break;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n\n      if (element_index < local_seq) {\n#pragma unroll\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          if (element_index + element < local_seq) {\n            out[element] = elements[i][it + element] / sum[i];\n          } else {\n            out[element] = 0;\n          }\n        }\n        copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out);\n      } else if (element_index < element_count) {\n        copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE);\n      } else {\n        break;\n      }\n    }\n  }\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t, int log2_elements>\n__global__ void scaled_upper_triang_masked_softmax_warp_backward(output_t* gradInput, input_t* grad,\n                                                                 const input_t* output, acc_t scale,\n                                                                 int micro_batch_size, int stride, int element_count) {\n  // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and\n  // warp_size of method warp_softmax_backward_kernel.\n  constexpr int next_power_of_two = 1 << log2_elements;\n  constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n  constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;\n  constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;\n  constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;\n\n  long int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;\n  int local_seq = blockIdx.x + 1;\n\n  // micro_batch_size might not be a multiple of WARP_BATCH. Check how\n  // many batches have to computed within this WARP.\n  int local_batches = micro_batch_size - first_batch;\n  if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;\n\n  // there might be multiple batches per warp. compute the index within the batch\n  int local_idx = threadIdx.x;\n\n  // the first element to process by the current thread\n  long int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;\n  grad += thread_offset;\n  output += thread_offset;\n  gradInput += thread_offset;\n\n  // load data from global memory\n  acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};\n  acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};\n  input_t temp_grad[ELEMENTS_PER_LDG_STG];\n  input_t temp_output[ELEMENTS_PER_LDG_STG];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    int batch_element_count = (i >= local_batches) ? 0 : local_seq;\n\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n      if (element_index < batch_element_count) {\n        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count * stride + it * WARP_SIZE);\n        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count * stride + it * WARP_SIZE);\n\n#pragma unroll\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          if (element_index + element < batch_element_count) {\n            output_reg[i][it + element] = (acc_t)temp_output[element];\n          }\n        }\n#pragma unroll\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          if (element_index + element < batch_element_count) {\n            grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];\n          }\n        }\n      }\n    }\n  }\n\n  acc_t sum[WARP_BATCH];\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    sum[i] = grad_reg[i][0];\n#pragma unroll\n    for (int it = 1; it < WARP_ITERATIONS; ++it) {\n      sum[i] += grad_reg[i][it];\n    }\n  }\n  warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);\n\n// store result\n#pragma unroll\n  for (int i = 0; i < WARP_BATCH; ++i) {\n    if (i >= local_batches) break;\n#pragma unroll\n    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {\n      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;\n      if (element_index < element_count) {\n        // compute gradients\n        output_t out[ELEMENTS_PER_LDG_STG];\n#pragma unroll\n        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {\n          out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));\n        }\n        copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count * stride + it * WARP_SIZE, out);\n      }\n    }\n  }\n}\n\n}  // end of anonymous namespace\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nvoid dispatch_scaled_upper_triang_masked_softmax_forward(output_t* dst, const input_t* src, const input_t scale,\n                                                         int softmax_elements, int softmax_elements_stride,\n                                                         int attn_batches) {\n  TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 16384);\n  if (softmax_elements == 0) {\n    return;\n  } else {\n    int log2_elements = log2_ceil(softmax_elements);\n    const int next_power_of_two = 1 << log2_elements;\n    int seq_len = softmax_elements;\n    int batch_count = attn_batches * seq_len;\n\n    // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.\n    int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n\n    // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.\n    int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n\n    int warps_per_block = (threads_per_block / warp_size);\n    int batches_per_block = warps_per_block * batches_per_warp;\n    TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);\n\n    int blocks_per_seq = attn_batches / batches_per_block;\n    dim3 blocks(seq_len, blocks_per_seq, 1);\n    dim3 threads(warp_size, warps_per_block, 1);\n    // Launch code would be more elegant if C++ supported FOR CONSTEXPR\n    switch (log2_elements) {\n      case 0:  // 1\n        scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 1:  // 2\n        scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 2:  // 4\n        scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 3:  // 8\n        scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 4:  // 16\n        scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 5:  // 32\n        scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 6:  // 64\n        scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 7:  // 128\n        scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 8:  // 256\n        scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 9:  // 512\n        scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 10:  // 1024\n        scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 11:  // 2048\n        scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 12:  // 4096\n        scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 13:  // 8192\n        scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 13>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 14:  // 16384\n        scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 14>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      default:\n        break;\n    }\n  }\n}\n\ntemplate <typename input_t, typename output_t, typename acc_t>\nvoid dispatch_scaled_upper_triang_masked_softmax_backward(output_t* grad_input, input_t* grad, const input_t* output,\n                                                          const acc_t scale, int softmax_elements,\n                                                          int softmax_elements_stride, int attn_batches) {\n  TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 16384);\n  if (softmax_elements == 0) {\n    return;\n  } else {\n    int log2_elements = log2_ceil(softmax_elements);\n    const int next_power_of_two = 1 << log2_elements;\n    int seq_len = softmax_elements;\n    int batch_count = attn_batches * seq_len;\n\n    // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.\n    int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;\n\n    // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.\n    int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;\n\n    // use 128 threads per block to maximize gpu utilization\n    constexpr int threads_per_block = 128;\n\n    int warps_per_block = (threads_per_block / warp_size);\n    int batches_per_block = warps_per_block * batches_per_warp;\n    TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);\n\n    int blocks_per_seq = attn_batches / batches_per_block;\n    dim3 blocks(seq_len, blocks_per_seq, 1);\n    dim3 threads(warp_size, warps_per_block, 1);\n    // Launch code would be more elegant if C++ supported FOR CONSTEXPR\n    switch (log2_elements) {\n      case 0:  // 1\n        scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 1:  // 2\n        scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 2:  // 4\n        scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 3:  // 8\n        scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 4:  // 16\n        scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 5:  // 32\n        scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 6:  // 64\n        scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 7:  // 128\n        scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 8:  // 256\n        scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 9:  // 512\n        scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 10:  // 1024\n        scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 11:  // 2048\n        scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 12:  // 4096\n        scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 13:  // 8192\n        scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 13>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      case 14:  // 16384\n        scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 14>\n            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count,\n                                                                       softmax_elements_stride, softmax_elements);\n        break;\n      default:\n        break;\n    }\n  }\n}\n"
  },
  {
    "path": "csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu",
    "content": "/* coding=utf-8\n * Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_profiler_api.h>\n#include <cuda_runtime.h>\n#include <torch/extension.h>\n\n#include \"scaled_upper_triang_masked_softmax.h\"\n#include \"type_shim.h\"\n\nnamespace multihead_attn {\nnamespace fused_softmax {\nnamespace scaled_upper_triang_masked_softmax {\n\ntorch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) {\n  // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]\n  const int attn_batches = input.size(0);\n  const int seq_len = input.size(1);\n  TORCH_INTERNAL_ASSERT(seq_len <= 16384);\n\n  // Output\n  auto act_options = input.options().requires_grad(false);\n  torch::Tensor softmax_results = torch::empty({attn_batches, seq_len, seq_len}, act_options);\n\n  // Softmax Intermediate Result Ptr\n  void* input_ptr = static_cast<void*>(input.data_ptr());\n  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());\n\n  DISPATCH_HALF_AND_BFLOAT(\n      input.scalar_type(), \"dispatch_scaled_upper_triang_masked_softmax_forward\",\n      dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(\n          reinterpret_cast<scalar_t*>(softmax_results_ptr), reinterpret_cast<const scalar_t*>(input_ptr), scale_factor,\n          seq_len, seq_len, attn_batches););\n  return softmax_results;\n}\n\ntorch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor const& softmax_results_, float scale_factor) {\n  auto output_grads = output_grads_.contiguous();\n  auto softmax_results = softmax_results_.contiguous();\n\n  // output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]\n  const int attn_batches = output_grads.size(0);\n  const int seq_len = output_grads.size(1);\n  TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));\n\n  void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());\n\n  // Softmax Grad\n  DISPATCH_HALF_AND_BFLOAT(\n      output_grads_.scalar_type(), \"dispatch_scaled_upper_triang_masked_softmax_backward\",\n      dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(\n          reinterpret_cast<scalar_t*>(output_grads_ptr), reinterpret_cast<scalar_t*>(output_grads_ptr),\n          reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()), scale_factor, seq_len, seq_len,\n          attn_batches););\n\n  // backward pass is completely in-place\n  return output_grads;\n}\n}  // namespace scaled_upper_triang_masked_softmax\n}  // namespace fused_softmax\n}  // namespace multihead_attn\n"
  },
  {
    "path": "csrc/mlp.cpp",
    "content": "#include <stdio.h>\n#include <torch/extension.h>\n#include <torch/torch.h>\n\n#include <vector>\n\nsize_t get_mlp_reserved_space(int64_t batch_size, int num_layers, const int* output_features);\n\ntemplate <typename T>\nsize_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features);\n\ntemplate <typename T>\nint mlp_fp(T* X, int input_features, int batch_size, T** WPtr, int num_layers, int* output_features, T** BPtr, T* Y,\n           T* reserved_space, int use_bias, int activation, void* lt_workspace);\n\ntemplate <typename T>\nint mlp_bp(T* X, T* Y, int input_features, int batch_size, T** WPtr, int num_layers, int* output_features, T* dY,\n           T* reserved_space, T* work_space, T* dX, T** dwPtr, T** dbPtr, bool requires_grad, int use_bias,\n           int activation);\n\nstd::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at::Tensor> inputs) {\n  auto num_layers = inputs.size() - 1;\n  if (use_bias) {\n    // inputs contains (input, weights, biases)\n    num_layers /= 2;\n  }\n  auto batch_size = inputs[0].size(0);\n  auto input_features = inputs[0].size(1);\n\n  std::vector<int> output_features;\n  for (int i = 0; i < num_layers; i++) {\n    output_features.push_back(inputs[i + 1].size(0));\n  }\n\n  auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());\n\n  // create output/workspace tensor\n  auto out = at::empty({batch_size, output_features.back()}, inputs[0].type());\n  auto reserved_space = at::empty({static_cast<long>(reserved_size)}, inputs[0].type());\n  // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB\n  auto lt_workspace = at::empty({1 << 22}, inputs[0].type());\n\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].scalar_type(), \"mlp_forward\", [&] {\n    std::vector<scalar_t*> w_ptr;\n    std::vector<scalar_t*> b_ptr;\n    for (int i = 0; i < num_layers; i++) {\n      w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());\n      if (use_bias) {\n        b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>());\n      }\n    }\n    [[maybe_unused]] auto result = mlp_fp<scalar_t>(inputs[0].data_ptr<scalar_t>(), input_features, batch_size,\n                                                    w_ptr.data(), num_layers, output_features.data(), b_ptr.data(),\n                                                    out.data_ptr<scalar_t>(), reserved_space.data_ptr<scalar_t>(),\n                                                    use_bias, activation, (void*)(lt_workspace.data_ptr<scalar_t>()));\n  });\n\n  return {out, reserved_space};\n}\n\nstd::vector<at::Tensor> mlp_backward(int use_bias, int activation, at::Tensor grad_o,\n                                     std::vector<at::Tensor> fprop_outputs, std::vector<at::Tensor> inputs) {\n  auto num_layers = inputs.size() - 1;\n  if (use_bias) {\n    // inputs contains (input, weights, biases)\n    num_layers /= 2;\n  }\n\n  auto batch_size = inputs[0].size(0);\n  auto input_features = inputs[0].size(1);\n\n  bool requires_grad = inputs[0].requires_grad();\n\n  std::vector<int> output_features;\n  for (int i = 0; i < num_layers; i++) {\n    output_features.push_back(inputs[i + 1].size(0));\n  }\n  // create outputs, length of inputs\n  std::vector<at::Tensor> outputs;\n  for (int i = 0; i < inputs.size(); i++) {\n    outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type()));  // clone for testing now\n  }\n\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].scalar_type(), \"mlp_backward\", [&] {\n    std::vector<scalar_t*> w_ptr;\n    for (int i = 0; i < num_layers; i++) {\n      w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());\n    }\n    std::vector<scalar_t*> outputs_ptr;\n    for (int i = 0; i < inputs.size(); i++) {\n      outputs_ptr.push_back(outputs[i].data_ptr<scalar_t>());\n    }\n\n    auto work_size = get_mlp_bp_workspace_in_bytes<scalar_t>(batch_size, num_layers, output_features.data());\n\n    // auto work_space = at::empty({work_size*4}, at::kByte);\n    auto work_space = at::empty({static_cast<long>(work_size / sizeof(scalar_t))}, inputs[0].type());\n\n    [[maybe_unused]] auto result = mlp_bp<scalar_t>(\n        inputs[0].data_ptr<scalar_t>(), fprop_outputs[0].data_ptr<scalar_t>(), input_features, batch_size, w_ptr.data(),\n        num_layers, output_features.data(), grad_o.contiguous().data_ptr<scalar_t>(),\n        fprop_outputs[1].data_ptr<scalar_t>(), work_space.data_ptr<scalar_t>(), outputs_ptr[0], outputs_ptr.data() + 1,\n        outputs_ptr.data() + 1 + num_layers, requires_grad, use_bias, activation);\n  });\n\n  return outputs;\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &mlp_forward, \"MLP forward\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"backward\", &mlp_backward, \"MLP backward\", py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "csrc/mlp_cuda.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <assert.h>\n#include <stdio.h>\n#include <stdlib.h>\n#include <string.h>\n#include <torch/torch.h>\n\n/* Includes, cuda */\n#include <cublas_v2.h>\n#include <cuda_runtime.h>\n\n#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000\n// includes cublaslt\n#include <cublasLt.h>\n#endif\n// constants for fused bias+relu kernel\n#define BIAS_RELU_FW_NTHREADS 128    // forward number of thread per block\n#define BIAS_RELU_BW_NTHREADS_X 32   // backward number of thread in feature dim\n#define BIAS_RELU_BW_NTHREADS_Y 16   // backward number of thread in batch dim\n#define BIAS_RELU_RED_PER_THREAD 16  // backward minimal reduction length per thread\n\n// move to a header later on\n#define ILP 4\ntemplate <typename T>\n__host__ __device__ __forceinline__ bool is_aligned(T* p) {\n  return ((uint64_t)p) % (ILP * sizeof(T)) == 0;\n}\n\ntemplate <typename T>\n__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset) {\n  typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;\n  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];\n}\ntemplate <typename T>\n__device__ __forceinline__ void load_store(T* dst, volatile T* src, int dst_offset, int src_offset) {\n  typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;\n  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];\n}\ntemplate <typename T>\n__device__ __forceinline__ void load_store(volatile T* dst, T* src, int dst_offset, int src_offset) {\n  typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;\n  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];\n}\n\n// Keep ReLU in float only. When using half, cast to float before calling.\n__device__ __inline__ float relu(float a) {\n  float retf = max(a, 0.f);\n  return (retf);\n}\n\n// Keep Sigmoid in float only. When using half, cast to float before calling.\n__device__ __inline__ float sigmoid(float a) {\n  float retf = 1.f / (1.f + expf(-a));\n  return (retf);\n}\n\n// FP64 Wrapper around cublas GEMMEx\ncublasStatus_t mlp_gemm(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                        float* alpha, const double* A, int lda, const double* B, int ldb, const float* beta, double* C,\n                        int ldc) {\n  return cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_R_64F, lda, B, CUDA_R_64F, ldb, beta, C,\n                      CUDA_R_64F, ldc, CUDA_R_64F, CUBLAS_GEMM_DEFAULT);\n}\n\n// FP32 Wrapper around cublas GEMMEx\ncublasStatus_t mlp_gemm(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                        float* alpha, const float* A, int lda, const float* B, int ldb, const float* beta, float* C,\n                        int ldc) {\n  return cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_R_32F, lda, B, CUDA_R_32F, ldb, beta, C,\n                      CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT);\n}\n\n// FP16 Tensor core wrapper around cublas GEMMEx\ncublasStatus_t mlp_gemm(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                        float* alpha, const at::Half* A, int lda, const at::Half* B, int ldb, float* beta, at::Half* C,\n                        int ldc) {\n  return cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_R_16F, lda, B, CUDA_R_16F, ldb, beta, C,\n                      CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);\n}\n#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000\nint mlp_gemm_lt(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                float* alpha,                                                        /* host pointer */\n                const at::Half* A, int lda, const at::Half* B, int ldb, float* beta, /* host pointer */\n                at::Half* C, int ldc, void* workspace, size_t workspaceSize, cudaStream_t stream, bool use_bias,\n                bool use_relu, const void* bias) {\n  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;\n\n  cublasLtMatmulDescOpaque_t operationDesc = {};\n  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};\n  cublasLtMatmulPreferenceOpaque_t preference = {};\n\n  int returnedResults = 0;\n  cublasLtMatmulHeuristicResult_t heuristicResult = {};\n  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;\n\n  // Create operation descriptor; see cublasLtMatmulDescAttributes_t\n  // for details about defaults; here we just set the transforms for\n  // A and B.\n  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (use_bias) {\n    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));\n    if (status != CUBLAS_STATUS_SUCCESS) {\n      goto CLEANUP;\n    }\n    if (use_relu) {\n      epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;\n    } else {\n      epilogue = CUBLASLT_EPILOGUE_BIAS;\n    }\n  } else {\n    if (use_relu) {\n      epilogue = CUBLASLT_EPILOGUE_RELU;\n    }\n  }\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n\n  // Create matrix descriptors. Not setting any extra attributes.\n  status =\n      cublasLtMatrixLayoutInit(&Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status =\n      cublasLtMatrixLayoutInit(&Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // Create preference handle; In general, extra attributes can be\n  // used here to disable tensor ops or to make sure algo selected\n  // will work with badly aligned A, B, C. However, for simplicity\n  // here we assume A,B,C are always well aligned (e.g., directly\n  // come from cudaMalloc)\n  status = cublasLtMatmulPreferenceInit(&preference);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize,\n                                                sizeof(workspaceSize));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // We just need the best available heuristic to try and run matmul.\n  // There is no guarantee that this will work. For example, if A is\n  // badly aligned, you can request more (e.g. 32) algos and try to\n  // run them one by one until something works.\n  status = cublasLtMatmulAlgoGetHeuristic(ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1,\n                                          &heuristicResult, &returnedResults);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (returnedResults == 0) {\n    status = CUBLAS_STATUS_NOT_SUPPORTED;\n    goto CLEANUP;\n  }\n  status = cublasLtMatmul(ltHandle, &operationDesc, alpha, A, &Adesc, B, &Bdesc, beta, C, &Cdesc, C, &Cdesc,\n                          &heuristicResult.algo, workspace, workspaceSize, stream);\n\nCLEANUP:\n  // Descriptors are no longer needed as all GPU work was already\n  // enqueued.\n  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;\n}\n\nint mlp_gemm_lt(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                float* alpha,                                                    /* host pointer */\n                const double* A, int lda, const double* B, int ldb, float* beta, /* host pointer */\n                double* C, int ldc, void* workspace, size_t workspaceSize, cudaStream_t stream, bool use_bias,\n                bool use_relu, const void* bias) {\n  return 1;\n}\n\nint mlp_gemm_lt(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,\n                float* alpha,                                                  /* host pointer */\n                const float* A, int lda, const float* B, int ldb, float* beta, /* host pointer */\n                float* C, int ldc, void* workspace, size_t workspaceSize, cudaStream_t stream, bool use_bias,\n                bool use_relu, const void* bias) {\n  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;\n\n  cublasLtMatmulDescOpaque_t operationDesc = {};\n  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};\n  cublasLtMatmulPreferenceOpaque_t preference = {};\n\n  int returnedResults = 0;\n  cublasLtMatmulHeuristicResult_t heuristicResult = {};\n  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;\n\n  // Create operation descriptor; see cublasLtMatmulDescAttributes_t\n  // for details about defaults; here we just set the transforms for\n  // A and B.\n  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (use_bias) {\n    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));\n    if (status != CUBLAS_STATUS_SUCCESS) {\n      goto CLEANUP;\n    }\n    if (use_relu) {\n      epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;\n    } else {\n      epilogue = CUBLASLT_EPILOGUE_BIAS;\n    }\n  } else {\n    if (use_relu) {\n      epilogue = CUBLASLT_EPILOGUE_RELU;\n    }\n  }\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n\n  // Create matrix descriptors. Not setting any extra attributes.\n  status =\n      cublasLtMatrixLayoutInit(&Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status =\n      cublasLtMatrixLayoutInit(&Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // Create preference handle; In general, extra attributes can be\n  // used here to disable tensor ops or to make sure algo selected\n  // will work with badly aligned A, B, C. However, for simplicity\n  // here we assume A,B,C are always well aligned (e.g., directly\n  // come from cudaMalloc)\n  status = cublasLtMatmulPreferenceInit(&preference);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize,\n                                                sizeof(workspaceSize));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // We just need the best available heuristic to try and run matmul.\n  // There is no guarantee that this will work. For example, if A is\n  // badly aligned, you can request more (e.g. 32) algos and try to\n  // run them one by one until something works.\n  status = cublasLtMatmulAlgoGetHeuristic(ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1,\n                                          &heuristicResult, &returnedResults);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (returnedResults == 0) {\n    status = CUBLAS_STATUS_NOT_SUPPORTED;\n    goto CLEANUP;\n  }\n\n  status = cublasLtMatmul(ltHandle, &operationDesc, alpha, A, &Adesc, B, &Bdesc, beta, C, &Cdesc, C, &Cdesc,\n                          &heuristicResult.algo, workspace, workspaceSize, stream);\n\nCLEANUP:\n  // Descriptors are no longer needed as all GPU work was already\n  // enqueued.\n  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;\n}\n#endif\n\n// Bias ADD. Assume input X is [features x batch size], column major.\n// Bias is one 'features' long vector, with implicit broadcast.\ntemplate <typename T>\n__global__ void biasAdd_fprop(T* X, T* b, uint batch_size, uint features) {\n  T r_x[ILP];\n  T r_b[ILP];\n  if (is_aligned(X) && is_aligned(b) && features % ILP == 0) {\n    int tid = blockIdx.x * blockDim.x + threadIdx.x;\n    for (; tid * ILP < features * batch_size; tid += blockDim.x * gridDim.x) {\n      int row = tid % (features / ILP);\n      load_store(r_x, X, 0, tid);\n      load_store(r_b, b, 0, row);\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);\n        r_x[ii] = bias_sum;\n      }\n      load_store(X, r_x, tid, 0);\n    }\n  } else {\n    int tid = blockIdx.x * blockDim.x + threadIdx.x;\n    for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int idx = tid + ii * blockDim.x * gridDim.x;\n        if (idx < features * batch_size) {\n          int row = tid % features;\n          r_x[ii] = X[idx];\n          r_b[ii] = b[row];\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);\n        r_x[ii] = bias_sum;\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int idx = tid + ii * blockDim.x * gridDim.x;\n        if (idx < features * batch_size) {\n          X[idx] = r_x[ii];\n        }\n      }\n    }\n  }\n}\n\n// Bias ADD + ReLU. Assume input X is [features x batch size], column major.\n// Activation support fuesed ReLU. Safe to call in-place.\ntemplate <typename T>\n__global__ void biasAddRelu_fprop(T* X, T* b, uint batch_size, uint features) {\n  T r_x[ILP];\n  T r_b[ILP];\n  if (is_aligned(X) && is_aligned(b) && features % ILP == 0) {\n    int tid = blockIdx.x * blockDim.x + threadIdx.x;\n    for (; tid * ILP < features * batch_size; tid += blockDim.x * gridDim.x) {\n      int row = tid % (features / ILP);\n      load_store(r_x, X, 0, tid);\n      load_store(r_b, b, 0, row);\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);\n        r_x[ii] = relu(bias_sum);\n      }\n      load_store(X, r_x, tid, 0);\n    }\n  } else {\n    int tid = blockIdx.x * blockDim.x + threadIdx.x;\n    for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int idx = tid + ii * blockDim.x * gridDim.x;\n        if (idx < features * batch_size) {\n          int row = tid % features;\n          r_x[ii] = X[idx];\n          r_b[ii] = b[row];\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);\n        r_x[ii] = relu(bias_sum);\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int idx = tid + ii * blockDim.x * gridDim.x;\n        if (idx < features * batch_size) {\n          X[idx] = r_x[ii];\n        }\n      }\n    }\n  }\n}\n\n// ReLU. Assume input X is [features x batch size], column major.\n// Safe to call in-place.\ntemplate <typename T>\n__global__ void Relu_fprop(T* X, uint batch_size, uint features) {\n  T r_x[ILP];\n  if (is_aligned(X) && features % ILP == 0) {\n    int tid = blockIdx.x * blockDim.x + threadIdx.x;\n    for (; tid * ILP < features * batch_size; tid += blockDim.x * gridDim.x) {\n      load_store(r_x, X, 0, tid);\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        r_x[ii] = relu(static_cast<float>(r_x[ii]));\n      }\n      load_store(X, r_x, tid, 0);\n    }\n  } else {\n    int tid = blockIdx.x * blockDim.x + threadIdx.x;\n    for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int idx = tid + ii * blockDim.x * gridDim.x;\n        if (idx < features * batch_size) {\n          r_x[ii] = X[idx];\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        r_x[ii] = relu(static_cast<float>(r_x[ii]));\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int idx = tid + ii * blockDim.x * gridDim.x;\n        if (idx < features * batch_size) {\n          X[idx] = r_x[ii];\n        }\n      }\n    }\n  }\n}\n\n// Sigmoid. Assume input X is [features x batch size], column major.\n// Safe to call in-place.\ntemplate <typename T>\n__global__ void Sigmoid_fprop(T* X, uint batch_size, uint features) {\n  T r_x[ILP];\n  if (is_aligned(X) && features % ILP == 0) {\n    int tid = blockIdx.x * blockDim.x + threadIdx.x;\n    for (; tid * ILP < features * batch_size; tid += blockDim.x * gridDim.x) {\n      load_store(r_x, X, 0, tid);\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        r_x[ii] = sigmoid(static_cast<float>(r_x[ii]));\n      }\n      load_store(X, r_x, tid, 0);\n    }\n  } else {\n    int tid = blockIdx.x * blockDim.x + threadIdx.x;\n    for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int idx = tid + ii * blockDim.x * gridDim.x;\n        if (idx < features * batch_size) {\n          r_x[ii] = X[idx];\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        r_x[ii] = sigmoid(static_cast<float>(r_x[ii]));\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int idx = tid + ii * blockDim.x * gridDim.x;\n        if (idx < features * batch_size) {\n          X[idx] = r_x[ii];\n        }\n      }\n    }\n  }\n}\n\n// ReLU. Assume input X is [features x batch size], column major.\n// Safe to call in-place.\ntemplate <typename T>\n__global__ void Relu_bprop(T* dY, T* Y, uint batch_size, uint features, T* dX) {\n  T r_dy[ILP];\n  T r_y[ILP];\n  if (is_aligned(dY) && is_aligned(Y) && is_aligned(dX) && features % ILP == 0) {\n    int tid = blockIdx.x * blockDim.x + threadIdx.x;\n    for (; tid * ILP < features * batch_size; tid += blockDim.x * gridDim.x) {\n      load_store(r_dy, dY, 0, tid);\n      load_store(r_y, Y, 0, tid);\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        if ((float)r_y[ii] <= 0.f) r_dy[ii] = 0;\n      }\n      load_store(dX, r_dy, tid, 0);\n    }\n  } else {\n    int tid = blockIdx.x * blockDim.x + threadIdx.x;\n    for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int idx = tid + ii * blockDim.x * gridDim.x;\n        if (idx < features * batch_size) {\n          r_dy[ii] = dY[idx];\n          r_y[ii] = Y[idx];\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        if ((float)r_y[ii] <= 0.f) r_dy[ii] = 0;\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int idx = tid + ii * blockDim.x * gridDim.x;\n        if (idx < features * batch_size) {\n          dX[idx] = r_dy[ii];\n        }\n      }\n    }\n  }\n}\n\n// Sigmoid. Assume input X is [features x batch size], column major.\n// Safe to call in-place.\ntemplate <typename T>\n__global__ void Sigmoid_bprop(T* dY, T* Y, uint batch_size, uint features, T* dX) {\n  T r_dy[ILP];\n  T r_y[ILP];\n  if (is_aligned(dY) && is_aligned(Y) && is_aligned(dX) && features % ILP == 0) {\n    int tid = blockIdx.x * blockDim.x + threadIdx.x;\n    for (; tid * ILP < features * batch_size; tid += blockDim.x * gridDim.x) {\n      load_store(r_dy, dY, 0, tid);\n      load_store(r_y, Y, 0, tid);\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        float grad_out = r_dy[ii];\n        float out = r_y[ii];\n        float grad_i = out * (1.f - out) * grad_out;\n        r_dy[ii] = grad_i;\n      }\n      load_store(dX, r_dy, tid, 0);\n    }\n  } else {\n    int tid = blockIdx.x * blockDim.x + threadIdx.x;\n    for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int idx = tid + ii * blockDim.x * gridDim.x;\n        if (idx < features * batch_size) {\n          r_dy[ii] = dY[idx];\n          r_y[ii] = Y[idx];\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        float grad_out = r_dy[ii];\n        float out = r_y[ii];\n        float grad_i = out * (1.f - out) * grad_out;\n        r_dy[ii] = grad_i;\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int idx = tid + ii * blockDim.x * gridDim.x;\n        if (idx < features * batch_size) {\n          dX[idx] = r_dy[ii];\n        }\n      }\n    }\n  }\n}\n\n// Compute grid size for pointwise backward kernel.\n// block_x/y is total elment being handled per block, not number of threads\nvoid get_biasAddRelu_bprop_grid_size(int yfeat, int batch_size, int block_x, int block_y, int* grid_x, int* grid_y) {\n  *grid_x = (yfeat + block_x - 1) / block_x;\n  // Get number of SMs for efficient reduction.\n  int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;\n  // can switch to occupancy calculation. use 4 below now for sm_70\n  int max_blocks_y = (num_SMs * 4 + (*grid_x) - 1) / (*grid_x);\n  // block_y should be from minimal work per thread\n  int nRedSplits = (batch_size + block_y - 1) / block_y;\n  // increase number of elem per thread redcution to not launch more than enough\n  // kernel adjust work, so here we just launch max block\n  *grid_y = std::min(nRedSplits, max_blocks_y);\n  return;\n}\n\n// Addition done deterministically via a 2-pass approach. Each CTA writes out partial\n// sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.\ntemplate <typename T, int UNROLL_FACTOR>\n__global__ void biasAdd_bprop(T* dY, int features, int batch_size, volatile float* intermediate, int* semaphores,\n                              T* db) {\n  // The feature that this thread is responsible for\n  int f = blockIdx.x * blockDim.x + threadIdx.x;\n\n  // Compute the span this thread is responsible for\n  // For this block\n  int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;\n  int b_nStart = blockIdx.y * b_chunkSize;\n  int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;\n  // For this thread\n  int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;\n  int nStart = threadIdx.y * chunkSize + b_nStart;\n  int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;\n\n  volatile float* out = intermediate + blockIdx.y * features;\n\n  // Flag to trigger last reduction.\n  __shared__ bool isLastBlock;\n  // we know block size for now\n  __shared__ float smem[BIAS_RELU_BW_NTHREADS_X * BIAS_RELU_BW_NTHREADS_Y];\n\n  // Accumulate db in FP32 always\n  float db_local = 0;\n  if (f < features) {\n    int nidx = 0;\n    // Handle non-multiple of UNROLL_FACTOR residue\n    for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {\n      int64_t row, col, flat_idx;\n      row = f;\n      col = nStart + nidx;\n      flat_idx = col * features + row;\n      db_local += (float)dY[flat_idx];\n    }\n\n    // Handle meat of work\n    for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {\n      int64_t row, col, flat_idx;\n      row = f;\n      col = nStart + nidx;\n      flat_idx = col * features + row;\n#pragma unroll 4\n      for (int u = 0; u < UNROLL_FACTOR; u++) {\n        db_local += (float)dY[flat_idx];\n        flat_idx += features;\n      }\n    }\n\n    // naive block reduction on y-dim\n    int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;\n    smem[linear_idx] = db_local;\n  }\n  __syncthreads();\n  if (f < features) {\n    if (threadIdx.y == 0) {\n      for (int yidx = 1; yidx < blockDim.y; yidx++) {\n        db_local += smem[yidx * blockDim.x + threadIdx.x];\n      }\n\n      // block result is in db_local now for all threadIdx.y == 0\n      // Write out partial result\n      out[f] = db_local;\n    }\n  }\n  __threadfence();\n  __syncthreads();\n\n  // Increment semaphore and check if this is the last CTA in the grid_y dimension.\n  // Only thread (0,0) calls this\n  if (threadIdx.x == 0 && threadIdx.y == 0 && f < features) {\n    unsigned int sum_idx;\n    sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);\n    isLastBlock = (sum_idx == (gridDim.y - 1));\n  }\n  __syncthreads();\n\n  db_local = 0;\n  // No block reduction for now, only thread (*,0) do grid reduction\n  if (isLastBlock && f < features) {\n    if (threadIdx.y == 0) {\n      for (int n = 0; n < gridDim.y; n++) {\n        int row, col;\n        row = f;\n        col = n;\n        db_local += (float)(intermediate[col * features + row]);\n      }\n      db[f] = (T)db_local;\n    }\n  }\n}\n\n// Addition done deterministically via a 2-pass approach. Each CTA writes out partial\n// sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.\ntemplate <typename T, int UNROLL_FACTOR>\n__global__ void biasAddRelu_bprop(T* Y, T* dY, int features, int batch_size, T* dX, volatile float* intermediate,\n                                  int* semaphores, T* db) {\n  // The feature that this thread is responsible for\n  int f = blockIdx.x * blockDim.x + threadIdx.x;\n\n  // Compute the span this thread is responsible for\n  // For this block\n  int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;\n  int b_nStart = blockIdx.y * b_chunkSize;\n  int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;\n  // For this thread\n  int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;\n  int nStart = threadIdx.y * chunkSize + b_nStart;\n  int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;\n\n  volatile float* out = intermediate + blockIdx.y * features;\n\n  // Flag to trigger last reduction.\n  __shared__ bool isLastBlock;\n  // we know block size for now\n  __shared__ float smem[BIAS_RELU_BW_NTHREADS_X * BIAS_RELU_BW_NTHREADS_Y];\n\n  // Accumulate db in FP32 always\n  float db_local = 0;\n  if (f < features) {\n    int nidx = 0;\n    // Handle non-multiple of UNROLL_FACTOR residue\n    for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {\n      int row, col, flat_idx;\n      row = f;\n      col = nStart + nidx;\n      flat_idx = col * features + row;\n      T y_val = Y[flat_idx];\n      T dy_val = dY[flat_idx];\n      T dx_val;\n      if ((float)y_val > 0.f)\n        dx_val = dy_val;\n      else\n        dx_val = 0;\n      dX[flat_idx] = dx_val;\n      db_local += (float)dx_val;\n    }\n\n    // Handle meat of work\n    for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {\n      int row, col, flat_idx;\n      row = f;\n      col = nStart + nidx;\n      flat_idx = col * features + row;\n#pragma unroll 4\n      for (int u = 0; u < UNROLL_FACTOR; u++) {\n        T y_val = Y[flat_idx];\n        T dy_val = dY[flat_idx];\n        T dx_val;\n        if ((float)y_val > 0.f)\n          dx_val = dy_val;\n        else\n          dx_val = 0;\n        dX[flat_idx] = dx_val;\n        db_local += (float)dx_val;\n        flat_idx += features;\n      }\n    }\n\n    // naive block reduction on y-dim\n    int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;\n    smem[linear_idx] = db_local;\n  }\n  __syncthreads();\n  if (f < features) {\n    if (threadIdx.y == 0) {\n      for (int yidx = 1; yidx < blockDim.y; yidx++) {\n        db_local += smem[yidx * blockDim.x + threadIdx.x];\n      }\n\n      // block result is in db_local now for all threadIdx.y == 0\n      // Write out partial result\n      out[f] = db_local;\n    }\n  }\n  __threadfence();\n  __syncthreads();\n\n  // Increment semaphore and check if this is the last CTA in the grid_y dimension.\n  // Only thread (0,0) calls this\n  if (threadIdx.x == 0 && threadIdx.y == 0 && f < features) {\n    unsigned int sum_idx;\n    sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);\n    isLastBlock = (sum_idx == (gridDim.y - 1));\n  }\n  __syncthreads();\n\n  db_local = 0;\n  // No block reduction for now, only thread (*,0) do grid reduction\n  if (isLastBlock && f < features) {\n    if (threadIdx.y == 0) {\n      for (int n = 0; n < gridDim.y; n++) {\n        int row, col;\n        row = f;\n        col = n;\n        db_local += (float)(intermediate[col * features + row]);\n      }\n      db[f] = (T)db_local;\n    }\n  }\n}\n\n// Addition done deterministically via a 2-pass approach. Each CTA writes out partial\n// sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.\ntemplate <typename T, int UNROLL_FACTOR>\n__global__ void biasAddRelu_bprop_aligned(T* Y, T* dY, int features, int batch_size, T* dX,\n                                          volatile float* intermediate, int* semaphores, T* db) {\n  // The feature that this thread is responsible for\n  int f = blockIdx.x * blockDim.x + threadIdx.x;\n\n  // Compute the span this thread is responsible for\n  // For this block\n  int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;\n  int b_nStart = blockIdx.y * b_chunkSize;\n  int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;\n  // For this thread\n  int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;\n  int nStart = threadIdx.y * chunkSize + b_nStart;\n  int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;\n\n  volatile float* out = intermediate + blockIdx.y * features;\n\n  // Flag to trigger last reduction.\n  __shared__ bool isLastBlock;\n\n  // Accumulate db in FP32 always\n  float db_local[ILP];\n  T r_y[ILP];\n  T r_dy[ILP];\n#pragma unroll\n  for (int ii = 0; ii < ILP; ii++) {\n    db_local[ii] = 0.f;\n  }\n\n  // f always <= features in this case\n  // if (f < features) {\n  int nidx = 0;\n\n  // Handle non-multiple of UNROLL_FACTOR residue\n  for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {\n    int row, col, flat_idx;\n    row = f;\n    col = nStart + nidx;\n    flat_idx = col * features / ILP + row;\n\n    load_store(r_y, Y, 0, flat_idx);\n    load_store(r_dy, dY, 0, flat_idx);\n#pragma unroll\n    for (int ii = 0; ii < ILP; ii++) {\n      if ((float)r_y[ii] <= 0.f) r_dy[ii] = 0;\n      db_local[ii] += (float)r_dy[ii];\n    }\n    load_store(dX, r_dy, flat_idx, 0);\n  }\n\n  // Handle meat of work\n  for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {\n    int row, col, flat_idx;\n    row = f;\n    col = nStart + nidx;\n    flat_idx = col * features / ILP + row;  // total threads in x == features/ILP\n#pragma unroll\n    for (int u = 0; u < UNROLL_FACTOR; u++) {\n      load_store(r_y, Y, 0, flat_idx);\n      load_store(r_dy, dY, 0, flat_idx);\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        if ((float)r_y[ii] <= 0.f) r_dy[ii] = 0;\n        db_local[ii] += (float)r_dy[ii];\n      }\n      load_store(dX, r_dy, flat_idx, 0);\n      flat_idx += features / ILP;\n    }\n  }\n\n  // we know block size for now\n  __shared__ float smem[BIAS_RELU_BW_NTHREADS_X * BIAS_RELU_BW_NTHREADS_Y * ILP];\n  // naive block reduction on y-dim\n  int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;\n  float* smem_out = smem + ILP * linear_idx;\n#pragma unroll\n  for (int ii = 0; ii < ILP; ii++) {\n    smem_out[ii] = db_local[ii];  // reuse local dy buffer\n  }\n  __syncthreads();\n  if (threadIdx.y == 0) {\n    for (int yidx = 1; yidx < blockDim.y; yidx++) {\n      float* smem_in = smem + ILP * (yidx * blockDim.x + threadIdx.x);\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        db_local[ii] += smem_in[ii];  // reuse local dy buffer\n      }\n    }\n\n    // block result is in db_local now for all threadIdx.y == 0\n    if (gridDim.y == 1) {\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        r_dy[ii] = db_local[ii];  // reuse local dy buffer\n      }\n      load_store(db, r_dy, f, 0);\n      return;\n    }\n\n    // Write out partial result\n    load_store(out, db_local, f, 0);\n  }\n  __threadfence();\n  __syncthreads();\n\n  // Increment semaphore and check if this is the last CTA in the grid_y dimension.\n  // Only thread (0,0) calls this\n  if (threadIdx.x == 0 && threadIdx.y == 0) {\n    unsigned int sum_idx;\n    sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);\n    isLastBlock = (sum_idx == (gridDim.y - 1));\n  }\n  __syncthreads();\n\n#pragma unroll\n  for (int ii = 0; ii < ILP; ii++) {\n    db_local[ii] = 0.f;\n  }\n  float r_db[ILP];\n\n  // No block reduction for now, only thread (*,0) do grid reduction\n  if (isLastBlock) {\n    if (threadIdx.y == 0) {\n      for (int n = 0; n < gridDim.y; n++) {\n        int row, col;\n        row = f;\n        col = n;\n        load_store(r_db, intermediate, 0, col * features / ILP + row);\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          db_local[ii] += r_db[ii];\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        r_dy[ii] = db_local[ii];  // reuse local dy buffer\n      }\n      load_store(db, r_dy, f, 0);\n    }\n  }\n}\n\n// Lists where the num_layers-1 intermediate Y buffers start in reserved space on fprop, starting\n// offset 0. The last Y value is, of course, stored in the user provided output buffer.\nvoid get_y_offsets(int batch_size, int num_layers, const int* output_features, int* y_start_offsets) {\n  y_start_offsets[0] = 0;\n  for (int i = 1; i < num_layers; i++) {\n    y_start_offsets[i] = y_start_offsets[i - 1] + batch_size * output_features[i - 1];\n  }\n}\n\n// Returns the reserved space (in elements) needed for the MLP\nsize_t get_mlp_reserved_space(int64_t batch_size, int num_layers, const int* output_features) {\n  size_t res_space = 0;\n  // Need to store output of every intermediate MLP - size equal to output_features[i] * batch_size\n  // for all 'i' in [0, num_layers-1)\n  for (int l = 0; l < num_layers; l++) {\n    res_space += output_features[l] * batch_size;\n  }\n  return res_space;\n}\n\n// Returns the size of all fprop activations combined\nsize_t get_all_activations_size(int64_t batch_size, int num_layers, const int* output_features) {\n  size_t acts_size = 0;\n  for (int l = 0; l < num_layers; l++) {\n    acts_size += output_features[l] * batch_size;\n  }\n  return acts_size;\n}\n\n#if 0\n// Returns the work space (in elements) needed for the MLP bprop.\nsize_t get_mlp_bp_workspace (int batch_size, int num_layers, const int* output_features) {\n    /*\n       Workspace is partitioned as\n       DY_GEMMs : DX_GEMMs\n    */\n    size_t work_space = 0;\n\n    // Store each intermediate dY explicitly. Need 2 dYs per MLP layer (one for o/p\n    // of biasReLU_bp and one for o/p of dgrad GEMM).\n    work_space += 2*get_all_activations_size(batch_size, num_layers, output_features);\n\n    return work_space;\n}\n#endif\n\n// Scratch space needed for reductions in number of elements\nsize_t get_reduction_scratch_space(int batch_size, int num_layers, const int* output_features) {\n  size_t max_scratch_space = 0;\n  // Loop over all layers to see which one needs the max scratch space\n  for (int l = 0; l < num_layers; l++) {\n    // need to find max(aligned, not_aligned)\n    int tmp, res0, res1;\n\n    int block_x = BIAS_RELU_BW_NTHREADS_X;\n    int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;\n    get_biasAddRelu_bprop_grid_size(output_features[l], batch_size, block_x, block_y, &tmp, &res0);\n\n    block_x = ILP * BIAS_RELU_BW_NTHREADS_X;\n    get_biasAddRelu_bprop_grid_size(output_features[l], batch_size, block_x, block_y, &tmp, &res1);\n\n    max_scratch_space = std::max(max_scratch_space, (size_t)(output_features[l] * res0));\n    max_scratch_space = std::max(max_scratch_space, (size_t)(output_features[l] * res1));\n  }\n\n  return max_scratch_space;\n}\n\n// Buffer for semaphores\nsize_t get_semaphores_size(int num_layers, const int* output_features) {\n  // Upper bound on semaphores is one per feature for the layer\n  // with the most features.\n  int max_features = 0;\n  for (int l = 0; l < num_layers; l++) {\n    max_features = std::max(max_features, output_features[l]);\n  }\n  return (size_t)max_features;\n}\n\n// Returns the work space (in elements) needed for the MLP bprop.\ntemplate <typename T>\nsize_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features) {\n  size_t work_space = 0;\n\n  // Store each intermediate dY explicitly. Need 2 dYs per MLP layer (one for o/p\n  // of biasReLU_bp and one for o/p of dgrad GEMM).\n  work_space += 2 * get_all_activations_size(batch_size, num_layers, output_features) * sizeof(T);\n  work_space += get_reduction_scratch_space(batch_size, num_layers, output_features) * sizeof(float);\n  work_space += get_semaphores_size(num_layers, output_features) * sizeof(int);\n\n  return work_space;\n}\n\n// Returns pointers to each segment of the workspace\ntemplate <typename T>\nvoid partition_mlp_bp_workspace(int batch_size, int num_layers, const int* output_features, void* work_space,\n                                T** dy_gemms, T** dx_gemms, float** db_scratch, int** semaphores) {\n  /*\n     Workspace is partitioned as\n     DY_GEMMs : DX_GEMMs : DB_SCRATCH : SEMAPHORES\n  */\n  // Start address where dy_gemm tensors are stored\n  *dy_gemms = reinterpret_cast<T*>(work_space);\n  // Start address where dx_gemm tensors are stored\n  *dx_gemms = *dy_gemms + get_all_activations_size(batch_size, num_layers, output_features);\n  // Start address where db intermediate tensors are stored\n  *db_scratch = reinterpret_cast<float*>(*dx_gemms + get_all_activations_size(batch_size, num_layers, output_features));\n  // Start address of semaphores\n  *semaphores =\n      reinterpret_cast<int*>(*db_scratch + get_reduction_scratch_space(batch_size, num_layers, output_features));\n\n  return;\n}\n\n// Does a simple MLP fprop (GEMM+bias+ReLU).\n// Can handle num_layers number of layers, each with its own shape. Output of layer i is assumed\n// to be input of layer i+1. output_features, WPtr and BPtr are arrays of length num_layers, and\n// must be in the same order i.e. WPtr[i] and BPtr[i] are respectively the weight and bias of layer\n// 'i'.\ntemplate <typename T>\nint mlp_fp(T* X, int input_features, int batch_size, T** WPtr, int num_layers, int* output_features, T** BPtr, T* Y,\n           T* reserved_space, int use_bias, int activation, void* lt_workspace) {\n  T *weight, *input, *output, *bias;\n  T *reserved_space_x, *reserved_space_y;\n  reserved_space_x = NULL;\n  reserved_space_y = reserved_space;\n\n  // Get cublas handle from Pytorch\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  // Get the stream from cublas handle to reuse for biasReLU kernel.\n  cudaStream_t stream;\n  cublasGetStream(handle, &stream);\n\n  for (int layer = 0; layer < num_layers; layer++) {\n    weight = WPtr[layer];\n    input = (layer == 0) ? X : reserved_space_x;\n    output = (layer == num_layers - 1) ? Y : reserved_space_y;\n    if (use_bias) {\n      bias = BPtr[layer];\n    }\n    int ifeat = (layer == 0) ? input_features : output_features[layer - 1];\n    int ofeat = output_features[layer];\n\n    float one = 1.f;\n    float zero = 0.f;\n\n    // try with cublaslt first for supported case with valid handle\n    int cublaslt_status = 1;\n#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000\n    if (activation < 1) {\n      cublaslt_status = mlp_gemm_lt(\n          // ltHandle,\n          (cublasLtHandle_t)handle, CUBLAS_OP_T, CUBLAS_OP_N, ofeat, batch_size, ifeat, &one, weight, ifeat, input,\n          ifeat, &zero, output, ofeat, lt_workspace, 1 << 22, stream, use_bias == 1, activation == 1, bias);\n    }\n#endif\n\n    // if cublaslt failed or not executed, fallback to cublas\n    if (cublaslt_status != 0) {\n      cublasStatus_t cublas_status;\n      // Call GEMM: fprop is Y = W'X\n      cublas_status = mlp_gemm(handle, CUBLAS_OP_T, CUBLAS_OP_N, ofeat, batch_size, ifeat, &one, weight, ifeat, input,\n                               ifeat, &zero, output, ofeat);\n\n      if (cublas_status != CUBLAS_STATUS_SUCCESS) {\n        printf(\"GEMM fprop failed with %d\\n\", cublas_status);\n        return 1;\n      }\n\n      const uint& input_size = ofeat;\n      int num_blocks = 0;\n      int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;\n      // Call biasReLU\n      if (use_bias == 1) {\n        if (activation == 0) {  // no activation\n          cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);\n          biasAdd_fprop<<<num_SMs * num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size,\n                                                                                    input_size);\n        } else if (activation == 1) {  // relu\n          cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAddRelu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);\n          biasAddRelu_fprop<<<num_SMs * num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size,\n                                                                                        input_size);\n        } else if (activation == 2) {  // sigmoid\n          cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);\n          biasAdd_fprop<<<num_SMs * num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size,\n                                                                                    input_size);\n          cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);\n          Sigmoid_fprop<<<num_SMs * num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);\n        }\n      } else {\n        // don't need to do anything in case of no activation and no bias\n        if (activation == 1) {  // relu\n          cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);\n          Relu_fprop<<<num_SMs * num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);\n        } else if (activation == 2) {  // sigmoid\n          cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);\n          Sigmoid_fprop<<<num_SMs * num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);\n        }\n      }\n    }\n    // Set current output as next layer input\n    reserved_space_x = reserved_space_y;\n    // Set next layer output\n    reserved_space_y += ofeat * batch_size;\n  }\n\n  return 0;\n}\n\n// Does a simple MLP bprop (GEMM+bias+ReLU).\n// Needs reserved space to come back exactly as it was populated in fprop.\n// Does dgrad and wgrad sequentially.\ntemplate <typename T>\nint mlp_bp(T* X, T* Y, int input_features, int batch_size, T** WPtr, int num_layers, int* output_features, T* dY,\n           T* reserved_space, T* work_space, T* dX, T** dwPtr, T** dbPtr, bool requires_grad, int use_bias,\n           int activation) {\n  T* weight;\n  T *dweight, *dx, *dy, *dbias;\n  T *x, *y;\n\n  // Where the dx of the biasReLU (== dy of gemm) is stored. Can be thrown away\n  // after bp call.\n  T* dy_gemm_base;\n  // Where the dx after GEMM is stored.\n  T* dx_gemm_base;\n  // Where partial reduction results are stored.\n  float* db_scratch;\n  // Semaphores for reduction.\n  int* semaphores;\n\n  partition_mlp_bp_workspace<T>(batch_size, num_layers, output_features, work_space, &dy_gemm_base, &dx_gemm_base,\n                                &db_scratch, &semaphores);\n\n  size_t semaphore_size = get_semaphores_size(num_layers, output_features) * sizeof(int);\n\n  // Get cublas handle from Pytorch\n  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n  // Get the stream from cublas handle to reuse for biasReLU kernel.\n  cudaStream_t stream;\n  cublasGetStream(handle, &stream);\n\n  int* y_offsets = (int*)malloc(num_layers * sizeof(int));\n  get_y_offsets(batch_size, num_layers, output_features, y_offsets);\n\n  for (int layer = num_layers - 1; layer >= 0; layer--) {\n    weight = WPtr[layer];\n    dweight = dwPtr[layer];\n\n    // x is read from reserved space\n    x = (layer == 0) ? X : reserved_space + y_offsets[layer - 1];\n    // dx is written in workspace for all but layer==0\n    dx = (layer == 0) ? dX : dx_gemm_base + y_offsets[layer - 1];\n\n    // y is read from reserved space\n    y = (layer == num_layers - 1) ? Y : reserved_space + y_offsets[layer];\n    // dx from layer+1\n    dy = (layer == num_layers - 1) ? dY : dx_gemm_base + y_offsets[layer];\n    // dy_gemm is written to and read immediately\n    T* dy_gemm = dy_gemm_base + y_offsets[layer];\n\n    dbias = dbPtr[layer];\n    int xfeat = (layer == 0) ? input_features : output_features[layer - 1];\n    int yfeat = output_features[layer];\n\n    float one = 1.f;\n    float zero = 0.f;\n\n    if (use_bias == 1) {\n      if (activation == 0) {  // no acitvation\n        // bgrad\n        dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);\n        int grid_x, grid_y;\n        cudaMemsetAsync(semaphores, 0, semaphore_size, stream);\n\n        int block_x = BIAS_RELU_BW_NTHREADS_X;\n        int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;\n        get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);\n        dim3 grid(grid_x, grid_y);\n        biasAdd_bprop<T, 4><<<grid, block, 0, stream>>>(dy, yfeat, batch_size, db_scratch, semaphores, dbias);\n        // bypass dgrad through reset pointer\n        dy_gemm = dy;\n      } else if (activation == 1) {  // relu\n        dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);\n        int grid_x, grid_y;\n        cudaMemsetAsync(semaphores, 0, semaphore_size, stream);\n\n        if (yfeat % (ILP * BIAS_RELU_BW_NTHREADS_X) == 0 && is_aligned(y) && is_aligned(dy) && is_aligned(dy_gemm) &&\n            is_aligned(dbias)) {\n          int block_x = ILP * BIAS_RELU_BW_NTHREADS_X;\n          int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;\n          get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);\n          dim3 grid(grid_x, grid_y);\n          biasAddRelu_bprop_aligned<T, 4>\n              <<<grid, block, 0, stream>>>(y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias);\n        } else {\n          int block_x = BIAS_RELU_BW_NTHREADS_X;\n          int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;\n          get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);\n          dim3 grid(grid_x, grid_y);\n          biasAddRelu_bprop<T, 4>\n              <<<grid, block, 0, stream>>>(y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias);\n        }\n      } else if (activation == 2) {  // sigmoid\n        // activation backward\n        int num_blocks = 0;\n        int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;\n        cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);\n        Sigmoid_bprop<<<num_SMs * num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);\n\n        // bgrad, from dy_gemm\n        dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);\n        int grid_x, grid_y;\n        cudaMemsetAsync(semaphores, 0, semaphore_size, stream);\n\n        int block_x = BIAS_RELU_BW_NTHREADS_X;\n        int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;\n        get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);\n        dim3 grid(grid_x, grid_y);\n        biasAdd_bprop<T, 4><<<grid, block, 0, stream>>>(dy_gemm, yfeat, batch_size, db_scratch, semaphores, dbias);\n      }\n    } else {  // no bias below\n      if (activation == 0) {\n        // bypass dgrad through reset pointer\n        dy_gemm = dy;\n      } else if (activation == 1) {  // relu\n        int num_blocks = 0;\n        int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;\n        cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);\n        Relu_bprop<<<num_SMs * num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);\n      } else if (activation == 2) {  // sigmoid\n        int num_blocks = 0;\n        int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;\n        cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);\n        Sigmoid_bprop<<<num_SMs * num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);\n      }\n    }\n\n    cublasStatus_t cublas_status;\n    // Call GEMM dgrad\n    if (layer > 0 || requires_grad == 1) {\n      cublas_status = mlp_gemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, xfeat, batch_size, yfeat, &one, weight, xfeat, dy_gemm,\n                               yfeat, &zero, dx, xfeat);\n\n      if (cublas_status != CUBLAS_STATUS_SUCCESS) {\n        printf(\"GEMM dgrad failed with %d\\n\", cublas_status);\n        return 1;\n      }\n    }\n\n    // Call GEMM wgrad\n    cublas_status = mlp_gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, xfeat, yfeat, batch_size, &one, x, xfeat, dy_gemm, yfeat,\n                             &zero, dweight, xfeat);\n\n    if (cublas_status != CUBLAS_STATUS_SUCCESS) {\n      printf(\"GEMM wgrad failed with %d\\n\", cublas_status);\n      return 1;\n    }\n  }\n\n  return 0;\n}\n\n// Instantiate for floating point types\ntemplate int mlp_fp<float>(float* X, int input_features, int batch_size, float** WPtr, int num_layers,\n                           int* output_features, float** BPtr, float* Y, float* reserved_space, int use_bias,\n                           int activation, void* lt_workspace);\n\ntemplate int mlp_bp<float>(float* X, float* Y, int input_features, int batch_size, float** WPtr, int num_layers,\n                           int* output_features, float* dY, float* reserved_space, float* work_space, float* dX,\n                           float** dwPtr, float** dbPtr, bool requires_grad, int use_bias, int activation);\n\ntemplate int mlp_fp<at::Half>(at::Half* X, int input_features, int batch_size, at::Half** WPtr, int num_layers,\n                              int* output_features, at::Half** BPtr, at::Half* Y, at::Half* reserved_space,\n                              int use_bias, int activation, void* lt_workspace);\n\ntemplate int mlp_bp<at::Half>(at::Half* X, at::Half* Y, int input_features, int batch_size, at::Half** WPtr,\n                              int num_layers, int* output_features, at::Half* dY, at::Half* reserved_space,\n                              at::Half* work_space, at::Half* dX, at::Half** dwPtr, at::Half** dbPtr,\n                              bool requires_grad, int use_bias, int activation);\n\ntemplate int mlp_fp<double>(double* X, int input_features, int batch_size, double** WPtr, int num_layers,\n                            int* output_features, double** BPtr, double* Y, double* reserved_space, int use_bias,\n                            int activation, void* lt_workspace);\n\ntemplate int mlp_bp<double>(double* X, double* Y, int input_features, int batch_size, double** WPtr, int num_layers,\n                            int* output_features, double* dY, double* reserved_space, double* work_space, double* dX,\n                            double** dwPtr, double** dbPtr, bool requires_grad, int use_bias, int activation);\n\ntemplate size_t get_mlp_bp_workspace_in_bytes<float>(int batch_size, int num_layers, const int* output_features);\ntemplate size_t get_mlp_bp_workspace_in_bytes<at::Half>(int batch_size, int num_layers, const int* output_features);\ntemplate size_t get_mlp_bp_workspace_in_bytes<double>(int batch_size, int num_layers, const int* output_features);\n"
  },
  {
    "path": "csrc/multi_tensor_adagrad.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n\n#include \"multi_tensor_apply.cuh\"\n#include \"type_shim.h\"\n\n#define BLOCK_SIZE 1024\n#define ILP 4\n\ntypedef enum {\n  ADAGRAD_MODE_0 = 0,  // L2 regularization mode.\n  ADAGRAD_MODE_1 = 1,  // AdamW-style weight decay.\n\n} adagradMode_t;\n\nusing MATH_T = float;\n\ntemplate <typename T>\nstruct AdagradFunctor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<3>& tl,\n                                             const float epsilon, const float lr, adagradMode_t mode,\n                                             const float weight_decay) {\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    T* g = (T*)tl.addresses[0][tensor_loc];\n    g += chunk_idx * chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx * chunk_size;\n\n    T* h = (T*)tl.addresses[2][tensor_loc];\n    h += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    // see note in multi_tensor_scale_kernel.cu\n    for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n      MATH_T r_g[ILP];\n      MATH_T r_p[ILP];\n      MATH_T r_h[ILP];\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          r_g[ii] = g[i];\n          r_p[ii] = p[i];\n          r_h[ii] = h[i];\n        } else {\n          r_g[ii] = MATH_T(0);\n          r_p[ii] = MATH_T(0);\n          r_h[ii] = MATH_T(0);\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        if (mode == ADAGRAD_MODE_0) {  // L2\n          r_g[ii] = r_g[ii] + weight_decay * r_p[ii];\n          r_h[ii] = r_h[ii] + r_g[ii] * r_g[ii];\n          r_p[ii] = r_p[ii] - lr * (r_g[ii] / (sqrtf(r_h[ii]) + epsilon));\n        } else {  // AdamW-style\n          r_h[ii] = r_h[ii] + r_g[ii] * r_g[ii];\n          r_p[ii] = r_p[ii] - lr * (r_g[ii] / (sqrtf(r_h[ii]) + epsilon) + weight_decay * r_p[ii]);\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          p[i] = r_p[ii];\n          h[i] = r_h[ii];\n        }\n      }\n    }\n  }\n};\n\nvoid multi_tensor_adagrad_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                               const float lr, const float epsilon, const int mode, const float weight_decay) {\n  using namespace at;\n\n  // Assume single type across p,g,h now\n  DISPATCH_DOUBLE_FLOAT_AND_HALF(\n      tensor_lists[0][0].scalar_type(), 0, \"adagrad\",\n      multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, AdagradFunctor<scalar_t_0>(), epsilon, lr,\n                            (adagradMode_t)mode, weight_decay);)\n\n  AT_CUDA_CHECK(cudaGetLastError());\n}\n"
  },
  {
    "path": "csrc/multi_tensor_adam.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n\n#include \"multi_tensor_apply.cuh\"\n#include \"type_shim.h\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntypedef enum {\n  ADAM_MODE_0 = 0,  // L2 regularization mode\n  ADAM_MODE_1 = 1   // Decoupled weight decay mode(AdamW)\n} adamMode_t;\n\nusing MATH_T = float;\n\ntemplate <typename T, typename FULL_T, typename index_t>\nstruct AdamFunctor {\n  __device__ __forceinline__ void operator()(index_t chunk_size, volatile int* noop_gmem, TensorListMetadata<4>& tl,\n                                             const float beta1, const float beta2, const float beta1_correction,\n                                             const float beta2_correction, const float epsilon, const float lr,\n                                             adamMode_t mode, const float decay) {\n    // I'd like this kernel to propagate infs/nans.\n    // if(*noop_gmem == 1)\n    //   return;\n\n    index_t tensor_loc = tl.block_to_tensor[blockIdx.x];\n\n    // potentially use to pass in list of scalar\n    // int tensor_num = tl.start_tensor_this_launch + tensor_loc;\n\n    index_t chunk_idx = tl.block_to_chunk[blockIdx.x];\n    index_t n = tl.sizes[tensor_loc];\n\n    T* g = (T*)tl.addresses[0][tensor_loc];\n    g += chunk_idx * chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx * chunk_size;\n\n    FULL_T* m = (FULL_T*)tl.addresses[2][tensor_loc];\n    m += chunk_idx * chunk_size;\n\n    FULL_T* v = (FULL_T*)tl.addresses[3][tensor_loc];\n    v += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    // see note in multi_tensor_scale_kernel.cu\n    for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n      MATH_T r_g[ILP];\n      MATH_T r_p[ILP];\n      MATH_T r_m[ILP];\n      MATH_T r_v[ILP];\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          r_g[ii] = g[i];\n          r_p[ii] = p[i];\n          r_m[ii] = m[i];\n          r_v[ii] = v[i];\n        } else {\n          r_g[ii] = MATH_T(0);\n          r_p[ii] = MATH_T(0);\n          r_m[ii] = MATH_T(0);\n          r_v[ii] = MATH_T(0);\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        if (mode == ADAM_MODE_0) {  // L2\n          r_g[ii] = r_g[ii] + (decay * r_p[ii]);\n          r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];\n          r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];\n          MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n          MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n          MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n          MATH_T update = next_m_unbiased / denom;\n          r_p[ii] = r_p[ii] - (lr * update);\n        } else {  // weight decay\n          r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];\n          r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];\n          MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n          MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n          MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n          MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);\n          r_p[ii] = r_p[ii] - (lr * update);\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          p[i] = r_p[ii];\n          m[i] = r_m[ii];\n          v[i] = r_v[ii];\n        }\n      }\n    }\n  }\n};\n\ntemplate <typename T, typename FULL_T>\nstruct AdamCapturableFunctor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<4>& tl,\n                                             const float beta1, const float beta2, const int* step,\n                                             const int bias_correction, const float epsilon, const float* lr,\n                                             adamMode_t mode, const float decay, const float* inv_scale) {\n    if (*noop_gmem == 1) return;\n\n    float beta1_correction = 1.0f, beta2_correction = 1.0f;\n    if (bias_correction == 1) {\n      beta1_correction = 1 - pow(beta1, *step);\n      beta2_correction = 1 - pow(beta2, *step);\n    }\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n\n    // potentially use to pass in list of scalar\n    // int tensor_num = tl.start_tensor_this_launch + tensor_loc;\n\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    T* g = (T*)tl.addresses[0][tensor_loc];\n    g += chunk_idx * chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx * chunk_size;\n\n    FULL_T* m = (FULL_T*)tl.addresses[2][tensor_loc];\n    m += chunk_idx * chunk_size;\n\n    FULL_T* v = (FULL_T*)tl.addresses[3][tensor_loc];\n    v += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    // see note in multi_tensor_scale_kernel.cu\n    for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n      MATH_T r_g[ILP];\n      MATH_T r_p[ILP];\n      MATH_T r_m[ILP];\n      MATH_T r_v[ILP];\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          r_g[ii] = static_cast<MATH_T>(g[i]) * (*inv_scale);\n          g[i] = static_cast<T>(r_g[ii]);\n          r_p[ii] = static_cast<MATH_T>(p[i]);\n          r_m[ii] = static_cast<MATH_T>(m[i]);\n          r_v[ii] = static_cast<MATH_T>(v[i]);\n        } else {\n          r_g[ii] = MATH_T(0);\n          r_p[ii] = MATH_T(0);\n          r_m[ii] = MATH_T(0);\n          r_v[ii] = MATH_T(0);\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        if (mode == ADAM_MODE_0) {  // L2\n          r_g[ii] = r_g[ii] + (decay * r_p[ii]);\n          r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];\n          r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];\n          MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n          MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n          MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n          MATH_T update = next_m_unbiased / denom;\n          r_p[ii] = r_p[ii] - (*lr * update);\n        } else {  // weight decay\n          r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];\n          r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];\n          MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n          MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n          MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n          MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);\n          r_p[ii] = r_p[ii] - (*lr * update);\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          p[i] = static_cast<T>(r_p[ii]);\n          m[i] = static_cast<T>(r_m[ii]);\n          v[i] = static_cast<T>(r_v[ii]);\n        }\n      }\n    }\n  }\n};\n\ntemplate <typename T, typename FULL_T>\nstruct AdamCapturableMasterFunctor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<5>& tl,\n                                             const float beta1, const float beta2, const int* step,\n                                             const int bias_correction, const float epsilon, const float* lr,\n                                             adamMode_t mode, const float decay, const float* inv_scale) {\n    if (*noop_gmem == 1) return;\n\n    float beta1_correction = 1.0f, beta2_correction = 1.0f;\n    if (bias_correction == 1) {\n      beta1_correction = 1 - pow(beta1, *step);\n      beta2_correction = 1 - pow(beta2, *step);\n    }\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n\n    // potentially use to pass in list of scalar\n    // int tensor_num = tl.start_tensor_this_launch + tensor_loc;\n\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    T* g = (T*)tl.addresses[0][tensor_loc];\n    g += chunk_idx * chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx * chunk_size;\n\n    FULL_T* m = (FULL_T*)tl.addresses[2][tensor_loc];\n    m += chunk_idx * chunk_size;\n\n    FULL_T* v = (FULL_T*)tl.addresses[3][tensor_loc];\n    v += chunk_idx * chunk_size;\n\n    FULL_T* p_master = (FULL_T*)tl.addresses[4][tensor_loc];\n    p_master += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    // see note in multi_tensor_scale_kernel.cu\n    for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n      MATH_T r_g[ILP];\n      MATH_T r_p[ILP];\n      MATH_T r_m[ILP];\n      MATH_T r_v[ILP];\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          r_g[ii] = static_cast<MATH_T>(g[i]) * (*inv_scale);\n          g[i] = static_cast<T>(r_g[ii]);\n          r_p[ii] = static_cast<MATH_T>(p_master[i]);\n          r_m[ii] = static_cast<MATH_T>(m[i]);\n          r_v[ii] = static_cast<MATH_T>(v[i]);\n        } else {\n          r_g[ii] = MATH_T(0);\n          r_p[ii] = MATH_T(0);\n          r_m[ii] = MATH_T(0);\n          r_v[ii] = MATH_T(0);\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        if (mode == ADAM_MODE_0) {  // L2\n          r_g[ii] = r_g[ii] + (decay * r_p[ii]);\n          r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];\n          r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];\n          MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n          MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n          MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n          MATH_T update = next_m_unbiased / denom;\n          r_p[ii] = r_p[ii] - (*lr * update);\n        } else {  // weight decay\n          r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];\n          r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];\n          MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n          MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n          MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n          MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);\n          r_p[ii] = r_p[ii] - (*lr * update);\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          p[i] = static_cast<T>(r_p[ii]);\n          p_master[i] = static_cast<FULL_T>(r_p[ii]);\n          m[i] = static_cast<FULL_T>(r_m[ii]);\n          v[i] = static_cast<FULL_T>(r_v[ii]);\n        }\n      }\n    }\n  }\n};\n\nvoid multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                            const float lr, const float beta1, const float beta2, const float epsilon, const int step,\n                            const int mode, const int bias_correction, const float weight_decay) {\n  using namespace at;\n\n  // Handle bias correction mode\n  float bias_correction1 = 1.0f, bias_correction2 = 1.0f;\n  if (bias_correction == 1) {\n    bias_correction1 = 1 - std::pow(beta1, step);\n    bias_correction2 = 1 - std::pow(beta2, step);\n  }\n\n  size_t max_size = 0;\n  bool requires_64bit_indexing = false;\n  for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {\n    for (auto it2 = it->begin(); it2 != it->end(); it2++) {\n      if (it2->numel() > max_size) {\n        max_size = it2->numel();\n        if (max_size >= INT_MAX) {\n          requires_64bit_indexing = true;\n          break;\n        }\n      }\n    }\n    if (requires_64bit_indexing) {\n      break;\n    }\n  }\n\n  if (requires_64bit_indexing) {\n    // Assume single type across p,g,m1,m2 now\n    DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(\n        tensor_lists[0][0].scalar_type(), 0, \"adam\",\n        multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,\n                              AdamFunctor<scalar_t_0, float, int64_t>(), beta1, beta2, bias_correction1,\n                              bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);)\n  } else {\n    // Assume single type across p,g,m1,m2 now\n    DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(\n        tensor_lists[0][0].scalar_type(), 0, \"adam\",\n        multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,\n                              AdamFunctor<scalar_t_0, float, int32_t>(), beta1, beta2, bias_correction1,\n                              bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);)\n  }\n  AT_CUDA_CHECK(cudaGetLastError());\n}\n\nvoid multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,\n                                       std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor lr,\n                                       const float beta1, const float beta2, const float epsilon, at::Tensor step,\n                                       const int mode, const int bias_correction, const float weight_decay,\n                                       at::Tensor inv_scale) {\n  using namespace at;\n\n  DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(\n      tensor_lists[0][0].scalar_type(), 0, \"adam\",\n      multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, AdamCapturableFunctor<scalar_t_0, float>(),\n                            beta1, beta2, step.data_ptr<int>(), bias_correction, epsilon, lr.data_ptr<float>(),\n                            (adamMode_t)mode, weight_decay, inv_scale.data_ptr<float>());)\n\n  AT_CUDA_CHECK(cudaGetLastError());\n}\n\nvoid multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag,\n                                              std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor lr,\n                                              const float beta1, const float beta2, const float epsilon,\n                                              at::Tensor step, const int mode, const int bias_correction,\n                                              const float weight_decay, at::Tensor inv_scale) {\n  using namespace at;\n\n  DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(\n      tensor_lists[0][0].scalar_type(), 0, \"adam\",\n      multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,\n                            AdamCapturableMasterFunctor<scalar_t_0, float>(), beta1, beta2, step.data_ptr<int>(),\n                            bias_correction, epsilon, lr.data_ptr<float>(), (adamMode_t)mode, weight_decay,\n                            inv_scale.data_ptr<float>());)\n\n  AT_CUDA_CHECK(cudaGetLastError());\n}\n"
  },
  {
    "path": "csrc/multi_tensor_apply.cuh",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n#include <assert.h>\n#include <c10/cuda/CUDAGuard.h>\n\n// #include <iostream>\n\n// This header is the one-stop shop for all your multi-tensor apply needs.\n\n// TODO:  Kernel arg size limit may be <4KB for some other cards (ie Jetson)\nconstexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24};\nconstexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320};\n\ntemplate <int n>\nstruct TensorListMetadata {\n  void* addresses[n][depth_to_max_tensors[n - 1]];\n  int64_t sizes[depth_to_max_tensors[n - 1]];\n  unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];\n  int block_to_chunk[depth_to_max_blocks[n - 1]];  // I fear this needs to be a full int.\n  int start_tensor_this_launch;\n};\n\ntemplate <typename T, typename U, typename... ArgTypes>\n__global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int* noop_flag, T tl, U callable,\n                                          ArgTypes... args) {\n  // Hand the chunk information to the user-supplied functor to process however it likes.\n  callable(chunk_size, noop_flag, tl, args...);\n}\n\ntemplate <int depth, typename T, typename... ArgTypes>\nvoid multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor& noop_flag,\n                        const std::vector<std::vector<at::Tensor>>& tensor_lists, T callable, ArgTypes... args) {\n  TORCH_CHECK(tensor_lists.size() == depth, \"tensor_lists.size() != depth\");\n  int len0 = tensor_lists[0].size();\n  TORCH_CHECK(len0 > 0, \"tensor_lists[0].size() is not > 0\");\n  auto ref_device = tensor_lists[0][0].device();\n  TORCH_CHECK(ref_device.type() == at::kCUDA, \"expected input to be on cuda\");\n  for (int l = 0; l < tensor_lists.size(); l++)  // No range-based for because I need indices\n  {\n    TORCH_CHECK(tensor_lists[l].size() == len0, \"Size mismatch among tensor lists\");\n    for (int t = 0; t < tensor_lists[l].size(); t++) {\n      // TODO:  Print which tensor fails.\n      bool contiguous_memory = tensor_lists[l][t].is_contiguous();\n      contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) ||\n                           tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d));\n      TORCH_CHECK(contiguous_memory, \"A tensor was not contiguous.\");\n      TORCH_CHECK(tensor_lists[l][t].device() == ref_device, \"A tensor was not on the same device as the first tensor\");\n      TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), \"Size mismatch\");\n    }\n  }\n\n  int ntensors = tensor_lists[0].size();\n\n  TensorListMetadata<depth> tl;\n\n  const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  tl.start_tensor_this_launch = 0;\n  int loc_block_info = 0;\n  int loc_tensor_info = 0;\n  for (int t = 0; t < ntensors; t++) {\n    tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();\n    for (int d = 0; d < depth; d++) tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();\n    loc_tensor_info++;\n\n    auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;\n\n    for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) {\n      // std::cout << chunks_this_tensor << std::endl;\n      tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;\n      tl.block_to_chunk[loc_block_info] = chunk;\n      loc_block_info++;\n\n      bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] && chunk == chunks_this_tensor - 1);\n      bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);\n      bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);\n      if (tensors_full || blocks_full || last_chunk) {\n        // using accscalar_t = acc_type<scalar_t, true>;\n        multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(chunk_size, noop_flag.data_ptr<int>(), tl,\n                                                                             callable, args...);\n\n        AT_CUDA_CHECK(cudaGetLastError());\n\n        // Reset.  The control flow possibilities here make my brain hurt.\n        loc_block_info = 0;\n        if (chunk == chunks_this_tensor - 1) {\n          // std::cout << \"Hit case 1 \" << cond1 << \" \" << cond2 << \" \" << cond3 << std::endl;\n          loc_tensor_info = 0;\n          tl.start_tensor_this_launch = t + 1;\n        } else {\n          // std::cout << \"Hit case 2 \" << cond1 << \" \" << cond2 << \" \" << cond3 << std::endl;\n          tl.sizes[0] = tl.sizes[loc_tensor_info - 1];\n          for (int d = 0; d < depth; d++) tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];\n          loc_tensor_info = 1;\n          tl.start_tensor_this_launch = t;\n        }\n      }\n    }\n  }\n}\n"
  },
  {
    "path": "csrc/multi_tensor_axpby_kernel.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n\n#include \"multi_tensor_apply.cuh\"\n#include \"type_shim.h\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntemplate <typename T>\n__device__ __forceinline__ bool is_aligned(T* p) {\n  return ((uint64_t)p) % (ILP * sizeof(T)) == 0;\n}\n\ntemplate <typename T>\n__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset) {\n  typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;\n  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];\n}\n\ntemplate <typename x_t, typename y_t, typename out_t>\nstruct AxpbyFunctor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<3>& tl,\n                                             float a, float b, int arg_to_check) {\n    // I'd like this kernel to propagate infs/nans.\n    // if(*noop_gmem == 1)\n    //   return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    x_t* x = (x_t*)tl.addresses[0][tensor_loc];\n    x += chunk_idx * chunk_size;\n\n    y_t* y = (y_t*)tl.addresses[1][tensor_loc];\n    y += chunk_idx * chunk_size;\n\n    out_t* out = (out_t*)tl.addresses[2][tensor_loc];\n    out += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    bool finite = true;\n    x_t r_x[ILP];\n    y_t r_y[ILP];\n    out_t r_out[ILP];\n\n    // to make things simple, we put aligned case in a different code path\n    if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x) && is_aligned(y) && is_aligned(out)) {\n      for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {\n        // load\n        load_store(r_x, x, 0, i_start);\n        load_store(r_y, y, 0, i_start);\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          r_out[ii] = a * static_cast<float>(r_x[ii]) + b * static_cast<float>(r_y[ii]);\n          if (arg_to_check == -1) finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii]));\n          if (arg_to_check == 0) finite = finite && isfinite(r_x[ii]);\n          if (arg_to_check == 1) finite = finite && isfinite(r_y[ii]);\n        }\n        // store\n        load_store(out, r_out, i_start, 0);\n      }\n    } else {\n      // Non-divergent exit condition for __syncthreads, not necessary here\n      for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          r_x[ii] = 0;\n          r_y[ii] = 0;\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            r_x[ii] = x[i];\n            r_y[ii] = y[i];\n          }\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          r_out[ii] = a * static_cast<float>(r_x[ii]) + b * static_cast<float>(r_y[ii]);\n          if (arg_to_check == -1) finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii]));\n          if (arg_to_check == 0) finite = finite && isfinite(r_x[ii]);\n          if (arg_to_check == 1) finite = finite && isfinite(r_y[ii]);\n        }\n        // see note in multi_tensor_scale_kernel.cu\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) out[i] = r_out[ii];\n        }\n      }\n    }\n    if (!finite) *noop_gmem = 1;  // Blindly fire off a write.  These will race but that's ok.\n  }\n};\n\nvoid multi_tensor_axpby_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                             float a, float b, int arg_to_check) {\n  using namespace at;\n  // The output (downscaled) type is always float.\n  // If build times suffer, think about where to put this dispatch,\n  // and what logic should be moved out of multi_tensor_apply.\n\n  DISPATCH_FLOAT_AND_HALF(\n      tensor_lists[0][0].scalar_type(), 0, \"multi_tensor_axpby_cuda\",\n      DISPATCH_FLOAT_AND_HALF(\n          tensor_lists[1][0].scalar_type(), 1, \"multi_tensor_axpby_cuda\",\n          DISPATCH_FLOAT_AND_HALF(\n              tensor_lists[2][0].scalar_type(), 2, \"multi_tensor_axpby_cuda\",\n              multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,\n                                    AxpbyFunctor<scalar_t_0, scalar_t_1, scalar_t_2>(), a, b, arg_to_check);)))\n\n  AT_CUDA_CHECK(cudaGetLastError());\n\n  // AT_CUDA_CHECK(cudaDeviceSynchronize());\n}\n"
  },
  {
    "path": "csrc/multi_tensor_l2norm_kernel.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n#include <c10/cuda/CUDAGuard.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n\n#include \"multi_tensor_apply.cuh\"\n#include \"type_shim.h\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntemplate <typename T>\n__device__ __forceinline__ bool is_aligned(T* p) {\n  return ((uint64_t)p) % (ILP * sizeof(T)) == 0;\n}\n\ntemplate <typename T>\n__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset) {\n  typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;\n  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];\n}\n\ntemplate <typename x_t>\nstruct L2NormFunctor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl,\n                                             float* output, float* output_per_tensor, bool per_tensor,\n                                             int max_chunks_per_tensor) {\n    // I'd like this kernel to propagate infs/nans.\n    // if(*noop_gmem == 1)\n    //   return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    x_t* x = (x_t*)tl.addresses[0][tensor_loc];\n    x += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    __shared__ float s_vals[512];\n\n    float vals[ILP];  // = {0}; // this probably works too but I want to be sure...\n    x_t r_x[ILP];\n    for (int i = 0; i < ILP; i++) {\n      vals[i] = 0.f;\n      r_x[i] = 0;\n    }\n\n    // to make things simple, we put aligned case in a different code path\n    if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {\n      for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {\n        // load\n        load_store(r_x, x, 0, i_start);\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          float next = static_cast<float>(r_x[ii]);\n          vals[ii] += next * next;\n        }\n      }\n    } else {\n      for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            float next = static_cast<float>(x[i]);\n            vals[ii] += next * next;\n          }\n        }\n      }\n    }\n\n    float val = 0.f;\n    for (int i = 0; i < ILP; i++) val += vals[i];\n\n    float final = reduce_block_into_lanes(s_vals, val);\n\n    if (threadIdx.x == 0) {\n      if (!isfinite(final)) *noop_gmem = 1;  // Blindly fire off a write.  These will race but that's ok.\n      output[blockIdx.x] += final;\n      if (per_tensor)\n        output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor + chunk_idx] = final;\n    }\n  }\n};\n\ntemplate <typename x_t>\nstruct UnscaleL2NormFunctor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl,\n                                             const float* inv_scale, float* output, float* output_per_tensor,\n                                             bool per_tensor, int max_chunks_per_tensor) {\n    // I'd like this kernel to propagate infs/nans.\n    // if(*noop_gmem == 1)\n    //   return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    x_t* x = (x_t*)tl.addresses[0][tensor_loc];\n    x += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    __shared__ float s_vals[512];\n\n    float vals[ILP];  // = {0}; // this probably works too but I want to be sure...\n    x_t r_x[ILP];\n    for (int i = 0; i < ILP; i++) {\n      vals[i] = 0.f;\n      r_x[i] = 0;\n    }\n\n    // to make things simple, we put aligned case in a different code path\n    if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {\n      for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {\n        // load\n        load_store(r_x, x, 0, i_start);\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          float next = static_cast<float>(r_x[ii]) * (*inv_scale);\n          vals[ii] += next * next;\n        }\n      }\n    } else {\n      for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            float next = static_cast<float>(x[i]) * (*inv_scale);\n            vals[ii] += next * next;\n          }\n        }\n      }\n    }\n\n    float val = 0.f;\n    for (int i = 0; i < ILP; i++) val += vals[i];\n\n    float final = reduce_block_into_lanes(s_vals, val);\n\n    if (threadIdx.x == 0) {\n      if (!isfinite(final)) *noop_gmem = 1;  // Blindly fire off a write.  These will race but that's ok.\n      output[blockIdx.x] += final;\n      if (per_tensor)\n        output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor + chunk_idx] = final;\n    }\n  }\n};\n\n// Probably better to template, but since we are not likely to support other norm\ntemplate <typename x_t>\nstruct MaxNormFunctor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl,\n                                             float* output, float* output_per_tensor, bool per_tensor,\n                                             int max_chunks_per_tensor) {\n    // I'd like this kernel to propagate infs/nans.\n    // if(*noop_gmem == 1)\n    //   return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    x_t* x = (x_t*)tl.addresses[0][tensor_loc];\n    x += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    __shared__ float s_vals[512];\n\n    float vals[ILP];  // = {0}; // this probably works too but I want to be sure...\n    x_t r_x[ILP];\n    for (int i = 0; i < ILP; i++) {\n      vals[i] = 0.f;\n      r_x[i] = 0;\n    }\n\n    // to make things simple, we put aligned case in a different code path\n    if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {\n      for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {\n        // load\n        load_store(r_x, x, 0, i_start);\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          float next = static_cast<float>(r_x[ii]);\n          vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));\n        }\n      }\n    } else {\n      for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            float next = static_cast<float>(x[i]);\n            vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));\n          }\n        }\n      }\n    }\n\n    float val = 0.f;\n    for (int i = 0; i < ILP; i++) val = fmaxf(fabsf(val), fabsf(vals[i]));\n\n    float final = reduce_block_into_lanes_max_op(s_vals, val);\n\n    if (threadIdx.x == 0) {\n      if (!isfinite(final)) *noop_gmem = 1;  // Blindly fire off a write.  These will race but that's ok.\n      output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final));\n      if (per_tensor)\n        output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor + chunk_idx] = final;\n    }\n  }\n};\n\n__global__ void cleanup(float* output, float* output_per_tensor, float* ret, float* ret_per_tensor, bool per_tensor,\n                        int max_chunks_per_tensor) {\n  __shared__ float vals[512];\n\n  if (blockIdx.x == 0) {\n    float val = 0;\n    if (threadIdx.x < 320) val = output[threadIdx.x];\n\n    float final = reduce_block_into_lanes(vals, val);\n\n    if (threadIdx.x == 0) *ret = sqrt(final);\n  }\n\n  if (per_tensor) {\n    float* output_this_tensor = output_per_tensor + blockIdx.x * max_chunks_per_tensor;\n\n    float val = 0;\n    for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) val += output_this_tensor[i];\n\n    float final = reduce_block_into_lanes(vals, val);\n\n    if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final);\n  }\n}\n\n__global__ void cleanup_v2(float* output, float* output_per_tensor, float* ret, float* ret_per_tensor, bool per_tensor,\n                           int max_chunks_per_tensor, int norm_type, float alpha, float beta) {\n  __shared__ float vals[512];\n\n  if (blockIdx.x == 0) {\n    float val = 0;\n    if (threadIdx.x < 320) val = output[threadIdx.x];\n\n    if (norm_type == 0) {\n      float final = reduce_block_into_lanes_max_op(vals, val);\n      if (threadIdx.x == 0) *ret = alpha * (*ret) + beta * final;\n    } else {\n      float final = reduce_block_into_lanes(vals, val);\n      if (threadIdx.x == 0) *ret = sqrt(alpha * (*ret) * (*ret) + beta * final);\n    }\n  }\n\n  if (per_tensor) {\n    float* output_this_tensor = output_per_tensor + blockIdx.x * max_chunks_per_tensor;\n\n    if (norm_type == 0) {\n      float val = 0;\n      for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)\n        val = fmaxf(fabsf(val), fabsf(output_this_tensor[i]));\n\n      float final = reduce_block_into_lanes_max_op(vals, val);\n\n      if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = alpha * ret_per_tensor[blockIdx.x] + beta * final;\n    } else {\n      float val = 0;\n      for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) val += output_this_tensor[i];\n\n      float final = reduce_block_into_lanes(vals, val);\n\n      if (threadIdx.x == 0)\n        ret_per_tensor[blockIdx.x] =\n            sqrt(alpha * ret_per_tensor[blockIdx.x] * ret_per_tensor[blockIdx.x] + beta * final);\n    }\n  }\n}\n\nstd::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,\n                                                            std::vector<std::vector<at::Tensor>> tensor_lists,\n                                                            at::optional<bool> per_tensor_python) {\n  bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;\n\n  auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);\n  auto output = at::zeros({320}, float_options);\n\n  at::Tensor output_per_tensor;\n  at::Tensor ret_per_tensor;\n\n  int ntensors = tensor_lists[0].size();\n  int max_chunks_per_tensor = -1;\n\n  if (per_tensor) {\n    for (int t = 0; t < ntensors; t++) {\n      int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;\n      if (max_chunks_this_tensor > max_chunks_per_tensor) max_chunks_per_tensor = max_chunks_this_tensor;\n    }\n    output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);\n    ret_per_tensor = at::empty({ntensors}, float_options);\n  } else {\n    ret_per_tensor = at::empty({0}, float_options);\n  }\n\n  DISPATCH_FLOAT_HALF_AND_BFLOAT(\n      tensor_lists[0][0].scalar_type(), 0, \"multi_tensor_l2norm_cuda\",\n      multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor<scalar_t_0>(),\n                            output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,\n                            per_tensor, max_chunks_per_tensor);)\n\n  AT_CUDA_CHECK(cudaGetLastError());\n  // AT_CUDA_CHECK(cudaDeviceSynchronize());\n\n  // This involves one more small kernel launches, but will be negligible end to end.\n  // I could get rid of these by hacking the functor + multi tensor harness with persistence\n  // logic, but keeping it simple for now\n  auto ret = at::empty({1}, output.options());\n  const at::cuda::OptionalCUDAGuard device_guard(device_of(output));\n  auto stream = at::cuda::getCurrentCUDAStream();\n  cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(\n      output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, ret.data_ptr<float>(),\n      per_tensor ? ret_per_tensor.data_ptr<float>() : nullptr, per_tensor, max_chunks_per_tensor);\n\n  return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);\n}\n\nstd::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(int chunk_size, at::Tensor noop_flag,\n                                                                    std::vector<std::vector<at::Tensor>> tensor_lists,\n                                                                    at::Tensor inv_scale,\n                                                                    at::optional<bool> per_tensor_python) {\n  bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;\n\n  auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);\n  auto output = at::zeros({320}, float_options);\n\n  at::Tensor output_per_tensor;\n  at::Tensor ret_per_tensor;\n\n  int ntensors = tensor_lists[0].size();\n  int max_chunks_per_tensor = -1;\n\n  if (per_tensor) {\n    for (int t = 0; t < ntensors; t++) {\n      int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;\n      if (max_chunks_this_tensor > max_chunks_per_tensor) max_chunks_per_tensor = max_chunks_this_tensor;\n    }\n    output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);\n    ret_per_tensor = at::empty({ntensors}, float_options);\n  } else {\n    ret_per_tensor = at::empty({0}, float_options);\n  }\n\n  DISPATCH_FLOAT_HALF_AND_BFLOAT(\n      tensor_lists[0][0].scalar_type(), 0, \"multi_tensor_unscale_l2norm_cuda\",\n      multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, UnscaleL2NormFunctor<scalar_t_0>(),\n                            inv_scale.data_ptr<float>(), output.data_ptr<float>(),\n                            per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, per_tensor,\n                            max_chunks_per_tensor);)\n\n  AT_CUDA_CHECK(cudaGetLastError());\n  // AT_CUDA_CHECK(cudaDeviceSynchronize());\n\n  // This involves one more small kernel launches, but will be negligible end to end.\n  // I could get rid of these by hacking the functor + multi tensor harness with persistence\n  // logic, but keeping it simple for now\n  auto ret = at::empty({1}, output.options());\n  const at::cuda::OptionalCUDAGuard device_guard(device_of(output));\n  auto stream = at::cuda::getCurrentCUDAStream();\n  cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(\n      output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, ret.data_ptr<float>(),\n      per_tensor ? ret_per_tensor.data_ptr<float>() : nullptr, per_tensor, max_chunks_per_tensor);\n\n  return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);\n}\n\n// Compute and update grad norm\n// Here use a per tensor norm, and blend new norm(n) and old norm(gn) by\n// L-2: gn = sqrt(a * gn^2 + b * n^2)\n// L-inf: gn = a * gn + b * n\nvoid multi_tensor_norm_out_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                                at::Tensor out, const float alpha, const float beta, const int norm_type) {\n  auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);\n  TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(), \"noop flag should be on the same device as tensors\");\n  // we don't need global thus uses empty here\n  auto output = at::empty({320}, float_options);\n\n  at::Tensor output_per_tensor;\n  at::Tensor ret_per_tensor;\n\n  int ntensors = tensor_lists[0].size();\n  int max_chunks_per_tensor = -1;\n\n  for (int t = 0; t < ntensors; t++) {\n    int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;\n    if (max_chunks_this_tensor > max_chunks_per_tensor) max_chunks_per_tensor = max_chunks_this_tensor;\n  }\n\n  // Although it is single write then read, still need to be zero\n  // Since tailing element also participate cleanup\n  output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);\n\n  if (norm_type == 0) {\n    DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, \"multi_tensor_maxnorm_cuda\",\n                            multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,\n                                                  MaxNormFunctor<scalar_t_0>(), output.data_ptr<float>(),\n                                                  output_per_tensor.data_ptr<float>(), true, max_chunks_per_tensor);)\n  } else {\n    DISPATCH_FLOAT_HALF_AND_BFLOAT(\n        tensor_lists[0][0].scalar_type(), 0, \"multi_tensor_l2norm_cuda\",\n        multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor<scalar_t_0>(),\n                              output.data_ptr<float>(), output_per_tensor.data_ptr<float>(), true,\n                              max_chunks_per_tensor);)\n  }\n  AT_CUDA_CHECK(cudaGetLastError());\n\n  // AT_CUDA_CHECK(cudaDeviceSynchronize());\n\n  // This involves one more small kernel launches, but will be negligible end to end.\n  // I could get rid of these by hacking the functor + multi tensor harness with persistence\n  // logic, but keeping it simple for now\n  auto ret = at::empty({1}, output.options());\n\n  // Adding the following device guard since it happens sometimes that the\n  // tensors are on one device and the cuda stream is on another device which\n  // results in ILLEGAL MEM ACCESS error.\n  const at::cuda::OptionalCUDAGuard device_guard(device_of(output));\n  auto stream = at::cuda::getCurrentCUDAStream();\n  cleanup_v2<<<ntensors, 512, 0, stream>>>(output.data_ptr<float>(), output_per_tensor.data_ptr<float>(),\n                                           ret.data_ptr<float>(), out.data_ptr<float>(), true, max_chunks_per_tensor,\n                                           norm_type, alpha, beta);\n\n  return;\n}\n"
  },
  {
    "path": "csrc/multi_tensor_l2norm_kernel_mp.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n#include <c10/cuda/CUDAGuard.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n\n#include \"multi_tensor_apply.cuh\"\n#include \"type_shim.h\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntemplate <typename T>\n__device__ __forceinline__ bool is_aligned(T* p) {\n  return ((uint64_t)p) % (ILP * sizeof(T)) == 0;\n}\n\ntemplate <typename T>\n__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset) {\n  typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;\n  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];\n}\n\ntemplate <typename x_t>\nstruct L2NormFunctor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl,\n                                             float* output, float* output_per_tensor, bool per_tensor,\n                                             int max_chunks_per_tensor) {\n    if (*noop_gmem) {\n      return;\n    }\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    x_t* x = (x_t*)tl.addresses[0][tensor_loc];\n    x += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    __shared__ float s_vals[512];\n\n    float vals[ILP];  // = {0}; // this probably works too but I want to be sure...\n    x_t r_x[ILP];\n    for (int i = 0; i < ILP; i++) {\n      vals[i] = 0.f;\n      r_x[i] = 0;\n    }\n\n    // to make things simple, we put aligned case in a different code path\n    if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {\n      for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {\n        // load\n        load_store(r_x, x, 0, i_start);\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          float next = static_cast<float>(r_x[ii]);\n          vals[ii] += next * next;\n        }\n      }\n    } else {\n      for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            float next = static_cast<float>(x[i]);\n            vals[ii] += next * next;\n          }\n        }\n      }\n    }\n\n    float val = 0.f;\n    for (int i = 0; i < ILP; i++) val += vals[i];\n\n    float final = reduce_block_into_lanes(s_vals, val);\n\n    if (threadIdx.x == 0) {\n      if (!isfinite(final)) *noop_gmem = 1;  // Blindly fire off a write.  These will race but that's ok.\n      output[blockIdx.x] += final;\n      if (per_tensor)\n        output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor + chunk_idx] = final;\n    }\n  }\n};\n\n__global__ void cleanup(float* output, float* output_per_tensor, float* ret, float* ret_per_tensor, bool per_tensor,\n                        int max_chunks_per_tensor, volatile int* noop_gmem) {\n  if (*noop_gmem) {\n    return;\n  }\n  __shared__ float vals[512];\n\n  if (blockIdx.x == 0) {\n    float val = 0;\n    if (threadIdx.x < 320) val = output[threadIdx.x];\n\n    float final = reduce_block_into_lanes(vals, val);\n\n    if (threadIdx.x == 0) *ret = sqrt(final);\n  }\n\n  if (per_tensor) {\n    float* output_this_tensor = output_per_tensor + blockIdx.x * max_chunks_per_tensor;\n\n    float val = 0;\n    for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) val += output_this_tensor[i];\n\n    float final = reduce_block_into_lanes(vals, val);\n\n    if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final);\n  }\n}\n\nstd::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_mp_cuda(int chunk_size, at::Tensor noop_flag,\n                                                               std::vector<std::vector<at::Tensor>> tensor_lists,\n                                                               at::optional<bool> per_tensor_python) {\n  bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;\n\n  auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);\n  auto output = at::zeros({320}, float_options);\n\n  at::Tensor output_per_tensor;\n  at::Tensor ret_per_tensor;\n\n  int ntensors = tensor_lists[0].size();\n  int max_chunks_per_tensor = -1;\n\n  if (per_tensor) {\n    for (int t = 0; t < ntensors; t++) {\n      int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;\n      if (max_chunks_this_tensor > max_chunks_per_tensor) max_chunks_per_tensor = max_chunks_this_tensor;\n    }\n    output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);\n    ret_per_tensor = at::empty({ntensors}, float_options);\n  } else {\n    ret_per_tensor = at::empty({0}, float_options);\n  }\n\n  DISPATCH_FLOAT_HALF_AND_BFLOAT(\n      tensor_lists[0][0].scalar_type(), 0, \"multi_tensor_l2norm_mp_cuda\",\n      multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor<scalar_t_0>(),\n                            output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,\n                            per_tensor, max_chunks_per_tensor);)\n\n  AT_CUDA_CHECK(cudaGetLastError());\n  // AT_CUDA_CHECK(cudaDeviceSynchronize());\n\n  // This involves one more small kernel launches, but will be negligible end to end.\n  // I could get rid of these by hacking the functor + multi tensor harness with persistence\n  // logic, but keeping it simple for now\n  auto ret = at::empty({1}, output.options());\n  const at::cuda::OptionalCUDAGuard device_guard(device_of(output));\n  auto stream = at::cuda::getCurrentCUDAStream();\n  cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(\n      output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, ret.data_ptr<float>(),\n      per_tensor ? ret_per_tensor.data_ptr<float>() : nullptr, per_tensor, max_chunks_per_tensor,\n      noop_flag.data_ptr<int>());\n\n  return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);\n}\n"
  },
  {
    "path": "csrc/multi_tensor_l2norm_scale_kernel.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n#include <c10/cuda/CUDAGuard.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n\n#include \"multi_tensor_apply.cuh\"\n#include \"type_shim.h\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntemplate <typename T>\n__device__ __forceinline__ bool is_aligned(T* p) {\n  return ((uint64_t)p) % (ILP * sizeof(T)) == 0;\n}\n\ntemplate <typename T>\n__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset) {\n  typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;\n  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];\n}\n\ntemplate <typename in_t, typename out_t>\nstruct L2NormScaleFunctor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<2>& tl,\n                                             float* output, float* output_per_tensor, float scale, bool per_tensor,\n                                             int max_chunks_per_tensor) {\n    // I'd like this kernel to propagate infs/nans.\n    // if(*noop_gmem == 1)\n    //   return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    in_t* in = (in_t*)tl.addresses[0][tensor_loc];\n    in += chunk_idx * chunk_size;\n\n    out_t* out = (out_t*)tl.addresses[1][tensor_loc];\n    out += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    __shared__ float s_vals[512];\n\n    float vals[ILP];  // = {0}; // this probably works too but I want to be sure...\n    in_t r_in[ILP];\n    for (int i = 0; i < ILP; i++) {\n      vals[i] = 0.f;\n      r_in[i] = 0;\n    }\n    // bool finite = true;\n    out_t r_out[ILP];\n\n    // to make things simple, we put aligned case in a different code path\n    if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out)) {\n      for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {\n        // load\n        load_store(r_in, in, 0, i_start);\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          float next = static_cast<float>(r_in[ii]);\n          r_out[ii] = next * scale;\n          vals[ii] += next * next;\n          // finite = finite && isfinite(r_in[ii]);\n        }\n        load_store(out, r_out, i_start, 0);\n      }\n    } else {\n      for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          r_in[ii] = 0;\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            r_in[ii] = in[i];\n            float next = static_cast<float>(in[i]);\n            vals[ii] += next * next;\n          }\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          r_out[ii] = static_cast<float>(r_in[ii]) * scale;\n          // finite = finite && isfinite(r_in[ii]);\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) out[i] = r_out[ii];\n        }\n      }\n    }\n\n    float val = 0.f;\n    for (int i = 0; i < ILP; i++) val += vals[i];\n\n    float final = reduce_block_into_lanes(s_vals, val);\n\n    if (threadIdx.x == 0) {\n      if (!isfinite(final)) *noop_gmem = 1;  // Blindly fire off a write.  These will race but that's ok.\n      output[blockIdx.x] += final;\n      if (per_tensor)\n        output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor + chunk_idx] = final;\n    }\n  }\n};\n// Probably better to template, but since we are not likely to support other norm\ntemplate <typename x_t>\nstruct MaxNormFunctor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<1>& tl,\n                                             float* output, float* output_per_tensor, bool per_tensor,\n                                             int max_chunks_per_tensor) {\n    // I'd like this kernel to propagate infs/nans.\n    // if(*noop_gmem == 1)\n    //   return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    x_t* x = (x_t*)tl.addresses[0][tensor_loc];\n    x += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    __shared__ float s_vals[512];\n\n    float vals[ILP];  // = {0}; // this probably works too but I want to be sure...\n    x_t r_x[ILP];\n    for (int i = 0; i < ILP; i++) {\n      vals[i] = 0.f;\n      r_x[i] = 0;\n    }\n\n    // to make things simple, we put aligned case in a different code path\n    if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {\n      for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {\n        // load\n        load_store(r_x, x, 0, i_start);\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          float next = static_cast<float>(r_x[ii]);\n          vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));\n        }\n      }\n    } else {\n      for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            float next = static_cast<float>(x[i]);\n            vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));\n          }\n        }\n      }\n    }\n\n    float val = 0.f;\n    for (int i = 0; i < ILP; i++) val = fmaxf(fabsf(val), fabsf(vals[i]));\n\n    float final = reduce_block_into_lanes_max_op(s_vals, val);\n\n    if (threadIdx.x == 0) {\n      if (!isfinite(final)) *noop_gmem = 1;  // Blindly fire off a write.  These will race but that's ok.\n      output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final));\n      if (per_tensor)\n        output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor + chunk_idx] = final;\n    }\n  }\n};\n\n__global__ void cleanup_v3(float* output, float* output_per_tensor, float* ret, float* ret_per_tensor, bool per_tensor,\n                           int max_chunks_per_tensor) {\n  __shared__ float vals[512];\n\n  if (blockIdx.x == 0) {\n    float val = 0;\n    if (threadIdx.x < 320) val = output[threadIdx.x];\n\n    float final = reduce_block_into_lanes(vals, val);\n\n    if (threadIdx.x == 0) *ret = sqrt(final);\n  }\n\n  if (per_tensor) {\n    float* output_this_tensor = output_per_tensor + blockIdx.x * max_chunks_per_tensor;\n\n    float val = 0;\n    for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) val += output_this_tensor[i];\n\n    float final = reduce_block_into_lanes(vals, val);\n\n    if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final);\n  }\n}\n\nstd::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_scale_cuda(int chunk_size, at::Tensor noop_flag,\n                                                                  std::vector<std::vector<at::Tensor>> tensor_lists,\n                                                                  float scale, at::optional<bool> per_tensor_python) {\n  bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;\n\n  auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);\n  auto output = at::zeros({320}, float_options);\n\n  at::Tensor output_per_tensor;\n  at::Tensor ret_per_tensor;\n\n  int ntensors = tensor_lists[0].size();\n  int max_chunks_per_tensor = -1;\n\n  if (per_tensor) {\n    for (int t = 0; t < ntensors; t++) {\n      int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;\n      if (max_chunks_this_tensor > max_chunks_per_tensor) max_chunks_per_tensor = max_chunks_this_tensor;\n    }\n    output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);\n    ret_per_tensor = at::empty({ntensors}, float_options);\n  } else {\n    ret_per_tensor = at::empty({0}, float_options);\n  }\n\n  DISPATCH_FLOAT_AND_HALF(\n      tensor_lists[0][0].scalar_type(), 0, \"multi_tensor_l2norm_scale_cuda\",\n      DISPATCH_FLOAT_AND_HALF(\n          tensor_lists[1][0].scalar_type(), 1, \"multi_tensor_l2norm_scale_cuda\",\n          multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,\n                                L2NormScaleFunctor<scalar_t_0, scalar_t_1>(), output.data_ptr<float>(),\n                                per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, scale, per_tensor,\n                                max_chunks_per_tensor);))\n\n  AT_CUDA_CHECK(cudaGetLastError());\n  // AT_CUDA_CHECK(cudaDeviceSynchronize());\n\n  // This involves one more small kernel launches, but will be negligible end to end.\n  // I could get rid of these by hacking the functor + multi tensor harness with persistence\n  // logic, but keeping it simple for now\n  auto ret = at::empty({1}, output.options());\n  const at::cuda::OptionalCUDAGuard device_guard(device_of(output));\n  auto stream = at::cuda::getCurrentCUDAStream();\n  cleanup_v3<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(\n      output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, ret.data_ptr<float>(),\n      per_tensor ? ret_per_tensor.data_ptr<float>() : nullptr, per_tensor, max_chunks_per_tensor);\n\n  return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);\n}\n"
  },
  {
    "path": "csrc/multi_tensor_lamb.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n\n#include \"multi_tensor_apply.cuh\"\n#include \"type_shim.h\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntemplate <typename T>\n__device__ __forceinline__ bool is_aligned(T* p) {\n  return ((uint64_t)p) % (ILP * sizeof(T)) == 0;\n}\n\ntemplate <typename T>\n__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset) {\n  typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;\n  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];\n}\n\ntypedef enum {\n  MOMENT_MODE_0 = 0,  // L2 regularization mode\n  MOMENT_MODE_1 = 1   // Decoupled weight decay mode\n} adamMode_t;\n\nstd::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,\n                                                            std::vector<std::vector<at::Tensor>> tensor_lists,\n                                                            at::optional<bool> per_tensor_python);\n\nusing MATH_T = float;\n\ntemplate <typename T>\nstruct LAMBStage1Functor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<4>& tl,\n                                             const float beta1, const float beta2, const float beta3,\n                                             const float beta1_correction, const float beta2_correction,\n                                             const float epsilon, adamMode_t mode, const float decay,\n                                             const float* global_grad_norm, const float max_global_grad_norm) {\n    // I'd like this kernel to propagate infs/nans.\n    // if(*noop_gmem == 1)\n    //   return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    float clipped_global_grad_norm =\n        (*global_grad_norm) > max_global_grad_norm ? (*global_grad_norm) / max_global_grad_norm : 1.0f;\n\n    T* g = (T*)tl.addresses[0][tensor_loc];\n    g += chunk_idx * chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx * chunk_size;\n\n    T* m = (T*)tl.addresses[2][tensor_loc];\n    m += chunk_idx * chunk_size;\n\n    T* v = (T*)tl.addresses[3][tensor_loc];\n    v += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    MATH_T r_g[ILP];\n    MATH_T r_p[ILP];\n    MATH_T r_m[ILP];\n    MATH_T r_v[ILP];\n    // to make things simple, we put aligned case in a different code path\n    if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(g) && is_aligned(p) && is_aligned(m) && is_aligned(v)) {\n      T l_g[ILP];\n      T l_p[ILP];\n      T l_m[ILP];\n      T l_v[ILP];\n      for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {\n        // load\n        load_store(l_g, g, 0, i_start);\n        if (decay != 0) load_store(l_p, p, 0, i_start);\n        load_store(l_m, m, 0, i_start);\n        load_store(l_v, v, 0, i_start);\n        // unpack\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          r_g[ii] = l_g[ii];\n          if (decay == 0) {\n            r_p[ii] = MATH_T(0);\n          } else {\n            r_p[ii] = l_p[ii];\n          }\n          r_m[ii] = l_m[ii];\n          r_v[ii] = l_v[ii];\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          if (mode == MOMENT_MODE_0) {\n            MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;\n            // L2 on scaled grad\n            scaled_grad = scaled_grad + decay * r_p[ii];\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = next_m_unbiased / denom;\n          } else {\n            MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]);\n          }\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          l_p[ii] = r_p[ii];\n          l_m[ii] = r_m[ii];\n          l_v[ii] = r_v[ii];\n        }\n        // store\n        load_store(g, l_p, i_start, 0);\n        load_store(m, l_m, i_start, 0);\n        load_store(v, l_v, i_start, 0);\n      }\n    } else {\n      // see note in multi_tensor_scale_kernel.cu\n      for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n        MATH_T r_g[ILP];\n        MATH_T r_p[ILP];\n        MATH_T r_m[ILP];\n        MATH_T r_v[ILP];\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            r_g[ii] = g[i];\n            // special ?optimization? for lamb stage 1\n            if (decay == 0) {\n              r_p[ii] = MATH_T(0);\n            } else {\n              r_p[ii] = p[i];\n            }\n            r_m[ii] = m[i];\n            r_v[ii] = v[i];\n          } else {\n            r_g[ii] = MATH_T(0);\n            r_p[ii] = MATH_T(0);\n            r_m[ii] = MATH_T(0);\n            r_v[ii] = MATH_T(0);\n          }\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          if (mode == MOMENT_MODE_0) {\n            MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;\n            // L2 on scaled grad\n            scaled_grad = scaled_grad + decay * r_p[ii];\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = next_m_unbiased / denom;\n          } else {\n            MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]);\n          }\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            g[i] = r_p[ii];\n            m[i] = r_m[ii];\n            v[i] = r_v[ii];\n          }\n        }\n      }\n    }\n  }\n};\n\n// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.\n// It computes new parameter value.\ntemplate <typename T>\nstruct LAMBStage2Functor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<2>& tl,\n                                             const float* per_tensor_param_norm, const float* per_tensor_update_norm,\n                                             const float learning_rate, const float decay, bool use_nvlamb) {\n    // I'd like this kernel to propagate infs/nans.\n    // if(*noop_gmem == 1)\n    //   return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int tensor_num = tl.start_tensor_this_launch + tensor_loc;\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    MATH_T ratio = learning_rate;\n    // nvlamb: apply adaptive learning rate to all parameters\n    // otherwise, only apply to those with non-zero weight decay\n    if (use_nvlamb || (decay != 0.0)) {\n      float param_norm = per_tensor_param_norm[tensor_num];\n      float update_norm = per_tensor_update_norm[tensor_num];\n      ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;\n    }\n\n    T* update = (T*)tl.addresses[0][tensor_loc];\n    update += chunk_idx * chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    // to make things simple, we put aligned case in a different code path\n    if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(p) && is_aligned(update)) {\n      T r_p[ILP];\n      T r_update[ILP];\n      for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {\n        // load\n        load_store(r_p, p, 0, i_start);\n        load_store(r_update, update, 0, i_start);\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * static_cast<MATH_T>(r_update[ii]));\n        }\n        load_store(p, r_p, i_start, 0);\n      }\n    } else {\n      for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n        MATH_T r_p[ILP];\n        MATH_T r_update[ILP];\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            r_p[ii] = p[i];\n            r_update[ii] = update[i];\n          }\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          r_p[ii] = r_p[ii] - (ratio * r_update[ii]);\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            p[i] = r_p[ii];\n          }\n        }\n      }\n    }\n  }\n};\n\nvoid multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                            const float lr, const float beta1, const float beta2, const float epsilon, const int step,\n                            const int bias_correction, const float weight_decay, const int grad_averaging,\n                            const int mode, at::Tensor global_grad_norm, const float max_grad_norm,\n                            at::optional<bool> use_nvlamb_python) {\n  using namespace at;\n  // Master weight and 32bit momentum(potentially changing) is not handled by this\n  // So we assume every tensor are all in the same type\n\n  bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;\n\n  // Handle bias correction mode\n  float bias_correction1 = 1.0f, bias_correction2 = 1.0f;\n  if (bias_correction == 1) {\n    bias_correction1 = 1 - std::pow(beta1, step);\n    bias_correction2 = 1 - std::pow(beta2, step);\n  }\n\n  // Handle grad averaging mode\n  float beta3 = 1.0f;\n  if (grad_averaging == 1) beta3 = 1 - beta1;\n\n  std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin() + 1);\n  std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin() + 1, tensor_lists.begin() + 2);\n\n  // Compute per tensor param norm\n  auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);\n\n  // We now in-place modify grad to store update before compute its norm\n  // Generally this is not a issue since people modify grad in step() method all the time\n  // We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code\n  DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, \"lamb_stage_1\",\n                          multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,\n                                                LAMBStage1Functor<scalar_t_0>(), beta1, beta2,\n                                                beta3,  // 1-beta1 or 1 depends on averaging mode\n                                                bias_correction1, bias_correction2, epsilon, (adamMode_t)mode,\n                                                weight_decay, global_grad_norm.data_ptr<float>(), max_grad_norm);)\n\n  // Compute update norms\n  auto update_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true);\n\n  std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin() + 2);\n\n  DISPATCH_FLOAT_AND_HALF(\n      tensor_lists[0][0].scalar_type(), 0, \"lamb_stage_2\",\n      multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, grad_param_list, LAMBStage2Functor<scalar_t_0>(),\n                            std::get<1>(param_norm_tuple).data_ptr<float>(),\n                            std::get<1>(update_norm_tuple).data_ptr<float>(), lr, weight_decay, use_nvlamb);)\n\n  AT_CUDA_CHECK(cudaGetLastError());\n}\n"
  },
  {
    "path": "csrc/multi_tensor_lamb_mp.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n\n#include \"multi_tensor_apply.cuh\"\n#include \"type_shim.h\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntemplate <typename T>\n__device__ __forceinline__ bool is_aligned(T* p) {\n  return ((uint64_t)p) % (ILP * sizeof(T)) == 0;\n}\n\ntemplate <typename T>\n__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset) {\n  typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;\n  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];\n}\n\ntypedef enum {\n  MOMENT_MODE_0 = 0,  // L2 regularization mode\n  MOMENT_MODE_1 = 1   // Decoupled weight decay mode\n} adamMode_t;\n\nstd::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_mp_cuda(int chunk_size, at::Tensor noop_flag,\n                                                               std::vector<std::vector<at::Tensor>> tensor_lists,\n                                                               at::optional<bool> per_tensor_python);\n\nusing MATH_T = float;\n\ntemplate <typename T, typename param_t>\nstruct LAMBStage1Functor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<4>& tl,\n                                             const float beta1, const float beta2, const float beta3,\n                                             const int* step_ptr, const int bias_correction, const float epsilon,\n                                             adamMode_t mode, const float decay, const float* global_grad_norm,\n                                             const float* max_global_grad_norm, const float* found_inf,\n                                             const float* inv_scale) {\n    if (*noop_gmem) {\n      return;\n    }\n\n    float beta1_correction = 1.0f;\n    float beta2_correction = 1.0f;\n    if (bias_correction == 1) {\n      int step = *step_ptr;\n      beta1_correction = 1 - std::pow(beta1, step);\n      beta2_correction = 1 - std::pow(beta2, step);\n    }\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    float clipped_global_grad_norm =\n        (*global_grad_norm) > (*max_global_grad_norm) ? (*global_grad_norm) / (*max_global_grad_norm) : 1.0f;\n\n    T* g = (T*)tl.addresses[0][tensor_loc];\n    g += chunk_idx * chunk_size;\n\n    param_t* p = (param_t*)tl.addresses[1][tensor_loc];\n    p += chunk_idx * chunk_size;\n\n    param_t* m = (param_t*)tl.addresses[2][tensor_loc];\n    m += chunk_idx * chunk_size;\n\n    param_t* v = (param_t*)tl.addresses[3][tensor_loc];\n    v += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    MATH_T r_g[ILP];\n    MATH_T r_p[ILP];\n    MATH_T r_m[ILP];\n    MATH_T r_v[ILP];\n    // to make things simple, we put aligned case in a different code path\n    if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(g) && is_aligned(p) && is_aligned(m) && is_aligned(v)) {\n      T l_g[ILP];\n      param_t l_p[ILP];\n      param_t l_m[ILP];\n      param_t l_v[ILP];\n      for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {\n        // load\n        load_store(l_g, g, 0, i_start);\n        if (decay != 0) load_store(l_p, p, 0, i_start);\n        load_store(l_m, m, 0, i_start);\n        load_store(l_v, v, 0, i_start);\n        // unpack\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          r_g[ii] = l_g[ii] * (*inv_scale);\n          if (decay == 0) {\n            r_p[ii] = MATH_T(0);\n          } else {\n            r_p[ii] = l_p[ii];\n          }\n          r_m[ii] = l_m[ii];\n          r_v[ii] = l_v[ii];\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          if (mode == MOMENT_MODE_0) {\n            MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;\n            // L2 on scaled grad\n            scaled_grad = scaled_grad + decay * r_p[ii];\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = next_m_unbiased / denom;\n          } else {\n            MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]);\n          }\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          l_p[ii] = r_p[ii];\n          // Difference from APEX's LAMB kernel. `g` and `p` can be different dtypes.\n          l_g[ii] = r_p[ii];\n          l_m[ii] = r_m[ii];\n          l_v[ii] = r_v[ii];\n        }\n        // store\n        load_store(g, l_g, i_start, 0);\n        load_store(m, l_m, i_start, 0);\n        load_store(v, l_v, i_start, 0);\n      }\n    } else {\n      // see note in multi_tensor_scale_kernel.cu\n      for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n        MATH_T r_g[ILP];\n        MATH_T r_p[ILP];\n        MATH_T r_m[ILP];\n        MATH_T r_v[ILP];\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            r_g[ii] = g[i] * (*inv_scale);\n            // special ?optimization? for lamb stage 1\n            if (decay == 0) {\n              r_p[ii] = MATH_T(0);\n            } else {\n              r_p[ii] = p[i];\n            }\n            r_m[ii] = m[i];\n            r_v[ii] = v[i];\n          } else {\n            r_g[ii] = MATH_T(0);\n            r_p[ii] = MATH_T(0);\n            r_m[ii] = MATH_T(0);\n            r_v[ii] = MATH_T(0);\n          }\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          if (mode == MOMENT_MODE_0) {\n            MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;\n            // L2 on scaled grad\n            scaled_grad = scaled_grad + decay * r_p[ii];\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = next_m_unbiased / denom;\n          } else {\n            MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;\n            r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;\n            r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;\n            MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n            MATH_T next_v_unbiased = r_v[ii] / beta2_correction;\n            MATH_T denom = sqrtf(next_v_unbiased) + epsilon;\n            r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]);\n          }\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            g[i] = r_p[ii];\n            m[i] = r_m[ii];\n            v[i] = r_v[ii];\n          }\n        }\n      }\n    }\n  }\n};\n\n// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.\n// It computes new parameter value.\n// N == 2: FP32 params, no master params\n// N == 3: FP16 params, FP32 master params.\ntemplate <typename T, int N, typename param_t>\nstruct LAMBStage2Functor {\n  static_assert((N == 2 && std::is_same<T, param_t>::value) || (N == 3 && std::is_same<param_t, float>::value), \"\");\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<N>& tl,\n                                             const float* per_tensor_param_norm, const float* per_tensor_update_norm,\n                                             const float* learning_rate, const float decay, bool use_nvlamb) {\n    if (*noop_gmem) {\n      return;\n    }\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int tensor_num = tl.start_tensor_this_launch + tensor_loc;\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    MATH_T ratio = *learning_rate;\n    // nvlamb: apply adaptive learning rate to all parameters\n    // otherwise, only apply to those with non-zero weight decay\n    if (use_nvlamb || (decay != 0.0)) {\n      float param_norm = per_tensor_param_norm[tensor_num];\n      float update_norm = per_tensor_update_norm[tensor_num];\n      ratio =\n          (update_norm != 0.0f && param_norm != 0.0f) ? *learning_rate * (param_norm / update_norm) : *learning_rate;\n    }\n\n    T* update = (T*)tl.addresses[0][tensor_loc];\n    update += chunk_idx * chunk_size;\n\n    param_t* p = (param_t*)tl.addresses[1][tensor_loc];\n    p += chunk_idx * chunk_size;\n\n    T* out_p;\n    if (N == 3) {\n      out_p = (T*)tl.addresses[2][tensor_loc];\n      out_p += chunk_idx * chunk_size;\n    }\n\n    n -= chunk_idx * chunk_size;\n\n    // to make things simple, we put aligned case in a different code path\n    bool can_use_aligned_path = n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(p) && is_aligned(update);\n    if (N == 3) {\n      can_use_aligned_path = can_use_aligned_path && is_aligned(out_p);\n    }\n    if (can_use_aligned_path) {\n      param_t r_p[ILP];\n      T r_update[ILP];\n      T r_out_p[ILP];\n      for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {\n        // load\n        load_store(r_p, p, 0, i_start);\n        load_store(r_update, update, 0, i_start);\n        if (N == 3) {\n          load_store(r_out_p, out_p, 0, i_start);\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * static_cast<MATH_T>(r_update[ii]));\n          if (N == 3) {\n            r_out_p[ii] = r_p[ii];\n          }\n        }\n        load_store(p, r_p, i_start, 0);\n        if (N == 3) {\n          load_store(out_p, r_out_p, i_start, 0);\n        }\n      }\n    } else {\n      for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n        MATH_T r_p[ILP];\n        MATH_T r_update[ILP];\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            r_p[ii] = p[i];\n            r_update[ii] = update[i];\n          }\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          r_p[ii] = r_p[ii] - (ratio * r_update[ii]);\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) {\n            p[i] = r_p[ii];\n            if (N == 3) {\n              out_p[i] = r_p[ii];\n            }\n          }\n        }\n      }\n    }\n  }\n};\n\nvoid multi_tensor_lamb_mp_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                               at::Tensor lr, const float beta1, const float beta2, const float epsilon,\n                               at::Tensor step, const int bias_correction, const float weight_decay,\n                               const int grad_averaging, const int mode, at::Tensor global_grad_norm,\n                               at::Tensor max_grad_norm, at::optional<bool> use_nvlamb_python, at::Tensor found_inf,\n                               at::Tensor inv_scale) {\n  // n_tensors == 5: FP16 model params & FP32 master params\n  // n_tensors == 4: FP32 model params & NO FP32 master params\n  const auto n_tensors = tensor_lists.size();\n  assert(n_tensors == 4 || n_tensors == 5);\n  using namespace at;\n\n  bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;\n\n  // note(mkozuki): move bias handling below to functor\n  // Handle bias correction mode\n  // float bias_correction1 = 1.0f, bias_correction2 = 1.0f;\n  // if (bias_correction == 1) {\n  //   bias_correction1 = 1 - std::pow(beta1, step);\n  //   bias_correction2 = 1 - std::pow(beta2, step);\n  // }\n\n  // Handle grad averaging mode\n  float beta3 = 1.0f;\n  if (grad_averaging == 1) beta3 = 1 - beta1;\n\n  std::vector<std::vector<at::Tensor>> stage1_tensor_lists(tensor_lists.begin(), tensor_lists.begin() + 4);\n  std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin() + 1);\n  std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin() + 1, tensor_lists.begin() + 2);\n\n  // Compute per tensor param norm\n  auto param_norm_tuple = multi_tensor_l2norm_mp_cuda(chunk_size, noop_flag, param_list, true);\n\n  // We now in-place modify grad to store update before compute its norm\n  // Generally this is not a issue since people modify grad in step() method all the time\n  // We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code\n  if (n_tensors == 4) {\n    DISPATCH_FLOAT_AND_HALF(\n        tensor_lists[0][0].scalar_type(), 0, \"lamb_stage_1\",\n        multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, stage1_tensor_lists,\n                              LAMBStage1Functor<scalar_t_0, scalar_t_0>(), beta1, beta2,\n                              beta3,  // 1-beta1 or 1 depends on averaging mode\n                              // bias_correction1,\n                              // bias_correction2,\n                              step.data_ptr<int>(), bias_correction, epsilon, (adamMode_t)mode, weight_decay,\n                              global_grad_norm.data_ptr<float>(), max_grad_norm.data_ptr<float>(),\n                              found_inf.data_ptr<float>(), inv_scale.data_ptr<float>());)\n  } else {\n    DISPATCH_FLOAT_HALF_AND_BFLOAT(\n        tensor_lists[0][0].scalar_type(), 0, \"lamb_stage_1\",\n        multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, stage1_tensor_lists,\n                              LAMBStage1Functor<scalar_t_0, float>(), beta1, beta2,\n                              beta3,  // 1-beta1 or 1 depends on averaging mode\n                              // bias_correction1,\n                              // bias_correction2,\n                              step.data_ptr<int>(), bias_correction, epsilon, (adamMode_t)mode, weight_decay,\n                              global_grad_norm.data_ptr<float>(), max_grad_norm.data_ptr<float>(),\n                              found_inf.data_ptr<float>(), inv_scale.data_ptr<float>());)\n  }\n\n  // Compute update norms\n  auto update_norm_tuple = multi_tensor_l2norm_mp_cuda(chunk_size, noop_flag, grad_list, true);\n\n  std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin() + 2);\n  if (n_tensors == 4) {\n    DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, \"lamb_stage_2\",\n                            multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, grad_param_list,\n                                                  LAMBStage2Functor<scalar_t_0, 2, scalar_t_0>(),\n                                                  std::get<1>(param_norm_tuple).data_ptr<float>(),\n                                                  std::get<1>(update_norm_tuple).data_ptr<float>(),\n                                                  lr.data_ptr<float>(), weight_decay, use_nvlamb);)\n  } else {\n    grad_param_list.push_back(tensor_lists[4]);\n    DISPATCH_FLOAT_HALF_AND_BFLOAT(tensor_lists[0][0].scalar_type(), 0, \"lamb_stage_2\",\n                                   multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, grad_param_list,\n                                                         LAMBStage2Functor<scalar_t_0, 3, float>(),\n                                                         std::get<1>(param_norm_tuple).data_ptr<float>(),\n                                                         std::get<1>(update_norm_tuple).data_ptr<float>(),\n                                                         lr.data_ptr<float>(), weight_decay, use_nvlamb);)\n  }\n  AT_CUDA_CHECK(cudaGetLastError());\n}\n"
  },
  {
    "path": "csrc/multi_tensor_lamb_stage_1.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n\n#include \"multi_tensor_apply.cuh\"\n#include \"type_shim.h\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\n// Step 1 computes the 'update' value of regular Adam optimizer.\ntemplate <typename GRAD_T, typename T, typename UPD_T>\nstruct LAMBStage1Functor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<5>& tl,\n                                             const float* per_tensor_decay, const float beta1, const float beta2,\n                                             const float beta1_correction, const float beta2_correction,\n                                             const float epsilon, const float clipped_global_grad_norm) {\n    // I'd like this kernel to propagate infs/nans.\n    // if(*noop_gmem == 1)\n    //   return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int tensor_num = tl.start_tensor_this_launch + tensor_loc;\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    float decay = per_tensor_decay[tensor_num];\n\n    GRAD_T* g = (GRAD_T*)tl.addresses[0][tensor_loc];\n    g += chunk_idx * chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx * chunk_size;\n\n    T* m = (T*)tl.addresses[2][tensor_loc];\n    m += chunk_idx * chunk_size;\n\n    T* v = (T*)tl.addresses[3][tensor_loc];\n    v += chunk_idx * chunk_size;\n\n    UPD_T* update = (UPD_T*)tl.addresses[4][tensor_loc];\n    update += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    // see note in multi_tensor_scale_kernel.cu\n    for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n      GRAD_T r_g[ILP];\n      T r_p[ILP];\n      T r_m[ILP];\n      T r_v[ILP];\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          r_g[ii] = g[i];\n          r_p[ii] = p[i];\n          r_m[ii] = m[i];\n          r_v[ii] = v[i];\n        } else {\n          r_g[ii] = GRAD_T(0);\n          r_p[ii] = T(0);\n          r_m[ii] = T(0);\n          r_v[ii] = T(0);\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        T scaled_grad = r_g[ii] / clipped_global_grad_norm;\n        r_m[ii] = r_m[ii] * beta1 + (1 - beta1) * scaled_grad;\n        r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;\n        T next_m_unbiased = r_m[ii] / beta1_correction;\n        T next_v_unbiased = r_v[ii] / beta2_correction;\n        T denom = std::sqrt(next_v_unbiased) + epsilon;\n        r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]);\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          update[i] = (UPD_T)r_p[ii];\n          m[i] = r_m[ii];\n          v[i] = r_v[ii];\n        }\n      }\n    }\n  }\n};\n\nvoid multi_tensor_lamb_stage1_cuda(int chunk_size, at::Tensor noop_flag,\n                                   std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor per_tensor_decay,\n                                   const int step, const float beta1, const float beta2, const float epsilon,\n                                   at::Tensor global_grad_norm, const float max_global_grad_norm) {\n  using namespace at;\n\n  const float* g_grad_norm = global_grad_norm.data_ptr<float>();\n  float clipped_global_grad_norm = *(g_grad_norm) > max_global_grad_norm ? *(g_grad_norm) / max_global_grad_norm : 1.0f;\n  float next_step = float(step + 1);\n  float beta1_correction = 1.0f - std::pow(beta1, next_step);\n  float beta2_correction = 1.0f - std::pow(beta2, next_step);\n  DISPATCH_FLOAT_AND_HALF(\n      tensor_lists[0][0].scalar_type(), 0, \"lamb_stage_1\",\n      DISPATCH_FLOAT_AND_HALF(\n          tensor_lists[1][0].scalar_type(), 1, \"lamb_stage_1\",\n          DISPATCH_FLOAT_AND_HALF(\n              tensor_lists[4][0].scalar_type(), 2, \"lamb_stage_1\",\n              multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,\n                                    LAMBStage1Functor<scalar_t_0, scalar_t_1, scalar_t_2>(),\n                                    per_tensor_decay.data_ptr<float>(), beta1, beta2, beta1_correction,\n                                    beta2_correction, epsilon, clipped_global_grad_norm);)))\n\n  AT_CUDA_CHECK(cudaGetLastError());\n\n  // AT_CUDA_CHECK(cudaDeviceSynchronize());\n}\n"
  },
  {
    "path": "csrc/multi_tensor_lamb_stage_2.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n\n#include \"multi_tensor_apply.cuh\"\n#include \"type_shim.h\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\nusing MATH_T = float;\n\n// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.\n// It computes new parameter value.\ntemplate <typename T, typename UPD_T>\nstruct LAMBStage2Functor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<2>& tl,\n                                             const float* per_tensor_param_norm, const float* per_tensor_update_norm,\n                                             const float learning_rate, const float decay, bool use_nvlamb) {\n    // I'd like this kernel to propagate infs/nans.\n    // if(*noop_gmem == 1)\n    //   return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int tensor_num = tl.start_tensor_this_launch + tensor_loc;\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    MATH_T ratio = learning_rate;\n    // nvlamb: apply adaptive learning rate to all parameters\n    // otherwise, only apply to those with non-zero weight decay\n    if (use_nvlamb || (decay != 0.0)) {\n      float param_norm = per_tensor_param_norm[tensor_num];\n      float update_norm = per_tensor_update_norm[tensor_num];\n      ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;\n    }\n\n    T* p = (T*)tl.addresses[0][tensor_loc];\n    p += chunk_idx * chunk_size;\n\n    UPD_T* update = (UPD_T*)tl.addresses[1][tensor_loc];\n    update += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n      T r_p[ILP];\n      UPD_T r_update[ILP];\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          r_p[ii] = p[i];\n          r_update[ii] = update[i];\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        r_p[ii] = r_p[ii] - (ratio * (T)r_update[ii]);\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          p[i] = r_p[ii];\n        }\n      }\n    }\n  }\n};\n\nvoid multi_tensor_lamb_stage2_cuda(int chunk_size, at::Tensor noop_flag,\n                                   std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor per_tensor_param_norm,\n                                   at::Tensor per_tensor_update_norm, const float lr, const float weight_decay,\n                                   at::optional<bool> use_nvlamb_python) {\n  bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;\n\n  using namespace at;\n\n  DISPATCH_FLOAT_AND_HALF(\n      tensor_lists[0][0].scalar_type(), 0, \"lamb_stage_2\",\n      DISPATCH_FLOAT_AND_HALF(\n          tensor_lists[1][0].scalar_type(), 1, \"lamb_stage_2\",\n          multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,\n                                LAMBStage2Functor<scalar_t_0, scalar_t_1>(), per_tensor_param_norm.data_ptr<float>(),\n                                per_tensor_update_norm.data_ptr<float>(), lr, weight_decay, use_nvlamb);))\n\n  AT_CUDA_CHECK(cudaGetLastError());\n\n  // AT_CUDA_CHECK(cudaDeviceSynchronize());\n}\n"
  },
  {
    "path": "csrc/multi_tensor_novograd.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n\n#include \"multi_tensor_apply.cuh\"\n#include \"type_shim.h\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntypedef enum {\n  MOMENT_MODE_0 = 0,  // Novograd paper mode, momentum caculation with denom then decay inside\n  MOMENT_MODE_1 = 1   // Decoupled weight decay mode\n} momentMode_t;\n\nvoid multi_tensor_norm_out_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                                at::Tensor out, const float alpha, const float beta, const int norm_type);\n\nusing MATH_T = float;\n\ntemplate <typename T>\nstruct NovoGradFunctor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<3>& tl,\n                                             const float beta1, const float beta2, const float beta3,\n                                             const float beta1_correction, const float beta2_correction,\n                                             const float epsilon, const float lr, momentMode_t m_mode,\n                                             const float decay, const float* per_tensor_grad_norm) {\n    // I'd like this kernel to propagate infs/nans.\n    // if(*noop_gmem == 1)\n    //   return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int tensor_num = tl.start_tensor_this_launch + tensor_loc;\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    float grad_norm = per_tensor_grad_norm[tensor_num];\n\n    T* g = (T*)tl.addresses[0][tensor_loc];\n    g += chunk_idx * chunk_size;\n\n    T* p = (T*)tl.addresses[1][tensor_loc];\n    p += chunk_idx * chunk_size;\n\n    T* m = (T*)tl.addresses[2][tensor_loc];\n    m += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    // see note in multi_tensor_scale_kernel.cu\n    for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n      MATH_T r_g[ILP];\n      MATH_T r_p[ILP];\n      MATH_T r_m[ILP];\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          r_g[ii] = g[i];\n          r_p[ii] = p[i];\n          r_m[ii] = m[i];\n        } else {\n          r_g[ii] = MATH_T(0);\n          r_p[ii] = MATH_T(0);\n          r_m[ii] = MATH_T(0);\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        if (m_mode == MOMENT_MODE_0) {\n          MATH_T next_v_unbiased = grad_norm / beta2_correction;\n          MATH_T denom = next_v_unbiased + epsilon;\n          r_g[ii] = (r_g[ii] / denom) + (decay * r_p[ii]);\n          r_m[ii] = beta1 * r_m[ii] + beta3 * r_g[ii];\n          MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n          r_p[ii] = r_p[ii] - (lr * next_m_unbiased);\n        } else {\n          r_m[ii] = beta1 * r_m[ii] + beta3 * r_g[ii];\n          MATH_T next_m_unbiased = r_m[ii] / beta1_correction;\n          MATH_T next_v_unbiased = grad_norm / beta2_correction;\n          MATH_T denom = next_v_unbiased + epsilon;\n          MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);\n          r_p[ii] = r_p[ii] - (lr * update);\n        }\n      }\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          p[i] = r_p[ii];\n          m[i] = r_m[ii];\n        }\n      }\n    }\n  }\n};\n\nvoid multi_tensor_novograd_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                                at::Tensor grad_norms, const float lr, const float beta1, const float beta2,\n                                const float epsilon, const int step, const int bias_correction,\n                                const float weight_decay, const int grad_averaging, const int moment_mode,\n                                const int norm_type) {\n  using namespace at;\n\n  // Handle bias correction mode\n  float bias_correction1 = 1.0f, bias_correction2 = 1.0f;\n  if (bias_correction == 1) {\n    bias_correction1 = 1 - std::pow(beta1, step);\n    bias_correction2 = std::sqrt(1 - std::pow(beta2, step));\n  }\n\n  // Handle grad averaging mode\n  float beta3 = 1;\n  if (grad_averaging == 1) beta3 = 1 - beta1;\n\n  std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin() + 1);\n\n  // Compute and update grad norm\n  // Here use a per tensor norm, and blend new norm(n) and old norm(gn) by\n  // L-2: gn = sqrt(a * gn^2 + b * n^2)\n  // L-inf: gn = a * gn + b * n\n  multi_tensor_norm_out_cuda(chunk_size, noop_flag, grad_list, grad_norms, beta2, (1.0f - beta2), norm_type);\n\n  // Assume single type across p,g,m1,m2 now\n  DISPATCH_DOUBLE_FLOAT_AND_HALF(\n      tensor_lists[0][0].scalar_type(), 0, \"novograd\",\n      multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, NovoGradFunctor<scalar_t_0>(), beta1,\n                            beta2,\n                            beta3,  // 1-beta1 or 1 depends on averaging mode\n                            bias_correction1, bias_correction2, epsilon, lr, (momentMode_t)moment_mode, weight_decay,\n                            grad_norms.data_ptr<float>());)\n\n  AT_CUDA_CHECK(cudaGetLastError());\n}\n"
  },
  {
    "path": "csrc/multi_tensor_scale_kernel.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n// Another possibility:\n// #include <torch/all.h>\n\n#include <assert.h>\n// Stringstream is a big hammer, but I want to rely on operator<< for dtype.\n#include <sstream>\n\n#include \"multi_tensor_apply.cuh\"\n#include \"type_shim.h\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\ntemplate <typename T>\n__device__ __forceinline__ bool is_aligned(T* p) {\n  return ((uint64_t)p) % (ILP * sizeof(T)) == 0;\n}\n\ntemplate <typename T>\n__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset) {\n  typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;\n  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];\n}\n\ntemplate <typename in_t, typename out_t>\nstruct ScaleFunctor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<2>& tl,\n                                             float scale) {\n    // I'd like this kernel to propagate infs/nans.\n    // if(*noop_gmem == 1)\n    //   return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    in_t* in = (in_t*)tl.addresses[0][tensor_loc];\n    in += chunk_idx * chunk_size;\n\n    out_t* out = (out_t*)tl.addresses[1][tensor_loc];\n    out += chunk_idx * chunk_size;\n\n    n -= chunk_idx * chunk_size;\n\n    bool finite = true;\n    in_t r_in[ILP];\n    out_t r_out[ILP];\n\n    // to make things simple, we put aligned case in a different code path\n    if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out)) {\n      for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) {\n        // load\n        load_store(r_in, in, 0, i_start);\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          r_out[ii] = static_cast<float>(r_in[ii]) * scale;\n          finite = finite && isfinite(r_in[ii]);\n        }\n        // store\n        load_store(out, r_out, i_start, 0);\n      }\n    } else {\n      // Non-divergent exit condition for __syncthreads, not necessary here\n      for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          r_in[ii] = 0;\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) r_in[ii] = in[i];\n        }\n        // note for clarification to future michael:\n        // From a pure memory dependency perspective, there's likely no point unrolling\n        // the write loop, since writes just fire off once their LDGs arrive.\n        // Put another way, the STGs are dependent on the LDGs, but not on each other.\n        // There is still compute ILP benefit from unrolling the loop though.\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          r_out[ii] = static_cast<float>(r_in[ii]) * scale;\n          finite = finite && isfinite(r_in[ii]);\n        }\n#pragma unroll\n        for (int ii = 0; ii < ILP; ii++) {\n          int i = i_start + threadIdx.x + ii * blockDim.x;\n          if (i < n && i < chunk_size) out[i] = r_out[ii];\n        }\n      }\n    }\n    if (!finite) *noop_gmem = 1;  // Blindly fire off a write.  These will race but that's ok.\n  }\n};\n\nvoid multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                             float scale) {\n  using namespace at;\n  // The output (downscaled) type is always float.\n  // If build times suffer, think about where to put this dispatch,\n  // and what logic should be moved out of multi_tensor_apply.\n\n  DISPATCH_FLOAT_HALF_AND_BFLOAT(\n      tensor_lists[0][0].scalar_type(), 0, \"multi_tensor_scale_cuda\",\n      DISPATCH_FLOAT_HALF_AND_BFLOAT(tensor_lists[1][0].scalar_type(), 1, \"multi_tensor_scale_cuda\",\n                                     multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,\n                                                           ScaleFunctor<scalar_t_0, scalar_t_1>(), scale);))\n  AT_CUDA_CHECK(cudaGetLastError());\n\n  // AT_CUDA_CHECK(cudaDeviceSynchronize());\n}\n"
  },
  {
    "path": "csrc/multi_tensor_sgd_kernel.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n#include <assert.h>\n#include <cuda_runtime.h>\n\n#include \"multi_tensor_apply.cuh\"\n\n#define BLOCK_SIZE 512\n#define ILP 4\n\n/**\n * Perform fused SGD on multiple buffers\n * N: number of tensors\n * tl[0] : gradients\n * tl[1] : weights\n * tl[2] : momentum buffers\n * tl[3] : fp16 weights (if appropriate)\n * wd : weight_decay (scalar)\n * momentum : momentum (scalar)\n * dampening : momentum dampening (scalar)\n * lr : learning rate (scalar)\n * nesterov : enable nesterov (bool)\n * first run : necessary for proper momentum handling & init\n * wd_after_momentum : apply weight decay _after_ momentum instead of before\n **/\ntemplate <int N, typename T_grad, typename T_weight>\nstruct SGDFunctor {\n  __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<N>& tl,\n                                             float wd, float momentum, float dampening, float lr, bool nesterov,\n                                             bool first_run, bool wd_after_momentum, float scale) {\n    // Early exit if we don't need to do anything\n    if (*noop_gmem) return;\n\n    int tensor_loc = tl.block_to_tensor[blockIdx.x];\n    int chunk_idx = tl.block_to_chunk[blockIdx.x];\n    int n = tl.sizes[tensor_loc];\n\n    T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc];\n    grad_in += chunk_idx * chunk_size;\n\n    T_weight* weight_in = (T_weight*)tl.addresses[1][tensor_loc];\n    weight_in += chunk_idx * chunk_size;\n\n    T_weight* mom_in = (T_weight*)tl.addresses[2][tensor_loc];\n    mom_in += chunk_idx * chunk_size;\n\n    at::Half* model_weights_out = nullptr;\n    if (N == 4) {\n      model_weights_out = (at::Half*)tl.addresses[3][tensor_loc];\n      model_weights_out += chunk_idx * chunk_size;\n    }\n\n    n -= chunk_idx * chunk_size;\n\n    // Non-divergent exit condition for the __syncthreads\n    float incoming_grads[ILP];\n    float incoming_weights[ILP];\n    float incoming_moms[ILP];\n    for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        incoming_grads[ii] = 0;\n        incoming_weights[ii] = 0;\n        incoming_moms[ii] = 0;\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          incoming_grads[ii] = static_cast<float>(grad_in[i]) * scale;\n          incoming_weights[ii] = static_cast<float>(weight_in[i]);\n          incoming_moms[ii] = static_cast<float>(mom_in[i]);\n        }\n      }\n\n// note for clarification to future michael:\n// From a pure memory dependency perspective, there's likely no point unrolling\n// the write loop, since writes just fire off once their LDGs arrive.\n// Put another way, the STGs are dependent on the LDGs, but not on each other.\n// There is still compute ILP benefit from unrolling the loop though.\n#pragma unroll\n      for (int ii = 0; ii < ILP; ii++) {\n        int i = i_start + threadIdx.x + ii * blockDim.x;\n        if (i < n && i < chunk_size) {\n          // apply weight decay before momentum if necessary\n          if (wd != 0.f && !wd_after_momentum) incoming_grads[ii] += wd * incoming_weights[ii];\n\n          if (momentum != 0.f) {\n            if (!first_run)\n              incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];\n            else  // initialize momentums to current incoming grads\n              incoming_moms[ii] = incoming_grads[ii];\n\n            if (nesterov)\n              incoming_grads[ii] += momentum * incoming_moms[ii];\n            else\n              incoming_grads[ii] = incoming_moms[ii];\n          }\n\n          // Apply WD after momentum if desired\n          if (wd != 0.f && wd_after_momentum) incoming_grads[ii] += wd * incoming_weights[ii];\n\n          // adjust the weight and write out\n          weight_in[i] += (-lr * incoming_grads[ii]);\n\n          // if necessary, write out an fp16 copy of the weights\n          if (N == 4) model_weights_out[i] = static_cast<at::Half>(weight_in[i]);\n\n          // also write out the new momentum\n          if (momentum != 0.f) mom_in[i] = incoming_moms[ii];\n        }\n      }\n    }\n  }\n};\n\nvoid multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,\n                           float wd, float momentum, float dampening, float lr, bool nesterov, bool first_run,\n                           bool wd_after_momentum, float scale) {\n  auto num_tensors = tensor_lists.size();\n  auto grad_type = tensor_lists[0][0].scalar_type();\n  auto weight_type = tensor_lists[1][0].scalar_type();\n\n  if (num_tensors == 4)\n    for (int i = 0; i < tensor_lists[3].size(); i++)\n      TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,\n                  \"Additional output tensors should always be fp16.\");\n\n  TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(),\n              \"expected noop flag to be on the same device as tensors\");\n\n  // We have 3 possibilities to handle here, in terms of\n  // grad_type, param_type, momentum_type, requires_fp16_copy\n  // 1. fp16, fp16, fp16, No\n  // 2. fp32, fp32, fp32, No\n  // 3. fp16, fp32, fp32, Yes\n  // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case\n  // It's easier to hardcode these possibilities than to use\n  // switches etc. to handle the cross-product of cases where\n  // we don't want the majority of them.\n\n  // Case 1. fp16, fp16, fp16, No\n  if (grad_type == at::ScalarType::Half && weight_type == at::ScalarType::Half && num_tensors == 3) {\n    multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, SGDFunctor<3, at::Half, at::Half>(), wd,\n                          momentum, dampening, lr, nesterov, first_run, wd_after_momentum, scale);\n  }\n  // Case 2. fp16, fp32, fp32, No\n  // else if (grad_type == at::ScalarType::Half &&\n  //          weight_type == at::ScalarType::Float &&\n  //          num_tensors == 3) {\n  //   multi_tensor_apply<3>(\n  //       BLOCK_SIZE,\n  //       chunk_size,\n  //       noop_flag,\n  //       tensor_lists,\n  //       SGDFunctor<3, at::Half, float>(),\n  //       wd,\n  //       momentum,\n  //       dampening,\n  //       lr,\n  //       nesterov,\n  //       first_run,\n  //       wd_after_momentum);\n  // }\n  // Case 2. fp32, fp32, fp32, No\n  else if (grad_type == at::ScalarType::Float && weight_type == at::ScalarType::Float && num_tensors == 3) {\n    multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, SGDFunctor<3, float, float>(), wd, momentum,\n                          dampening, lr, nesterov, first_run, wd_after_momentum, scale);\n  }\n  // Case 3. fp16, fp32, fp32, Yes\n  else if (grad_type == at::ScalarType::Half && weight_type == at::ScalarType::Float && num_tensors == 4) {\n    multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, SGDFunctor<4, at::Half, float>(), wd,\n                          momentum, dampening, lr, nesterov, first_run, wd_after_momentum, scale);\n  }\n  // Case 4. fp32, fp32, fp32, Yes\n  else if (grad_type == at::ScalarType::Float && weight_type == at::ScalarType::Float && num_tensors == 4) {\n    multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, SGDFunctor<4, float, float>(), wd, momentum,\n                          dampening, lr, nesterov, first_run, wd_after_momentum, scale);\n  } else {\n    AT_ERROR(\"multi_tensor_sgd only supports some combinations of gradient & weight types. Given: \", \"gradient: \",\n             grad_type, \", weight: \", weight_type, \", num_lists: \", num_tensors);\n  }\n\n  AT_CUDA_CHECK(cudaGetLastError());\n}\n"
  },
  {
    "path": "csrc/static_switch.h",
    "content": "// From\n// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h\n\n#pragma once\n\n/// @param COND       - a boolean expression to switch by\n/// @param CONST_NAME - a name given for the constexpr bool variable.\n/// @param ...       - code to execute for true and false\n///\n/// Usage:\n/// ```\n/// BOOL_SWITCH(flag, BoolConst, [&] {\n///     some_function<BoolConst>(...);\n/// });\n/// ```\n#define BOOL_SWITCH(COND, CONST_NAME, ...)      \\\n  [&] {                                         \\\n    if (COND) {                                 \\\n      constexpr static bool CONST_NAME = true;  \\\n      return __VA_ARGS__();                     \\\n    } else {                                    \\\n      constexpr static bool CONST_NAME = false; \\\n      return __VA_ARGS__();                     \\\n    }                                           \\\n  }()\n"
  },
  {
    "path": "csrc/syncbn.cpp",
    "content": "#include <ATen/ATen.h>\n#include <torch/extension.h>\n\n#include <vector>\n\n// returns {mean,biased_var}\n// implemented using welford\nstd::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input);\n\n// reduces array of mean/var across processes\n// returns global {mean,inv_std,biased_var}\n// implemented using welford\nstd::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,\n                                              const at::Tensor var_biased_feature_nodes, const at::Tensor numel,\n                                              const float eps);\n\n// elementwise BN operation, returns output\n// input/weight/shift should have identical data type;\n// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)\nat::Tensor batchnorm_forward_CUDA(const at::Tensor input, const at::Tensor mean, const at::Tensor inv_std,\n                                  const at::optional<at::Tensor> weight, const at::optional<at::Tensor> shift);\n\n// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}\n// grad_output/input should have identical data type;\n// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)\n// implemented using kahan summation\nstd::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output, const at::Tensor input, const at::Tensor mean,\n                                       const at::Tensor inv_std, const at::optional<at::Tensor> weight);\n\n// elementwise backward BN operation, returns grad_input\n// grad_output/input/weight precision could be fp16/fp32;\n// mean/inv_std/sum_dy/sum_dy_xmu precision is fp32\nat::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output, const at::Tensor input, const at::Tensor mean,\n                                   const at::Tensor inv_std, const at::optional<at::Tensor> weight,\n                                   const at::Tensor sum_dy, const at::Tensor sum_dy_xmu, const at::Tensor count);\n\n// returns {mean, biased_var}\n// implemented using welford\n// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL\nstd::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);\n\n// elementwise BN operation, returns output\n// input/weight/shift should have identical data type;\n// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)\n// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL\nat::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input, const at::optional<at::Tensor> z,\n                                         const at::Tensor mean, const at::Tensor inv_std,\n                                         const at::optional<at::Tensor> weight, const at::optional<at::Tensor> shift,\n                                         const bool fuse_relu);\n\n// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}\n// grad_output/input should have identical data type;\n// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)\n// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL\nstd::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output, const at::Tensor input,\n                                              const at::Tensor mean, const at::Tensor inv_std,\n                                              const at::optional<at::Tensor> weight);\n\n// elementwise backward BN operation, returns grad_input\n// grad_output/input/weight precision could be fp16/fp32;\n// mean/inv_std/sum_dy/sum_dy_xmu precision is fp32\n// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL\nat::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output, const at::Tensor input, const at::Tensor mean,\n                                          const at::Tensor inv_std, const at::optional<at::Tensor> weight,\n                                          const at::Tensor sum_dy, const at::Tensor sum_dy_xmu, const at::Tensor count);\n\nat::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output, const at::Tensor input,\n                                     const at::optional<at::Tensor> z, const at::Tensor mean, const at::Tensor inv_std,\n                                     const at::optional<at::Tensor> weight, const at::optional<at::Tensor> shift);\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"welford_mean_var\", &welford_mean_var_CUDA, \"welford mean variance\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"welford_parallel\", &welford_parallel_CUDA, \"welford parallel reduce mean variance\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"batchnorm_forward\", &batchnorm_forward_CUDA, \"batchnorm forward\", py::call_guard<py::gil_scoped_release>());\n  m.def(\"reduce_bn\", &reduce_bn_CUDA, \"batchnorm backward reduce grad sum and bias/weight grad\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"batchnorm_backward\", &batchnorm_backward_CUDA, \"batchnorm backward dgrad\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"welford_mean_var_c_last\", &welford_mean_var_c_last_CUDA, \"welford mean variance nhwc\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"batchnorm_forward_c_last\", &batchnorm_forward_c_last_CUDA, \"batchnorm forward nhwc\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"reduce_bn_c_last\", &reduce_bn_c_last_CUDA, \"batchnorm backwards reduce grad sum and bias/weight grad nhwc\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"batchnorm_backward_c_last\", &batchnorm_backward_c_last_CUDA, \"batchnorm backward dgrad nhwc\",\n        py::call_guard<py::gil_scoped_release>());\n  m.def(\"relu_bw_c_last\", &relu_backward_c_last_CUDA, \"relu_bw_c_last\", py::call_guard<py::gil_scoped_release>());\n}\n"
  },
  {
    "path": "csrc/type_shim.h",
    "content": "#include <ATen/ATen.h>\n\n// Forward/backward compatiblity hack around\n// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288\n// pending more future-proof guidance from upstream.\n// struct TypeShim\n// {\n//   const at::Type& payload;\n//   TypeShim(const at::Type& type) : payload(type) {}\n//   // Enable trivial conversion to a const at::Type& for pre-3aeb78\n//   operator const at::Type&(){ return payload; };\n//   // Enable dispatch switch statements to take *this directly for  post-3aeb78\n//   //operator at::ScalarType(){ return payload.; };\n// };\n\n#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...)               \\\n  switch (TYPE) {                                                     \\\n    case at::ScalarType::Float: {                                     \\\n      using scalar_t_##LEVEL = float;                                 \\\n      __VA_ARGS__;                                                    \\\n      break;                                                          \\\n    }                                                                 \\\n    case at::ScalarType::Half: {                                      \\\n      using scalar_t_##LEVEL = at::Half;                              \\\n      __VA_ARGS__;                                                    \\\n      break;                                                          \\\n    }                                                                 \\\n    default:                                                          \\\n      AT_ERROR(#NAME, \" not implemented for '\", toString(TYPE), \"'\"); \\\n  }\n\n#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...)        \\\n  switch (TYPE) {                                                     \\\n    case at::ScalarType::Float: {                                     \\\n      using scalar_t_##LEVEL = float;                                 \\\n      __VA_ARGS__;                                                    \\\n      break;                                                          \\\n    }                                                                 \\\n    case at::ScalarType::Half: {                                      \\\n      using scalar_t_##LEVEL = at::Half;                              \\\n      __VA_ARGS__;                                                    \\\n      break;                                                          \\\n    }                                                                 \\\n    case at::ScalarType::BFloat16: {                                  \\\n      using scalar_t_##LEVEL = at::BFloat16;                          \\\n      __VA_ARGS__;                                                    \\\n      break;                                                          \\\n    }                                                                 \\\n    default:                                                          \\\n      AT_ERROR(#NAME, \" not implemented for '\", toString(TYPE), \"'\"); \\\n  }\n\n#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...)          \\\n  switch (TYPE) {                                                     \\\n    case at::ScalarType::Float: {                                     \\\n      using scalar_t_##LEVEL = float;                                 \\\n      __VA_ARGS__;                                                    \\\n      break;                                                          \\\n    }                                                                 \\\n    case at::ScalarType::Half: {                                      \\\n      using scalar_t_##LEVEL = at::Half;                              \\\n      __VA_ARGS__;                                                    \\\n      break;                                                          \\\n    }                                                                 \\\n    case at::ScalarType::Byte: {                                      \\\n      using scalar_t_##LEVEL = uint8_t;                               \\\n      __VA_ARGS__;                                                    \\\n      break;                                                          \\\n    }                                                                 \\\n    default:                                                          \\\n      AT_ERROR(#NAME, \" not implemented for '\", toString(TYPE), \"'\"); \\\n  }\n\n#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...)        \\\n  switch (TYPE) {                                                     \\\n    case at::ScalarType::Double: {                                    \\\n      using scalar_t_##LEVEL = double;                                \\\n      __VA_ARGS__;                                                    \\\n      break;                                                          \\\n    }                                                                 \\\n    case at::ScalarType::Float: {                                     \\\n      using scalar_t_##LEVEL = float;                                 \\\n      __VA_ARGS__;                                                    \\\n      break;                                                          \\\n    }                                                                 \\\n    case at::ScalarType::Half: {                                      \\\n      using scalar_t_##LEVEL = at::Half;                              \\\n      __VA_ARGS__;                                                    \\\n      break;                                                          \\\n    }                                                                 \\\n    default:                                                          \\\n      AT_ERROR(#NAME, \" not implemented for '\", toString(TYPE), \"'\"); \\\n  }\n\n#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \\\n  switch (TYPE) {                                                     \\\n    case at::ScalarType::Double: {                                    \\\n      using scalar_t_##LEVEL = double;                                \\\n      __VA_ARGS__;                                                    \\\n      break;                                                          \\\n    }                                                                 \\\n    case at::ScalarType::Float: {                                     \\\n      using scalar_t_##LEVEL = float;                                 \\\n      __VA_ARGS__;                                                    \\\n      break;                                                          \\\n    }                                                                 \\\n    case at::ScalarType::Half: {                                      \\\n      using scalar_t_##LEVEL = at::Half;                              \\\n      __VA_ARGS__;                                                    \\\n      break;                                                          \\\n    }                                                                 \\\n    case at::ScalarType::BFloat16: {                                  \\\n      using scalar_t_##LEVEL = at::BFloat16;                          \\\n      __VA_ARGS__;                                                    \\\n      break;                                                          \\\n    }                                                                 \\\n    default:                                                          \\\n      AT_ERROR(#NAME, \" not implemented for '\", toString(TYPE), \"'\"); \\\n  }\n\n#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...)             \\\n  switch (TYPE) {                                                     \\\n    case at::ScalarType::Double: {                                    \\\n      using scalar_t_##LEVEL = double;                                \\\n      __VA_ARGS__;                                                    \\\n      break;                                                          \\\n    }                                                                 \\\n    case at::ScalarType::Float: {                                     \\\n      using scalar_t_##LEVEL = float;                                 \\\n      __VA_ARGS__;                                                    \\\n      break;                                                          \\\n    }                                                                 \\\n    default:                                                          \\\n      AT_ERROR(#NAME, \" not implemented for '\", toString(TYPE), \"'\"); \\\n  }\n\n#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...)                     \\\n  switch (TYPE) {                                                     \\\n    case at::ScalarType::Half: {                                      \\\n      using scalar_t = at::Half;                                      \\\n      __VA_ARGS__;                                                    \\\n      break;                                                          \\\n    }                                                                 \\\n    case at::ScalarType::BFloat16: {                                  \\\n      using scalar_t = at::BFloat16;                                  \\\n      __VA_ARGS__;                                                    \\\n      break;                                                          \\\n    }                                                                 \\\n    default:                                                          \\\n      AT_ERROR(#NAME, \" not implemented for '\", toString(TYPE), \"'\"); \\\n  }\n\n#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \\\n  switch (TYPEIN) {                                                            \\\n    case at::ScalarType::Float: {                                              \\\n      using scalar_t_in = float;                                               \\\n      switch (TYPEOUT) {                                                       \\\n        case at::ScalarType::Float: {                                          \\\n          using scalar_t_out = float;                                          \\\n          __VA_ARGS__;                                                         \\\n          break;                                                               \\\n        }                                                                      \\\n        case at::ScalarType::Half: {                                           \\\n          using scalar_t_out = at::Half;                                       \\\n          __VA_ARGS__;                                                         \\\n          break;                                                               \\\n        }                                                                      \\\n        case at::ScalarType::BFloat16: {                                       \\\n          using scalar_t_out = at::BFloat16;                                   \\\n          __VA_ARGS__;                                                         \\\n          break;                                                               \\\n        }                                                                      \\\n        default:                                                               \\\n          AT_ERROR(#NAME, \" not implemented for '\", toString(TYPEOUT), \"'\");   \\\n      }                                                                        \\\n      break;                                                                   \\\n    }                                                                          \\\n    case at::ScalarType::Half: {                                               \\\n      using scalar_t_in = at::Half;                                            \\\n      using scalar_t_out = at::Half;                                           \\\n      __VA_ARGS__;                                                             \\\n      break;                                                                   \\\n    }                                                                          \\\n    case at::ScalarType::BFloat16: {                                           \\\n      using scalar_t_in = at::BFloat16;                                        \\\n      using scalar_t_out = at::BFloat16;                                       \\\n      __VA_ARGS__;                                                             \\\n      break;                                                                   \\\n    }                                                                          \\\n    default:                                                                   \\\n      AT_ERROR(#NAME, \" not implemented for '\", toString(TYPEIN), \"'\");        \\\n  }\n\n#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \\\n  switch (TYPEIN) {                                                                   \\\n    case at::ScalarType::Double: {                                                    \\\n      using scalar_t_in = double;                                                     \\\n      switch (TYPEOUT) {                                                              \\\n        case at::ScalarType::Double: {                                                \\\n          using scalar_t_out = double;                                                \\\n          __VA_ARGS__;                                                                \\\n          break;                                                                      \\\n        }                                                                             \\\n        case at::ScalarType::Float: {                                                 \\\n          using scalar_t_out = float;                                                 \\\n          __VA_ARGS__;                                                                \\\n          break;                                                                      \\\n        }                                                                             \\\n        case at::ScalarType::Half: {                                                  \\\n          using scalar_t_out = at::Half;                                              \\\n          __VA_ARGS__;                                                                \\\n          break;                                                                      \\\n        }                                                                             \\\n        case at::ScalarType::BFloat16: {                                              \\\n          using scalar_t_out = at::BFloat16;                                          \\\n          __VA_ARGS__;                                                                \\\n          break;                                                                      \\\n        }                                                                             \\\n        default:                                                                      \\\n          AT_ERROR(#NAME, \" not implemented for '\", toString(TYPEOUT), \"'\");          \\\n      }                                                                               \\\n      break;                                                                          \\\n    }                                                                                 \\\n    case at::ScalarType::Float: {                                                     \\\n      using scalar_t_in = float;                                                      \\\n      switch (TYPEOUT) {                                                              \\\n        case at::ScalarType::Float: {                                                 \\\n          using scalar_t_out = float;                                                 \\\n          __VA_ARGS__;                                                                \\\n          break;                                                                      \\\n        }                                                                             \\\n        case at::ScalarType::Half: {                                                  \\\n          using scalar_t_out = at::Half;                                              \\\n          __VA_ARGS__;                                                                \\\n          break;                                                                      \\\n        }                                                                             \\\n        case at::ScalarType::BFloat16: {                                              \\\n          using scalar_t_out = at::BFloat16;                                          \\\n          __VA_ARGS__;                                                                \\\n          break;                                                                      \\\n        }                                                                             \\\n        default:                                                                      \\\n          AT_ERROR(#NAME, \" not implemented for '\", toString(TYPEOUT), \"'\");          \\\n      }                                                                               \\\n      break;                                                                          \\\n    }                                                                                 \\\n    case at::ScalarType::Half: {                                                      \\\n      using scalar_t_in = at::Half;                                                   \\\n      using scalar_t_out = at::Half;                                                  \\\n      __VA_ARGS__;                                                                    \\\n      break;                                                                          \\\n    }                                                                                 \\\n    case at::ScalarType::BFloat16: {                                                  \\\n      using scalar_t_in = at::BFloat16;                                               \\\n      using scalar_t_out = at::BFloat16;                                              \\\n      __VA_ARGS__;                                                                    \\\n      break;                                                                          \\\n    }                                                                                 \\\n    default:                                                                          \\\n      AT_ERROR(#NAME, \" not implemented for '\", toString(TYPEIN), \"'\");               \\\n  }\n\ntemplate <typename T>\n__device__ __forceinline__ T reduce_block_into_lanes(T* x, T val, int lanes = 1,\n                                                     bool share_result = false)  // lanes is intended to be <= 32.\n{\n  int tid = threadIdx.x + threadIdx.y * blockDim.x;\n  int blockSize = blockDim.x * blockDim.y;  // blockSize is intended to be a multiple of 32.\n\n  if (blockSize >= 64) {\n    x[tid] = val;\n    __syncthreads();\n  }\n\n#pragma unroll\n  for (int i = (blockSize >> 1); i >= 64; i >>= 1) {\n    if (tid < i) x[tid] = x[tid] + x[tid + i];\n    __syncthreads();\n  }\n\n  T final;\n\n  if (tid < 32) {\n    if (blockSize >= 64)\n      final = x[tid] + x[tid + 32];\n    else\n      final = val;\n    // __SYNCWARP();\n\n#pragma unroll\n    for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i);\n  }\n\n  if (share_result) {\n    if (tid < lanes) x[tid] = final;  // EpilogueOp\n    // Make sure the smem result is visible to all warps.\n  }\n  __syncthreads();\n  // Avoid potential write before read race when reduce_block_into_lanes is called back to back\n\n  return final;\n}\n\ntemplate <typename T>\n__device__ __forceinline__ T\nreduce_block_into_lanes_max_op(T* x, T val, int lanes = 1,\n                               bool share_result = false)  // lanes is intended to be <= 32.\n{\n  int tid = threadIdx.x + threadIdx.y * blockDim.x;\n  int blockSize = blockDim.x * blockDim.y;  // blockSize is intended to be a multiple of 32.\n\n  if (blockSize >= 64) {\n    x[tid] = val;\n    __syncthreads();\n  }\n\n#pragma unroll\n  for (int i = (blockSize >> 1); i >= 64; i >>= 1) {\n    if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));\n    __syncthreads();\n  }\n\n  T final;\n\n  if (tid < 32) {\n    if (blockSize >= 64)\n      final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));\n    else\n      final = val;\n    // __SYNCWARP();\n\n#pragma unroll\n    for (int i = 16; i >= lanes; i >>= 1) final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));\n  }\n\n  if (share_result) {\n    if (tid < lanes) x[tid] = final;  // EpilogueOp\n    // Make sure the smem result is visible to all warps.\n    __syncthreads();\n  }\n\n  return final;\n}\n"
  },
  {
    "path": "csrc/update_scale_hysteresis.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/Exceptions.h>\n\n__global__ void update_scale_hysteresis_cuda_kernel(float* current_scale, int* growth_tracker, int* hysteresis_tracker,\n                                                    const float* found_inf, double growth_factor, double backoff_factor,\n                                                    int growth_interval, int hysteresis) {\n  if (*found_inf > 0) {\n    *hysteresis_tracker -= 1;\n\n    // Only reset the growth tracker when hysteresis is larger than zero\n    if (*hysteresis_tracker > 0) {\n      *growth_tracker = 0;\n      return;\n    }\n  }\n\n  if (*found_inf) {\n    *current_scale = (*current_scale) * backoff_factor;\n    *growth_tracker = 0;\n  } else {\n    // Entering this branch means we just carried out a successful step,\n    // so growth_tracker is incremented before comparing to growth_interval.\n    auto successful = (*growth_tracker) + 1;\n    if (successful == growth_interval) {\n      auto new_scale = static_cast<float>((*current_scale) * growth_factor);\n      // Do not grow the scale past fp32 bounds to inf.\n      if (isfinite(new_scale)) {\n        *current_scale = new_scale;\n      }\n      *growth_tracker = 0;\n    } else {\n      *growth_tracker = successful;\n    }\n  }\n\n  // Reset the hysteresis tracker if no infs are found\n  if (*found_inf <= 0) {\n    *hysteresis_tracker = hysteresis;\n  }\n}\n\nat::Tensor update_scale_hysteresis_cuda(at::Tensor current_scale, at::Tensor growth_tracker,\n                                        at::Tensor hysteresis_tracker, at::Tensor found_inf, const double growth_factor,\n                                        const double backoff_factor, const int64_t growth_interval,\n                                        const int hysteresis) {\n  update_scale_hysteresis_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(\n      current_scale.mutable_data_ptr<float>(), growth_tracker.mutable_data_ptr<int>(),\n      hysteresis_tracker.mutable_data_ptr<int>(), found_inf.const_data_ptr<float>(), growth_factor, backoff_factor,\n      growth_interval, hysteresis);\n\n  AT_CUDA_CHECK(cudaGetLastError());\n\n  return current_scale;\n}\n"
  },
  {
    "path": "csrc/welford.cu",
    "content": "#include <ATen/ATen.h>\n#include <ATen/AccumulateType.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n\n#include <iostream>\n#include <vector>\n\n#include \"type_shim.h\"\n\n__device__ __forceinline__ int lastpow2(int n) {\n  int out = 1 << (31 - __clz(n));\n  if (n == out) out >>= 1;\n  return out;\n}\n\n__host__ __forceinline__ int h_next_pow2(unsigned int n) {\n  n--;\n  n |= (n >> 1);\n  n |= (n >> 2);\n  n |= (n >> 4);\n  n |= (n >> 8);\n  n |= (n >> 16);\n  return ++n;\n}\n\n__host__ __forceinline__ int h_last_pow2(unsigned int n) {\n  n |= (n >> 1);\n  n |= (n >> 2);\n  n |= (n >> 4);\n  n |= (n >> 8);\n  n |= (n >> 16);\n  return n - (n >> 1);\n}\n\n#define WARP_SIZE 32\n\ntemplate <typename T>\n__device__ __forceinline__ T warp_reduce_sum(T val) {\n#pragma unroll\n  for (int i = WARP_SIZE / 2; i > 0; i >>= 1) val = val + __shfl_down_sync(0xffffffff, val, i);\n  return val;\n}\n\ntemplate <typename T>\n__device__ __forceinline__ T reduce_block(T* x, T val) {\n  int tid = threadIdx.y * blockDim.x + threadIdx.x;\n  int blockSize = blockDim.x * blockDim.y;\n\n  if (blockSize > 32) {\n    val = warp_reduce_sum(val);\n    if (tid % WARP_SIZE == 0) x[tid / WARP_SIZE] = val;\n\n    __syncthreads();\n\n    val = (tid < blockSize / WARP_SIZE ? x[tid % WARP_SIZE] : T(0));\n  }\n\n  if (tid / WARP_SIZE == 0) val = warp_reduce_sum(val);\n\n  return val;\n}\n\n#define ELEMENTS_PER_ITER 4  // enables concurrency within each thread to hide latency\n#define ELEMENTS_PER_THREAD 16\n#define OPTIMAL_TILE_W 32\n#define MAX_H_BLOCK 128\n#define MAX_BLOCK_SIZE 512\n\n__host__ int div_ru(int x, int y) { return h_last_pow2(1 + (x - 1) / y); }\n\n__host__ void flexible_launch_configs(const int reduction, const int stride, dim3& block, dim3& grid,\n                                      const bool coop_flag = false) {\n  int block_x = std::min(h_last_pow2(stride), OPTIMAL_TILE_W);\n  int block_y = std::min(h_last_pow2(div_ru(reduction, ELEMENTS_PER_THREAD)), MAX_BLOCK_SIZE / block_x);\n  if (block_x * block_y != MAX_BLOCK_SIZE) {\n    block_x = std::min(h_last_pow2(stride), MAX_BLOCK_SIZE / block_y);\n  }\n\n  int grid_x = div_ru(stride, block_x);\n  int grid_y = std::min(div_ru(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK);\n  if (coop_flag) {\n    // it's not worth having a grid reduction if the reduction dimension is not big enough\n    grid_y = grid_y < 8 ? 1 : grid_y;\n  }\n\n  block.x = block_x;\n  block.y = block_y;\n  block.z = 1;\n  grid.x = grid_x;\n  grid.y = grid_y;\n  grid.z = 1;\n}\n\ntemplate <typename T, typename C>\n__device__ __forceinline__ void welford_merge_element(C& count, T& mean, T& m2n, const C& num_new, const T& mean_new,\n                                                      const T& m2n_new) {\n  T factor = T(1.0) / max(1, (count + num_new));\n  T delta0 = mean - mean_new;\n  mean = (mean_new * num_new + mean * count) * factor;\n  m2n += m2n_new + delta0 * delta0 * num_new * count * factor;\n  count += num_new;\n}\n\ntemplate <typename T>\n__device__ __forceinline__ void warp_reduce_mean_m2n(T& mean, T& m2n, int& num) {\n#pragma unroll\n  for (int i = WARP_SIZE / 2; i > 0; i >>= 1) {\n    auto num_new = __shfl_down_sync(0xffffffff, num, i);\n    auto mean_new = __shfl_down_sync(0xffffffff, mean, i);\n    auto m2n_new = __shfl_down_sync(0xffffffff, m2n, i);\n    welford_merge_element(num, mean, m2n, num_new, mean_new, m2n_new);\n  }\n}\n\ntemplate <typename T>\n__device__ void welford_reduce_mean_m2n(T* __restrict__ x, int* __restrict__ count, T& mean, T& m2n, int& num,\n                                        int block_size, int thread_id) {\n  int lane = thread_id % WARP_SIZE;\n  int wid = thread_id / WARP_SIZE;\n\n  if (block_size > 32) {\n    warp_reduce_mean_m2n(mean, m2n, num);\n    if (lane == 0) {\n      x[wid * 2] = mean;\n      x[wid * 2 + 1] = m2n;\n      count[wid] = num;\n    }\n    __syncthreads();\n\n    if (wid == 0) {\n      mean = (thread_id < block_size / WARP_SIZE) ? x[lane * 2] : T(0);\n      m2n = (thread_id < block_size / WARP_SIZE) ? x[lane * 2 + 1] : T(0);\n      num = (thread_id < block_size / WARP_SIZE) ? count[lane] : int(0);\n    }\n  }\n\n  if (wid == 0) warp_reduce_mean_m2n(mean, m2n, num);\n\n  return;\n}\n\n// return spatial size for NC+ Tensors\n__host__ int get_tensor_spatial_size(const at::Tensor& input) {\n  auto space_size = input.size(2);\n  for (int i = 3; i < input.ndimension(); i++) {\n    space_size *= input.size(i);\n  }\n  return space_size;\n}\n\n// promote accumulation scalar type. promote half to float.\n__host__ at::ScalarType promote_scalartype(const at::Tensor& input) {\n  return input.scalar_type() == at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type();\n}\n\n// return single element size, optional accumulation type promotion.\n__host__ size_t get_element_data_size(const at::Tensor& input, bool accumulation = false) {\n  auto scalar_type = accumulation ? promote_scalartype(input) : input.scalar_type();\n  return at::elementSize(scalar_type);\n}\n\ntemplate <typename T, typename C>\n__device__ __forceinline__ void welford_merge_block_vertical(C& count, T& mean, T& m2n, C* shmem_count, T* shmem_mean,\n                                                             T* shmem_m2n) {\n  // write to shared memory\n  auto address_base = threadIdx.x + threadIdx.y * blockDim.x;\n  shmem_mean[address_base] = mean;\n  shmem_m2n[address_base] = m2n;\n  shmem_count[address_base] = count;\n\n#pragma unroll\n  for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {\n    __syncthreads();\n    if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {\n      auto address = address_base + offset * blockDim.x;\n      // read shared memory back to register for reduction\n      auto num_new = shmem_count[address];\n      auto mean_new = shmem_mean[address];\n      auto m2n_new = shmem_m2n[address];\n\n      welford_merge_element(count, mean, m2n, num_new, mean_new, m2n_new);\n\n      // last write is not necessary\n      shmem_mean[address_base] = mean;\n      shmem_m2n[address_base] = m2n;\n      shmem_count[address_base] = count;\n    }\n  }\n}\n\ntemplate <typename T>\n__device__ __forceinline__ void merge_block_vertical(T& sum_dy, T& sum_dy_xmu, T* shmem_sum_dy, T* shmem_sum_dy_xmu) {\n  // write to shared memory\n  auto address_base = threadIdx.x + threadIdx.y * blockDim.x;\n  shmem_sum_dy[address_base] = sum_dy;\n  shmem_sum_dy_xmu[address_base] = sum_dy_xmu;\n\n#pragma unroll\n  for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {\n    __syncthreads();\n    if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {\n      auto address = address_base + offset * blockDim.x;\n\n      sum_dy += shmem_sum_dy[address];\n      sum_dy_xmu += shmem_sum_dy_xmu[address];\n\n      // last write is not necessary\n      shmem_sum_dy[address_base] = sum_dy;\n      shmem_sum_dy_xmu[address_base] = sum_dy_xmu;\n    }\n  }\n}\n\n// welford kernel calculating mean/biased_variance/unbiased_variance\ntemplate <typename scalar_t, typename accscalar_t, typename outscalar_t>\n__global__ void welford_kernel(const scalar_t* __restrict__ input, outscalar_t* __restrict__ out_mean,\n                               outscalar_t* __restrict__ out_var_biased, const int bs, const int fs, const int ss) {\n  int block_size = blockDim.x * blockDim.y;\n  int count = 0;\n  accscalar_t x_mean = accscalar_t(0);\n  accscalar_t m_2_n = accscalar_t(0);\n\n  int thread_id = threadIdx.y * blockDim.x + threadIdx.x;\n\n  for (int batch_id = threadIdx.y; batch_id < bs; batch_id += blockDim.y) {\n    int input_base = blockIdx.x * ss + batch_id * ss * fs;\n    // sequential welford\n    for (int offset = threadIdx.x; offset < ss; offset += blockDim.x) {\n      count++;\n      auto x_n = static_cast<accscalar_t>(input[offset + input_base]);\n      auto d = x_n - x_mean;\n      x_mean += d / count;\n      m_2_n += d * (x_n - x_mean);\n    }\n  }\n\n  static __shared__ int s_mem[160];\n  accscalar_t* s_mem_ac = (accscalar_t*)&s_mem[32];\n\n  welford_reduce_mean_m2n<accscalar_t>(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id);\n\n  if (thread_id == 0) {\n    out_mean[blockIdx.x] = static_cast<outscalar_t>(x_mean);\n    out_var_biased[blockIdx.x] = static_cast<outscalar_t>(m_2_n / count);\n  }\n}\n\n// elementwise BN kernel\ntemplate <typename scalar_t, typename accscalar_t, typename layerscalar_t>\n__global__ void batchnorm_forward_kernel(const scalar_t* __restrict__ input, const accscalar_t* __restrict__ mean,\n                                         const accscalar_t* __restrict__ inv_std,\n                                         const layerscalar_t* __restrict__ weight,\n                                         const layerscalar_t* __restrict__ shift, scalar_t* __restrict__ out,\n                                         const int ss, const int bs) {\n  auto m_c = mean[blockIdx.x];\n  auto inv_std_c = inv_std[blockIdx.x];\n  auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x]);\n  auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[blockIdx.x]);\n\n  for (int batch_offset = blockIdx.y * blockDim.y + threadIdx.y; batch_offset < bs;\n       batch_offset += gridDim.y * blockDim.y) {\n    int address_base = blockIdx.x * ss + batch_offset * gridDim.x * ss;\n    for (int offset = threadIdx.x + blockIdx.z * blockDim.x; offset < ss; offset += gridDim.z * blockDim.x) {\n      out[address_base + offset] =\n          static_cast<scalar_t>(w_c * (static_cast<accscalar_t>(input[address_base + offset]) - m_c) * inv_std_c + s_c);\n    }\n  }\n}\n\n// Backward BN kernel, calculates grad_bias, grad_weight as well as intermediate\n// results to calculating grad_input.\n// Breaking the grad_input to two step to support sync BN, which requires all\n// reduce of the intermediate results across processes.\ntemplate <typename scalar_t, typename accscalar_t, typename layerscalar_t>\n__global__ void reduce_bn_kernel(const scalar_t* __restrict__ input, const scalar_t* __restrict__ grad_output,\n                                 const accscalar_t* __restrict__ mean, const accscalar_t* __restrict__ inv_std,\n                                 accscalar_t* __restrict__ sum_dy_o, accscalar_t* __restrict__ sum_dy_xmu_o,\n                                 layerscalar_t* __restrict__ grad_weight, layerscalar_t* __restrict__ grad_bias,\n                                 const int bs, const int fs, const int ss) {\n  static __shared__ int s_mem[64];\n  // int total_item_num = bs * ss;\n\n  int thread_id = threadIdx.y * blockDim.x + threadIdx.x;\n\n  auto r_mean = mean[blockIdx.x];\n  auto factor = inv_std[blockIdx.x];\n\n  // Kahan sum\n  accscalar_t sum_dy = 0.0;\n  accscalar_t sum_dy_xmu = 0.0;\n  accscalar_t sum_dy_c = 0.0;\n  accscalar_t sum_dy_xmu_c = 0.0;\n  for (int batch_id = threadIdx.y; batch_id < bs; batch_id += blockDim.y) {\n    int input_base = blockIdx.x * ss + batch_id * ss * fs;\n    for (int offset = threadIdx.x; offset < ss; offset += blockDim.x) {\n      auto e_grad = static_cast<accscalar_t>(grad_output[offset + input_base]);\n      auto e_input = static_cast<accscalar_t>(input[offset + input_base]);\n      // calculating sum_dy\n      auto sum_dy_y = e_grad - sum_dy_c;\n      auto sum_dy_t = sum_dy + sum_dy_y;\n      sum_dy_c = (sum_dy_t - sum_dy) - sum_dy_y;\n      sum_dy = sum_dy_t;\n\n      // calculating sum_dy_xmu\n      auto sum_dy_xmu_y = e_grad * (e_input - r_mean) - sum_dy_xmu_c;\n      auto sum_dy_xmu_t = sum_dy_xmu + sum_dy_xmu_y;\n      sum_dy_xmu_c = (sum_dy_xmu_t - sum_dy_xmu) - sum_dy_xmu_y;\n      sum_dy_xmu = sum_dy_xmu_t;\n    }\n  }\n\n  sum_dy = reduce_block((accscalar_t*)s_mem, sum_dy);\n  __syncthreads();\n  sum_dy_xmu = reduce_block((accscalar_t*)s_mem, sum_dy_xmu);\n\n  if (thread_id == 0) {\n    if (grad_bias != NULL) {\n      grad_bias[blockIdx.x] = static_cast<layerscalar_t>(sum_dy);\n    }\n    if (grad_weight != NULL) {\n      grad_weight[blockIdx.x] = static_cast<layerscalar_t>(sum_dy_xmu * factor);\n    }\n    // mean_dy[blockIdx.x] = sum_dy / total_item_num;\n    // mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num;\n    sum_dy_o[blockIdx.x] = sum_dy;\n    sum_dy_xmu_o[blockIdx.x] = sum_dy_xmu;\n  }\n}\n\n// elementwise backward BN kernel\ntemplate <typename scalar_t, typename accscalar_t, typename layerscalar_t>\n__global__ void batchnorm_backward_kernel(const scalar_t* __restrict__ grad_output, const scalar_t* __restrict__ input,\n                                          const accscalar_t* __restrict__ mean, const accscalar_t* __restrict__ inv_std,\n                                          const layerscalar_t* __restrict__ weight,\n                                          const accscalar_t* __restrict__ sum_dy,\n                                          const accscalar_t* __restrict__ sum_dy_xmu, const int* __restrict__ numel,\n                                          scalar_t* __restrict__ grad_input, const int64_t world_size, const int ss,\n                                          const int bs) {\n  int64_t div = 0;\n  for (int i = 0; i < world_size; i++) {\n    div += numel[i];\n  }\n  auto m_c = static_cast<accscalar_t>(mean[blockIdx.x]);\n  // auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);\n  auto m_dy_c = static_cast<accscalar_t>(sum_dy[blockIdx.x]) / div;\n  auto factor_1_c = inv_std[blockIdx.x];\n  auto factor_2_c = (weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x])) * factor_1_c;\n  // factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[blockIdx.x];\n  factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[blockIdx.x] / div;\n\n  for (int batch_offset = blockIdx.y * blockDim.y + threadIdx.y; batch_offset < bs;\n       batch_offset += gridDim.y * blockDim.y) {\n    int address_base = blockIdx.x * ss + batch_offset * gridDim.x * ss;\n    for (int offset = threadIdx.x + blockIdx.z * blockDim.x; offset < ss; offset += gridDim.z * blockDim.x) {\n      grad_input[address_base + offset] =\n          (static_cast<accscalar_t>(grad_output[address_base + offset]) - m_dy_c -\n           (static_cast<accscalar_t>(input[address_base + offset]) - m_c) * factor_1_c) *\n          factor_2_c;\n    }\n  }\n}\n\n// welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance\ntemplate <typename scalar_t, typename accscalar_t, typename outscalar_t, int PARALLEL_LOADS>\n__global__ void welford_kernel_c_last(const scalar_t* __restrict__ input, outscalar_t* __restrict__ out_mean,\n                                      outscalar_t* __restrict__ out_var_biased, volatile accscalar_t* staging_data,\n                                      int* semaphores, const int reduction_size, const int stride) {\n  // hide latency with concurrency\n  accscalar_t x_mean[PARALLEL_LOADS];\n  accscalar_t m_2_n[PARALLEL_LOADS];\n  int count[PARALLEL_LOADS];\n\n#pragma unroll\n  for (int i = 0; i < PARALLEL_LOADS; i++) {\n    x_mean[i] = accscalar_t(0);\n    m_2_n[i] = accscalar_t(0);\n    count[i] = accscalar_t(0);\n  }\n  // tensor dimension (m,c)\n\n  // loop along m dimension\n  int inner_loop_stride = blockDim.y * gridDim.y;\n\n  // offset along m dimension\n  int m_offset = blockIdx.y * blockDim.y + threadIdx.y;\n  int c_offset = blockIdx.x * blockDim.x + threadIdx.x;\n\n  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);\n  int address_base = m_offset * stride + c_offset;\n  int address_increment = inner_loop_stride * stride;\n\n  for (int i = 0; i < loop_count; i++) {\n    accscalar_t x_math[PARALLEL_LOADS];\n    accscalar_t x_count_inv[PARALLEL_LOADS];\n    accscalar_t is_valid[PARALLEL_LOADS];\n\n    // load multiple data in\n#pragma unroll\n    for (int j = 0; j < PARALLEL_LOADS; j++) {\n      if (c_offset < stride && m_offset < reduction_size) {\n        x_math[j] = input[address_base];\n        count[j]++;\n        x_count_inv[j] = accscalar_t(1) / count[j];\n        is_valid[j] = accscalar_t(1);\n      } else {\n        x_math[j] = accscalar_t(0);\n        x_count_inv[j] = accscalar_t(0);\n        is_valid[j] = accscalar_t(0);\n      }\n      m_offset += inner_loop_stride;\n      address_base += address_increment;\n    }\n\n    // calculate mean/m2n with welford\n#pragma unroll\n    for (int j = 0; j < PARALLEL_LOADS; j++) {\n      accscalar_t delta0 = x_math[j] - x_mean[j];\n      x_mean[j] += delta0 * x_count_inv[j];\n      accscalar_t delta1 = x_math[j] - x_mean[j];\n      m_2_n[j] += delta0 * delta1 * is_valid[j];\n    }\n  }\n\n  // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS\n#pragma unroll\n  for (int j = 1; j < PARALLEL_LOADS; j++) {\n    welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]);\n  }\n\n  // release x_mean / m_2_n\n  auto mean_th = x_mean[0];\n  auto m2_th = m_2_n[0];\n  auto count_th = count[0];\n\n  // block-wise reduction with shared memory (since reduction cannot be done within a warp)\n  static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE];\n  static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE];\n  static __shared__ int shmem_count[MAX_BLOCK_SIZE];\n\n  welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);\n\n  // grid reduction if needed (coop launch used at the first place)\n  if (gridDim.y > 1) {\n    volatile accscalar_t* staging_mean = staging_data;\n    volatile accscalar_t* staging_m2n = &staging_data[stride * gridDim.y];\n    volatile int* staging_count = reinterpret_cast<volatile int*>(&staging_m2n[stride * gridDim.y]);\n\n    address_base = c_offset + blockIdx.y * stride;\n    // write data to staging_data;\n    if (threadIdx.y == 0 && c_offset < stride) {\n      staging_mean[address_base] = mean_th;\n      staging_m2n[address_base] = m2_th;\n      staging_count[address_base] = count_th;\n    }\n\n    __threadfence();\n    __syncthreads();  // ensuring writes to staging_ is visible to all blocks\n\n    __shared__ bool is_last_block_done;\n    // mark block done\n    if (threadIdx.x == 0 && threadIdx.y == 0) {\n      int old = atomicAdd(&semaphores[blockIdx.x], 1);\n      is_last_block_done = (old == (gridDim.y - 1));\n    }\n\n    __syncthreads();\n\n    // check that all data is now available in global memory\n    if (is_last_block_done) {\n      count_th = 0;\n      mean_th = accscalar_t(0.0);\n      m2_th = accscalar_t(0.0);\n\n      for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {\n        address_base = c_offset + y * stride;\n        int num_new = c_offset < stride ? staging_count[address_base] : 0;\n        accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0);\n        accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0);\n\n        welford_merge_element(count_th, mean_th, m2_th, num_new, mean_new, m2n_new);\n      }\n\n      welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);\n      if (threadIdx.y == 0 && c_offset < stride) {\n        out_mean[c_offset] = static_cast<outscalar_t>(mean_th);\n        out_var_biased[c_offset] = static_cast<outscalar_t>(m2_th / count_th);\n      }\n    }\n  } else {\n    if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {\n      out_mean[c_offset] = static_cast<outscalar_t>(mean_th);\n      out_var_biased[c_offset] = static_cast<outscalar_t>(m2_th / count_th);\n    }\n  }\n}\n\n// parallel welford kernel to further reduce mean / biased_var\n// into mean / unbiased_var / inv_std across multiple processes.\ntemplate <typename scalar_t>\n__global__ void welford_kernel_parallel(const scalar_t* __restrict__ mean, const scalar_t* __restrict__ var_biased,\n                                        const int* __restrict__ numel, scalar_t* __restrict__ out_mean,\n                                        scalar_t* __restrict__ out_var, scalar_t* __restrict__ inv_std,\n                                        const int world_size, const int feature_size, const float eps) {\n  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < feature_size; i += gridDim.x * blockDim.x) {\n    // load data;\n    int address = i;\n    scalar_t x_mean = 0;\n    scalar_t m_2_n = 0;\n    int count = 0;\n    for (int j = 0; j < world_size; j++) {\n      welford_merge_element(count, x_mean, m_2_n, numel[j], mean[address], var_biased[address] * numel[j]);\n      address += feature_size;\n    }\n    out_mean[i] = x_mean;\n    out_var[i] = m_2_n / (count - 1);\n    inv_std[i] = scalar_t(1) / sqrt(m_2_n / count + eps);\n  }\n}\n\n// elementwise BN kernel\ntemplate <typename scalar_t, typename accscalar_t, typename layerscalar_t, int PARALLEL_LOADS>\n__global__ void batchnorm_forward_c_last_kernel(const scalar_t* __restrict__ input, const scalar_t* __restrict__ z,\n                                                const accscalar_t* __restrict__ mean,\n                                                const accscalar_t* __restrict__ inv_std,\n                                                const layerscalar_t* __restrict__ weight,\n                                                const layerscalar_t* __restrict__ shift, scalar_t* __restrict__ out,\n                                                const int reduction_size, const int stride, const bool fuse_relu) {\n  // tensor dimension (m,c)\n  // loop along m dimension\n  int inner_loop_stride = blockDim.y * gridDim.y;\n\n  // offset along m dimension\n  int m_offset = blockIdx.y * blockDim.y + threadIdx.y;\n  int c_offset = blockIdx.x * blockDim.x + threadIdx.x;\n\n  auto m_c = mean[c_offset];\n  auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);\n  auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);\n  auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);\n\n  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);\n  int address_base = m_offset * stride + c_offset;\n  int address_increment = inner_loop_stride * stride;\n\n  for (int i = 0; i < loop_count; i++) {\n#pragma unroll\n    for (int j = 0; j < PARALLEL_LOADS; j++) {\n      if (c_offset < stride && m_offset < reduction_size) {\n        auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c) * inv_std_c + s_c;\n        if (z != NULL) {\n          tmp += z[address_base];\n        }\n        out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast<scalar_t>(tmp));\n      }\n      m_offset += inner_loop_stride;\n      address_base += address_increment;\n    }\n  }\n}\n\n// elementwise BN kernel\ntemplate <typename scalar_t, typename accscalar_t, typename layerscalar_t, int PARALLEL_LOADS>\n__global__ void relu_backward_c_last_kernel(const scalar_t* __restrict__ grad_output,\n                                            const scalar_t* __restrict__ input, const scalar_t* __restrict__ z,\n                                            const accscalar_t* __restrict__ mean,\n                                            const accscalar_t* __restrict__ inv_std,\n                                            const layerscalar_t* __restrict__ weight,\n                                            const layerscalar_t* __restrict__ shift, scalar_t* __restrict__ out,\n                                            const int reduction_size, const int stride) {\n  // tensor dimension (m,c)\n  // loop along m dimension\n  int inner_loop_stride = blockDim.y * gridDim.y;\n\n  // offset along m dimension\n  int m_offset = blockIdx.y * blockDim.y + threadIdx.y;\n  int c_offset = blockIdx.x * blockDim.x + threadIdx.x;\n\n  auto m_c = mean[c_offset];\n  auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);\n  auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);\n  auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);\n\n  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);\n  int address_base = m_offset * stride + c_offset;\n  int address_increment = inner_loop_stride * stride;\n\n  for (int i = 0; i < loop_count; i++) {\n#pragma unroll\n    for (int j = 0; j < PARALLEL_LOADS; j++) {\n      if (c_offset < stride && m_offset < reduction_size) {\n        auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c) * inv_std_c + s_c;\n        if (z != NULL) {\n          tmp += z[address_base];\n        }\n        out[address_base] = (tmp <= accscalar_t(0.0) ? scalar_t(0.0) : grad_output[address_base]);\n      }\n      m_offset += inner_loop_stride;\n      address_base += address_increment;\n    }\n  }\n}\n\n// batchnorm backward kernel for c last tensor\ntemplate <typename scalar_t, typename accscalar_t, typename layerscalar_t, int PARALLEL_LOADS>\n__global__ void reduce_bn_c_last_kernel(const scalar_t* __restrict__ input, const scalar_t* __restrict__ grad_output,\n                                        const accscalar_t* __restrict__ mean, const accscalar_t* __restrict__ inv_std,\n                                        accscalar_t* __restrict__ sum_dy_o, accscalar_t* __restrict__ sum_dy_xmu_o,\n                                        layerscalar_t* __restrict__ grad_weight, layerscalar_t* __restrict__ grad_bias,\n                                        volatile accscalar_t* staging_data, int* semaphores, const int reduction_size,\n                                        const int stride) {\n  // hide latency with concurrency\n  accscalar_t sum_dy[PARALLEL_LOADS];\n  accscalar_t sum_dy_xmu[PARALLEL_LOADS];\n\n#pragma unroll\n  for (int i = 0; i < PARALLEL_LOADS; i++) {\n    sum_dy[i] = accscalar_t(0);\n    sum_dy_xmu[i] = accscalar_t(0);\n  }\n  // tensor dimension (m,c)\n\n  // loop along m dimension\n  int inner_loop_stride = blockDim.y * gridDim.y;\n\n  // offset along m dimension\n  int m_offset = blockIdx.y * blockDim.y + threadIdx.y;\n  int c_offset = blockIdx.x * blockDim.x + threadIdx.x;\n\n  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);\n  int address_base = m_offset * stride + c_offset;\n  int address_increment = inner_loop_stride * stride;\n\n  auto r_mean = mean[c_offset];\n  auto factor = inv_std[c_offset];\n\n  for (int i = 0; i < loop_count; i++) {\n    accscalar_t x_input[PARALLEL_LOADS];\n    accscalar_t x_grad_output[PARALLEL_LOADS];\n\n    // load multiple data in\n#pragma unroll\n    for (int j = 0; j < PARALLEL_LOADS; j++) {\n      if (c_offset < stride && m_offset < reduction_size) {\n        x_input[j] = input[address_base];\n        x_grad_output[j] = grad_output[address_base];\n      } else {\n        x_input[j] = accscalar_t(0);\n        x_grad_output[j] = accscalar_t(0);\n      }\n      m_offset += inner_loop_stride;\n      address_base += address_increment;\n    }\n\n    // calculate sum_dy / sum_dy_xmu\n#pragma unroll\n    for (int j = 0; j < PARALLEL_LOADS; j++) {\n      sum_dy[j] += x_grad_output[j];\n      sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean);\n    }\n  }\n\n  // thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS\n#pragma unroll\n  for (int j = 1; j < PARALLEL_LOADS; j++) {\n    sum_dy[0] += sum_dy[j];\n    sum_dy_xmu[0] += sum_dy_xmu[j];\n  }\n\n  // release array of registers\n  auto sum_dy_th = sum_dy[0];\n  auto sum_dy_xmu_th = sum_dy_xmu[0];\n\n  // block-wise reduction with shared memory (since reduction cannot be done within a warp)\n  static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE];\n  static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE];\n\n  merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);\n\n  // grid reduction if needed (coop launch used at the first place)\n  if (gridDim.y > 1) {\n    volatile accscalar_t* staging_sum_dy = staging_data;\n    volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride * gridDim.y];\n\n    address_base = c_offset + blockIdx.y * stride;\n    // write data to staging_data;\n    if (threadIdx.y == 0 && c_offset < stride) {\n      staging_sum_dy[address_base] = sum_dy_th;\n      staging_sum_dy_xmu[address_base] = sum_dy_xmu_th;\n    }\n\n    __threadfence();\n    __syncthreads();  // ensuring writes to staging_ is visible to all blocks\n\n    __shared__ bool is_last_block_done;\n    // mark block done\n    if (threadIdx.x == 0 && threadIdx.y == 0) {\n      int old = atomicAdd(&semaphores[blockIdx.x], 1);\n      is_last_block_done = (old == (gridDim.y - 1));\n    }\n\n    __syncthreads();\n\n    // check that all data is now available in global memory\n    if (is_last_block_done) {\n      sum_dy_th = accscalar_t(0.0);\n      sum_dy_xmu_th = accscalar_t(0.0);\n\n      for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {\n        address_base = c_offset + y * stride;\n        sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0));\n        sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0));\n      }\n\n      merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);\n      if (threadIdx.y == 0 && c_offset < stride) {\n        if (grad_bias != NULL) {\n          grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);\n        }\n        if (grad_weight != NULL) {\n          grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);\n        }\n        // mean_dy[c_offset] = sum_dy_th / reduction_size;\n        // mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;\n        sum_dy_o[c_offset] = sum_dy_th;\n        sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;\n      }\n    }\n  } else {\n    if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {\n      if (grad_bias != NULL) {\n        grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);\n      }\n      if (grad_weight != NULL) {\n        grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);\n      }\n      // mean_dy[c_offset] = sum_dy_th / reduction_size;\n      // mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;\n      sum_dy_o[c_offset] = sum_dy_th;\n      sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;\n    }\n  }\n}\n\n// elementwise BN kernel\ntemplate <typename scalar_t, typename accscalar_t, typename layerscalar_t, int PARALLEL_LOADS>\n__global__ void batchnorm_backward_c_last_kernel(\n    const scalar_t* __restrict__ grad_output, const scalar_t* __restrict__ input, const accscalar_t* __restrict__ mean,\n    const accscalar_t* __restrict__ inv_std, const layerscalar_t* __restrict__ weight,\n    const accscalar_t* __restrict__ sum_dy, const accscalar_t* __restrict__ sum_dy_xmu, const int* __restrict__ numel,\n    scalar_t* __restrict__ grad_input, const int64_t world_size, const int reduction_size, const int stride) {\n  int64_t div = 0;\n  for (int i = 0; i < world_size; i++) {\n    div += numel[i];\n  }\n  // tensor dimension (m,c)\n  // loop along m dimension\n  int inner_loop_stride = blockDim.y * gridDim.y;\n\n  // offset along m dimension\n  int m_offset = blockIdx.y * blockDim.y + threadIdx.y;\n  int c_offset = blockIdx.x * blockDim.x + threadIdx.x;\n\n  auto m_c = mean[c_offset];\n  auto m_dy_c = sum_dy[c_offset] / div;\n  auto factor_1_c = inv_std[c_offset];\n  auto factor_2_c = (weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c;\n  factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] / div;\n\n  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);\n  int address_base = m_offset * stride + c_offset;\n  int address_increment = inner_loop_stride * stride;\n\n  for (int i = 0; i < loop_count; i++) {\n#pragma unroll\n    for (int j = 0; j < PARALLEL_LOADS; j++) {\n      if (c_offset < stride && m_offset < reduction_size) {\n        grad_input[address_base] =\n            static_cast<scalar_t>((static_cast<accscalar_t>(grad_output[address_base]) - m_dy_c -\n                                   (static_cast<accscalar_t>(input[address_base]) - m_c) * factor_1_c) *\n                                  factor_2_c);\n      }\n      m_offset += inner_loop_stride;\n      address_base += address_increment;\n    }\n  }\n}\n\nstd::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {\n  const auto batch_size = input.size(0);\n  const auto feature_size = input.size(1);\n\n  auto space_size = get_tensor_spatial_size(input);\n  auto scalar_type = promote_scalartype(input);\n\n  at::Tensor out_var_biased = at::empty({feature_size}, input.options().dtype(scalar_type));\n  at::Tensor out_mean = at::empty({feature_size}, input.options().dtype(scalar_type));\n\n  int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / 32));\n  int block_x = max(1, min(MAX_BLOCK_SIZE / block_y, h_last_pow2(space_size)));\n  const dim3 block(block_x, block_y);\n  const dim3 grid(feature_size);\n\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  {\n    using namespace at;\n    DISPATCH_FLOAT_AND_HALF(\n        input.scalar_type(), 0, \"welford_mean_var_kernel\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n        welford_kernel<scalar_t_0, accscalar_t, accscalar_t>\n        <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t_0>(), out_mean.data_ptr<accscalar_t>(),\n                                     out_var_biased.data_ptr<accscalar_t>(), batch_size, feature_size, space_size););\n  }\n\n  return {out_mean, out_var_biased};\n}\n\nat::Tensor batchnorm_forward_CUDA(const at::Tensor input, const at::Tensor mean, const at::Tensor inv_std,\n                                  const at::optional<at::Tensor> weight, const at::optional<at::Tensor> shift) {\n  const auto batch_size = input.size(0);\n  const auto feature_size = input.size(1);\n  at::Tensor out = at::empty_like(input);\n\n  auto space_size = get_tensor_spatial_size(input);\n\n  int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size) / 4));\n  int block_y = max(1, min(MAX_BLOCK_SIZE / block_x, h_last_pow2(batch_size) / 4));\n  const dim3 block(block_x, block_y);\n  int grid_z = max(1, min(65535, h_last_pow2(space_size) / 4 / block_x));\n  int batch_group_size = max(1, min(65535, h_last_pow2(batch_size) / block_y));\n  const dim3 grid(feature_size, batch_group_size, grid_z);\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  if (input.scalar_type() == at::ScalarType::Half && weight.has_value() &&\n      weight.value().scalar_type() == at::ScalarType::Float) {\n    using namespace at;\n    DISPATCH_FLOAT_AND_HALF(\n        input.scalar_type(), 0, \"batchnorm_forward\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n        batchnorm_forward_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(\n            input.data_ptr<scalar_t_0>(), mean.data_ptr<accscalar_t>(), inv_std.data_ptr<accscalar_t>(),\n            weight.has_value() ? weight.value().data_ptr<accscalar_t>() : NULL,\n            shift.has_value() ? shift.value().data_ptr<accscalar_t>() : NULL, out.data_ptr<scalar_t_0>(), space_size,\n            batch_size););\n  } else {\n    if (weight.has_value()) {\n      TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),\n                  \"input.scalar_type() is not supported with weight.scalar_type()\");\n    }\n    using namespace at;\n    DISPATCH_FLOAT_AND_HALF(\n        input.scalar_type(), 0, \"batchnorm_forward\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n        batchnorm_forward_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(\n            input.data_ptr<scalar_t_0>(), mean.data_ptr<accscalar_t>(), inv_std.data_ptr<accscalar_t>(),\n            weight.has_value() ? weight.value().data_ptr<scalar_t_0>() : NULL,\n            shift.has_value() ? shift.value().data_ptr<scalar_t_0>() : NULL, out.data_ptr<scalar_t_0>(), space_size,\n            batch_size););\n  }\n  return out;\n}\n\nstd::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output, const at::Tensor input, const at::Tensor mean,\n                                       const at::Tensor inv_std, const at::optional<at::Tensor> weight) {\n  const auto batch_size = input.size(0);\n  const auto feature_size = input.size(1);\n\n  auto scalar_type = promote_scalartype(input);\n\n  at::Tensor sum_dy = at::empty({feature_size}, mean.options());\n  at::Tensor sum_dy_xmu = at::empty({feature_size}, mean.options());\n\n  at::Tensor grad_weight;\n  at::Tensor grad_bias;\n  if (weight.has_value()) {\n    grad_weight = at::empty({feature_size}, weight.value().options());\n    grad_bias = at::empty({feature_size}, weight.value().options());\n  } else {\n    grad_weight = at::empty({0}, mean.options());\n    grad_bias = at::empty({0}, mean.options());\n  }\n\n  auto space_size = get_tensor_spatial_size(input);\n\n  int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / 32));\n  int block_x = max(1, min(MAX_BLOCK_SIZE / block_y, h_last_pow2(space_size)));\n  const dim3 block(block_x, block_y);\n  const dim3 grid(feature_size);\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  if (input.scalar_type() == at::ScalarType::Half && weight.has_value() &&\n      weight.value().scalar_type() == at::ScalarType::Float) {\n    using namespace at;\n    DISPATCH_FLOAT_AND_HALF(\n        input.scalar_type(), 0, \"batchnorm_backward_reduce\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n        reduce_bn_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(\n            input.data_ptr<scalar_t_0>(), grad_output.data_ptr<scalar_t_0>(), mean.data_ptr<accscalar_t>(),\n            inv_std.data_ptr<accscalar_t>(), sum_dy.data_ptr<accscalar_t>(), sum_dy_xmu.data_ptr<accscalar_t>(),\n            weight.has_value() ? grad_weight.data_ptr<accscalar_t>() : NULL,\n            weight.has_value() ? grad_bias.data_ptr<accscalar_t>() : NULL, batch_size, feature_size, space_size););\n  } else {\n    if (weight.has_value()) {\n      TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),\n                  \"input.scalar_type() is not supported with weight.scalar_type()\");\n    }\n    using namespace at;\n    DISPATCH_FLOAT_AND_HALF(\n        input.scalar_type(), 0, \"batchnorm_backward_reduce\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n        reduce_bn_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(\n            input.data_ptr<scalar_t_0>(), grad_output.data_ptr<scalar_t_0>(), mean.data_ptr<accscalar_t>(),\n            inv_std.data_ptr<accscalar_t>(), sum_dy.data_ptr<accscalar_t>(), sum_dy_xmu.data_ptr<accscalar_t>(),\n            weight.has_value() ? grad_weight.data_ptr<scalar_t_0>() : NULL,\n            weight.has_value() ? grad_bias.data_ptr<scalar_t_0>() : NULL, batch_size, feature_size, space_size););\n  }\n\n  return {sum_dy, sum_dy_xmu, grad_weight, grad_bias};\n}\n\nat::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output, const at::Tensor input, const at::Tensor mean,\n                                   const at::Tensor inv_std, const at::optional<at::Tensor> weight,\n                                   const at::Tensor sum_dy, const at::Tensor sum_dy_xmu, const at::Tensor count) {\n  const auto batch_size = input.size(0);\n  const auto feature_size = input.size(1);\n\n  at::Tensor grad_input = at::empty_like(input);\n\n  auto space_size = get_tensor_spatial_size(input);\n\n  int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size) / 4));\n  int block_y = max(1, min(MAX_BLOCK_SIZE / block_x, h_last_pow2(batch_size) / 4));\n  const dim3 block(block_x, block_y);\n  int grid_z = max(1, min(65535, h_last_pow2(space_size) / 4 / block_x));\n  int batch_group_size = max(1, min(65535, h_last_pow2(batch_size) / block_y));\n  const dim3 grid(feature_size, batch_group_size, grid_z);\n\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  if (input.scalar_type() == at::ScalarType::Half && weight.has_value() &&\n      weight.value().scalar_type() == at::ScalarType::Float) {\n    using namespace at;\n    DISPATCH_FLOAT_AND_HALF(\n        input.scalar_type(), 0, \"batchnorm_backward\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n        batchnorm_backward_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(\n            grad_output.data_ptr<scalar_t_0>(), input.data_ptr<scalar_t_0>(), mean.data_ptr<accscalar_t>(),\n            inv_std.data_ptr<accscalar_t>(), weight.has_value() ? weight.value().data_ptr<accscalar_t>() : NULL,\n            sum_dy.data_ptr<accscalar_t>(), sum_dy_xmu.data_ptr<accscalar_t>(), count.data_ptr<int>(),\n            grad_input.data_ptr<scalar_t_0>(), count.numel(), space_size, batch_size););\n  } else {\n    if (weight.has_value()) {\n      TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),\n                  \"input.scalar_type() is not supported with weight.scalar_type()\");\n    }\n    using namespace at;\n    DISPATCH_FLOAT_AND_HALF(\n        input.scalar_type(), 0, \"batchnorm_backward\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n        batchnorm_backward_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(\n            grad_output.data_ptr<scalar_t_0>(), input.data_ptr<scalar_t_0>(), mean.data_ptr<accscalar_t>(),\n            inv_std.data_ptr<accscalar_t>(), weight.has_value() ? weight.value().data_ptr<scalar_t_0>() : NULL,\n            sum_dy.data_ptr<accscalar_t>(), sum_dy_xmu.data_ptr<accscalar_t>(), count.data_ptr<int>(),\n            grad_input.data_ptr<scalar_t_0>(), count.numel(), space_size, batch_size););\n  }\n\n  return grad_input;\n}\n\nstd::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes, const at::Tensor var_biased,\n                                              const at::Tensor numel, const float eps) {\n  const auto world_size = mean_feature_nodes.size(0);\n  const auto feature_size = mean_feature_nodes.size(1);\n\n  at::Tensor out_var = at::empty({feature_size}, var_biased.options());\n  at::Tensor inv_std = at::empty_like(out_var);\n  at::Tensor out_mean = at::empty_like(out_var);\n\n  at::Tensor mean_feature_nodes_ = mean_feature_nodes.contiguous();\n  at::Tensor var_biased_ = var_biased.contiguous();\n  at::Tensor numel_ = numel.contiguous();\n\n  // TODO(jie): tile this for memory coalescing!\n  const int block = std::min(h_last_pow2(feature_size), MAX_BLOCK_SIZE);\n  const int grid = std::max<int>(1, feature_size / block);\n\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  {\n    using namespace at;\n    DISPATCH_FLOAT_AND_HALF(mean_feature_nodes.scalar_type(), 0, \"welford_parallel_kernel\",\n                            welford_kernel_parallel<scalar_t_0><<<grid, block, 0, stream>>>(\n                                mean_feature_nodes_.data_ptr<scalar_t_0>(), var_biased_.data_ptr<scalar_t_0>(),\n                                numel_.data_ptr<int>(), out_mean.data_ptr<scalar_t_0>(), out_var.data_ptr<scalar_t_0>(),\n                                inv_std.data_ptr<scalar_t_0>(), world_size, feature_size, eps););\n  }\n\n  return {out_mean, out_var, inv_std};\n}\n\nstd::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {\n  const auto stride = input.size(input.ndimension() - 1);\n  const auto reduction_size = input.numel() / stride;\n\n  auto scalar_type = promote_scalartype(input);\n  auto option = input.options().dtype(scalar_type);\n\n  at::Tensor out_var_biased = at::empty({stride}, option);\n  at::Tensor out_mean = at::empty({stride}, option);\n\n  dim3 block;\n  dim3 grid;\n  flexible_launch_configs(reduction_size, stride, block, grid, true);\n\n  at::Tensor staging_data;\n  at::Tensor semaphores;\n  if (grid.y > 1) {\n    staging_data = at::empty({4 * stride * grid.y}, option);\n    semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));\n  }\n\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  {\n    using namespace at;\n    DISPATCH_FLOAT_AND_HALF(\n        input.scalar_type(), 0, \"welford_mean_var_c_last\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n        accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr<accscalar_t>() : nullptr;\n        int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr<int>() : nullptr;\n        welford_kernel_c_last<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER><<<grid, block, 0, stream>>>(\n            input.data_ptr<scalar_t_0>(), out_mean.data_ptr<accscalar_t>(), out_var_biased.data_ptr<accscalar_t>(),\n            staging_data_ptr, semaphores_ptr, reduction_size, stride););\n  }\n\n  return {out_mean, out_var_biased};\n}\n\nat::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input, const at::optional<at::Tensor> z,\n                                         const at::Tensor mean, const at::Tensor inv_std,\n                                         const at::optional<at::Tensor> weight, const at::optional<at::Tensor> shift,\n                                         const bool fuse_relu) {\n  const auto stride = input.size(input.ndimension() - 1);\n  const auto reduction_size = input.numel() / stride;\n\n  at::Tensor out = at::empty_like(input);\n\n  dim3 block;\n  dim3 grid;\n  flexible_launch_configs(reduction_size, stride, block, grid);\n\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  if (input.scalar_type() == at::ScalarType::Half && weight.has_value() &&\n      weight.value().scalar_type() == at::ScalarType::Float) {\n    using namespace at;\n    DISPATCH_FLOAT_AND_HALF(\n        input.scalar_type(), 0, \"batchnorm_forward\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n        batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>\n        <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t_0>(),\n                                     z.has_value() ? z.value().data_ptr<scalar_t_0>() : NULL,\n                                     mean.data_ptr<accscalar_t>(), inv_std.data_ptr<accscalar_t>(),\n                                     weight.has_value() ? weight.value().data_ptr<accscalar_t>() : NULL,\n                                     shift.has_value() ? shift.value().data_ptr<accscalar_t>() : NULL,\n                                     out.data_ptr<scalar_t_0>(), reduction_size, stride, fuse_relu););\n  } else {\n    if (weight.has_value()) {\n      TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),\n                  \"input.scalar_type() is not supported with weight.scalar_type()\");\n    }\n    using namespace at;\n    DISPATCH_FLOAT_AND_HALF(\n        input.scalar_type(), 0, \"batchnorm_forward\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n        batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>\n        <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t_0>(),\n                                     z.has_value() ? z.value().data_ptr<scalar_t_0>() : NULL,\n                                     mean.data_ptr<accscalar_t>(), inv_std.data_ptr<accscalar_t>(),\n                                     weight.has_value() ? weight.value().data_ptr<scalar_t_0>() : NULL,\n                                     shift.has_value() ? shift.value().data_ptr<scalar_t_0>() : NULL,\n                                     out.data_ptr<scalar_t_0>(), reduction_size, stride, fuse_relu););\n  }\n  return out;\n}\n\nstd::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output, const at::Tensor input,\n                                              const at::Tensor mean, const at::Tensor inv_std,\n                                              const at::optional<at::Tensor> weight) {\n  const auto stride = input.size(input.ndimension() - 1);\n  const auto reduction_size = input.numel() / stride;\n\n  at::Tensor sumn_dy = at::empty({stride}, mean.options());\n  at::Tensor sum_dy_xmu = at::empty({stride}, mean.options());\n\n  at::Tensor grad_weight;\n  at::Tensor grad_bias;\n  if (weight.has_value()) {\n    grad_weight = at::empty({stride}, weight.value().options());\n    grad_bias = at::empty({stride}, weight.value().options());\n  } else {\n    // because I cannot return an uninitialized at::Tensor\n    grad_weight = at::empty({0}, mean.options());\n    grad_bias = at::empty({0}, mean.options());\n  }\n\n  dim3 block;\n  dim3 grid;\n  flexible_launch_configs(reduction_size, stride, block, grid, true);\n\n  at::Tensor staging_data;\n  at::Tensor semaphores;\n  if (grid.y > 1) {\n    staging_data = at::empty({2 * stride * grid.y}, mean.options());\n    semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));\n  }\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  if (input.scalar_type() == at::ScalarType::Half && weight.has_value() &&\n      weight.value().scalar_type() == at::ScalarType::Float) {\n    using namespace at;\n    DISPATCH_FLOAT_AND_HALF(\n        input.scalar_type(), 0, \"batchnorm_backward_reduce\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n        accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr<accscalar_t>() : nullptr;\n        int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr<int>() : nullptr;\n        reduce_bn_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER><<<grid, block, 0, stream>>>(\n            input.data_ptr<scalar_t_0>(), grad_output.data_ptr<scalar_t_0>(), mean.data_ptr<accscalar_t>(),\n            inv_std.data_ptr<accscalar_t>(), sumn_dy.data_ptr<accscalar_t>(), sum_dy_xmu.data_ptr<accscalar_t>(),\n            weight.has_value() ? grad_weight.data_ptr<accscalar_t>() : NULL,\n            weight.has_value() ? grad_bias.data_ptr<accscalar_t>() : NULL, staging_data_ptr, semaphores_ptr,\n            reduction_size, stride););\n  } else {\n    if (weight.has_value()) {\n      TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),\n                  \"input.scalar_type() is not supported with weight.scalar_type()\");\n    }\n    using namespace at;\n    DISPATCH_FLOAT_AND_HALF(\n        input.scalar_type(), 0, \"batchnorm_backward_reduce\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n        accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr<accscalar_t>() : nullptr;\n        int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr<int>() : nullptr;\n        reduce_bn_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER><<<grid, block, 0, stream>>>(\n            input.data_ptr<scalar_t_0>(), grad_output.data_ptr<scalar_t_0>(), mean.data_ptr<accscalar_t>(),\n            inv_std.data_ptr<accscalar_t>(), sumn_dy.data_ptr<accscalar_t>(), sum_dy_xmu.data_ptr<accscalar_t>(),\n            weight.has_value() ? grad_weight.data_ptr<scalar_t_0>() : NULL,\n            weight.has_value() ? grad_bias.data_ptr<scalar_t_0>() : NULL, staging_data_ptr, semaphores_ptr,\n            reduction_size, stride););\n  }\n\n  return {sumn_dy, sum_dy_xmu, grad_weight, grad_bias};\n}\n\nat::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output, const at::Tensor input, const at::Tensor mean,\n                                          const at::Tensor inv_std, const at::optional<at::Tensor> weight,\n                                          const at::Tensor sum_dy, const at::Tensor sum_dy_xmu,\n                                          const at::Tensor count) {\n  const auto stride = input.size(input.ndimension() - 1);\n  const auto reduction_size = input.numel() / stride;\n\n  at::Tensor grad_input = at::empty_like(input);\n\n  dim3 block;\n  dim3 grid;\n  flexible_launch_configs(reduction_size, stride, block, grid);\n\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  if (input.scalar_type() == at::ScalarType::Half && weight.has_value() &&\n      weight.value().scalar_type() == at::ScalarType::Float) {\n    using namespace at;\n    DISPATCH_FLOAT_AND_HALF(\n        input.scalar_type(), 0, \"batchnorm_forward\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n        batchnorm_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>\n        <<<grid, block, 0, stream>>>(\n            grad_output.data_ptr<scalar_t_0>(), input.data_ptr<scalar_t_0>(), mean.data_ptr<accscalar_t>(),\n            inv_std.data_ptr<accscalar_t>(), weight.has_value() ? weight.value().data_ptr<accscalar_t>() : NULL,\n            sum_dy.data_ptr<accscalar_t>(), sum_dy_xmu.data_ptr<accscalar_t>(), count.data_ptr<int>(),\n            grad_input.data_ptr<scalar_t_0>(), count.numel(), reduction_size, stride););\n  } else {\n    if (weight.has_value()) {\n      TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),\n                  \"input.scalar_type() is not supported with weight.scalar_type()\");\n    }\n    using namespace at;\n    DISPATCH_FLOAT_AND_HALF(\n        input.scalar_type(), 0, \"batchnorm_forward\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n        batchnorm_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>\n        <<<grid, block, 0, stream>>>(\n            grad_output.data_ptr<scalar_t_0>(), input.data_ptr<scalar_t_0>(), mean.data_ptr<accscalar_t>(),\n            inv_std.data_ptr<accscalar_t>(), weight.has_value() ? weight.value().data_ptr<scalar_t_0>() : NULL,\n            sum_dy.data_ptr<accscalar_t>(), sum_dy_xmu.data_ptr<accscalar_t>(), count.data_ptr<int>(),\n            grad_input.data_ptr<scalar_t_0>(), count.numel(), reduction_size, stride););\n  }\n\n  return grad_input;\n}\n\nat::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output, const at::Tensor input,\n                                     const at::optional<at::Tensor> z, const at::Tensor mean, const at::Tensor inv_std,\n                                     const at::optional<at::Tensor> weight, const at::optional<at::Tensor> shift) {\n  const auto stride = input.size(input.ndimension() - 1);\n  const auto reduction_size = input.numel() / stride;\n\n  at::Tensor out = at::empty_like(input);\n\n  dim3 block;\n  dim3 grid;\n  flexible_launch_configs(reduction_size, stride, block, grid);\n\n  auto stream = at::cuda::getCurrentCUDAStream();\n\n  if (input.scalar_type() == at::ScalarType::Half && weight.has_value() &&\n      weight.value().scalar_type() == at::ScalarType::Float) {\n    using namespace at;\n    DISPATCH_FLOAT_AND_HALF(\n        input.scalar_type(), 0, \"batchnorm_forward\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n        relu_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>\n        <<<grid, block, 0, stream>>>(grad_output.data_ptr<scalar_t_0>(), input.data_ptr<scalar_t_0>(),\n                                     z.has_value() ? z.value().data_ptr<scalar_t_0>() : NULL,\n                                     mean.data_ptr<accscalar_t>(), inv_std.data_ptr<accscalar_t>(),\n                                     weight.has_value() ? weight.value().data_ptr<accscalar_t>() : NULL,\n                                     shift.has_value() ? shift.value().data_ptr<accscalar_t>() : NULL,\n                                     out.data_ptr<scalar_t_0>(), reduction_size, stride););\n  } else {\n    if (weight.has_value()) {\n      TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),\n                  \"input.scalar_type() is not supported with weight.scalar_type()\");\n    }\n    using namespace at;\n    DISPATCH_FLOAT_AND_HALF(\n        input.scalar_type(), 0, \"batchnorm_forward\", using accscalar_t = at::acc_type<scalar_t_0, true>;\n        relu_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER><<<grid, block, 0, stream>>>(\n            grad_output.data_ptr<scalar_t_0>(), input.data_ptr<scalar_t_0>(),\n            z.has_value() ? z.value().data_ptr<scalar_t_0>() : NULL, mean.data_ptr<accscalar_t>(),\n            inv_std.data_ptr<accscalar_t>(), weight.has_value() ? weight.value().data_ptr<scalar_t_0>() : NULL,\n            shift.has_value() ? shift.value().data_ptr<scalar_t_0>() : NULL, out.data_ptr<scalar_t_0>(), reduction_size,\n            stride););\n  }\n  return out;\n}\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHINXBUILD   = sphinx-build\nSPHINXPROJ    = NVIDIAAPEX\nSOURCEDIR     = source\nBUILDDIR      = build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\ngh-pages:\n\tgit checkout gh-pages\n\trm -rf build\n\trm -rf source\n\tgit checkout master -- .\n\tmake html\n\trm -rf ../_modules ../_sources ../_static\n\tmv -fv build/html/* ../\n\trm -rf build\n\tgit add -A\n\tgit commit -m \"Generated gh-pages for `git log master -1 --pretty=short --abbrev-commit`\" && git push origin gh-pages ; git checkout master\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/source/_static/css/pytorch_theme.css",
    "content": "body {\n    font-family: \"Lato\",\"proxima-nova\",\"Helvetica Neue\",Arial,sans-serif;\n}\n\n/* Default header fonts are ugly */\nh1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption {\n    font-family: \"Lato\",\"proxima-nova\",\"Helvetica Neue\",Arial,sans-serif;\n}\n\n/* Use white for docs background */\n.wy-side-nav-search {\n    background-color: #fff;\n}\n\n.wy-nav-content-wrap, .wy-menu li.current > a  {\n    background-color: #fff;\n}\n\n@media screen and (min-width: 1400px) {\n    .wy-nav-content-wrap {\n        background-color: rgba(0, 0, 0, 0.0470588);\n    }\n\n    .wy-nav-content {\n        background-color: #fff;\n    }\n}\n\n/* Fixes for mobile */\n.wy-nav-top {\n    background-color: #fff;\n    background-image: url('../img/apex.jpg');\n    background-repeat: no-repeat;\n    background-position: center;\n    padding: 0;\n    margin: 0.4045em 0.809em;\n    color: #333;\n}\n\n.wy-nav-top > a {\n    display: none;\n}\n\n@media screen and (max-width: 768px) {\n    .wy-side-nav-search>a img.logo {\n        height: 60px;\n    }\n}\n\n/* This is needed to ensure that logo above search scales properly */\n.wy-side-nav-search a {\n    display: block;\n}\n\n/* This ensures that multiple constructors will remain in separate lines. */\n.rst-content dl:not(.docutils) dt {\n    display: table;\n}\n\n/* Use our red for literals (it's very similar to the original color) */\n.rst-content tt.literal, .rst-content tt.literal, .rst-content code.literal {\n    color: #F05732;\n}\n\n.rst-content tt.xref, a .rst-content tt, .rst-content tt.xref,\n.rst-content code.xref, a .rst-content tt, a .rst-content code {\n    color: #404040;\n}\n\n/* Change link colors (except for the menu) */\n\na {\n    color: #F05732;\n}\n\na:hover {\n    color: #F05732;\n}\n\n\na:visited {\n    color: #D44D2C;\n}\n\n.wy-menu a {\n    color: #b3b3b3;\n}\n\n.wy-menu a:hover {\n    color: #b3b3b3;\n}\n\n/* Default footer text is quite big */\nfooter {\n    font-size: 80%;\n}\n\nfooter .rst-footer-buttons {\n    font-size: 125%; /* revert footer settings - 1/80% = 125% */\n}\n\nfooter p {\n    font-size: 100%;\n}\n\n/* For hidden headers that appear in TOC tree */\n/* see http://stackoverflow.com/a/32363545/3343043 */\n.rst-content .hidden-section {\n    display: none;\n}\n\nnav .hidden-section {\n    display: inherit;\n}\n\n.wy-side-nav-search>div.version {\n    color: #000;\n}\n"
  },
  {
    "path": "docs/source/_templates/layout.html",
    "content": "{% extends \"!layout.html\" %}\n  {% block sidebartitle %} {{ super() }}\n\n  <style>\n    /* Sidebar header (and topbar for mobile) */\n    .wy-side-nav-search, .wy-nav-top {\n      background: #76b900;\n    }\n\n    .wy-side-nav-search a:link, .wy-nav-top a:link {\n      color: #fff;\n    }\n    .wy-side-nav-search a:visited, .wy-nav-top a:visited {\n      color: #fff;\n    }\n    .wy-side-nav-search a:hover, .wy-nav-top a:hover {\n      color: #fff;\n    }\n\n    .wy-menu-vertical a:link, .wy-menu-vertical a:visited {\n      color: #d9d9d9\n    }\n\n    .wy-menu-vertical a:active {\n      background-color: #76b900\n    }\n\n    .wy-side-nav-search>div.version {\n      color: rgba(0, 0, 0, 0.3)\n    }\n  </style>\n  {% endblock %}\n\n  {% block footer %} {{ super() }}\n\n  <style>\n  a:link, a:visited {\n    color: #76b900;\n  }\n\n  a:hover {\n    color: #8c0;\n  }\n\n  .rst-content dl:not(.docutils) dt {\n    background: rgba(118, 185, 0, 0.1);\n    color: rgba(59,93,0,1);\n    border-top: solid 3px rgba(59,93,0,1);\n  }\n  </style>\n  {% endblock %}\n"
  },
  {
    "path": "docs/source/conf.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n#\n# PyTorch documentation build configuration file, created by\n# sphinx-quickstart on Fri Dec 23 13:31:47 2016.\n#\n# This file is execfile()d with the current directory set to its\n# containing dir.\n#\n# Note that not all possible configuration values are present in this\n# autogenerated file.\n#\n# All configuration values have a default; values that are commented out\n# serve to show the default.\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\nimport os\nimport sys\n\nsys.path.insert(0, os.path.abspath(\".\"))\n# sys.path.insert(0, os.path.abspath('../../apex/parallel/'))\n# import multiproc\nimport sphinx_rtd_theme\n\n\n# -- General configuration ------------------------------------------------\n\n# If your documentation needs a minimal Sphinx version, state it here.\n#\n# needs_sphinx = '1.0'\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    \"sphinx.ext.autodoc\",\n    \"sphinx.ext.autosummary\",\n    \"sphinx.ext.doctest\",\n    \"sphinx.ext.intersphinx\",\n    \"sphinx.ext.todo\",\n    \"sphinx.ext.coverage\",\n    \"sphinx.ext.mathjax\",\n    \"sphinx.ext.napoleon\",\n    \"sphinx.ext.viewcode\",\n    \"sphinx.ext.extlinks\",\n]\n\nnapoleon_use_ivar = True\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = [\"_templates\"]\n\n# The suffix(es) of source filenames.\n# You can specify multiple suffix as a list of string:\n#\n# source_suffix = ['.rst', '.md']\nsource_suffix = \".rst\"\n\n# The master toctree document.\nmaster_doc = \"index\"\n\n# General information about the project.\nproject = \"Apex\"\ncopyright = \"2018\"\nauthor = \"Christian Sarofeen, Natalia Gimelshein, Michael Carilli, Raul Puri\"\n\n# The version info for the project you're documenting, acts as replacement for\n# |version| and |release|, also used in various other places throughout the\n# built documents.\n#\n# The short X.Y version.\n# TODO: change to [:2] at v1.0\n# version = 'master (' + torch.__version__ + ' )'\nversion = \"0.1\"\n# The full version, including alpha/beta/rc tags.\n# TODO: verify this works as expected\nrelease = \"0.1.0\"\n\n# The language for content autogenerated by Sphinx. Refer to documentation\n# for a list of supported languages.\n#\n# This is also used if you do content translation via gettext catalogs.\n# Usually you set \"language\" from the command line for these cases.\nlanguage = None\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path\nexclude_patterns = []\n\n# The name of the Pygments (syntax highlighting) style to use.\npygments_style = \"sphinx\"\n\n# If true, `todo` and `todoList` produce output, else they produce nothing.\ntodo_include_todos = True\n\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = \"sphinx_rtd_theme\"\nhtml_theme_path = [sphinx_rtd_theme.get_html_theme_path()]\n\n# Theme options are theme-specific and customize the look and feel of a theme\n# further.  For a list of options available for each theme, see the\n# documentation.\n#\nhtml_theme_options = {\n    \"collapse_navigation\": False,\n    \"display_version\": True,\n    \"logo_only\": True,\n}\n\n# html_logo = '_static/img/nv-pytorch2.png'\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = [\"_static\"]\n\n# html_style_path = 'css/pytorch_theme.css'\nhtml_context = {\n    \"css_files\": [\n        \"https://fonts.googleapis.com/css?family=Lato\",\n        \"_static/css/pytorch_theme.css\",\n    ],\n}\n\n\n# -- Options for HTMLHelp output ---------------------------------------------\n\n# Output file base name for HTML help builder.\nhtmlhelp_basename = \"PyTorchdoc\"\n\n\n# -- Options for LaTeX output ------------------------------------------------\n\nlatex_elements = {\n    # The paper size ('letterpaper' or 'a4paper').\n    #\n    # 'papersize': 'letterpaper',\n    # The font size ('10pt', '11pt' or '12pt').\n    #\n    # 'pointsize': '10pt',\n    # Additional stuff for the LaTeX preamble.\n    #\n    # 'preamble': '',\n    # Latex figure (float) alignment\n    #\n    # 'figure_align': 'htbp',\n}\n\n# Grouping the document tree into LaTeX files. List of tuples\n# (source start file, target name, title,\n#  author, documentclass [howto, manual, or own class]).\nlatex_documents = [\n    (master_doc, \"apex.tex\", \"Apex Documentation\", \"Torch Contributors\", \"manual\"),\n]\n\n\n# -- Options for manual page output ------------------------------------------\n\n# One entry per manual page. List of tuples\n# (source start file, name, description, authors, manual section).\nman_pages = [(master_doc, \"Apex\", \"Apex Documentation\", [author], 1)]\n\n\n# -- Options for Texinfo output ----------------------------------------------\n\n# Grouping the document tree into Texinfo files. List of tuples\n# (source start file, target name, title, author,\n#  dir menu entry, description, category)\ntexinfo_documents = [\n    (\n        master_doc,\n        \"Apex\",\n        \"Apex Documentation\",\n        author,\n        \"Apex\",\n        \"One line description of project.\",\n        \"Miscellaneous\",\n    ),\n]\n\n\n# Example configuration for intersphinx: refer to the Python standard library.\nintersphinx_mapping = {\n    \"python\": (\"https://docs.python.org/\", None),\n    \"numpy\": (\"http://docs.scipy.org/doc/numpy/\", None),\n}\n\n# -- A patch that prevents Sphinx from cross-referencing ivar tags -------\n# See http://stackoverflow.com/a/41184353/3343043\n\nfrom docutils import nodes\nfrom sphinx.util.docfields import TypedField\nfrom sphinx import addnodes\n\n\ndef patched_make_field(self, types, domain, items, **kw):\n    # `kw` catches `env=None` needed for newer sphinx while maintaining\n    #  backwards compatibility when passed along further down!\n\n    # type: (List, unicode, Tuple) -> nodes.field\n    def handle_item(fieldarg, content):\n        par = nodes.paragraph()\n        par += addnodes.literal_strong(\"\", fieldarg)  # Patch: this line added\n        # par.extend(self.make_xrefs(self.rolename, domain, fieldarg,\n        #                           addnodes.literal_strong))\n        if fieldarg in types:\n            par += nodes.Text(\" (\")\n            # NOTE: using .pop() here to prevent a single type node to be\n            # inserted twice into the doctree, which leads to\n            # inconsistencies later when references are resolved\n            fieldtype = types.pop(fieldarg)\n            if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text):\n                typename = \"\".join(n.astext() for n in fieldtype)\n                typename = typename.replace(\"int\", \"python:int\")\n                typename = typename.replace(\"long\", \"python:long\")\n                typename = typename.replace(\"float\", \"python:float\")\n                typename = typename.replace(\"type\", \"python:type\")\n                par.extend(\n                    self.make_xrefs(\n                        self.typerolename,\n                        domain,\n                        typename,\n                        addnodes.literal_emphasis,\n                        **kw,\n                    )\n                )\n            else:\n                par += fieldtype\n            par += nodes.Text(\")\")\n        par += nodes.Text(\" -- \")\n        par += content\n        return par\n\n    fieldname = nodes.field_name(\"\", self.label)\n    if len(items) == 1 and self.can_collapse:\n        fieldarg, content = items[0]\n        bodynode = handle_item(fieldarg, content)\n    else:\n        bodynode = self.list_type()\n        for fieldarg, content in items:\n            bodynode += nodes.list_item(\"\", handle_item(fieldarg, content))\n    fieldbody = nodes.field_body(\"\", bodynode)\n    return nodes.field(\"\", fieldname, fieldbody)\n\n\nTypedField.make_field = patched_make_field\n"
  },
  {
    "path": "docs/source/index.rst",
    "content": ".. PyTorch documentation master file, created by\n   sphinx-quickstart on Fri Dec 23 13:31:47 2016.\n   You can adapt this file completely to your liking, but it should at least\n   contain the root `toctree` directive.\n\n:github_url: https://github.com/nvidia/apex\n\nApex (A PyTorch Extension)\n===================================\n\nThis site contains the API documentation for Apex (https://github.com/nvidia/apex),\na Pytorch extension with NVIDIA-maintained utilities to streamline mixed precision and distributed training.  Some of the code here will be included in upstream Pytorch eventually. The intention of Apex is to make up-to-date utilities available to users as quickly as possible.\n\nInstallation instructions can be found here:  https://github.com/NVIDIA/apex#quick-start.\n\nSome other useful material, including GTC 2019 and Pytorch DevCon 2019 Slides, can be found here:  https://github.com/mcarilli/mixed_precision_references.\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Fused Optimizers\n\n   optimizers\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Fused Layer Norm\n\n   layernorm\n\n..   .. toctree::\n     :maxdepth: 1\n     :caption: Deprecated mixed precision API\n     fp16_util\n\n..   RNN\n   \nIndices and tables\n==================\n\n* :ref:`genindex`\n* :ref:`modindex`\n"
  },
  {
    "path": "docs/source/layernorm.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\napex.normalization.fused_layer_norm\n===================================\n\n.. automodule:: apex.normalization\n.. currentmodule:: apex.normalization\n\n.. FusedAdam\n   ----------\n\n.. autoclass:: FusedLayerNorm\n    :members:\n\n.. autoclass:: FusedRMSNorm\n    :members:\n"
  },
  {
    "path": "docs/source/optimizers.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\napex.optimizers\n===================================\n\n.. automodule:: apex.optimizers\n.. currentmodule:: apex.optimizers\n\n.. FusedAdam\n   ----------\n\n.. autoclass:: FusedAdam\n    :members:\n\n.. autoclass:: FusedLAMB\n    :members:\n\n.. autoclass:: FusedNovoGrad\n    :members:\n\n.. autoclass:: FusedSGD\n    :members:\n"
  },
  {
    "path": "examples/README.md",
    "content": "This directory contains examples illustrating Apex mixed precision and distributed tools.\n\n**Note for users of the pre-unification API**:\n`deprecated_api` contains examples illustrating the old (pre-unified) APIs.  These APIs will be removed soon, and users are strongly encouraged to switch.  The separate mixed precision tools called `Amp` and `FP16_Optimizer` in the old API are exposed via different flags/optimization levels in the new API.\n"
  },
  {
    "path": "examples/dcgan/README.md",
    "content": "# Mixed Precision DCGAN Training in PyTorch\n\n`main_amp.py` is based on [https://github.com/pytorch/examples/tree/master/dcgan](https://github.com/pytorch/examples/tree/master/dcgan).\nIt implements Automatic Mixed Precision (Amp) training of the DCGAN example for different datasets. Command-line flags forwarded to `amp.initialize` are used to easily manipulate and switch between various pure and mixed precision \"optimization levels\" or `opt_level`s.  For a detailed explanation of `opt_level`s, see the [updated API guide](https://nvidia.github.io/apex/amp.html).\n\nWe introduce these changes to the PyTorch DCGAN example as described in the [Multiple models/optimizers/losses](https://nvidia.github.io/apex/advanced.html#multiple-models-optimizers-losses) section of the documentation::\n```\n# Added after models and optimizers construction\n[netD, netG], [optimizerD, optimizerG] = amp.initialize(\n    [netD, netG], [optimizerD, optimizerG], opt_level=opt.opt_level, num_losses=3)\n...\n# loss.backward() changed to:\nwith amp.scale_loss(errD_real, optimizerD, loss_id=0) as errD_real_scaled:\n    errD_real_scaled.backward()\n...\nwith amp.scale_loss(errD_fake, optimizerD, loss_id=1) as errD_fake_scaled:\n    errD_fake_scaled.backward()\n...\nwith amp.scale_loss(errG, optimizerG, loss_id=2) as errG_scaled:\n    errG_scaled.backward()\n```\n\nNote that we use different `loss_scalers` for each computed loss.\nUsing a separate loss scaler per loss is [optional, not required](https://nvidia.github.io/apex/advanced.html#optionally-have-amp-use-a-different-loss-scaler-per-loss).\n\nTo improve the numerical stability, we swapped `nn.Sigmoid() + nn.BCELoss()` to `nn.BCEWithLogitsLoss()`.\n\nWith the new Amp API **you never need to explicitly convert your model, or the input data, to half().**\n\n\"Pure FP32\" training:\n```\n$ python main_amp.py --opt_level O0\n```\nRecommended mixed precision training:\n```\n$ python main_amp.py --opt_level O1\n```\n\nHave a look at the original [DCGAN example](https://github.com/pytorch/examples/tree/master/dcgan) for more information about the used arguments.\n\nTo enable mixed precision training, we introduce the `--opt_level` argument.\n"
  },
  {
    "path": "examples/dcgan/main_amp.py",
    "content": "from __future__ import print_function\nimport argparse\nimport os\nimport random\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.backends.cudnn as cudnn\nimport torch.optim as optim\nimport torch.utils.data\nimport torchvision.datasets as dset\nimport torchvision.transforms as transforms\nimport torchvision.utils as vutils\n\ntry:    \n    from apex import amp\nexcept ImportError:\n    raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to run this example.\")\n\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--dataset', default='cifar10', help='cifar10 | lsun | mnist |imagenet | folder | lfw | fake')\nparser.add_argument('--dataroot', default='./', help='path to dataset')\nparser.add_argument('--workers', type=int, help='number of data loading workers', default=2)\nparser.add_argument('--batchSize', type=int, default=64, help='input batch size')\nparser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network')\nparser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')\nparser.add_argument('--ngf', type=int, default=64)\nparser.add_argument('--ndf', type=int, default=64)\nparser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')\nparser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')\nparser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')\nparser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')\nparser.add_argument('--netG', default='', help=\"path to netG (to continue training)\")\nparser.add_argument('--netD', default='', help=\"path to netD (to continue training)\")\nparser.add_argument('--outf', default='.', help='folder to output images and model checkpoints')\nparser.add_argument('--manualSeed', type=int, help='manual seed')\nparser.add_argument('--classes', default='bedroom', help='comma separated list of classes for the lsun data set')\nparser.add_argument('--opt_level', default='O1', help='amp opt_level, default=\"O1\"')\n\nopt = parser.parse_args()\nprint(opt)\n\n\ntry:\n    os.makedirs(opt.outf)\nexcept OSError:\n    pass\n\nif opt.manualSeed is None:\n    opt.manualSeed = 2809\nprint(\"Random Seed: \", opt.manualSeed)\nrandom.seed(opt.manualSeed)\ntorch.manual_seed(opt.manualSeed)\n\ncudnn.benchmark = True\n\n\nif opt.dataset in ['imagenet', 'folder', 'lfw']:\n    # folder dataset\n    dataset = dset.ImageFolder(root=opt.dataroot,\n                               transform=transforms.Compose([\n                                   transforms.Resize(opt.imageSize),\n                                   transforms.CenterCrop(opt.imageSize),\n                                   transforms.ToTensor(),\n                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n                               ]))\n    nc=3\nelif opt.dataset == 'lsun':\n    classes = [ c + '_train' for c in opt.classes.split(',')]\n    dataset = dset.LSUN(root=opt.dataroot, classes=classes,\n                        transform=transforms.Compose([\n                            transforms.Resize(opt.imageSize),\n                            transforms.CenterCrop(opt.imageSize),\n                            transforms.ToTensor(),\n                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n                        ]))\n    nc=3\nelif opt.dataset == 'cifar10':\n    dataset = dset.CIFAR10(root=opt.dataroot, download=True,\n                           transform=transforms.Compose([\n                               transforms.Resize(opt.imageSize),\n                               transforms.ToTensor(),\n                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n                           ]))\n    nc=3\n\nelif opt.dataset == 'mnist':\n        dataset = dset.MNIST(root=opt.dataroot, download=True,\n                           transform=transforms.Compose([\n                               transforms.Resize(opt.imageSize),\n                               transforms.ToTensor(),\n                               transforms.Normalize((0.5,), (0.5,)),\n                           ]))\n        nc=1\n\nelif opt.dataset == 'fake':\n    dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),\n                            transform=transforms.ToTensor())\n    nc=3\n\nassert dataset\ndataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,\n                                         shuffle=True, num_workers=int(opt.workers))\n\ndevice = torch.device(\"cuda:0\")\nngpu = int(opt.ngpu)\nnz = int(opt.nz)\nngf = int(opt.ngf)\nndf = int(opt.ndf)\n\n\n# custom weights initialization called on netG and netD\ndef weights_init(m):\n    classname = m.__class__.__name__\n    if classname.find('Conv') != -1:\n        m.weight.data.normal_(0.0, 0.02)\n    elif classname.find('BatchNorm') != -1:\n        m.weight.data.normal_(1.0, 0.02)\n        m.bias.data.fill_(0)\n\n\nclass Generator(nn.Module):\n    def __init__(self, ngpu):\n        super(Generator, self).__init__()\n        self.ngpu = ngpu\n        self.main = nn.Sequential(\n            # input is Z, going into a convolution\n            nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),\n            nn.BatchNorm2d(ngf * 8),\n            nn.ReLU(True),\n            # state size. (ngf*8) x 4 x 4\n            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),\n            nn.BatchNorm2d(ngf * 4),\n            nn.ReLU(True),\n            # state size. (ngf*4) x 8 x 8\n            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),\n            nn.BatchNorm2d(ngf * 2),\n            nn.ReLU(True),\n            # state size. (ngf*2) x 16 x 16\n            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),\n            nn.BatchNorm2d(ngf),\n            nn.ReLU(True),\n            # state size. (ngf) x 32 x 32\n            nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),\n            nn.Tanh()\n            # state size. (nc) x 64 x 64\n        )\n\n    def forward(self, input):\n        if input.is_cuda and self.ngpu > 1:\n            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))\n        else:\n            output = self.main(input)\n        return output\n\n\nnetG = Generator(ngpu).to(device)\nnetG.apply(weights_init)\nif opt.netG != '':\n    netG.load_state_dict(torch.load(opt.netG))\nprint(netG)\n\n\nclass Discriminator(nn.Module):\n    def __init__(self, ngpu):\n        super(Discriminator, self).__init__()\n        self.ngpu = ngpu\n        self.main = nn.Sequential(\n            # input is (nc) x 64 x 64\n            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),\n            nn.LeakyReLU(0.2, inplace=True),\n            # state size. (ndf) x 32 x 32\n            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),\n            nn.BatchNorm2d(ndf * 2),\n            nn.LeakyReLU(0.2, inplace=True),\n            # state size. (ndf*2) x 16 x 16\n            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),\n            nn.BatchNorm2d(ndf * 4),\n            nn.LeakyReLU(0.2, inplace=True),\n            # state size. (ndf*4) x 8 x 8\n            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),\n            nn.BatchNorm2d(ndf * 8),\n            nn.LeakyReLU(0.2, inplace=True),\n            # state size. (ndf*8) x 4 x 4\n            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),\n        )\n\n    def forward(self, input):\n        if input.is_cuda and self.ngpu > 1:\n            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))\n        else:\n            output = self.main(input)\n\n        return output.view(-1, 1).squeeze(1)\n\n\nnetD = Discriminator(ngpu).to(device)\nnetD.apply(weights_init)\nif opt.netD != '':\n    netD.load_state_dict(torch.load(opt.netD))\nprint(netD)\n\ncriterion = nn.BCEWithLogitsLoss()\n\nfixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)\nreal_label = 1\nfake_label = 0\n\n# setup optimizer\noptimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))\noptimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))\n\n[netD, netG], [optimizerD, optimizerG] = amp.initialize(\n    [netD, netG], [optimizerD, optimizerG], opt_level=opt.opt_level, num_losses=3)\n\nfor epoch in range(opt.niter):\n    for i, data in enumerate(dataloader, 0):\n        ############################\n        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))\n        ###########################\n        # train with real\n        netD.zero_grad()\n        real_cpu = data[0].to(device)\n        batch_size = real_cpu.size(0)\n        label = torch.full((batch_size,), real_label, device=device)\n\n        output = netD(real_cpu)\n        errD_real = criterion(output, label)\n        with amp.scale_loss(errD_real, optimizerD, loss_id=0) as errD_real_scaled:\n            errD_real_scaled.backward()\n        D_x = output.mean().item()\n\n        # train with fake\n        noise = torch.randn(batch_size, nz, 1, 1, device=device)\n        fake = netG(noise)\n        label.fill_(fake_label)\n        output = netD(fake.detach())\n        errD_fake = criterion(output, label)\n        with amp.scale_loss(errD_fake, optimizerD, loss_id=1) as errD_fake_scaled:\n            errD_fake_scaled.backward()\n        D_G_z1 = output.mean().item()\n        errD = errD_real + errD_fake\n        optimizerD.step()\n\n        ############################\n        # (2) Update G network: maximize log(D(G(z)))\n        ###########################\n        netG.zero_grad()\n        label.fill_(real_label)  # fake labels are real for generator cost\n        output = netD(fake)\n        errG = criterion(output, label)\n        with amp.scale_loss(errG, optimizerG, loss_id=2) as errG_scaled:\n            errG_scaled.backward()\n        D_G_z2 = output.mean().item()\n        optimizerG.step()\n\n        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'\n              % (epoch, opt.niter, i, len(dataloader),\n                 errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))\n        if i % 100 == 0:\n            vutils.save_image(real_cpu,\n                    '%s/real_samples.png' % opt.outf,\n                    normalize=True)\n            fake = netG(fixed_noise)\n            vutils.save_image(fake.detach(),\n                    '%s/amp_fake_samples_epoch_%03d.png' % (opt.outf, epoch),\n                    normalize=True)\n\n    # do checkpointing\n    torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))\n    torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))\n\n\n"
  },
  {
    "path": "examples/docker/Dockerfile",
    "content": "# Base image must at least have pytorch and CUDA installed.\nARG BASE_IMAGE=nvcr.io/nvidia/pytorch:23.03-py3\nFROM $BASE_IMAGE\nARG BASE_IMAGE\nRUN echo \"Installing Apex on top of ${BASE_IMAGE}\"\n# make sure we don't overwrite some existing directory called \"apex\"\nWORKDIR /tmp/unique_for_apex\n# uninstall Apex if present, twice to make absolutely sure :)\nRUN pip uninstall -y apex || :\nRUN pip uninstall -y apex || :\n# SHA is something the user can touch to force recreation of this Docker layer,\n# and therefore force cloning of the latest version of Apex\nRUN SHA=ToUcHMe git clone https://github.com/NVIDIA/apex.git\nWORKDIR /tmp/unique_for_apex/apex\nRUN pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" .\nWORKDIR /workspace\n"
  },
  {
    "path": "examples/docker/README.md",
    "content": "## Option 1:  Create a new container with Apex\n\n**Dockerfile** installs the latest Apex on top of an existing image.  Run\n```\ndocker build -t new_image_with_apex .\n```\nBy default, **Dockerfile** uses NVIDIA's Pytorch container as the base image,\nwhich requires an NVIDIA GPU Cloud (NGC) account.  If you don't have an NGC account, you can sign up for free by following the instructions [here](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html#generating-api-key).\n\nAlternatively, you can supply your own base image via the `BASE_IMAGE` build-arg.\n`BASE_IMAGE` must have Pytorch and Cuda installed.  For example, any\n`-devel` image for Pytorch 1.0 and later from the\n[official Pytorch Dockerhub](https://hub.docker.com/r/pytorch/pytorch) may be used:\n```\ndocker build --build-arg BASE_IMAGE=1.3-cuda10.1-cudnn7-devel -t new_image_with_apex .\n```\n\nIf you want to rebuild your image, and force the latest Apex to be cloned and installed, make any small change to the `SHA` variable in **Dockerfile**.\n\n**Warning:**\nCurrently, the non-`-devel` images on Pytorch Dockerhub do not contain the Cuda compiler `nvcc`.  Therefore,\nimages whose name does not contain `-devel` are not eligible candidates for `BASE_IMAGE`.\n\n### Running your Apex container\n\nLike any Cuda-enabled Pytorch container, a container with Apex should be run via [nvidia-docker](https://github.com/NVIDIA/nvidia-docker), for example:\n```\ndocker run --runtime=nvidia -it --rm --ipc=host new_image_with_apex\n```\n\n## Option 2:  Install Apex in a running container\n\nInstead of building a new container, it is also a viable option to `git clone https://github.com/NVIDIA/apex.git` on bare metal, mount the Apex repo into your container at launch by running, for example,\n```\ndocker run --runtime=nvidia -it --rm --ipc=host -v /bare/metal/apex:/apex/in/container <base image>\n```\nthen go to /apex/in/container within the running container and\n```\npip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" .\n```\n"
  },
  {
    "path": "examples/imagenet/README.md",
    "content": "# Mixed Precision ImageNet Training in PyTorch\n\n`main_amp.py` is based on [https://github.com/pytorch/examples/tree/master/imagenet](https://github.com/pytorch/examples/tree/master/imagenet).\nIt implements Automatic Mixed Precision (Amp) training of popular model architectures, such as ResNet, AlexNet, and VGG, on the ImageNet dataset.  Command-line flags forwarded to `amp.initialize` are used to easily manipulate and switch between various pure and mixed precision \"optimization levels\" or `opt_level`s.  For a detailed explanation of `opt_level`s, see the [updated API guide](https://nvidia.github.io/apex/amp.html).\n\nThree lines enable Amp:\n```\n# Added after model and optimizer construction\nmodel, optimizer = amp.initialize(model, optimizer, flags...)\n...\n# loss.backward() changed to:\nwith amp.scale_loss(loss, optimizer) as scaled_loss:\n    scaled_loss.backward()\n```\n\nWith the new Amp API **you never need to explicitly convert your model, or the input data, to half().**\n\n## Requirements\n\n- Download the ImageNet dataset and move validation images to labeled subfolders\n    - The following script may be helpful: https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh\n\n## Training\n\nTo train a model, create softlinks to the Imagenet dataset, then run `main.py` with the desired model architecture, as shown in `Example commands` below.\n\nThe default learning rate schedule is set for ResNet50.  `main_amp.py` script rescales the learning rate according to the global batch size (number of distributed processes \\* per-process minibatch size).\n\n## Example commands\n\n**Note:**  batch size `--b 224` assumes your GPUs have >=16GB of onboard memory.  You may be able to increase this to 256, but that's cutting it close, so it may out-of-memory for different Pytorch versions.\n\n**Note:**  All of the following use 4 dataloader subprocesses (`--workers 4`) to reduce potential\nCPU data loading bottlenecks.\n\n**Note:**  `--opt-level` `O1` and `O2` both use dynamic loss scaling by default unless manually overridden.\n`--opt-level` `O0` and `O3` (the \"pure\" training modes) do not use loss scaling by default.\n`O0` and `O3` can be told to use loss scaling via manual overrides, but using loss scaling with `O0`\n(pure FP32 training) does not really make sense, and will trigger a warning.\n\nSoftlink training and validation datasets into the current directory:\n```\n$ ln -sf /data/imagenet/train-jpeg/ train\n$ ln -sf /data/imagenet/val-jpeg/ val\n```\n\n### Summary\n\nAmp allows easy experimentation with various pure and mixed precision options.\n```\n$ python main_amp.py -a resnet50 --b 128 --workers 4 --opt-level O0 ./\n$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O3 ./\n$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O3 --keep-batchnorm-fp32 True ./\n$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 ./\n$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 --loss-scale 128.0 ./\n$ python -m torch.distributed.launch --nproc_per_node=2 main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 ./\n$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 ./\n$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 --loss-scale 128.0 ./\n$ python -m torch.distributed.launch --nproc_per_node=2 main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 ./\n```\nOptions are explained below.  Again, the [updated API guide](https://nvidia.github.io/apex/amp.html) provides more detail.\n\n#### `--opt-level O0` (FP32 training) and `O3` (FP16 training)\n\n\"Pure FP32\" training:\n```\n$ python main_amp.py -a resnet50 --b 128 --workers 4 --opt-level O0 ./\n```\n\"Pure FP16\" training:\n```\n$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O3 ./\n```\nFP16 training with FP32 batchnorm:\n```\n$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O3 --keep-batchnorm-fp32 True ./\n```\nKeeping the batchnorms in FP32 improves stability and allows Pytorch\nto use cudnn batchnorms, which significantly increases speed in Resnet50.\n\nThe `O3` options might not converge, because they are not true mixed precision.\nHowever, they can be useful to establish \"speed of light\" performance for\nyour model, which provides a baseline for comparison with `O1` and `O2`.\nFor Resnet50 in particular, `--opt-level O3 --keep-batchnorm-fp32 True` establishes\nthe \"speed of light.\"  (Without `--keep-batchnorm-fp32`, it's slower, because it does\nnot use cudnn batchnorm.)\n\n#### `--opt-level O1` (Official Mixed Precision recipe, recommended for typical use)\n\n`O1` patches Torch functions to cast inputs according to a whitelist-blacklist model.\nFP16-friendly (Tensor Core) ops like gemms and convolutions run in FP16, while ops\nthat benefit from FP32, like batchnorm and softmax, run in FP32.\nAlso, dynamic loss scaling is used by default.\n```\n$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 ./\n```\n`O1` overridden to use static loss scaling:\n```\n$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 --loss-scale 128.0\n```\nDistributed training with 2 processes (1 GPU per process, see **Distributed training** below\nfor more detail)\n```\n$ python -m torch.distributed.launch --nproc_per_node=2 main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 ./\n```\nFor best performance, set `--nproc_per_node` equal to the total number of GPUs on the node\nto use all available resources.\n\n#### `--opt-level O2` (\"Almost FP16\" mixed precision.  More dangerous than O1.)\n\n`O2` exists mainly to support some internal use cases.  Please prefer `O1`.\n\n`O2` casts the model to FP16, keeps batchnorms in FP32,\nmaintains master weights in FP32, and implements\ndynamic loss scaling by default. (Unlike --opt-level O1, --opt-level O2\ndoes not patch Torch functions.)\n```\n$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 ./\n```\n\"Fast mixed precision\" overridden to use static loss scaling:\n```\n$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 --loss-scale 128.0 ./\n```\nDistributed training with 2 processes (1 GPU per process)\n```\n$ python -m torch.distributed.launch --nproc_per_node=2 main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 ./\n```\n\n## Distributed training\n\n`main_amp.py` optionally uses `apex.parallel.DistributedDataParallel` (DDP) for multiprocess training with one GPU per process.\n```\nmodel = apex.parallel.DistributedDataParallel(model)\n```\nis a drop-in replacement for\n```\nmodel = torch.nn.parallel.DistributedDataParallel(model,\n                                                  device_ids=[arg.local_rank],\n                                                  output_device=arg.local_rank)\n```\n(because Torch DDP permits multiple GPUs per process, with Torch DDP you are required to\nmanually specify the device to run on and the output device.\nWith Apex DDP, it uses only the current device by default).\n\nThe choice of DDP wrapper (Torch or Apex) is orthogonal to the use of Amp and other Apex tools.  It is safe to use `apex.amp` with either `torch.nn.parallel.DistributedDataParallel` or `apex.parallel.DistributedDataParallel`.  In the future, I may add some features that permit optional tighter integration between `Amp` and `apex.parallel.DistributedDataParallel` for marginal performance benefits, but currently, there's no compelling reason to use Apex DDP versus Torch DDP for most models.\n\nTo use DDP with `apex.amp`, the only gotcha is that\n```\nmodel, optimizer = amp.initialize(model, optimizer, flags...)\n```\nmust precede\n```\nmodel = DDP(model)\n```\nIf DDP wrapping occurs before `amp.initialize`, `amp.initialize` will raise an error.\n\nWith both Apex DDP and Torch DDP, you must also call `torch.cuda.set_device(args.local_rank)` within\neach process prior to initializing your model or any other tensors.\nMore information can be found in the docs for the\nPytorch multiprocess launcher module [torch.distributed.launch](https://pytorch.org/docs/stable/distributed.html#launch-utility).\n\n`main_amp.py` is written to interact with \n[torch.distributed.launch](https://pytorch.org/docs/master/distributed.html#launch-utility),\nwhich spawns multiprocess jobs using the following syntax:\n```\npython -m torch.distributed.launch --nproc_per_node=NUM_GPUS main_amp.py args...\n```\n`NUM_GPUS` should be less than or equal to the number of visible GPU devices on the node.  The use of `torch.distributed.launch` is unrelated to the choice of DDP wrapper.  It is safe to use either apex DDP or torch DDP with `torch.distributed.launch`.\n\nOptionally, one can run imagenet with synchronized batch normalization across processes by adding\n`--sync_bn` to the `args...`\n\n## Deterministic training (for debugging purposes)\n\nRunning with the `--deterministic` flag should produce bitwise identical outputs run-to-run,\nregardless of what other options are used (see [Pytorch docs on reproducibility](https://pytorch.org/docs/stable/notes/randomness.html)).\nSince `--deterministic` disables `torch.backends.cudnn.benchmark`, `--deterministic` may\ncause a modest performance decrease.\n\n## Profiling\n\nIf you're curious how the network actually looks on the CPU and GPU timelines (for example, how good is the overall utilization?\nIs the prefetcher really overlapping data transfers?) try profiling `main_amp.py`.\n[Detailed instructions can be found here](https://gist.github.com/mcarilli/213a4e698e4a0ae2234ddee56f4f3f95).\n"
  },
  {
    "path": "examples/imagenet/main_amp.py",
    "content": "import argparse\nimport os\nimport shutil\nimport time\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.backends.cudnn as cudnn\nimport torch.distributed as dist\nimport torch.optim\nimport torch.utils.data\nimport torch.utils.data.distributed\nimport torchvision.transforms as transforms\nimport torchvision.datasets as datasets\nimport torchvision.models as models\n\nimport numpy as np\n\nfrom torch.nn.parallel import DistributedDataParallel as DDP\n\ndef to_python_float(scalar_tensor: torch.Tensor):\n    return scalar_tensor.float().item()\n\ndef fast_collate(batch, memory_format):\n\n    imgs = [img[0] for img in batch]\n    targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)\n    w = imgs[0].size[0]\n    h = imgs[0].size[1]\n    tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8).contiguous(memory_format=memory_format)\n    for i, img in enumerate(imgs):\n        nump_array = np.asarray(img, dtype=np.uint8)\n        if(nump_array.ndim < 3):\n            nump_array = np.expand_dims(nump_array, axis=-1)\n        nump_array = np.rollaxis(nump_array, 2)\n        tensor[i] += torch.from_numpy(nump_array)\n    return tensor, targets\n\n\ndef parse():\n    model_names = sorted(name for name in models.__dict__\n                     if name.islower() and not name.startswith(\"__\")\n                     and callable(models.__dict__[name]))\n\n    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')\n    parser.add_argument('data', metavar='DIR',\n                        help='path to dataset')\n    parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',\n                        choices=model_names,\n                        help='model architecture: ' +\n                        ' | '.join(model_names) +\n                        ' (default: resnet18)')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=90, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',\n                        help='manual epoch number (useful on restarts)')\n    parser.add_argument('-b', '--batch-size', default=256, type=int,\n                        metavar='N', help='mini-batch size per process (default: 256)')\n    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n                        metavar='LR', help='Initial learning rate.  Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256.  A warmup schedule will also be applied over the first 5 epochs.')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,\n                        metavar='W', help='weight decay (default: 1e-4)')\n    parser.add_argument('--print-freq', '-p', default=10, type=int,\n                        metavar='N', help='print frequency (default: 10)')\n    parser.add_argument('--resume', default='', type=str, metavar='PATH',\n                        help='path to latest checkpoint (default: none)')\n    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',\n                        help='evaluate model on validation set')\n    parser.add_argument('--pretrained', dest='pretrained', action='store_true',\n                        help='use pre-trained model')\n\n    parser.add_argument('--prof', default=-1, type=int,\n                        help='Only run 10 iterations for profiling.')\n    parser.add_argument('--deterministic', action='store_true')\n\n    parser.add_argument(\"--local_rank\", default=os.getenv('LOCAL_RANK', 0), type=int)\n    parser.add_argument('--sync_bn', action='store_true',\n                        help='enabling apex sync BN.')\n\n    parser.add_argument('--opt-level', type=str)\n    parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)\n    parser.add_argument('--loss-scale', type=str, default=None)\n    parser.add_argument('--channels-last', type=bool, default=False)\n    args = parser.parse_args()\n    return args\n\ndef main():\n    global best_prec1, args\n\n    args = parse()\n    print(\"opt_level = {}\".format(args.opt_level))\n    print(\"keep_batchnorm_fp32 = {}\".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))\n    print(\"loss_scale = {}\".format(args.loss_scale), type(args.loss_scale))\n\n    print(\"\\nCUDNN VERSION: {}\\n\".format(torch.backends.cudnn.version()))\n\n    cudnn.benchmark = True\n    best_prec1 = 0\n    if args.deterministic:\n        cudnn.benchmark = False\n        cudnn.deterministic = True\n        torch.manual_seed(args.local_rank)\n        torch.set_printoptions(precision=10)\n\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n\n    args.gpu = 0\n    args.world_size = 1\n\n    if args.distributed:\n        args.gpu = args.local_rank\n        torch.cuda.set_device(args.gpu)\n        torch.distributed.init_process_group(backend='nccl',\n                                             init_method='env://')\n        args.world_size = torch.distributed.get_world_size()\n\n    assert torch.backends.cudnn.enabled, \"Amp requires cudnn backend to be enabled.\"\n\n    if args.channels_last:\n        memory_format = torch.channels_last\n    else:\n        memory_format = torch.contiguous_format\n\n    # create model\n    if args.pretrained:\n        print(\"=> using pre-trained model '{}'\".format(args.arch))\n        model = models.__dict__[args.arch](pretrained=True)\n    else:\n        print(\"=> creating model '{}'\".format(args.arch))\n        model = models.__dict__[args.arch]()\n\n    if args.sync_bn:\n        import apex\n        print(\"using apex synced BN\")\n        model = apex.parallel.convert_syncbn_model(model)\n\n    model = model.cuda().to(memory_format=memory_format)\n\n    # Scale learning rate based on global batch size\n    args.lr = args.lr*float(args.batch_size*args.world_size)/256.\n    optimizer = torch.optim.SGD(model.parameters(), args.lr,\n                                momentum=args.momentum,\n                                weight_decay=args.weight_decay)\n\n    if args.distributed:\n        model = DDP(model)\n    scaler = torch.amp.GradScaler(\"cuda\")\n\n    # define loss function (criterion) and optimizer\n    criterion = nn.CrossEntropyLoss().cuda()\n\n    # Optionally resume from a checkpoint\n    if args.resume:\n        # Use a local scope to avoid dangling references\n        def resume():\n            if os.path.isfile(args.resume):\n                print(\"=> loading checkpoint '{}'\".format(args.resume))\n                checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))\n                args.start_epoch = checkpoint['epoch']\n                global best_prec1\n                best_prec1 = checkpoint['best_prec1']\n                model.load_state_dict(checkpoint['state_dict'])\n                optimizer.load_state_dict(checkpoint['optimizer'])\n                print(\"=> loaded checkpoint '{}' (epoch {})\"\n                      .format(args.resume, checkpoint['epoch']))\n            else:\n                print(\"=> no checkpoint found at '{}'\".format(args.resume))\n        resume()\n\n    # Data loading code\n    traindir = os.path.join(args.data, 'train')\n    valdir = os.path.join(args.data, 'val')\n\n    if(args.arch == \"inception_v3\"):\n        raise RuntimeError(\"Currently, inception_v3 is not supported by this example.\")\n        # crop_size = 299\n        # val_size = 320 # I chose this value arbitrarily, we can adjust.\n    else:\n        crop_size = 224\n        val_size = 256\n\n    train_dataset = datasets.ImageFolder(\n        traindir,\n        transforms.Compose([\n            transforms.RandomResizedCrop(crop_size),\n            transforms.RandomHorizontalFlip(),\n            # transforms.ToTensor(), Too slow\n            # normalize,\n        ]))\n    val_dataset = datasets.ImageFolder(valdir, transforms.Compose([\n            transforms.Resize(val_size),\n            transforms.CenterCrop(crop_size),\n        ]))\n\n    train_sampler = None\n    val_sampler = None\n    if args.distributed:\n        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)\n        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)\n\n    collate_fn = lambda b: fast_collate(b, memory_format)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),\n        num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=collate_fn)\n\n    val_loader = torch.utils.data.DataLoader(\n        val_dataset,\n        batch_size=args.batch_size, shuffle=False,\n        num_workers=args.workers, pin_memory=True,\n        sampler=val_sampler,\n        collate_fn=collate_fn)\n\n    if args.evaluate:\n        validate(val_loader, model, criterion)\n        return\n\n    for epoch in range(args.start_epoch, args.epochs):\n        if args.distributed:\n            train_sampler.set_epoch(epoch)\n\n        # train for one epoch\n        train(train_loader, model, criterion, optimizer, scaler, epoch)\n\n        # evaluate on validation set\n        prec1 = validate(val_loader, model, criterion)\n\n        # remember best prec@1 and save checkpoint\n        if args.local_rank == 0:\n            is_best = prec1 > best_prec1\n            best_prec1 = max(prec1, best_prec1)\n            save_checkpoint({\n                'epoch': epoch + 1,\n                'arch': args.arch,\n                'state_dict': model.state_dict(),\n                'best_prec1': best_prec1,\n                'optimizer' : optimizer.state_dict(),\n            }, is_best)\n\nclass data_prefetcher():\n    def __init__(self, loader):\n        self.loader = iter(loader)\n        self.stream = torch.cuda.Stream()\n        self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)\n        self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)\n        # With Amp, it isn't necessary to manually convert data to half.\n        # if args.fp16:\n        #     self.mean = self.mean.half()\n        #     self.std = self.std.half()\n        self.preload()\n\n    def preload(self):\n        try:\n            self.next_input, self.next_target = next(self.loader)\n        except StopIteration:\n            self.next_input = None\n            self.next_target = None\n            return\n        # if record_stream() doesn't work, another option is to make sure device inputs are created\n        # on the main stream.\n        # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')\n        # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')\n        # Need to make sure the memory allocated for next_* is not still in use by the main stream\n        # at the time we start copying to next_*:\n        # self.stream.wait_stream(torch.cuda.current_stream())\n        with torch.cuda.stream(self.stream):\n            self.next_input = self.next_input.cuda(non_blocking=True)\n            self.next_target = self.next_target.cuda(non_blocking=True)\n            # more code for the alternative if record_stream() doesn't work:\n            # copy_ will record the use of the pinned source tensor in this side stream.\n            # self.next_input_gpu.copy_(self.next_input, non_blocking=True)\n            # self.next_target_gpu.copy_(self.next_target, non_blocking=True)\n            # self.next_input = self.next_input_gpu\n            # self.next_target = self.next_target_gpu\n\n            # With Amp, it isn't necessary to manually convert data to half.\n            # if args.fp16:\n            #     self.next_input = self.next_input.half()\n            # else:\n            self.next_input = self.next_input.float()\n            self.next_input = self.next_input.sub_(self.mean).div_(self.std)\n\n    def next(self):\n        torch.cuda.current_stream().wait_stream(self.stream)\n        input = self.next_input\n        target = self.next_target\n        if input is not None:\n            input.record_stream(torch.cuda.current_stream())\n        if target is not None:\n            target.record_stream(torch.cuda.current_stream())\n        self.preload()\n        return input, target\n\n\ndef train(train_loader, model, criterion, optimizer, scaler, epoch):\n    batch_time = AverageMeter()\n    losses = AverageMeter()\n    top1 = AverageMeter()\n    top5 = AverageMeter()\n\n    # switch to train mode\n    model.train()\n    end = time.time()\n\n    prefetcher = data_prefetcher(train_loader)\n    input, target = prefetcher.next()\n    i = 0\n    while input is not None:\n        i += 1\n        if args.prof >= 0 and i == args.prof:\n            print(\"Profiling begun at iteration {}\".format(i))\n            torch.cuda.cudart().cudaProfilerStart()\n\n        if args.prof >= 0: torch.cuda.nvtx.range_push(\"Body of iteration {}\".format(i))\n\n        adjust_learning_rate(optimizer, epoch, i, len(train_loader))\n\n        # compute output\n        with torch.autocast(device_type=\"cuda\"):\n            if args.prof >= 0: torch.cuda.nvtx.range_push(\"forward\")\n            output = model(input)\n            if args.prof >= 0: torch.cuda.nvtx.range_pop()\n            loss = criterion(output, target)\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n\n        if args.prof >= 0: torch.cuda.nvtx.range_push(\"backward\")\n        scaler.scale(loss).backward()\n        if args.prof >= 0: torch.cuda.nvtx.range_pop()\n\n        # for param in model.parameters():\n        #     print(param.data.double().sum().item(), param.grad.data.double().sum().item())\n\n        if args.prof >= 0: torch.cuda.nvtx.range_push(\"optimizer.step()\")\n        scaler.step(optimizer)\n        scaler.update()\n        if args.prof >= 0: torch.cuda.nvtx.range_pop()\n\n        if i%args.print_freq == 0:\n            # Every print_freq iterations, check the loss, accuracy, and speed.\n            # For best performance, it doesn't make sense to print these metrics every\n            # iteration, since they incur an allreduce and some host<->device syncs.\n\n            # Measure accuracy\n            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))\n\n            # Average loss and accuracy across processes for logging\n            if args.distributed:\n                reduced_loss = reduce_tensor(loss.data)\n                prec1 = reduce_tensor(prec1)\n                prec5 = reduce_tensor(prec5)\n            else:\n                reduced_loss = loss.data\n\n            # to_python_float incurs a host<->device sync\n            losses.update(to_python_float(reduced_loss), input.size(0))\n            top1.update(to_python_float(prec1), input.size(0))\n            top5.update(to_python_float(prec5), input.size(0))\n\n            torch.cuda.synchronize()\n            batch_time.update((time.time() - end)/args.print_freq)\n            end = time.time()\n\n            if args.local_rank == 0:\n                print('Epoch: [{0}][{1}/{2}]\\t'\n                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                      'Speed {3:.3f} ({4:.3f})\\t'\n                      'Loss {loss.val:.10f} ({loss.avg:.4f})\\t'\n                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\\t'\n                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(\n                       epoch, i, len(train_loader),\n                       args.world_size*args.batch_size/batch_time.val,\n                       args.world_size*args.batch_size/batch_time.avg,\n                       batch_time=batch_time,\n                       loss=losses, top1=top1, top5=top5))\n        if args.prof >= 0: torch.cuda.nvtx.range_push(\"prefetcher.next()\")\n        input, target = prefetcher.next()\n        if args.prof >= 0: torch.cuda.nvtx.range_pop()\n\n        # Pop range \"Body of iteration {}\".format(i)\n        if args.prof >= 0: torch.cuda.nvtx.range_pop()\n\n        if args.prof >= 0 and i == args.prof + 10:\n            print(\"Profiling ended at iteration {}\".format(i))\n            torch.cuda.cudart().cudaProfilerStop()\n            quit()\n\n\ndef validate(val_loader, model, criterion):\n    batch_time = AverageMeter()\n    losses = AverageMeter()\n    top1 = AverageMeter()\n    top5 = AverageMeter()\n\n    # switch to evaluate mode\n    model.eval()\n\n    end = time.time()\n\n    prefetcher = data_prefetcher(val_loader)\n    input, target = prefetcher.next()\n    i = 0\n    while input is not None:\n        i += 1\n\n        # compute output\n        with torch.no_grad():\n            output = model(input)\n            loss = criterion(output, target)\n\n        # measure accuracy and record loss\n        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))\n\n        if args.distributed:\n            reduced_loss = reduce_tensor(loss.data)\n            prec1 = reduce_tensor(prec1)\n            prec5 = reduce_tensor(prec5)\n        else:\n            reduced_loss = loss.data\n\n        losses.update(to_python_float(reduced_loss), input.size(0))\n        top1.update(to_python_float(prec1), input.size(0))\n        top5.update(to_python_float(prec5), input.size(0))\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        # TODO:  Change timings to mirror train().\n        if args.local_rank == 0 and i % args.print_freq == 0:\n            print('Test: [{0}/{1}]\\t'\n                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                  'Speed {2:.3f} ({3:.3f})\\t'\n                  'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\\t'\n                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(\n                   i, len(val_loader),\n                   args.world_size * args.batch_size / batch_time.val,\n                   args.world_size * args.batch_size / batch_time.avg,\n                   batch_time=batch_time, loss=losses,\n                   top1=top1, top5=top5))\n\n        input, target = prefetcher.next()\n\n    print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'\n          .format(top1=top1, top5=top5))\n\n    return top1.avg\n\n\ndef save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):\n    torch.save(state, filename)\n    if is_best:\n        shutil.copyfile(filename, 'model_best.pth.tar')\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\ndef adjust_learning_rate(optimizer, epoch, step, len_epoch):\n    \"\"\"LR schedule that should yield 76% converged accuracy with batch size 256\"\"\"\n    factor = epoch // 30\n\n    if epoch >= 80:\n        factor = factor + 1\n\n    lr = args.lr*(0.1**factor)\n\n    \"\"\"Warmup\"\"\"\n    if epoch < 5:\n        lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch)\n\n    # if(args.local_rank == 0):\n    #     print(\"epoch = {}, step = {}, lr = {}\".format(epoch, step, lr))\n\n    for param_group in optimizer.param_groups:\n        param_group['lr'] = lr\n\n\ndef accuracy(output, target, topk=(1,)):\n    \"\"\"Computes the precision@k for the specified values of k\"\"\"\n    maxk = max(topk)\n    batch_size = target.size(0)\n\n    _, pred = output.topk(maxk, 1, True, True)\n    pred = pred.t()\n    correct = pred.eq(target.view(1, -1).expand_as(pred))\n\n    res = []\n    for k in topk:\n        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)\n        res.append(correct_k.mul_(100.0 / batch_size))\n    return res\n\n\ndef reduce_tensor(tensor):\n    rt = tensor.clone()\n    dist.all_reduce(rt, op=dist.reduce_op.SUM)\n    rt /= args.world_size\n    return rt\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "examples/simple/distributed/README.md",
    "content": "**distributed_data_parallel.py** and **run.sh** show an example using Amp with\n[apex.parallel.DistributedDataParallel](https://nvidia.github.io/apex/parallel.html) or\n[torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#distributeddataparallel)\nand the Pytorch multiprocess launcher script,\n[torch.distributed.launch](https://pytorch.org/docs/master/distributed.html#launch-utility).\nThe use of `Amp` with DistributedDataParallel does not need to change from ordinary \nsingle-process use.  The only gotcha is that wrapping your model with `DistributedDataParallel` must\ncome after the call to `amp.initialize`.  Test via\n```bash\nbash run.sh\n```\n\n**This is intended purely as an instructional example, not a performance showcase.**\n"
  },
  {
    "path": "examples/simple/distributed/distributed_data_parallel.py",
    "content": "import torch\nimport argparse\nimport os\nfrom apex import amp\n# FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead)\nfrom apex.parallel import DistributedDataParallel\n\nparser = argparse.ArgumentParser()\n# FOR DISTRIBUTED:  Parse for the local_rank argument, which will be supplied\n# automatically by torch.distributed.launch.\nparser.add_argument(\"--local_rank\", default=0, type=int)\nargs = parser.parse_args()\n\n# FOR DISTRIBUTED:  If we are running under torch.distributed.launch,\n# the 'WORLD_SIZE' environment variable will also be set automatically.\nargs.distributed = False\nif 'WORLD_SIZE' in os.environ:\n    args.distributed = int(os.environ['WORLD_SIZE']) > 1\n\nif args.distributed:\n    # FOR DISTRIBUTED:  Set the device according to local_rank.\n    torch.cuda.set_device(args.local_rank)\n\n    # FOR DISTRIBUTED:  Initialize the backend.  torch.distributed.launch will provide\n    # environment variables, and requires that you use init_method=`env://`.\n    torch.distributed.init_process_group(backend='nccl',\n                                         init_method='env://')\n\ntorch.backends.cudnn.benchmark = True\n\nN, D_in, D_out = 64, 1024, 16\n\n# Each process receives its own batch of \"fake input data\" and \"fake target data.\"\n# The \"training loop\" in each process just uses this fake batch over and over.\n# https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic\n# example of distributed data sampling for both training and validation.\nx = torch.randn(N, D_in, device='cuda')\ny = torch.randn(N, D_out, device='cuda')\n\nmodel = torch.nn.Linear(D_in, D_out).cuda()\noptimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n\nmodel, optimizer = amp.initialize(model, optimizer, opt_level=\"O1\")\n\nif args.distributed:\n    # FOR DISTRIBUTED:  After amp.initialize, wrap the model with\n    # apex.parallel.DistributedDataParallel.\n    model = DistributedDataParallel(model)\n    # torch.nn.parallel.DistributedDataParallel is also fine, with some added args:\n    # model = torch.nn.parallel.DistributedDataParallel(model,\n    #                                                   device_ids=[args.local_rank],\n    #                                                   output_device=args.local_rank)\n\nloss_fn = torch.nn.MSELoss()\n\nfor t in range(500):\n    optimizer.zero_grad()\n    y_pred = model(x)\n    loss = loss_fn(y_pred, y)\n    with amp.scale_loss(loss, optimizer) as scaled_loss:\n        scaled_loss.backward()\n    optimizer.step()\n\nif args.local_rank == 0:\n    print(\"final loss = \", loss)\n"
  },
  {
    "path": "examples/simple/distributed/run.sh",
    "content": "#!/bin/bash\npython -m torch.distributed.launch --nproc_per_node=2 distributed_data_parallel.py\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\n    \"setuptools\",\n    \"wheel\",\n]\nbuild-backend = \"setuptools.build_meta\"\n\n[tool.ruff]\nline-length = 100\nignore = [\n    # Sorted by occurrence count (ascending) - easier to fix first\n    \"E731\",  # lambda assignment (6 occurrences)\n    \"E721\",  # type comparison should use isinstance (8 occurrences)\n    \"E741\",  # ambiguous variable name (8 occurrences)\n    \"E712\",  # comparison to True/False (9 occurrences)\n    \"F403\",  # star imports used (9 occurrences)\n    \"E701\",  # multiple statements on one line (10 occurrences)\n    \"E711\",  # comparison to None should be `cond is None` (11 occurrences)\n    \"F821\",  # undefined name (14 occurrences)\n    \"E722\",  # bare except (15 occurrences)\n    \"E402\",  # module level import not at top of file (41 occurrences)\n    \"F401\",  # imported but unused (45 occurrences)\n    \"F841\",  # local variable assigned but never used (52 occurrences)\n    \"F405\",  # star imports (80 occurrences)\n]\n"
  },
  {
    "path": "requirements.txt",
    "content": "cxxfilt>=0.2.0\ntqdm>=4.28.1\nnumpy>=1.15.3\nPyYAML>=5.1\npytest>=3.5.1\npackaging>=14.0\ntorch>=2.6.0\n"
  },
  {
    "path": "requirements_dev.txt",
    "content": "-r requirements.txt\nflake8>=3.7.9\nSphinx>=3.0.3"
  },
  {
    "path": "setup.py",
    "content": "import sys\nimport warnings\nimport os\nimport threading\nimport glob\nfrom packaging.version import parse, Version\n\nfrom setuptools import setup, find_packages\nimport subprocess\n\nimport torch\nfrom torch.utils.cpp_extension import (\n    BuildExtension,\n    CppExtension,\n    CUDAExtension,\n    CUDA_HOME,\n    load,\n)\n\n# ninja build does not work unless include_dirs are abs path\nthis_dir = os.path.dirname(os.path.abspath(__file__))\n\n# Allow environment variables to specify build flags for PEP 517 compatibility\nENV_TO_FLAG = {\n    \"APEX_CPP_EXT\": \"--cpp_ext\",\n    \"APEX_CUDA_EXT\": \"--cuda_ext\",\n    \"APEX_XENTROPY\": \"--xentropy\",\n    \"APEX_FAST_LAYER_NORM\": \"--fast_layer_norm\",\n    \"APEX_DISTRIBUTED_ADAM\": \"--distributed_adam\",\n    \"APEX_DISTRIBUTED_LAMB\": \"--distributed_lamb\",\n    \"APEX_BNP\": \"--bnp\",\n    \"APEX_GROUP_NORM\": \"--group_norm\",\n    \"APEX_INDEX_MUL_2D\": \"--index_mul_2d\",\n    \"APEX_DEPRECATED_FUSED_ADAM\": \"--deprecated_fused_adam\",\n    \"APEX_DEPRECATED_FUSED_LAMB\": \"--deprecated_fused_lamb\",\n    \"APEX_FAST_MULTIHEAD_ATTN\": \"--fast_multihead_attn\",\n    \"APEX_FMHA\": \"--fmha\",\n    \"APEX_PERMUTATION_SEARCH\": \"--permutation_search\",\n    \"APEX_FOCAL_LOSS\": \"--focal_loss\",\n    \"APEX_TRANSDUCER\": \"--transducer\",\n    \"APEX_CUDNN_GBN\": \"--cudnn_gbn\",\n    \"APEX_PEER_MEMORY\": \"--peer_memory\",\n    \"APEX_NCCL_P2P\": \"--nccl_p2p\",\n    \"APEX_FAST_BOTTLENECK\": \"--fast_bottleneck\",\n    \"APEX_FUSED_CONV_BIAS_RELU\": \"--fused_conv_bias_relu\",\n    \"APEX_NCCL_ALLOCATOR\": \"--nccl_allocator\",\n    \"APEX_GPU_DIRECT_STORAGE\": \"--gpu_direct_storage\",\n}\nfor env_var, flag in ENV_TO_FLAG.items():\n    if os.environ.get(env_var, \"0\") == \"1\" and flag not in sys.argv:\n        print(f\"[apex] Detected {env_var}=1, adding {flag} to build flags.\")\n        sys.argv.append(flag)\n\n\nFLAG_TO_ENV = {v: k for k, v in ENV_TO_FLAG.items()}\nCORE_FLAGS = {\"--cpp_ext\", \"--cuda_ext\"}\nCONTRIB_FLAGS = set(FLAG_TO_ENV.keys()) - CORE_FLAGS\n\n\ndef has_flag(flag, env_var):\n    if flag in sys.argv or os.environ.get(env_var, \"0\") == \"1\":\n        return True\n    if flag in CONTRIB_FLAGS and os.environ.get(\"APEX_ALL_CONTRIB_EXT\", \"0\") == \"1\":\n        return True\n    return False\n\n\ndef get_cuda_bare_metal_version(cuda_dir):\n    raw_output = subprocess.check_output([cuda_dir + \"/bin/nvcc\", \"-V\"], universal_newlines=True)\n    output = raw_output.split()\n    release_idx = output.index(\"release\") + 1\n    bare_metal_version = parse(output[release_idx].split(\",\")[0])\n\n    return raw_output, bare_metal_version\n\n\ndef check_cuda_torch_binary_vs_bare_metal(cuda_dir):\n    raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)\n    torch_binary_version = parse(torch.version.cuda)\n\n    print(\"\\nCompiling cuda extensions with\")\n    print(raw_output + \"from \" + cuda_dir + \"/bin\\n\")\n\n    if bare_metal_version != torch_binary_version:\n        raise RuntimeError(\n            \"Cuda extensions are being compiled with a version of Cuda that does \"\n            \"not match the version used to compile Pytorch binaries.  \"\n            \"Pytorch binaries were compiled with Cuda {}.\\n\".format(torch.version.cuda)\n            + \"In some cases, a minor-version mismatch will not cause later errors:  \"\n            \"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798.  \"\n            \"You can try commenting out this check (at your own risk).\"\n        )\n\n\ndef raise_if_cuda_home_none(global_option: str) -> None:\n    if CUDA_HOME is not None:\n        return\n    raise RuntimeError(\n        f\"{global_option} was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  \"\n        \"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, \"\n        \"only images whose names contain 'devel' will provide nvcc.\"\n    )\n\n\ndef check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:\n    cudnn_available = torch.backends.cudnn.is_available()\n    cudnn_version = torch.backends.cudnn.version() if cudnn_available else None\n    if not (cudnn_available and (cudnn_version >= required_cudnn_version)):\n        warnings.warn(\n            f\"Skip `{global_option}` as it requires cuDNN {required_cudnn_version} or later, \"\n            f\"but {'cuDNN is not available' if not cudnn_available else cudnn_version}\"\n        )\n        return False\n    return True\n\n\nif not torch.cuda.is_available():\n    # https://github.com/NVIDIA/apex/issues/486\n    # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),\n    # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).\n    print(\n        \"\\nWarning: Torch did not find available GPUs on this system.\\n\",\n        \"If your intention is to cross-compile, this is not an error.\\n\"\n        \"By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2) (until CUDA 12.8),\\n\"\n        \"Volta (compute capability 7.0), Turing (compute capability 7.5),\\n\"\n        \"and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0, 8.6), and,\\n\"\n        \"if the CUDA version is >= 12.8, Blackwell (compute capability 10.0, 12.0).\\n\"\n        \"If you wish to cross-compile for a single specific architecture,\\n\"\n        'export TORCH_CUDA_ARCH_LIST=\"compute capability\" before running setup.py.\\n',\n    )\n    if os.environ.get(\"TORCH_CUDA_ARCH_LIST\", None) is None and CUDA_HOME is not None:\n        _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)\n        if bare_metal_version >= Version(\"13.0\"):\n            os.environ[\"TORCH_CUDA_ARCH_LIST\"] = \"7.5;8.0;8.6;9.0;10.0;11.0;12.0\"\n        elif bare_metal_version >= Version(\"12.8\"):\n            os.environ[\"TORCH_CUDA_ARCH_LIST\"] = \"7.0;7.5;8.0;8.6;9.0;10.0;12.0\"\n        elif bare_metal_version >= Version(\"11.8\"):\n            os.environ[\"TORCH_CUDA_ARCH_LIST\"] = \"6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0\"\n        elif bare_metal_version >= Version(\"11.1\"):\n            os.environ[\"TORCH_CUDA_ARCH_LIST\"] = \"6.0;6.1;6.2;7.0;7.5;8.0;8.6\"\n        elif bare_metal_version == Version(\"11.0\"):\n            os.environ[\"TORCH_CUDA_ARCH_LIST\"] = \"6.0;6.1;6.2;7.0;7.5;8.0\"\n        else:\n            os.environ[\"TORCH_CUDA_ARCH_LIST\"] = \"6.0;6.1;6.2;7.0;7.5\"\n\nprint(\"\\n\\ntorch.__version__  = {}\\n\\n\".format(torch.__version__))\nTORCH_MAJOR = int(torch.__version__.split(\".\")[0])\nTORCH_MINOR = int(torch.__version__.split(\".\")[1])\n\nif TORCH_MAJOR == 0 and TORCH_MINOR < 4:\n    raise RuntimeError(\n        \"Apex requires Pytorch 0.4 or newer.\\nThe latest stable release can be obtained from https://pytorch.org/\"\n    )\n\ncmdclass = {}\next_modules = []\n\nextras = {}\n\nif \"--cpp_ext\" in sys.argv or \"--cuda_ext\" in sys.argv:\n    if TORCH_MAJOR == 0:\n        raise RuntimeError(\n            \"--cpp_ext requires Pytorch 1.0 or later, found torch.__version__ = {}\".format(\n                torch.__version__\n            )\n        )\n\nif has_flag(\"--cpp_ext\", \"APEX_CPP_EXT\"):\n    if \"--cpp_ext\" in sys.argv:\n        sys.argv.remove(\"--cpp_ext\")\n    ext_modules.append(CppExtension(\"apex_C\", [\"csrc/flatten_unflatten.cpp\"]))\n\n\n_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)\n\nif has_flag(\"--distributed_adam\", \"APEX_DISTRIBUTED_ADAM\"):\n    if \"--distributed_adam\" in sys.argv:\n        sys.argv.remove(\"--distributed_adam\")\n    raise_if_cuda_home_none(\"--distributed_adam\")\n    ext_modules.append(\n        CUDAExtension(\n            name=\"distributed_adam_cuda\",\n            sources=[\n                \"apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp\",\n                \"apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu\",\n            ],\n            include_dirs=[os.path.join(this_dir, \"csrc\")],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"],\n                \"nvcc\": [\"-O3\", \"--use_fast_math\"],\n            },\n        )\n    )\n\nif has_flag(\"--distributed_lamb\", \"APEX_DISTRIBUTED_LAMB\"):\n    if \"--distributed_lamb\" in sys.argv:\n        sys.argv.remove(\"--distributed_lamb\")\n    raise_if_cuda_home_none(\"--distributed_lamb\")\n    ext_modules.append(\n        CUDAExtension(\n            name=\"distributed_lamb_cuda\",\n            sources=[\n                \"apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp\",\n                \"apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu\",\n            ],\n            include_dirs=[os.path.join(this_dir, \"csrc\")],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"],\n                \"nvcc\": [\"-O3\", \"--use_fast_math\"],\n            },\n        )\n    )\n\nif has_flag(\"--cuda_ext\", \"APEX_CUDA_EXT\"):\n    if \"--cuda_ext\" in sys.argv:\n        sys.argv.remove(\"--cuda_ext\")\n    raise_if_cuda_home_none(\"--cuda_ext\")\n    check_cuda_torch_binary_vs_bare_metal(CUDA_HOME)\n\n    ext_modules.append(\n        CUDAExtension(\n            name=\"amp_C\",\n            sources=[\n                \"csrc/amp_C_frontend.cpp\",\n                \"csrc/multi_tensor_sgd_kernel.cu\",\n                \"csrc/multi_tensor_scale_kernel.cu\",\n                \"csrc/multi_tensor_axpby_kernel.cu\",\n                \"csrc/multi_tensor_l2norm_kernel.cu\",\n                \"csrc/multi_tensor_l2norm_kernel_mp.cu\",\n                \"csrc/multi_tensor_l2norm_scale_kernel.cu\",\n                \"csrc/multi_tensor_lamb_stage_1.cu\",\n                \"csrc/multi_tensor_lamb_stage_2.cu\",\n                \"csrc/multi_tensor_adam.cu\",\n                \"csrc/multi_tensor_adagrad.cu\",\n                \"csrc/multi_tensor_novograd.cu\",\n                \"csrc/multi_tensor_lamb.cu\",\n                \"csrc/multi_tensor_lamb_mp.cu\",\n                \"csrc/update_scale_hysteresis.cu\",\n            ],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"],\n                \"nvcc\": [\n                    \"-lineinfo\",\n                    \"-O3\",\n                    # '--resource-usage',\n                    \"--use_fast_math\",\n                ],\n            },\n        )\n    )\n    ext_modules.append(\n        CUDAExtension(\n            name=\"syncbn\",\n            sources=[\"csrc/syncbn.cpp\", \"csrc/welford.cu\"],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"],\n                \"nvcc\": [\"-O3\"],\n            },\n        )\n    )\n\n    ext_modules.append(\n        CUDAExtension(\n            name=\"fused_layer_norm_cuda\",\n            sources=[\"csrc/layer_norm_cuda.cpp\", \"csrc/layer_norm_cuda_kernel.cu\"],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"],\n                \"nvcc\": [\"-maxrregcount=50\", \"-O3\", \"--use_fast_math\"],\n            },\n        )\n    )\n\n    ext_modules.append(\n        CUDAExtension(\n            name=\"mlp_cuda\",\n            sources=[\"csrc/mlp.cpp\", \"csrc/mlp_cuda.cu\"],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"],\n                \"nvcc\": [\"-O3\"],\n            },\n        )\n    )\n    ext_modules.append(\n        CUDAExtension(\n            name=\"fused_dense_cuda\",\n            sources=[\"csrc/fused_dense.cpp\", \"csrc/fused_dense_cuda.cu\"],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"],\n                \"nvcc\": [\"-O3\"],\n            },\n        )\n    )\n\n    ext_modules.append(\n        CUDAExtension(\n            name=\"scaled_upper_triang_masked_softmax_cuda\",\n            sources=[\n                \"csrc/megatron/scaled_upper_triang_masked_softmax.cpp\",\n                \"csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu\",\n            ],\n            include_dirs=[os.path.join(this_dir, \"csrc\")],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"],\n                \"nvcc\": [\n                    \"-O3\",\n                    \"-U__CUDA_NO_HALF_OPERATORS__\",\n                    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n                    \"--expt-relaxed-constexpr\",\n                    \"--expt-extended-lambda\",\n                ],\n            },\n        )\n    )\n\n    ext_modules.append(\n        CUDAExtension(\n            name=\"generic_scaled_masked_softmax_cuda\",\n            sources=[\n                \"csrc/megatron/generic_scaled_masked_softmax.cpp\",\n                \"csrc/megatron/generic_scaled_masked_softmax_cuda.cu\",\n            ],\n            include_dirs=[os.path.join(this_dir, \"csrc\")],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"],\n                \"nvcc\": [\n                    \"-O3\",\n                    \"-U__CUDA_NO_HALF_OPERATORS__\",\n                    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n                    \"--expt-relaxed-constexpr\",\n                    \"--expt-extended-lambda\",\n                ],\n            },\n        )\n    )\n\n    ext_modules.append(\n        CUDAExtension(\n            name=\"scaled_masked_softmax_cuda\",\n            sources=[\n                \"csrc/megatron/scaled_masked_softmax.cpp\",\n                \"csrc/megatron/scaled_masked_softmax_cuda.cu\",\n            ],\n            include_dirs=[os.path.join(this_dir, \"csrc\")],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"],\n                \"nvcc\": [\n                    \"-O3\",\n                    \"-U__CUDA_NO_HALF_OPERATORS__\",\n                    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n                    \"--expt-relaxed-constexpr\",\n                    \"--expt-extended-lambda\",\n                ],\n            },\n        )\n    )\n\n    ext_modules.append(\n        CUDAExtension(\n            name=\"scaled_softmax_cuda\",\n            sources=[\n                \"csrc/megatron/scaled_softmax.cpp\",\n                \"csrc/megatron/scaled_softmax_cuda.cu\",\n            ],\n            include_dirs=[os.path.join(this_dir, \"csrc\")],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"],\n                \"nvcc\": [\n                    \"-O3\",\n                    \"-U__CUDA_NO_HALF_OPERATORS__\",\n                    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n                    \"--expt-relaxed-constexpr\",\n                    \"--expt-extended-lambda\",\n                ],\n            },\n        )\n    )\n\n    ext_modules.append(\n        CUDAExtension(\n            name=\"fused_rotary_positional_embedding\",\n            sources=[\n                \"csrc/megatron/fused_rotary_positional_embedding.cpp\",\n                \"csrc/megatron/fused_rotary_positional_embedding_cuda.cu\",\n            ],\n            include_dirs=[os.path.join(this_dir, \"csrc\")],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"],\n                \"nvcc\": [\n                    \"-O3\",\n                    \"-U__CUDA_NO_HALF_OPERATORS__\",\n                    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n                    \"--expt-relaxed-constexpr\",\n                    \"--expt-extended-lambda\",\n                ],\n            },\n        )\n    )\n\n    ext_modules.append(\n        CUDAExtension(\n            name=\"fused_weight_gradient_mlp_cuda\",\n            include_dirs=[os.path.join(this_dir, \"csrc\")],\n            sources=[\n                \"csrc/megatron/fused_weight_gradient_dense.cpp\",\n                \"csrc/megatron/fused_weight_gradient_dense_cuda.cu\",\n                \"csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu\",\n            ],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"],\n                \"nvcc\": [\n                    \"-O3\",\n                    \"-U__CUDA_NO_HALF_OPERATORS__\",\n                    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n                    \"--expt-relaxed-constexpr\",\n                    \"--expt-extended-lambda\",\n                    \"--use_fast_math\",\n                ],\n            },\n        )\n    )\n\nif has_flag(\"--permutation_search\", \"APEX_PERMUTATION_SEARCH\"):\n    if \"--permutation_search\" in sys.argv:\n        sys.argv.remove(\"--permutation_search\")\n\n    if CUDA_HOME is None:\n        raise RuntimeError(\n            \"--permutation_search was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.\"\n        )\n    else:\n        cc_flag = [\"-Xcompiler\", \"-fPIC\", \"-shared\"]\n        ext_modules.append(\n            CUDAExtension(\n                name=\"permutation_search_cuda\",\n                sources=[\n                    \"apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu\"\n                ],\n                include_dirs=[\n                    os.path.join(\n                        this_dir,\n                        \"apex\",\n                        \"contrib\",\n                        \"sparsity\",\n                        \"permutation_search_kernels\",\n                        \"CUDA_kernels\",\n                    )\n                ],\n                extra_compile_args={\"cxx\": [\"-O3\"], \"nvcc\": [\"-O3\"] + cc_flag},\n            )\n        )\n\nif has_flag(\"--bnp\", \"APEX_BNP\"):\n    if \"--bnp\" in sys.argv:\n        sys.argv.remove(\"--bnp\")\n    raise_if_cuda_home_none(\"--bnp\")\n    ext_modules.append(\n        CUDAExtension(\n            name=\"bnp\",\n            sources=[\n                \"apex/contrib/csrc/groupbn/batch_norm.cu\",\n                \"apex/contrib/csrc/groupbn/ipc.cu\",\n                \"apex/contrib/csrc/groupbn/interface.cpp\",\n                \"apex/contrib/csrc/groupbn/batch_norm_add_relu.cu\",\n            ],\n            include_dirs=[os.path.join(this_dir, \"csrc\")],\n            extra_compile_args={\n                \"cxx\": [],\n                \"nvcc\": [\n                    \"-DCUDA_HAS_FP16=1\",\n                    \"-D__CUDA_NO_HALF_OPERATORS__\",\n                    \"-D__CUDA_NO_HALF_CONVERSIONS__\",\n                    \"-D__CUDA_NO_HALF2_OPERATORS__\",\n                ],\n            },\n        )\n    )\n\nif has_flag(\"--xentropy\", \"APEX_XENTROPY\"):\n    from datetime import datetime\n\n    if \"--xentropy\" in sys.argv:\n        sys.argv.remove(\"--xentropy\")\n    raise_if_cuda_home_none(\"--xentropy\")\n    xentropy_ver = datetime.today().strftime(\"%y.%m.%d\")\n    print(f\"`--xentropy` setting version of {xentropy_ver}\")\n    ext_modules.append(\n        CUDAExtension(\n            name=\"xentropy_cuda\",\n            sources=[\n                \"apex/contrib/csrc/xentropy/interface.cpp\",\n                \"apex/contrib/csrc/xentropy/xentropy_kernel.cu\",\n            ],\n            include_dirs=[os.path.join(this_dir, \"csrc\")],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"] + [f'-DXENTROPY_VER=\"{xentropy_ver}\"'],\n                \"nvcc\": [\"-O3\"],\n            },\n        )\n    )\n\nif has_flag(\"--focal_loss\", \"APEX_FOCAL_LOSS\"):\n    if \"--focal_loss\" in sys.argv:\n        sys.argv.remove(\"--focal_loss\")\n    raise_if_cuda_home_none(\"--focal_loss\")\n    ext_modules.append(\n        CUDAExtension(\n            name=\"focal_loss_cuda\",\n            sources=[\n                \"apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp\",\n                \"apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu\",\n            ],\n            include_dirs=[os.path.join(this_dir, \"csrc\")],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"],\n                \"nvcc\": [\"-O3\", \"--use_fast_math\", \"--ftz=false\"],\n            },\n        )\n    )\n\nif has_flag(\"--group_norm\", \"APEX_GROUP_NORM\"):\n    if \"--group_norm\" in sys.argv:\n        sys.argv.remove(\"--group_norm\")\n    raise_if_cuda_home_none(\"--group_norm\")\n\n    ext_modules.append(\n        CUDAExtension(\n            name=\"group_norm_cuda\",\n            sources=[\n                \"apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp\",\n            ]\n            + glob.glob(\"apex/contrib/csrc/group_norm/*.cu\"),\n            include_dirs=[os.path.join(this_dir, \"csrc\")],\n            extra_compile_args={\n                \"cxx\": [\"-O3\", \"-std=c++17\"],\n                \"nvcc\": [\n                    \"-O3\",\n                    \"-std=c++17\",\n                    \"--use_fast_math\",\n                    \"--ftz=false\",\n                ],\n            },\n        )\n    )\n\n    # CUDA group norm V2 is tested on SM100\n    if bare_metal_version >= Version(\"12.4\"):\n        if bare_metal_version >= Version(\"12.8\"):\n            arch_flags = [\n                \"-gencode=arch=compute_90,code=sm_90\",\n                \"-gencode=arch=compute_100,code=sm_100\",\n                \"-gencode=arch=compute_120,code=compute_120\",\n            ]\n        else:\n            arch_flags = [\"-gencode=arch=compute_90,code=compute_90\"]\n\n        ext_modules.append(\n            CUDAExtension(\n                name=\"group_norm_v2_cuda\",\n                sources=[\n                    \"apex/contrib/csrc/group_norm_v2/gn.cpp\",\n                    \"apex/contrib/csrc/group_norm_v2/gn_cuda.cu\",\n                    \"apex/contrib/csrc/group_norm_v2/gn_utils.cpp\",\n                ]\n                + glob.glob(\"apex/contrib/csrc/group_norm_v2/gn_cuda_inst_*.cu\"),\n                extra_compile_args={\n                    \"cxx\": [\"-O2\"],\n                    \"nvcc\": [\n                        \"-O2\",\n                        \"--use_fast_math\",\n                        \"--ftz=false\",\n                        \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n                        \"-U__CUDA_NO_HALF_OPERATORS__\",\n                        \"-U__CUDA_NO_BFLOAT16_CONVERSIONS__\",\n                        \"-U__CUDA_NO_BFLOAT16_OPERATORS__\",\n                    ]\n                    + arch_flags,\n                },\n            )\n        )\n\nif has_flag(\"--index_mul_2d\", \"APEX_INDEX_MUL_2D\"):\n    if \"--index_mul_2d\" in sys.argv:\n        sys.argv.remove(\"--index_mul_2d\")\n    raise_if_cuda_home_none(\"--index_mul_2d\")\n    ext_modules.append(\n        CUDAExtension(\n            name=\"fused_index_mul_2d\",\n            sources=[\n                \"apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp\",\n                \"apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu\",\n            ],\n            include_dirs=[os.path.join(this_dir, \"csrc\")],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"],\n                \"nvcc\": [\"-O3\", \"--use_fast_math\", \"--ftz=false\"],\n            },\n        )\n    )\n\nif has_flag(\"--deprecated_fused_adam\", \"APEX_DEPRECATED_FUSED_ADAM\"):\n    if \"--deprecated_fused_adam\" in sys.argv:\n        sys.argv.remove(\"--deprecated_fused_adam\")\n    raise_if_cuda_home_none(\"--deprecated_fused_adam\")\n    ext_modules.append(\n        CUDAExtension(\n            name=\"fused_adam_cuda\",\n            sources=[\n                \"apex/contrib/csrc/optimizers/fused_adam_cuda.cpp\",\n                \"apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu\",\n            ],\n            include_dirs=[os.path.join(this_dir, \"csrc\")],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"],\n                \"nvcc\": [\"-O3\", \"--use_fast_math\"],\n            },\n        )\n    )\n\nif has_flag(\"--deprecated_fused_lamb\", \"APEX_DEPRECATED_FUSED_LAMB\"):\n    if \"--deprecated_fused_lamb\" in sys.argv:\n        sys.argv.remove(\"--deprecated_fused_lamb\")\n    raise_if_cuda_home_none(\"--deprecated_fused_lamb\")\n    ext_modules.append(\n        CUDAExtension(\n            name=\"fused_lamb_cuda\",\n            sources=[\n                \"apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp\",\n                \"apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu\",\n                \"csrc/multi_tensor_l2norm_kernel.cu\",\n            ],\n            include_dirs=[os.path.join(this_dir, \"csrc\")],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"],\n                \"nvcc\": [\"-O3\", \"--use_fast_math\"],\n            },\n        )\n    )\n\n# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h\n# See https://github.com/pytorch/pytorch/pull/70650\ngenerator_flag = []\ntorch_dir = torch.__path__[0]\nif os.path.exists(os.path.join(torch_dir, \"include\", \"ATen\", \"CUDAGeneratorImpl.h\")):\n    generator_flag = [\"-DOLD_GENERATOR_PATH\"]\n\nif has_flag(\"--fast_layer_norm\", \"APEX_FAST_LAYER_NORM\"):\n    if \"--fast_layer_norm\" in sys.argv:\n        sys.argv.remove(\"--fast_layer_norm\")\n    raise_if_cuda_home_none(\"--fast_layer_norm\")\n\n    ext_modules.append(\n        CUDAExtension(\n            name=\"fast_layer_norm\",\n            sources=[\n                \"apex/contrib/csrc/layer_norm/ln_api.cpp\",\n                \"apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu\",\n                \"apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu\",\n            ],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"] + generator_flag,\n                \"nvcc\": [\n                    \"-O3\",\n                    \"-U__CUDA_NO_HALF_OPERATORS__\",\n                    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n                    \"-U__CUDA_NO_BFLOAT16_OPERATORS__\",\n                    \"-U__CUDA_NO_BFLOAT16_CONVERSIONS__\",\n                    \"-U__CUDA_NO_BFLOAT162_OPERATORS__\",\n                    \"-U__CUDA_NO_BFLOAT162_CONVERSIONS__\",\n                    \"-I./apex/contrib/csrc/layer_norm/\",\n                    \"--expt-relaxed-constexpr\",\n                    \"--expt-extended-lambda\",\n                    \"--use_fast_math\",\n                ]\n                + generator_flag,\n            },\n            include_dirs=[os.path.join(this_dir, \"apex/contrib/csrc/layer_norm\")],\n        )\n    )\n\nif has_flag(\"--fmha\", \"APEX_FMHA\"):\n    if \"--fmha\" in sys.argv:\n        sys.argv.remove(\"--fmha\")\n    raise_if_cuda_home_none(\"--fmha\")\n\n    if bare_metal_version < Version(\"11.0\"):\n        raise RuntimeError(\"--fmha only supported on sm_80 and sm_90 GPUs\")\n\n    cc_flag = []\n    cc_flag.append(\"-gencode\")\n    cc_flag.append(\"arch=compute_80,code=sm_80\")\n    if bare_metal_version >= Version(\"11.8\"):\n        cc_flag.append(\"-gencode\")\n        cc_flag.append(\"arch=compute_90,code=sm_90\")\n    if bare_metal_version >= Version(\"12.8\"):\n        cc_flag.append(\"-gencode\")\n        cc_flag.append(\"arch=compute_100,code=sm_100\")\n        cc_flag.append(\"-gencode\")\n        cc_flag.append(\"arch=compute_120,code=sm_120\")\n    if bare_metal_version >= Version(\"13.0\"):\n        cc_flag.append(\"-gencode\")\n        cc_flag.append(\"arch=compute_110,code=sm_110\")\n\n    ext_modules.append(\n        CUDAExtension(\n            name=\"fmhalib\",\n            sources=[\n                \"apex/contrib/csrc/fmha/fmha_api.cpp\",\n                \"apex/contrib/csrc/fmha/src/fmha_fill.cu\",\n                \"apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu\",\n                \"apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu\",\n                \"apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu\",\n                \"apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu\",\n                \"apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu\",\n                \"apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu\",\n                \"apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu\",\n                \"apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu\",\n                \"apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu\",\n            ],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"] + generator_flag,\n                \"nvcc\": [\n                    \"-O3\",\n                    \"-U__CUDA_NO_HALF_OPERATORS__\",\n                    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n                    \"--expt-relaxed-constexpr\",\n                    \"--expt-extended-lambda\",\n                    \"--use_fast_math\",\n                ]\n                + generator_flag\n                + cc_flag,\n            },\n            include_dirs=[\n                os.path.join(this_dir, \"apex/contrib/csrc\"),\n                os.path.join(this_dir, \"apex/contrib/csrc/fmha/src\"),\n            ],\n        )\n    )\n\n\nif has_flag(\"--fast_multihead_attn\", \"APEX_FAST_MULTIHEAD_ATTN\"):\n    if \"--fast_multihead_attn\" in sys.argv:\n        sys.argv.remove(\"--fast_multihead_attn\")\n    raise_if_cuda_home_none(\"--fast_multihead_attn\")\n\n    subprocess.run(\n        [\n            \"git\",\n            \"submodule\",\n            \"update\",\n            \"--init\",\n            \"apex/contrib/csrc/multihead_attn/cutlass\",\n        ]\n    )\n    ext_modules.append(\n        CUDAExtension(\n            name=\"fast_multihead_attn\",\n            sources=[\n                \"apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp\",\n                \"apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu\",\n                \"apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu\",\n                \"apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu\",\n                \"apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu\",\n                \"apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu\",\n                \"apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu\",\n                \"apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu\",\n                \"apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu\",\n            ],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"] + generator_flag,\n                \"nvcc\": [\n                    \"-O3\",\n                    \"-U__CUDA_NO_HALF_OPERATORS__\",\n                    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n                    \"--expt-relaxed-constexpr\",\n                    \"--expt-extended-lambda\",\n                    \"--use_fast_math\",\n                ]\n                + generator_flag,\n            },\n            include_dirs=[\n                os.path.join(this_dir, \"apex/contrib/csrc/multihead_attn/cutlass/include/\"),\n                os.path.join(\n                    this_dir,\n                    \"apex/contrib/csrc/multihead_attn/cutlass/tools/util/include\",\n                ),\n            ],\n        )\n    )\n\nif has_flag(\"--transducer\", \"APEX_TRANSDUCER\"):\n    if \"--transducer\" in sys.argv:\n        sys.argv.remove(\"--transducer\")\n    raise_if_cuda_home_none(\"--transducer\")\n    ext_modules.append(\n        CUDAExtension(\n            name=\"transducer_joint_cuda\",\n            sources=[\n                \"apex/contrib/csrc/transducer/transducer_joint.cpp\",\n                \"apex/contrib/csrc/transducer/transducer_joint_kernel.cu\",\n            ],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"] + generator_flag,\n                \"nvcc\": [\"-O3\"] + generator_flag,\n            },\n            include_dirs=[\n                os.path.join(this_dir, \"csrc\"),\n                os.path.join(this_dir, \"apex/contrib/csrc/multihead_attn\"),\n            ],\n        )\n    )\n    ext_modules.append(\n        CUDAExtension(\n            name=\"transducer_loss_cuda\",\n            sources=[\n                \"apex/contrib/csrc/transducer/transducer_loss.cpp\",\n                \"apex/contrib/csrc/transducer/transducer_loss_kernel.cu\",\n            ],\n            include_dirs=[os.path.join(this_dir, \"csrc\")],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"],\n                \"nvcc\": [\"-O3\"],\n            },\n        )\n    )\n\nif has_flag(\"--cudnn_gbn\", \"APEX_CUDNN_GBN\"):\n    if \"--cudnn_gbn\" in sys.argv:\n        sys.argv.remove(\"--cudnn_gbn\")\n    raise_if_cuda_home_none(\"--cudnn_gbn\")\n    if check_cudnn_version_and_warn(\"--cudnn_gbn\", 8500):\n        subprocess.run(\n            [\n                \"git\",\n                \"submodule\",\n                \"update\",\n                \"--init\",\n                \"apex/contrib/csrc/cudnn-frontend/\",\n            ]\n        )\n        ext_modules.append(\n            CUDAExtension(\n                name=\"cudnn_gbn_lib\",\n                sources=[\n                    \"apex/contrib/csrc/cudnn_gbn/norm_sample.cpp\",\n                    \"apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp\",\n                ],\n                include_dirs=[os.path.join(this_dir, \"apex/contrib/csrc/cudnn-frontend/include\")],\n                extra_compile_args={\"cxx\": [\"-O3\", \"-g\"] + generator_flag},\n            )\n        )\n\nif has_flag(\"--peer_memory\", \"APEX_PEER_MEMORY\"):\n    if \"--peer_memory\" in sys.argv:\n        sys.argv.remove(\"--peer_memory\")\n    raise_if_cuda_home_none(\"--peer_memory\")\n    ext_modules.append(\n        CUDAExtension(\n            name=\"peer_memory_cuda\",\n            sources=[\n                \"apex/contrib/csrc/peer_memory/peer_memory_cuda.cu\",\n                \"apex/contrib/csrc/peer_memory/peer_memory.cpp\",\n            ],\n            extra_compile_args={\"cxx\": [\"-O3\"] + generator_flag},\n        )\n    )\n\n# NOTE: Requires NCCL >= 2.10.3\nif has_flag(\"--nccl_p2p\", \"APEX_NCCL_P2P\"):\n    if \"--nccl_p2p\" in sys.argv:\n        sys.argv.remove(\"--nccl_p2p\")\n    raise_if_cuda_home_none(\"--nccl_p2p\")\n    # Check NCCL version.\n    _nccl_version_getter = load(\n        name=\"_nccl_version_getter\",\n        sources=[\n            \"apex/contrib/csrc/nccl_p2p/nccl_version.cpp\",\n            \"apex/contrib/csrc/nccl_p2p/nccl_version_check.cu\",\n        ],\n    )\n    _available_nccl_version = _nccl_version_getter.get_nccl_version()\n    if _available_nccl_version >= (2, 10):\n        ext_modules.append(\n            CUDAExtension(\n                name=\"nccl_p2p_cuda\",\n                sources=[\n                    \"apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu\",\n                    \"apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp\",\n                ],\n                extra_compile_args={\"cxx\": [\"-O3\"] + generator_flag},\n            )\n        )\n    else:\n        warnings.warn(\n            f\"Skip `--nccl_p2p` as it requires NCCL 2.10.3 or later, but {_available_nccl_version[0]}.{_available_nccl_version[1]}\"\n        )\n\n# note (mkozuki): Now `--fast_bottleneck` option (i.e. apex/contrib/bottleneck) depends on `--peer_memory` and `--nccl_p2p`.\nif has_flag(\"--fast_bottleneck\", \"APEX_FAST_BOTTLENECK\"):\n    if \"--fast_bottleneck\" in sys.argv:\n        sys.argv.remove(\"--fast_bottleneck\")\n    raise_if_cuda_home_none(\"--fast_bottleneck\")\n    if check_cudnn_version_and_warn(\"--fast_bottleneck\", 8400):\n        subprocess.run(\n            [\n                \"git\",\n                \"submodule\",\n                \"update\",\n                \"--init\",\n                \"apex/contrib/csrc/cudnn-frontend/\",\n            ]\n        )\n        ext_modules.append(\n            CUDAExtension(\n                name=\"fast_bottleneck\",\n                sources=[\"apex/contrib/csrc/bottleneck/bottleneck.cpp\"],\n                include_dirs=[os.path.join(this_dir, \"apex/contrib/csrc/cudnn-frontend/include\")],\n                extra_compile_args={\"cxx\": [\"-O3\"] + generator_flag},\n            )\n        )\n\n\nif has_flag(\"--fused_conv_bias_relu\", \"APEX_FUSED_CONV_BIAS_RELU\"):\n    if \"--fused_conv_bias_relu\" in sys.argv:\n        sys.argv.remove(\"--fused_conv_bias_relu\")\n    raise_if_cuda_home_none(\"--fused_conv_bias_relu\")\n    if check_cudnn_version_and_warn(\"--fused_conv_bias_relu\", 8400):\n        subprocess.run(\n            [\n                \"git\",\n                \"submodule\",\n                \"update\",\n                \"--init\",\n                \"apex/contrib/csrc/cudnn-frontend/\",\n            ]\n        )\n        ext_modules.append(\n            CUDAExtension(\n                name=\"fused_conv_bias_relu\",\n                sources=[\"apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp\"],\n                include_dirs=[os.path.join(this_dir, \"apex/contrib/csrc/cudnn-frontend/include\")],\n                extra_compile_args={\"cxx\": [\"-O3\"] + generator_flag},\n            )\n        )\n\n\nif has_flag(\"--nccl_allocator\", \"APEX_NCCL_ALLOCATOR\"):\n    if \"--nccl_allocator\" in sys.argv:\n        sys.argv.remove(\"--nccl_allocator\")\n    raise_if_cuda_home_none(\"--nccl_allocator\")\n    _nccl_version_getter = load(\n        name=\"_nccl_version_getter\",\n        sources=[\n            \"apex/contrib/csrc/nccl_p2p/nccl_version.cpp\",\n            \"apex/contrib/csrc/nccl_p2p/nccl_version_check.cu\",\n        ],\n    )\n    _available_nccl_version = _nccl_version_getter.get_nccl_version()\n    if _available_nccl_version >= (2, 19):\n        ext_modules.append(\n            CUDAExtension(\n                name=\"_apex_nccl_allocator\",\n                sources=[\n                    \"apex/contrib/csrc/nccl_allocator/NCCLAllocator.cpp\",\n                ],\n                include_dirs=[os.path.join(this_dir, \"apex/apex/contrib/csrc/nccl_allocator\")],\n                libraries=[\"nccl\"],\n                extra_compile_args={\"cxx\": [\"-O3\"] + generator_flag},\n            )\n        )\n    else:\n        warnings.warn(\n            f\"Skip `--nccl_allocator` as it requires NCCL 2.19 or later, but {_available_nccl_version[0]}.{_available_nccl_version[1]}\"\n        )\n\n\nif has_flag(\"--gpu_direct_storage\", \"APEX_GPU_DIRECT_STORAGE\"):\n    if \"--gpu_direct_storage\" in sys.argv:\n        sys.argv.remove(\"--gpu_direct_storage\")\n    raise_if_cuda_home_none(\"--gpu_direct_storage\")\n    ext_modules.append(\n        CUDAExtension(\n            name=\"_apex_gpu_direct_storage\",\n            sources=[\n                \"apex/contrib/csrc/gpu_direct_storage/gds.cpp\",\n                \"apex/contrib/csrc/gpu_direct_storage/gds_pybind.cpp\",\n            ],\n            include_dirs=[os.path.join(this_dir, \"apex/contrib/csrc/gpu_direct_storage\")],\n            libraries=[\"cufile\"],\n            extra_compile_args={\"cxx\": [\"-O3\"] + generator_flag},\n        )\n    )\n\n\n# Patch because `setup.py bdist_wheel` and `setup.py develop` do not support the `parallel` option\nparallel: int | None = None\nif \"--parallel\" in sys.argv:\n    idx = sys.argv.index(\"--parallel\")\n    parallel = int(sys.argv[idx + 1])\n    sys.argv.pop(idx + 1)\n    sys.argv.pop(idx)\nelse:\n    # Check if APEX_PARALLEL_BUILD environment variable is set\n    apex_parallel_build = os.environ.get(\"APEX_PARALLEL_BUILD\", None)\n    if apex_parallel_build is not None:\n        try:\n            parallel = int(apex_parallel_build)\n            print(\n                f\"[apex] Using parallel build with {parallel} jobs from APEX_PARALLEL_BUILD environment variable\"\n            )\n        except ValueError:\n            print(\n                f\"[apex] Warning: APEX_PARALLEL_BUILD environment variable '{apex_parallel_build}' is not a valid integer, ignoring\"\n            )\n\n\n# Prevent file conflicts when multiple extensions are compiled simultaneously\nclass BuildExtensionSeparateDir(BuildExtension):\n    build_extension_patch_lock = threading.Lock()\n    thread_ext_name_map = {}\n\n    def finalize_options(self):\n        if parallel is not None:\n            self.parallel = parallel\n        super().finalize_options()\n\n    def build_extension(self, ext):\n        with self.build_extension_patch_lock:\n            if not getattr(self.compiler, \"_compile_separate_output_dir\", False):\n                compile_orig = self.compiler.compile\n\n                def compile_new(*args, **kwargs):\n                    return compile_orig(\n                        *args,\n                        **{\n                            **kwargs,\n                            \"output_dir\": os.path.join(\n                                kwargs[\"output_dir\"],\n                                self.thread_ext_name_map[threading.current_thread().ident],\n                            ),\n                        },\n                    )\n\n                self.compiler.compile = compile_new\n                self.compiler._compile_separate_output_dir = True\n        self.thread_ext_name_map[threading.current_thread().ident] = ext.name\n        objects = super().build_extension(ext)\n        return objects\n\n\nsetup(\n    name=\"apex\",\n    version=\"0.1\",\n    packages=find_packages(\n        exclude=(\n            \"build\",\n            \"csrc\",\n            \"include\",\n            \"tests\",\n            \"dist\",\n            \"docs\",\n            \"tests\",\n            \"examples\",\n            \"apex.egg-info\",\n        )\n    ),\n    install_requires=[\"packaging>20.6\"],\n    description=\"PyTorch Extensions written by NVIDIA\",\n    ext_modules=ext_modules,\n    cmdclass={\"build_ext\": BuildExtensionSeparateDir} if ext_modules else {},\n    extras_require=extras,\n)\n"
  },
  {
    "path": "tests/L0/run_fused_layer_norm/test_fused_layer_norm.py",
    "content": "import importlib.util\n\nimport torch\nfrom apex.normalization import FusedLayerNorm\nfrom apex.normalization import FusedRMSNorm\nfrom apex.normalization import MixedFusedLayerNorm\nfrom apex.normalization import MixedFusedRMSNorm\n\nfrom torch.testing._internal import common_utils\nfrom torch.testing._internal.common_device_type import instantiate_device_type_tests\n\nfrom itertools import product\n\n\ndef _prep_inputs(batch_size, normalized_shape, dtype):\n    shape = (batch_size, *normalized_shape)\n    fused = torch.randn(shape).cuda().requires_grad_(True)\n    with torch.no_grad():\n        native = fused.clone().to(dtype).requires_grad_(True)\n    return native, fused\n\n\nautocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)\n\n\nclass TestFusedLayerNorm(common_utils.TestCase):\n    def _test_fused_layer_norm(\n        self,\n        batch_size,\n        contiguous,\n        elementwise_affine,\n        mixed_fused,\n        dtype,\n        memory_efficient,\n        fwd_thresholds=dict(rtol=None, atol=None),\n        bwd_thresholds=dict(rtol=None, atol=None),\n    ):\n        normalized_shape = [32, 16]\n\n        if not mixed_fused:\n            module_cpu_ = FusedLayerNorm(\n                normalized_shape=normalized_shape,\n                elementwise_affine=elementwise_affine,\n                memory_efficient=memory_efficient,\n            ).cpu()\n            module_cuda_ = FusedLayerNorm(\n                normalized_shape=normalized_shape,\n                elementwise_affine=elementwise_affine,\n                memory_efficient=memory_efficient,\n            ).to(device=\"cuda\", dtype=dtype)\n        else:\n            assert elementwise_affine\n            module_cpu_ = MixedFusedLayerNorm(\n                normalized_shape=normalized_shape, memory_efficient=memory_efficient\n            ).cpu()\n            module_cuda_ = MixedFusedLayerNorm(\n                normalized_shape=normalized_shape, memory_efficient=memory_efficient\n            ).to(device=\"cuda\", dtype=dtype)\n\n        torch.cuda.manual_seed(42)\n        if contiguous:\n            input_shape = [batch_size] + normalized_shape\n            input_ = torch.randn(input_shape, device=\"cpu\").requires_grad_(True)\n            input_cuda_ = input_.to(device=\"cuda\", dtype=dtype).detach().requires_grad_(True)\n            self.assertTrue(input_.is_contiguous())\n            self.assertTrue(input_cuda_.is_contiguous())\n        else:\n            input_shape = [batch_size] + normalized_shape\n            input_shape = [batch_size * 3] + [\n                normalized_shape[0] * 5,\n                normalized_shape[1] * 3,\n            ]\n            input_src_ = torch.randn(input_shape, device=\"cpu\")\n            input_ = input_src_[::3, ::5, ::3].detach().requires_grad_(True)\n            input_cuda_ = (\n                input_src_.to(device=\"cuda\", dtype=dtype)[::3, ::5, ::3]\n                .detach()\n                .requires_grad_(True)\n            )\n            # make sure that tensors are NOT contiguous.\n            self.assertFalse(input_.is_contiguous())\n            self.assertFalse(input_cuda_.is_contiguous())\n        out_cpu_ = module_cpu_(input_)\n        gO = torch.rand_like(out_cpu_)\n        out_cpu_.backward(gO)\n        out_cuda_ = module_cuda_(input_cuda_)\n\n        gO = gO.to(device=\"cuda\", dtype=dtype)\n        out_cuda_.backward(gO)\n        self.assertFalse(out_cpu_.is_cuda)\n        self.assertTrue(out_cuda_.is_cuda)\n        torch.testing.assert_close(\n            out_cpu_.to(device=\"cuda\", dtype=dtype), out_cuda_, **fwd_thresholds\n        )\n        torch.testing.assert_close(\n            input_.grad.to(device=\"cuda\", dtype=dtype),\n            input_cuda_.grad,\n            **bwd_thresholds,\n        )\n\n    def _test_fused_rms_norm(\n        self,\n        batch_size,\n        contiguous,\n        elementwise_affine,\n        mixed_fused,\n        dtype,\n        memory_efficient,\n        fwd_thresholds=dict(rtol=None, atol=None),\n        bwd_thresholds=dict(rtol=None, atol=None),\n    ):\n        normalized_shape = [32, 16]\n\n        if not mixed_fused:\n            module_cpu_ = FusedRMSNorm(\n                normalized_shape=normalized_shape,\n                elementwise_affine=elementwise_affine,\n                memory_efficient=memory_efficient,\n            ).cpu()\n            module_cuda_ = FusedRMSNorm(\n                normalized_shape=normalized_shape,\n                elementwise_affine=elementwise_affine,\n                memory_efficient=memory_efficient,\n            ).to(device=\"cuda\", dtype=dtype)\n        else:\n            assert elementwise_affine\n            module_cpu_ = MixedFusedRMSNorm(normalized_shape=normalized_shape).cpu()\n            module_cuda_ = MixedFusedRMSNorm(normalized_shape=normalized_shape).to(\n                device=\"cuda\", dtype=dtype\n            )\n\n        torch.cuda.manual_seed(42)\n        if contiguous:\n            input_shape = [batch_size] + normalized_shape\n            input_ = torch.randn(input_shape, device=\"cpu\").requires_grad_(True)\n            input_cuda_ = input_.to(device=\"cuda\", dtype=dtype).detach().requires_grad_(True)\n            self.assertTrue(input_.is_contiguous())\n            self.assertTrue(input_cuda_.is_contiguous())\n        else:\n            input_shape = [batch_size] + normalized_shape\n            input_shape = [batch_size * 3] + [\n                normalized_shape[0] * 5,\n                normalized_shape[1] * 3,\n            ]\n            input_src_ = torch.randn(input_shape, device=\"cpu\")\n            input_ = input_src_[::3, ::5, ::3].detach().requires_grad_(True)\n            input_cuda_ = (\n                input_src_.to(device=\"cuda\", dtype=dtype)[::3, ::5, ::3]\n                .detach()\n                .requires_grad_(True)\n            )\n            # make sure that tensors are NOT contiguous.\n            self.assertFalse(input_.is_contiguous())\n            self.assertFalse(input_cuda_.is_contiguous())\n        out_cpu_ = module_cpu_(input_)\n        gO = torch.rand_like(out_cpu_)\n        out_cpu_.backward(gO)\n        out_cuda_ = module_cuda_(input_cuda_)\n\n        torch.testing.assert_close(\n            out_cpu_.to(device=\"cuda\", dtype=dtype),\n            out_cuda_.clone().detach(),\n            **fwd_thresholds,\n        )\n        gO = gO.to(device=\"cuda\", dtype=dtype)\n        out_cuda_.backward(gO)\n        self.assertFalse(out_cpu_.is_cuda)\n        self.assertTrue(out_cuda_.is_cuda)\n        torch.testing.assert_close(\n            input_.grad.to(device=\"cuda\", dtype=dtype),\n            input_cuda_.grad,\n            **bwd_thresholds,\n        )\n        if elementwise_affine:\n            torch.testing.assert_close(\n                module_cpu_.weight.grad.to(device=\"cuda\", dtype=dtype),\n                module_cuda_.weight.grad,\n                **bwd_thresholds,\n            )\n\n    # layer norm tests\n    @common_utils.parametrize(\n        \"batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient\",\n        list(\n            product(\n                (16, 65536),\n                (True, False),\n                (False,),\n                (False,),\n                (torch.float,),\n                (True, False),\n            )\n        ),\n    )\n    def test_layer_norm_regular(\n        self,\n        batch_size,\n        contiguous,\n        elementwise_affine,\n        mixed_fused,\n        dtype,\n        memory_efficient,\n    ):\n        self._test_fused_layer_norm(\n            batch_size,\n            contiguous,\n            elementwise_affine,\n            mixed_fused,\n            dtype,\n            memory_efficient,\n        )\n\n    @common_utils.parametrize(\n        \"batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient\",\n        list(\n            product(\n                (16, 65536),\n                (True, False),\n                (True,),\n                (False,),\n                (torch.float,),\n                (True, False),\n            )\n        ),\n    )\n    def test_layer_norm_elemwise(\n        self,\n        batch_size,\n        contiguous,\n        elementwise_affine,\n        mixed_fused,\n        dtype,\n        memory_efficient,\n    ):\n        self._test_fused_layer_norm(\n            batch_size,\n            contiguous,\n            elementwise_affine,\n            mixed_fused,\n            dtype,\n            memory_efficient,\n        )\n\n    @common_utils.parametrize(\n        \"batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient\",\n        list(\n            product(\n                (16, 65536),\n                (True, False),\n                (True,),\n                (True,),\n                (torch.float,),\n                (True, False),\n            )\n        ),\n    )\n    def test_layer_norm_mixed(\n        self,\n        batch_size,\n        contiguous,\n        elementwise_affine,\n        mixed_fused,\n        dtype,\n        memory_efficient,\n    ):\n        self._test_fused_layer_norm(\n            batch_size,\n            contiguous,\n            elementwise_affine,\n            mixed_fused,\n            dtype,\n            memory_efficient,\n        )\n\n    @common_utils.parametrize(\n        \"batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient\",\n        list(product((16,), (True, False), (True,), (False,), (torch.half,), (True, False))),\n    )\n    def test_layer_norm_half(\n        self,\n        batch_size,\n        contiguous,\n        elementwise_affine,\n        mixed_fused,\n        dtype,\n        memory_efficient,\n    ):\n        self._test_fused_layer_norm(\n            batch_size,\n            contiguous,\n            elementwise_affine,\n            mixed_fused,\n            dtype,\n            memory_efficient,\n            fwd_thresholds=dict(rtol=1e-3, atol=1e-3),\n            bwd_thresholds=dict(rtol=1e-3, atol=1e-3),\n        )\n\n    @common_utils.parametrize(\n        \"batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient\",\n        list(\n            product(\n                (16,),\n                (True, False),\n                (True,),\n                (False,),\n                (torch.bfloat16,),\n                (True, False),\n            )\n        ),\n    )\n    def test_layer_norm_bfloat16(\n        self,\n        batch_size,\n        contiguous,\n        elementwise_affine,\n        mixed_fused,\n        dtype,\n        memory_efficient,\n    ):\n        self._test_fused_layer_norm(\n            batch_size,\n            contiguous,\n            elementwise_affine,\n            mixed_fused,\n            dtype,\n            memory_efficient,\n            fwd_thresholds=dict(rtol=1.6e-2, atol=3e-4),\n            bwd_thresholds=dict(rtol=1.6e-2, atol=3e-3),\n        )\n\n    # rms norm tests\n    @common_utils.parametrize(\n        \"batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient\",\n        list(\n            product(\n                (16, 65536),\n                (True, False),\n                (False,),\n                (False,),\n                (torch.float,),\n                (True, False),\n            )\n        ),\n    )\n    def test_rms_norm_regular(\n        self,\n        batch_size,\n        contiguous,\n        elementwise_affine,\n        mixed_fused,\n        dtype,\n        memory_efficient,\n    ):\n        self._test_fused_rms_norm(\n            batch_size,\n            contiguous,\n            elementwise_affine,\n            mixed_fused,\n            dtype,\n            memory_efficient,\n        )\n\n    @common_utils.parametrize(\n        \"batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient\",\n        list(\n            product(\n                (16, 65536),\n                (True, False),\n                (True,),\n                (False,),\n                (torch.float,),\n                (True, False),\n            )\n        ),\n    )\n    def test_rms_norm_elemwise(\n        self,\n        batch_size,\n        contiguous,\n        elementwise_affine,\n        mixed_fused,\n        dtype,\n        memory_efficient,\n    ):\n        self._test_fused_rms_norm(\n            batch_size,\n            contiguous,\n            elementwise_affine,\n            mixed_fused,\n            dtype,\n            memory_efficient,\n            bwd_thresholds=dict(rtol=2e-3, atol=2e-4),\n        )\n\n    @common_utils.parametrize(\n        \"batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient\",\n        list(\n            product(\n                (16, 65536),\n                (True, False),\n                (True,),\n                (True,),\n                (torch.float,),\n                (True, False),\n            )\n        ),\n    )\n    def test_rms_norm_mixed(\n        self,\n        batch_size,\n        contiguous,\n        elementwise_affine,\n        mixed_fused,\n        dtype,\n        memory_efficient,\n    ):\n        self._test_fused_rms_norm(\n            batch_size,\n            contiguous,\n            elementwise_affine,\n            mixed_fused,\n            dtype,\n            memory_efficient,\n            bwd_thresholds=dict(rtol=2e-3, atol=2e-4),\n        )\n\n    @common_utils.parametrize(\n        \"batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient\",\n        list(product((16,), (True, False), (True,), (False,), (torch.half,), (True, False))),\n    )\n    def test_rms_norm_half(\n        self,\n        batch_size,\n        contiguous,\n        elementwise_affine,\n        mixed_fused,\n        dtype,\n        memory_efficient,\n    ):\n        self._test_fused_rms_norm(\n            batch_size,\n            contiguous,\n            elementwise_affine,\n            mixed_fused,\n            dtype,\n            memory_efficient,\n            bwd_thresholds=dict(rtol=1.6e-2, atol=3e-3),\n        )\n\n    @common_utils.parametrize(\n        \"batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient\",\n        list(\n            product(\n                (16,),\n                (True, False),\n                (True,),\n                (False,),\n                (torch.bfloat16,),\n                (True, False),\n            )\n        ),\n    )\n    def test_rms_norm_bfloat16(\n        self,\n        batch_size,\n        contiguous,\n        elementwise_affine,\n        mixed_fused,\n        dtype,\n        memory_efficient,\n    ):\n        self._test_fused_rms_norm(\n            batch_size,\n            contiguous,\n            elementwise_affine,\n            mixed_fused,\n            dtype,\n            memory_efficient,\n            fwd_thresholds=dict(rtol=1.6e-2, atol=3e-4),\n            bwd_thresholds=dict(rtol=1.6e-2, atol=3e-2),\n        )\n\n    @common_utils.parametrize(\n        \"dtype, elementwise_affine, memory_efficient\",\n        list(product(autocast_dtypes, (True, False), (True, False))),\n    )\n    def test_autocast_fused_layer_norm(self, dtype, elementwise_affine, memory_efficient):\n        bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4)\n        bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)\n        batch_size = 16\n        normalized_shape = [32, 16]\n        native = torch.nn.LayerNorm(\n            normalized_shape=normalized_shape, elementwise_affine=elementwise_affine\n        ).to(device=\"cuda\", dtype=dtype)\n        fused = FusedLayerNorm(\n            normalized_shape=normalized_shape,\n            elementwise_affine=elementwise_affine,\n            memory_efficient=memory_efficient,\n        ).cuda()\n        native_x, fused_x = _prep_inputs(batch_size, normalized_shape, dtype)\n\n        expected = native(native_x)\n        with torch.amp.autocast(\"cuda\", dtype=dtype):\n            actual = fused(fused_x)\n        tols = {\"rtol\": None, \"atol\": None} if dtype == torch.half else bf16_fwd_thresholds\n        # original tests used torch.testing.assert_allclose, which disables dtype checking by default.\n        # link to issue here: https://github.com/pytorch/pytorch/issues/61844\n        torch.testing.assert_close(actual, expected, **tols, check_dtype=False)\n\n        g_native = torch.rand_like(expected)\n        with torch.no_grad():\n            g_fused = g_native.clone()\n        expected.backward(g_native)\n        actual.backward(g_fused)\n\n        if dtype != torch.half:\n            tols = bf16_bwd_thresholds\n        elif memory_efficient:\n            tols = {\"rtol\": 1e-3, \"atol\": 1e-4}\n        else:\n            tols = {\"rtol\": None, \"atol\": None}\n        torch.testing.assert_close(native_x.grad, fused_x.grad, **tols, check_dtype=False)\n\n    @common_utils.parametrize(\n        \"dtype, elementwise_affine, memory_efficient\",\n        list(product(autocast_dtypes, (True, False), (True, False))),\n    )\n    def test_autocast_fused_rms_norm(self, dtype, elementwise_affine, memory_efficient):\n        bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4)\n        bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)\n        batch_size = 16\n        normalized_shape = [32, 16]\n        native = FusedRMSNorm(\n            normalized_shape=normalized_shape,\n            elementwise_affine=elementwise_affine,\n            memory_efficient=memory_efficient,\n        ).to(dtype=dtype)\n        fused = FusedRMSNorm(\n            normalized_shape=normalized_shape,\n            elementwise_affine=elementwise_affine,\n            memory_efficient=memory_efficient,\n        ).cuda()\n        native_x, fused_x = _prep_inputs(batch_size, normalized_shape, dtype)\n\n        expected = native(native_x.cpu())\n        with torch.amp.autocast(\"cuda\", dtype=dtype):\n            actual = fused(fused_x)\n        tols = {\"rtol\": None, \"atol\": None} if dtype == torch.half else bf16_fwd_thresholds\n        torch.testing.assert_close(\n            actual, expected.detach().clone().cuda(), **tols, check_dtype=False\n        )\n\n        g_native = torch.rand_like(expected)\n        with torch.no_grad():\n            g_fused = g_native.detach().clone().cuda()\n        expected.backward(g_native)\n        actual.backward(g_fused)\n\n        tols = {\"rtol\": 1e-3, \"atol\": 1e-3} if dtype == torch.half else bf16_bwd_thresholds\n        torch.testing.assert_close(native_x.grad.cuda(), fused_x.grad, **tols, check_dtype=False)\n\n    def _verify_export(self, fused, fused_x):\n        if importlib.util.find_spec(\"onnxscript\") is None:\n            self.skipTest(\"`onnxscript` is not found\")\n        # check that export() is working\n        import io\n\n        f = io.BytesIO()\n        torch.onnx.export(\n            fused,\n            (fused_x,),\n            f,\n            input_names=[\"x_in\"],\n            opset_version=18,\n        )\n        # Load the ONNX model\n        import onnx\n\n        model_onnx = onnx.load_from_string(f.getvalue())\n        # Get string representation\n        onnx_str = onnx.helper.printable_graph(model_onnx.graph)\n\n        assert \"x_in\" in onnx_str\n        assert \"ReduceMean\" in onnx_str or \"LayerNormalization\" in onnx_str\n\n    def test_rms_export(self):\n        batch_size = 16\n        normalized_shape = [32, 16]\n        fused = FusedRMSNorm(normalized_shape=normalized_shape, elementwise_affine=True).cuda()\n        fused_m = MixedFusedRMSNorm(normalized_shape=normalized_shape).cuda()\n        native_x, fused_x = _prep_inputs(batch_size, normalized_shape, torch.float32)\n        self._verify_export(fused, fused_x)\n        self._verify_export(fused_m, fused_x)\n\n    def test_layer_norm_export(self):\n        batch_size = 16\n        normalized_shape = [32, 16]\n        fused = FusedLayerNorm(normalized_shape=normalized_shape, elementwise_affine=True).cuda()\n        fused_m = MixedFusedLayerNorm(normalized_shape=normalized_shape).cuda()\n        native_x, fused_x = _prep_inputs(batch_size, normalized_shape, torch.float32)\n        self._verify_export(fused, fused_x)\n        self._verify_export(fused_m, fused_x)\n\n    @common_utils.parametrize(\"elementwise_affine\", (True, False))\n    def test_compile_fused_layer_norm(self, elementwise_affine):\n        batch_size = 16\n        normalized_shape = [32, 16]\n        eager_mod = FusedLayerNorm(\n            normalized_shape=normalized_shape, elementwise_affine=elementwise_affine\n        ).cuda()\n        compiled_mod = torch.compile(fullgraph=True)(eager_mod)\n        input_shape = [batch_size] + normalized_shape\n        eager_x = torch.randn(input_shape, device=\"cuda\").requires_grad_(True)\n        compiled_x = eager_x.detach().clone().requires_grad_(True)\n\n        expected = eager_mod(eager_x)\n        actual = compiled_mod(compiled_x)\n        torch.testing.assert_close(actual, expected.detach())\n\n        g_eager = torch.rand_like(expected)\n        with torch.no_grad():\n            g_compiled = g_eager.detach().clone()\n        expected.backward(g_eager)\n        actual.backward(g_compiled)\n\n        torch.testing.assert_close(eager_x.grad, compiled_x.grad)\n\n    @common_utils.parametrize(\"elementwise_affine\", (True, False))\n    def test_compile_fused_rms_norm(self, elementwise_affine):\n        batch_size = 16\n        normalized_shape = [32, 16]\n        eager_mod = FusedRMSNorm(\n            normalized_shape=normalized_shape, elementwise_affine=elementwise_affine\n        ).cuda()\n        compiled_mod = torch.compile(fullgraph=True)(eager_mod)\n        input_shape = [batch_size] + normalized_shape\n        eager_x = torch.randn(input_shape, device=\"cuda\").requires_grad_(True)\n        compiled_x = eager_x.detach().clone().requires_grad_(True)\n\n        expected = eager_mod(eager_x)\n        actual = compiled_mod(compiled_x)\n        torch.testing.assert_close(actual, expected.detach())\n\n        g_eager = torch.rand_like(expected)\n        with torch.no_grad():\n            g_compiled = g_eager.detach().clone()\n        expected.backward(g_eager)\n        actual.backward(g_compiled)\n\n        torch.testing.assert_close(eager_x.grad, compiled_x.grad)\n\n\ninstantiate_device_type_tests(TestFusedLayerNorm, globals(), only_for=(\"cuda\",))\nif __name__ == \"__main__\":\n    common_utils.run_tests()\n"
  },
  {
    "path": "tests/L0/run_mlp/test_mlp.py",
    "content": "\"\"\"Tests for c++ MLP\"\"\"\n\nfrom itertools import product\nfrom time import time\n\nimport torch\nfrom torch import nn\nfrom torch.testing._internal import common_utils\nfrom torch.testing._internal.common_device_type import instantiate_device_type_tests\nfrom torch.testing._internal.common_cuda import tf32_off\n\nfrom apex.mlp import MLP\n\n\nbatch_size = 1024\nmlp_sizes = [480, 1024, 1024, 512, 256, 1]\nnum_iters = 10\n\n\n# note(crcrpar): On Ampere, this test should be run without TF32 enabled.\nclass TestMLP(common_utils.TestCase):\n    def test_creation(self):\n        MLP(mlp_sizes)\n\n    def test_numeric(self):\n        mlp = MLP(mlp_sizes).cuda()\n\n        mlp_layers = []\n        for i in range(mlp.num_layers):\n            linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1])\n            with torch.no_grad():\n                mlp.weights[i].copy_(linear.weight)\n                mlp.biases[i].copy_(linear.bias)\n            mlp_layers.append(linear)\n            mlp_layers.append(nn.ReLU())\n\n        ref_mlp = nn.Sequential(*mlp_layers).cuda()\n\n        test_input = (\n            torch.empty(batch_size, mlp_sizes[0], device=\"cuda\")\n            .uniform_(-1.0, 1.0)\n            .requires_grad_()\n        )\n        ref_input = test_input.clone().detach().requires_grad_()\n        mlp_out = mlp(test_input)\n        ref_out = ref_mlp(ref_input)\n        self.assertEqual(mlp_out, ref_out)\n\n        # Use mean value as scalar loss. Multiply 10 to make it big enough not zero out\n        mlp_out.mean().mul(10.0).backward()\n        ref_out.mean().mul(10.0).backward()\n        self.assertEqual(test_input.grad, ref_input.grad)\n        self.assertEqual(mlp.biases[0].grad, ref_mlp[0].bias.grad)\n\n    def _test_mlp_impl(self, use_activation: str, bias: bool, enable_autocast: bool):\n        mlp = MLP(mlp_sizes, bias=bias, activation=use_activation).cuda()\n\n        mlp_layers = []\n        for i in range(mlp.num_layers):\n            linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1], bias=bias)\n            with torch.no_grad():\n                mlp.weights[i].copy_(linear.weight)\n                if bias:\n                    mlp.biases[i].copy_(linear.bias)\n            mlp_layers.append(linear)\n            if use_activation == \"relu\":\n                mlp_layers.append(nn.ReLU())\n            if use_activation == \"sigmoid\":\n                mlp_layers.append(nn.Sigmoid())\n\n        ref_mlp = nn.Sequential(*mlp_layers).cuda()\n\n        test_input = (\n            torch.empty(batch_size, mlp_sizes[0], device=\"cuda\")\n            .uniform_(-1.0, 1.0)\n            .requires_grad_()\n        )\n        ref_input = test_input.clone().detach().requires_grad_()\n\n        with torch.cuda.amp.autocast_mode.autocast(enabled=enable_autocast):\n            mlp_out = mlp(test_input)\n            mlp_loss = mlp_out.mean().mul(10.0)\n            # Use mean value as scalar loss. Multiply 10 to make it big enough not zero out\n            ref_out = ref_mlp(ref_input)\n            ref_loss = ref_out.mean().mul(10.0)\n\n        mlp_loss.backward()\n        ref_loss.backward()\n        if enable_autocast:\n            self.assertEqual(mlp_out.dtype, torch.float16)\n            self.assertEqual(ref_out.dtype, torch.float16)\n        else:\n            self.assertEqual(mlp_out, ref_out)\n            self.assertEqual(test_input.grad, ref_input.grad)\n            self.assertEqual(mlp.weights[0].grad, ref_mlp[0].weight.grad)\n\n    @tf32_off()\n    @common_utils.parametrize(\n        \"use_activation,bias\",\n        list(product((\"none\", \"relu\", \"sigmoid\"), (True, False))),\n    )\n    def test_mlp(self, use_activation: str, bias: bool):\n        self._test_mlp_impl(use_activation, bias, enable_autocast=False)\n\n    @common_utils.parametrize(\n        \"use_activation,bias\",\n        list(product((\"none\", \"relu\", \"sigmoid\"), (True, False))),\n    )\n    def test_mlp_autocast_fp16(self, use_activation: str, bias: bool):\n        self._test_mlp_impl(use_activation, bias, enable_autocast=True)\n\n    def test_no_grad(self):\n        mlp = MLP(mlp_sizes).cuda()\n\n        mlp_layers = []\n        for i in range(mlp.num_layers):\n            linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1])\n            with torch.no_grad():\n                mlp.weights[i].copy_(linear.weight)\n                mlp.biases[i].copy_(linear.bias)\n            mlp_layers.append(linear)\n            mlp_layers.append(nn.ReLU(inplace=True))\n\n        ref_mlp = nn.Sequential(*mlp_layers).cuda()\n\n        test_input = torch.empty(batch_size, mlp_sizes[0], device=\"cuda\").uniform_(-1.0, 1.0)\n        ref_input = test_input.clone().detach()\n        mlp_out = mlp(test_input)\n        ref_out = ref_mlp(ref_input)\n        self.assertEqual(mlp_out, ref_out)\n\n        # Use mean value as scalar loss. Multiply 10 to make it big enough not zero out\n        mlp_out.mean().mul(10.0).backward()\n        ref_out.mean().mul(10.0).backward()\n        self.assertEqual(mlp.weights[0].grad, ref_mlp[0].weight.grad)\n\n    def test_performance_half(self):\n        mlp = MLP(mlp_sizes).cuda().half()\n\n        mlp_layers = []\n        for i in range(mlp.num_layers):\n            linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1])\n            mlp.weights[i].data.copy_(linear.weight)\n            mlp.biases[i].data.copy_(linear.bias)\n            mlp_layers.append(linear)\n            mlp_layers.append(nn.ReLU(inplace=True))\n\n        ref_mlp = nn.Sequential(*mlp_layers).cuda().half()\n\n        test_input = (\n            torch.empty(batch_size, mlp_sizes[0], device=\"cuda\", dtype=torch.half)\n            .fill_(10.0)\n            .requires_grad_()\n        )\n        ref_input = (\n            torch.empty(batch_size, mlp_sizes[0], device=\"cuda\", dtype=torch.half)\n            .fill_(10.0)\n            .requires_grad_()\n        )\n\n        # Warm up GPU\n        for _ in range(100):\n            ref_out = ref_mlp(ref_input)\n            ref_loss = ref_out.mean()\n            ref_mlp.zero_grad()\n            ref_loss.backward()\n            mlp_out = mlp(test_input)\n            test_loss = mlp_out.mean()\n            mlp.zero_grad()\n            test_loss.backward()\n\n        torch.cuda.profiler.start()\n        torch.cuda.synchronize()\n        start_time = time()\n        for _ in range(num_iters):\n            ref_out = ref_mlp(ref_input)\n            ref_loss = ref_out.mean()\n            ref_mlp.zero_grad()\n            ref_loss.backward()\n        torch.cuda.synchronize()\n        stop_time = time()\n        ref_time = (stop_time - start_time) * 1000.0 / num_iters\n        print(f\"\\nPytorch MLP time {ref_time:.4f} ms\")\n\n        torch.cuda.synchronize()\n        start_time = time()\n        for _ in range(num_iters):\n            mlp_out = mlp(test_input)\n            test_loss = mlp_out.mean()\n            mlp.zero_grad()\n            test_loss.backward()\n        torch.cuda.synchronize()\n        stop_time = time()\n        actual_time = (stop_time - start_time) * 1000.0 / num_iters\n        print(f\"C++ MLP time {actual_time:.4f} ms\")\n        torch.cuda.profiler.stop()\n        self.assertLessEqual(\n            actual_time,\n            ref_time,\n            msg=f\"Custom extension took {actual_time:.4f} while PyTorch took {ref_time:.4f}\",\n        )\n\n\ninstantiate_device_type_tests(TestMLP, globals(), only_for=(\"cuda\",))\n\n\nif __name__ == \"__main__\":\n    common_utils.run_tests()\n"
  },
  {
    "path": "tests/L0/run_optimizers/__init__.py",
    "content": ""
  },
  {
    "path": "tests/L0/run_optimizers/test_adam.py",
    "content": "import copy\nimport unittest\n\nimport torch\nfrom torch import nn\nfrom torch.testing._internal.common_device_type import largeTensorTest\n\ntry:\n    import apex\nexcept ImportError:\n    HAS_APEX = False\nelse:\n    HAS_APEX = True\n\n\nclass Model(torch.nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        self.conv1 = nn.Conv2d(1, 6, 5)\n        self.relu1 = nn.ReLU()\n        self.pool1 = nn.MaxPool2d(2)\n        self.conv2 = nn.Conv2d(6, 16, 5)\n        self.relu2 = nn.ReLU()\n        self.pool2 = nn.MaxPool2d(2)\n        self.fc1 = nn.Linear(256, 120)\n        self.relu3 = nn.ReLU()\n        self.fc2 = nn.Linear(120, 84)\n        self.relu4 = nn.ReLU()\n        self.fc3 = nn.Linear(84, 10)\n        self.relu5 = nn.ReLU()\n\n    def forward(self, x):\n        y = self.conv1(x)\n        y = self.relu1(y)\n        y = self.pool1(y)\n        y = self.conv2(y)\n        y = self.relu2(y)\n        y = self.pool2(y)\n        y = y.reshape(y.shape[0], -1)\n        y = self.fc1(y)\n        y = self.relu3(y)\n        y = self.fc2(y)\n        y = self.relu4(y)\n        y = self.fc3(y)\n        y = self.relu5(y)\n        return y\n\n\n@unittest.skipIf(not HAS_APEX, \"`apex` is not found.\")\nclass AdamTest(unittest.TestCase):\n    def setUp(self, seed=0):\n        super().setUp()\n        torch.manual_seed(seed)\n\n        self.model = Model().cuda()\n        self.model_ = Model().cuda()\n        self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))\n\n        self.lr = 0.00001\n        params = [p for p in self.model.parameters() if p.requires_grad]\n        self.optimizer = torch.optim.Adam(params, lr=self.lr)\n\n    def testGradScaler(self):\n        params_ = [p for p in self.model_.parameters() if p.requires_grad]\n        optimizer_ = apex.optimizers.FusedAdam(params_, lr=self.lr, capturable=False)\n        scaler = torch.amp.GradScaler(\"cuda\", enabled=True)\n        scaler_ = torch.amp.GradScaler(\"cuda\", enabled=True)\n\n        for i in range(100):\n            x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)\n            x_ = x.clone()\n            gt = torch.rand([32, 10]).cuda()\n            gt_ = gt.clone()\n\n            # Reference\n            with torch.amp.autocast(\"cuda\", enabled=True):\n                y = self.model(x)\n                loss = ((gt - y) ** 2).mean()\n\n            scaler.scale(loss).backward()\n            scaler.step(self.optimizer)\n            scaler.update()\n\n            # DUT\n            with torch.amp.autocast(\"cuda\", enabled=True):\n                y = self.model_(x)\n                loss_ = ((gt_ - y) ** 2).mean()\n\n            scaler_.scale(loss_).backward()\n            scaler_.step(optimizer_)\n            scaler_.update()\n\n            for module in zip(self.model.modules(), self.model_.modules()):\n                m = module[0]\n                m_ = module[1]\n                if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):\n                    torch.testing.assert_close(\n                        m.weight, m_.weight, atol=1e-3, rtol=1e-3, equal_nan=True\n                    )\n                    torch.testing.assert_close(\n                        m.weight.grad,\n                        m_.weight.grad,\n                        atol=1e-3,\n                        rtol=1e-3,\n                        equal_nan=True,\n                    )\n\n            # Init for next iteration\n            self.optimizer.zero_grad()\n            optimizer_.zero_grad()\n\n            self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))\n\n    def testGradScalerCapturable(self):\n        params_ = [p for p in self.model_.parameters() if p.requires_grad]\n        optimizer_ = apex.optimizers.FusedAdam(params_, lr=self.lr, capturable=True)\n        scaler = torch.amp.GradScaler(\"cuda\", enabled=True)\n        scaler_ = torch.amp.GradScaler(\"cuda\", enabled=True)\n\n        for i in range(100):\n            x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)\n            x_ = x.clone()\n            gt = torch.rand([32, 10]).cuda()\n            gt_ = gt.clone()\n\n            # Reference\n            with torch.amp.autocast(\"cuda\", enabled=True):\n                y = self.model(x)\n                loss = ((gt - y) ** 2).mean()\n\n            scaler.scale(loss).backward()\n            scaler.step(self.optimizer)\n            scaler.update()\n\n            # DUT\n            with torch.amp.autocast(\"cuda\", enabled=True):\n                y = self.model_(x)\n                loss_ = ((gt_ - y) ** 2).mean()\n\n            scaler_.scale(loss_).backward()\n            scaler_.step(optimizer_)\n            scaler_.update()\n\n            for module in zip(self.model.modules(), self.model_.modules()):\n                m = module[0]\n                m_ = module[1]\n                if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):\n                    torch.testing.assert_close(\n                        m.weight, m_.weight, atol=1e-3, rtol=1e-3, equal_nan=True\n                    )\n                    torch.testing.assert_close(\n                        m.weight.grad,\n                        m_.weight.grad,\n                        atol=1e-3,\n                        rtol=1e-3,\n                        equal_nan=True,\n                    )\n\n            # Init for next iteration\n            self.optimizer.zero_grad()\n            optimizer_.zero_grad()\n\n            self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))\n\n    def testGradScalerCapturableMaster(self):\n        # Cast conv layers to FP16\n        for m in self.model_.modules():\n            if m.__class__ in [torch.nn.Conv2d]:\n                m.half()\n        params_ = [p for p in self.model_.parameters() if p.requires_grad]\n        optimizer_ = apex.optimizers.FusedAdam(\n            params_, lr=self.lr, capturable=True, master_weights=True\n        )\n        scaler = torch.amp.GradScaler(\"cuda\", enabled=True)\n        scaler_ = torch.amp.GradScaler(\"cuda\", enabled=True)\n\n        for i in range(100):\n            x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)\n            x_ = x.clone()\n            gt = torch.rand([32, 10]).cuda()\n            gt_ = gt.clone()\n\n            # Reference\n            with torch.amp.autocast(\"cuda\", enabled=True):\n                y = self.model(x)\n                loss = ((gt - y) ** 2).mean()\n\n            scaler.scale(loss).backward()\n            scaler.step(self.optimizer)\n            scaler.update()\n\n            # DUT\n            with torch.amp.autocast(\"cuda\", enabled=True):\n                y = self.model_(x)\n                loss_ = ((gt_ - y) ** 2).mean()\n\n            scaler_.scale(loss_).backward()\n            scaler_.step(optimizer_)\n            scaler_.update()\n\n            for module in zip(self.model.modules(), self.model_.modules()):\n                m = module[0]\n                m_ = module[1]\n                if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):\n                    torch.testing.assert_close(\n                        m.weight,\n                        m_.weight.float(),\n                        atol=1e-3,\n                        rtol=1e-3,\n                        equal_nan=True,\n                    )\n                    torch.testing.assert_close(\n                        m.weight.grad,\n                        m_.weight.grad.float(),\n                        atol=1e-3,\n                        rtol=1e-3,\n                        equal_nan=True,\n                    )\n\n            # Init for next iteration\n            self.optimizer.zero_grad()\n            optimizer_.zero_grad()\n\n            self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))\n\n    def testNative(self):\n        params_ = [p for p in self.model_.parameters() if p.requires_grad]\n        optimizer_ = apex.optimizers.FusedAdam(params_, lr=self.lr, capturable=False)\n\n        for i in range(100):\n            x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)\n            x_ = x.clone()\n            gt = torch.rand([32, 10]).cuda()\n            gt_ = gt.clone()\n\n            # Reference\n            y = self.model(x)\n            loss = ((gt - y) ** 2).mean()\n\n            loss.backward()\n            self.optimizer.step()\n\n            # DUT\n            y = self.model_(x)\n            loss_ = ((gt_ - y) ** 2).mean()\n\n            loss_.backward()\n            optimizer_.step()\n\n            for module in zip(self.model.modules(), self.model_.modules()):\n                m = module[0]\n                m_ = module[1]\n                if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):\n                    torch.testing.assert_close(\n                        m.weight, m_.weight, atol=1e-3, rtol=1e-3, equal_nan=True\n                    )\n                    torch.testing.assert_close(\n                        m.weight.grad,\n                        m_.weight.grad,\n                        atol=1e-3,\n                        rtol=1e-3,\n                        equal_nan=True,\n                    )\n\n            # Init for next iteration\n            self.optimizer.zero_grad()\n            optimizer_.zero_grad()\n\n            self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))\n\n    @largeTensorTest(\"60GB\", \"cuda\")\n    def testLargeTensor(self):\n        t = torch.zeros(2359332864, dtype=torch.half, device=\"cuda\")\n        t2 = torch.zeros(2359332864, dtype=torch.half, device=\"cuda\")\n        grad = torch.randn_like(t)\n        t.grad = grad\n        t2.grad = grad\n        params = [t]\n        params2 = [t2]\n        optimizer = apex.optimizers.FusedAdam(params, lr=self.lr)\n        optimizer.step()\n        optimizer2 = torch.optim.Adam(params2, lr=self.lr)\n        torch.testing.assert_close(t, t2)\n        torch.cuda.synchronize()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/L0/run_optimizers/test_fused_novograd.py",
    "content": "import torch\nfrom torch.optim import Optimizer\nimport apex\nimport unittest\n\nfrom test_fused_optimizer import TestFusedOptimizer\nfrom itertools import product\n\n\nclass Novograd(Optimizer):\n    \"\"\"\n    Implements Novograd algorithm.\n\n    Args:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float, optional): learning rate (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square (default: (0.95, 0))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        grad_averaging: gradient averaging\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            (default: False)\n    \"\"\"\n\n    def __init__(\n        self,\n        params,\n        lr=1e-3,\n        betas=(0.95, 0),\n        eps=1e-8,\n        weight_decay=0,\n        grad_averaging=False,\n        amsgrad=False,\n    ):\n        if not 0.0 <= lr:\n            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n        if not 0.0 <= eps:\n            raise ValueError(\"Invalid epsilon value: {}\".format(eps))\n        if not 0.0 <= betas[0] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 0: {}\".format(betas[0]))\n        if not 0.0 <= betas[1] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 1: {}\".format(betas[1]))\n        defaults = dict(\n            lr=lr,\n            betas=betas,\n            eps=eps,\n            weight_decay=weight_decay,\n            grad_averaging=grad_averaging,\n            amsgrad=amsgrad,\n        )\n\n        super(Novograd, self).__init__(params, defaults)\n\n    def __setstate__(self, state):\n        super(Novograd, self).__setstate__(state)\n        for group in self.param_groups:\n            group.setdefault(\"amsgrad\", False)\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n            and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n            for p in group[\"params\"]:\n                if p.grad is None:\n                    continue\n                grad = p.grad.data\n                if grad.is_sparse:\n                    raise RuntimeError(\"Sparse gradients are not supported.\")\n                amsgrad = group[\"amsgrad\"]\n\n                state = self.state[p]\n\n                # State initialization\n                if len(state) == 0:\n                    state[\"step\"] = 0\n                    # Exponential moving average of gradient values\n                    state[\"exp_avg\"] = torch.zeros_like(p.data)\n                    # Exponential moving average of squared gradient values\n                    state[\"exp_avg_sq\"] = torch.zeros([]).to(state[\"exp_avg\"].device)\n                    if amsgrad:\n                        # Maintains max of all exp. moving avg. of sq. grad. values\n                        state[\"max_exp_avg_sq\"] = torch.zeros([]).to(state[\"exp_avg\"].device)\n\n                exp_avg, exp_avg_sq = state[\"exp_avg\"], state[\"exp_avg_sq\"]\n                if amsgrad:\n                    max_exp_avg_sq = state[\"max_exp_avg_sq\"]\n                beta1, beta2 = group[\"betas\"]\n\n                state[\"step\"] += 1\n\n                norm = torch.sum(torch.pow(grad, 2))\n\n                if exp_avg_sq == 0:\n                    exp_avg_sq.copy_(norm)\n                else:\n                    exp_avg_sq.mul_(beta2).add_(norm, alpha=1 - beta2)\n\n                if amsgrad:\n                    # Maintains the maximum of all 2nd moment running avg. till now\n                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)\n                    # Use the max. for normalizing running avg. of gradient\n                    denom = max_exp_avg_sq.sqrt().add_(group[\"eps\"])\n                else:\n                    denom = exp_avg_sq.sqrt().add_(group[\"eps\"])\n\n                grad.div_(denom)\n                if group[\"weight_decay\"] != 0:\n                    grad.add_(p.data, alpha=group[\"weight_decay\"])\n                if group[\"grad_averaging\"]:\n                    grad.mul_(1 - beta1)\n                exp_avg.mul_(beta1).add_(grad)\n\n                p.data.add_(exp_avg, alpha=-group[\"lr\"])\n\n        return loss\n\n\nclass TestFusedNovoGrad(TestFusedOptimizer):\n    def __init__(self, *args, **kwargs):\n        super(TestFusedNovoGrad, self).__init__(*args, **kwargs)\n\n        # The options for NovoGrad and FusedNovoGrad are very specific if they\n        # are expected to behave the same.\n        self.options = {\n            \"lr\": 1e-3,\n            \"betas\": (0.95, 0),\n            \"eps\": 1e-8,\n            \"weight_decay\": 0,\n            \"grad_averaging\": False,\n            \"amsgrad\": False,\n        }\n\n        self.tst_options = {\n            \"lr\": 1e-3,\n            \"betas\": (0.95, 0),\n            \"eps\": 1e-8,\n            \"weight_decay\": 0,\n            \"grad_averaging\": False,\n            \"amsgrad\": False,\n            \"bias_correction\": False,\n            \"reg_inside_moment\": True,\n            \"norm_type\": 2,\n            \"init_zero\": False,\n            \"set_grad_none\": True,\n        }\n\n        self.ref_optim = Novograd\n        self.fused_optim = apex.optimizers.FusedNovoGrad\n\n    def test_float(self):\n        self.gen_single_type_test(param_type=torch.float)\n\n    def test_half(self):\n        self.gen_single_type_test(param_type=torch.float16)\n\n    @unittest.skipIf(torch.cuda.device_count() < 2, \"more than 1 GPU required\")\n    def test_multi_device(self):\n        devices = (\"cuda:1\", \"cuda:0\")\n        for current_dev, tensor_dev in product(devices, devices):\n            with torch.cuda.device(current_dev):\n                torch.cuda.synchronize()\n                self.gen_single_type_test(param_type=torch.float, device=tensor_dev)\n\n    def test_multi_params(self):\n        sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]\n\n        tensors = []\n        for size in sizes:\n            tensors.append(torch.rand(size, dtype=torch.float, device=\"cuda\"))\n        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(\n            tensors, self.options, self.tst_options\n        )\n\n        for _ in range(self.iters):\n            self.gen_grad(ref_param, tst_param)\n            ref_optim.step()\n            tst_optim.step()\n            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)\n            self.assertLessEqual(max_abs_diff, self.max_abs_diff)\n            self.assertLessEqual(max_rel_diff, self.max_rel_diff)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/L0/run_optimizers/test_fused_optimizer.py",
    "content": "from itertools import product\nimport random\nimport unittest\n\nimport torch\n\nimport apex\n\n\nclass TestFusedOptimizer(unittest.TestCase):\n    def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):\n        self.max_abs_diff = max_abs_diff\n        self.max_rel_diff = max_rel_diff\n        self.iters = iters\n        torch.manual_seed(9876)\n\n    def tearDown(self):\n        pass\n\n    def gen_param_optim(self, tensors, options, tst_options=None):\n        # Adding this to make backward compatible with existing tests. Just in\n        # case \"tst_options\" are not provided, it gets a copy of options\n        # which contains the parameters for the reference optimizer\n        if tst_options == None:\n            tst_options = options\n\n        ref_param = []\n        tst_param = []\n        for tensor in tensors:\n            ref_param.append(torch.nn.Parameter(tensor.clone()))\n            tst_param.append(torch.nn.Parameter(tensor.clone()))\n\n        ref_optim = self.ref_optim(ref_param, **options)\n        tst_optim = self.fused_optim(tst_param, **tst_options)\n\n        return (ref_param, tst_param, ref_optim, tst_optim)\n\n    def gen_grad(self, ref_param, tst_param):\n        for p_ref, p_tst in zip(ref_param, tst_param):\n            p_ref.grad = torch.rand_like(p_ref)\n            p_tst.grad = p_ref.grad\n\n    def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):\n        half_grads = []\n        for p_ref, p_tst in zip(ref_param, tst_param):\n            half_grads.append(torch.rand_like(p_ref).half())\n            p_ref.grad = half_grads[-1].float() / scale\n        return half_grads\n\n    def get_max_diff(self, ref_param, tst_param):\n        max_abs_diff = max_rel_diff = 0\n        for p_ref, p_tst in zip(ref_param, tst_param):\n            max_abs_diff_p = (p_ref - p_tst).abs().max().item()\n            max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()\n\n            if max_abs_diff_p > max_abs_diff:\n                max_abs_diff = max_abs_diff_p\n            if max_rel_diff_p > max_rel_diff:\n                max_rel_diff = max_rel_diff_p\n\n        return max_abs_diff, max_rel_diff\n\n    def gen_single_type_test(\n        self, param_type=torch.float, device=\"cuda\", *, skip_assert: bool = False\n    ):\n        nelem = 278011\n\n        # Some ref and test optimizers may require different set of options.\n        # This is a quick workaround to add that functionality while making\n        # minimum changes in existing code.\n        # If there is no \"tst_options\" field provided, safe to initialize\n        # the test optimizer with the parameters of reference optimizer.\n        if not hasattr(self, \"tst_options\"):\n            self.tst_options = self.options\n\n        tensor = torch.rand(nelem, dtype=param_type, device=device)\n\n        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(\n            [tensor], self.options, self.tst_options\n        )\n\n        for i in range(self.iters):\n            self.gen_grad(ref_param, tst_param)\n            ref_optim.step()\n            tst_optim.step()\n            if skip_assert:\n                return\n            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)\n            self.assertLessEqual(max_abs_diff, self.max_abs_diff)\n            self.assertLessEqual(max_rel_diff, self.max_rel_diff)\n\n\nclass TestFusedAdam(TestFusedOptimizer):\n    def setUp(self):\n        super().setUp()\n        self.options = {\n            \"lr\": 5e-4,\n            \"betas\": (0.9, 0.999),\n            \"eps\": 1e-08,\n            \"weight_decay\": 0,\n            \"amsgrad\": False,\n        }\n        self.ref_optim = torch.optim.Adam\n        self.fused_optim = apex.optimizers.FusedAdam\n\n    def test_float(self):\n        self.gen_single_type_test(param_type=torch.float)\n\n    # NOTE(mkozuki): Current threshold values look too small for BFloat16.\n    # TODO(mkozuki): Refactor `TestFusedOptimizer`\n    def test_half(self):\n        self.gen_single_type_test(param_type=torch.float16, skip_assert=True)\n\n    def test_bfloat16(self):\n        self.gen_single_type_test(param_type=torch.bfloat16, skip_assert=True)\n\n    @unittest.skipIf(torch.cuda.device_count() < 2, \"more than 1 GPU required\")\n    def test_multi_device(self):\n        devices = (\"cuda:0\", \"cuda:1\")\n        for current_dev, tensor_dev in product(devices, devices):\n            with torch.cuda.device(current_dev):\n                self.gen_single_type_test(param_type=torch.float, device=tensor_dev)\n\n    @unittest.skip(\"Disable until 8/1/2019 adam/adamw upstream picked\")\n    def test_multi_params(self):\n        sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]\n\n        tensors = []\n        for size in sizes:\n            tensors.append(torch.rand(size, dtype=torch.float, device=\"cuda\"))\n        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(tensors, self.options)\n\n        for i in range(self.iters):\n            self.gen_grad(ref_param, tst_param)\n            ref_optim.step()\n            tst_optim.step()\n            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)\n            self.assertLessEqual(max_abs_diff, self.max_abs_diff)\n            self.assertLessEqual(max_rel_diff, self.max_rel_diff)\n\n    @unittest.skip(\"No longer support fuse scaling\")\n    def test_scale(self):\n        nelem = 278011\n        tensor = torch.rand(nelem, dtype=torch.float, device=\"cuda\")\n        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], self.options)\n\n        for i in range(self.iters):\n            scale = random.random() * 1000\n            half_grads = self.gen_mixed_grad(ref_param, tst_param, scale)\n            ref_optim.step()\n            tst_optim.step(grads=half_grads, scale=scale)\n            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)\n\n            self.assertLessEqual(max_abs_diff, self.max_abs_diff)\n            self.assertLessEqual(max_rel_diff, self.max_rel_diff)\n\n    @unittest.skip(\"No longer support output fp16 param\")\n    def test_fp16_output(self):\n        nelem = 278011\n\n        tensor = torch.rand(nelem, dtype=torch.float, device=\"cuda\")\n        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], self.options)\n\n        fp16_param = torch.nn.Parameter(tensor.clone().half())\n\n        for i in range(self.iters):\n            half_grads = self.gen_mixed_grad(ref_param, tst_param)\n            ref_optim.step()\n            tst_optim.step(grads=half_grads, output_params=[fp16_param])\n\n            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)\n            self.assertLessEqual(max_abs_diff, self.max_abs_diff)\n            self.assertLessEqual(max_rel_diff, self.max_rel_diff)\n\n            max_abs_diff, max_rel_diff = self.get_max_diff(tst_param, [fp16_param.float()])\n            self.assertLessEqual(max_abs_diff, self.max_abs_diff)\n            self.assertLessEqual(max_rel_diff, self.max_rel_diff)\n\n    def test_adam_option(self):\n        nelem = 1\n        adam_option = {\n            \"lr\": 0.01,\n            \"betas\": (0.6, 0.9),\n            \"eps\": 3e-06,\n            \"weight_decay\": 0,\n            \"amsgrad\": False,\n        }\n\n        tensor = torch.rand(nelem, dtype=torch.float, device=\"cuda\")\n        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], adam_option)\n\n        for i in range(self.iters):\n            self.gen_grad(ref_param, tst_param)\n            ref_optim.step()\n            tst_optim.step()\n            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)\n\n            self.assertLessEqual(max_abs_diff, self.max_abs_diff)\n            self.assertLessEqual(max_rel_diff, self.max_rel_diff)\n\n    def test_frozen_model(self):\n        nelem = 1\n        adam_option = {\n            \"lr\": 0.01,\n            \"betas\": (0.6, 0.9),\n            \"eps\": 3e-06,\n            \"weight_decay\": 0,\n            \"amsgrad\": False,\n        }\n\n        tensor = torch.rand(nelem, dtype=torch.float, device=\"cuda\")\n        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], adam_option)\n\n        # Add an empty param group which may occur for pipeline parallel p-tuning\n        tst_optim.add_param_group({\"params\": []})\n\n        for i in range(self.iters):\n            self.gen_grad(ref_param, tst_param)\n            ref_optim.step()\n            tst_optim.step()\n            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)\n\n            self.assertLessEqual(max_abs_diff, self.max_abs_diff)\n            self.assertLessEqual(max_rel_diff, self.max_rel_diff)\n\n\nclass TestFusedAdagrad(TestFusedOptimizer):\n    def __init__(self, *args, **kwargs):\n        super(TestFusedAdagrad, self).__init__(*args, **kwargs)\n        self.options = {\"lr\": 5e-4, \"eps\": 1e-08, \"weight_decay\": 1.0e-5}\n        self.ref_optim = torch.optim.Adagrad\n        self.fused_optim = apex.optimizers.FusedAdagrad\n\n    def test_float(self):\n        self.gen_single_type_test(param_type=torch.float)\n\n    @unittest.skip(\"PyTorch optimizer is not numerically correct for fp16\")\n    def test_half(self):\n        self.gen_single_type_test(param_type=torch.float16)\n\n    @unittest.skipIf(torch.cuda.device_count() < 2, \"more than 1 GPU required\")\n    def test_multi_device(self):\n        devices = (\"cuda:0\", \"cuda:1\")\n        for current_dev, tensor_dev in product(devices, devices):\n            with torch.cuda.device(current_dev):\n                self.gen_single_type_test(param_type=torch.float, device=tensor_dev)\n\n    def test_multi_params(self):\n        sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]\n        adagrad_option = {\"lr\": 5e-4, \"eps\": 1e-08, \"weight_decay\": 0}\n\n        tensors = []\n        for size in sizes:\n            tensors.append(torch.rand(size, dtype=torch.float, device=\"cuda\"))\n        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(tensors, adagrad_option)\n\n        for _ in range(self.iters):\n            self.gen_grad(ref_param, tst_param)\n            ref_optim.step()\n            tst_optim.step()\n            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)\n            self.assertLessEqual(max_abs_diff, self.max_abs_diff)\n            self.assertLessEqual(max_rel_diff, self.max_rel_diff)\n\n    @unittest.skipIf(torch.cuda.device_count() < 2, \"more than 1 GPU required\")\n    def test_multi_params_different_devices_throws(self):\n        sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]\n        adagrad_option = {\"lr\": 5e-4, \"eps\": 1e-08, \"weight_decay\": 0}\n\n        tensors = []\n        for i, size in enumerate(sizes):\n            tensors.append(torch.rand(size, dtype=torch.float, device=\"cuda:\" + str(i % 2)))\n        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(tensors, adagrad_option)\n        self.gen_grad(ref_param, tst_param)\n        with self.assertRaisesRegex(RuntimeError, \"not on the same device\"):\n            tst_optim.step()\n\n    def test_adagrad_option(self):\n        nelem = 1\n        adagrad_option = {\"lr\": 0.01, \"eps\": 3e-06, \"weight_decay\": 0}\n\n        tensor = torch.rand(nelem, dtype=torch.float, device=\"cuda\")\n        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], adagrad_option)\n\n        for _ in range(self.iters):\n            self.gen_grad(ref_param, tst_param)\n            ref_optim.step()\n            tst_optim.step()\n            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)\n\n            self.assertLessEqual(max_abs_diff, self.max_abs_diff)\n            self.assertLessEqual(max_rel_diff, self.max_rel_diff)\n\n\nclass TestFusedSGD(TestFusedOptimizer):\n    def __init__(self, *args, **kwargs):\n        super(TestFusedSGD, self).__init__(*args, **kwargs)\n        self.options = {\"lr\": 0.25, \"momentum\": 0.125}\n        self.ref_optim = torch.optim.SGD\n        self.fused_optim = apex.optimizers.FusedSGD\n\n    def test_float(self):\n        self.gen_single_type_test(param_type=torch.float)\n\n    def test_half(self):\n        self.gen_single_type_test(param_type=torch.float16)\n\n    @unittest.skipIf(torch.cuda.device_count() < 2, \"more than 1 GPU required\")\n    def test_multi_device(self):\n        devices = (\"cuda:0\", \"cuda:1\")\n        for current_dev, tensor_dev in product(devices, devices):\n            with torch.cuda.device(current_dev):\n                self.gen_single_type_test(param_type=torch.float, device=tensor_dev)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/L0/run_optimizers/test_lamb.py",
    "content": "import unittest\nimport os\n\nimport torch\nfrom torch.optim import Optimizer\nimport apex\nfrom apex.multi_tensor_apply import multi_tensor_applier\nfrom itertools import product\n\n\nclass RefLAMB(Optimizer):\n    r\"\"\"Implements Lamb algorithm.\n\n    It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float, optional): learning rate (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability (default: 1e-6)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01)\n\n    .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:\n        https://arxiv.org/abs/1904.00962\n    \"\"\"\n\n    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01):\n        if not 0.0 <= lr:\n            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n        if not 0.0 <= eps:\n            raise ValueError(\"Invalid epsilon value: {}\".format(eps))\n        if not 0.0 <= betas[0] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 0: {}\".format(betas[0]))\n        if not 0.0 <= betas[1] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 1: {}\".format(betas[1]))\n        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)\n        super(RefLAMB, self).__init__(params, defaults)\n        if multi_tensor_applier.available:\n            import amp_C\n\n            self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm\n            # Skip buffer\n            self._dummy_overflow_buf = torch.tensor(\n                [0], dtype=torch.int, device=self.param_groups[0][\"params\"][0].device\n            )\n            self.multi_tensor_lamb = amp_C.multi_tensor_lamb\n        else:\n            raise RuntimeError(\"apex.optimizers.FusedLAMB requires cuda extensions\")\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        # create separate grad lists for fp32, fp16, and bf16 params\n        g_all_32, g_all_16, g_all_bf16 = [], [], []\n        for group in self.param_groups:\n            for p in group[\"params\"]:\n                if p.grad is None:\n                    continue\n                if p.dtype == torch.float32:\n                    g_all_32.append(p.grad.data)\n                elif p.dtype == torch.float16:\n                    g_all_16.append(p.grad.data)\n                elif p.dtype == torch.bfloat16:\n                    g_all_bf16.append(p.grad.data)\n                else:\n                    raise RuntimeError(\"FusedLAMB only support fp16, fp32, and bf16.\")\n\n        device = self.param_groups[0][\"params\"][0].device\n        g_norm_32, g_norm_16, g_norm_bf16 = (\n            torch.zeros(1, device=device),\n            torch.zeros(1, device=device),\n            torch.zeros(1, device=device),\n        )\n        # compute grad norm for two lists\n        if len(g_all_32) > 0:\n            g_norm_32 = multi_tensor_applier(\n                self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_32], False\n            )[0]\n        if len(g_all_16) > 0:\n            g_norm_16 = multi_tensor_applier(\n                self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_16], False\n            )[0]\n        if len(g_all_bf16) > 0:\n            g_norm_bf16 = multi_tensor_applier(\n                self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_bf16], False\n            )[0]\n\n        # blend two grad norms to get global grad norm\n        global_grad_norm = multi_tensor_applier(\n            self.multi_tensor_l2norm,\n            self._dummy_overflow_buf,\n            [[g_norm_32, g_norm_16, g_norm_bf16]],\n            False,\n        )[0]\n\n        max_grad_norm = 1.0\n        clipped_ratio = max_grad_norm / max(global_grad_norm, max_grad_norm)\n\n        for group in self.param_groups:\n            for p in group[\"params\"]:\n                if p.grad is None:\n                    continue\n                p.grad.data *= clipped_ratio\n                grad = p.grad.data\n                if grad.is_sparse:\n                    raise RuntimeError(\n                        \"Lamb does not support sparse gradients, consider SparseAdam instad.\"\n                    )\n\n                state = self.state[p]\n\n                # State initialization\n                if len(state) == 0:\n                    state[\"step\"] = 0\n                    # Exponential moving average of gradient values\n                    state[\"m\"] = torch.zeros_like(p.data)\n                    # Exponential moving average of squared gradient values\n                    state[\"v\"] = torch.zeros_like(p.data)\n\n                m_t, v_t = state[\"m\"], state[\"v\"]\n                beta1, beta2 = group[\"betas\"]\n\n                state[\"step\"] += 1\n\n                # m_t = beta1 * m + (1 - beta1) * g_t\n                m_t.mul_(beta1).add_(grad, alpha=1 - beta1)\n                # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)\n                if len(g_all_16) > 0:\n                    v_t.mul_(beta2)\n                    v_t = v_t.to(torch.float32)\n                    grad32 = grad.to(torch.float32)\n                    v_t.addcmul_(grad32, grad32, value=1 - beta2)\n                else:\n                    v_t.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)\n\n                # Debiasing\n                m_t_hat = m_t / (1.0 - beta1 ** state[\"step\"])\n                v_t_hat = v_t / (1.0 - beta2 ** state[\"step\"])\n\n                update = m_t_hat / v_t_hat.sqrt().add(group[\"eps\"])\n\n                if group[\"weight_decay\"] != 0:\n                    update.add_(p.data, alpha=group[\"weight_decay\"])\n\n                trust_ratio = 1.0\n                w_norm = p.data.to(torch.float32).pow(2).sum().sqrt()\n                g_norm = update.pow(2).sum().sqrt()\n                if w_norm > 0 and g_norm > 0:\n                    trust_ratio = w_norm / g_norm\n\n                state[\"w_norm\"] = w_norm\n                state[\"g_norm\"] = g_norm\n                state[\"trust_ratio\"] = trust_ratio\n\n                step_size = group[\"lr\"]\n\n                p.data.add_(update, alpha=-step_size * trust_ratio)\n\n        return loss\n\n\nclass TestLamb(unittest.TestCase):\n    def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):\n        self.max_abs_diff = max_abs_diff\n        self.max_rel_diff = max_rel_diff\n        self.iters = iters\n        torch.cuda.manual_seed(9876)\n\n    def tearDown(self):\n        pass\n\n    def gen_param_optim(self, tensors, lamb_option):\n        ref_param = []\n        tst_param = []\n        for tensor in tensors:\n            ref_param.append(torch.nn.Parameter(tensor.clone()))\n            tst_param.append(torch.nn.Parameter(tensor.clone()))\n\n        ref_optim = self.ref_optim(ref_param, **lamb_option)\n        tst_optim = self.tst_optim(tst_param, use_nvlamb=True, **lamb_option)\n\n        return (ref_param, tst_param, ref_optim, tst_optim)\n\n    def gen_grad(self, ref_param, tst_param):\n        for p_ref, p_tst in zip(ref_param, tst_param):\n            p_ref.grad = torch.rand_like(p_ref)\n            p_tst.grad = p_ref.grad\n\n    def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):\n        half_grads = []\n        for p_ref, _ in zip(ref_param, tst_param):\n            half_grads.append(torch.rand_like(p_ref).half())\n            p_ref.grad = half_grads[-1].float() / scale\n        return half_grads\n\n    def gen_single_type_test(self, param_type=torch.float, device=\"cuda\"):\n        nelem = 18011\n        tensor = torch.rand(nelem, dtype=param_type, device=device)\n        weight_decay = [0, 0.01]\n\n        for wd in weight_decay:\n            lamb_option = {\n                \"lr\": 5e-4,\n                \"betas\": (0.9, 0.999),\n                \"eps\": 1e-08,\n                \"weight_decay\": wd,\n            }\n            ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], lamb_option)\n\n            if isinstance(tst_optim, apex.optimizers.FusedMixedPrecisionLamb):\n                if param_type != torch.float:\n                    # joseli: This parameter is usually passed into the constructor,\n                    # but I do not want to change the testing interface.\n                    # As long as this parameter is set before the first call to step(),\n                    # then it should act normally.\n                    tst_optim.reduced_precision_dtype = param_type\n            for i in range(self.iters):\n                self.gen_grad(ref_param, tst_param)\n                ref_optim.step()\n                torch.cuda.synchronize()\n                tst_optim.step()\n                torch.cuda.synchronize()\n                torch.testing.assert_close(tst_param, ref_param)\n\n\nclass TestFusedLAMB(TestLamb):\n    def __init__(self, *args, **kwargs):\n        super(TestLamb, self).__init__(*args, **kwargs)\n        self.ref_optim = RefLAMB\n        self.tst_optim = apex.optimizers.FusedLAMB\n\n    def test_float(self):\n        self.gen_single_type_test(param_type=torch.float)\n\n    @unittest.skip(\"PyTorch optimizer is not numerically correct for fp16\")\n    def test_half(self):\n        self.gen_single_type_test(param_type=torch.float16)\n\n    @unittest.skipIf(torch.cuda.device_count() < 2, \"more than 1 GPU required\")\n    def test_multi_device(self):\n        devices = (\"cuda:0\", \"cuda:1\")\n        for current_dev, tensor_dev in product(devices, devices):\n            with torch.cuda.device(current_dev):\n                self.gen_single_type_test(param_type=torch.float, device=tensor_dev)\n\n    def test_multi_params(self):\n        sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]\n        weight_decay = [0, 0.01]\n\n        for wd in weight_decay:\n            lamb_option = {\n                \"lr\": 5e-4,\n                \"betas\": (0.9, 0.999),\n                \"eps\": 1e-08,\n                \"weight_decay\": wd,\n            }\n            tensors = []\n            for size in sizes:\n                tensors.append(torch.rand(size, dtype=torch.float, device=\"cuda\"))\n            ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(tensors, lamb_option)\n\n            for i in range(self.iters):\n                self.gen_grad(ref_param, tst_param)\n                ref_optim.step()\n                tst_optim.step()\n                torch.testing.assert_close(tst_param, ref_param)\n\n    def test_lamb_option(self):\n        nelem = 1\n        tensor = torch.rand(nelem, dtype=torch.float, device=\"cuda\")\n        weight_decay = [0, 0.01]\n\n        for wd in weight_decay:\n            lamb_option = {\n                \"lr\": 0.01,\n                \"betas\": (0.6, 0.9),\n                \"eps\": 3e-06,\n                \"weight_decay\": wd,\n            }\n            ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], lamb_option)\n\n            for i in range(self.iters):\n                self.gen_grad(ref_param, tst_param)\n                ref_optim.step()\n                tst_optim.step()\n                torch.testing.assert_close(tst_param, ref_param)\n\n\nclass TestFusedMixedPrecisionLamb(TestLamb):\n    def __init__(self, *args, **kwargs):\n        super(TestLamb, self).__init__(*args, **kwargs)\n        self.ref_optim = RefLAMB\n        self.tst_optim = apex.optimizers.FusedMixedPrecisionLamb\n\n    def test_float(self):\n        self.gen_single_type_test(param_type=torch.float)\n\n    def test_bfloat16(self):\n        self.iters = 4\n        self.gen_single_type_test(param_type=torch.bfloat16)\n\n    def test_half(self):\n        self.iters = 1\n        self.gen_single_type_test(param_type=torch.float16)\n\n    @unittest.skipIf(torch.cuda.device_count() < 2, \"more than 1 GPU required\")\n    def test_multi_device(self):\n        devices = (\"cuda:0\", \"cuda:1\")\n        for current_dev, tensor_dev in product(devices, devices):\n            with torch.cuda.device(current_dev):\n                self.gen_single_type_test(param_type=torch.float, device=tensor_dev)\n\n    def test_multi_params(self):\n        sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]\n        weight_decay = [0, 0.01]\n\n        for wd in weight_decay:\n            lamb_option = {\n                \"lr\": 5e-4,\n                \"betas\": (0.9, 0.999),\n                \"eps\": 1e-08,\n                \"weight_decay\": wd,\n            }\n            tensors = []\n            for size in sizes:\n                tensors.append(torch.rand(size, dtype=torch.float, device=\"cuda\"))\n            ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(tensors, lamb_option)\n\n            for i in range(self.iters):\n                self.gen_grad(ref_param, tst_param)\n                ref_optim.step()\n                tst_optim.step()\n                torch.testing.assert_close(tst_param, ref_param)\n\n    def test_lamb_option(self):\n        nelem = 1\n        tensor = torch.rand(nelem, dtype=torch.float, device=\"cuda\")\n        weight_decay = [0, 0.01]\n\n        for wd in weight_decay:\n            lamb_option = {\n                \"lr\": 0.01,\n                \"betas\": (0.6, 0.9),\n                \"eps\": 3e-06,\n                \"weight_decay\": wd,\n            }\n            ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], lamb_option)\n\n            for i in range(self.iters):\n                self.gen_grad(ref_param, tst_param)\n                ref_optim.step()\n                tst_optim.step()\n                torch.testing.assert_close(tst_param, ref_param)\n\n\nif __name__ == \"__main__\":\n    script_path = os.path.dirname(os.path.realpath(__file__))\n    unittest.main()\n"
  },
  {
    "path": "tests/L0/run_test.py",
    "content": "\"\"\"L0 Tests Runner.\n\nHow to run this script?\n\n1. Run all the tests: `python /path/to/apex/tests/L0/run_test.py` If you want an xml report,\n    pass `--xml-report`, i.e. `python /path/to/apex/tests/L0/run_test.py --xml-report` and\n    the file is created in `/path/to/apex/tests/L0`.\n2. Run one of the tests (e.g. fused layer norm):\n    `python /path/to/apex/tests/L0/run_test.py --include run_fused_layer_norm`\n3. Run two or more of the tests (e.g. optimizers and fused layer norm):\n    `python /path/to/apex/tests/L0/run_test.py --include run_optimizers run_fused_layer_norm`\n\"\"\"\n\nimport argparse\nimport os\nimport unittest\nimport sys\n\n\nTEST_ROOT = os.path.dirname(os.path.abspath(__file__))\nTEST_DIRS = [\n    \"run_optimizers\",\n    \"run_fused_layer_norm\",\n    \"run_mlp\",\n]\nDEFAULT_TEST_DIRS = [\n    \"run_optimizers\",\n    \"run_fused_layer_norm\",\n    \"run_mlp\",\n]\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"L0 test runner\",\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter,\n    )\n    parser.add_argument(\n        \"--include\",\n        nargs=\"+\",\n        choices=TEST_DIRS,\n        default=DEFAULT_TEST_DIRS,\n        help=\"select a set of tests to run (defaults to ALL tests).\",\n    )\n    parser.add_argument(\n        \"--xml-report\",\n        default=None,\n        action=\"store_true\",\n        help=\"[deprecated] pass this argument to get a junit xml report. Use `--xml-dir`. (requires `xmlrunner`)\",\n    )\n    parser.add_argument(\n        \"--xml-dir\",\n        default=None,\n        type=str,\n        help=\"Directory to save junit test reports. (requires `xmlrunner`)\",\n    )\n    args, _ = parser.parse_known_args()\n    return args\n\n\ndef main(args: argparse.Namespace) -> None:\n    test_runner_kwargs = {\"verbosity\": 2}\n    Runner = unittest.TextTestRunner\n\n    xml_dir = None\n    if (args.xml_report is not None) or (args.xml_dir is not None):\n        if args.xml_report is not None:\n            import warnings\n\n            warnings.warn(\"The option of `--xml-report` is deprecated\", FutureWarning)\n\n        import xmlrunner\n        from datetime import date  # NOQA\n\n        Runner = xmlrunner.XMLTestRunner\n        if args.xml_report:\n            xml_dir = os.path.abspath(os.path.dirname(__file__))\n        else:\n            xml_dir = os.path.abspath(args.xml_dir)\n        if not os.path.exists(xml_dir):\n            os.makedirs(xml_dir)\n\n    errcode = 0\n    for test_dir in args.include:\n        if xml_dir is not None:\n            xml_output = os.path.join(\n                xml_dir,\n                f\"\"\"TEST_{test_dir}_{date.today().strftime(\"%y%m%d\")}\"\"\",\n            )\n            if not os.path.exists(xml_output):\n                os.makedirs(xml_output)\n            test_runner_kwargs[\"output\"] = xml_output\n\n        runner = Runner(**test_runner_kwargs)\n        test_dir = os.path.join(TEST_ROOT, test_dir)\n        suite = unittest.TestLoader().discover(test_dir)\n\n        print(\"\\nExecuting tests from \" + test_dir)\n\n        result = runner.run(suite)\n\n        if not result.wasSuccessful():\n            errcode = 1\n\n    sys.exit(errcode)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "tests/L1/common/compare.py",
    "content": "import argparse\nimport torch\n\nparser = argparse.ArgumentParser(description=\"Compare\")\nparser.add_argument(\"--opt-level\", type=str)\nparser.add_argument(\"--keep-batchnorm-fp32\", type=str, default=None)\nparser.add_argument(\"--loss-scale\", type=str, default=None)\nparser.add_argument(\"--fused-adam\", action=\"store_true\")\nparser.add_argument(\"--use_baseline\", action=\"store_true\")\nargs = parser.parse_args()\n\nbase_file = (\n    str(args.opt_level)\n    + \"_\"\n    + str(args.loss_scale)\n    + \"_\"\n    + str(args.keep_batchnorm_fp32)\n    + \"_\"\n    + str(args.fused_adam)\n)\n\nfile_e = \"True_\" + base_file\nfile_p = \"False_\" + base_file\nif args.use_baseline:\n    file_b = \"baselines/True_\" + base_file\n\ndict_e = torch.load(file_e)\ndict_p = torch.load(file_p)\nif args.use_baseline:\n    dict_b = torch.load(file_b)\n\ntorch.set_printoptions(precision=10)\n\nprint(file_e)\nprint(file_p)\nif args.use_baseline:\n    print(file_b)\n\n# ugly duplication here...\nif not args.use_baseline:\n    for n, (i_e, i_p) in enumerate(zip(dict_e[\"Iteration\"], dict_p[\"Iteration\"])):\n        assert i_e == i_p, \"i_e = {}, i_p = {}\".format(i_e, i_p)\n\n        loss_e = dict_e[\"Loss\"][n]\n        loss_p = dict_p[\"Loss\"][n]\n        assert loss_e == loss_p, \"Iteration {}, loss_e = {}, loss_p = {}\".format(\n            i_e, loss_e, loss_p\n        )\n        print(\n            \"{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f}\".format(\n                i_e, loss_e, loss_p, dict_e[\"Speed\"][n], dict_p[\"Speed\"][n]\n            )\n        )\nelse:\n    for n, (i_e, i_p) in enumerate(zip(dict_e[\"Iteration\"], dict_p[\"Iteration\"])):\n        assert i_e == i_p, \"i_e = {}, i_p = {}\".format(i_e, i_p)\n\n        loss_e = dict_e[\"Loss\"][n]\n        loss_p = dict_p[\"Loss\"][n]\n        loss_b = dict_b[\"Loss\"][n]\n        assert loss_e == loss_p, \"Iteration {}, loss_e = {}, loss_p = {}\".format(\n            i_e, loss_e, loss_p\n        )\n        assert loss_e == loss_b, \"Iteration {}, loss_e = {}, loss_b = {}\".format(\n            i_e, loss_e, loss_b\n        )\n        print(\n            \"{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f}\".format(\n                i_e,\n                loss_b,\n                loss_e,\n                loss_p,\n                dict_b[\"Speed\"][n],\n                dict_e[\"Speed\"][n],\n                dict_p[\"Speed\"][n],\n            )\n        )\n"
  },
  {
    "path": "tests/L1/common/main_amp.py",
    "content": "import argparse\nimport os\nimport shutil\nimport time\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.backends.cudnn as cudnn\nimport torch.distributed as dist\nimport torch.optim\nimport torch.utils.data\nimport torch.utils.data.distributed\nimport torchvision.transforms as transforms\nimport torchvision.datasets as datasets\nimport torchvision.models as models\n\nimport numpy as np\n\ntry:\n    from apex.parallel import DistributedDataParallel as DDP\n    from apex.fp16_utils import *\n    from apex import amp, optimizers\n    from apex.multi_tensor_apply import multi_tensor_applier\nexcept ImportError:\n    raise ImportError(\n        \"Please install apex from https://www.github.com/nvidia/apex to run this example.\"\n    )\n\nmodel_names = sorted(\n    name\n    for name in models.__dict__\n    if name.islower() and not name.startswith(\"__\") and callable(models.__dict__[name])\n)\n\nparser = argparse.ArgumentParser(description=\"PyTorch ImageNet Training\")\nparser.add_argument(\"data\", metavar=\"DIR\", help=\"path to dataset\")\nparser.add_argument(\n    \"--arch\",\n    \"-a\",\n    metavar=\"ARCH\",\n    default=\"resnet18\",\n    choices=model_names,\n    help=\"model architecture: \" + \" | \".join(model_names) + \" (default: resnet18)\",\n)\nparser.add_argument(\n    \"-j\",\n    \"--workers\",\n    default=4,\n    type=int,\n    metavar=\"N\",\n    help=\"number of data loading workers (default: 4)\",\n)\nparser.add_argument(\n    \"--epochs\", default=90, type=int, metavar=\"N\", help=\"number of total epochs to run\"\n)\nparser.add_argument(\n    \"--start-epoch\",\n    default=0,\n    type=int,\n    metavar=\"N\",\n    help=\"manual epoch number (useful on restarts)\",\n)\nparser.add_argument(\n    \"-b\",\n    \"--batch-size\",\n    default=256,\n    type=int,\n    metavar=\"N\",\n    help=\"mini-batch size per process (default: 256)\",\n)\nparser.add_argument(\n    \"--lr\",\n    \"--learning-rate\",\n    default=0.1,\n    type=float,\n    metavar=\"LR\",\n    help=\"Initial learning rate.  Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256.  A warmup schedule will also be applied over the first 5 epochs.\",\n)\nparser.add_argument(\"--momentum\", default=0.9, type=float, metavar=\"M\", help=\"momentum\")\nparser.add_argument(\n    \"--weight-decay\",\n    \"--wd\",\n    default=1e-4,\n    type=float,\n    metavar=\"W\",\n    help=\"weight decay (default: 1e-4)\",\n)\nparser.add_argument(\n    \"--print-freq\",\n    \"-p\",\n    default=10,\n    type=int,\n    metavar=\"N\",\n    help=\"print frequency (default: 10)\",\n)\nparser.add_argument(\n    \"--resume\",\n    default=\"\",\n    type=str,\n    metavar=\"PATH\",\n    help=\"path to latest checkpoint (default: none)\",\n)\nparser.add_argument(\n    \"-e\",\n    \"--evaluate\",\n    dest=\"evaluate\",\n    action=\"store_true\",\n    help=\"evaluate model on validation set\",\n)\nparser.add_argument(\n    \"--pretrained\", dest=\"pretrained\", action=\"store_true\", help=\"use pre-trained model\"\n)\n\nparser.add_argument(\n    \"--prof\",\n    dest=\"prof\",\n    action=\"store_true\",\n    help=\"Only run 10 iterations for profiling.\",\n)\nparser.add_argument(\"--deterministic\", action=\"store_true\")\n\nparser.add_argument(\"--local_rank\", default=0, type=int)\nparser.add_argument(\"--sync_bn\", action=\"store_true\", help=\"enabling apex sync BN.\")\n\nparser.add_argument(\"--has-ext\", action=\"store_true\")\nparser.add_argument(\"--opt-level\", type=str)\nparser.add_argument(\"--keep-batchnorm-fp32\", type=str, default=None)\nparser.add_argument(\"--loss-scale\", type=str, default=None)\nparser.add_argument(\"--fused-adam\", action=\"store_true\")\n\nparser.add_argument(\"--prints-to-process\", type=int, default=10)\n\ncudnn.benchmark = True\n\n\ndef fast_collate(batch):\n    imgs = [img[0] for img in batch]\n    targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)\n    w = imgs[0].size[0]\n    h = imgs[0].size[1]\n    tensor = torch.zeros((len(imgs), 3, h, w), dtype=torch.uint8)\n    for i, img in enumerate(imgs):\n        nump_array = np.asarray(img, dtype=np.uint8)\n        if nump_array.ndim < 3:\n            nump_array = np.expand_dims(nump_array, axis=-1)\n        nump_array = np.rollaxis(nump_array, 2)\n\n        tensor[i] += torch.from_numpy(nump_array)\n\n    return tensor, targets\n\n\nbest_prec1 = 0\nargs = parser.parse_args()\n\n# Let multi_tensor_applier be the canary in the coalmine\n# that verifies if the backend is what we think it is\nassert multi_tensor_applier.available == args.has_ext\n\nprint(\"opt_level = {}\".format(args.opt_level))\nprint(\n    \"keep_batchnorm_fp32 = {}\".format(args.keep_batchnorm_fp32),\n    type(args.keep_batchnorm_fp32),\n)\nprint(\"loss_scale = {}\".format(args.loss_scale), type(args.loss_scale))\n\n\nprint(\"\\nCUDNN VERSION: {}\\n\".format(torch.backends.cudnn.version()))\n\nif args.deterministic:\n    cudnn.benchmark = False\n    cudnn.deterministic = True\n    torch.manual_seed(args.local_rank)\n    torch.set_printoptions(precision=10)\n\n\ndef main():\n    global best_prec1, args\n\n    args.distributed = False\n    if \"WORLD_SIZE\" in os.environ:\n        args.distributed = int(os.environ[\"WORLD_SIZE\"]) > 1\n\n    args.gpu = 0\n    args.world_size = 1\n\n    if args.distributed:\n        args.gpu = args.local_rank % torch.cuda.device_count()\n        torch.cuda.set_device(args.gpu)\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n        args.world_size = torch.distributed.get_world_size()\n\n    assert torch.backends.cudnn.enabled, \"Amp requires cudnn backend to be enabled.\"\n\n    # create model\n    if args.pretrained:\n        print(\"=> using pre-trained model '{}'\".format(args.arch))\n        model = models.__dict__[args.arch](pretrained=True)\n    else:\n        print(\"=> creating model '{}'\".format(args.arch))\n        model = models.__dict__[args.arch]()\n\n    if args.sync_bn:\n        import apex\n\n        print(\"using apex synced BN\")\n        model = apex.parallel.convert_syncbn_model(model)\n\n    model = model.cuda()\n\n    # Scale learning rate based on global batch size\n    args.lr = args.lr * float(args.batch_size * args.world_size) / 256.0\n    if args.fused_adam:\n        optimizer = optimizers.FusedAdam(model.parameters())\n    else:\n        optimizer = torch.optim.SGD(\n            model.parameters(),\n            args.lr,\n            momentum=args.momentum,\n            weight_decay=args.weight_decay,\n        )\n\n    model, optimizer = amp.initialize(\n        model,\n        optimizer,\n        # enabled=False,\n        opt_level=args.opt_level,\n        keep_batchnorm_fp32=args.keep_batchnorm_fp32,\n        loss_scale=args.loss_scale,\n    )\n\n    if args.distributed:\n        # By default, apex.parallel.DistributedDataParallel overlaps communication with\n        # computation in the backward pass.\n        # model = DDP(model)\n        # delay_allreduce delays all communication to the end of the backward pass.\n        model = DDP(model, delay_allreduce=True)\n\n    # define loss function (criterion) and optimizer\n    criterion = nn.CrossEntropyLoss().cuda()\n\n    # Optionally resume from a checkpoint\n    if args.resume:\n        # Use a local scope to avoid dangling references\n        def resume():\n            if os.path.isfile(args.resume):\n                print(\"=> loading checkpoint '{}'\".format(args.resume))\n                checkpoint = torch.load(\n                    args.resume,\n                    map_location=lambda storage, loc: storage.cuda(args.gpu),\n                )\n                args.start_epoch = checkpoint[\"epoch\"]\n                best_prec1 = checkpoint[\"best_prec1\"]\n                model.load_state_dict(checkpoint[\"state_dict\"])\n                optimizer.load_state_dict(checkpoint[\"optimizer\"])\n                print(\n                    \"=> loaded checkpoint '{}' (epoch {})\".format(args.resume, checkpoint[\"epoch\"])\n                )\n            else:\n                print(\"=> no checkpoint found at '{}'\".format(args.resume))\n\n        resume()\n\n    # Data loading code\n    traindir = os.path.join(args.data, \"train\")\n    valdir = os.path.join(args.data, \"val\")\n\n    if args.arch == \"inception_v3\":\n        crop_size = 299\n        val_size = 320  # I chose this value arbitrarily, we can adjust.\n    else:\n        crop_size = 224\n        val_size = 256\n\n    train_dataset = datasets.ImageFolder(\n        traindir,\n        transforms.Compose(\n            [\n                transforms.RandomResizedCrop(crop_size),\n                transforms.RandomHorizontalFlip(),\n                # transforms.ToTensor(), Too slow\n                # normalize,\n            ]\n        ),\n    )\n    val_dataset = datasets.ImageFolder(\n        valdir,\n        transforms.Compose(\n            [\n                transforms.Resize(val_size),\n                transforms.CenterCrop(crop_size),\n            ]\n        ),\n    )\n\n    train_sampler = None\n    val_sampler = None\n    if args.distributed:\n        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)\n        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.batch_size,\n        shuffle=(train_sampler is None),\n        num_workers=args.workers,\n        pin_memory=True,\n        sampler=train_sampler,\n        collate_fn=fast_collate,\n    )\n\n    val_loader = torch.utils.data.DataLoader(\n        val_dataset,\n        batch_size=args.batch_size,\n        shuffle=False,\n        num_workers=args.workers,\n        pin_memory=True,\n        sampler=val_sampler,\n        collate_fn=fast_collate,\n    )\n\n    if args.evaluate:\n        validate(val_loader, model, criterion)\n        return\n\n    for epoch in range(args.start_epoch, args.epochs):\n        if args.distributed:\n            train_sampler.set_epoch(epoch)\n\n        # train for one epoch\n        train(train_loader, model, criterion, optimizer, epoch)\n        if args.prof:\n            break\n        # evaluate on validation set\n        prec1 = validate(val_loader, model, criterion)\n\n        # remember best prec@1 and save checkpoint\n        if args.local_rank == 0:\n            is_best = prec1 > best_prec1\n            best_prec1 = max(prec1, best_prec1)\n            save_checkpoint(\n                {\n                    \"epoch\": epoch + 1,\n                    \"arch\": args.arch,\n                    \"state_dict\": model.state_dict(),\n                    \"best_prec1\": best_prec1,\n                    \"optimizer\": optimizer.state_dict(),\n                },\n                is_best,\n            )\n\n\nclass data_prefetcher:\n    def __init__(self, loader):\n        self.loader = iter(loader)\n        self.stream = torch.cuda.Stream()\n        self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1, 3, 1, 1)\n        self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1, 3, 1, 1)\n        # With Amp, it isn't necessary to manually convert data to half.\n        # if args.fp16:\n        #     self.mean = self.mean.half()\n        #     self.std = self.std.half()\n        self.preload()\n\n    def preload(self):\n        try:\n            self.next_input, self.next_target = next(self.loader)\n        except StopIteration:\n            self.next_input = None\n            self.next_target = None\n            return\n        with torch.cuda.stream(self.stream):\n            self.next_input = self.next_input.cuda(non_blocking=True)\n            self.next_target = self.next_target.cuda(non_blocking=True)\n            # With Amp, it isn't necessary to manually convert data to half.\n            # if args.fp16:\n            #     self.next_input = self.next_input.half()\n            # else:\n            self.next_input = self.next_input.float()\n            self.next_input = self.next_input.sub_(self.mean).div_(self.std)\n\n    def next(self):\n        torch.cuda.current_stream().wait_stream(self.stream)\n        input = self.next_input\n        target = self.next_target\n        self.preload()\n        return input, target\n\n\ndef train(train_loader, model, criterion, optimizer, epoch):\n    batch_time = AverageMeter()\n    data_time = AverageMeter()\n    losses = AverageMeter()\n    top1 = AverageMeter()\n    top5 = AverageMeter()\n\n    # switch to train mode\n    model.train()\n    end = time.time()\n\n    run_info_dict = {\"Iteration\": [], \"Loss\": [], \"Speed\": []}\n\n    prefetcher = data_prefetcher(train_loader)\n    input, target = prefetcher.next()\n    i = -1\n    while input is not None:\n        i += 1\n\n        # No learning rate warmup for this test, to expose bitwise inaccuracies more quickly\n        # adjust_learning_rate(optimizer, epoch, i, len(train_loader))\n\n        if args.prof:\n            if i > 10:\n                break\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        output = model(input)\n        loss = criterion(output, target)\n\n        # measure accuracy and record loss\n        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))\n\n        if args.distributed:\n            reduced_loss = reduce_tensor(loss.data)\n            prec1 = reduce_tensor(prec1)\n            prec5 = reduce_tensor(prec5)\n        else:\n            reduced_loss = loss.data\n\n        losses.update(to_python_float(reduced_loss), input.size(0))\n        top1.update(to_python_float(prec1), input.size(0))\n        top5.update(to_python_float(prec5), input.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n\n        with amp.scale_loss(loss, optimizer) as scaled_loss:\n            scaled_loss.backward()\n\n        # for param in model.parameters():\n        #     print(param.data.double().sum().item(), param.grad.data.double().sum().item())\n\n        # torch.cuda.synchronize()\n        torch.cuda.nvtx.range_push(\"step\")\n        optimizer.step()\n        torch.cuda.nvtx.range_pop()\n\n        torch.cuda.synchronize()\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n\n        end = time.time()\n\n        # If you decide to refactor this test, like examples/imagenet, to sample the loss every\n        # print_freq iterations, make sure to move this prefetching below the accuracy calculation.\n        input, target = prefetcher.next()\n\n        if i % args.print_freq == 0 and i > 1:\n            if args.local_rank == 0:\n                print(\n                    \"Epoch: [{0}][{1}/{2}]\\t\"\n                    \"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t\"\n                    \"Speed {3:.3f} ({4:.3f})\\t\"\n                    \"Data {data_time.val:.3f} ({data_time.avg:.3f})\\t\"\n                    \"Loss {loss.val:.10f} ({loss.avg:.4f})\\t\"\n                    \"Prec@1 {top1.val:.3f} ({top1.avg:.3f})\\t\"\n                    \"Prec@5 {top5.val:.3f} ({top5.avg:.3f})\".format(\n                        epoch,\n                        i,\n                        len(train_loader),\n                        args.world_size * args.batch_size / batch_time.val,\n                        args.world_size * args.batch_size / batch_time.avg,\n                        batch_time=batch_time,\n                        data_time=data_time,\n                        loss=losses,\n                        top1=top1,\n                        top5=top5,\n                    )\n                )\n            run_info_dict[\"Iteration\"].append(i)\n            run_info_dict[\"Loss\"].append(losses.val)\n            run_info_dict[\"Speed\"].append(args.world_size * args.batch_size / batch_time.val)\n            if len(run_info_dict[\"Loss\"]) == args.prints_to_process:\n                if args.local_rank == 0:\n                    torch.save(\n                        run_info_dict,\n                        str(args.has_ext)\n                        + \"_\"\n                        + str(args.opt_level)\n                        + \"_\"\n                        + str(args.loss_scale)\n                        + \"_\"\n                        + str(args.keep_batchnorm_fp32)\n                        + \"_\"\n                        + str(args.fused_adam),\n                    )\n                quit()\n\n\ndef validate(val_loader, model, criterion):\n    batch_time = AverageMeter()\n    losses = AverageMeter()\n    top1 = AverageMeter()\n    top5 = AverageMeter()\n\n    # switch to evaluate mode\n    model.eval()\n\n    end = time.time()\n\n    prefetcher = data_prefetcher(val_loader)\n    input, target = prefetcher.next()\n    i = -1\n    while input is not None:\n        i += 1\n\n        # compute output\n        with torch.no_grad():\n            output = model(input)\n            loss = criterion(output, target)\n\n        # measure accuracy and record loss\n        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))\n\n        if args.distributed:\n            reduced_loss = reduce_tensor(loss.data)\n            prec1 = reduce_tensor(prec1)\n            prec5 = reduce_tensor(prec5)\n        else:\n            reduced_loss = loss.data\n\n        losses.update(to_python_float(reduced_loss), input.size(0))\n        top1.update(to_python_float(prec1), input.size(0))\n        top5.update(to_python_float(prec5), input.size(0))\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if args.local_rank == 0 and i % args.print_freq == 0:\n            print(\n                \"Test: [{0}/{1}]\\t\"\n                \"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t\"\n                \"Speed {2:.3f} ({3:.3f})\\t\"\n                \"Loss {loss.val:.4f} ({loss.avg:.4f})\\t\"\n                \"Prec@1 {top1.val:.3f} ({top1.avg:.3f})\\t\"\n                \"Prec@5 {top5.val:.3f} ({top5.avg:.3f})\".format(\n                    i,\n                    len(val_loader),\n                    args.world_size * args.batch_size / batch_time.val,\n                    args.world_size * args.batch_size / batch_time.avg,\n                    batch_time=batch_time,\n                    loss=losses,\n                    top1=top1,\n                    top5=top5,\n                )\n            )\n\n        input, target = prefetcher.next()\n\n    print(\" * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}\".format(top1=top1, top5=top5))\n\n    return top1.avg\n\n\ndef save_checkpoint(state, is_best, filename=\"checkpoint.pth.tar\"):\n    torch.save(state, filename)\n    if is_best:\n        shutil.copyfile(filename, \"model_best.pth.tar\")\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\ndef adjust_learning_rate(optimizer, epoch, step, len_epoch):\n    \"\"\"LR schedule that should yield 76% converged accuracy with batch size 256\"\"\"\n    factor = epoch // 30\n\n    if epoch >= 80:\n        factor = factor + 1\n\n    lr = args.lr * (0.1**factor)\n\n    \"\"\"Warmup\"\"\"\n    if epoch < 5:\n        lr = lr * float(1 + step + epoch * len_epoch) / (5.0 * len_epoch)\n\n    # if(args.local_rank == 0):\n    #     print(\"epoch = {}, step = {}, lr = {}\".format(epoch, step, lr))\n\n    for param_group in optimizer.param_groups:\n        param_group[\"lr\"] = lr\n\n\ndef accuracy(output, target, topk=(1,)):\n    \"\"\"Computes the precision@k for the specified values of k\"\"\"\n    maxk = max(topk)\n    batch_size = target.size(0)\n\n    _, pred = output.topk(maxk, 1, True, True)\n    pred = pred.t()\n    correct = pred.eq(target.view(1, -1).expand_as(pred))\n\n    res = []\n    for k in topk:\n        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)\n        res.append(correct_k.mul_(100.0 / batch_size))\n    return res\n\n\ndef reduce_tensor(tensor):\n    rt = tensor.clone()\n    dist.all_reduce(rt, op=dist.reduce_op.SUM)\n    rt /= args.world_size\n    return rt\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tests/L1/common/run_test.sh",
    "content": "#!/bin/bash\n\nprint_banner() {\n  printf \"\\n\\n\\n\\e[30m\\e[42m$1\\e[0m\\n\\n\\n\\n\"\n}\n\nprint_banner \"Distributed status:  $1\"\n\necho $2\nDATADIR=$2\n\nif [ -n \"$3\" ]\nthen\n  USE_BASELINE=\"\"\nelse\n  USE_BASELINE=\"--use_baseline\"\nfi\n\nif [ \"$1\" == \"single_gpu\" ]\nthen\n  BASE_CMD=\"python main_amp.py -a resnet50 --b 128 --workers 4 --deterministic --prints-to-process 5\"\nfi\n\nif [ \"$1\" == \"distributed\" ]\nthen\n  BASE_CMD=\"python -m torch.distributed.launch --nproc_per_node=2 main_amp.py -a resnet50 --b 128 --workers 4 --deterministic --prints-to-process 5\"\nfi\n\nADAM_ARGS=\"--opt-level O2 --keep-batchnorm-fp32 False --fused-adam\"\n\nkeep_batchnorms=(\n\"\"\n\"--keep-batchnorm-fp32 True\"\n\"--keep-batchnorm-fp32 False\"\n)\n\nloss_scales=(\n\"\"\n\"--loss-scale 1.0\"\n\"--loss-scale 128.0\"\n\"--loss-scale dynamic\"\n)\n\nopt_levels=(\n\"O0\"\n\"O1\"\n\"O2\"\n\"O3\"\n)\n\nrm True*\nrm False*\n\nset -e\n\nprint_banner \"Installing Apex with --cuda_ext and --cpp_ext\"\n\npushd ../../..\npip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" .\npopd\n\nfor opt_level in \"${opt_levels[@]}\"\ndo\n  for loss_scale in \"${loss_scales[@]}\"\n  do\n    for keep_batchnorm in \"${keep_batchnorms[@]}\"\n    do\n      if [ \"$opt_level\" == \"O1\" ] && [ -n \"${keep_batchnorm}\" ]\n      then\n        print_banner \"Skipping ${opt_level} ${loss_scale} ${keep_batchnorm}\"\n        continue\n      fi\n      print_banner \"${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR\"\n      set -x\n      ${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR\n      set +x\n    done\n  done\ndone\n\n# Handle FusedAdam separately due to limited support.\n# FusedAdam will not be tested for bitwise accuracy against the Python implementation.\n# The L0 tests already do so.  These tests are here to ensure that it actually runs,\n# and get an idea of performance.\nfor loss_scale in \"${loss_scales[@]}\"\ndo\n  print_banner \"${BASE_CMD} ${ADAM_ARGS} ${loss_scale} --has-ext $DATADIR\"\n  set -x\n  ${BASE_CMD} ${ADAM_ARGS} ${loss_scale} --has-ext $DATADIR\n  set +x\ndone\n\nprint_banner \"Reinstalling apex without extensions\"\n\npushd ../../..\npip install -v --no-cache-dir .\npopd\n\nfor opt_level in \"${opt_levels[@]}\"\ndo\n  for loss_scale in \"${loss_scales[@]}\"\n  do\n    for keep_batchnorm in \"${keep_batchnorms[@]}\"\n    do\n      if [ \"$opt_level\" == \"O1\" ] && [ -n \"${keep_batchnorm}\" ]\n      then\n        print_banner \"Skipping ${opt_level} ${loss_scale} ${keep_batchnorm}\"\n        continue\n      fi\n      print_banner \"${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} $DATADIR\"\n      set -x\n      ${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} $DATADIR\n      set +x\n    done\n  done\ndone\n\nprint_banner \"Checking for bitwise accuracy between Python-only and cpp/cuda extension installs\"\n\nfor opt_level in \"${opt_levels[@]}\"\ndo\n  for loss_scale in \"${loss_scales[@]}\"\n  do\n    for keep_batchnorm in \"${keep_batchnorms[@]}\"\n    do\n      echo \"\"\n      if [ \"$opt_level\" == \"O1\" ] && [ -n \"${keep_batchnorm}\" ]\n      then\n        echo \"Skipping ${opt_level} ${loss_scale} ${keep_batchnorm}\"\n        continue\n      fi\n      echo \"${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} [--has-ext] $DATADIR\"\n      set -x\n      python compare.py --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} --use_baseline\n      set +x\n    done\n  done\ndone\n\nprint_banner \"Reinstalling Apex with --cuda_ext and --cpp_ext\"\n\npushd ../../..\npip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" .\npopd\n"
  },
  {
    "path": "tests/L1/cross_product/run.sh",
    "content": "#!/bin/bash\n\n# DATADIR=\"/home/mcarilli/Desktop/pt18data/apex_stale/examples/imagenet/bare_metal_train_val/\"\n# DATADIR=\"/opt/home/apex/examples/imagenet/\"\ncp ../common/* .\nbash run_test.sh single_gpu $1\n"
  },
  {
    "path": "tests/L1/cross_product_distributed/run.sh",
    "content": "#!/bin/bash\n\ncp ../common/* .\nbash run_test.sh distributed $1\n"
  },
  {
    "path": "tests/distributed/DDP/ddp_race_condition_test.py",
    "content": "import torch\nfrom torch.nn import Parameter\nfrom torch.nn import Module\nfrom apex.parallel import DistributedDataParallel as DDP\nimport argparse\nimport os\n\n\nparser = argparse.ArgumentParser(description=\"allreduce hook example\")\nparser.add_argument(\"--local_rank\", default=0, type=int)\nargs = parser.parse_args()\n\nargs.distributed = False\nif \"WORLD_SIZE\" in os.environ:\n    args.distributed = int(os.environ[\"WORLD_SIZE\"]) > 1\n\nif args.distributed:\n    args.gpu = args.local_rank % torch.cuda.device_count()\n    torch.cuda.set_device(args.gpu)\n    torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n    args.world_size = torch.distributed.get_world_size()\n\ntorch.set_printoptions(precision=10)\ntorch.manual_seed(args.local_rank)\n\n\nclass Model(Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        self.a = Parameter(torch.cuda.FloatTensor(4096 * 4096).fill_(1.0))\n        self.b = Parameter(torch.cuda.FloatTensor(4096 * 4096).fill_(2.0))\n\n    def forward(self, input):\n        return (input * self.a) * self.b\n\n\nmodel = Model()\n# model = DDP(model, message_size=1, gradient_predivide_factor=8.0)\n# model = DDP(model, delay_allreduce=True)\n# model = DDP(model, message_size=1, allreduce_trigger_params=[model.b])\nmodel = DDP(model, message_size=1, allreduce_trigger_params=[model.b], num_allreduce_streams=3)\n\nx = torch.cuda.FloatTensor(4096 * 4096)\n\npassed = True\ntorch.cuda.cudart().cudaProfilerStart()\nfor i in range(10):\n    x.fill_(i + args.local_rank)  # fill x with new values every iteration for sanity\n    model.zero_grad()\n    out = model(x)\n    loss = out.sum()\n    # torch.cuda.nvtx.range_push(\"backward\")\n    loss.backward()\n    # torch.cuda.nvtx.range_pop()\n\n    # torch.cuda.nvtx.range_push(\"synchronize() + info\")\n    # torch.cuda.synchronize()\n    print(\"i = {}\".format(i))\n\n    def info(name, param, val):\n        expected = val * 4096 * 4096 * (2.0 * i + 1) / 2.0\n        actual = param.grad.data.sum().item()\n        print(\n            name\n            + \": grad.data_ptr() = {}, expected sum {}, got {}\".format(\n                param.grad.data_ptr(), expected, actual\n            )\n        )\n        return expected == actual\n\n    if not info(\"model.a\", model.module.a, 2.0):\n        passed = False\n    if not info(\"model.b\", model.module.b, 1.0):\n        passed = False\n    # torch.cuda.nvtx.range_pop()\ntorch.cuda.cudart().cudaProfilerStop()\n\nprint(\"passed = \", passed)\n"
  },
  {
    "path": "tests/distributed/DDP/run_race_test.sh",
    "content": "#!/bin/bash\n\nCUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 ddp_race_condition_test.py\n"
  },
  {
    "path": "tests/distributed/amp_master_params/amp_master_params.py",
    "content": "import torch\nimport argparse\nimport os\nfrom apex import amp\n\n# FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead)\nfrom apex.parallel import DistributedDataParallel\n\nparser = argparse.ArgumentParser()\n# FOR DISTRIBUTED:  Parse for the local_rank argument, which will be supplied\n# automatically by torch.distributed.launch.\nparser.add_argument(\"--local_rank\", default=0, type=int)\nargs = parser.parse_args()\n\n# FOR DISTRIBUTED:  If we are running under torch.distributed.launch,\n# the 'WORLD_SIZE' environment variable will also be set automatically.\nargs.distributed = False\nif \"WORLD_SIZE\" in os.environ:\n    args.distributed = int(os.environ[\"WORLD_SIZE\"]) > 1\n\nif args.distributed:\n    # FOR DISTRIBUTED:  Set the device according to local_rank.\n    torch.cuda.set_device(args.local_rank)\n\n    # FOR DISTRIBUTED:  Initialize the backend.  torch.distributed.launch will provide\n    # environment variables, and requires that you use init_method=`env://`.\n    torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n\n    torch.manual_seed(torch.distributed.get_rank())\n\ntorch.backends.cudnn.benchmark = True\n\nN, D_in, D_out = 64, 1024, 16\n\n# Each process receives its own batch of \"fake input data\" and \"fake target data.\"\n# The \"training loop\" in each process just uses this fake batch over and over.\n# https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic\n# example of distributed data sampling for both training and validation.\nx = torch.randn(N, D_in, device=\"cuda\")\ny = torch.randn(N, D_out, device=\"cuda\")\n\nmodel = torch.nn.Linear(D_in, D_out).cuda()\noptimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n\nmodel, optimizer = amp.initialize(model, optimizer, opt_level=\"O2\")\n\nif args.distributed:\n    # FOR DISTRIBUTED:  After amp.initialize, wrap the model with\n    # apex.parallel.DistributedDataParallel.\n    model = DistributedDataParallel(model)\n    # torch.nn.parallel.DistributedDataParallel is also fine, with some added args:\n    # model = torch.nn.parallel.DistributedDataParallel(model,\n    #                                                   device_ids=[args.local_rank],\n    #                                                   output_device=args.local_rank)\n\nloss_fn = torch.nn.MSELoss()\n\nfor t in range(500):\n    optimizer.zero_grad()\n    y_pred = model(x)\n    loss = loss_fn(y_pred, y)\n    with amp.scale_loss(loss, optimizer) as scaled_loss:\n        scaled_loss.backward()\n    optimizer.step()\n\nif args.local_rank == 0:\n    print(\"final loss = \", loss)\n\ntorch.save(list(model.parameters()), \"rank{}model.pth\".format(torch.distributed.get_rank()))\ntorch.save(\n    list(amp.master_params(optimizer)),\n    \"rank{}master.pth\".format(torch.distributed.get_rank()),\n)\n"
  },
  {
    "path": "tests/distributed/amp_master_params/compare.py",
    "content": "import torch\n\nmodel_params_rank0 = torch.load(\"rank0model.pth\", map_location=lambda storage, loc: storage.cuda(0))\nmodel_params_rank1 = torch.load(\"rank1model.pth\", map_location=lambda storage, loc: storage.cuda(0))\nmaster_params_rank0 = torch.load(\n    \"rank0master.pth\", map_location=lambda storage, loc: storage.cuda(0)\n)\nmaster_params_rank1 = torch.load(\n    \"rank1master.pth\", map_location=lambda storage, loc: storage.cuda(0)\n)\n\nfor model_rank0, model_rank1, master_rank0, master_rank1 in zip(\n    model_params_rank0, model_params_rank1, master_params_rank0, master_params_rank1\n):\n    assert torch.allclose(model_rank0, model_rank1), \"Model param mismatch\"\n    assert torch.allclose(master_rank0, master_rank1), \"Master param mismatch\"\n    # Some debugging/investigation assistance code:\n    # maxval, maxind = torch.max(((torch.abs(model_rank0).float())/torch.abs(master_rank0)).view(-1), 0)\n    # offending_val_half = model_rank0.view(-1)[maxind.item()]\n    # offending_val_float = master_rank0.view(-1)[maxind.item()]\n    # print(maxval.item(), maxind.item(), offending_val_half.item(), offending_val_float.item(),\n    #       offending_val_float.half().item())\n    # rtol needs to be > 2^-11 because of denormals...\n    assert torch.allclose(model_rank0, master_rank0.half(), rtol=0.005), \"Model-master mismatch\"\n\nprint(\"OK:  Model and master params match across ranks.\")\n"
  },
  {
    "path": "tests/distributed/amp_master_params/run.sh",
    "content": "#!/bin/bash\npython -m torch.distributed.launch --nproc_per_node=2 amp_master_params.py\n\npython compare.py\n"
  },
  {
    "path": "tests/distributed/synced_batchnorm/python_single_gpu_unit_test.py",
    "content": "import torch\nimport numpy as np\n\n\ndef compare(desc, inp1, inp2, error):\n    a = inp1.clone().detach().cpu().numpy()\n    b = inp2.clone().detach().cpu().numpy()\n    close = np.allclose(a, b, error, error)\n    if not close:\n        print(desc, close)\n        z = a - b\n        index = (np.abs(z) >= error + error * np.abs(b)).nonzero()\n        print(\"dif    : \", z[index])\n        print(\"inp1   : \", a[index])\n        print(\"inp2   : \", b[index])\n    return close\n\n\nfeature_size = 10\nspace_size = 16\nbatch_size = 5\n\n\nerror = 1e-5\n\nnp.random.seed(1)\ndtype = np.float32\ninp = (np.random.randn(batch_size, feature_size, space_size, space_size)).astype(dtype)\ngrad = (np.random.randn(batch_size, feature_size, space_size, space_size)).astype(dtype)\nweight = (np.random.randn(feature_size)).astype(dtype)\nbias = (np.random.randn(feature_size)).astype(dtype)\n\ntype_tensor = torch.cuda.FloatTensor\nref_tensor = torch.cuda.DoubleTensor\n\ninp_t = type_tensor(inp)\nweight_t = type_tensor(weight)\nbias_t = type_tensor(bias)\n\ninp_r = ref_tensor(inp.transpose(1, 0, 2, 3).reshape(feature_size, -1))\ninp2_r = ref_tensor(inp)\nweight_r = ref_tensor(weight).view(-1, 1, 1)\nbias_r = ref_tensor(bias).view(-1, 1, 1)\n\ngrad_output_t = type_tensor(grad)\n\nm = inp_r.mean(1)\nb_v = inp_r.var(1, unbiased=False)\nunb_v = inp_r.var(1, unbiased=True)\n\neps = 1e-5\n\nbn = torch.nn.BatchNorm2d(feature_size).cuda()\nbn.momentum = 1.0\nbn.weight.data = weight_t.clone()\nbn.bias.data = bias_t.clone()\ninp_bn = inp_t.clone().requires_grad_()\ngrad_bn = grad_output_t.clone().detach()\nout_bn = bn(inp_bn)\nout_bn.backward(grad_bn)\n\nfrom apex.parallel.sync_batchnorm import SyncBatchNorm\n\nsbn = SyncBatchNorm(feature_size).cuda()\nsbn.momentum = 1.0\nsbn.weight.data = weight_t.clone()\nsbn.bias.data = bias_t.clone()\ninp_sbn = inp_t.clone().requires_grad_()\ngrad_sbn = grad_output_t.clone().detach()\nout_sbn = sbn(inp_sbn)\nout_sbn.backward(grad_sbn)\n\nsbn_result = True\nsbn_result_c_last = True\nbn_result = True\n\nout_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1, 1, 1) + eps) + bias_r\n\ncompare(\"comparing bn output: \", out_bn, out_r, error)\n\ngrad_output_t = type_tensor(grad)\n\ngrad_output_r = ref_tensor(grad.transpose(1, 0, 2, 3).reshape(feature_size, -1))\ngrad_output2_r = ref_tensor(grad)\n\ngrad_bias_r = grad_output_r.sum(1)\ngrad_weight_r = (\n    ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1, 1, 1) + eps) * grad_output2_r)\n    .transpose(1, 0)\n    .contiguous()\n    .view(feature_size, -1)\n    .sum(1)\n)\n\nmean_dy_r = grad_output_r.mean(1)\nmean_dy_xmu_r = (\n    ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r)\n    .transpose(1, 0)\n    .contiguous()\n    .view(feature_size, -1)\n    .mean(1)\n)\n\ngrad_input_r = (\n    (\n        grad_output2_r\n        - mean_dy_r.view(-1, 1, 1)\n        - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1, 1, 1) + eps) * mean_dy_xmu_r.view(-1, 1, 1)\n    )\n    * torch.rsqrt(b_v.view(-1, 1, 1) + eps)\n    * weight_r.view(-1, 1, 1)\n)\n\ncompare(\"comparing bn input grad: \", inp_bn.grad, grad_input_r, error)\nsbn_result = compare(\"comparing sbn input grad: \", inp_sbn.grad, grad_input_r, error) and sbn_result\n\ncompare(\"comparing bn/sbn output: \", out_bn, out_sbn, error)\nsbn_result = (\n    compare(\"comparing running_mean: \", bn.running_mean.data, sbn.running_mean.data, error)\n    and sbn_result\n)\nsbn_result = (\n    compare(\"comparing running_variance: \", bn.running_var.data, sbn.running_var.data, error)\n    and sbn_result\n)\ncompare(\"comparing grad_input: \", inp_bn.grad, inp_sbn.grad, error)\ncompare(\"comparing grad_bias: \", bn.bias.grad, sbn.bias.grad, error)\ncompare(\"comparing grad_bias bn to ref: \", bn.bias.grad, grad_bias_r, error)\nsbn_result = (\n    compare(\"comparing grad_bias sbn to ref: \", sbn.bias.grad, grad_bias_r, error) and sbn_result\n)\ncompare(\"comparing grad_weight: \", bn.weight.grad, sbn.weight.grad, error)\ncompare(\"comparing grad_weight bn to ref: \", bn.weight.grad, grad_weight_r, error)\nsbn_result = (\n    compare(\"comparing grad_weight sbn to ref: \", sbn.weight.grad, grad_weight_r, error)\n    and sbn_result\n)\n\nif sbn_result:\n    print(\"====SBN single gpu passed tests\")\nelse:\n    print(\"*SBN single gpu failed*\")\n"
  },
  {
    "path": "tests/distributed/synced_batchnorm/single_gpu_unit_test.py",
    "content": "import torch\nimport numpy as np\nimport apex\n\nif True:\n    print(\"using setup tools\")\n    import syncbn\nelse:\n    print(\"using jit\")\n    from torch.utils.cpp_extension import load\n\n    syncbn = load(name=\"syncbn\", sources=[\"../../csrc/syncbn.cpp\", \"../../csrc/welford.cu\"])\n\n\ndef compare(desc, inp1, inp2, error):\n    a = inp1.clone().detach().cpu().numpy()\n    b = inp2.clone().detach().cpu().numpy()\n    close = np.allclose(a, b, error, error)\n    if not close:\n        print(desc, close)\n        z = a - b\n        index = (np.abs(z) >= error + error * np.abs(b)).nonzero()\n        print(\"dif    : \", z[index])\n        print(\"inp1   : \", a[index])\n        print(\"inp2   : \", b[index])\n    return close\n\n\nfeature_size = 10\nspace_size = 16\nbatch_size = 5\n\n\nerror = 1e-5\n\nnp.random.seed(1)\ndtype = np.float32\ninp = (np.random.randn(batch_size, feature_size, space_size, space_size)).astype(dtype)\ngrad = (np.random.randn(batch_size, feature_size, space_size, space_size)).astype(dtype)\nweight = (np.random.randn(feature_size)).astype(dtype)\nbias = (np.random.randn(feature_size)).astype(dtype)\ncount = torch.cuda.IntTensor([batch_size * space_size**2])\n\ntype_tensor = torch.cuda.FloatTensor\nref_tensor = torch.cuda.DoubleTensor\n\ninp_t = type_tensor(inp)\nweight_t = type_tensor(weight)\nbias_t = type_tensor(bias)\n\ninp_r = ref_tensor(inp.transpose(1, 0, 2, 3).reshape(feature_size, -1))\ninp2_r = ref_tensor(inp)\nweight_r = ref_tensor(weight).view(-1, 1, 1)\nbias_r = ref_tensor(bias).view(-1, 1, 1)\n\ngrad_output_t = type_tensor(grad)\n\nm = inp_r.mean(1)\nb_v = inp_r.var(1, unbiased=False)\nunb_v = inp_r.var(1, unbiased=True)\n\neps = 1e-5\n\n# mean, var, var_biased = syncbn.welford_mean_var(inp_t)\nmean, var_biased = syncbn.welford_mean_var(inp_t)\ninv_std = 1.0 / torch.sqrt(var_biased + eps)\n\nbn = torch.nn.BatchNorm2d(feature_size).cuda()\nbn.momentum = 1.0\nbn.weight.data = weight_t.clone()\nbn.bias.data = bias_t.clone()\ninp_bn = inp_t.clone().requires_grad_()\ngrad_bn = grad_output_t.clone().detach()\nout_bn = bn(inp_bn)\nout_bn.backward(grad_bn)\n\nsbn = apex.parallel.SyncBatchNorm(feature_size).cuda()\nsbn.momentum = 1.0\nsbn.weight.data = weight_t.clone()\nsbn.bias.data = bias_t.clone()\ninp_sbn = inp_t.clone().requires_grad_()\ngrad_sbn = grad_output_t.clone().detach()\nout_sbn = sbn(inp_sbn)\nout_sbn.backward(grad_sbn)\n\nsbn_c_last = apex.parallel.SyncBatchNorm(feature_size, channel_last=True).cuda()\nsbn_c_last.momentum = 1.0\nsbn_c_last.weight.data = weight_t.clone()\nsbn_c_last.bias.data = bias_t.clone()\ninp_sbn_c_last = inp_t.clone().transpose(-1, 1).contiguous().requires_grad_()\ngrad_sbn_c_last = grad_output_t.clone().transpose(-1, 1).contiguous().detach()\nout_sbn_c_last = sbn_c_last(inp_sbn_c_last)\nout_sbn_c_last.backward(grad_sbn_c_last)\n\nsbn_result = True\nsbn_result_c_last = True\nbn_result = True\n\nsbn_result = compare(\"comparing mean: \", mean, m, error) and sbn_result\n# sbn_result = compare(\"comparing variance: \", var, unb_v, error) and sbn_result\nsbn_result = compare(\"comparing biased variance: \", var_biased, b_v, error) and sbn_result\n\n\nout = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t)\nout_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1, 1, 1) + eps) + bias_r\n\nsbn_result = compare(\"comparing output: \", out, out_r, error) and sbn_result\ncompare(\"comparing bn output: \", out_bn, out_r, error)\n\ngrad_output_t = type_tensor(grad)\n\ngrad_output_r = ref_tensor(grad.transpose(1, 0, 2, 3).reshape(feature_size, -1))\ngrad_output2_r = ref_tensor(grad)\n\ngrad_bias_r = grad_output_r.sum(1)\ngrad_weight_r = (\n    ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1, 1, 1) + eps) * grad_output2_r)\n    .transpose(1, 0)\n    .contiguous()\n    .view(feature_size, -1)\n    .sum(1)\n)\n\nsum_dy_r = grad_output_r.sum(1)\nmean_dy_r = grad_output_r.mean(1)\nsum_dy_xmu_r = (\n    ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r)\n    .transpose(1, 0)\n    .contiguous()\n    .view(feature_size, -1)\n    .sum(1)\n)\nmean_dy_xmu_r = (\n    ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r)\n    .transpose(1, 0)\n    .contiguous()\n    .view(feature_size, -1)\n    .mean(1)\n)\n\ngrad_input_r = (\n    (\n        grad_output2_r\n        - mean_dy_r.view(-1, 1, 1)\n        - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1, 1, 1) + eps) * mean_dy_xmu_r.view(-1, 1, 1)\n    )\n    * torch.rsqrt(b_v.view(-1, 1, 1) + eps)\n    * weight_r.view(-1, 1, 1)\n)\n\nsum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(\n    grad_output_t, inp_t, mean, inv_std, weight_t\n)\ngrad_input = syncbn.batchnorm_backward(\n    grad_output_t, inp_t, mean, inv_std, weight_t, sum_dy, sum_dy_xmu, count\n)\nsbn_result = compare(\"comparing bias grad: \", grad_bias, grad_bias_r, error) and sbn_result\nsbn_result = compare(\"comparing weight grad: \", grad_weight, grad_weight_r, error) and sbn_result\nsbn_result = compare(\"comparing sum_dy grad: \", sum_dy, sum_dy_r, error) and sbn_result\nsbn_result = compare(\"comparing sum_dy_xmu grad: \", sum_dy_xmu, sum_dy_xmu_r, error) and sbn_result\nsbn_result = compare(\"comparing input grad: \", grad_input, grad_input_r, error) and sbn_result\ncompare(\"comparing bn input grad: \", inp_bn.grad, grad_input_r, error)\nsbn_result = compare(\"comparing sbn input grad: \", inp_sbn.grad, grad_input_r, error) and sbn_result\n\ncompare(\"comparing bn/sbn output: \", out_bn, out_sbn, error)\nsbn_result = (\n    compare(\"comparing running_mean: \", bn.running_mean.data, sbn.running_mean.data, error)\n    and sbn_result\n)\nsbn_result = (\n    compare(\"comparing running_variance: \", bn.running_var.data, sbn.running_var.data, error)\n    and sbn_result\n)\ncompare(\"comparing grad_input: \", inp_bn.grad, inp_sbn.grad, error)\ncompare(\"comparing grad_bias: \", bn.bias.grad, sbn.bias.grad, error)\ncompare(\"comparing grad_bias bn to ref: \", bn.bias.grad, grad_bias_r, error)\nsbn_result = (\n    compare(\"comparing grad_bias sbn to ref: \", sbn.bias.grad, grad_bias_r, error) and sbn_result\n)\ncompare(\"comparing grad_weight: \", bn.weight.grad, sbn.weight.grad, error)\ncompare(\"comparing grad_weight bn to ref: \", bn.weight.grad, grad_weight_r, error)\nsbn_result = (\n    compare(\"comparing grad_weight sbn to ref: \", sbn.weight.grad, grad_weight_r, error)\n    and sbn_result\n)\n\ncompare(\n    \"comparing channel last bn/sbn output: \",\n    out_bn,\n    out_sbn_c_last.transpose(-1, 1).contiguous(),\n    error,\n)\nsbn_result_c_last = (\n    compare(\n        \"comparing channel last running_mean: \",\n        bn.running_mean.data,\n        sbn_c_last.running_mean.data,\n        error,\n    )\n    and sbn_result_c_last\n)\nsbn_result_c_last = (\n    compare(\n        \"comparing channel last running_variance: \",\n        bn.running_var.data,\n        sbn_c_last.running_var.data,\n        error,\n    )\n    and sbn_result_c_last\n)\ncompare(\n    \"comparing channel last grad_input: \",\n    inp_bn.grad,\n    inp_sbn_c_last.grad.transpose(-1, 1).contiguous(),\n    error,\n)\ncompare(\"comparing channel last grad_bias: \", bn.bias.grad, sbn_c_last.bias.grad, error)\nsbn_result_c_last = (\n    compare(\n        \"comparing channel last grad_bias sbn to ref: \",\n        sbn_c_last.bias.grad,\n        grad_bias_r,\n        error,\n    )\n    and sbn_result_c_last\n)\ncompare(\n    \"comparing channel last grad_weight: \",\n    bn.weight.grad,\n    sbn_c_last.weight.grad,\n    error,\n)\nsbn_result_c_last = (\n    compare(\n        \"comparing channel last grad_weight sbn to ref: \",\n        sbn_c_last.weight.grad,\n        grad_weight_r,\n        error,\n    )\n    and sbn_result_c_last\n)\n\nif sbn_result:\n    print(\"====SBN single gpu passed tests\")\nelse:\n    print(\"*SBN single gpu failed*\")\n\nif sbn_result_c_last:\n    print(\"====SBN channel last single gpu passed tests\")\nelse:\n    print(\"*SBN channel last single gpu failed*\")\n"
  },
  {
    "path": "tests/distributed/synced_batchnorm/test_batchnorm1d.py",
    "content": "import torch\nimport apex\n\nmodel = apex.parallel.SyncBatchNorm(4).cuda()\nmodel.weight.data.uniform_()\nmodel.bias.data.uniform_()\ndata = torch.rand((8, 4)).cuda()\n\nmodel_ref = torch.nn.BatchNorm1d(4).cuda()\nmodel_ref.load_state_dict(model.state_dict())\ndata_ref = data.clone()\n\noutput = model(data)\noutput_ref = model_ref(data_ref)\n\nassert output.allclose(output_ref)\nassert model.running_mean.allclose(model_ref.running_mean)\nassert model.running_var.allclose(model_ref.running_var)\n"
  },
  {
    "path": "tests/distributed/synced_batchnorm/test_groups.py",
    "content": "import torch\nimport numpy as np\nimport apex\nimport syncbn\nimport os\nimport argparse\nimport torch.optim as optim\n\n\ndef compare(desc, inp1, inp2, error):\n    a = inp1.clone().detach().cpu().numpy()\n    b = inp2.clone().detach().cpu().numpy()\n    close = np.allclose(a, b, error, error)\n    if not close:\n        print(desc, close)\n        z = a - b\n        index = (np.abs(z) >= error + error * np.abs(b)).nonzero()\n        print(\"dif    : \", z[index])\n        print(\"inp1   : \", a[index])\n        print(\"inp2   : \", b[index])\n    return close\n\n\nfeature_size = 10\nspace_size = 40\nbatch_size = 32\n\n\nfrom apex.parallel import DistributedDataParallel as DDP\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--local_rank\", default=0, type=int)\nparser.add_argument(\"--fp16\", action=\"store_true\", default=False)\nparser.add_argument(\"--fp64\", action=\"store_true\", default=False)\nparser.add_argument(\"--group_size\", default=0, type=int)\nargs = parser.parse_args()\n\ntry:\n    args.world_size = int(os.environ[\"WORLD_SIZE\"])\nexcept:\n    print(\n        \"This is a multi-gpu test. To run it please use 'python -m torch.distributed.launch --nproc_per_node=<num gpus> test_groups.py <more options>'\"\n    )\n    exit(1)\n\ntorch.cuda.set_device(args.local_rank)\ntorch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n\nstart = (args.local_rank % args.group_size) * batch_size // args.group_size\nfinish = (args.local_rank % args.group_size + 1) * batch_size // args.group_size\n\nerror = 1e-5\ndtype = np.float32\nif args.fp16:\n    error = 1e-3\n    dtype = np.float16\nelif args.fp64:\n    error = 1e-8\n    dtype = np.float64\n\n\nnp.random.seed(18 + args.local_rank // args.group_size)\n\ninp = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype)\ngrad = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype)\nweight = np.random.randn(feature_size).astype(dtype)\nbias = np.random.randn(feature_size).astype(dtype)\n\n\ntype_tensor = torch.cuda.FloatTensor\nif args.fp16:\n    type_tensor = torch.cuda.HalfTensor\nif args.fp64:\n    type_tensor = torch.cuda.DoubleTensor\n\nref_tensor = torch.cuda.DoubleTensor\n\ninp_t = type_tensor(inp)\nweight_t = type_tensor(weight)\nbias_t = type_tensor(bias)\n\ninp_r = ref_tensor(inp.transpose(1, 0, 2, 3).reshape(feature_size, -1))\ninp2_r = ref_tensor(inp)\nweight_r = ref_tensor(weight).view(-1, 1, 1)\nbias_r = ref_tensor(bias).view(-1, 1, 1)\n\ngrad_output_t = type_tensor(grad)\n\nm = inp_r.mean(1)\nb_v = inp_r.var(1, unbiased=False)\nunb_v = inp_r.var(1, unbiased=True)\n\neps = 1e-5\n\nmean, var_biased = syncbn.welford_mean_var(inp_t)\ninv_std = 1.0 / torch.sqrt(var_biased + eps)\n\nbn = torch.nn.BatchNorm2d(feature_size).cuda()\nbn.momentum = 1.0\nbn.weight.data = weight_t.clone()\nbn.bias.data = bias_t.clone()\nif args.fp16:\n    bn.half()\nif args.fp64:\n    bn.double()\nbn = DDP(bn)\ninp_bn = inp_t.clone().requires_grad_()\ngrad_bn = grad_output_t.clone().detach()\nout_bn = bn(inp_bn)\nout_bn.backward(grad_bn)\n# compensating the averaging over processes done by DDP\n# in order to produce mathematically equivalent result\n# https://github.com/NVIDIA/apex/issues/134#issuecomment-458307368\nfor param in bn.parameters():\n    param.grad = param.grad / args.group_size\nbn_opt = optim.SGD(bn.parameters(), lr=1.0)\n\nsbn = apex.parallel.SyncBatchNorm(\n    feature_size,\n    process_group=apex.parallel.create_syncbn_process_group(args.group_size),\n).cuda()\nsbn.momentum = 1.0\nsbn.weight.data = weight_t.clone()\nsbn.bias.data = bias_t.clone()\nif args.fp16:\n    sbn.half()\nif args.fp64:\n    sbn.double()\nsbn = DDP(sbn)\nsbn_opt = optim.SGD(sbn.parameters(), lr=1.0)\ninp_sbn = inp_t.clone().requires_grad_()\ngrad_sbn = grad_output_t.clone().detach()\nout_sbn = sbn(inp_sbn[start:finish])\nout_sbn.backward(grad_sbn[start:finish])\n\nsbn_result = True\nbn_result = True\n\nif args.local_rank == 0:\n    sbn_result = compare(\"comparing mean: \", mean, m, error) and sbn_result\n    sbn_result = compare(\"comparing biased variance: \", var_biased, b_v, error) and sbn_result\n\nout = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t)\nout_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1, 1, 1) + eps) + bias_r\n\nif args.local_rank == 0:\n    sbn_result = compare(\"comparing output: \", out, out_r, error) and sbn_result\n    compare(\"comparing bn output: \", out_bn, out_r, error)\n\ngrad_output_t = type_tensor(grad)\n\ngrad_output_r = ref_tensor(grad.transpose(1, 0, 2, 3).reshape(feature_size, -1))\ngrad_output2_r = ref_tensor(grad)\n\ngrad_bias_r = grad_output_r.sum(1)\ngrad_weight_r = (\n    ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1, 1, 1) + eps) * grad_output2_r)\n    .transpose(1, 0)\n    .contiguous()\n    .view(feature_size, -1)\n    .sum(1)\n)\n\nmean_dy_r = grad_output_r.mean(1)\nmean_dy_xmu_r = (\n    ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r)\n    .transpose(1, 0)\n    .contiguous()\n    .view(feature_size, -1)\n    .mean(1)\n)\n\ngrad_input_r = (\n    (\n        grad_output2_r\n        - mean_dy_r.view(-1, 1, 1)\n        - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1, 1, 1) + eps) * mean_dy_xmu_r.view(-1, 1, 1)\n    )\n    * torch.rsqrt(b_v.view(-1, 1, 1) + eps)\n    * weight_r.view(-1, 1, 1)\n)\n\nmean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(\n    grad_output_t, inp_t, mean, inv_std, weight_t\n)\ngrad_input = syncbn.batchnorm_backward(\n    grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu\n)\n\nif args.local_rank == 0:\n    sbn_result = compare(\"comparing bias grad: \", grad_bias, grad_bias_r, error) and sbn_result\n    sbn_result = (\n        compare(\"comparing weight grad: \", grad_weight, grad_weight_r, error) and sbn_result\n    )\n    sbn_result = compare(\"comparing mean_dy grad: \", mean_dy, mean_dy_r, error) and sbn_result\n    sbn_result = (\n        compare(\"comparing mean_dy_xmu grad: \", mean_dy_xmu, mean_dy_xmu_r, error) and sbn_result\n    )\n    sbn_result = compare(\"comparing input grad: \", grad_input, grad_input_r, error) and sbn_result\n    compare(\"comparing bn input grad: \", inp_bn.grad, grad_input_r, error)\n\nif args.local_rank == 0:\n    sbn_result = (\n        compare(\n            \"comparing running_mean: \",\n            bn.module.running_mean.data,\n            sbn.module.running_mean.data,\n            error,\n        )\n        and sbn_result\n    )\n    sbn_result = (\n        compare(\n            \"comparing running_variance: \",\n            bn.module.running_var.data,\n            sbn.module.running_var.data,\n            error,\n        )\n        and sbn_result\n    )\n\n# execute by both\ncompare(\"comparing layers output: \", out_bn[start:finish], out_sbn, error) and sbn_result\ncompare(\n    \"comparing layers grad_input: \",\n    inp_bn.grad[start:finish],\n    inp_sbn.grad[start:finish],\n    error,\n) and sbn_result\n\nbn_opt.step()\nsbn_opt.step()\n\nif args.local_rank == 0:\n    compare(\"comparing bn vs sbn bias: \", bn.module.bias, sbn.module.bias, error)\n    compare(\"comparing bn vs sbn weight: \", bn.module.weight, sbn.module.weight, error)\n\n\nif sbn_result:\n    print(\"====SBN group test passed\")\nelse:\n    print(\"*SBN group test failed*\")\n"
  },
  {
    "path": "tests/distributed/synced_batchnorm/two_gpu_test_different_batch_size.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom apex.parallel import SyncBatchNorm as ApexSyncBatchNorm\n\nimport argparse\nimport os\nimport numpy as np\n\nvar_batch = 16\n\n\ndef compare(desc, inp1, inp2, error=1e-5):\n    a = inp1.clone().detach().cpu().numpy()\n    b = inp2.clone().detach().cpu().numpy()\n    close = np.allclose(a, b, error, error)\n    if not close:\n        print(desc, close)\n        z = a - b\n        index = (np.abs(z) >= error + error * np.abs(b)).nonzero()\n        print(\"dif    : \", z[index])\n        print(\"inp1   : \", a[index])\n        print(\"inp2   : \", b[index])\n    return close\n\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--local_rank\", type=int, default=0)\nparser.add_argument(\"--apex\", action=\"store_true\")\nargs = parser.parse_args()\n\n\ntorch.manual_seed(2809)\n# Setup DDP\ntorch.cuda.set_device(args.local_rank)\ndevice = torch.device(\"cuda:{}\".format(args.local_rank))\n\ntorch.distributed.init_process_group(\n    \"nccl\",\n    init_method=\"env://\",\n    rank=args.local_rank,\n)\n\n# Setup model\nif args.apex:\n    model = nn.Sequential(nn.Conv2d(3, 6, 3, 1, 1), ApexSyncBatchNorm(6))\nelse:\n    model = nn.Sequential(nn.Conv2d(3, 6, 3, 1, 1), nn.SyncBatchNorm(6))\n\n# Setup reference model\nmodel_reference = nn.Sequential(nn.Conv2d(3, 6, 3, 1, 1), nn.BatchNorm2d(6))\n\nwith torch.no_grad():\n    model_reference[0].weight.copy_(model[0].weight)\n    model_reference[0].bias.copy_(model[0].bias)\nmodel_reference.to(device)\n\nmodel = model.to(device)\nmodel = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)\n\nglobal_batch_size = var_batch + 8\n# Create random data\nif args.local_rank == 0:\n    data = torch.randn(var_batch, 3, 8, 8, device=device, dtype=torch.float) * 50.0\n    grad = torch.randint(0, 10, (var_batch, 6, 8, 8), device=device, dtype=torch.float) / 10.0\nelse:\n    data = torch.randn(8, 3, 8, 8, device=device)\n    grad = torch.randint(0, 10, (8, 6, 8, 8), device=device, dtype=torch.float) / 10.0\n\ndata.requires_grad_()\ndata.retain_grad = True\n\nweighted_gradient = True\n\n# DDP forward/backward\noutput = model(data)\n\nif weighted_gradient:\n    output.backward(grad * 2 / global_batch_size)\nelse:\n    output.backward(grad / output.size(0))\n\nd_list = [torch.randn(8, 3, 8, 8, device=device) for i in range(int(os.environ[\"WORLD_SIZE\"]))]\ny_list = [torch.randn(8, 6, 8, 8, device=device) for i in range(int(os.environ[\"WORLD_SIZE\"]))]\ndgrad_list = [torch.randn(8, 3, 8, 8, device=device) for i in range(int(os.environ[\"WORLD_SIZE\"]))]\ngrad_list = [torch.randn(8, 6, 8, 8, device=device) for i in range(int(os.environ[\"WORLD_SIZE\"]))]\nif args.local_rank == 0:\n    # placeholder, these random data will later be discarded.\n    torch.distributed.all_gather(d_list, torch.randn(8, 3, 8, 8, device=device))\n    torch.distributed.all_gather(y_list, torch.randn(8, 6, 8, 8, device=device))\n    torch.distributed.all_gather(dgrad_list, torch.randn(8, 3, 8, 8, device=device))\n    torch.distributed.all_gather(grad_list, torch.randn(8, 6, 8, 8, device=device))\nelse:\n    torch.distributed.all_gather(d_list, data)\n    torch.distributed.all_gather(y_list, output)\n    torch.distributed.all_gather(dgrad_list, data.grad)\n    torch.distributed.all_gather(grad_list, grad)\n\ntorch.distributed.barrier()\n\nif args.local_rank == 0:\n    ref_tensor = d_list[1:]\n    ref_tensor.insert(0, data)\n    assert ref_tensor[0].equal(data)\n    ref_tensor = torch.cat(ref_tensor, 0)\n    ref_tensor = ref_tensor.detach()\n    ref_tensor.requires_grad_()\n    ref_tensor.retain_grad()\n\n    # Reference forward/backward\n    output_reference = model_reference(ref_tensor)\n    grad_tensor = grad_list[1:]\n    grad_tensor.insert(0, grad)\n    assert grad_tensor[0].equal(grad)\n    grad_tensor = torch.cat(grad_tensor, 0)\n    if weighted_gradient:\n        output_reference.backward(grad_tensor / output_reference.size(0))\n    else:\n        output_reference.backward(grad_tensor / output_reference.size(0))\n\n    dgrad_tensor = dgrad_list[1:]\n    dgrad_tensor.insert(0, data.grad)\n    dgrad_tensor = torch.cat(dgrad_tensor, 0)\n    # check output\n    output_tensor = y_list[1:]\n    output_tensor.insert(0, output)\n    output_tensor = torch.cat(output_tensor, 0)\n    passed = True\n    passed = passed and compare(\"check output\", output_tensor, output_reference)\n    # check stats\n    passed = passed and compare(\n        \"check running mean failed\",\n        model_reference[1].running_mean,\n        model.module[1].running_mean,\n    )\n    passed = passed and compare(\n        \"check running var failed\",\n        model_reference[1].running_var,\n        model.module[1].running_var,\n    )\n    passed = passed and compare(\n        \"bn wgrad check failed!\",\n        model_reference[1].weight.grad,\n        model.module[1].weight.grad,\n        1e-6,\n    )\n    passed = passed and compare(\n        \"conv wgrad check failed!\",\n        model_reference[0].weight.grad,\n        model.module[0].weight.grad,\n    )\n    # can't really compare dgrad directly, as we need to scale it to account for\n    # DDP\n    # passed = passed and compare(\"dgrad check failed!\", ref_tensor.grad, dgrad_tensor)\n    if passed:\n        print(\"====SBN two gpu with different batches test passed\")\n    else:\n        assert \"*failed two gpu with different batches tests*\"\n"
  },
  {
    "path": "tests/distributed/synced_batchnorm/two_gpu_unit_test.py",
    "content": "import torch\nimport numpy as np\nimport apex\nimport syncbn\nimport os\nimport argparse\nimport torch.optim as optim\n\n\ndef compare(desc, inp1, inp2, error):\n    a = inp1.clone().detach().cpu().numpy()\n    b = inp2.clone().detach().cpu().numpy()\n    close = np.allclose(a, b, error, error)\n    if not close:\n        print(desc, close)\n        z = a - b\n        index = (np.abs(z) >= error + error * np.abs(b)).nonzero()\n        print(\"dif    : \", z[index])\n        print(\"inp1   : \", a[index])\n        print(\"inp2   : \", b[index])\n    return close\n\n\nfeature_size = 10\nspace_size = 40\nbatch_size = 32\n\n\nfrom apex.parallel import DistributedDataParallel as DDP\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--local_rank\", default=0, type=int)\nparser.add_argument(\"--fp16\", action=\"store_true\", default=False)\nparser.add_argument(\"--fp64\", action=\"store_true\", default=False)\nargs = parser.parse_args()\nargs.world_size = int(os.environ[\"WORLD_SIZE\"])\ntorch.cuda.set_device(args.local_rank)\ntorch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\nstart = args.local_rank * batch_size // args.world_size\nfinish = (args.local_rank + 1) * batch_size // args.world_size\n\nerror = 1e-5\ndtype = np.float32\nif args.fp16:\n    error = 1e-3\n    dtype = np.float16\nelif args.fp64:\n    error = 1e-8\n    dtype = np.float64\n\nnp.random.seed(18)\ninp = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype)\ngrad = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype)\nweight = np.random.randn(feature_size).astype(dtype)\nbias = np.random.randn(feature_size).astype(dtype)\n\n\ntype_tensor = torch.cuda.FloatTensor\nif args.fp16:\n    type_tensor = torch.cuda.HalfTensor\nif args.fp64:\n    type_tensor = torch.cuda.DoubleTensor\n\nref_tensor = torch.cuda.DoubleTensor\n\ninp_t = type_tensor(inp)\nweight_t = type_tensor(weight)\nbias_t = type_tensor(bias)\n\ninp_r = ref_tensor(inp.transpose(1, 0, 2, 3).reshape(feature_size, -1))\ninp2_r = ref_tensor(inp)\nweight_r = ref_tensor(weight).view(-1, 1, 1)\nbias_r = ref_tensor(bias).view(-1, 1, 1)\n\ngrad_output_t = type_tensor(grad)\n\nm = inp_r.mean(1)\nb_v = inp_r.var(1, unbiased=False)\nunb_v = inp_r.var(1, unbiased=True)\n\neps = 1e-5\n\nmean, var_biased = syncbn.welford_mean_var(inp_t)\ninv_std = 1.0 / torch.sqrt(var_biased + eps)\n\nbn = torch.nn.BatchNorm2d(feature_size).cuda()\nbn.momentum = 1.0\nbn.weight.data = weight_t.clone()\nbn.bias.data = bias_t.clone()\nif args.fp16:\n    bn.half()\nif args.fp64:\n    bn.double()\ninp_bn = inp_t.clone().requires_grad_()\ngrad_bn = grad_output_t.clone().detach()\nout_bn = bn(inp_bn)\nout_bn.backward(grad_bn)\n# compensating the averaging over processes done by DDP\n# in order to produce mathematically equivalent result\n# https://github.com/NVIDIA/apex/issues/134#issuecomment-458307368\nfor param in bn.parameters():\n    param.grad = param.grad / args.world_size\nbn_opt = optim.SGD(bn.parameters(), lr=1.0)\n\nsbn = apex.parallel.SyncBatchNorm(feature_size).cuda()\nsbn.momentum = 1.0\nsbn.weight.data = weight_t.clone()\nsbn.bias.data = bias_t.clone()\nif args.fp16:\n    sbn.half()\nif args.fp64:\n    sbn.double()\nsbn = DDP(sbn)\nsbn_opt = optim.SGD(sbn.parameters(), lr=1.0)\ninp_sbn = inp_t.clone().requires_grad_()\ngrad_sbn = grad_output_t.clone().detach()\nout_sbn = sbn(inp_sbn[start:finish])\nout_sbn.backward(grad_sbn[start:finish])\n\ncount = [\n    space_size**2 * ((i + 1) * batch_size // args.world_size - i * batch_size // args.world_size)\n    for i in range(0, args.world_size)\n]\ncount = torch.cuda.IntTensor(count)\n\nprint(\"--- count : \", count)\n\nsbn_result = True\nbn_result = True\n\nif args.local_rank == 0:\n    sbn_result = compare(\"comparing mean: \", mean, m, error) and sbn_result\n    sbn_result = compare(\"comparing biased variance: \", var_biased, b_v, error) and sbn_result\n\nout = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t)\nout_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1, 1, 1) + eps) + bias_r\n\nif args.local_rank == 0:\n    sbn_result = compare(\"comparing output: \", out, out_r, error) and sbn_result\n    compare(\"comparing bn output: \", out_bn, out_r, error)\n\ngrad_output_t = type_tensor(grad)\n\ngrad_output_r = ref_tensor(grad.transpose(1, 0, 2, 3).reshape(feature_size, -1))\ngrad_output2_r = ref_tensor(grad)\n\ngrad_bias_r = grad_output_r.sum(1)\ngrad_weight_r = (\n    ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1, 1, 1) + eps) * grad_output2_r)\n    .transpose(1, 0)\n    .contiguous()\n    .view(feature_size, -1)\n    .sum(1)\n)\n\nsum_dy_r = grad_output_r.sum(1)\nmean_dy_r = grad_output_r.mean(1)\nmean_dy_xmu_r = (\n    ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r)\n    .transpose(1, 0)\n    .contiguous()\n    .view(feature_size, -1)\n    .mean(1)\n)\nsum_dy_xmu_r = (\n    ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r)\n    .transpose(1, 0)\n    .contiguous()\n    .view(feature_size, -1)\n    .sum(1)\n)\n\ngrad_input_r = (\n    (\n        grad_output2_r\n        - mean_dy_r.view(-1, 1, 1)\n        - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1, 1, 1) + eps) * mean_dy_xmu_r.view(-1, 1, 1)\n    )\n    * torch.rsqrt(b_v.view(-1, 1, 1) + eps)\n    * weight_r.view(-1, 1, 1)\n)\n\nsum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(\n    grad_output_t, inp_t, mean, inv_std, weight_t\n)\ngrad_input = syncbn.batchnorm_backward(\n    grad_output_t, inp_t, mean, inv_std, weight_t, sum_dy, sum_dy_xmu, count\n)\nif args.local_rank == 0:\n    sbn_result = compare(\"comparing bias grad: \", grad_bias, grad_bias_r, error) and sbn_result\n    sbn_result = (\n        compare(\"comparing weight grad: \", grad_weight, grad_weight_r, error) and sbn_result\n    )\n    sbn_result = compare(\"comparing sum_dy grad: \", sum_dy, sum_dy_r, error) and sbn_result\n    sbn_result = (\n        compare(\"comparing sum_dy_xmu grad: \", sum_dy_xmu, sum_dy_xmu_r, error) and sbn_result\n    )\n    sbn_result = compare(\"comparing input grad: \", grad_input, grad_input_r, error) and sbn_result\n    compare(\"comparing bn input grad: \", inp_bn.grad, grad_input_r, error)\n\nif args.local_rank == 0:\n    sbn_result = (\n        compare(\n            \"comparing running_mean: \",\n            bn.running_mean.data,\n            sbn.module.running_mean.data,\n            error,\n        )\n        and sbn_result\n    )\n    sbn_result = (\n        compare(\n            \"comparing running_variance: \",\n            bn.running_var.data,\n            sbn.module.running_var.data,\n            error,\n        )\n        and sbn_result\n    )\n\n# execute by both\ncompare(\"comparing layers output: \", out_bn[start:finish], out_sbn, error) and sbn_result\ncompare(\n    \"comparing layers grad_input: \",\n    inp_bn.grad[start:finish],\n    inp_sbn.grad[start:finish],\n    error,\n) and sbn_result\n\nbn_opt.step()\nsbn_opt.step()\n\nif args.local_rank == 0:\n    compare(\"comparing bn vs sbn bias: \", bn.bias, sbn.module.bias, error)\n    compare(\"comparing bn vs sbn weight: \", bn.weight, sbn.module.weight, error)\n\n\nif sbn_result:\n    print(\"====SBN two gpu passed tests\")\nelse:\n    print(\"*SBN two gpu failed*\")\n"
  },
  {
    "path": "tests/distributed/synced_batchnorm/unit_test.sh",
    "content": "python python_single_gpu_unit_test.py\npython single_gpu_unit_test.py\npython test_batchnorm1d.py\npython -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py\npython -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py --fp16\npython -m torch.distributed.launch --nproc_per_node=2 two_gpu_test_different_batch_size.py --apex\n#beware, you need a system with at least 4 gpus to test group_size<world_size\n#python -m torch.distributed.launch --nproc_per_node=4 test_groups.py --group_size=2\n"
  },
  {
    "path": "tests/docker_extension_builds/run.sh",
    "content": "#!/bin/bash\n\nprint_banner() {\n  printf \"\\n\\n\\n\\e[30m\\e[42m$1\\e[0m\\n\\n\\n\\n\"\n}\n\nprint_green() {\n  printf \"\\e[30m\\e[42m$1\\e[0m\\n\"\n}\n\nprint_red() {\n  printf \"\\e[30m\\e[41m$1\\e[0m\\n\"\n}\n\nimages=(\n\"pytorch/pytorch:nightly-devel-cuda10.0-cudnn7\"\n\"pytorch/pytorch:1.1.0-cuda10.0-cudnn7.5-devel\"\n\"pytorch/pytorch:1.0.1-cuda10.0-cudnn7-devel\"\n\"pytorch/pytorch:1.0-cuda10.0-cudnn7-devel\"\n\"pytorch/pytorch:nightly-devel-cuda9.2-cudnn7\"\n)\n\nbranch=\"master\"\n\n# Associative array for exit codes\ndeclare -A exit_codes\nfor image in images\ndo\n  exit_codes[$image]=\"None\"\ndone\n\nfor image in \"${images[@]}\"\ndo\n  print_banner \"$image\"\n  set -x\n  docker pull $image\n  # Trying python setup.py install instead of pip install to ensure direct access to error codes.\n  # Maybe pip install would be ok too but this works.\n  docker run --runtime=nvidia --rm $image /bin/bash -c \"yes | pip uninstall apex; yes | pip uninstall apex; git clone https://github.com/NVIDIA/apex.git; cd apex; git checkout $branch; set -e;  python setup.py install --cuda_ext --cpp_ext\"\n  exit_code=$?\n  set +x\n  if [ $exit_code != 0 ]\n  then\n    print_red \"Exit code: $exit_code\"\n  else\n    print_green \"Exit code: $exit_code\"\n  fi\n  exit_codes[$image]=$exit_code\ndone\n\nsuccess=0\nfor image in \"${images[@]}\"\ndo\n  exit_code=${exit_codes[$image]}\n  if [ $exit_code != 0 ]\n  then\n    print_red \"$image : $exit_code\"\n    success=1\n  else\n    print_green \"$image : $exit_code\"\n  fi\ndone\n\nif [ $success != 0 ]\nthen\n  print_red \"Overall status:  failure\"\nelse\n  print_green \"Overall status:  success\"\nfi\n\nexit $success\n"
  }
]