[
  {
    "path": ".gitignore",
    "content": "*.pyc\n**/.ipynb_checkpoints\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "content": "# Code of Conduct\n\n## Our Pledge\n\nIn the interest of fostering an open and welcoming environment, we as\ncontributors and maintainers pledge to make participation in our project and\nour community a harassment-free experience for everyone, regardless of age, body\nsize, disability, ethnicity, sex characteristics, gender identity and expression,\nlevel of experience, education, socio-economic status, nationality, personal\nappearance, race, religion, or sexual identity and orientation.\n\n## Our Standards\n\nExamples of behavior that contributes to creating a positive environment\ninclude:\n\n* Using welcoming and inclusive language\n* Being respectful of differing viewpoints and experiences\n* Gracefully accepting constructive criticism\n* Focusing on what is best for the community\n* Showing empathy towards other community members\n\nExamples of unacceptable behavior by participants include:\n\n* The use of sexualized language or imagery and unwelcome sexual attention or\nadvances\n* Trolling, insulting/derogatory comments, and personal or political attacks\n* Public or private harassment\n* Publishing others' private information, such as a physical or electronic\naddress, without explicit permission\n* Other conduct which could reasonably be considered inappropriate in a\nprofessional setting\n\n## Our Responsibilities\n\nProject maintainers are responsible for clarifying the standards of acceptable\nbehavior and are expected to take appropriate and fair corrective action in\nresponse to any instances of unacceptable behavior.\n\nProject maintainers have the right and responsibility to remove, edit, or\nreject comments, commits, code, wiki edits, issues, and other contributions\nthat are not aligned to this Code of Conduct, or to ban temporarily or\npermanently any contributor for other behaviors that they deem inappropriate,\nthreatening, offensive, or harmful.\n\n## Scope\n\nThis Code of Conduct applies within all project spaces, and it also applies when\nan individual is representing the project or its community in public spaces.\nExamples of representing a project or community include using an official\nproject e-mail address, posting via an official social media account, or acting\nas an appointed representative at an online or offline event. Representation of\na project may be further defined and clarified by project maintainers.\n\nThis Code of Conduct also applies outside the project spaces when there is a\nreasonable belief that an individual's behavior may have a negative impact on\nthe project or its community.\n\n## Enforcement\n\nInstances of abusive, harassing, or otherwise unacceptable behavior may be\nreported by contacting the project team at <opensource-conduct@meta.com>. All\ncomplaints will be reviewed and investigated and will result in a response that\nis deemed necessary and appropriate to the circumstances. The project team is\nobligated to maintain confidentiality with regard to the reporter of an incident.\nFurther details of specific enforcement policies may be posted separately.\n\nProject maintainers who do not follow or enforce the Code of Conduct in good\nfaith may face temporary or permanent repercussions as determined by other\nmembers of the project's leadership.\n\n## Attribution\n\nThis Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,\navailable at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html\n\n[homepage]: https://www.contributor-covenant.org\n\nFor answers to common questions about this code of conduct, see\nhttps://www.contributor-covenant.org/faq\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing to Applied AI\nWe want to make contributing to this project as easy and transparent as\npossible.\n\n## Our Development Process\n... (in particular how this is synced with internal changes to the project)\n\n## Pull Requests\nWe actively welcome your pull requests.\n\n1. Fork the repo and create your branch from `main`.\n2. If you've added code that should be tested, add tests.\n3. If you've changed APIs, update the documentation.\n4. Ensure the test suite passes.\n5. Make sure your code lints.\n6. If you haven't already, complete the Contributor License Agreement (\"CLA\").\n\n## Contributor License Agreement (\"CLA\")\nIn order to accept your pull request, we need you to submit a CLA. You only need\nto do this once to work on any of Meta's open source projects.\n\nComplete your CLA here: <https://code.facebook.com/cla>\n\n## Issues\nWe use GitHub issues to track public bugs. Please ensure your description is\nclear and has sufficient instructions to be able to reproduce the issue.\n\nMeta has a [bounty program](https://www.facebook.com/whitehat/) for the safe\ndisclosure of security bugs. In those cases, please go through the process\noutlined on that page and do not file a public issue.\n\n## Coding Style\n* 2 spaces for indentation rather than tabs\n* 80 character line length\n* ...\n\n## License\nBy contributing to applied-ai, you agree that your contributions will be licensed\nunder the LICENSE file in the root directory of this source tree.\n"
  },
  {
    "path": "LICENSE",
    "content": "Copyright 2024 Meta\n\nRedistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n\n1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n\n2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n\n3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "assets/images/dev-discuss-asynctp/readme.md",
    "content": "This folder is for hosting the images for the AsyncTP public post at:    \n[https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487)\n"
  },
  {
    "path": "assets/images/readme.md",
    "content": "Folder for housing images for the readmes.\n"
  },
  {
    "path": "dev/sr/.gitignore",
    "content": "*.o\n*.ninja\n*.txt\n*.egg-info\n*.ninja-deps\n*.ninja-log/\n*.so\ndist/\nbuild/\n"
  },
  {
    "path": "dev/sr/readme.md",
    "content": "Branch for stochastic rounding kernel\nCurrently processes 4 elements per thread to leverage rand4\n"
  },
  {
    "path": "dev/sr/setup.py",
    "content": "\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\nsetup(\n    name='stochastic_rounding_cuda',\n    version='0.1.021825',\n    ext_modules=[\n        CUDAExtension('stochastic_rounding_cuda', [\n            'src/stochastic_rounding.cu',\n            'src/stochastic_rounding_cuda.cu'\n        ],\n        extra_compile_args={\n            'cxx': ['-O3'],\n            'nvcc': [\n                '-O3',\n                '--expt-relaxed-constexpr',  # better template support\n                #'-gencode=arch=compute_70,code=sm_70',  # Volta\n                #'-gencode=arch=compute_75,code=sm_75',  # Turing\n                #'-gencode=arch=compute_80,code=sm_80'   # Amper\n                #'-gencode=arch=compute_86,code=sm_86'   # Ampere\n                '-gencode=arch=compute_90,code=sm_90',  # Hopper\n            ]\n        })\n    ],\n    cmdclass={\n        'build_ext': BuildExtension\n    }\n)\n"
  },
  {
    "path": "dev/sr/src/stochastic_rounding.cu",
    "content": "\n#include <pybind11/pybind11.h>\n#include \"stochastic_rounding.hpp\"\n#include <random>\n\nnamespace py = pybind11;\n\n__host__ int getOptimalBlockSize() {\n    cudaDeviceProp prop;\n    cudaGetDeviceProperties(&prop, 0);\n    return std::min(prop.maxThreadsPerBlock, 256);\n}\n\ntorch::Tensor stochastic_round_bf16_cuda(torch::Tensor input, bool requires_grad) {\n    TORCH_CHECK(input.is_cuda(), \"Input tensor must be on CUDA device\");\n    TORCH_CHECK(input.is_contiguous(), \"Input tensor must be contiguous\");\n    TORCH_CHECK(input.scalar_type() == torch::kFloat32, \"Input tensor must be float32\");\n\n    const int threads_per_block = 256;\n    const int num_elements = input.numel();\n    const int elements_per_thread = 4;\n\n    const int min_blocks = (num_elements + elements_per_thread * threads_per_block - 1) /\n                          (elements_per_thread * threads_per_block);\n\n    cudaDeviceProp prop;\n    cudaGetDeviceProperties(&prop, 0);\n    const int blocks_per_sm = 4;\n    const int min_blocks_for_sms = prop.multiProcessorCount * blocks_per_sm;\n    const int num_blocks = std::max(min_blocks, min_blocks_for_sms);\n\n    auto options = torch::TensorOptions()\n                      .dtype(torch::kBFloat16)\n                      .device(input.device())\n                      .requires_grad(requires_grad);\n    auto output = torch::empty_like(input, options);\n\n    std::random_device rd;\n    std::mt19937_64 gen(rd());\n    std::uniform_int_distribution<unsigned long long> dis;\n    const unsigned long long seed = dis(gen);\n\n    stochastic_round_bf16<<<num_blocks, threads_per_block>>>(\n        input.data_ptr<float>(),\n        reinterpret_cast<__nv_bfloat16*>(output.data_ptr()),\n        num_elements,\n        seed);\n\n    cudaError_t err = cudaGetLastError();\n    TORCH_CHECK(err == cudaSuccess,\n                \"CUDA kernel execution failed: \", cudaGetErrorString(err));\n\n    return output;\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"stochastic_round_bf16\",\n          static_cast<torch::Tensor (*)(torch::Tensor, bool)>(&stochastic_round_bf16_cuda),\n          \"Stochastic rounding to BFloat16\",\n          py::arg(\"input\"),\n          py::arg(\"requires_grad\") = false);\n}\n"
  },
  {
    "path": "dev/sr/src/stochastic_rounding.hpp",
    "content": "\n#pragma once\n#include <cuda_bf16.h>\n#include <cuda_runtime.h>\n#include <vector_types.h>\n#include <torch/extension.h>\n#include <pybind11/pybind11.h>\n\nnamespace philox {\n    constexpr unsigned int W32_0   = 0x9E3779B9;\n    constexpr unsigned int W32_1   = 0xBB67AE85;\n    constexpr unsigned int M0      = 0xD2511F53;\n    constexpr unsigned int M1      = 0xCD9E8D57;\n    constexpr int ROUNDS          = 7;\n}\n\n// Forward declarations\nclass PhiloxGenerator {\npublic:\n    __device__ __forceinline__ PhiloxGenerator();\n    __device__ __forceinline__ void init(const unsigned long long seed, const unsigned int thread_id);\n    __device__ __forceinline__ uint4 next();\nprivate:\n    uint2 key;\n    uint4 counter;\n    static __device__ __forceinline__ uint2 mulhilo(const unsigned int a, const unsigned int b);\n    static __device__ __forceinline__ uint4 round(uint4 ctr, uint2 key);\n};\n\n// CUDA kernel declaration\n__global__ void stochastic_round_bf16(\n    float *__restrict__ input,\n    __nv_bfloat16 *__restrict__ output,\n    const int size,\n    const unsigned long long seed);\n\n// Host functions\n__host__ int getOptimalBlockSize();\ntorch::Tensor stochastic_round_bf16_cuda(torch::Tensor input, bool requires_grad = false);\n"
  },
  {
    "path": "dev/sr/src/stochastic_rounding_cuda.cu",
    "content": " #include \"stochastic_rounding.hpp\"\n#include <cstdint>\n\n// Philox RNG implementation\n\n__device__ __forceinline__ PhiloxGenerator::PhiloxGenerator() :\n    key(make_uint2(0, 0)),\n    counter(make_uint4(0, 0, 0, 0)) {}\n\n__device__ __forceinline__ void PhiloxGenerator::init(const unsigned long long seed, const unsigned int thread_id) {\n    key.x = static_cast<unsigned int>(seed);\n    key.y = static_cast<unsigned int>(seed >> 32);\n    counter = make_uint4(thread_id, 0, 0, 0);\n    __threadfence_block();\n}\n\n__device__ __forceinline__ uint2 PhiloxGenerator::mulhilo(const unsigned int a, const unsigned int b) {\n    uint2 result;\n    unsigned long long prod;\n    asm(\"mul.wide.u32 %0, %1, %2;\" : \"=l\"(prod) : \"r\"(a), \"r\"(b));\n    result.x = static_cast<unsigned int>(prod);\n    result.y = static_cast<unsigned int>(prod >> 32);\n    return result;\n}\n\n__device__ __forceinline__ uint4 PhiloxGenerator::round(uint4 ctr, uint2 key) {\n    const uint2 mul0 = mulhilo(philox::M0, ctr.x);\n    const uint2 mul1 = mulhilo(philox::M1, ctr.z);\n\n    return make_uint4(\n        mul1.y ^ ctr.y ^ key.x,\n        mul1.x,\n        mul0.y ^ ctr.w ^ key.y,\n        mul0.x\n    );\n}\n\n__device__ __forceinline__ uint4 PhiloxGenerator::next() {\n    uint4 ctr = counter;\n    uint2 k = key;\n\n    #pragma unroll\n    for (int i = 0; i < philox::ROUNDS; ++i) {\n        ctr = round(ctr, k);\n        k.x += philox::W32_0;\n        k.y += philox::W32_1;\n    }\n\n    counter.x += 4;\n    return ctr;\n}\n\n__device__ __forceinline__ __nv_bfloat16 float_to_bf16_stochastic(const float value, const uint32_t rand) {\n    const uint32_t val_bits = __float_as_uint(value);\n    const uint32_t rounding_bits = val_bits & 0xFFFF;\n    uint32_t result = val_bits & 0xFFFF0000u;\n    result += (rand & 0xFFFF) < rounding_bits ? 0x10000u : 0;\n    return __float2bfloat16(__uint_as_float(result));\n}\n\n__device__ __forceinline__ void float4_to_bf16_stochastic(\n    const float4& values,\n    uint4& rand_vals,\n    __nv_bfloat16* output) {\n\n    float vals[4] = {values.x, values.y, values.z, values.w};\n    uint32_t rands[4] = {rand_vals.x, rand_vals.y, rand_vals.z, rand_vals.w};\n\n    #pragma unroll\n    for (int i = 0; i < 4; i++) {\n        output[i] = float_to_bf16_stochastic(vals[i], rands[i]);\n    }\n}\n\n__global__ void stochastic_round_bf16(\n    float *__restrict__ input,\n    __nv_bfloat16 *__restrict__ output,\n    const int size,\n    const unsigned long long seed) {\n\n    PhiloxGenerator rng;\n    rng.init(seed, blockIdx.x * blockDim.x + threadIdx.x);\n\n    int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;\n    int stride = blockDim.x * gridDim.x * 4;\n\n    float4 values;\n    __nv_bfloat16 local_output[4];\n\n    // Process full vectors of 4 elements\n    for (; idx <= size - 4; idx += stride) {\n        values = *reinterpret_cast<float4*>(&input[idx]);\n        uint4 rand = rng.next();\n        float4_to_bf16_stochastic(values, rand, local_output);\n\n        for (int j = 0; j < 4; j++) {\n            output[idx + j] = local_output[j];\n        }\n    }\n\n    // Handle remaining elements\n    if (idx < size) {\n        float remaining_values[4] = {0.0f, 0.0f, 0.0f, 0.0f};\n        int remainder = size - idx;\n\n        for (int j = 0; j < remainder; j++) {\n            remaining_values[j] = input[idx + j];\n        }\n\n        values.x = remaining_values[0];\n        values.y = remaining_values[1];\n        values.z = remaining_values[2];\n        values.w = remaining_values[3];\n\n        uint4 rand = rng.next();\n        float4_to_bf16_stochastic(values, rand, local_output);\n\n        for (int j = 0; j < remainder; j++) {\n            output[idx + j] = local_output[j];\n        }\n    }\n}\n"
  },
  {
    "path": "dev/sr/test.md",
    "content": "(tkdev11) [less@devgpu115.cco2 ~/local/applied-ai/dev/sr (sr_kernel)]$ python usage.py\nLaunching kernel with blocks=1, threads_per_block=256, num_elements=12\nInput tensor: tensor([ 0.3282, -0.4513, -1.0612,  0.1446, -0.8440, -1.4669, -0.7135, -0.6183,\n        -2.2411,  2.1464,  1.4772, -1.3564], device='cuda:0')\nOutput tensor: tensor([ 0.3281, -0.4512, -1.0625,  0.1445, -0.8438, -1.4688, -0.7109, -0.6172,\n        -2.2344,  2.1406,  1.4766, -1.3516], device='cuda:0',\n       dtype=torch.bfloat16)\nOutput tensor dtype: torch.bfloat16\nSuccess!\n"
  },
  {
    "path": "dev/sr/tests/benchmark.py",
    "content": "import torch\nimport stochastic_rounding_cuda\nimport numpy as np\nimport time\nfrom tabulate import tabulate\nimport argparse\n\ndef measure_performance(func, input_tensor, warmup=0, repeats=1):\n    \"\"\"Measure performance of a function with proper CUDA synchronization\"\"\"\n    # Warmup\n    for _ in range(warmup):\n        output = func(input_tensor)\n\n    torch.cuda.synchronize()\n    start = time.perf_counter()\n\n    for _ in range(repeats):\n        output = func(input_tensor)\n\n    torch.cuda.synchronize()\n    end = time.perf_counter()\n\n    avg_time = (end - start) / repeats\n    elements_per_second = input_tensor.numel() / avg_time\n    return avg_time, elements_per_second\n\ndef benchmark_sizes(sizes= [1000, 10000, 100000, 1000000, 10000000, (10000000*10), (10000000*100)]):\n    #[ 50,000,000]): #\n    \"\"\"Benchmark different input sizes\"\"\"\n    results = []\n\n    for size in sizes:\n        # Create input tensor\n        x = torch.randn(size, device='cuda')\n\n        # Measure stochastic rounding\n        time_stoch, throughput_stoch = measure_performance(\n            stochastic_rounding_cuda.stochastic_round_bf16, x)\n\n        # Measure regular BF16 casting\n        time_regular, throughput_regular = measure_performance(\n            lambda t: t.to(torch.bfloat16), x)\n\n        results.append([\n            size,\n            time_stoch * 1000,  # convert to ms\n            throughput_stoch / 1e6,  # convert to GElements/s\n            time_regular * 1000,\n            throughput_regular / 1e6,\n            throughput_regular / throughput_stoch  # speedup\n        ])\n\n    print(\"\\nSize Comparison:\")\n    print(tabulate(results,\n                  headers=['Size', 'Stoch Time (ms)', 'Stoch ME/s',\n                          'Regular Time (ms)', 'Regular ME/s', 'Casting faster by'],\n                  floatfmt='.3f'))\n\ndef benchmark_shapes(total_size=1000000):\n    \"\"\"Benchmark different tensor shapes with same total size\"\"\"\n    shapes = [\n        (total_size,),           # 1D\n        (1000, total_size//1000),  # 2D\n        (100, 100, total_size//10000),  # 3D\n    ]\n\n    results = []\n    for shape in shapes:\n        x = torch.randn(*shape, device='cuda')\n        time_stoch, throughput_stoch = measure_performance(\n            stochastic_rounding_cuda.stochastic_round_bf16, x)\n\n        results.append([\n            'x'.join(str(d) for d in shape),\n            time_stoch * 1000,\n            throughput_stoch / 1e9\n        ])\n\n    print(\"\\nShape Comparison (same total size):\")\n    print(tabulate(results,\n                  headers=['Shape', 'Time (ms)', 'GElements/s'],\n                  floatfmt='.3f'))\n\ndef stress_test(duration=10):\n    \"\"\"Run a stress test for specified duration\"\"\"\n    print(f\"\\nRunning stress test for {duration} seconds...\")\n\n    size = 1000000\n    x = torch.randn(size, device='cuda')\n    start_time = time.time()\n    iterations = 0\n\n    while time.time() - start_time < duration:\n        stochastic_rounding_cuda.stochastic_round_bf16(x)\n        iterations += 1\n\n    print(f\"Completed {iterations} iterations without errors\")\n    print(f\"Average throughput: {(iterations * size) / (duration * 1e9):.2f} GElements/s\")\n\ndef memory_test(max_size=1e9):\n    \"\"\"Test memory scaling\"\"\"\n    sizes = np.logspace(3, min(9, np.log10(max_size)), num=7, dtype=int)\n    results = []\n\n    for size in sizes:\n        try:\n            torch.cuda.empty_cache()\n            x = torch.randn(size, device='cuda')\n            torch.cuda.synchronize()\n\n            # Measure peak memory during operation\n            torch.cuda.reset_peak_memory_stats()\n            _ = stochastic_rounding_cuda.stochastic_round_bf16(x)\n            torch.cuda.synchronize()\n\n            peak_memory = torch.cuda.max_memory_allocated() / 1e6  # MB\n            results.append([size, peak_memory])\n\n        except RuntimeError as e:\n            print(f\"Out of memory at size {size}\")\n            break\n\n    print(\"\\nMemory Usage:\")\n    print(tabulate(results,\n                  headers=['Size', 'Peak Memory (MB)'],\n                  floatfmt='.2f'))\n\ndef main():\n    parser = argparse.ArgumentParser(description='Benchmark stochastic rounding')\n    parser.add_argument('--sizes', action='store_true', help='Run size benchmarks')\n    parser.add_argument('--shapes', action='store_true', help='Run shape benchmarks')\n    parser.add_argument('--stress', action='store_true', help='Run stress test')\n    parser.add_argument('--memory', action='store_true', help='Run memory test')\n    parser.add_argument('--all', action='store_true', help='Run all benchmarks')\n\n    args = parser.parse_args()\n\n    # Print device information\n    device = torch.cuda.get_device_properties(0)\n    print(f\"\\nRunning on: {device.name}\")\n    print(f\"Compute Capability: {device.major}.{device.minor}\")\n\n\n    if args.all or args.sizes:\n        benchmark_sizes()\n\n    if args.all or args.shapes:\n        benchmark_shapes()\n\n    if args.all or args.stress:\n        stress_test()\n\n    if args.all or args.memory:\n        memory_test()\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "dev/sr/tests/core_unit_tests.py",
    "content": "import torch\nimport numpy as np\nfrom collections import Counter\nimport unittest\nimport stochastic_rounding_cuda\nimport time\n\nclass TestStochasticRounding(unittest.TestCase):\n    def setup(self):\n        # Ensure deterministic behavior for some tests\n        torch.manual_seed(42)\n        np.random.seed(42)\n\n    def _test_rounding_statistics_helper(self, value, lower_value, upper_value, tensor_size=10000, rounds=100):\n        \"\"\"Helper method for testing stochastic rounding statistics\"\"\"\n        print(f\"\\nInput value: {value}\")\n        MAX_VARIANCE = 0.03\n        x = torch.full((tensor_size,), value, device='cuda')\n        torch.cuda.manual_seed(42)\n\n        # Single round test - isolate and show the round up and round down values\n        single_result = stochastic_rounding_cuda.stochastic_round_bf16(x)\n        print(f\"Possible rounded values: {torch.unique(single_result)}\")\n\n        # Multiple rounds\n        results = torch.empty((rounds, tensor_size), device='cuda', dtype=torch.bfloat16)\n        for i in range(rounds):\n            results[i] = stochastic_rounding_cuda.stochastic_round_bf16(x)\n\n        prob_up = (results == upper_value).float().mean().item()\n        print(f\"Kernel's probability of rounding up: {prob_up:.4f}\")\n\n        distance_to_lower = abs(value - lower_value)\n        total_distance = upper_value - lower_value\n        expected_prob = distance_to_lower / total_distance\n        print(f\"Expected probability: {expected_prob:.4f}\")\n\n        self.assertTrue(abs(prob_up - expected_prob) < MAX_VARIANCE)\n\n    def test_special_values(self):\n        \"\"\"Test handling of special values like inf, -inf, nan\"\"\"\n        special_values = torch.tensor([float('inf'), float('-inf'), float('nan'), 0.0, -0.0],\n                                    device='cuda')\n        rounded = stochastic_rounding_cuda.stochastic_round_bf16(special_values)\n\n        # Check inf and -inf are preserved\n        self.assertTrue(torch.isinf(rounded[0]))\n        self.assertTrue(torch.isinf(rounded[1]))\n        self.assertTrue(rounded[0] > 0)\n        self.assertTrue(rounded[1] < 0)\n\n        # Check nan is preserved\n        self.assertTrue(torch.isnan(rounded[2]))\n\n        # Check zeros are preserved\n        self.assertEqual(rounded[3].item(), 0.0)\n        self.assertEqual(rounded[4].item(), 0.0)\n\n    def test_small_values(self):\n        \"\"\"Test handling of small values near zero\"\"\"\n        small_values = torch.tensor([1e-38, -1e-38, 1e-20, -1e-20], device='cuda')\n        rounded = stochastic_rounding_cuda.stochastic_round_bf16(small_values)\n\n        # Check that very small values are handled properly\n        self.assertTrue(torch.all(torch.isfinite(rounded)))\n\n    def test_vectorized_loading(self):\n        \"\"\"Test if vectorized loading works correctly for different tensor sizes\"\"\"\n        sizes = [4, 8, 9, 16, 32, 100]  # Test various sizes including non-aligned\n\n        for size in sizes:\n            x = torch.linspace(1, size, size, device='cuda')\n            rounded = stochastic_rounding_cuda.stochastic_round_bf16(x)\n\n            # Check output size matches input\n            self.assertEqual(rounded.size(0), size)\n\n            # Check dtype\n            self.assertEqual(rounded.dtype, torch.bfloat16)\n\n    def test_large_values(self):\n        \"\"\"Test handling of large values\"\"\"\n        large_values = torch.tensor([1e38, -1e38, 1e20, -1e20], device='cuda')\n        rounded = stochastic_rounding_cuda.stochastic_round_bf16(large_values)\n\n        # Values should be preserved approximately in BF16 range\n        self.assertTrue(torch.all(torch.isfinite(rounded)))\n\n    def test_rounding_statistics(self):\n        \"\"\"Test if rounding probabilities match expected distribution\"\"\"\n        self._test_rounding_statistics_helper(2.1999969482421875, 2.1875, 2.2031)\n\n    def test_rounding_statistics_2(self):\n        \"\"\"Test stochastic rounding with different BF16 boundary values\"\"\"\n        self._test_rounding_statistics_helper(1.7999992370605469, 1.7969, 1.8047)\n\n    def test_rounding_statistics_small(self):\n        \"\"\"Test stochastic rounding for number between 0 and 1\"\"\"\n        self._test_rounding_statistics_helper(0.7499847412109375, 0.7480, 0.7500)\n\n    def test_rounding_statistics_large(self):\n        \"\"\"Test stochastic rounding for large number, over 100\"\"\"\n        self._test_rounding_statistics_helper(128.99998474121094, 128.875, 129.000)\n\n\n\nif __name__ == '__main__':\n    unittest.main(verbosity=2)\n"
  },
  {
    "path": "dev/sr/usage.py",
    "content": "import torch\nimport stochastic_rounding_cuda\n\n# Create input tensor\ninput_tensor = torch.randn(12, device='cuda', dtype=torch.float32)\n\n# Apply stochastic rounding\noutput_tensor = stochastic_rounding_cuda.stochastic_round_bf16(input_tensor)\nprint(f\"Input tensor: {input_tensor}\")\nprint(f\"Output tensor: {output_tensor}\")\nprint(f\"Output tensor dtype: {output_tensor.dtype}\")\nprint(f\"Success!\")\n\n'''\n# Test tensor\nx = torch.tensor([9.8751e-01, -8.5288e-01, 1.6775e+00, -1.3683e+00,\n                  4.0467e-01, 1.0759e-03, 2.8418e-01, -4.9392e-01,\n                  8.7239e-01, -9.0545e-01, 1.1134e+00, 0],  # -2.6872e+00\n                device='cuda')\n\n# Convert to BF16\ny = stochastic_rounding_cuda.stochastic_round_bf16(x)\nprint(f\"Input: {x}\")\nprint(f\"Output: {y}\")\n'''\n"
  },
  {
    "path": "dev/sr/usage2.py",
    "content": "import torch\nimport stochastic_rounding_cuda\n\n# Test tensor\nx = torch.tensor([9.8751e-01, -8.5288e-01, 1.6775e+00], device='cuda')\n\n# Compare with regular rounding\ny_normal = x.to(torch.bfloat16)\ny_stochastic = stochastic_rounding_cuda.stochastic_round_bf16(x)\n\nprint(f\"Input: {x}\")\nprint(f\"Normal BF16: {y_normal}\")\nprint(f\"Stochastic BF16: {y_stochastic}\")\n"
  },
  {
    "path": "dev/triton_groupGEMM/groupgemm.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\n# pyre-unsafe\n\nimport functools\nfrom typing import Optional\n\nimport tma_utils as utils\n\nimport torch\n\nimport triton\nimport triton.language as tl\nfrom triton.runtime import driver  # @manual\n\n\n_NV_CONFIGS = [\n    triton.Config(\n        {\n            \"BLOCK_SIZE_M\": block_size_m,\n            \"BLOCK_SIZE_N\": block_size_n,\n            \"BLOCK_SIZE_K\": block_size_k,\n        },\n        num_stages=num_stages,\n        num_warps=num_warps,\n        num_ctas=num_ctas,\n    )\n    for block_size_m in [64, 128]\n    for block_size_n in [64, 128, 256]\n    for block_size_k in [64, 128, 256]\n    for num_stages in [3, 4]\n    for num_warps in [4, 8]\n    for num_ctas in [1]\n]\n\n_AMD_CONFIGS = [\n    triton.Config(\n        {\n            \"BLOCK_SIZE_M\": block_size_m,\n            \"BLOCK_SIZE_N\": block_size_n,\n            \"BLOCK_SIZE_K\": block_size_k,\n            \"waves_per_eu\": waves_per_cu,\n            \"matrix_instr_nonkdim\": matrix_instr_nonkdim,\n        },\n        num_stages=num_stages,\n        num_warps=num_warps,\n    )\n    for block_size_m in [32, 64, 128]\n    for block_size_n in [32, 64, 128, 256]\n    for block_size_k in [128, 256]\n    for num_stages in [1, 2]\n    for num_warps, waves_per_cu in [(4, 1), (8, 2), (16, 4)]\n    for matrix_instr_nonkdim in [16]\n]\n\n\ndef early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):\n    device = torch.cuda.current_device()\n    # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages\n    if dtsize is None:\n        dtsize = named_args[\"c_ptr\"].element_size()\n    if dtype is None:\n        dtype = named_args[\"c_ptr\"].dtype\n\n    pruned_configs = []\n    for config in configs:\n        kw = config.kwargs\n        BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (\n            kw[\"BLOCK_SIZE_M\"],\n            kw[\"BLOCK_SIZE_N\"],\n            kw[\"BLOCK_SIZE_K\"],\n            config.num_stages,\n        )\n        G, M, N, K = (\n            named_args[\"G\"],\n            named_args[\"M_BUCKET\"],\n            named_args[\"N\"],\n            named_args[\"K\"],\n        )\n\n        # 1. make sure we have enough smem\n        max_shared_memory = driver.active.utils.get_device_properties(device)[\n            \"max_shared_mem\"\n        ]\n        if torch.version.hip:\n            required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize\n        else:\n            required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize\n        if required_shared_memory > max_shared_memory:\n            continue\n\n        M_PER_GROUP = M // G\n        MIN_M_TILES = 32 if torch.version.hip else 64\n        # 2. make sure we don't load M tiles that are too big\n        if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):\n            continue\n        # 3. make sure we don't load N tiles that are too small\n        if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):\n            continue\n\n        num_sm = driver.active.utils.get_device_properties(device)[\n            \"multiprocessor_count\"\n        ]\n        N_TILES = N // BLOCK_N\n        MIN_N_TILES = 32 if torch.version.hip else 64\n        # 4. make sure we don't load N tiles that are too big\n        if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:\n            continue\n        # 5. make sure we don't load N tiles that are too small\n        if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:\n            continue\n        # 6. make sure K can be evenly divided\n        if K % BLOCK_K != 0:\n            continue\n\n        pruned_configs.append(config)\n\n    return pruned_configs\n\n\n@triton.autotune(\n    configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,\n    key=[\"G\", \"M_BUCKET\", \"N\", \"K\"],\n    prune_configs_by={\"early_config_prune\": early_config_prune},\n)\n@triton.jit\ndef _kernel_grouped_gemm(\n    a_desc_ptr,\n    b_desc_ptr,\n    c_ptr,\n    workspace,\n    m_sizes,\n    # problem sizes\n    G: tl.constexpr,\n    M_BUCKET: tl.constexpr,\n    N: tl.constexpr,\n    K: tl.constexpr,\n    NUM_SMS: tl.constexpr,\n    USE_TMA_LOAD: tl.constexpr,\n    USE_TMA_STORE: tl.constexpr,\n    # tile sizes\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n) -> None:\n    tidx = tl.program_id(0)\n\n    dtype: tl.dtype = c_ptr.dtype.element_ty\n    TMA_SIZE: tl.constexpr = tl.constexpr(128)\n    if USE_TMA_STORE:\n        c_desc_ptr = workspace + tidx * TMA_SIZE\n    else:\n        c_desc_ptr = None\n\n    M_end_offset = 0\n    iterated_tiles = 0\n    for g in tl.range(G):\n        # Move across groups\n        M_start_offset = M_end_offset\n        m_size = tl.load(m_sizes + g)\n        M_end_offset = M_start_offset + m_size\n\n        if m_size > 0:\n            N_start_offset = g * N\n            n_size = N\n            num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)\n            num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)\n            num_tiles = num_m_tiles * num_n_tiles\n\n            if USE_TMA_STORE:\n                # pyre-ignore\n                tl.extra.cuda.experimental_device_tensormap_create2d(\n                    desc_ptr=c_desc_ptr,\n                    global_address=c_ptr + M_start_offset * N,\n                    load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],\n                    global_size=[m_size, n_size],\n                    element_ty=c_ptr.dtype.element_ty,\n                )\n                # pyre-ignore\n                tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)\n\n            # Move across tiles\n            while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:\n                gidx = tidx - iterated_tiles\n                # Split M first and N second.\n                tile_m_idx = gidx % num_m_tiles\n                tile_n_idx = gidx // num_m_tiles\n\n                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n                tl.static_assert(K % BLOCK_SIZE_K == 0)\n                if USE_TMA_LOAD:\n                    m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)\n                    n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)\n                    for k_offset in range(0, K, BLOCK_SIZE_K):\n                        a = tl._experimental_descriptor_load(\n                            a_desc_ptr,\n                            [m_offset, k_offset],\n                            [BLOCK_SIZE_M, BLOCK_SIZE_K],\n                            dtype,\n                        )\n                        b = tl._experimental_descriptor_load(\n                            b_desc_ptr,\n                            [n_offset, k_offset],\n                            [BLOCK_SIZE_N, BLOCK_SIZE_K],\n                            dtype,\n                        )\n                        accumulator += tl.dot(a, b.T)\n                else:\n                    offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n                    offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n                    offs_k = tl.arange(0, BLOCK_SIZE_K)\n                    a_ptrs = (\n                        a_desc_ptr\n                        + (M_start_offset + offs_am[:, None]) * K\n                        + offs_k[None, :]\n                    )\n                    b_ptrs = (\n                        b_desc_ptr\n                        + (N_start_offset + offs_bn[:, None]) * K\n                        + offs_k[None, :]\n                    )\n                    for k_offset in range(0, K, BLOCK_SIZE_K):\n                        a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)\n                        b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)\n                        accumulator += tl.dot(a, b.T)\n                        a_ptrs += BLOCK_SIZE_K\n                        b_ptrs += BLOCK_SIZE_K\n\n                if USE_TMA_STORE:\n                    m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)\n                    n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)\n                    tl._experimental_descriptor_store(\n                        c_desc_ptr,\n                        accumulator.to(c_ptr.dtype.element_ty),\n                        [m_offset, n_offset],\n                    )\n                else:\n                    offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n                    offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n                    c = accumulator.to(c_ptr.dtype.element_ty)\n                    tl.store(\n                        c_ptr\n                        + (M_start_offset + offs_am[:, None]) * N\n                        + offs_bn[None, :],\n                        c,\n                        mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,\n                    )\n                tidx += NUM_SMS\n\n            iterated_tiles += num_tiles\n\n\nTT_FP8_DTYPE = tl.float8e4b8 if torch.version.hip else tl.float8e4nv\n\n\n@triton.autotune(\n    configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,\n    key=[\"G\", \"M_BUCKET\", \"N\", \"K\"],\n    prune_configs_by={\n        \"early_config_prune\": functools.partial(\n            early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1\n        )\n    },\n)\n@triton.jit\ndef _kernel_grouped_gemm_fp8_rowwise(\n    a_desc_ptr,\n    a_scale_ptr,\n    b_desc_ptr,\n    b_scale_ptr,\n    c_ptr,\n    workspace,\n    m_sizes,\n    # problem sizes\n    G: tl.constexpr,\n    M_BUCKET: tl.constexpr,\n    N: tl.constexpr,\n    K: tl.constexpr,\n    NUM_SMS: tl.constexpr,\n    USE_TMA_LOAD: tl.constexpr,\n    USE_TMA_STORE: tl.constexpr,\n    # tile sizes\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n) -> None:\n    tidx = tl.program_id(0)\n\n    dtype = TT_FP8_DTYPE\n    TMA_SIZE: tl.constexpr = tl.constexpr(128)\n    if USE_TMA_STORE:\n        c_desc_ptr = workspace + tidx * TMA_SIZE\n    else:\n        c_desc_ptr = None\n\n    M_end_offset = 0\n    iterated_tiles = 0\n    for g in tl.range(G):\n        # Move across groups\n        M_start_offset = M_end_offset\n        m_size = tl.load(m_sizes + g)\n        M_end_offset = M_start_offset + m_size\n\n        if m_size > 0:\n            N_start_offset = g * N\n            n_size = N\n            num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)\n            num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)\n            num_tiles = num_m_tiles * num_n_tiles\n\n            if USE_TMA_STORE:\n                # pyre-ignore\n                tl.extra.cuda.experimental_device_tensormap_create2d(\n                    desc_ptr=c_desc_ptr,\n                    global_address=c_ptr + M_start_offset * N,\n                    load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],\n                    global_size=[m_size, n_size],\n                    element_ty=c_ptr.dtype.element_ty,\n                )\n                # pyre-ignore\n                tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)\n\n            # Move across tiles\n            while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:\n                gidx = tidx - iterated_tiles\n                # Split M first and N second.\n                tile_m_idx = gidx % num_m_tiles\n                tile_n_idx = gidx // num_m_tiles\n\n                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n                tl.static_assert(K % BLOCK_SIZE_K == 0)\n                if USE_TMA_LOAD:\n                    m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)\n                    n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)\n                    for k_offset in range(0, K, BLOCK_SIZE_K):\n                        a = tl._experimental_descriptor_load(\n                            a_desc_ptr,\n                            [m_offset, k_offset],\n                            [BLOCK_SIZE_M, BLOCK_SIZE_K],\n                            dtype,\n                        )\n                        b = tl._experimental_descriptor_load(\n                            b_desc_ptr,\n                            [n_offset, k_offset],\n                            [BLOCK_SIZE_N, BLOCK_SIZE_K],\n                            dtype,\n                        )\n                        accumulator += tl.dot(a, b.T)\n                else:\n                    offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n                    offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n                    offs_k = tl.arange(0, BLOCK_SIZE_K)\n                    a_ptrs = (\n                        a_desc_ptr\n                        + (M_start_offset + offs_am[:, None]) * K\n                        + offs_k[None, :]\n                    )\n                    b_ptrs = (\n                        b_desc_ptr\n                        + (N_start_offset + offs_bn[:, None]) * K\n                        + offs_k[None, :]\n                    )\n                    for k_offset in range(0, K, BLOCK_SIZE_K):\n                        a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)\n                        b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)\n                        accumulator += tl.dot(a, b.T)\n                        a_ptrs += BLOCK_SIZE_K\n                        b_ptrs += BLOCK_SIZE_K\n\n                offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n                offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n                a_scale = tl.load(\n                    a_scale_ptr + M_start_offset + offs_am[:, None],\n                    mask=offs_am[:, None] < m_size,\n                )\n                b_scale = tl.load(\n                    b_scale_ptr + N_start_offset + offs_bn[None, :],\n                    mask=offs_bn[None, :] < n_size,\n                )\n                c = accumulator.to(tl.float32) * a_scale * b_scale\n\n                if USE_TMA_STORE:\n                    m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)\n                    n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)\n                    tl._experimental_descriptor_store(\n                        c_desc_ptr,\n                        c.to(c_ptr.dtype.element_ty),\n                        [m_offset, n_offset],\n                    )\n                else:\n                    tl.store(\n                        c_ptr\n                        + (M_start_offset + offs_am[:, None]) * N\n                        + offs_bn[None, :],\n                        c,\n                        mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,\n                    )\n                tidx += NUM_SMS\n\n            iterated_tiles += num_tiles\n\n\ndef _grouped_gemm(\n    x: torch.Tensor,\n    w: torch.Tensor,\n    m_sizes: torch.Tensor,\n    x_scale: Optional[torch.Tensor] = None,\n    w_scale: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    if not utils.HAS_TMA_DESC:\n        raise NotImplementedError(\"Grouped GEMM without TMA is not supported yet\")\n\n    G = m_sizes.shape[0]\n\n    assert x.is_contiguous()\n    assert w.is_contiguous()\n    assert m_sizes.is_contiguous()\n\n    M, K = x.shape\n    N = w.shape[0] // G\n    assert K == w.shape[1]\n\n    y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)\n\n    NUM_SMS = torch.cuda.get_device_properties(\"cuda\").multi_processor_count\n    USE_TMA_LOAD = not torch.version.hip\n    USE_TMA_STORE = False\n\n    desc_helper = None\n    desc_x = x\n    desc_w = w\n    workspace = None\n\n    if USE_TMA_LOAD:\n        desc_helper = utils.TmaAutoTuneHelper()\n        desc_helper.init_tma_descriptor(\"x\")\n        desc_helper.init_tma_descriptor(\"w\")\n        desc_x = desc_helper.get_tma_descriptor_kernel_param(\"x\")\n        desc_w = desc_helper.get_tma_descriptor_kernel_param(\"w\")\n\n    if USE_TMA_STORE:\n        workspace = torch.empty(\n            NUM_SMS * utils.TmaAutoTuneHelper.TMA_SIZE,\n            device=x.device,\n            dtype=torch.uint8,\n        )\n\n    def grid(META):\n        if USE_TMA_LOAD:\n            nonlocal desc_helper\n            desc_helper.fill_2d_tma_descriptor(\n                \"x\",\n                x.data_ptr(),\n                M,\n                K,\n                META[\"BLOCK_SIZE_M\"],\n                META[\"BLOCK_SIZE_K\"],\n                x.element_size(),\n            )\n\n            desc_helper.fill_2d_tma_descriptor(\n                \"w\",\n                w.data_ptr(),\n                N * G,\n                K,\n                META[\"BLOCK_SIZE_N\"],\n                META[\"BLOCK_SIZE_K\"],\n                w.element_size(),\n            )\n\n        return (NUM_SMS,)\n\n    M_BUCKET = triton.next_power_of_2(M)\n    if x_scale is not None and w_scale is not None:\n        assert x_scale.is_contiguous()\n        assert w_scale.is_contiguous()\n        _kernel_grouped_gemm_fp8_rowwise[grid](\n            desc_x,\n            x_scale,\n            desc_w,\n            w_scale,\n            y,\n            workspace,\n            m_sizes,\n            G,\n            M_BUCKET,\n            N,\n            K,\n            NUM_SMS,\n            USE_TMA_LOAD,\n            USE_TMA_STORE,\n        )\n    else:\n        assert x_scale is None\n        assert w_scale is None\n        _kernel_grouped_gemm[grid](\n            desc_x,\n            desc_w,\n            y,\n            workspace,\n            m_sizes,\n            G,\n            M_BUCKET,\n            N,\n            K,\n            NUM_SMS,\n            USE_TMA_LOAD,\n            USE_TMA_STORE,\n        )\n\n    return y\n\n\ndef grouped_gemm(\n    x: torch.Tensor, w: torch.Tensor, m_sizes: torch.Tensor\n) -> torch.Tensor:\n    return _grouped_gemm(x, w, m_sizes)\n\n\ndef grouped_gemm_fp8_rowwise(\n    x: torch.Tensor,\n    w: torch.Tensor,\n    m_sizes: torch.Tensor,\n    x_scale: torch.Tensor,\n    w_scale: torch.Tensor,\n) -> torch.Tensor:\n    return _grouped_gemm(x, w, m_sizes, x_scale, w_scale)\n"
  },
  {
    "path": "dev/triton_groupGEMM/testing/base_testing.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\n# pyre-strict\n\nimport logging\n\n# Configure logging to print to stdout\nlogging.basicConfig(\n    level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\"\n)\n\nimport os\nimport sys\nimport unittest\nfrom typing import Tuple\n\nimport torch\n\n# Add parent directory to path\nsys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))\n\n\nif torch.cuda.is_available():\n    # from fp8_gemm import quantize_fp8_row\n    from groupgemm import grouped_gemm  # , grouped_gemm_fp8_rowwise\n    from tma_utils import HAS_TMA_DESC\n\n\n@unittest.skipIf(\n    not torch.cuda.is_available()\n    or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9\n    or not HAS_TMA_DESC,\n    \"Skip when H100 or TMA is not available\",\n)\nclass TestGroupedGEMM(unittest.TestCase):\n    def setUp(self) -> None:\n        torch.manual_seed(0)\n\n    \"\"\"def test_grouped_gemm_fp8_rowwise(self) -> None:\n        def _test_grouped_gemm_fp8_rowwise(\n            shape: Tuple[int, int, int, int],\n            device: torch.device,\n        ) -> None:\n            G, M, N, K = shape\n            a = torch.randn(M, K, dtype=torch.bfloat16, device=device)\n            b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)\n            m_ends, _ = torch.sort(\n                torch.randint(\n                    low=0, high=M, size=[G - 1], device=device, dtype=torch.int32\n                )\n            )\n            m_ends = m_ends.tolist()\n            m_starts = [0] + m_ends\n            m_ends = m_ends + [M]\n            m_sizes = torch.tensor(\n                [m_ends[i] - m_starts[i] for i in range(G)], device=device\n            ).to(torch.int32)\n\n            a_fp8, a_scale = quantize_fp8_row(a)\n            b_fp8, b_scale = quantize_fp8_row(b)\n\n            result = grouped_gemm_fp8_rowwise(\n                a_fp8,\n                b_fp8,\n                m_sizes,\n                a_scale,\n                b_scale,\n            )\n            self.assertTrue(result.shape == (M, N))\n\n            expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)\n            # Running baseline with quantization to exclude quantization error from the test as it has nothing to do with the correctness of the kernel implementation.\n            for g in range(G):\n                m_start = m_starts[g]\n                m_end = m_ends[g]\n                n_start = g * N\n                n_end = (g + 1) * N\n\n                expected_result[m_start:m_end, :] = (\n                    a_fp8[m_start:m_end, :].to(torch.float32)\n                    @ b_fp8[n_start:n_end, :].to(torch.float32).T\n                    * a_scale[m_start:m_end][:, None]\n                    * b_scale[n_start:n_end][None, :]\n                ).to(torch.bfloat16)\n\n            torch.testing.assert_close(result, expected_result, atol=2e-2, rtol=1.6e-2)\n\n        for G in (1, 4, 16):\n            for M in (64, 512):\n                logging.info(f\"Testing FP8 GMM with G={G}, M={M}\")\n                _test_grouped_gemm_fp8_rowwise((G, M, 256, 256), torch.device(\"cuda\"))\n        \"\"\"\n\n    def test_grouped_gemm_bf16(self) -> None:\n        def _test_grouped_gemm_bf16(\n            shape: Tuple[int, int, int, int],\n            device: torch.device,\n        ) -> None:\n            G, M, N, K = shape\n            a = torch.randn(M, K, dtype=torch.bfloat16, device=device)\n            b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)\n            m_ends, _ = torch.sort(\n                torch.randint(\n                    low=0, high=M, size=[G - 1], device=device, dtype=torch.int32\n                )\n            )\n            m_ends = m_ends.tolist()\n            m_starts = [0] + m_ends\n            m_ends = m_ends + [M]\n            m_sizes = torch.tensor(\n                [m_ends[i] - m_starts[i] for i in range(G)], device=device\n            ).to(torch.int32)\n\n            result = grouped_gemm(\n                a,\n                b,\n                m_sizes,\n            )\n            self.assertTrue(result.shape == (M, N))\n\n            expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)\n            for g in range(G):\n                m_start = m_starts[g]\n                m_end = m_ends[g]\n                expected_result[m_start:m_end, :] = (\n                    a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T\n                )\n\n            torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)\n\n        for G in (1, 4, 16):\n            for M in (64, 512):\n                logging.info(f\"Testing BF16 GMM with G={G}, M={M}\")\n                _test_grouped_gemm_bf16((G, M, 256, 256), torch.device(\"cuda\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main(exit=False)\n"
  },
  {
    "path": "dev/triton_groupGEMM/testing/unit_tests.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\n# pyre-unsafe\n# This code is derived from: https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gemm/triton_gemm\n\n\nimport logging\nimport unittest\nfrom typing import Tuple\n\nimport torch\n\n# Add parent directory to path\nsys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))\n\nfrom groupgemm import grouped_gemm\n\n\nclass TestGroupedGEMM(unittest.TestCase):\n    def test_grouped_gemm_bf16(self) -> None:\n        def _test_grouped_gemm_bf16(\n            shape: Tuple[int, int, int, int],\n            device: torch.device,\n        ) -> None:\n            G, M, N, K = shape\n            a = torch.randn(M, K, dtype=torch.bfloat16, device=device)\n            b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)\n            m_ends, _ = torch.sort(\n                torch.randint(\n                    low=0, high=M, size=[G - 1], device=device, dtype=torch.int32\n                )\n            )\n            m_ends = m_ends.tolist()\n            m_starts = [0] + m_ends\n            m_ends = m_ends + [M]\n            m_sizes = torch.tensor(\n                [m_ends[i] - m_starts[i] for i in range(G)], device=device\n            ).to(torch.int32)\n            result = grouped_gemm(\n                a,\n                b,\n                m_sizes,\n            )\n            self.assertTrue(result.shape == (M, N))\n            expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)\n            for g in range(G):\n                m_start = m_starts[g]\n                m_end = m_ends[g]\n                expected_result[m_start:m_end, :] = (\n                    a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T\n                )\n            torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)\n\n        for G in (1, 4, 16):\n            for M in (64, 512):\n                logging.info(f\"Testing BF16 GMM with G={G}, M={M}\")\n                _test_grouped_gemm_bf16((G, M, 256, 256), torch.device(\"cuda\"))\n\n    def test_grouped_gemm_bf16_various_dimensions(self) -> None:\n        \"\"\"Test grouped_gemm with bf16 precision and various dimensions\"\"\"\n\n        def _test_grouped_gemm_bf16(\n            shape: Tuple[int, int, int, int],\n            device: torch.device,\n        ) -> None:\n            G, M, N, K = shape\n            a = torch.randn(M, K, dtype=torch.bfloat16, device=device)\n            b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)\n            m_ends, _ = torch.sort(\n                torch.randint(\n                    low=0, high=M, size=[G - 1], device=device, dtype=torch.int32\n                )\n            )\n            m_ends = m_ends.tolist()\n            m_starts = [0] + m_ends\n            m_ends = m_ends + [M]\n            m_sizes = torch.tensor(\n                [m_ends[i] - m_starts[i] for i in range(G)], device=device\n            ).to(torch.int32)\n            result = grouped_gemm(\n                a,\n                b,\n                m_sizes,\n            )\n            self.assertTrue(result.shape == (M, N))\n            expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)\n            for g in range(G):\n                m_start = m_starts[g]\n                m_end = m_ends[g]\n                expected_result[m_start:m_end, :] = (\n                    a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T\n                )\n            torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)\n\n        for G in (4, 8):\n            for M in (128, 256):\n                for N, K in [(128, 256), (256, 128), (64, 64)]:\n                    logging.info(f\"Testing BF16 GMM with G={G}, M={M}, N={N}, K={K}\")\n                    _test_grouped_gemm_bf16((G, M, N, K), torch.device(\"cuda\"))\n\n    def test_grouped_gemm_bf16_edge_cases(self) -> None:\n        \"\"\"Test grouped_gemm with bfloat16 for various edge cases\"\"\"\n        device = torch.device(\"cuda\")\n\n        # Test with G=1 (single group case)\n        G, M, N, K = 1, 32, 32, 32\n        a = torch.randn(M, K, dtype=torch.bfloat16, device=device)\n        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)\n        m_sizes = torch.tensor([M], device=device).to(torch.int32)\n        result = grouped_gemm(a, b, m_sizes)\n        expected_result = a @ b.T\n        torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)\n\n        # Test with uneven group sizes\n        G, M, N, K = 3, 100, 32, 32\n        a = torch.randn(M, K, dtype=torch.bfloat16, device=device)\n        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)\n        m_sizes = torch.tensor([25, 50, 25], device=device).to(torch.int32)\n        result = grouped_gemm(a, b, m_sizes)\n        self.assertTrue(result.shape == (M, N))\n        expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)\n        m_start = 0\n        for g in range(G):\n            m_end = m_start + m_sizes[g].item()\n            expected_result[m_start:m_end, :] = (\n                a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T\n            )\n            m_start = m_end\n        torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)\n\n        # Test with extremely small matrices\n        G, M, N, K = 2, 8, 8, 8\n        a = torch.randn(M, K, dtype=torch.bfloat16, device=device)\n        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)\n        m_sizes = torch.tensor([4, 4], device=device).to(torch.int32)\n        result = grouped_gemm(a, b, m_sizes)\n        expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)\n        m_start = 0\n        for g in range(G):\n            m_end = m_start + m_sizes[g].item()\n            expected_result[m_start:m_end, :] = (\n                a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T\n            )\n            m_start = m_end\n        torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)\n\n        # Test with large group count but small matrix sizes\n        G, M, N, K = 32, 128, 16, 16\n        a = torch.randn(M, K, dtype=torch.bfloat16, device=device)\n        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)\n        m_sizes = torch.ones(G, device=device).to(torch.int32) * (M // G)\n        m_sizes[-1] += M % G  # Adjust the last group size to account for remainder\n        result = grouped_gemm(a, b, m_sizes)\n        expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)\n        m_start = 0\n        for g in range(G):\n            m_end = m_start + m_sizes[g].item()\n            expected_result[m_start:m_end, :] = (\n                a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T\n            )\n            m_start = m_end\n        torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)\n\n    def test_grouped_gemm_bf16_invalid_inputs(self) -> None:\n        \"\"\"Test grouped_gemm with invalid inputs to ensure proper error handling\"\"\"\n        device = torch.device(\"cuda\")\n\n        # Test with mismatched dimensions\n        G, M, N, K = 2, 64, 32, 32\n        a = torch.randn(\n            M, K + 1, dtype=torch.bfloat16, device=device\n        )  # Wrong K dimension\n        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)\n        m_sizes = torch.tensor([32, 32], device=device).to(torch.int32)\n\n        with self.assertRaises(RuntimeError):\n            grouped_gemm(a, b, m_sizes)\n\n        # Test with mismatched G and m_sizes length\n        G, M, N, K = 2, 64, 32, 32\n        a = torch.randn(M, K, dtype=torch.bfloat16, device=device)\n        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)\n        m_sizes = torch.tensor([32, 32, 32], device=device).to(\n            torch.int32\n        )  # Too many groups\n\n        with self.assertRaises((RuntimeError, ValueError, IndexError)):\n            grouped_gemm(a, b, m_sizes)\n\n        # Test with incorrect sum of m_sizes\n        G, M, N, K = 2, 64, 32, 32\n        a = torch.randn(M, K, dtype=torch.bfloat16, device=device)\n        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)\n        m_sizes = torch.tensor([32, 40], device=device).to(torch.int32)  # Sum > M\n\n        with self.assertRaises((RuntimeError, ValueError, IndexError)):\n            grouped_gemm(a, b, m_sizes)\n\n        # Test with negative m_sizes values\n        G, M, N, K = 2, 64, 32, 32\n        a = torch.randn(M, K, dtype=torch.bfloat16, device=device)\n        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)\n        m_sizes = torch.tensor([40, -8], device=device).to(\n            torch.int32\n        )  # Negative group size\n\n        with self.assertRaises((RuntimeError, ValueError)):\n            grouped_gemm(a, b, m_sizes)\n\n    def test_grouped_gemm_bf16_deterministic(self) -> None:\n        \"\"\"Test that grouped_gemm produces deterministic results with the same inputs\"\"\"\n        G, M, N, K = 4, 128, 64, 64\n        device = torch.device(\"cuda\")\n\n        # Fix the random seed for reproducibility\n        torch.manual_seed(42)\n\n        a = torch.randn(M, K, dtype=torch.bfloat16, device=device)\n        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)\n        m_sizes = torch.tensor([32, 32, 32, 32], device=device).to(torch.int32)\n\n        # First run\n        result1 = grouped_gemm(a, b, m_sizes)\n\n        # Second run with same inputs\n        result2 = grouped_gemm(a, b, m_sizes)\n\n        # Results should be exactly the same\n        self.assertTrue(torch.all(result1 == result2))\n\n    def test_grouped_gemm_bf16_large_matrices(self) -> None:\n        \"\"\"Test grouped_gemm with larger matrices to stress test performance and stability\"\"\"\n        device = torch.device(\"cuda\")\n\n        # Test with large matrices but fewer groups\n        G, M, N, K = 2, 2048, 512, 1024\n        a = torch.randn(M, K, dtype=torch.bfloat16, device=device)\n        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)\n        m_sizes = torch.tensor([1024, 1024], device=device).to(torch.int32)\n\n        result = grouped_gemm(a, b, m_sizes)\n        expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)\n\n        m_start = 0\n        for g in range(G):\n            m_end = m_start + m_sizes[g].item()\n            expected_result[m_start:m_end, :] = (\n                a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T\n            )\n            m_start = m_end\n\n        torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)\n\n\nif __name__ == \"__main__\":\n    unittest.main(argv=[\"first-arg-is-ignored\"], exit=False)\n"
  },
  {
    "path": "dev/triton_groupGEMM/tma_utils.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\n# pyre-unsafe\n# This code is derived from: https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gemm/triton_gemm\n\nimport sys\n\nimport torch\nimport triton  # @manual\n\nimport triton.language as tl  # @manual\n\n\ndef map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:\n    \"\"\"\n    Maps torch dtype to triton dtype.\n\n    Args:\n        dtype (torch.dtype): input dtype.\n\n    Returns:\n        tl.dtype: triton dtype.\n    \"\"\"\n    if dtype == torch.float16:\n        return tl.float16\n    elif dtype == torch.bfloat16:\n        return tl.bfloat16\n    elif dtype == torch.float32:\n        return tl.float32\n    elif dtype == torch.int32:\n        return tl.int32\n    elif dtype == torch.float8_e4m3fn and torch.version.hip is None:\n        return tl.float8e4nv\n    else:\n        raise ValueError(f\"Unsupported dtype {dtype}\")\n\n\n# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498).\nHAS_TMA_DESC = \"nv_tma_desc_type\" in dir(tl)\n\nif HAS_TMA_DESC:\n    print(\n        \"TMA benchmarks will be running with experimental grid constant TMA descriptor.\",\n        file=sys.stderr,\n    )\nelse:\n    print(\n        \"Missing: This group gemm code will not run without TMA descriptor support....\",\n        file=sys.stderr,\n    )\n    raise NotImplementedError(\"grouped Gemm without TMA is not supported\")\n\n\nclass TmaAutoTuneHelper:\n\n    # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498\n    class KernelParamWrapper:\n        def __init__(self, desc):\n            self.desc = desc\n\n        def tma_desc_cpu_ptr(self):\n            return self.desc.data_ptr()\n\n    TMA_SIZE = 128\n\n    def __init__(self):\n        self.fill_1d_tma_descriptor_inner = (\n            triton.runtime.driver.active.utils.fill_1d_tma_descriptor\n        )\n        self.fill_2d_tma_descriptor_inner = (\n            triton.runtime.driver.active.utils.fill_2d_tma_descriptor\n        )\n        if HAS_TMA_DESC:\n            self.descriptors = {}\n        else:\n            self.cuda_descriptors = {}\n\n    # Call this method outside of the lambda function for grid size\n    def init_tma_descriptor(self, name):\n        if HAS_TMA_DESC:\n            self.descriptors[name] = torch.empty(\n                TmaAutoTuneHelper.TMA_SIZE, device=\"cpu\", dtype=torch.int8\n            )\n        else:\n            self.cuda_descriptors[name] = torch.empty(\n                TmaAutoTuneHelper.TMA_SIZE, device=\"cuda\", dtype=torch.int8\n            )\n\n    # Call this method inside the lambda function for grid size\n    def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):\n        if HAS_TMA_DESC:\n            desc_x = self.descriptors[name]\n            assert desc_x.data_ptr() % 64 == 0\n            self.fill_1d_tma_descriptor_inner(\n                ptr, dim, block_dim, element_size, desc_x.data_ptr()\n            )\n        else:\n            desc_x = self.cuda_descriptors[name]\n            buf_x = torch.empty_like(desc_x, device=\"cpu\", pin_memory=True)\n            self.fill_1d_tma_descriptor_inner(\n                ptr, dim, block_dim, element_size, buf_x.data_ptr()\n            )\n            desc_x.copy_(buf_x, non_blocking=True)\n\n    # Call this method inside the lambda function for grid size\n    def fill_2d_tma_descriptor(\n        self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size\n    ):\n        if HAS_TMA_DESC:\n            desc_x = self.descriptors[name]\n            assert desc_x.data_ptr() % 64 == 0\n            self.fill_2d_tma_descriptor_inner(\n                ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()\n            )\n        else:\n            desc_x = self.cuda_descriptors[name]\n            buf_x = torch.empty_like(desc_x, device=\"cpu\", pin_memory=True)\n            self.fill_2d_tma_descriptor_inner(\n                ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()\n            )\n            desc_x.copy_(buf_x, non_blocking=True)\n\n    def get_tma_descriptor_kernel_param(self, name):\n        if HAS_TMA_DESC:\n            assert self.descriptors[name] is not None\n            return self.KernelParamWrapper(self.descriptors[name])\n        else:\n            assert self.cuda_descriptors[name] is not None\n            return self.cuda_descriptors[name]\n"
  },
  {
    "path": "dev/triton_groupGEMM/triton_tutorial_groupgemm.py",
    "content": "\"\"\"\nGroup GEMM\n============================\nThis group gemm kernel launches a fixed number of CTA to compute a group\nof gemms. The scheduling is static and we do it on device.\n\"\"\"\n\n# Copyright (c) 2023 - 2025 NVIDIA Corporation & Affiliates. All rights reserved.\n#\n# Permission is hereby granted, free of charge, to any person obtaining\n# a copy of this software and associated documentation files\n# (the \"Software\"), to deal in the Software without restriction,\n# including without limitation the rights to use, copy, modify, merge,\n# publish, distribute, sublicense, and/or sell copies of the Software,\n# and to permit persons to whom the Software is furnished to do so,\n# subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be\n# included in all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND,\n# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF\n# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.\n# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY\n# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,\n# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE\n# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n\n# from: https://github.com/triton-lang/triton/blob/main/python/tutorials/08-grouped-gemm.py\n\nfrom typing import Optional\n\nimport torch\n\nimport triton\nimport triton.language as tl\n\nDEVICE = triton.runtime.driver.active.get_active_torch_device()\n\n\ndef is_cuda():\n    return triton.runtime.driver.active.get_current_target().backend == \"cuda\"\n\n\ndef supports_tma():\n    return is_cuda() and torch.cuda.get_device_capability()[0] >= 9\n\n\ndef num_sms():\n    if is_cuda():\n        return torch.cuda.get_device_properties(\"cuda\").multi_processor_count\n    return 148\n\n\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\n                \"BLOCK_SIZE_M\": 128,\n                \"BLOCK_SIZE_N\": 128,\n                \"BLOCK_SIZE_K\": 32,\n                \"NUM_SM\": 84,\n            }\n        ),\n        triton.Config(\n            {\n                \"BLOCK_SIZE_M\": 128,\n                \"BLOCK_SIZE_N\": 128,\n                \"BLOCK_SIZE_K\": 32,\n                \"NUM_SM\": 128,\n            }\n        ),\n        triton.Config(\n            {\n                \"BLOCK_SIZE_M\": 64,\n                \"BLOCK_SIZE_N\": 64,\n                \"BLOCK_SIZE_K\": 32,\n                \"NUM_SM\": 84,\n            }\n        ),\n        triton.Config(\n            {\n                \"BLOCK_SIZE_M\": 64,\n                \"BLOCK_SIZE_N\": 64,\n                \"BLOCK_SIZE_K\": 32,\n                \"NUM_SM\": 128,\n            }\n        ),\n        triton.Config(\n            {\n                \"BLOCK_SIZE_M\": 128,\n                \"BLOCK_SIZE_N\": 128,\n                \"BLOCK_SIZE_K\": 64,\n                \"NUM_SM\": num_sms(),\n            }\n        ),\n        triton.Config(\n            {\n                \"BLOCK_SIZE_M\": 64,\n                \"BLOCK_SIZE_N\": 128,\n                \"BLOCK_SIZE_K\": 64,\n                \"NUM_SM\": num_sms(),\n            }\n        ),\n    ],\n    key=[\"group_size\"],\n)\n@triton.jit\ndef grouped_matmul_kernel(\n    # device tensor of matrices pointers\n    group_a_ptrs,\n    group_b_ptrs,\n    group_c_ptrs,\n    # device tensor of gemm sizes. its shape is [group_size, 3]\n    # dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemm\n    group_gemm_sizes,\n    # device tensor of leading dimension sizes. its shape is [group_size, 3]\n    # dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemm\n    g_lds,\n    # number of gemms\n    group_size,\n    # number of virtual SM\n    NUM_SM: tl.constexpr,\n    # tile sizes\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n):\n    tile_idx = tl.program_id(0)\n    last_problem_end = 0\n    for g in range(group_size):\n        # get the gemm size of the current problem\n        gm = tl.load(group_gemm_sizes + g * 3)\n        gn = tl.load(group_gemm_sizes + g * 3 + 1)\n        gk = tl.load(group_gemm_sizes + g * 3 + 2)\n        num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)\n        num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)\n        num_tiles = num_m_tiles * num_n_tiles\n        # iterate through the tiles in the current gemm problem\n        while tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles:\n            # pick up a tile from the current gemm problem\n            k = gk\n            lda = tl.load(g_lds + g * 3)\n            ldb = tl.load(g_lds + g * 3 + 1)\n            ldc = tl.load(g_lds + g * 3 + 2)\n            a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16))\n            b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16))\n            c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16))\n            # figure out tile coordinates\n            tile_idx_in_gemm = tile_idx - last_problem_end\n            tile_m_idx = tile_idx_in_gemm // num_n_tiles\n            tile_n_idx = tile_idx_in_gemm % num_n_tiles\n\n            # do regular gemm here\n            offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n            offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n            offs_k = tl.arange(0, BLOCK_SIZE_K)\n            a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :]\n            b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :]\n            accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n            for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):\n                # hint to Triton compiler to do proper loop pipelining\n                tl.multiple_of(a_ptrs, [16, 16])\n                tl.multiple_of(b_ptrs, [16, 16])\n                # assume full tile for now\n                a = tl.load(a_ptrs)\n                b = tl.load(b_ptrs)\n                accumulator += tl.dot(a, b)\n                a_ptrs += BLOCK_SIZE_K\n                b_ptrs += BLOCK_SIZE_K * ldb\n            c = accumulator.to(tl.float16)\n\n            offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n            offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n            c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :]\n\n            # assumes full tile for now\n            tl.store(c_ptrs, c)\n\n            # go to the next tile by advancing NUM_SM\n            tile_idx += NUM_SM\n\n        # get ready to go to the next gemm problem\n        last_problem_end = last_problem_end + num_tiles\n\n\ndef group_gemm_fn(group_A, group_B):\n    assert len(group_A) == len(group_B)\n    group_size = len(group_A)\n\n    A_addrs = []\n    B_addrs = []\n    C_addrs = []\n    g_sizes = []\n    g_lds = []\n    group_C = []\n    for i in range(group_size):\n        A = group_A[i]\n        B = group_B[i]\n        assert A.shape[1] == B.shape[0]\n        M, K = A.shape\n        K, N = B.shape\n        C = torch.empty((M, N), device=DEVICE, dtype=A.dtype)\n        group_C.append(C)\n        A_addrs.append(A.data_ptr())\n        B_addrs.append(B.data_ptr())\n        C_addrs.append(C.data_ptr())\n        g_sizes += [M, N, K]\n        g_lds += [A.stride(0), B.stride(0), C.stride(0)]\n\n    # note these are device tensors\n    d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)\n    d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)\n    d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)\n    d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)\n    d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)\n    # we use a fixed number of CTA, and it's auto-tunable\n    grid = lambda META: (META[\"NUM_SM\"],)\n    grouped_matmul_kernel[grid](\n        d_a_ptrs,\n        d_b_ptrs,\n        d_c_ptrs,\n        d_g_sizes,\n        d_g_lds,\n        group_size,\n    )\n\n    return group_C\n\n\ntma_configs = [\n    triton.Config(\n        {\"BLOCK_SIZE_M\": BM, \"BLOCK_SIZE_N\": BN, \"BLOCK_SIZE_K\": BK},\n        num_stages=s,\n        num_warps=w,\n    )\n    for BM in [128]\n    for BN in [128, 256]\n    for BK in [64, 128]\n    for s in ([3, 4])\n    for w in [4, 8]\n]\n\n\n@triton.autotune(\n    tma_configs,\n    key=[\"group_a_ptrs\", \"group_b_ptrs\", \"gropup_c_ptrs\", \"group_size\"],\n)\n@triton.jit\ndef grouped_matmul_tma_kernel(\n    # device tensor of matrices pointers\n    group_a_ptrs,\n    group_b_ptrs,\n    group_c_ptrs,\n    # device tensor of gemm sizes. its shape is [group_size, 3]\n    # dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemm\n    group_gemm_sizes,\n    # device tensor of leading dimension sizes. its shape is [group_size, 3]\n    # dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemm\n    g_lds,\n    # number of gemms\n    group_size,\n    # number of virtual SM\n    NUM_SM: tl.constexpr,\n    # tile sizes\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    # is the output FP8 or FP16\n    FP8: tl.constexpr,\n):\n    dtype = tl.float8e4nv if FP8 else tl.float16\n    tile_idx = tl.program_id(0)\n    last_problem_end = 0\n    for g in range(group_size):\n        # get the gemm size of the current problem\n        gm = tl.load(group_gemm_sizes + g * 3)\n        gn = tl.load(group_gemm_sizes + g * 3 + 1)\n        gk = tl.load(group_gemm_sizes + g * 3 + 2)\n        num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)\n        num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)\n        num_tiles = num_m_tiles * num_n_tiles\n        if tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles:\n            # pick up a tile from the current gemm problem\n            lda = tl.load(g_lds + g * 3)\n            ldb = tl.load(g_lds + g * 3 + 1)\n            ldc = tl.load(g_lds + g * 3 + 2)\n\n            a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(dtype))\n            b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(dtype))\n            c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(dtype))\n\n            a_desc = tl._experimental_make_tensor_descriptor(\n                a_ptr,\n                shape=[gm, gk],\n                strides=[lda, 1],\n                block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],\n            )\n\n            b_desc = tl._experimental_make_tensor_descriptor(\n                b_ptr,\n                shape=[gn, gk],\n                strides=[ldb, 1],\n                block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],\n            )\n            c_desc = tl._experimental_make_tensor_descriptor(\n                c_ptr,\n                shape=[gm, gn],\n                strides=[ldc, 1],\n                block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],\n            )\n\n            # iterate through the tiles in the current gemm problem\n            while (\n                tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles\n            ):\n                k = gk\n                # figure out tile coordinates\n                tile_idx_in_gemm = tile_idx - last_problem_end\n                tile_m_idx = tile_idx_in_gemm // num_n_tiles\n                tile_n_idx = tile_idx_in_gemm % num_n_tiles\n\n                # do regular gemm here\n                offs_am = tile_m_idx * BLOCK_SIZE_M\n                offs_bn = tile_n_idx * BLOCK_SIZE_N\n\n                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n                for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):\n                    a = a_desc.load([offs_am, kk * BLOCK_SIZE_K])\n                    b = b_desc.load([offs_bn, kk * BLOCK_SIZE_K])\n                    accumulator += tl.dot(a, b.T)\n\n                offs_cm = tile_m_idx * BLOCK_SIZE_M\n                offs_cn = tile_n_idx * BLOCK_SIZE_N\n\n                c = accumulator.to(dtype)\n                c_desc.store([offs_cm, offs_cn], c)\n\n                # go to the next tile by advancing NUM_SM\n                tile_idx += NUM_SM\n\n        # get ready to go to the next gemm problem\n        last_problem_end = last_problem_end + num_tiles\n\n\ndef group_gemm_tma_fn(group_A, group_B):\n\n    assert supports_tma()\n\n    assert len(group_A) == len(group_B)\n    group_size = len(group_A)\n\n    A_addrs = []\n    B_addrs = []\n    C_addrs = []\n    g_sizes = []\n    g_lds = []\n    group_C = []\n    for i in range(group_size):\n        A = group_A[i]\n        B = group_B[i]\n        assert A.shape[1] == B.shape[1]\n        M, K = A.shape\n        N, K = B.shape\n        C = torch.empty((M, N), device=DEVICE, dtype=A.dtype)\n        group_C.append(C)\n        A_addrs.append(A.data_ptr())\n        B_addrs.append(B.data_ptr())\n        C_addrs.append(C.data_ptr())\n        g_sizes += [M, N, K]\n        g_lds += [A.stride(0), B.stride(0), C.stride(0)]\n    # note these are device tensors\n    d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)\n    d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)\n    d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)\n    d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)\n    d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)\n\n    # we use a fixed number of CTA, and it's auto-tunable\n\n    # TMA descriptors require a global memory allocation\n    def alloc_fn(size: int, alignment: int, stream: Optional[int]):\n        return torch.empty(size, device=\"cuda\", dtype=torch.int8)\n\n    triton.set_allocator(alloc_fn)\n\n    grid = lambda META: (META[\"NUM_SM\"],)\n    grouped_matmul_tma_kernel[grid](\n        d_a_ptrs,\n        d_b_ptrs,\n        d_c_ptrs,\n        d_g_sizes,\n        d_g_lds,\n        group_size,\n        FP8=torch.float8_e4m3fn == group_A[0].dtype,\n        NUM_SM=num_sms(),\n    )\n    return group_C\n\n\ngroup_m = [1024, 512, 256, 128]\ngroup_n = [1024, 512, 256, 128]\ngroup_k = [1024, 512, 256, 128]\ngroup_A = []\ngroup_B = []\ngroup_B_T = []\nassert len(group_m) == len(group_n)\nassert len(group_n) == len(group_k)\ngroup_size = len(group_m)\nfor i in range(group_size):\n    M = group_m[i]\n    N = group_n[i]\n    K = group_k[i]\n    A = torch.rand((M, K), device=DEVICE, dtype=torch.float16)\n    B = torch.rand((K, N), device=DEVICE, dtype=torch.float16)\n    B_T = B.T.contiguous()\n    group_A.append(A)\n    group_B.append(B)\n    group_B_T.append(B_T)\n\ntri_out = group_gemm_fn(group_A, group_B)\nref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)]\nfor i in range(group_size):\n    assert torch.allclose(ref_out[i], tri_out[i], atol=1e-2, rtol=0)\n\nif supports_tma():\n    tri_tma_out = group_gemm_tma_fn(group_A, group_B_T)\n    for i in range(group_size):\n        assert torch.allclose(ref_out[i], tri_tma_out[i], atol=1e-2, rtol=0)\n\n\n# only launch the kernel, no tensor preparation here to remove all overhead\ndef triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size):\n    grid = lambda META: (META[\"NUM_SM\"],)\n    grouped_matmul_kernel[grid](\n        a_ptrs,\n        b_ptrs,\n        c_ptrs,\n        sizes,\n        lds,\n        group_size,\n    )\n\n\ndef triton_tma_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, dtype):\n    grid = lambda META: (META[\"NUM_SM\"],)\n    grouped_matmul_tma_kernel[grid](\n        a_ptrs,\n        b_ptrs,\n        c_ptrs,\n        sizes,\n        lds,\n        group_size,\n        FP8=torch.float8_e4m3fn == dtype,\n        NUM_SM=num_sms(),\n    )\n\n\ndef torch_perf_fn(group_A, group_B):\n    for a, b in zip(group_A, group_B):\n        torch.matmul(a, b)\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        # argument names to use as an x-axis for the plot\n        x_names=[\"N\"],\n        x_vals=[2**i for i in range(7, 11)],  # different possible values for `x_name`\n        line_arg=\"provider\",\n        # argument name whose value corresponds to a different line in the plot\n        # possible values for `line_arg``\n        line_vals=[\"cublas\", \"triton\"] + ([\"triton-tma\"] if supports_tma() else []),\n        # label name for the lines\n        line_names=[\"cuBLAS\", \"Triton\"] + ([\"Triton + TMA\"] if supports_tma() else []),\n        # line styles\n        styles=[(\"green\", \"-\"), (\"blue\", \"-\")]\n        + ([(\"red\", \"-\")] if supports_tma() else []),\n        ylabel=\"runtime(ms)\",  # label name for the y-axis\n        plot_name=\"group-gemm-performance\",\n        # name for the plot. Used also as a file name for saving the plot.\n        args={},\n    )\n)\ndef benchmark_square_matrices(N, provider):\n    group_size = 4\n    group_A = []\n    group_B = []\n    group_B_T = []\n    A_addrs = []\n    B_addrs = []\n    B_T_addrs = []\n    C_addrs = []\n    g_sizes = []\n    g_lds = []\n    group_C = []\n    for i in range(group_size):\n        A = torch.rand((N, N), device=DEVICE, dtype=torch.float16)\n        B = torch.rand((N, N), device=DEVICE, dtype=torch.float16)\n        C = torch.empty((N, N), device=DEVICE, dtype=torch.float16)\n        B_T = B.T.contiguous()\n        group_A.append(A)\n        group_B.append(B)\n        group_B_T.append(B_T)\n        group_C.append(C)\n        A_addrs.append(A.data_ptr())\n        B_addrs.append(B.data_ptr())\n        B_T_addrs.append(B_T.data_ptr())\n        C_addrs.append(C.data_ptr())\n        g_sizes += [N, N, N]\n        g_lds += [N, N, N]\n\n    d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)\n    d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)\n    d_b_t_ptrs = torch.tensor(B_T_addrs, device=DEVICE)\n    d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)\n    d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)\n    d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)\n\n    quantiles = [0.5, 0.2, 0.8]\n    if provider == \"cublas\":\n        ms, min_ms, max_ms = triton.testing.do_bench(\n            lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles\n        )\n    if provider == \"triton\":\n        ms, min_ms, max_ms = triton.testing.do_bench(\n            lambda: triton_perf_fn(\n                d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size\n            ),\n            quantiles=quantiles,\n        )\n    if provider == \"triton-tma\":\n        ms, min_ms, max_ms = triton.testing.do_bench(\n            lambda: triton_tma_perf_fn(\n                d_a_ptrs,\n                d_b_t_ptrs,\n                d_c_ptrs,\n                d_g_sizes,\n                d_g_lds,\n                group_size,\n                dtype=torch.float16,\n            ),\n            quantiles=quantiles,\n        )\n    return ms, max_ms, min_ms\n\n\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        # argument names to use as an x-axis for the plot\n        x_names=[\"M\"],\n        x_vals=[2**i for i in range(7, 11)],  # different possible values for `x_name`\n        line_arg=\"provider\",\n        # argument name whose value corresponds to a different line in the plot\n        # possible values for `line_arg``\n        line_vals=[\"cublas\", \"triton\"] + ([\"triton-tma\"] if supports_tma() else []),\n        # label name for the lines\n        line_names=[\"cuBLAS\", \"Triton\"] + ([\"Triton + TMA\"] if supports_tma() else []),\n        # line styles\n        styles=[(\"green\", \"-\"), (\"blue\", \"-\")]\n        + ([(\"red\", \"-\")] if supports_tma() else []),\n        ylabel=\"runtime(ms)\",  # label name for the y-axis\n        plot_name=\"group-gemm-performance-m-8192-k-8192\",\n        # name for the plot. Used also as a file name for saving the plot.\n        args={},\n    )\n)\ndef benchmark_batches(M, provider):\n    N = 8192\n    K = 8192\n    group_size = 4\n    group_A = []\n    group_B = []\n    group_B_T = []\n    A_addrs = []\n    B_addrs = []\n    B_T_addrs = []\n    C_addrs = []\n    g_sizes = []\n    g_lds = []\n    g_T_lds = []\n    group_C = []\n    for i in range(group_size):\n        A = torch.rand((M, K), device=DEVICE, dtype=torch.float16)\n        B = torch.rand((K, N), device=DEVICE, dtype=torch.float16)\n        C = torch.empty((M, N), device=DEVICE, dtype=torch.float16)\n        B_T = B.T.contiguous()\n        group_A.append(A)\n        group_B.append(B)\n        group_B_T.append(B_T)\n        group_C.append(C)\n        A_addrs.append(A.data_ptr())\n        B_addrs.append(B.data_ptr())\n        B_T_addrs.append(B_T.data_ptr())\n        C_addrs.append(C.data_ptr())\n        g_sizes += [M, N, K]\n        g_lds += [A.stride(0), B.stride(0), C.stride(0)]\n        g_T_lds += [A.stride(0), B_T.stride(0), C.stride(0)]\n\n    d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)\n    d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)\n    d_b_t_ptrs = torch.tensor(B_T_addrs, device=DEVICE)\n    d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)\n    d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)\n    d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)\n    d_g_t_lds = torch.tensor(g_T_lds, dtype=torch.int32, device=DEVICE)\n\n    quantiles = [0.5, 0.2, 0.8]\n    if provider == \"cublas\":\n        ms, min_ms, max_ms = triton.testing.do_bench(\n            lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles\n        )\n    if provider == \"triton\":\n        ms, min_ms, max_ms = triton.testing.do_bench(\n            lambda: triton_perf_fn(\n                d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size\n            ),\n            quantiles=quantiles,\n        )\n    if provider == \"triton-tma\":\n        ms, min_ms, max_ms = triton.testing.do_bench(\n            lambda: triton_tma_perf_fn(\n                d_a_ptrs,\n                d_b_t_ptrs,\n                d_c_ptrs,\n                d_g_sizes,\n                d_g_t_lds,\n                group_size,\n                dtype=torch.float16,\n            ),\n            quantiles=quantiles,\n        )\n    return ms, max_ms, min_ms\n\n\nbenchmark_square_matrices.run(show_plots=True, print_data=True)\nbenchmark_batches.run(show_plots=True, print_data=True)\n"
  },
  {
    "path": "kernels/MoE/group_GEMM/triton/readme.md",
    "content": "##  Experimental\n\nTriton Group GEMM for supporting MoE training. \n"
  },
  {
    "path": "kernels/MoE/group_GEMM/triton/testing/fast_verification.py",
    "content": "import logging\n\nimport torch\n\n# Configure logging\nlogging.basicConfig(\n    level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\"\n)\n\n# import the reference implementations\nfrom pytorch_reference_backwards import (\n    _compute_grad_w_pytorch,\n    _compute_grad_x_pytorch,\n    _pytorch_fallback_backward,\n    _pytorch_reference_backward,\n)\n\n# Import the grouped GEMM modules\nfrom tgrouped_gemm_backwards import grouped_gemm_backward\nfrom tgrouped_gemm_forward import grouped_gemm_forward as grouped_gemm\n\n\ndef test_backward_pass():\n    \"\"\"\n    A simple test for the grouped GEMM backward pass with detailed error handling.\n    \"\"\"\n    try:\n        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n        # Test parameters\n        G = 20  # Number of groups\n        M = 1024  # Input dimension\n        N = 512  # Output dimension per group\n        K = 256  # Hidden dimension\n\n        # Create input and weight tensors\n        x = torch.randn(M, K, dtype=torch.bfloat16, device=device, requires_grad=True)\n        w = torch.randn(\n            N * G, K, dtype=torch.bfloat16, device=device, requires_grad=True\n        )\n\n        # Create group sizes\n        m_sizes = torch.zeros(G, device=device, dtype=torch.int32)\n        base_size = M // G\n        remainder = M % G\n\n        for i in range(G):\n            m_sizes[i] = base_size + (1 if i < remainder else 0)\n\n        # Log the setup\n        print(f\"Test setup - G: {G}, M: {M}, N: {N}, K: {K}\")\n        print(f\"Input x shape: {x.shape}\")\n        logging.info(f\"Weight w shape: {w.shape}\")\n        logging.info(f\"Group sizes: {m_sizes}\")\n\n        # Step 1: Run forward pass\n        logging.info(\"Running forward pass\")\n        result = grouped_gemm(x, w, m_sizes)\n        logging.info(f\"Forward result shape: {result.shape}\")\n\n        # Create a gradient for backpropagation\n        grad_output = torch.randn_like(result)\n        logging.info(f\"Created gradient with shape: {grad_output.shape}\")\n\n        # Step 2: Run backward pass directly\n        logging.info(\"Running backward pass directly\")\n        grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes)\n\n        # Verify gradient shapes\n        logging.info(\n            f\"Gradient shapes - grad_x: {grad_x.shape}, grad_w: {grad_w.shape}\"\n        )\n\n        # Step 3: Verify gradient computation using PyTorch's autograd\n        # First create autograd-enabled tensors\n        x_autograd = x.detach().clone().requires_grad_(True)\n        w_autograd = w.detach().clone().requires_grad_(True)\n\n        # Create a PyTorch reference implementation to compare against\n        logging.info(\"Running PyTorch reference implementation\")\n\n        # Compute reference result\n        reference_result = torch.zeros_like(result)\n        m_start = 0\n        for g in range(G):\n            m_size = m_sizes[g].item()\n            m_end = m_start + m_size\n            n_start = g * N\n            n_end = (g + 1) * N\n\n            if m_size > 0:\n                reference_result[m_start:m_end, n_start:n_end] = (\n                    x_autograd[m_start:m_end, :] @ w_autograd[n_start:n_end, :].T\n                )\n\n            m_start = m_end\n\n        # Backpropagate using PyTorch\n        reference_result.backward(grad_output)\n\n        # Compare gradients\n        logging.info(\"Comparing gradients with PyTorch reference\")\n        grad_x_error = (grad_x - x_autograd.grad).abs().max().item()\n        grad_w_error = (grad_w - w_autograd.grad).abs().max().item()\n\n        logging.info(\n            f\"Maximum gradient error - grad_x: {grad_x_error}, grad_w: {grad_w_error}\"\n        )\n\n        # Check if gradients are close using allclose\n        rtol = 1e-2  # Relative tolerance for bfloat16\n        atol = 1e-2  # Absolute tolerance for bfloat16\n\n        grad_x_close = torch.allclose(grad_x, x_autograd.grad, rtol=rtol, atol=atol)\n        if not grad_x_close:\n            logging.warning(\"FAILED: Gradient mismatch detected in grad_x\")\n        else:\n            logging.info(\n                \"✓ SUCCESS! grad_X matches the PyTorch reference (allclose check passed)\"\n            )\n\n        grad_w_close = torch.allclose(grad_w, w_autograd.grad, rtol=rtol, atol=atol)\n        if not grad_w_close:\n            logging.warning(\"FAILED: Gradient mismatch detected in grad_w\")\n        else:\n            logging.info(\n                \"✓ SUCCESS! grad_W matches the PyTorch reference (allclose check passed)\"\n            )\n\n        logging.info(\n            f\"Gradients allclose check - grad_x: {grad_x_close}, grad_w: {grad_w_close}\"\n        )\n\n        if grad_x_close and grad_w_close:\n            logging.info(\n                \"✓ SUCCESS: Gradients match the PyTorch reference (allclose check passed)\"\n            )\n        else:\n            logging.error(\"✗ FAILURE: Gradient mismatch detected in allclose check\")\n\n        # Additional diagnostics (for failed cases or in general)\n        if True:  # not grad_x_close:\n            # Find where the largest differences are\n            diff_x = (grad_x - x_autograd.grad).abs()\n            max_idx_x = diff_x.argmax().item()\n            flat_idx_x = max_idx_x\n            idx_x = np.unravel_index(flat_idx_x, grad_x.shape)\n            logging.error(\n                f\"Largest grad_x difference at {idx_x}: \"\n                f\"{grad_x[idx_x].item()} vs {x_autograd.grad[idx_x].item()}\"\n            )\n            # Count zeros\n            zeros_grad_x = (grad_x == 0).sum().item()\n            zeros_autograd_x = (x_autograd.grad == 0).sum().item()\n            logging.error(\n                f\"Zeros in grad_x: {zeros_grad_x}/{grad_x.numel()} ({zeros_grad_x/grad_x.numel()*100:.2f}%)\"\n            )\n            logging.error(\n                f\"Zeros in x_autograd.grad: {zeros_autograd_x}/{x_autograd.grad.numel()} ({zeros_autograd_x/x_autograd.grad.numel()*100:.2f}%)\"\n            )\n\n        if True:  # not grad_w_close:\n            # Find where the largest differences are\n            diff_w = (grad_w - w_autograd.grad).abs()\n            max_idx_w = diff_w.argmax().item()\n            flat_idx_w = max_idx_w\n            idx_w = np.unravel_index(flat_idx_w, grad_w.shape)\n            logging.error(\n                f\"Largest grad_w difference at {idx_w}: \"\n                f\"{grad_w[idx_w].item()} vs {w_autograd.grad[idx_w].item()}\"\n            )\n            # Count zeros\n            zeros_grad_w = (grad_w == 0).sum().item()\n            zeros_autograd_w = (w_autograd.grad == 0).sum().item()\n            logging.error(\n                f\"Zeros in grad_w: {zeros_grad_w}/{grad_w.numel()} ({zeros_grad_w/grad_w.numel()*100:.2f}%)\"\n            )\n            logging.error(\n                f\"Zeros in w_autograd.grad: {zeros_autograd_w}/{w_autograd.grad.numel()} ({zeros_autograd_w/w_autograd.grad.numel()*100:.2f}%)\"\n            )\n\n        return grad_x_close and grad_w_close\n\n    except Exception as e:\n        logging.error(f\"Test failed with error: {e}\")\n        import traceback\n\n        logging.error(traceback.format_exc())\n        return False\n\n\nif __name__ == \"__main__\":\n    print(\"Running test_backward_pass\")\n    logging.debug(\"Running test_backward_pass\")\n    # Add numpy import for unravel_index\n    import numpy as np\n\n    success = test_backward_pass()\n    logging.info(f\"Test {'succeeded' if success else 'failed'}\")\n"
  },
  {
    "path": "kernels/MoE/group_GEMM/triton/testing/pytorch_reference_backwards.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\n\nimport torch\n\n# This is a series of helper functions for grouped GEMM backward that compute the gradients\n# using eager PyTorch operations. They are used as a verification reference for the Triton kernels.\n# They can also used as a fallback when the Triton kernels cannot be used, though lets hope that is not needed.\n\n\ndef _compute_grad_x_pytorch(grad_output, w, m_sizes, grad_x):\n    \"\"\"\n    Compute grad_x using pure PyTorch operations with FP32 precision\n    \"\"\"\n    G = m_sizes.shape[0]\n    M, K = grad_x.shape\n    N = w.shape[0] // G\n\n    # Zero out the output tensor first\n    grad_x.zero_()\n\n    # Store original dtype and convert to float32 for computation\n    orig_dtype = grad_x.dtype\n    grad_output_fp32 = grad_output.float()\n    w_fp32 = w.float()\n    grad_x_fp32 = torch.zeros_like(grad_x, dtype=torch.float32)\n\n    # Process each group separately\n    m_start = 0\n    for g in range(G):\n        m_size = m_sizes[g].item()\n        if m_size > 0:\n            m_end = m_start + m_size\n            n_start = g * N\n            n_end = (g + 1) * N\n\n            # Get slices for this group\n            grad_output_slice = grad_output_fp32[m_start:m_end, n_start:n_end]\n            w_slice = w_fp32[n_start:n_end]\n\n            # Process in chunks for better precision on large matrices\n            CHUNK_SIZE = 256\n            for chunk_start in range(0, m_size, CHUNK_SIZE):\n                chunk_end = min(chunk_start + CHUNK_SIZE, m_size)\n                chunk_size = chunk_end - chunk_start\n\n                # Compute matrix multiplication with higher precision\n                grad_output_chunk = grad_output_slice[chunk_start:chunk_end]\n                result_chunk = torch.matmul(\n                    grad_output_chunk.double(), w_slice.double()\n                )\n\n                # Store the result\n                grad_x_fp32[m_start + chunk_start : m_start + chunk_end].copy_(\n                    result_chunk.float()\n                )\n\n        m_start = m_end\n\n    # Convert back to original dtype\n    grad_x.copy_(grad_x_fp32.to(orig_dtype))\n\n\ndef _compute_grad_w_pytorch(grad_output, x, m_sizes, grad_w):\n    \"\"\"\n    Compute grad_w using pure PyTorch operations with FP64 precision for better accuracy.\n    \"\"\"\n    G = m_sizes.shape[0]\n    N_times_G, K = grad_w.shape\n    N = N_times_G // G\n\n    # Zero out the output tensor first\n    grad_w.zero_()\n\n    # Store original dtype and convert to float32 for computation\n    orig_dtype = grad_w.dtype\n    grad_output_fp32 = grad_output.float()\n    x_fp32 = x.float()\n    grad_w_fp32 = torch.zeros_like(grad_w, dtype=torch.float32)\n\n    # Handle potential K dimension mismatches\n    K_x = x.shape[1]\n    min_K = min(K, K_x)\n\n    # Process each group separately\n    m_start = 0\n    for g in range(G):\n        m_size = m_sizes[g].item()\n        if m_size > 0:\n            m_end = m_start + m_size\n            n_start = g * N\n            n_end = (g + 1) * N\n\n            # Get slices for this group\n            grad_output_slice = grad_output_fp32[m_start:m_end, n_start:n_end]\n            x_slice = x_fp32[m_start:m_end, :min_K]\n\n            # Process in chunks for better precision\n            CHUNK_SIZE = 32\n            result = torch.zeros(\n                (grad_output_slice.shape[1], min_K),\n                dtype=torch.float64,\n                device=grad_output_slice.device,\n            )\n\n            for chunk_start in range(0, m_size, CHUNK_SIZE):\n                chunk_end = min(chunk_start + CHUNK_SIZE, m_size)\n\n                # Get chunks\n                grad_output_chunk = grad_output_slice[chunk_start:chunk_end].double()\n                x_chunk = x_slice[chunk_start:chunk_end].double()\n\n                # Matrix multiplication in FP64\n                chunk_result = torch.matmul(grad_output_chunk.t(), x_chunk)\n                result += chunk_result\n\n            # Handle K dimension padding if needed\n            if K > min_K:\n                temp_result = torch.zeros(\n                    (grad_output_slice.shape[1], K),\n                    dtype=torch.float32,\n                    device=grad_output_slice.device,\n                )\n                temp_result[:, :min_K] = result.float()\n                grad_w_fp32[n_start:n_end].copy_(temp_result)\n            else:\n                grad_w_fp32[n_start:n_end].copy_(result.float())\n\n        m_start = m_end\n\n    # Convert back to original dtype\n    grad_w.copy_(grad_w_fp32.to(orig_dtype))\n\n\ndef _pytorch_fallback_backward(grad_output, x, w, m_sizes):\n    \"\"\"\n    Pure PyTorch implementation of grouped GEMM backward with high precision.\n    Used as a fallback when the Triton kernels cannot be used.\n    \"\"\"\n    logging.info(\n        \"WARNING:  Using PyTorch fallback for grouped GEMM backward with high precision\"\n    )\n\n    # Ensure inputs are contiguous\n    x = x.contiguous()\n    w = w.contiguous()\n    grad_output = grad_output.contiguous()\n    m_sizes = m_sizes.contiguous()\n\n    # Allocate output tensors\n    grad_x = torch.zeros_like(x)\n    grad_w = torch.zeros_like(w)\n\n    # Compute gradients using the helper functions\n    _compute_grad_x_pytorch(grad_output, w, m_sizes, grad_x)\n    _compute_grad_w_pytorch(grad_output, x, m_sizes, grad_w)\n\n    return grad_x, grad_w\n\n\ndef _pytorch_reference_backward(grad_output, x, w, m_sizes):\n    \"\"\"\n    Pure PyTorch implementation of grouped GEMM backward for validation.\n    Simple version that's easy to verify but may be less numerically accurate\n    for large matrices.\n    \"\"\"\n    # Create output gradients\n    grad_x = torch.zeros_like(x)\n    grad_w = torch.zeros_like(w)\n\n    # Compute group-by-group\n    G = m_sizes.shape[0]\n    N = w.shape[0] // G\n\n    m_start = 0\n    for g in range(G):\n        m_size = m_sizes[g].item()\n        if m_size > 0:\n            m_end = m_start + m_size\n            n_start = g * N\n            n_end = (g + 1) * N\n\n            # Compute gradients\n            grad_x[m_start:m_end] = torch.matmul(\n                grad_output[m_start:m_end, n_start:n_end], w[n_start:n_end]\n            )\n            grad_w[n_start:n_end] = torch.matmul(\n                grad_output[m_start:m_end, n_start:n_end].t(), x[m_start:m_end]\n            )\n\n        m_start += m_size\n\n    return grad_x, grad_w\n\n\n# ========== End helper functions ==========\n"
  },
  {
    "path": "kernels/MoE/group_GEMM/triton/tgroup_gemm_backwards.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport logging\nfrom typing import Tuple\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom tma_utils import TmaAutoTuneHelper\n\n# Configure logging\nlogging.basicConfig(\n    level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\"\n)\n\n\"\"\"\nBackward pass for grouped GEMM with Triton, where grouping is N*G\nWe are computing gradients with respect to both the input (`grad_x`) and the weights (`grad_w`).\n\"\"\"\n\n\n# =============== Start Triton Kernels ===============\n@triton.jit\ndef _kernel_grouped_gemm_backward_x_scheduled(\n    grad_y_ptr,  # grad of dl/dY [M, N*G]\n    w_t_ptr,  # w transposed [K, N*G]\n    grad_x_ptr,  # output of kernel [M, K]\n    group_offsets_ptr,  # Pre-computed group offsets [G+1]\n    workspace,  # Workspace for TMA descriptors\n    G,  # Number of groups\n    M,  # Total M dimension size\n    N,  # N per group\n    K,  # K dimension size\n    stride_go_m,\n    stride_go_n,\n    stride_w_n,\n    stride_w_k,\n    stride_gx_m,\n    stride_gx_k,\n    NUM_SMS,\n    USE_TMA_LOAD: tl.constexpr = False,\n    USE_TMA_STORE: tl.constexpr = False,\n    BLOCK_SIZE_M: tl.constexpr = 64,\n    BLOCK_SIZE_N: tl.constexpr = 64,\n    BLOCK_SIZE_K: tl.constexpr = 64,\n    GROUP_SIZE_M: tl.constexpr = 8,\n    EVEN_K: tl.constexpr = True,\n) -> None:\n    \"\"\"\n    Scheduled grouped GEMM backward for X with TMA support.\n\n    For each group g, computes: grad_x[g] = grad_y[g] @ w_t[g].T\n\n    Where:\n    - grad_y is [M, N*G]\n    - w_t is [K, N*G] (transposed from [N*G, K])\n    - grad_x is [M, K]\n    \"\"\"\n    # Get coordinates for the current program\n    tidx = tl.program_id(axis=0)\n    dtype = grad_x_ptr.dtype.element_ty\n    TMA_SIZE: tl.constexpr = 128\n\n    # Initialize workspace pointer if using TMA store\n    if USE_TMA_STORE:\n        c_desc_ptr = workspace + tidx * TMA_SIZE\n    else:\n        c_desc_ptr = None\n\n    # Calculate work distribution parameters\n    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n    num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n    num_pid_in_group = num_pid_m * num_pid_k\n\n    # Process all assigned work items\n    pid = tidx\n    while pid < G * num_pid_in_group:\n        # Calculate work distribution for this pid\n        group_id = pid // num_pid_in_group\n        pid_in_group = pid % num_pid_in_group\n        pid_m = pid_in_group % num_pid_m\n        pid_k = pid_in_group // num_pid_m\n\n        # Get group boundaries\n        valid_group = group_id < G\n        group_start = tl.where(valid_group, tl.load(group_offsets_ptr + group_id), 0)\n        group_end = tl.where(valid_group, tl.load(group_offsets_ptr + group_id + 1), 0)\n        group_size = group_end - group_start\n\n        # Calculate a mask for valid processing (valid group and non-empty)\n        valid_work = valid_group & (group_size > 0)\n\n        # Only process if we have valid work\n        if valid_work:\n            # Compute offsets for this group\n            n_start = group_id * N\n\n            # Block dimensions\n            m_block_offset = pid_m * BLOCK_SIZE_M\n            k_block_offset = pid_k * BLOCK_SIZE_K\n\n            # Setup TMA descriptor for output if using TMA\n            if USE_TMA_STORE:\n                m_size = tl.minimum(\n                    BLOCK_SIZE_M, group_end - (group_start + m_block_offset)\n                )\n                if m_size > 0:\n                    tl.extra.cuda.experimental_device_tensormap_create2d(\n                        desc_ptr=c_desc_ptr,\n                        global_address=grad_x_ptr\n                        + (group_start + m_block_offset) * stride_gx_m\n                        + k_block_offset * stride_gx_k,\n                        load_size=[\n                            m_size,\n                            tl.minimum(BLOCK_SIZE_K, K - k_block_offset),\n                        ],\n                        global_size=[m_size, K],\n                        element_ty=dtype,\n                    )\n                    tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)\n\n            # Initialize offsets for this block\n            offs_m = group_start + m_block_offset + tl.arange(0, BLOCK_SIZE_M)\n\n            # For K dimension, optimize memory access if EVEN_K is True\n            offs_k = k_block_offset + tl.arange(0, BLOCK_SIZE_K)\n\n            # Create masks\n            m_mask = offs_m < group_end\n            k_mask = offs_k < K\n\n            # Initialize accumulator\n            accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n\n            # Loop over the reduction dimension (N)\n            # Use smaller steps to improve precision and avoid numerical issues\n            for n_offset in range(0, N, BLOCK_SIZE_N):\n                # Handle boundary conditions for the reduction dimension\n                n_size = tl.minimum(BLOCK_SIZE_N, N - n_offset)\n                offs_n = n_start + n_offset + tl.arange(0, BLOCK_SIZE_N)\n                n_mask = offs_n < (n_start + N)\n\n                # Fixed stride formats to ensure consistent memory access\n                grad_y_block = tl.load(\n                    grad_y_ptr\n                    + offs_m[:, None] * stride_go_m\n                    + offs_n[None, :] * stride_go_n,\n                    mask=m_mask[:, None] & n_mask[None, :],\n                    other=0.0,\n                )\n\n                # Load w_t [K, N*G] block with correct strides\n                w_t_block = tl.load(\n                    w_t_ptr\n                    + offs_k[:, None] * stride_w_k\n                    + offs_n[None, :] * stride_w_n,\n                    mask=k_mask[:, None] & n_mask[None, :],\n                    other=0.0,\n                )\n\n                # grad_y @ w_t.T\n                # Allow TF32 if K is even and divisible by 8\n                if EVEN_K:\n                    accumulator += tl.dot(\n                        grad_y_block.to(tl.float32),\n                        w_t_block.to(tl.float32).T,\n                        allow_tf32=True,\n                    )\n                else:\n                    accumulator += tl.dot(\n                        grad_y_block.to(tl.float32),\n                        w_t_block.to(tl.float32).T,\n                        allow_tf32=False,\n                    )\n\n            # Store result to grad_x with explicit strides\n            if USE_TMA_STORE:\n                # TMA store\n                tl._experimental_descriptor_store(\n                    c_desc_ptr,\n                    accumulator.to(dtype),\n                    [0, 0],  # Starting offset in the output block\n                )\n            else:\n                # Standard store\n                tl.store(\n                    grad_x_ptr\n                    + offs_m[:, None] * stride_gx_m\n                    + offs_k[None, :] * stride_gx_k,\n                    accumulator.to(dtype),\n                    mask=m_mask[:, None] & k_mask[None, :],\n                )\n\n        pid = pid + NUM_SMS\n\n\n@triton.jit\ndef _kernel_grouped_gemm_backward_w_scheduled(\n    x_t_ptr,  # x transposed [K, M]\n    grad_y_ptr,  # grad of dl/dY [M, N*G]\n    grad_w_ptr,  # output of kernel (grad_w) [N*G, K]\n    group_offsets_ptr,  # Pre-computed group offsets [G+1]\n    workspace,  # Workspace for TMA descriptors\n    G,  # Number of groups\n    M,  # Total M dimension size\n    N,  # N per group\n    K,  # K dimension size\n    stride_x_m,\n    stride_x_k,\n    stride_go_m,\n    stride_go_n,\n    stride_gw_n,\n    stride_gw_k,\n    NUM_SMS,\n    USE_TMA_LOAD: tl.constexpr = False,\n    USE_TMA_STORE: tl.constexpr = False,\n    BLOCK_SIZE_N: tl.constexpr = 64,\n    BLOCK_SIZE_K: tl.constexpr = 64,\n    BLOCK_SIZE_M: tl.constexpr = 32,\n    GROUP_SIZE_N: tl.constexpr = 8,\n    EVEN_K: tl.constexpr = True,\n) -> None:\n    \"\"\"\n    Scheduled implementation of grouped GEMM backward for W with TMA support.\n\n    For each group g, computes:\n        grad_w[g] = grad_y[g].T @ x[g]\n\n    Where:\n    - x_t is [K, M] (transposed from [M, K])\n    - grad_y is [M, N*G]\n    - grad_w is [N*G, K]\n    \"\"\"\n    # Define coordinates for the current program\n    tidx = tl.program_id(axis=0)\n    dtype = grad_w_ptr.dtype.element_ty\n    TMA_SIZE: tl.constexpr = 128\n\n    # Initialize workspace pointer if using TMA store\n    if USE_TMA_STORE:\n        c_desc_ptr = workspace + tidx * TMA_SIZE\n    else:\n        c_desc_ptr = None\n\n    # Calculate work distribution parameters\n    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n    num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n    num_pid_in_group = num_pid_n * num_pid_k\n\n    # Process all assigned work items\n    pid = tidx\n    while pid < G * num_pid_in_group:\n        # Calculate work distribution for this pid\n        group_id = pid // num_pid_in_group\n        pid_in_group = pid % num_pid_in_group\n        pid_n = pid_in_group % num_pid_n\n        pid_k = pid_in_group // num_pid_n\n\n        # Get group boundaries\n        valid_group = group_id < G\n        group_start = tl.where(valid_group, tl.load(group_offsets_ptr + group_id), 0)\n        group_end = tl.where(valid_group, tl.load(group_offsets_ptr + group_id + 1), 0)\n        group_size = group_end - group_start\n\n        # Calculate a mask for valid processing (valid group and non-empty)\n        valid_work = valid_group & (group_size > 0)\n\n        # Only process if we have valid work\n        if valid_work:\n            # Compute offsets for this group\n            n_start = group_id * N\n\n            # Block dimensions\n            n_block_offset = pid_n * BLOCK_SIZE_N\n            k_block_offset = pid_k * BLOCK_SIZE_K\n\n            # Setup TMA descriptor for output if using TMA\n            if USE_TMA_STORE:\n                n_size = tl.minimum(BLOCK_SIZE_N, N - n_block_offset)\n                if n_size > 0:\n                    tl.extra.cuda.experimental_device_tensormap_create2d(\n                        desc_ptr=c_desc_ptr,\n                        global_address=grad_w_ptr\n                        + (n_start + n_block_offset) * stride_gw_n\n                        + k_block_offset * stride_gw_k,\n                        load_size=[\n                            n_size,\n                            tl.minimum(BLOCK_SIZE_K, K - k_block_offset),\n                        ],\n                        global_size=[n_size, K],\n                        element_ty=dtype,\n                    )\n                    tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)\n\n            # Initialize offsets for this block\n            offs_n = n_start + n_block_offset + tl.arange(0, BLOCK_SIZE_N)\n            offs_k = k_block_offset + tl.arange(0, BLOCK_SIZE_K)\n\n            # Create masks\n            n_mask = offs_n < (n_start + N)\n            k_mask = offs_k < K\n\n            # Initialize accumulator\n            accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)\n\n            # Loop over the reduction dimension (M) with smaller steps to avoid overflow\n            for m_offset in range(0, group_size, BLOCK_SIZE_M):\n                # Handle boundary conditions for the reduction dimension\n                m_size = tl.minimum(BLOCK_SIZE_M, group_size - m_offset)\n                offs_m = group_start + m_offset + tl.arange(0, BLOCK_SIZE_M)\n                m_mask = offs_m < group_end\n\n                # Load grad_y [M, N*G] block with explicit strides\n                grad_y_block = tl.load(\n                    grad_y_ptr\n                    + offs_m[:, None] * stride_go_m\n                    + offs_n[None, :] * stride_go_n,\n                    mask=m_mask[:, None] & n_mask[None, :],\n                    other=0.0,\n                )\n\n                # Load x_t [K, M] block with explicit strides\n                x_t_block = tl.load(\n                    x_t_ptr\n                    + offs_k[:, None] * stride_x_k\n                    + offs_m[None, :] * stride_x_m,\n                    mask=k_mask[:, None] & m_mask[None, :],\n                    other=0.0,\n                )\n\n                # Matrix multiplication: (grad_y_block.T @ x_t_block.T)\n                if EVEN_K:\n                    accumulator += tl.dot(\n                        grad_y_block.to(\n                            tl.float32\n                        ).T,  # Shape: [BLOCK_SIZE_N, BLOCK_SIZE_M]\n                        x_t_block.to(\n                            tl.float32\n                        ).T,  # Shape: [BLOCK_SIZE_M, BLOCK_SIZE_K]\n                        allow_tf32=True,\n                    )\n                else:\n                    accumulator += tl.dot(\n                        grad_y_block.to(\n                            tl.float32\n                        ).T,  # Shape: [BLOCK_SIZE_N, BLOCK_SIZE_M]\n                        x_t_block.to(\n                            tl.float32\n                        ).T,  # Shape: [BLOCK_SIZE_M, BLOCK_SIZE_K]\n                        allow_tf32=False,\n                    )\n\n            # Store result to grad_w with explicit strides\n            if USE_TMA_STORE:\n                # TMA store\n                tl._experimental_descriptor_store(\n                    c_desc_ptr,\n                    accumulator.to(dtype),\n                    [0, 0],  # Starting offset in the output block\n                )\n            else:\n                # Standard store with explicit strides\n                tl.store(\n                    grad_w_ptr\n                    + offs_n[:, None] * stride_gw_n\n                    + offs_k[None, :] * stride_gw_k,\n                    accumulator.to(dtype),\n                    mask=n_mask[:, None] & k_mask[None, :],\n                )\n\n        pid = pid + NUM_SMS\n\n\n# ========== End Triton kernels ==========\n\n# ========== Begin grouped_gemm_backward cover function ==========\n\ndef grouped_gemm_backward(\n    grad_output: torch.Tensor,\n    x: torch.Tensor,\n    w: torch.Tensor,\n    m_sizes: torch.Tensor,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Backward pass for grouped matrix multiplication using scheduled kernels with TMA support.\n\n    Args:\n        grad_output: Gradient with respect to output, shape [M, N*G]\n        x: Input tensor from forward pass, shape [M, K]\n        w: Weight tensor from forward pass, shape [N*G, K]\n        m_sizes: Group sizes tensor, shape [G]\n\n    Returns:\n        Tuple of gradients with respect to x and w: (grad_x, grad_w)\n    \"\"\"\n    logging.info(\"Starting grouped_gemm_backward with TMA-enabled scheduling\")\n\n    # Check CUDA availability\n    if not torch.cuda.is_available():\n        logging.error(\"CUDA not available for backward pass\")\n        raise RuntimeError(\"CUDA not available for backward pass\")\n        # return _pytorch_fallback_backward(grad_output, x, w, m_sizes)\n\n    # Get GPU parameters - TODO: this can use PyTorch cached info...\n    device_props = torch.cuda.get_device_properties(\"cuda\")\n    NUM_SMS = device_props.multi_processor_count\n\n    # Check TMA support\n    has_tma = hasattr(tl.extra, \"cuda\") and device_props.major >= 9\n\n    if has_tma:\n        logging.info(f\"TMA support detected on GPU with {NUM_SMS} SMs\")\n        USE_TMA_LOAD = True  # TODO - this does nothing atm..removed to focus on numerical correctness first.\n        USE_TMA_STORE = False\n    else:\n        logging.warning(\"TMA support not detected, disabling TMA optimizations\")\n        USE_TMA_LOAD = False\n        USE_TMA_STORE = False\n\n    # Validate input dimensions\n    G = m_sizes.shape[0]\n    M, K_x = x.shape\n    N_times_G, K_w = w.shape\n\n    # Check that K dimensions match\n    if K_x != K_w:\n        logging.warning(f\"K dimension mismatch: x has K={K_x}, w has K={K_w}\")\n        raise ValueError(\"K dimensions must match for grouped GEMM backward\")\n        # return _pytorch_fallback_backward(grad_output, x, w, m_sizes)\n\n    try:\n        # Ensure contiguous tensors\n        grad_output = grad_output.contiguous()\n        x = x.contiguous()\n        w = w.contiguous()\n        m_sizes = m_sizes.contiguous()\n\n        # Allocate output tensors\n        grad_x = torch.zeros_like(x)\n        grad_w = torch.zeros_like(w)\n\n        # Determine N per group\n        # N*G is the second dimension size of grad_output\n        N = N_times_G // G\n\n        # Set stride values\n        # Direct access pattern for grad_output tensor\n        stride_go_m = grad_output.stride(0)  # grad_output in M dimension\n        stride_go_n = grad_output.stride(1)  # grad_output in N dimension\n\n        # Pattern match the transposed weight tensor\n        stride_w_n = 1  # transposed weights in N dimension\n        stride_w_k = N * G  # transposed weights in K dimension\n\n        # Pattern match the output grad_x tensor\n        stride_gx_m = grad_x.stride(0)  # grad_x in M dimension\n        stride_gx_k = grad_x.stride(1)  # Sgrad_x in K dimension\n\n        # Pattern match the transposed x tensor\n        stride_x_m = 1  # Stride for transposed x in M dimension\n        stride_x_k = M  # Stride for transposed x in K dimension\n\n        # Pattern match the output grad_w tensor\n        stride_gw_n = grad_w.stride(0)  # grad_w in N dimension\n        stride_gw_k = grad_w.stride(1)  # grad_w in K dimension\n\n        # Pre-compute group offsets for indexing\n        group_offsets = torch.zeros(G + 1, device=m_sizes.device, dtype=torch.int32)\n        m_offset = 0\n        for g in range(G):\n            group_offsets[g] = m_offset\n            m_offset += m_sizes[g].item()\n        group_offsets[G] = m_offset  # Total M\n\n        # Check if K dimension is even (optimize memory access patterns)\n        EVEN_K = (K_x % 8) == 0\n        logging.info(f\"EVEN_K optimization enabled: {EVEN_K} (K={K_x})\")\n\n        # Transpose x and w for backward computation\n        x_t = x.T.contiguous()  # Shape: [K, M]\n        w_t = w.T.contiguous()  # Shape: [K, N*G]\n\n        # Allocate workspace for TMA descriptors if needed\n        if USE_TMA_LOAD or USE_TMA_STORE:\n            workspace = torch.empty((NUM_SMS * 128), device=x.device, dtype=torch.uint8)\n        else:\n            # Empty tensor when TMA is not used\n            workspace = torch.empty(0, device=x.device, dtype=torch.uint8)\n\n        # Set block sizes based on K dimension\n        # For larger K, use smaller blocks to reduce register pressure\n        BLOCK_SIZE = 64 if K_x <= 64 else 32\n\n        BLOCK_SIZE_K = BLOCK_SIZE\n        BLOCK_SIZE_M = BLOCK_SIZE\n        BLOCK_SIZE_N = BLOCK_SIZE\n\n        # Determine maximum size needed and set the grid size\n        num_pid_m = triton.cdiv(M, BLOCK_SIZE_M)\n        num_pid_k = triton.cdiv(K_x, BLOCK_SIZE_K)\n        num_pid_n = triton.cdiv(N, BLOCK_SIZE_N)\n\n        # Compute total number of blocks needed for each kernel\n        total_blocks_x = G * num_pid_m * num_pid_k\n        total_blocks_w = G * num_pid_n * num_pid_k\n\n        try:\n            logging.info(\"Computing grad_x with TMA-enabled kernel\")\n\n            # Fixed grid size based on SM count\n            grid = (NUM_SMS,)\n\n            _kernel_grouped_gemm_backward_x_scheduled[grid](\n                grad_output,\n                w_t,  # Using transposed weights\n                grad_x,\n                group_offsets,\n                workspace,\n                G,\n                M,\n                N,\n                K_x,\n                stride_go_m,\n                stride_go_n,\n                stride_w_n,\n                stride_w_k,\n                stride_gx_m,\n                stride_gx_k,\n                NUM_SMS,\n                USE_TMA_LOAD,\n                USE_TMA_STORE,\n                BLOCK_SIZE_M=BLOCK_SIZE_M,\n                BLOCK_SIZE_N=BLOCK_SIZE_N,\n                BLOCK_SIZE_K=BLOCK_SIZE_K,\n                EVEN_K=EVEN_K,\n            )\n            logging.info(\n                \"Kernel run success: grad_X computation successful with TMA-enabled kernel\"\n            )\n        except Exception as e:\n            logging.error(f\"FAILED: Error in TMA-enabled backward_x kernel: {e}\")\n            logging.info(\"WARNING: Falling back to PyTorch for grad_x\")\n            _compute_grad_x_pytorch(grad_output, w, m_sizes, grad_x)\n\n        try:\n            logging.info(\"Computing grad_w with TMA-enabled kernel\")\n\n            # Fixed grid size based on SM count\n            grid = (NUM_SMS,)\n\n            _kernel_grouped_gemm_backward_w_scheduled[grid](\n                x_t,  # Using transposed inputs\n                grad_output,\n                grad_w,\n                group_offsets,\n                workspace,\n                G,\n                M,\n                N,\n                K_w,\n                stride_x_m,\n                stride_x_k,\n                stride_go_m,\n                stride_go_n,\n                stride_gw_n,\n                stride_gw_k,\n                NUM_SMS,\n                USE_TMA_LOAD,\n                USE_TMA_STORE,\n                BLOCK_SIZE_N=BLOCK_SIZE_N,\n                BLOCK_SIZE_K=BLOCK_SIZE_K,\n                BLOCK_SIZE_M=BLOCK_SIZE_M,\n                EVEN_K=EVEN_K,\n            )\n            logging.info(\n                \"Kernel run success - grad_W computation successful with TMA-enabled kernel\"\n            )\n        except Exception as e:\n            logging.error(f\"FAILED: Error in TMA-enabled backward_w kernel: {e}\")\n            logging.info(\"WARNING: Falling back to PyTorch for grad_w\")\n            # _compute_grad_w_pytorch(grad_output, x, m_sizes, grad_w)\n\n        return grad_x, grad_w\n    except Exception as e:\n        logging.error(f\"Error in grouped_gemm_backward: {e}\")\n        # return _pytorch_fallback_backward(grad_output, x, w, m_sizes)\n"
  },
  {
    "path": "kernels/MoE/group_GEMM/triton/tgroup_gemm_forward.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\n# pyre-unsafe\n\n# This is copied from FBGEMM, with some modifications.  Not kept in sync, Original code:\n# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py\n\nimport functools\nfrom typing import Optional\n\nimport tma_utils as utils\n\nimport torch\n\nimport triton\nimport triton.language as tl\nfrom triton.runtime import driver  # @manual\n\n\"\"\"\n_NV_CONFIGS = [\n    triton.Config(\n        {\n            \"BLOCK_SIZE_M\": block_size_m,\n            \"BLOCK_SIZE_N\": block_size_n,\n            \"BLOCK_SIZE_K\": block_size_k,\n        },\n        num_stages=num_stages,\n        num_warps=num_warps,\n        num_ctas=num_ctas,\n    )\n    for block_size_m in [64, 128]\n    for block_size_n in [64, 128, 256]\n    for block_size_k in [64, 128, 256]\n    for num_stages in [3, 4]\n    for num_warps in [4, 8]\n    for num_ctas in [1]\n]\n\n_AMD_CONFIGS = [\n    triton.Config(\n        {\n            \"BLOCK_SIZE_M\": block_size_m,\n            \"BLOCK_SIZE_N\": block_size_n,\n            \"BLOCK_SIZE_K\": block_size_k,\n            \"waves_per_eu\": waves_per_cu,\n            \"matrix_instr_nonkdim\": matrix_instr_nonkdim,\n        },\n        num_stages=num_stages,\n        num_warps=num_warps,\n    )\n    for block_size_m in [32, 64, 128]\n    for block_size_n in [32, 64, 128, 256]\n    for block_size_k in [128, 256]\n    for num_stages in [1, 2]\n    for num_warps, waves_per_cu in [(4, 1), (8, 2), (16, 4)]\n    for matrix_instr_nonkdim in [16]\n]\n\n\ndef early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):\n    device = torch.cuda.current_device()\n    # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages\n    if dtsize is None:\n        dtsize = named_args[\"c_ptr\"].element_size()\n    if dtype is None:\n        dtype = named_args[\"c_ptr\"].dtype\n\n    pruned_configs = []\n    for config in configs:\n        kw = config.kwargs\n        BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (\n            kw[\"BLOCK_SIZE_M\"],\n            kw[\"BLOCK_SIZE_N\"],\n            kw[\"BLOCK_SIZE_K\"],\n            config.num_stages,\n        )\n        G, M, N, K = (\n            named_args[\"G\"],\n            named_args[\"M_BUCKET\"],\n            named_args[\"N\"],\n            named_args[\"K\"],\n        )\n\n        # 1. make sure we have enough smem\n        max_shared_memory = driver.active.utils.get_device_properties(device)[\n            \"max_shared_mem\"\n        ]\n        if torch.version.hip:\n            required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize\n        else:\n            required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize\n        if required_shared_memory > max_shared_memory:\n            continue\n\n        M_PER_GROUP = M // G\n        MIN_M_TILES = 32 if torch.version.hip else 64\n        # 2. make sure we don't load M tiles that are too big\n        if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):\n            continue\n        # 3. make sure we don't load N tiles that are too small\n        if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):\n            continue\n\n        num_sm = driver.active.utils.get_device_properties(device)[\n            \"multiprocessor_count\"\n        ]\n        N_TILES = N // BLOCK_N\n        MIN_N_TILES = 32 if torch.version.hip else 64\n        # 4. make sure we don't load N tiles that are too big\n        if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:\n            continue\n        # 5. make sure we don't load N tiles that are too small\n        if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:\n            continue\n        # 6. make sure K can be evenly divided\n        if K % BLOCK_K != 0:\n            continue\n\n        pruned_configs.append(config)\n\n    return pruned_configs\n\n\n@triton.autotune(\n    configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,\n    key=[\"G\", \"M_BUCKET\", \"N\", \"K\"],\n    prune_configs_by={\"early_config_prune\": early_config_prune},\n)\n\"\"\"\n\n\n@triton.jit\ndef _kernel_grouped_gemm(\n    a_desc_ptr,\n    b_desc_ptr,\n    c_ptr,\n    workspace,\n    m_sizes,\n    # problem sizes\n    G: tl.constexpr,\n    M_BUCKET: tl.constexpr,\n    N: tl.constexpr,  # N is per group\n    K: tl.constexpr,\n    NUM_SMS: tl.constexpr,\n    USE_TMA_LOAD: tl.constexpr,\n    USE_TMA_STORE: tl.constexpr,\n    # tile sizes\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n) -> None:\n    tidx = tl.program_id(0)\n\n    dtype: tl.dtype = c_ptr.dtype.element_ty\n    TMA_SIZE: tl.constexpr = tl.constexpr(128)\n    if USE_TMA_STORE:\n        c_desc_ptr = workspace + tidx * TMA_SIZE\n    else:\n        c_desc_ptr = None\n\n    M_end_offset = 0\n    iterated_tiles = 0\n    for g in tl.range(G):\n        # Move across groups\n        M_start_offset = M_end_offset\n        m_size = tl.load(m_sizes + g)\n        M_end_offset = M_start_offset + m_size\n\n        if m_size > 0:\n            # Compute for this group\n            N_start_offset = g * N\n            n_size = N  # N is already per group\n\n            # Calculate the number of tiles for this group\n            num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)\n            num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)\n            num_tiles = num_m_tiles * num_n_tiles\n\n            if USE_TMA_STORE:\n                # Set up TMA descriptor for output\n                # pyre-ignore\n                tl.extra.cuda.experimental_device_tensormap_create2d(\n                    desc_ptr=c_desc_ptr,\n                    global_address=c_ptr\n                    + M_start_offset * (N * G)\n                    + N_start_offset,  # Offset to this group's output\n                    load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],\n                    global_size=[m_size, n_size],\n                    element_ty=c_ptr.dtype.element_ty,\n                )\n                # pyre-ignore\n                tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)\n\n            # Move across tiles\n            while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:\n                gidx = tidx - iterated_tiles\n                # Split M first and N second.\n                tile_m_idx = gidx % num_m_tiles\n                tile_n_idx = gidx // num_m_tiles\n\n                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n                tl.static_assert(K % BLOCK_SIZE_K == 0)\n\n                if USE_TMA_LOAD:\n                    # Use TMA to load input and weight blocks\n                    m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)\n                    n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)\n\n                    for k_offset in range(0, K, BLOCK_SIZE_K):\n                        # Load input block [M, K]\n                        a = tl._experimental_descriptor_load(\n                            a_desc_ptr,\n                            [m_offset, k_offset],\n                            [BLOCK_SIZE_M, BLOCK_SIZE_K],\n                            dtype,\n                        )\n\n                        # Load weight block [N, K]\n                        b = tl._experimental_descriptor_load(\n                            b_desc_ptr,\n                            [n_offset, k_offset],\n                            [BLOCK_SIZE_N, BLOCK_SIZE_K],\n                            dtype,\n                        )\n\n                        # Compute matrix multiplication\n                        accumulator += tl.dot(a, b.T)\n                else:\n                    # Manual load without TMA\n                    offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n                    offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n                    offs_k = tl.arange(0, BLOCK_SIZE_K)\n\n                    a_ptrs = (\n                        a_desc_ptr\n                        + (M_start_offset + offs_am[:, None]) * K\n                        + offs_k[None, :]\n                    )\n\n                    b_ptrs = (\n                        b_desc_ptr\n                        + (N_start_offset + offs_bn[:, None]) * K\n                        + offs_k[None, :]\n                    )\n\n                    for k_offset in range(0, K, BLOCK_SIZE_K):\n                        # Load with bounds checking\n                        a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)\n                        b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)\n\n                        # Compute matrix multiplication\n                        accumulator += tl.dot(a, b.T)\n\n                        # Update pointers for next block\n                        a_ptrs += BLOCK_SIZE_K\n                        b_ptrs += BLOCK_SIZE_K\n\n                # Store result\n                if USE_TMA_STORE:\n                    # Store using TMA\n                    m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)\n                    n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)\n\n                    tl._experimental_descriptor_store(\n                        c_desc_ptr,\n                        accumulator.to(c_ptr.dtype.element_ty),\n                        [m_offset, n_offset],\n                    )\n                else:\n                    # Manual store\n                    offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n                    offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n                    c = accumulator.to(c_ptr.dtype.element_ty)\n\n                    tl.store(\n                        c_ptr\n                        + (M_start_offset + offs_am[:, None])\n                        * (N * G)  # Row stride is N*G\n                        + (\n                            N_start_offset + offs_bn[None, :]\n                        ),  # Column offset to this group's N\n                        c,\n                        mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,\n                    )\n\n                tidx += NUM_SMS  # Move to next tile\n\n            iterated_tiles += num_tiles\n\n\nTT_FP8_DTYPE = tl.float8e4b8 if torch.version.hip else tl.float8e4nv\n\n\n\"\"\"@triton.autotune(\n    configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,\n    key=[\"G\", \"M_BUCKET\", \"N\", \"K\"],\n    prune_configs_by={\n        \"early_config_prune\": functools.partial(\n            early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1\n        )\n    },\n)\n\"\"\"\n\n\n@triton.jit\ndef _kernel_grouped_gemm_fp8_rowwise(\n    a_desc_ptr,\n    a_scale_ptr,\n    b_desc_ptr,\n    b_scale_ptr,\n    c_ptr,\n    workspace,\n    m_sizes,\n    # problem sizes\n    G: tl.constexpr,\n    M_BUCKET: tl.constexpr,\n    N: tl.constexpr,  # N is per group\n    K: tl.constexpr,\n    NUM_SMS: tl.constexpr,\n    USE_TMA_LOAD: tl.constexpr,\n    USE_TMA_STORE: tl.constexpr,\n    # tile sizes\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n) -> None:\n    tidx = tl.program_id(0)\n\n    dtype = TT_FP8_DTYPE\n    TMA_SIZE: tl.constexpr = tl.constexpr(128)\n    if USE_TMA_STORE:\n        c_desc_ptr = workspace + tidx * TMA_SIZE\n    else:\n        c_desc_ptr = None\n\n    M_end_offset = 0\n    iterated_tiles = 0\n    for g in tl.range(G):\n        # Move across groups\n        M_start_offset = M_end_offset\n        m_size = tl.load(m_sizes + g)\n        M_end_offset = M_start_offset + m_size\n\n        if m_size > 0:\n            # Compute for this group\n            N_start_offset = g * N\n            n_size = N  # N is already per group\n\n            # Calculate the number of tiles for this group\n            num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)\n            num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)\n            num_tiles = num_m_tiles * num_n_tiles\n\n            if USE_TMA_STORE:\n                # Set up TMA descriptor for output\n                # pyre-ignore\n                tl.extra.cuda.experimental_device_tensormap_create2d(\n                    desc_ptr=c_desc_ptr,\n                    global_address=c_ptr\n                    + M_start_offset * (N * G)\n                    + N_start_offset,  # Offset to this group's output\n                    load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],\n                    global_size=[m_size, n_size],\n                    element_ty=c_ptr.dtype.element_ty,\n                )\n                # pyre-ignore\n                tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)\n\n            # Move across tiles\n            while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:\n                gidx = tidx - iterated_tiles\n                # Split M first and N second.\n                tile_m_idx = gidx % num_m_tiles\n                tile_n_idx = gidx // num_m_tiles\n\n                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n                tl.static_assert(K % BLOCK_SIZE_K == 0)\n\n                if USE_TMA_LOAD:\n                    # Use TMA to load input and weight blocks with FP8 support\n                    m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)\n                    n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)\n\n                    for k_offset in range(0, K, BLOCK_SIZE_K):\n                        # Load input block [M, K] with FP8\n                        a = tl._experimental_descriptor_load(\n                            a_desc_ptr,\n                            [m_offset, k_offset],\n                            [BLOCK_SIZE_M, BLOCK_SIZE_K],\n                            dtype,\n                        )\n\n                        # Load weight block [N, K] with FP8\n                        b = tl._experimental_descriptor_load(\n                            b_desc_ptr,\n                            [n_offset, k_offset],\n                            [BLOCK_SIZE_N, BLOCK_SIZE_K],\n                            dtype,\n                        )\n\n                        # Compute matrix multiplication\n                        accumulator += tl.dot(a, b.T)\n                else:\n                    # Manual load without TMA for FP8\n                    offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n                    offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n                    offs_k = tl.arange(0, BLOCK_SIZE_K)\n\n                    a_ptrs = (\n                        a_desc_ptr\n                        + (M_start_offset + offs_am[:, None]) * K\n                        + offs_k[None, :]\n                    )\n\n                    b_ptrs = (\n                        b_desc_ptr\n                        + (N_start_offset + offs_bn[:, None]) * K\n                        + offs_k[None, :]\n                    )\n\n                    for k_offset in range(0, K, BLOCK_SIZE_K):\n                        # Load with bounds checking\n                        a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)\n                        b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)\n\n                        # Compute matrix multiplication\n                        accumulator += tl.dot(a, b.T)\n\n                        # Update pointers for next block\n                        a_ptrs += BLOCK_SIZE_K\n                        b_ptrs += BLOCK_SIZE_K\n\n                # Load FP8 scales\n                offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n                offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n                a_scale = tl.load(\n                    a_scale_ptr + M_start_offset + offs_am[:, None],\n                    mask=offs_am[:, None] < m_size,\n                )\n\n                b_scale = tl.load(\n                    b_scale_ptr + N_start_offset + offs_bn[None, :],\n                    mask=offs_bn[None, :] < n_size,\n                )\n\n                # Apply scales to result\n                c = accumulator.to(tl.float32) * a_scale * b_scale\n\n                # Store result\n                if USE_TMA_STORE:\n                    # Store using TMA\n                    m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)\n                    n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)\n\n                    tl._experimental_descriptor_store(\n                        c_desc_ptr,\n                        c.to(c_ptr.dtype.element_ty),\n                        [m_offset, n_offset],\n                    )\n                else:\n                    # Manual store\n                    tl.store(\n                        c_ptr\n                        + (M_start_offset + offs_am[:, None])\n                        * (N * G)  # Row stride is N*G\n                        + (\n                            N_start_offset + offs_bn[None, :]\n                        ),  # Column offset to this group's N\n                        c,\n                        mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,\n                    )\n\n                tidx += NUM_SMS  # Move to next tile\n\n            iterated_tiles += num_tiles\n\n\ndef _grouped_gemm(\n    x: torch.Tensor,\n    w: torch.Tensor,\n    m_sizes: torch.Tensor,\n    x_scale: Optional[torch.Tensor] = None,\n    w_scale: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    if not utils.HAS_TMA_DESC:\n        raise NotImplementedError(\"Grouped GEMM without TMA is not supported yet\")\n\n    G = m_sizes.shape[0]\n\n    assert x.is_contiguous()\n    assert w.is_contiguous()\n    assert m_sizes.is_contiguous()\n\n    M, K = x.shape\n    N_times_G = w.shape[0]\n\n    # Ensure N is per group\n    assert (\n        N_times_G % G == 0\n    ), f\"Weight dimension ({N_times_G}) must be divisible by groups ({G})\"\n    N = N_times_G // G\n\n    assert K == w.shape[1], f\"Input K ({K}) must match weight K ({w.shape[1]})\"\n\n    # Create output tensor with correct shape [M, N*G]\n    y = torch.empty((M, N_times_G), device=x.device, dtype=torch.bfloat16)\n\n    NUM_SMS = torch.cuda.get_device_properties(\"cuda\").multi_processor_count\n    USE_TMA_LOAD = True  # not torch.version.hip\n    USE_TMA_STORE = True\n\n    desc_helper = None\n    desc_x = x\n    desc_w = w\n    workspace = None\n\n    if USE_TMA_LOAD:\n        desc_helper = utils.TmaAutoTuneHelper()\n        desc_helper.init_tma_descriptor(\"x\")\n        desc_helper.init_tma_descriptor(\"w\")\n        desc_x = desc_helper.get_tma_descriptor_kernel_param(\"x\")\n        desc_w = desc_helper.get_tma_descriptor_kernel_param(\"w\")\n\n    if USE_TMA_STORE:\n        workspace = torch.empty(\n            NUM_SMS * utils.TmaAutoTuneHelper.TMA_SIZE,\n            device=x.device,\n            dtype=torch.uint8,\n        )\n\n    # Skip autotuning - use fixed grid size\n    grid_size = (min(NUM_SMS, 4),)  # Use smaller grid for small inputs\n    M_BUCKET = triton.next_power_of_2(M)\n\n    try:\n\n        if USE_TMA_LOAD and desc_helper is not None:\n            # Fixed block sizes that work well for most cases\n            BLOCK_SIZE_M = 64\n            BLOCK_SIZE_N = 64\n            BLOCK_SIZE_K = 32\n\n            desc_helper.fill_2d_tma_descriptor(\n                \"x\",\n                x.data_ptr(),\n                M,\n                K,\n                BLOCK_SIZE_M,\n                BLOCK_SIZE_K,\n                x.element_size(),\n            )\n\n            desc_helper.fill_2d_tma_descriptor(\n                \"w\",\n                w.data_ptr(),\n                N_times_G,\n                K,\n                BLOCK_SIZE_N,\n                BLOCK_SIZE_K,\n                w.element_size(),\n            )\n    except Exception as e:\n        print(f\"Error in TMA descriptor setup: {e}\")\n\n    if x_scale is not None and w_scale is not None:\n        assert x_scale.is_contiguous()\n        assert w_scale.is_contiguous()\n        # Call kernel directly without autotuning\n        _kernel_grouped_gemm_fp8_rowwise[grid_size](\n            desc_x,\n            x_scale,\n            desc_w,\n            w_scale,\n            y,\n            workspace,\n            m_sizes,\n            G,\n            M_BUCKET,\n            N,  # N is per group\n            K,\n            NUM_SMS,\n            USE_TMA_LOAD,\n            USE_TMA_STORE,\n            BLOCK_SIZE_M=64,  # Fixed block sizes\n            BLOCK_SIZE_N=64,\n            BLOCK_SIZE_K=32,\n        )\n    else:\n        assert x_scale is None\n        assert w_scale is None\n        # Call kernel directly without autotuning\n        _kernel_grouped_gemm[grid_size](\n            desc_x,\n            desc_w,\n            y,\n            workspace,\n            m_sizes,\n            G,\n            M_BUCKET,\n            N,  # N is per group\n            K,\n            NUM_SMS,\n            USE_TMA_LOAD,\n            USE_TMA_STORE,\n            BLOCK_SIZE_M=64,  # Fixed block sizes\n            BLOCK_SIZE_N=64,\n            BLOCK_SIZE_K=32,\n        )\n\n    # Verify the output shape\n    expected_output_shape = (M, N_times_G)\n    assert y.shape == expected_output_shape, (\n        f\"Output shape mismatch: got {y.shape}, \" f\"expected {expected_output_shape}\"\n    )\n\n    return y\n\n\ndef grouped_gemm_forward(\n    x: torch.Tensor, w: torch.Tensor, m_sizes: torch.Tensor\n) -> torch.Tensor:\n    return _grouped_gemm(x, w, m_sizes)\n\n\ndef grouped_gemm_fp8_rowwise(\n    x: torch.Tensor,\n    w: torch.Tensor,\n    m_sizes: torch.Tensor,\n    x_scale: torch.Tensor,\n    w_scale: torch.Tensor,\n) -> torch.Tensor:\n    return _grouped_gemm(x, w, m_sizes, x_scale, w_scale)\n"
  },
  {
    "path": "kernels/MoE/group_GEMM/triton/utils/tma_utils.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\n# pyre-unsafe\n# This code is derived from: https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gemm/triton_gemm\n\nimport sys\n\nimport torch\nimport triton  # @manual\n\nimport triton.language as tl  # @manual\n\n\ndef map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:\n    \"\"\"\n    Maps torch dtype to triton dtype.\n\n    Args:\n        dtype (torch.dtype): input dtype.\n\n    Returns:\n        tl.dtype: triton dtype.\n    \"\"\"\n    if dtype == torch.float16:\n        return tl.float16\n    elif dtype == torch.bfloat16:\n        return tl.bfloat16\n    elif dtype == torch.float32:\n        return tl.float32\n    elif dtype == torch.int32:\n        return tl.int32\n    elif dtype == torch.float8_e4m3fn and torch.version.hip is None:\n        return tl.float8e4nv\n    else:\n        raise ValueError(f\"Unsupported dtype {dtype}\")\n\n\n# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498).\nHAS_TMA_DESC = \"nv_tma_desc_type\" in dir(tl)\n\nif HAS_TMA_DESC:\n    print(\n        \"TMA benchmarks will be running with experimental grid constant TMA descriptor.\",\n        file=sys.stderr,\n    )\nelse:\n    print(\n        \"Missing: This group gemm code will not run without TMA descriptor support....\",\n        file=sys.stderr,\n    )\n    raise NotImplementedError(\"grouped Gemm without TMA is not supported\")\n\n\nclass TmaAutoTuneHelper:\n\n    # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498\n    class KernelParamWrapper:\n        def __init__(self, desc):\n            self.desc = desc\n\n        def tma_desc_cpu_ptr(self):\n            return self.desc.data_ptr()\n\n    TMA_SIZE = 128\n\n    def __init__(self):\n        self.fill_1d_tma_descriptor_inner = (\n            triton.runtime.driver.active.utils.fill_1d_tma_descriptor\n        )\n        self.fill_2d_tma_descriptor_inner = (\n            triton.runtime.driver.active.utils.fill_2d_tma_descriptor\n        )\n        if HAS_TMA_DESC:\n            self.descriptors = {}\n        else:\n            self.cuda_descriptors = {}\n\n    # Call this method outside of the lambda function for grid size\n    def init_tma_descriptor(self, name):\n        if HAS_TMA_DESC:\n            self.descriptors[name] = torch.empty(\n                TmaAutoTuneHelper.TMA_SIZE, device=\"cpu\", dtype=torch.int8\n            )\n        else:\n            self.cuda_descriptors[name] = torch.empty(\n                TmaAutoTuneHelper.TMA_SIZE, device=\"cuda\", dtype=torch.int8\n            )\n\n    # Call this method inside the lambda function for grid size\n    def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):\n        if HAS_TMA_DESC:\n            desc_x = self.descriptors[name]\n            assert desc_x.data_ptr() % 64 == 0\n            self.fill_1d_tma_descriptor_inner(\n                ptr, dim, block_dim, element_size, desc_x.data_ptr()\n            )\n        else:\n            desc_x = self.cuda_descriptors[name]\n            buf_x = torch.empty_like(desc_x, device=\"cpu\", pin_memory=True)\n            self.fill_1d_tma_descriptor_inner(\n                ptr, dim, block_dim, element_size, buf_x.data_ptr()\n            )\n            desc_x.copy_(buf_x, non_blocking=True)\n\n    # Call this method inside the lambda function for grid size\n    def fill_2d_tma_descriptor(\n        self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size\n    ):\n        if HAS_TMA_DESC:\n            desc_x = self.descriptors[name]\n            assert desc_x.data_ptr() % 64 == 0\n            self.fill_2d_tma_descriptor_inner(\n                ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()\n            )\n        else:\n            desc_x = self.cuda_descriptors[name]\n            buf_x = torch.empty_like(desc_x, device=\"cpu\", pin_memory=True)\n            self.fill_2d_tma_descriptor_inner(\n                ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()\n            )\n            desc_x.copy_(buf_x, non_blocking=True)\n\n    def get_tma_descriptor_kernel_param(self, name):\n        if HAS_TMA_DESC:\n            assert self.descriptors[name] is not None\n            return self.KernelParamWrapper(self.descriptors[name])\n        else:\n            assert self.cuda_descriptors[name] is not None\n            return self.cuda_descriptors[name]\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/Makefile",
    "content": "\n# Makefile for SM100 GEMM PyTorch Extension\n\n# Set these paths according to your installation\nCUTLASS_PATH ?= /path/to/cutlass\nCUDA_HOME ?= $(shell python -c \"import torch; print(torch.utils.cpp_extension.CUDA_HOME)\")\n\n# Build the extension\nbuild:\n\tCUTLASS_PATH=$(CUTLASS_PATH) python setup.py build_ext --inplace\n\n# Install the extension\ninstall:\n\tCUTLASS_PATH=$(CUTLASS_PATH) pip install .\n\n# Clean build artifacts\nclean:\n\trm -rf build/ dist/ *.egg-info/ sm100_gemm*.so\n\n# Test the installation\ntest:\n\tpython python_interface.py\n\n# Check CUDA device capability\ncheck_device:\n\tpython -c \"import torch; print(f'CUDA device: {torch.cuda.get_device_name()}, Compute capability: {torch.cuda.get_device_capability()}')\"\n\n.PHONY: build install clean test check_device\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/.ninja_log",
    "content": "# ninja log v5\n0\t15279\t1748131038212164071\t/data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o\t1163be77f63db063\n6\t13596\t1748131241209889865\t/data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm.o\t79aa61597088743a\n8\t13684\t1748132015451659084\t/data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm.o\t89ead7aaccf82852\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/build.ninja",
    "content": "ninja_required_version = 1.3\ncxx = c++\nnvcc = /usr/local/cuda-12.8/bin/nvcc\n\ncflags = -pthread -B /home/less/.conda/envs/pycutlass/compiler_compat -fno-strict-overflow -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c\npost_cflags = -O3 -std=c++17 -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE=\"_gcc\"' '-DPYBIND11_STDLIB=\"_libstdcpp\"' '-DPYBIND11_BUILD_ABI=\"_cxxabi1018\"' -DTORCH_EXTENSION_NAME=sm100_gemm\ncuda_cflags = -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c\ncuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''\"'\"'-fPIC'\"'\"'' -O3 -std=c++17 --expt-relaxed-constexpr --expt-extended-lambda -gencode=arch=compute_100a,code=sm_100a -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED --use_fast_math -Xcompiler=-fPIC -DCUTE_ARCH_TCGEN05_TMEM_ENABLED=1 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE=\"_gcc\"' '-DPYBIND11_STDLIB=\"_libstdcpp\"' '-DPYBIND11_BUILD_ABI=\"_cxxabi1018\"' -DTORCH_EXTENSION_NAME=sm100_gemm\ncuda_dlink_post_cflags = \nsycl_dlink_post_cflags = \nldflags = \n\nrule compile\n  command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags\n  depfile = $out.d\n  deps = gcc\n\nrule cuda_compile\n  depfile = $out.d\n  deps = gcc\n  command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags\n\n\n\n\n\n\n\nbuild /data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm.o: cuda_compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm/sm100_gemm.cu\nbuild /data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o: compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm/sm100_gemm_pytorch.cpp\n\n\n\n\n\n\n\n\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/driver.py",
    "content": "# ==============================================================================\n# python_interface.py - High-level Python interface\n# ==============================================================================\n\n\nimport torch\n\ntry:\n    import sm100_gemm  # The compiled extension - this has to go after import torch...but auto-formatting is blocking\nexcept ImportError:\n    print(\"❌ SM100 not ready!\")\n    raise ImportError(\n        \"SM100 not ready! Please build the extension using `python setup.py install`\"\n    )\n\n\ndef sm100_gemm_f16(A, B, C=None, alpha=1.0, beta=0.0):\n    \"\"\"\n    Perform GEMM using SM100 optimized kernel: D = alpha * A @ B^T + beta * C\n\n    Args:\n        A (torch.Tensor): Input tensor A of shape (M, K), dtype=torch.float16\n        B (torch.Tensor): Input tensor B of shape (N, K), dtype=torch.float16\n        C (torch.Tensor, optional): Input tensor C of shape (M, N), dtype=torch.float32\n                                   If None, creates zero tensor\n        alpha (float): Scaling factor for A @ B^T\n        beta (float): Scaling factor for C\n\n    Returns:\n        torch.Tensor: Output tensor D of shape (M, N), dtype=torch.float32\n\n    Note:\n        - A and B are K-major (transposed in BLAS terms)\n        - C and D are N-major (row-major)\n        - All tensors must be on CUDA\n        - M must be multiple of 128, N multiple of 256, K multiple of 64\n    \"\"\"\n\n    # Input validation\n    assert A.dtype == torch.float16, f\"A must be float16, got {A.dtype}\"\n    assert B.dtype == torch.float16, f\"B must be float16, got {B.dtype}\"\n    assert A.is_cuda and B.is_cuda, \"A and B must be on CUDA\"\n    assert A.is_contiguous() and B.is_contiguous(), \"A and B must be contiguous\"\n\n    M, K = A.shape\n    N, K_B = B.shape\n    assert K == K_B, f\"Inner dimensions must match: A.shape[1]={K}, B.shape[1]={K_B}\"\n\n    # Check alignment requirements\n    assert M % 128 == 0, f\"M={M} must be multiple of 128\"\n    assert N % 256 == 0, f\"N={N} must be multiple of 256\"\n    assert K % 64 == 0, f\"K={K} must be multiple of 64\"\n\n    # Create C if not provided\n    if C is None:\n        C = torch.zeros(M, N, dtype=torch.float32, device=A.device)\n    else:\n        assert C.dtype == torch.float32, f\"C must be float32, got {C.dtype}\"\n        assert C.is_cuda, \"C must be on CUDA\"\n        assert C.is_contiguous(), \"C must be contiguous\"\n        assert C.shape == (\n            M,\n            N,\n        ), f\"C shape {C.shape} must match output shape ({M}, {N})\"\n\n    # Call the extension\n    return sm100_gemm.sm100_gemm_f16(A, B, C, alpha, beta)\n\n\ndef benchmark_sm100_vs_torch(\n    M=1024, N=2048, K=256, num_warmup=1, num_trials=10\n):  # M=512, N=1024, K=256, num_warmup=10, num_trials=100):\n    \"\"\"\n    Benchmark SM100 GEMM against PyTorch's native GEMM\n    \"\"\"\n    # Ensure dimensions are aligned\n    M = ((M + 127) // 128) * 128\n    N = ((N + 255) // 256) * 256\n    K = ((K + 63) // 64) * 64\n\n    print(f\"Benchmarking GEMM with shape: ({M}, {N}, {K})\")\n\n    # Create test tensors\n    A = torch.randn(M, K, dtype=torch.float16, device=\"cuda\")\n    B = torch.randn(N, K, dtype=torch.float16, device=\"cuda\")\n    C = torch.randn(M, N, dtype=torch.float16, device=\"cuda\")\n    C32 = C.to(torch.float32).clone()\n\n    # Keep A and B as FP16 for PyTorch\n    A_fp16 = A\n    B_fp16 = B\n\n    # Warmup\n    for _ in range(num_warmup):\n        # PyTorch GEMM (using FP16)\n        torch_result = torch.addmm(C, A_fp16, B_fp16.T)\n\n        # SM100 GEMM\n        sm100_result = sm100_gemm_f16(A, B, C32)\n\n    torch.cuda.synchronize()\n\n    # Benchmark PyTorch\n    torch.cuda.synchronize()\n    start = torch.cuda.Event(enable_timing=True)\n    end = torch.cuda.Event(enable_timing=True)\n\n    start.record()\n    for _ in range(num_trials):\n        torch_result = torch.addmm(C, A_fp16, B_fp16.T)\n    end.record()\n    torch.cuda.synchronize()\n    torch_time = start.elapsed_time(end) / num_trials\n\n    # Benchmark SM100\n    start.record()\n    for _ in range(num_trials):\n        sm100_result = sm100_gemm_f16(A, B, C32)\n    end.record()\n    torch.cuda.synchronize()\n    sm100_time = start.elapsed_time(end) / num_trials\n\n    # Check correctness\n    max_diff = torch.max(torch.abs(torch_result - sm100_result.to(torch.float16)))\n    rel_error = max_diff / torch.max(torch.abs(torch_result))\n\n    # Calculate FLOPS\n    flops = 2 * M * N * K  # Multiply-add operations\n    torch_tflops = flops / (torch_time * 1e-3) / 1e12\n    sm100_tflops = flops / (sm100_time * 1e-3) / 1e12\n\n    print(f\"PyTorch time: {torch_time:.3f} ms ({torch_tflops:.2f} TFLOPS)\")\n    print(f\"SM100 time: {sm100_time:.3f} ms ({sm100_tflops:.2f} TFLOPS)\")\n    print(f\"Speedup: {torch_time/sm100_time:.2f}x\")\n    print(f\"Max difference: {max_diff:.6f}\")\n    print(f\"Relative error: {rel_error:.6f}\")\n\n    return {\n        \"torch_time\": torch_time,\n        \"sm100_time\": sm100_time,\n        \"speedup\": torch_time / sm100_time,\n        \"torch_tflops\": torch_tflops,\n        \"sm100_tflops\": sm100_tflops,\n        \"max_diff\": max_diff.item(),\n        \"rel_error\": rel_error.item(),\n    }\n\n\n# Example usage and test\nif __name__ == \"__main__\":\n    # Test basic functionality\n    print(\"Testing SM100 GEMM...\")\n\n    M, N, K = 512, 1024, 256\n    A = torch.randn(M, K, dtype=torch.float16, device=\"cuda\")\n    B = torch.randn(N, K, dtype=torch.float16, device=\"cuda\")\n    C = torch.randn(M, N, dtype=torch.float32, device=\"cuda\")\n\n    # Test the GEMM\n    result = sm100_gemm_f16(A, B, C, alpha=1.0, beta=0.5)\n    print(f\"Result shape: {result.shape}, dtype: {result.dtype}\")\n\n    # Run benchmark\n    print(\"\\nRunning benchmark...\")\n    benchmark_results = benchmark_sm100_vs_torch(M, N, K)\n\n# ==============================================================================\n# Makefile for easy building\n# ==============================================================================\n'''\nMAKEFILE_CONTENT = \"\"\"\n# Makefile for SM100 GEMM PyTorch Extension\n\n# Set these paths according to your installation\nCUTLASS_PATH ?= /path/to/cutlass\nCUDA_HOME ?= $(shell python -c \"import torch; print(torch.utils.cpp_extension.CUDA_HOME)\")\n\n# Build the extension\nbuild:\n\tCUTLASS_PATH=$(CUTLASS_PATH) python setup.py build_ext --inplace\n\n# Install the extension\ninstall:\n\tCUTLASS_PATH=$(CUTLASS_PATH) pip install .\n\n# Clean build artifacts\nclean:\n\trm -rf build/ dist/ *.egg-info/ sm100_gemm*.so\n\n# Test the installation\ntest:\n\tpython python_interface.py\n\n# Check CUDA device capability\ncheck_device:\n\tpython -c \"import torch; print(f'CUDA device: {torch.cuda.get_device_name()}, Compute capability: {torch.cuda.get_device_capability()}')\"\n\n.PHONY: build install clean test check_device\n\"\"\"\n\n# Write Makefile\nwith open(\"Makefile\", \"w\") as f:\n    f.write(MAKEFILE_CONTENT)\n\nprint(\"Setup files created!\")\nprint(\"To build:\")\nprint(\"1. Set CUTLASS_PATH environment variable to your CUTLASS installation\")\nprint(\"2. Run: make build\")\nprint(\"3. Test: make test\")\n'''\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/setup.py",
    "content": "# setup.py\nimport os\n\nimport pybind11\nimport torch\nfrom pybind11 import get_cmake_dir\nfrom pybind11.setup_helpers import build_ext, Pybind11Extension\nfrom setuptools import Extension, setup\nfrom torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension\n\n# IMPORTANT: The following two lines are the only ones you need to change\n# Get CUTLASS path (you'll need to set this to your CUTLASS installation)\nCUTLASS_PATH = os.environ.get(\"CUTLASS_PATH\", \"/home/less/local/cutlas40\")\n\n# CUDA and PyTorch paths\ncuda_home = torch.utils.cpp_extension.CUDA_HOME\npytorch_includes = torch.utils.cpp_extension.include_paths()\n\next_modules = [\n    CUDAExtension(\n        name=\"sm100_gemm\",\n        sources=[\n            \"sm100_gemm_pytorch.cpp\",  # PyTorch bindings (C++)\n            \"sm100_gemm.cu\",  # CUDA kernel implementation\n        ],\n        include_dirs=[\n            # PyTorch includes\n            *pytorch_includes,\n            # CUTLASS includes\n            f\"{CUTLASS_PATH}/include\",\n            f\"{CUTLASS_PATH}/tools/util/include\",\n            # CUDA includes\n            f\"{cuda_home}/include\",\n        ],\n        library_dirs=[\n            f\"{cuda_home}/lib64\",\n        ],\n        libraries=[\"cuda\", \"cudart\"],\n        extra_compile_args={\n            \"cxx\": [\n                \"-O3\",\n                \"-std=c++17\",\n                \"-DCUTLASS_ARCH_MMA_SM100_SUPPORTED\",\n                \"-DCUTE_SM100_ENABLED\",\n            ],\n            \"nvcc\": [\n                \"-O3\",\n                \"-std=c++17\",\n                \"--expt-relaxed-constexpr\",\n                \"--expt-extended-lambda\",\n                \"-gencode=arch=compute_100a,code=sm_100a\",  # SM100 architecture\n                \"-DCUTLASS_ARCH_MMA_SM100_SUPPORTED\",\n                \"-DCUTE_SM100_ENABLED\",\n                \"--use_fast_math\",\n                \"-Xcompiler=-fPIC\",\n                \"-DCUTE_ARCH_TCGEN05_TMEM_ENABLED=1\",  # Enable TCGEN05_TMEM\n            ],\n        },\n        extra_link_args=[\"-lcuda\", \"-lcudart\"],\n        language=\"c++\",\n    )\n]\n\nsetup(\n    name=\"sm100_gemm\",\n    ext_modules=ext_modules,\n    cmdclass={\"build_ext\": BuildExtension},\n    zip_safe=False,\n    python_requires=\">=3.8\",\n    install_requires=[\"torch>=1.12.0\"],\n)\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/sm100_gemm.cu",
    "content": "// sm100_gemm_kernel.cu - CUDA kernel implementation\n#include \"sm100_gemm.h\"\n\n#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)\n\n#include <cutlass/arch/barrier.h>\n#include <cutlass/cluster_launch.hpp>\n#include <cutlass/half.h>\n#include <cutlass/util/print_error.hpp>\n\n#include <cute/algorithm/cooperative_copy.hpp>\n#include <cute/arch/cluster_sm90.hpp>\n#include <cute/arch/tmem_allocator_sm100.hpp>\n#include <cute/numeric/integral_constant.hpp>\n#include <cute/tensor.hpp>\n\nusing namespace cute;\n\n// Shared storage structure\ntemplate <class TypeA, class TypeB, class ASmemLayout, class BSmemLayout>\nstruct SharedStorage {\n  alignas(128) cute::ArrayEngine<TypeA, cute::cosize_v<ASmemLayout>> A;\n  alignas(128) cute::ArrayEngine<TypeB, cute::cosize_v<BSmemLayout>> B;\n  alignas(16) cute::uint64_t mma_barrier;\n  alignas(16) cute::uint32_t tmem_base_ptr;\n\n  CUTE_DEVICE constexpr auto tensor_sA() {\n    return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{});\n  }\n  CUTE_DEVICE constexpr auto tensor_sB() {\n    return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{});\n  }\n};\n\n// Device kernel\ntemplate <class SharedStorage, class ATensor, class BTensor, class CTensor,\n          class DTensor, class MmaTiler_MNK, class TiledMMA,\n          class ClusterShape_MNK, class Alpha, class Beta>\n__global__ static void\ngemm_device(ATensor mA, BTensor mB, CTensor mC, DTensor mD,\n            MmaTiler_MNK mma_tiler, TiledMMA tiled_mma,\n            ClusterShape_MNK cluster_shape, Alpha alpha, Beta beta) {\n  // Step 1: The Prologue\n  Layout cluster_layout_vmnk = tiled_divide(\n      make_layout(cluster_shape), make_tile(typename TiledMMA::AtomThrID{}));\n\n  auto mma_coord_vmnk =\n      make_coord(blockIdx.x % size<0>(cluster_layout_vmnk),\n                 blockIdx.x / size<0>(cluster_layout_vmnk), blockIdx.y, _);\n\n  auto mma_coord = select<1, 2, 3>(mma_coord_vmnk);\n  Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X, _1>{});\n  Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step<X, _1, _1>{});\n  Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1, _1, X>{});\n  Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1, _1, X>{});\n\n  // SMEM allocation\n  extern __shared__ char shared_memory[];\n  SharedStorage &shared_storage =\n      *reinterpret_cast<SharedStorage *>(shared_memory);\n\n  Tensor tCsA = shared_storage.tensor_sA();\n  Tensor tCsB = shared_storage.tensor_sB();\n\n  // MMA partitioning\n  auto mma_v = get<0>(mma_coord_vmnk);\n  ThrMMA cta_mma = tiled_mma.get_slice(mma_v);\n  Tensor tCgA = cta_mma.partition_A(gA);\n  Tensor tCgB = cta_mma.partition_B(gB);\n  Tensor tCgC = cta_mma.partition_C(gC);\n  Tensor tCgD = cta_mma.partition_C(gD);\n\n  // Fragment allocation\n  Tensor tCrA = cta_mma.make_fragment_A(tCsA);\n  Tensor tCrB = cta_mma.make_fragment_B(tCsB);\n  Tensor tCtAcc = cta_mma.make_fragment_C(tCgC);\n\n  uint32_t elect_one_thr = cute::elect_one_sync();\n  uint32_t elect_one_warp = (threadIdx.x / 32 == 0);\n\n  using TmemAllocator = cute::TMEM::Allocator1Sm;\n  TmemAllocator tmem_allocator{};\n\n  // TMEM allocation\n  if (elect_one_warp) {\n    tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns,\n                            &shared_storage.tmem_base_ptr);\n  }\n  __syncthreads();\n  tCtAcc.data() = shared_storage.tmem_base_ptr;\n\n  // Barrier initialization\n  if (elect_one_warp && elect_one_thr) {\n    cute::initialize_barrier(shared_storage.mma_barrier, 1);\n  }\n  int mma_barrier_phase_bit = 0;\n  __syncthreads();\n\n  // Step 2: The Mainloop\n  tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;\n\n  for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile) {\n    // Load A and B tiles\n    cooperative_copy<128>(threadIdx.x, tCgA(_, _, _, k_tile), tCsA);\n    cooperative_copy<128>(threadIdx.x, tCgB(_, _, _, k_tile), tCsB);\n\n    __syncthreads();\n\n    // Execute MMAs\n    if (elect_one_warp) {\n      for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {\n        gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCtAcc);\n        tiled_mma.accumulate_ = UMMA::ScaleOut::One;\n      }\n      cutlass::arch::umma_arrive(&shared_storage.mma_barrier);\n    }\n    cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit);\n    mma_barrier_phase_bit ^= 1;\n  }\n\n  // Step 3: The Epilogue\n  TiledCopy tiled_t2r_copy =\n      make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc);\n  ThrCopy thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x);\n\n  Tensor tDgC = thr_t2r_copy.partition_D(tCgC);\n  Tensor tDrC = make_fragment_like(tDgC);\n  copy(tDgC, tDrC);\n\n  Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc);\n  Tensor tDgD = thr_t2r_copy.partition_D(tCgD);\n  using AccType = typename decltype(tCtAcc)::value_type;\n  Tensor tDrAcc = make_tensor<AccType>(shape(tDgD));\n  copy(tiled_t2r_copy, tDtAcc, tDrAcc);\n\n  // AXPBY and store result\n  axpby(alpha, tDrAcc, beta, tDrC);\n  copy(tDrC, tDgD);\n\n  __syncthreads();\n\n  // Cleanup TMEM\n  if (elect_one_warp) {\n    tmem_allocator.release_allocation_lock();\n    tmem_allocator.free(shared_storage.tmem_base_ptr,\n                        TmemAllocator::Sm100TmemCapacityColumns);\n  }\n}\n\n// Host function that launches the kernel\ncudaError_t launch_sm100_gemm_f16(void *d_A, void *d_B, void *d_C, void *d_D,\n                                  int M, int N, int K, float alpha, float beta,\n                                  cudaStream_t stream) {\n  // Define types\n  using TypeA = cutlass::half_t;\n  using TypeB = cutlass::half_t;\n  using TypeC = float;\n  using TypeD = float;\n\n  // Create layouts (K-major for A and B, N-major for C and D)\n  auto layout_A = make_layout(make_shape(M, K), make_stride(K, Int<1>{}));\n  auto layout_B = make_layout(make_shape(N, K), make_stride(K, Int<1>{}));\n  auto layout_C = make_layout(make_shape(M, N), make_stride(N, Int<1>{}));\n  auto layout_D = layout_C;\n\n  // Create CuTe tensors\n  auto mA =\n      make_tensor(make_gmem_ptr(reinterpret_cast<TypeA *>(d_A)), layout_A);\n  auto mB =\n      make_tensor(make_gmem_ptr(reinterpret_cast<TypeB *>(d_B)), layout_B);\n  auto mC =\n      make_tensor(make_gmem_ptr(reinterpret_cast<TypeC *>(d_C)), layout_C);\n  auto mD =\n      make_tensor(make_gmem_ptr(reinterpret_cast<TypeD *>(d_D)), layout_D);\n\n  // Create TiledMMA\n  TiledMMA tiled_mma =\n      make_tiled_mma(SM100_MMA_F16BF16_SS<TypeA, TypeB, TypeC, 128, 256,\n                                          UMMA::Major::K, UMMA::Major::K>{});\n\n  // Define MMA tiler sizes\n  auto bM = tile_size<0>(tiled_mma);            // 128\n  auto bN = tile_size<1>(tiled_mma);            // 256\n  auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // 64\n  auto mma_tiler = make_shape(bM, bN, bK);\n\n  // Check alignment\n  if (M % int(bM) != 0 || N % int(bN) != 0 || K % int(bK) != 0) {\n    return cudaErrorInvalidValue;\n  }\n\n  // Create SMEM layouts\n  auto mma_shape_A = partition_shape_A(\n      tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler)));\n  auto mma_shape_B = partition_shape_B(\n      tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler)));\n\n  auto sA_layout =\n      UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeA>{}, mma_shape_A);\n  auto sB_layout =\n      UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeB>{}, mma_shape_B);\n\n  using SMEMStorage =\n      SharedStorage<TypeA, TypeB, decltype(sA_layout), decltype(sB_layout)>;\n\n  // Cluster configuration\n  auto cluster_shape = make_shape(Int<1>{}, Int<1>{}, Int<1>{});\n\n  // Launch parameters\n  dim3 dimBlock(128);\n  dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape),\n                  size<2>(cluster_shape));\n  dim3 dimGrid(ceil_div(M, int(bM)), ceil_div(N, int(bN)));\n  int smemBytes = sizeof(SMEMStorage);\n\n  // Get kernel pointer\n  auto *kernel_ptr =\n      &gemm_device<SMEMStorage, decltype(mA), decltype(mB), decltype(mC),\n                   decltype(mD), decltype(mma_tiler), decltype(tiled_mma),\n                   decltype(cluster_shape), float, float>;\n\n  // Set kernel attributes\n  cudaError_t error = cudaFuncSetAttribute(\n      kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smemBytes);\n  if (error != cudaSuccess) {\n    return error;\n  }\n\n  // Launch kernel\n  cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster,\n                                         smemBytes};\n  cutlass::Status status = cutlass::launch_kernel_on_cluster(\n      params, (void const *)kernel_ptr, mA, mB, mC, mD, mma_tiler, tiled_mma,\n      cluster_shape, alpha, beta);\n\n  return (status == cutlass::Status::kSuccess) ? cudaSuccess\n                                               : cudaErrorLaunchFailure;\n}\n\n#else\n\ncudaError_t launch_sm100_gemm_f16(void *d_A, void *d_B, void *d_C, void *d_D,\n                                  int M, int N, int K, float alpha, float beta,\n                                  cudaStream_t stream) {\n  return cudaErrorNotSupported;\n}\n\n#endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/PKG-INFO",
    "content": "Metadata-Version: 2.4\nName: sm100_gemm\nVersion: 0.0.0\nRequires-Python: >=3.8\nRequires-Dist: torch>=1.12.0\nDynamic: requires-dist\nDynamic: requires-python\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/SOURCES.txt",
    "content": "setup.py\nsm100_gemm.cu\nsm100_gemm_pytorch.cpp\nsm100_gemm.egg-info/PKG-INFO\nsm100_gemm.egg-info/SOURCES.txt\nsm100_gemm.egg-info/dependency_links.txt\nsm100_gemm.egg-info/not-zip-safe\nsm100_gemm.egg-info/requires.txt\nsm100_gemm.egg-info/top_level.txt"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/dependency_links.txt",
    "content": "\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/not-zip-safe",
    "content": "\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/requires.txt",
    "content": "torch>=1.12.0\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/top_level.txt",
    "content": "sm100_gemm\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/sm100_gemm.h",
    "content": "// sm100_gemm_kernel.h - Header file for CUDA kernel\n#pragma once\n\n#include <cuda_runtime.h>\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n/**\n * Launch SM100 GEMM kernel: D = alpha * A @ B^T + beta * C\n *\n * @param d_A Pointer to matrix A in device memory (M x K, FP16, K-major)\n * @param d_B Pointer to matrix B in device memory (N x K, FP16, K-major)\n * @param d_C Pointer to matrix C in device memory (M x N, FP32, N-major)\n * @param d_D Pointer to matrix D in device memory (M x N, FP32, N-major)\n * @param M Number of rows in A and C/D\n * @param N Number of rows in B and columns in C/D\n * @param K Number of columns in A and B\n * @param alpha Scaling factor for A @ B^T\n * @param beta Scaling factor for C\n * @param stream CUDA stream (currently unused, for future async support)\n *\n * @return cudaSuccess on success, error code otherwise\n *\n * Requirements:\n * - M must be multiple of 128\n * - N must be multiple of 256\n * - K must be multiple of 64\n * - All pointers must be valid device memory\n * - Tensors must be contiguous with specified layouts\n */\ncudaError_t launch_sm100_gemm_f16(void *d_A, void *d_B, void *d_C, void *d_D,\n                                  int M, int N, int K, float alpha, float beta,\n                                  cudaStream_t stream = 0);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/sm100_gemm_pytorch.cpp",
    "content": "// sm100_gemm_pytorch.cpp - PyTorch C++ extension (no CUDA code)\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGuard.h>\n#include <cuda_runtime.h>\n#include <torch/extension.h>\n\n#include \"sm100_gemm.h\"\n\n// Check if SM100 support is available at compile time\nbool is_sm100_supported() {\n#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)\n  return true;\n#else\n  return false;\n#endif\n}\n\n// Check if current GPU supports SM100 at runtime\nbool check_sm100_device() {\n  int device;\n  cudaGetDevice(&device);\n\n  cudaDeviceProp props;\n  cudaError_t error = cudaGetDeviceProperties(&props, device);\n  if (error != cudaSuccess) {\n    return false;\n  }\n\n  // Check for SM100 architecture (compute capability 10.0a)\n  return (props.major == 10 && props.minor == 0);\n}\n\ntorch::Tensor sm100_gemm_f16(const torch::Tensor &A, const torch::Tensor &B,\n                             const torch::Tensor &C, float alpha = 1.0f,\n                             float beta = 0.0f) {\n\n  // Check compile-time support\n  TORCH_CHECK(\n      is_sm100_supported(),\n      \"SM100 support not compiled. Requires CUTLASS_ARCH_MMA_SM100_SUPPORTED\");\n\n  // Check runtime device support\n  TORCH_CHECK(check_sm100_device(),\n              \"Current GPU does not support SM100 architecture (requires \"\n              \"compute capability 10.0a)\");\n\n  // Input validation\n  TORCH_CHECK(A.device().is_cuda(), \"A must be a CUDA tensor\");\n  TORCH_CHECK(B.device().is_cuda(), \"B must be a CUDA tensor\");\n  TORCH_CHECK(C.device().is_cuda(), \"C must be a CUDA tensor\");\n  TORCH_CHECK(A.dtype() == torch::kFloat16, \"A must be float16\");\n  TORCH_CHECK(B.dtype() == torch::kFloat16, \"B must be float16\");\n  TORCH_CHECK(C.dtype() == torch::kFloat32, \"C must be float32\");\n  TORCH_CHECK(A.is_contiguous(), \"A must be contiguous\");\n  TORCH_CHECK(B.is_contiguous(), \"B must be contiguous\");\n  TORCH_CHECK(C.is_contiguous(), \"C must be contiguous\");\n  TORCH_CHECK(A.dim() == 2, \"A must be 2D\");\n  TORCH_CHECK(B.dim() == 2, \"B must be 2D\");\n  TORCH_CHECK(C.dim() == 2, \"C must be 2D\");\n\n  // Get dimensions\n  int64_t M = A.size(0);\n  int64_t K = A.size(1);\n  int64_t N = B.size(0);\n  int64_t K_B = B.size(1);\n\n  TORCH_CHECK(K == K_B, \"Inner dimensions must match: A.shape[1]=\", K,\n              \", B.shape[1]=\", K_B);\n  TORCH_CHECK(C.size(0) == M && C.size(1) == N, \"C dimensions (\", C.size(0),\n              \", \", C.size(1), \") must match output shape (\", M, \", \", N, \")\");\n\n  // Check alignment requirements for SM100\n  TORCH_CHECK(M % 128 == 0, \"M=\", M, \" must be multiple of 128\");\n  TORCH_CHECK(N % 256 == 0, \"N=\", N, \" must be multiple of 256\");\n  TORCH_CHECK(K % 64 == 0, \"K=\", K, \" must be multiple of 64\");\n\n  // Check size limits (avoid overflow in int conversion)\n  TORCH_CHECK(M <= INT_MAX && N <= INT_MAX && K <= INT_MAX,\n              \"Dimensions too large for int conversion\");\n\n  // Create output tensor\n  auto D = torch::empty_like(C);\n\n  // Set CUDA device guard\n  const auto device = A.device();\n  c10::cuda::CUDAGuard device_guard(device);\n\n  // Get current CUDA stream\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index()).stream();\n\n  // Launch the kernel\n  cudaError_t error = launch_sm100_gemm_f16(\n      A.data_ptr(), B.data_ptr(), C.data_ptr(), D.data_ptr(),\n      static_cast<int>(M), static_cast<int>(N), static_cast<int>(K), alpha,\n      beta, stream);\n\n  // Check for launch errors\n  TORCH_CHECK(error == cudaSuccess,\n              \"SM100 GEMM kernel launch failed: \", cudaGetErrorString(error));\n\n  // Check for kernel execution errors\n  C10_CUDA_CHECK(cudaGetLastError());\n\n  return D;\n}\n\n// Utility functions for debugging and information\ntorch::Tensor get_device_info() {\n  int device;\n  cudaGetDevice(&device);\n\n  cudaDeviceProp props;\n  cudaGetDeviceProperties(&props, device);\n\n  // Return device info as a tensor (for easy Python access)\n  auto info = torch::zeros({4}, torch::kInt32);\n  auto accessor = info.accessor<int32_t, 1>();\n\n  accessor[0] = props.major;          // Compute capability major\n  accessor[1] = props.minor;          // Compute capability minor\n  accessor[2] = is_sm100_supported(); // Compile-time support\n  accessor[3] = check_sm100_device(); // Runtime device support\n\n  return info;\n}\n\nstd::vector<int64_t> get_aligned_shape(int64_t M, int64_t N, int64_t K) {\n  // Return properly aligned dimensions for SM100\n  int64_t aligned_M = ((M + 127) / 128) * 128;\n  int64_t aligned_N = ((N + 255) / 256) * 256;\n  int64_t aligned_K = ((K + 63) / 64) * 64;\n\n  return {aligned_M, aligned_N, aligned_K};\n}\n\ntorch::Tensor create_aligned_tensor(const std::vector<int64_t> &shape,\n                                    torch::ScalarType dtype,\n                                    torch::Device device) {\n  // Create a tensor with SM100-aligned dimensions\n  TORCH_CHECK(shape.size() == 2, \"Shape must be 2D\");\n\n  auto aligned_shape =\n      get_aligned_shape(shape[0], shape[1], shape.size() > 2 ? shape[2] : 64);\n\n  if (shape.size() == 2) {\n    return torch::zeros({aligned_shape[0], aligned_shape[1]},\n                        torch::TensorOptions().dtype(dtype).device(device));\n  } else {\n    return torch::zeros({aligned_shape[0], aligned_shape[2]},\n                        torch::TensorOptions().dtype(dtype).device(device));\n  }\n}\n\n// Python bindings\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.doc() = \"SM100 GEMM PyTorch Extension\";\n\n  // Main GEMM function\n  m.def(\"sm100_gemm_f16\", &sm100_gemm_f16,\n        \"SM100 GEMM with FP16 inputs and FP32 output: D = alpha * A @ B^T + \"\n        \"beta * C\",\n        py::arg(\"A\"), py::arg(\"B\"), py::arg(\"C\"), py::arg(\"alpha\") = 1.0f,\n        py::arg(\"beta\") = 0.0f);\n\n  // Utility functions\n  m.def(\"is_sm100_supported\", &is_sm100_supported,\n        \"Check if SM100 support was compiled in\");\n\n  m.def(\"check_sm100_device\", &check_sm100_device,\n        \"Check if current GPU supports SM100 architecture\");\n\n  m.def(\"get_device_info\", &get_device_info,\n        \"Get device compute capability and SM100 support info\");\n\n  m.def(\"get_aligned_shape\", &get_aligned_shape,\n        \"Get SM100-aligned dimensions for given shape\", py::arg(\"M\"),\n        py::arg(\"N\"), py::arg(\"K\"));\n\n  m.def(\"create_aligned_tensor\", &create_aligned_tensor,\n        \"Create tensor with SM100-aligned dimensions\", py::arg(\"shape\"),\n        py::arg(\"dtype\"), py::arg(\"device\"));\n\n  // Constants for alignment requirements\n  m.attr(\"MMA_TILE_M\") = 128;\n  m.attr(\"MMA_TILE_N\") = 256;\n  m.attr(\"MMA_TILE_K\") = 64;\n}\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/.ninja_log",
    "content": "# ninja log v5\n1\t15202\t1748185895110710199\t/data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o\t342153d32d365f0b\n7\t78\t1748186494782816813\t/data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o\t342153d32d365f0b\n6\t15086\t1748186805607894090\t/data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o\t342153d32d365f0b\n6\t14058\t1748187024415643408\t/data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm.o\t6c5f77cfca7cfb81\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/build.ninja",
    "content": "ninja_required_version = 1.3\ncxx = c++\nnvcc = /usr/local/cuda-12.8/bin/nvcc\n\ncflags = -pthread -B /home/less/.conda/envs/pycutlass/compiler_compat -fno-strict-overflow -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c\npost_cflags = -O3 -std=c++17 -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE=\"_gcc\"' '-DPYBIND11_STDLIB=\"_libstdcpp\"' '-DPYBIND11_BUILD_ABI=\"_cxxabi1018\"' -DTORCH_EXTENSION_NAME=sm100_gemm\ncuda_cflags = -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c\ncuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''\"'\"'-fPIC'\"'\"'' -O3 -std=c++17 --expt-relaxed-constexpr --expt-extended-lambda -gencode=arch=compute_100a,code=sm_100a -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED --use_fast_math -Xcompiler=-fPIC -DCUTE_ARCH_TCGEN05_TMEM_ENABLED=1 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE=\"_gcc\"' '-DPYBIND11_STDLIB=\"_libstdcpp\"' '-DPYBIND11_BUILD_ABI=\"_cxxabi1018\"' -DTORCH_EXTENSION_NAME=sm100_gemm\ncuda_dlink_post_cflags = \nsycl_dlink_post_cflags = \nldflags = \n\nrule compile\n  command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags\n  depfile = $out.d\n  deps = gcc\n\nrule cuda_compile\n  depfile = $out.d\n  deps = gcc\n  command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags\n\n\n\n\n\n\n\nbuild /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm.o: cuda_compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/sm100_gemm.cu\nbuild /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o: compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/sm100_gemm_pytorch.cpp\n\n\n\n\n\n\n\n\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/driver.py",
    "content": "# python_interface.py - High-level Python interface with TMA support\n\nimport torch\n\ntry:\n    import sm100_gemm  # The compiled extension\nexcept ImportError:\n    print(\"❌ SM100 not ready!\")\n    raise ImportError(\n        \"SM100 not ready! Please build the extension using `python setup.py install`\"\n    )\n\n\ndef check_sm100_compatibility():\n    \"\"\"Check if SM100 is supported and available\"\"\"\n    compile_support = sm100_gemm.is_sm100_supported()\n    device_support = sm100_gemm.check_sm100_device()\n\n    info = sm100_gemm.get_device_info()\n    major, minor, compile_flag, device_flag = info.tolist()\n\n    print(f\"Device compute capability: {major}.{minor}\")\n    print(f\"Compile-time SM100 support: {bool(compile_flag)}\")\n    print(f\"Runtime SM100 device support: {bool(device_flag)}\")\n\n    if not compile_support:\n        print(\n            \"❌ SM100 support not compiled in. Rebuild with CUTLASS_ARCH_MMA_SM100_SUPPORTED\"\n        )\n    elif not device_support:\n        print(\"❌ Current GPU does not support SM100 (need compute capability 10.0a)\")\n    else:\n        print(\"✅ SM100 with TMA ready!\")\n\n    return compile_support and device_support\n\n\ndef sm100_gemm_f16_tma(A, B, C=None, alpha=1.0, beta=0.0, check_alignment=True):\n    \"\"\"\n    Perform GEMM using SM100 optimized kernel with TMA: D = alpha * A @ B^T + beta * C\n\n    Args:\n        A (torch.Tensor): Input tensor A of shape (M, K), dtype=torch.float16\n        B (torch.Tensor): Input tensor B of shape (N, K), dtype=torch.float16\n        C (torch.Tensor, optional): Input tensor C of shape (M, N), dtype=torch.float32\n                                   If None, creates zero tensor\n        alpha (float): Scaling factor for A @ B^T\n        beta (float): Scaling factor for C\n        check_alignment (bool): Whether to check and suggest aligned dimensions\n\n    Returns:\n        torch.Tensor: Output tensor D of shape (M, N), dtype=torch.float32\n\n    Note:\n        - Uses TMA (Tensor Memory Accelerator) for efficient memory transfers\n        - A and B are K-major (transposed in BLAS terms)\n        - C and D are N-major (row-major)\n        - All tensors must be on CUDA\n        - M must be multiple of 128, N multiple of 256, K multiple of 64\n    \"\"\"\n\n    # Input validation\n    assert A.dtype == torch.float16, f\"A must be float16, got {A.dtype}\"\n    assert B.dtype == torch.float16, f\"B must be float16, got {B.dtype}\"\n    assert A.is_cuda and B.is_cuda, \"A and B must be on CUDA\"\n    assert A.is_contiguous() and B.is_contiguous(), \"A and B must be contiguous\"\n\n    M, K = A.shape\n    N, K_B = B.shape\n    assert K == K_B, f\"Inner dimensions must match: A.shape[1]={K}, B.shape[1]={K_B}\"\n\n    # Check or fix alignment requirements\n    if check_alignment:\n        aligned_M, aligned_N, aligned_K = sm100_gemm.get_aligned_shape(M, N, K)\n\n        if M != aligned_M or N != aligned_N or K != aligned_K:\n            print(f\"Warning: Dimensions ({M}, {N}, {K}) not aligned for SM100\")\n            print(\n                f\"Suggested aligned dimensions: ({aligned_M}, {aligned_N}, {aligned_K})\"\n            )\n            print(\"Consider padding tensors or use create_aligned_tensors()\")\n\n    # Strict alignment check\n    assert (\n        M % sm100_gemm.MMA_TILE_M == 0\n    ), f\"M={M} must be multiple of {sm100_gemm.MMA_TILE_M}\"\n    assert (\n        N % sm100_gemm.MMA_TILE_N == 0\n    ), f\"N={N} must be multiple of {sm100_gemm.MMA_TILE_N}\"\n    assert (\n        K % sm100_gemm.MMA_TILE_K == 0\n    ), f\"K={K} must be multiple of {sm100_gemm.MMA_TILE_K}\"\n\n    # Create C if not provided\n    if C is None:\n        C = torch.zeros(M, N, dtype=torch.float32, device=A.device)\n    else:\n        assert C.dtype == torch.float32, f\"C must be float32, got {C.dtype}\"\n        assert C.is_cuda, \"C must be on CUDA\"\n        assert C.is_contiguous(), \"C must be contiguous\"\n        assert C.shape == (\n            M,\n            N,\n        ), f\"C shape {C.shape} must match output shape ({M}, {N})\"\n\n    # Call the extension (now uses TMA internally)\n    return sm100_gemm.sm100_gemm_f16(A, B, C, alpha, beta)\n\n\n# Keep the old name for compatibility\nsm100_gemm_f16 = sm100_gemm_f16_tma\n\n\ndef create_aligned_tensors(\n    M, N, K, device=\"cuda\", dtype_AB=torch.float16, dtype_C=torch.float32\n):\n    \"\"\"\n    Create properly aligned tensors for SM100 GEMM with TMA\n\n    Returns:\n        tuple: (A, B, C) tensors with aligned dimensions\n    \"\"\"\n    aligned_M, aligned_N, aligned_K = sm100_gemm.get_aligned_shape(M, N, K)\n\n    A = torch.zeros(aligned_M, aligned_K, dtype=dtype_AB, device=device)\n    B = torch.zeros(aligned_N, aligned_K, dtype=dtype_AB, device=device)\n    C = torch.zeros(aligned_M, aligned_N, dtype=dtype_C, device=device)\n\n    return A, B, C\n\n\ndef pad_to_aligned(tensor, target_shape=None, dim_requirements=None):\n    \"\"\"\n    Pad tensor to meet SM100 alignment requirements\n\n    Args:\n        tensor: Input tensor to pad\n        target_shape: Specific target shape (optional)\n        dim_requirements: Tuple of (M_align, N_align, K_align) requirements\n\n    Returns:\n        Padded tensor and padding info for later unpadding\n    \"\"\"\n    if dim_requirements is None:\n        dim_requirements = (\n            sm100_gemm.MMA_TILE_M,\n            sm100_gemm.MMA_TILE_N,\n            sm100_gemm.MMA_TILE_K,\n        )\n\n    if tensor.dim() == 2:\n        M, N = tensor.shape\n\n        if target_shape:\n            target_M, target_N = target_shape\n        else:\n            target_M = (\n                (M + dim_requirements[0] - 1) // dim_requirements[0]\n            ) * dim_requirements[0]\n            target_N = (\n                (N + dim_requirements[1] - 1) // dim_requirements[1]\n            ) * dim_requirements[1]\n\n        pad_M = target_M - M\n        pad_N = target_N - N\n\n        # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)\n        padded = torch.nn.functional.pad(tensor, (0, pad_N, 0, pad_M))\n\n        return padded, (M, N, pad_M, pad_N)\n    else:\n        raise ValueError(\"Only 2D tensors supported\")\n\n\ndef unpad_result(tensor, padding_info):\n    \"\"\"Remove padding from result tensor\"\"\"\n    orig_M, orig_N, pad_M, pad_N = padding_info\n    return tensor[:orig_M, :orig_N]\n\n\ndef benchmark_sm100_vs_torch(\n    M=512,\n    N=1024,\n    K=256,\n    num_warmup=1,\n    num_trials=10,\n    auto_align=True,\n    compare_tma=True,\n):\n    \"\"\"\n    Benchmark SM100 GEMM with TMA against PyTorch's native GEMM\n    \"\"\"\n    # Ensure dimensions are aligned\n    if auto_align:\n        M = (\n            (M + sm100_gemm.MMA_TILE_M - 1) // sm100_gemm.MMA_TILE_M\n        ) * sm100_gemm.MMA_TILE_M\n        N = (\n            (N + sm100_gemm.MMA_TILE_N - 1) // sm100_gemm.MMA_TILE_N\n        ) * sm100_gemm.MMA_TILE_N\n        K = (\n            (K + sm100_gemm.MMA_TILE_K - 1) // sm100_gemm.MMA_TILE_K\n        ) * sm100_gemm.MMA_TILE_K\n\n    print(f\"Benchmarking GEMM with TMA for shape: ({M}, {N}, {K})\")\n\n    # Check SM100 availability\n    if not check_sm100_compatibility():\n        print(\"SM100 not available, skipping benchmark\")\n        return None\n\n    # Create test tensors\n    A = torch.randn(M, K, dtype=torch.float16, device=\"cuda\")\n    B = torch.randn(N, K, dtype=torch.float16, device=\"cuda\")\n    C = torch.randn(M, N, dtype=torch.float32, device=\"cuda\")\n\n    # PyTorch baseline (using mixed precision)\n    A_fp32 = A.float()\n    B_fp32 = B.float()\n\n    # Warmup\n    for _ in range(num_warmup):\n        # PyTorch GEMM\n        torch_result = torch.addmm(C, A_fp32, B_fp32.T)\n\n        # SM100 GEMM with TMA\n        sm100_result = sm100_gemm_f16_tma(A, B, C.clone(), check_alignment=False)\n\n    torch.cuda.synchronize()\n\n    # Benchmark PyTorch\n    torch.cuda.synchronize()\n    start = torch.cuda.Event(enable_timing=True)\n    end = torch.cuda.Event(enable_timing=True)\n    # warmup\n    torch_result = torch.addmm(C, A_fp32, B_fp32.T)\n\n    start.record()\n    for _ in range(num_trials):\n        torch_result = torch.addmm(C, A_fp32, B_fp32.T)\n    end.record()\n    torch.cuda.synchronize()\n    torch_time = start.elapsed_time(end) / num_trials\n\n    # Benchmark SM100 with TMA\n    # warmup\n    sm100_result = sm100_gemm_f16_tma(A, B, C.clone(), check_alignment=False)\n\n    start.record()\n    for _ in range(num_trials):\n        sm100_result = sm100_gemm_f16_tma(A, B, C.clone(), check_alignment=False)\n    end.record()\n    torch.cuda.synchronize()\n    sm100_time = start.elapsed_time(end) / num_trials\n\n    # Check correctness\n    max_diff = torch.max(torch.abs(torch_result - sm100_result))\n    rel_error = max_diff / torch.max(torch.abs(torch_result))\n\n    # Calculate FLOPS\n    flops = 2 * M * N * K  # Multiply-add operations\n    torch_tflops = flops / (torch_time * 1e-3) / 1e12\n    sm100_tflops = flops / (sm100_time * 1e-3) / 1e12\n\n    print(f\"PyTorch time: {torch_time:.3f} ms ({torch_tflops:.2f} TFLOPS)\")\n    print(f\"SM100+TMA time: {sm100_time:.3f} ms ({sm100_tflops:.2f} TFLOPS)\")\n    print(f\"Speedup: {torch_time/sm100_time:.2f}x\")\n    print(f\"Max difference: {max_diff:.6f}\")\n    print(f\"Relative error: {rel_error:.6f}\")\n    print(f\"🚀 TMA provides efficient memory transfers for large matrices!\")\n\n    return {\n        \"torch_time\": torch_time,\n        \"sm100_time\": sm100_time,\n        \"speedup\": torch_time / sm100_time,\n        \"torch_tflops\": torch_tflops,\n        \"sm100_tflops\": sm100_tflops,\n        \"max_diff\": max_diff.item(),\n        \"rel_error\": rel_error.item(),\n    }\n\n\n# Neural network layer implementations with TMA\nclass SM100LinearTMA(torch.nn.Module):\n    \"\"\"\n    Linear layer using SM100 GEMM with TMA for forward pass\n    \"\"\"\n\n    def __init__(self, in_features, out_features, bias=True, device=\"cuda\"):\n        super().__init__()\n\n        # Align dimensions\n        self.orig_in_features = in_features\n        self.orig_out_features = out_features\n\n        aligned_in = (\n            (in_features + sm100_gemm.MMA_TILE_K - 1) // sm100_gemm.MMA_TILE_K\n        ) * sm100_gemm.MMA_TILE_K\n        aligned_out = (\n            (out_features + sm100_gemm.MMA_TILE_N - 1) // sm100_gemm.MMA_TILE_N\n        ) * sm100_gemm.MMA_TILE_N\n\n        self.in_features = aligned_in\n        self.out_features = aligned_out\n\n        # Parameters (with padding)\n        self.weight = torch.nn.Parameter(\n            torch.randn(aligned_out, aligned_in, dtype=torch.float16, device=device)\n            * 0.1\n        )\n\n        if bias:\n            self.bias = torch.nn.Parameter(\n                torch.zeros(aligned_out, dtype=torch.float32, device=device)\n            )\n        else:\n            self.register_parameter(\"bias\", None)\n\n        print(\n            f\"SM100LinearTMA: {in_features} -> {out_features} (aligned: {aligned_in} -> {aligned_out})\"\n        )\n        print(\"🚀 Using TMA for efficient memory transfers\")\n\n    def forward(self, x):\n        # Pad input if necessary\n        batch_size = x.size(0)\n\n        # Align batch size\n        aligned_batch = (\n            (batch_size + sm100_gemm.MMA_TILE_M - 1) // sm100_gemm.MMA_TILE_M\n        ) * sm100_gemm.MMA_TILE_M\n\n        if x.size(1) != self.in_features or batch_size != aligned_batch:\n            x_padded = torch.zeros(\n                aligned_batch, self.in_features, dtype=torch.float16, device=x.device\n            )\n            x_padded[:batch_size, : self.orig_in_features] = x\n            x = x_padded\n\n        # Prepare bias\n        if self.bias is not None:\n            C = (\n                self.bias.unsqueeze(0)\n                .expand(aligned_batch, self.out_features)\n                .contiguous()\n            )\n            beta = 1.0\n        else:\n            C = torch.zeros(\n                aligned_batch, self.out_features, dtype=torch.float32, device=x.device\n            )\n            beta = 0.0\n\n        # SM100 GEMM with TMA: output = x @ weight^T + bias\n        output = sm100_gemm_f16_tma(\n            x, self.weight, C, alpha=1.0, beta=beta, check_alignment=False\n        )\n\n        # Remove padding\n        return output[:batch_size, : self.orig_out_features]\n\n\ndef benchmark_tma_vs_cooperative_copy(M=512, N=1024, K=256, num_trials=50):\n    \"\"\"\n    TMA addition\n    \"\"\"\n\n    results = benchmark_sm100_vs_torch(M, N, K, num_trials=num_trials)\n\n    if results:\n        print(f\"\\nTMA-accelerated SM100 GEMM achieved:\")\n        print(f\"   Performance: {results['sm100_tflops']:.2f} TFLOPS\")\n        print(f\"   Speedup: {results['speedup']:.2f}x over PyTorch\")\n        print(f\"   Memory efficiency: Hardware-optimized transfers\")\n\n\ndef stress_test_large_matrices():\n    \"\"\"\n    Test TMA performance with large matrices that benefit most from TMA\n    \"\"\"\n    print(\"\\n=== Large Matrix Stress Test with TMA ===\")\n\n    # Test progressively larger matrices\n    test_sizes = [\n        (1024, 2048, 512),  # 1GB+ tensors\n        (2048, 4096, 1024),  # 4GB+ tensors\n        (4096, 8192, 2048),  # 16GB+ tensors (if memory allows)\n    ]\n\n    for M, N, K in test_sizes:\n        try:\n            print(f\"\\nTesting size: ({M}, {N}, {K})\")\n\n            # Check memory requirements\n            memory_A = M * K * 2  # FP16\n            memory_B = N * K * 2  # FP16\n            memory_C = M * N * 4  # FP32\n            total_memory = (memory_A + memory_B + memory_C * 2) / (1024**3)  # GB\n\n            print(f\"Memory requirement: {total_memory:.2f} GB\")\n\n            if total_memory > 20:  # Skip if > 20GB\n                print(\"⚠️  Skipping due to memory constraints\")\n                continue\n\n            # Create tensors\n            A, B, C = create_aligned_tensors(M, N, K)\n            A[:M, :K].normal_(0, 0.1)\n            B[:N, :K].normal_(0, 0.1)\n            C[:M, :N].normal_(0, 0.1)\n\n            # Warmup\n            for _ in range(3):\n                result = sm100_gemm_f16_tma(A, B, C.clone(), check_alignment=False)\n\n            torch.cuda.synchronize()\n\n            # Benchmark\n            start = torch.cuda.Event(enable_timing=True)\n            end = torch.cuda.Event(enable_timing=True)\n\n            num_trials = 10\n            start.record()\n            for _ in range(num_trials):\n                result = sm100_gemm_f16_tma(A, B, C.clone(), check_alignment=False)\n            end.record()\n\n            torch.cuda.synchronize()\n            avg_time = start.elapsed_time(end) / num_trials\n\n            # Calculate performance\n            flops = 2 * M * N * K\n            tflops = flops / (avg_time * 1e-3) / 1e12\n            bandwidth = total_memory / (avg_time * 1e-3)  # GB/s\n\n            print(f\"✅ Time: {avg_time:.2f} ms\")\n            print(f\"✅ Performance: {tflops:.2f} TFLOPS\")\n            print(f\"✅ Bandwidth: {bandwidth:.1f} GB/s\")\n            print(f\"🚀 TMA enables efficient handling of large matrices!\")\n\n        except torch.cuda.OutOfMemoryError:\n            print(f\"❌ Out of memory for size ({M}, {N}, {K})\")\n            break\n        except Exception as e:\n            print(f\"❌ Error: {e}\")\n            break\n\n\n# Example usage and test\nif __name__ == \"__main__\":\n    print(\"=== SM100 GEMM with TMA Extension Test ===\")\n\n    # Check compatibility first\n    if not check_sm100_compatibility():\n        print(\"Exiting due to compatibility issues\")\n        exit(1)\n\n    print(\"\\n=== Testing basic TMA functionality ===\")\n\n    # Test with properly aligned dimensions\n    M, N, K = 512, 1024, 256\n    A, B, C = create_aligned_tensors(M, N, K)\n\n    # Fill with random data (only the actual needed portion)\n    A[:M, :K].normal_()\n    B[:N, :K].normal_()\n    C[:M, :N].normal_()\n\n    # Test the TMA GEMM\n    result = sm100_gemm_f16_tma(A, B, C, alpha=1.0, beta=0.5, check_alignment=False)\n    print(\n        f\"✅ TMA GEMM test passed. Result shape: {result.shape}, dtype: {result.dtype}\"\n    )\n\n    print(\"\\n=== Testing SM100LinearTMA layer ===\")\n\n    # Test linear layer with TMA\n    layer = SM100LinearTMA(256, 512, bias=True)\n    x = torch.randn(128, 256, dtype=torch.float16, device=\"cuda\")\n    output = layer(x)\n    print(f\"✅ TMA Linear layer test passed. Output shape: {output.shape}\")\n\n    print(\"\\n=== Testing padding utilities ===\")\n\n    # Test padding for misaligned tensors\n    misaligned_A = torch.randn(300, 200, dtype=torch.float16, device=\"cuda\")\n    padded_A, pad_info = pad_to_aligned(misaligned_A)\n    print(f\"Original shape: {misaligned_A.shape}, Padded shape: {padded_A.shape}\")\n\n    unpadded = unpad_result(padded_A, pad_info)\n    print(f\"✅ Padding test passed. Unpadded shape: {unpadded.shape}\")\n\n    print(\"\\n=== Running TMA performance benchmark ===\")\n\n    # Run benchmark\n    benchmark_results = benchmark_sm100_vs_torch(M=512, N=1024, K=256, num_trials=50)\n\n    if benchmark_results:\n        print(f\"\\n✅ All TMA tests passed!\")\n        print(\n            f\"🚀 SM100+TMA achieved {benchmark_results['speedup']:.2f}x speedup over PyTorch\"\n        )\n        print(f\"🚀 TMA provides hardware-accelerated memory transfers!\")\n\n        # Run additional TMA-specific tests\n        benchmark_tma_vs_cooperative_copy(M=1024, N=2048, K=512)\n\n        # Test with larger matrices if memory allows\n        print(\"\\n=== Testing TMA with larger matrices ===\")\n        stress_test_large_matrices()\n\n    else:\n        print(\"❌ Benchmark failed\")\n\n    print(\"\\n=== TMA Summary ===\")\n    print(\"🚀 TMA (Tensor Memory Accelerator) provides:\")\n    print(\"   • Hardware-accelerated global->shared memory transfers\")\n    print(\"   • Reduced CPU overhead and better bandwidth utilization\")\n    print(\"   • Automatic memory layout optimization\")\n    print(\"   • Essential for peak performance on large matrices\")\n    print(\"   • Enables scaling to multi-GB tensor operations\")\nimport sm100_gemm  # The compiled extension\n\n# python_interface.py - High-level Python interface (updated for split files)\nimport torch\n\n\ndef check_sm100_compatibility():\n    \"\"\"Check if SM100 is supported and available\"\"\"\n    compile_support = sm100_gemm.is_sm100_supported()\n    device_support = sm100_gemm.check_sm100_device()\n\n    info = sm100_gemm.get_device_info()\n    major, minor, compile_flag, device_flag = info.tolist()\n\n    print(f\"Device compute capability: {major}.{minor}\")\n    print(f\"Compile-time SM100 support: {bool(compile_flag)}\")\n    print(f\"Runtime SM100 device support: {bool(device_flag)}\")\n\n    if not compile_support:\n        print(\n            \"❌ SM100 support not compiled in. Rebuild with CUTLASS_ARCH_MMA_SM100_SUPPORTED\"\n        )\n    elif not device_support:\n        print(\"❌ Current GPU does not support SM100 (need compute capability 10.0a)\")\n    else:\n        print(\"SM100 ready!\")  # ✅\n\n    return compile_support and device_support\n\n\ndef sm100_gemm_f16(A, B, C=None, alpha=1.0, beta=0.0, check_alignment=True):\n    \"\"\"\n    Perform GEMM using SM100 optimized kernel: D = alpha * A @ B^T + beta * C\n\n    Args:\n        A (torch.Tensor): Input tensor A of shape (M, K), dtype=torch.float16\n        B (torch.Tensor): Input tensor B of shape (N, K), dtype=torch.float16\n        C (torch.Tensor, optional): Input tensor C of shape (M, N), dtype=torch.float32\n                                   If None, creates zero tensor\n        alpha (float): Scaling factor for A @ B^T\n        beta (float): Scaling factor for C\n        check_alignment (bool): Whether to check and suggest aligned dimensions\n\n    Returns:\n        torch.Tensor: Output tensor D of shape (M, N), dtype=torch.float32\n\n    Note:\n        - A and B are K-major (transposed in BLAS terms)\n        - C and D are N-major (row-major)\n        - All tensors must be on CUDA\n        - M must be multiple of 128, N multiple of 256, K multiple of 64\n    \"\"\"\n\n    # Input validation\n    assert A.dtype == torch.float16, f\"A must be float16, got {A.dtype}\"\n    assert B.dtype == torch.float16, f\"B must be float16, got {B.dtype}\"\n    assert A.is_cuda and B.is_cuda, \"A and B must be on CUDA\"\n    assert A.is_contiguous() and B.is_contiguous(), \"A and B must be contiguous\"\n\n    M, K = A.shape\n    N, K_B = B.shape\n    assert K == K_B, f\"Inner dimensions must match: A.shape[1]={K}, B.shape[1]={K_B}\"\n\n    # Check or fix alignment requirements\n    if check_alignment:\n        aligned_M, aligned_N, aligned_K = sm100_gemm.get_aligned_shape(M, N, K)\n\n        if M != aligned_M or N != aligned_N or K != aligned_K:\n            print(f\"Warning: Dimensions ({M}, {N}, {K}) not aligned for SM100\")\n            print(\n                f\"Suggested aligned dimensions: ({aligned_M}, {aligned_N}, {aligned_K})\"\n            )\n            print(\"Consider padding tensors or use create_aligned_tensors()\")\n\n    # Strict alignment check\n    assert (\n        M % sm100_gemm.MMA_TILE_M == 0\n    ), f\"M={M} must be multiple of {sm100_gemm.MMA_TILE_M}\"\n    assert (\n        N % sm100_gemm.MMA_TILE_N == 0\n    ), f\"N={N} must be multiple of {sm100_gemm.MMA_TILE_N}\"\n    assert (\n        K % sm100_gemm.MMA_TILE_K == 0\n    ), f\"K={K} must be multiple of {sm100_gemm.MMA_TILE_K}\"\n\n    # Create C if not provided\n    if C is None:\n        C = torch.zeros(M, N, dtype=torch.float32, device=A.device)\n    else:\n        assert C.dtype == torch.float32, f\"C must be float32, got {C.dtype}\"\n        assert C.is_cuda, \"C must be on CUDA\"\n        assert C.is_contiguous(), \"C must be contiguous\"\n        assert C.shape == (\n            M,\n            N,\n        ), f\"C shape {C.shape} must match output shape ({M}, {N})\"\n\n    # Call the extension\n    return sm100_gemm.sm100_gemm_f16(A, B, C, alpha, beta)\n\n\ndef create_aligned_tensors(\n    M, N, K, device=\"cuda\", dtype_AB=torch.float16, dtype_C=torch.float32\n):\n    \"\"\"\n    Create properly aligned tensors for SM100 GEMM\n\n    Returns:\n        tuple: (A, B, C) tensors with aligned dimensions\n    \"\"\"\n    aligned_M, aligned_N, aligned_K = sm100_gemm.get_aligned_shape(M, N, K)\n\n    A = torch.zeros(aligned_M, aligned_K, dtype=dtype_AB, device=device)\n    B = torch.zeros(aligned_N, aligned_K, dtype=dtype_AB, device=device)\n    C = torch.zeros(aligned_M, aligned_N, dtype=dtype_C, device=device)\n\n    return A, B, C\n\n\ndef pad_to_aligned(tensor, target_shape=None, dim_requirements=None):\n    \"\"\"\n    Pad tensor to meet SM100 alignment requirements\n\n    Args:\n        tensor: Input tensor to pad\n        target_shape: Specific target shape (optional)\n        dim_requirements: Tuple of (M_align, N_align, K_align) requirements\n\n    Returns:\n        Padded tensor and padding info for later unpadding\n    \"\"\"\n    if dim_requirements is None:\n        dim_requirements = (\n            sm100_gemm.MMA_TILE_M,\n            sm100_gemm.MMA_TILE_N,\n            sm100_gemm.MMA_TILE_K,\n        )\n\n    if tensor.dim() == 2:\n        M, N = tensor.shape\n\n        if target_shape:\n            target_M, target_N = target_shape\n        else:\n            target_M = (\n                (M + dim_requirements[0] - 1) // dim_requirements[0]\n            ) * dim_requirements[0]\n            target_N = (\n                (N + dim_requirements[1] - 1) // dim_requirements[1]\n            ) * dim_requirements[1]\n\n        pad_M = target_M - M\n        pad_N = target_N - N\n\n        # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)\n        padded = torch.nn.functional.pad(tensor, (0, pad_N, 0, pad_M))\n\n        return padded, (M, N, pad_M, pad_N)\n    else:\n        raise ValueError(\"Only 2D tensors supported\")\n\n\ndef unpad_result(tensor, padding_info):\n    \"\"\"Remove padding from result tensor\"\"\"\n    orig_M, orig_N, pad_M, pad_N = padding_info\n    return tensor[:orig_M, :orig_N]\n\n\ndef benchmark_sm100_vs_torch(\n    M=512, N=1024, K=256, num_warmup=10, num_trials=100, auto_align=True\n):\n    \"\"\"\n    Benchmark SM100 GEMM against PyTorch's native GEMM\n    \"\"\"\n    # Ensure dimensions are aligned\n    if auto_align:\n        M = (\n            (M + sm100_gemm.MMA_TILE_M - 1) // sm100_gemm.MMA_TILE_M\n        ) * sm100_gemm.MMA_TILE_M\n        N = (\n            (N + sm100_gemm.MMA_TILE_N - 1) // sm100_gemm.MMA_TILE_N\n        ) * sm100_gemm.MMA_TILE_N\n        K = (\n            (K + sm100_gemm.MMA_TILE_K - 1) // sm100_gemm.MMA_TILE_K\n        ) * sm100_gemm.MMA_TILE_K\n\n    print(f\"Benchmarking GEMM with shape: ({M}, {N}, {K})\")\n\n    # Check SM100 availability\n    if not check_sm100_compatibility():\n        print(\"SM100 not available, skipping benchmark\")\n        return None\n\n    # Create test tensors\n    A = torch.randn(M, K, dtype=torch.float16, device=\"cuda\")\n    B = torch.randn(N, K, dtype=torch.float16, device=\"cuda\")\n    C = torch.randn(M, N, dtype=torch.float32, device=\"cuda\")\n\n    # PyTorch baseline (using mixed precision)\n    A_fp32 = A.float()\n    B_fp32 = B.float()\n\n    # Warmup\n    for _ in range(num_warmup):\n        # PyTorch GEMM\n        torch_result = torch.addmm(C, A_fp32, B_fp32.T)\n\n        # SM100 GEMM\n        sm100_result = sm100_gemm_f16(A, B, C.clone(), check_alignment=False)\n\n    torch.cuda.synchronize()\n\n    # Benchmark PyTorch\n    torch.cuda.synchronize()\n    start = torch.cuda.Event(enable_timing=True)\n    end = torch.cuda.Event(enable_timing=True)\n\n    start.record()\n    for _ in range(num_trials):\n        torch_result = torch.addmm(C, A_fp32, B_fp32.T)\n    end.record()\n    torch.cuda.synchronize()\n    torch_time = start.elapsed_time(end) / num_trials\n\n    # Benchmark SM100\n    start.record()\n    for _ in range(num_trials):\n        sm100_result = sm100_gemm_f16(A, B, C.clone(), check_alignment=False)\n    end.record()\n    torch.cuda.synchronize()\n    sm100_time = start.elapsed_time(end) / num_trials\n\n    # Check correctness\n    max_diff = torch.max(torch.abs(torch_result - sm100_result))\n    rel_error = max_diff / torch.max(torch.abs(torch_result))\n\n    # Calculate FLOPS\n    flops = 2 * M * N * K  # Multiply-add operations\n    torch_tflops = flops / (torch_time * 1e-3) / 1e12\n    sm100_tflops = flops / (sm100_time * 1e-3) / 1e12\n\n    print(f\"PyTorch time: {torch_time:.3f} ms ({torch_tflops:.2f} TFLOPS)\")\n    print(f\"SM100 time: {sm100_time:.3f} ms ({sm100_tflops:.2f} TFLOPS)\")\n    print(f\"Speedup: {torch_time/sm100_time:.2f}x\")\n    # print(f\"Max difference: {max_diff:.6f}\")\n    print(f\"Relative error: {rel_error:.6f}\")\n\n    return {\n        \"torch_time\": torch_time,\n        \"sm100_time\": sm100_time,\n        \"speedup\": torch_time / sm100_time,\n        \"torch_tflops\": torch_tflops,\n        \"sm100_tflops\": sm100_tflops,\n        \"max_diff\": max_diff.item(),\n        \"rel_error\": rel_error.item(),\n    }\n\n\n# Neural network layer implementations\nclass SM100Linear(torch.nn.Module):\n    \"\"\"\n    Linear layer using SM100 GEMM for forward pass\n    \"\"\"\n\n    def __init__(self, in_features, out_features, bias=True, device=\"cuda\"):\n        super().__init__()\n\n        # Align dimensions\n        self.orig_in_features = in_features\n        self.orig_out_features = out_features\n\n        aligned_in = (\n            (in_features + sm100_gemm.MMA_TILE_K - 1) // sm100_gemm.MMA_TILE_K\n        ) * sm100_gemm.MMA_TILE_K\n        aligned_out = (\n            (out_features + sm100_gemm.MMA_TILE_N - 1) // sm100_gemm.MMA_TILE_N\n        ) * sm100_gemm.MMA_TILE_N\n\n        self.in_features = aligned_in\n        self.out_features = aligned_out\n\n        # Parameters (with padding)\n        self.weight = torch.nn.Parameter(\n            torch.randn(aligned_out, aligned_in, dtype=torch.float16, device=device)\n            * 0.1\n        )\n\n        if bias:\n            self.bias = torch.nn.Parameter(\n                torch.zeros(aligned_out, dtype=torch.float32, device=device)\n            )\n        else:\n            self.register_parameter(\"bias\", None)\n\n        print(\n            f\"SM100Linear: {in_features} -> {out_features} (aligned: {aligned_in} -> {aligned_out})\"\n        )\n\n    def forward(self, x):\n        # Pad input if necessary\n        batch_size = x.size(0)\n\n        # Align batch size\n        aligned_batch = (\n            (batch_size + sm100_gemm.MMA_TILE_M - 1) // sm100_gemm.MMA_TILE_M\n        ) * sm100_gemm.MMA_TILE_M\n\n        if x.size(1) != self.in_features or batch_size != aligned_batch:\n            x_padded = torch.zeros(\n                aligned_batch, self.in_features, dtype=torch.float16, device=x.device\n            )\n            x_padded[:batch_size, : self.orig_in_features] = x\n            x = x_padded\n\n        # Prepare bias\n        if self.bias is not None:\n            C = (\n                self.bias.unsqueeze(0)\n                .expand(aligned_batch, self.out_features)\n                .contiguous()\n            )\n            beta = 1.0\n        else:\n            C = torch.zeros(\n                aligned_batch, self.out_features, dtype=torch.float32, device=x.device\n            )\n            beta = 0.0\n\n        # SM100 GEMM: output = x @ weight^T + bias\n        output = sm100_gemm_f16(\n            x, self.weight, C, alpha=1.0, beta=beta, check_alignment=False\n        )\n\n        # Remove padding\n        return output[:batch_size, : self.orig_out_features]\n\n\n# Example usage and test\nif __name__ == \"__main__\":\n    print(\"=== SM100 GEMM Extension Test ===\")\n\n    # Check compatibility first\n    if not check_sm100_compatibility():\n        print(\"Exiting due to compatibility issues\")\n        exit(1)\n\n    print(\"\\n=== Testing basic functionality ===\")\n\n    # Test with properly aligned dimensions\n    M, N, K = 512, 1024, 256\n    A, B, C = create_aligned_tensors(M, N, K)\n\n    # Fill with random data (only the actual needed portion)\n    A[:M, :K].normal_()\n    B[:N, :K].normal_()\n    C[:M, :N].normal_()\n\n    # Test the GEMM\n    result = sm100_gemm_f16(A, B, C, alpha=1.0, beta=0.5, check_alignment=False)\n    print(\n        f\"✅ Basic GEMM test passed. Result shape: {result.shape}, dtype: {result.dtype}\"\n    )\n\n    print(\"\\n=== Testing SM100Linear layer ===\")\n\n    # Test linear layer\n    layer = SM100Linear(256, 512, bias=True)\n    x = torch.randn(128, 256, dtype=torch.float16, device=\"cuda\")\n    output = layer(x)\n    print(f\"✅ Linear layer test passed. Output shape: {output.shape}\")\n\n    print(\"\\n=== Testing padding utilities ===\")\n\n    # Test padding for misaligned tensors\n    misaligned_A = torch.randn(300, 200, dtype=torch.float16, device=\"cuda\")\n    padded_A, pad_info = pad_to_aligned(misaligned_A)\n    print(f\"Original shape: {misaligned_A.shape}, Padded shape: {padded_A.shape}\")\n\n    unpadded = unpad_result(padded_A, pad_info)\n    print(f\"✅ Padding test passed. Unpadded shape: {unpadded.shape}\")\n\n    print(\"\\n=== Running performance benchmark ===\")\n\n    # Run benchmark\n    benchmark_results = benchmark_sm100_vs_torch(\n        M=8192, N=8192 * 2, K=2048, num_trials=50\n    )\n\n    if benchmark_results:\n        print(f\"\\n✅ All tests passed!\")\n        print(\n            f\"SM100 achieved {benchmark_results['speedup']:.2f}x speedup over PyTorch\"\n        )\n    else:\n        print(\"❌ Benchmark failed\")\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/setup.py",
    "content": "# setup.py\nimport os\n\nimport pybind11\nimport torch\nfrom pybind11 import get_cmake_dir\nfrom pybind11.setup_helpers import build_ext, Pybind11Extension\nfrom setuptools import Extension, setup\nfrom torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension\n\n# IMPORTANT: The following two lines are the only ones you need to change\n# Get CUTLASS path (you'll need to set this to your CUTLASS installation)\nCUTLASS_PATH = os.environ.get(\"CUTLASS_PATH\", \"/home/less/local/cutlas40\")\n\n# CUDA and PyTorch paths\ncuda_home = torch.utils.cpp_extension.CUDA_HOME\npytorch_includes = torch.utils.cpp_extension.include_paths()\n\next_modules = [\n    CUDAExtension(\n        name=\"sm100_gemm\",\n        sources=[\n            \"sm100_gemm_pytorch.cpp\",  # PyTorch bindings (C++)\n            \"sm100_gemm.cu\",  # CUDA kernel implementation\n        ],\n        include_dirs=[\n            # PyTorch includes\n            *pytorch_includes,\n            # CUTLASS includes\n            f\"{CUTLASS_PATH}/include\",\n            f\"{CUTLASS_PATH}/tools/util/include\",\n            # CUDA includes\n            f\"{cuda_home}/include\",\n        ],\n        library_dirs=[\n            f\"{cuda_home}/lib64\",\n        ],\n        libraries=[\"cuda\", \"cudart\"],\n        extra_compile_args={\n            \"cxx\": [\n                \"-O3\",\n                \"-std=c++17\",\n                \"-DCUTLASS_ARCH_MMA_SM100_SUPPORTED\",\n                \"-DCUTE_SM100_ENABLED\",\n            ],\n            \"nvcc\": [\n                \"-O3\",\n                \"-std=c++17\",\n                \"--expt-relaxed-constexpr\",\n                \"--expt-extended-lambda\",\n                \"-gencode=arch=compute_100a,code=sm_100a\",  # SM100 architecture\n                \"-DCUTLASS_ARCH_MMA_SM100_SUPPORTED\",\n                \"-DCUTE_SM100_ENABLED\",\n                \"--use_fast_math\",\n                \"-Xcompiler=-fPIC\",\n                \"-DCUTE_ARCH_TCGEN05_TMEM_ENABLED=1\",  # Enable TCGEN05_TMEM\n            ],\n        },\n        extra_link_args=[\"-lcuda\", \"-lcudart\"],\n        language=\"c++\",\n    )\n]\n\nsetup(\n    name=\"sm100_gemm\",\n    ext_modules=ext_modules,\n    cmdclass={\"build_ext\": BuildExtension},\n    zip_safe=False,\n    python_requires=\">=3.8\",\n    install_requires=[\"torch>=1.12.0\"],\n)\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.cu",
    "content": "// sm100_gemm_kernel.cu - CUDA kernel implementation with TMA\n#include \"sm100_gemm.h\"\n\n#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)\n\n#include <cutlass/arch/barrier.h>\n#include <cutlass/cluster_launch.hpp>\n#include <cutlass/half.h>\n#include <cutlass/util/print_error.hpp>\n\n#include <cute/algorithm/cooperative_copy.hpp>\n#include <cute/arch/cluster_sm90.hpp>\n#include <cute/arch/tmem_allocator_sm100.hpp>\n#include <cute/numeric/integral_constant.hpp>\n#include <cute/tensor.hpp>\n\nusing namespace cute;\n\n// Shared storage structure with TMA barriers\ntemplate <class TypeA, class TypeB, class ASmemLayout, class BSmemLayout>\nstruct SharedStorage {\n  alignas(128) cute::ArrayEngine<TypeA, cute::cosize_v<ASmemLayout>> A;\n  alignas(128) cute::ArrayEngine<TypeB, cute::cosize_v<BSmemLayout>> B;\n  alignas(16) cute::uint64_t\n      mma_barrier; // Barrier to track MMA computation on SMEM\n  alignas(16) cute::uint64_t\n      tma_barrier; // Barrier to track TMA data transfers to SMEM\n  alignas(16) cute::uint32_t tmem_base_ptr; // Base pointer for TMEM allocation\n\n  CUTE_DEVICE constexpr auto tensor_sA() {\n    return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{});\n  }\n  CUTE_DEVICE constexpr auto tensor_sB() {\n    return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{});\n  }\n};\n\n// Device kernel with TMA\ntemplate <class SharedStorage, class ATensor, class BTensor, class CTensor,\n          class DTensor, class MmaTiler_MNK, class TiledMMA,\n          class ClusterShape_MNK, class TmaAtomA, class TmaAtomB, class Alpha,\n          class Beta>\n__global__ static void gemm_device_tma(\n    ATensor mA, BTensor mB, CTensor mC, DTensor mD, MmaTiler_MNK mma_tiler,\n    TiledMMA tiled_mma, ClusterShape_MNK cluster_shape,\n    CUTE_GRID_CONSTANT TmaAtomA const tma_atom_A,\n    CUTE_GRID_CONSTANT TmaAtomB const tma_atom_B, Alpha alpha, Beta beta) {\n  // Step 1: The Prologue\n  Layout cluster_layout_vmnk = tiled_divide(\n      make_layout(cluster_shape), make_tile(typename TiledMMA::AtomThrID{}));\n\n  auto mma_coord_vmnk =\n      make_coord(blockIdx.x % size<0>(cluster_layout_vmnk),\n                 blockIdx.x / size<0>(cluster_layout_vmnk), blockIdx.y, _);\n\n  auto mma_coord = select<1, 2, 3>(mma_coord_vmnk);\n  Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X, _1>{});\n  Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step<X, _1, _1>{});\n  Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1, _1, X>{});\n  Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1, _1, X>{});\n\n  // SMEM allocation\n  extern __shared__ char shared_memory[];\n  SharedStorage &shared_storage =\n      *reinterpret_cast<SharedStorage *>(shared_memory);\n\n  Tensor tCsA = shared_storage.tensor_sA();\n  Tensor tCsB = shared_storage.tensor_sB();\n\n  // MMA partitioning\n  auto mma_v = get<0>(mma_coord_vmnk);\n  ThrMMA cta_mma = tiled_mma.get_slice(mma_v);\n  Tensor tCgA = cta_mma.partition_A(gA);\n  Tensor tCgB = cta_mma.partition_B(gB);\n  Tensor tCgC = cta_mma.partition_C(gC);\n  Tensor tCgD = cta_mma.partition_C(gD);\n\n  // Fragment allocation\n  Tensor tCrA = cta_mma.make_fragment_A(tCsA);\n  Tensor tCrB = cta_mma.make_fragment_B(tCsB);\n  Tensor tCtAcc = cta_mma.make_fragment_C(tCgC);\n\n  uint32_t elect_one_thr = cute::elect_one_sync();\n  uint32_t elect_one_warp = (threadIdx.x / 32 == 0);\n\n  using TmemAllocator = cute::TMEM::Allocator1Sm;\n  TmemAllocator tmem_allocator{};\n\n  // TMEM allocation\n  if (elect_one_warp) {\n    tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns,\n                            &shared_storage.tmem_base_ptr);\n  }\n  __syncthreads();\n  tCtAcc.data() = shared_storage.tmem_base_ptr;\n\n  // TMA Setup\n  // TMA partitioning with dedicated custom partitioner\n  // The Int<0>, Layout<_1> indicates that the TMAs are not multicasted\n  // group_modes<0,3> transforms the tensor shape for TMA operation\n  auto [tAgA, tAsA] =\n      tma_partition(tma_atom_A, Int<0>{}, Layout<_1>{}, group_modes<0, 3>(tCsA),\n                    group_modes<0, 3>(tCgA));\n\n  auto [tBgB, tBsB] =\n      tma_partition(tma_atom_B, Int<0>{}, Layout<_1>{}, group_modes<0, 3>(tCsB),\n                    group_modes<0, 3>(tCgB));\n\n  // Calculate total bytes that TMA will transfer each tile to track completion\n  int tma_transaction_bytes =\n      sizeof(make_tensor_like(tAsA)) + sizeof(make_tensor_like(tBsB));\n\n  // Barrier initialization\n  if (elect_one_warp && elect_one_thr) {\n    cute::initialize_barrier(shared_storage.mma_barrier, 1);\n    cute::initialize_barrier(shared_storage.tma_barrier, 1);\n  }\n  int mma_barrier_phase_bit = 0;\n  int tma_barrier_phase_bit = 0;\n  __syncthreads();\n\n  // Step 2: The Mainloop with TMA\n  tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;\n\n  for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile) {\n    // Step 2a: TMA Load Operations\n    // Execute asynchronous TMA loads with single thread\n    if (elect_one_warp && elect_one_thr) {\n      cute::set_barrier_transaction_bytes(shared_storage.tma_barrier,\n                                          tma_transaction_bytes);\n      copy(tma_atom_A.with(shared_storage.tma_barrier), tAgA(_, k_tile),\n           tAsA); // Load MmaTile_M x MmaTile_K A tile\n      copy(tma_atom_B.with(shared_storage.tma_barrier), tBgB(_, k_tile),\n           tBsB); // Load MmaTile_N x MmaTile_K B tile\n    }\n\n    // Step 2b: Wait for TMA loads and execute MMAs\n    // Wait for TMA loads to SMEM to complete\n    cute::wait_barrier(shared_storage.tma_barrier, tma_barrier_phase_bit);\n    tma_barrier_phase_bit ^= 1;\n\n    // Execute MMAs\n    if (elect_one_warp) {\n      for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {\n        gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCtAcc);\n        tiled_mma.accumulate_ = UMMA::ScaleOut::One;\n      }\n      cutlass::arch::umma_arrive(&shared_storage.mma_barrier);\n    }\n\n    // Wait MMAs to complete to avoid overwriting the A and B SMEM\n    cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit);\n    mma_barrier_phase_bit ^= 1;\n  }\n\n  // Step 3: The Epilogue\n  TiledCopy tiled_t2r_copy =\n      make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc);\n  ThrCopy thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x);\n\n  Tensor tDgC = thr_t2r_copy.partition_D(tCgC);\n  Tensor tDrC = make_fragment_like(tDgC);\n  copy(tDgC, tDrC);\n\n  Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc);\n  Tensor tDgD = thr_t2r_copy.partition_D(tCgD);\n  using AccType = typename decltype(tCtAcc)::value_type;\n  Tensor tDrAcc = make_tensor<AccType>(shape(tDgD));\n  copy(tiled_t2r_copy, tDtAcc, tDrAcc);\n\n  // AXPBY and store result\n  axpby(alpha, tDrAcc, beta, tDrC);\n  copy(tDrC, tDgD);\n\n  __syncthreads();\n\n  // Cleanup TMEM\n  if (elect_one_warp) {\n    tmem_allocator.release_allocation_lock();\n    tmem_allocator.free(shared_storage.tmem_base_ptr,\n                        TmemAllocator::Sm100TmemCapacityColumns);\n  }\n}\n\n// Host setup\n// Host function that creates TMA descriptors and launches the kernel\ncudaError_t launch_sm100_gemm_f16_tma(void *d_A, void *d_B, void *d_C,\n                                      void *d_D, int M, int N, int K,\n                                      float alpha, float beta,\n                                      cudaStream_t stream) {\n  // Define types\n  using TypeA = cutlass::half_t;\n  using TypeB = cutlass::half_t;\n  using TypeC = float;\n  using TypeD = float;\n\n  // Create layouts (K-major for A and B, N-major for C and D)\n  auto layout_A = make_layout(make_shape(M, K), make_stride(K, Int<1>{}));\n  auto layout_B = make_layout(make_shape(N, K), make_stride(K, Int<1>{}));\n  auto layout_C = make_layout(make_shape(M, N), make_stride(N, Int<1>{}));\n  auto layout_D = layout_C;\n\n  // Create CuTe tensors\n  auto mA =\n      make_tensor(make_gmem_ptr(reinterpret_cast<TypeA *>(d_A)), layout_A);\n  auto mB =\n      make_tensor(make_gmem_ptr(reinterpret_cast<TypeB *>(d_B)), layout_B);\n  auto mC =\n      make_tensor(make_gmem_ptr(reinterpret_cast<TypeC *>(d_C)), layout_C);\n  auto mD =\n      make_tensor(make_gmem_ptr(reinterpret_cast<TypeD *>(d_D)), layout_D);\n\n  // Create TiledMMA\n  TiledMMA tiled_mma =\n      make_tiled_mma(SM100_MMA_F16BF16_SS<TypeA, TypeB, TypeC, 128, 256,\n                                          UMMA::Major::K, UMMA::Major::K>{});\n\n  // Define MMA tiler sizes\n  auto bM = tile_size<0>(tiled_mma);            // 128\n  auto bN = tile_size<1>(tiled_mma);            // 256\n  auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // 64\n  auto mma_tiler = make_shape(bM, bN, bK);\n\n  // Check alignment\n  if (M % int(bM) != 0 || N % int(bN) != 0 || K % int(bK) != 0) {\n    return cudaErrorInvalidValue;\n  }\n\n  // Create SMEM layouts\n  auto mma_shape_A = partition_shape_A(\n      tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler)));\n  auto mma_shape_B = partition_shape_B(\n      tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler)));\n\n  auto sA_layout =\n      UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeA>{}, mma_shape_A);\n  auto sB_layout =\n      UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeB>{}, mma_shape_B);\n\n  using SMEMStorage =\n      SharedStorage<TypeA, TypeB, decltype(sA_layout), decltype(sB_layout)>;\n\n  // Cluster configuration\n  auto cluster_shape = make_shape(Int<1>{}, Int<1>{}, Int<1>{});\n\n  // Create TMA descriptors for A and B matrices\n  Copy_Atom tma_atom_A =\n      make_tma_atom(SM90_TMA_LOAD{},        // TMA Load Op\n                    mA,                     // Source GMEM tensor\n                    sA_layout,              // Destination SMEM layout\n                    select<0, 2>(mma_tiler) // MK Tiler for TMA operation\n      );\n  Tensor mA_tma = tma_atom_A.get_tma_tensor(shape(mA));\n\n  Copy_Atom tma_atom_B =\n      make_tma_atom(SM90_TMA_LOAD{},        // TMA Load Op\n                    mB,                     // Source GMEM tensor\n                    sB_layout,              // Destination SMEM layout\n                    select<1, 2>(mma_tiler) // NK Tiler for TMA operation\n      );\n  Tensor mB_tma = tma_atom_B.get_tma_tensor(shape(mB));\n\n  // Launch parameters\n  dim3 dimBlock(128);\n  dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape),\n                  size<2>(cluster_shape));\n  dim3 dimGrid(ceil_div(M, int(bM)), ceil_div(N, int(bN)));\n  int smemBytes = sizeof(SMEMStorage);\n\n  // Get kernel pointer\n  auto *kernel_ptr =\n      &gemm_device_tma<SMEMStorage, decltype(mA_tma), decltype(mB_tma),\n                       decltype(mC), decltype(mD), decltype(mma_tiler),\n                       decltype(tiled_mma), decltype(cluster_shape),\n                       decltype(tma_atom_A), decltype(tma_atom_B), float,\n                       float>;\n\n  // Set kernel attributes\n  cudaError_t error = cudaFuncSetAttribute(\n      kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smemBytes);\n  if (error != cudaSuccess) {\n    return error;\n  }\n\n  // Launch kernel\n  cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster,\n                                         smemBytes};\n  cutlass::Status status = cutlass::launch_kernel_on_cluster(\n      params, (void const *)kernel_ptr, mA_tma, mB_tma, mC, mD, mma_tiler,\n      tiled_mma, cluster_shape, tma_atom_A, tma_atom_B, alpha, beta);\n\n  return (status == cutlass::Status::kSuccess) ? cudaSuccess\n                                               : cudaErrorLaunchFailure;\n}\n\n// Wrapper function to choose between TMA and non-TMA versions\ncudaError_t launch_sm100_gemm_f16(void *d_A, void *d_B, void *d_C, void *d_D,\n                                  int M, int N, int K, float alpha, float beta,\n                                  cudaStream_t stream) {\n  // For now, always use TMA version for better performance\n  return launch_sm100_gemm_f16_tma(d_A, d_B, d_C, d_D, M, N, K, alpha, beta,\n                                   stream);\n}\n\n#else\n\ncudaError_t launch_sm100_gemm_f16(void *d_A, void *d_B, void *d_C, void *d_D,\n                                  int M, int N, int K, float alpha, float beta,\n                                  cudaStream_t stream) {\n  return cudaErrorNotSupported;\n}\n\n#endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/PKG-INFO",
    "content": "Metadata-Version: 2.4\nName: sm100_gemm\nVersion: 0.0.0\nRequires-Python: >=3.8\nRequires-Dist: torch>=1.12.0\nDynamic: requires-dist\nDynamic: requires-python\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/SOURCES.txt",
    "content": "setup.py\nsm100_gemm.cu\nsm100_gemm_pytorch.cpp\nsm100_gemm.egg-info/PKG-INFO\nsm100_gemm.egg-info/SOURCES.txt\nsm100_gemm.egg-info/dependency_links.txt\nsm100_gemm.egg-info/not-zip-safe\nsm100_gemm.egg-info/requires.txt\nsm100_gemm.egg-info/top_level.txt"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/dependency_links.txt",
    "content": "\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/not-zip-safe",
    "content": "\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/requires.txt",
    "content": "torch>=1.12.0\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/top_level.txt",
    "content": "sm100_gemm\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.h",
    "content": "// sm100_gemm_kernel.h - Header file for CUDA kernel\n#pragma once\n\n#include <cuda_runtime.h>\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n/**\n * Launch SM100 GEMM kernel: D = alpha * A @ B^T + beta * C\n *\n * @param d_A Pointer to matrix A in device memory (M x K, FP16, K-major)\n * @param d_B Pointer to matrix B in device memory (N x K, FP16, K-major)\n * @param d_C Pointer to matrix C in device memory (M x N, FP32, N-major)\n * @param d_D Pointer to matrix D in device memory (M x N, FP32, N-major)\n * @param M Number of rows in A and C/D\n * @param N Number of rows in B and columns in C/D\n * @param K Number of columns in A and B\n * @param alpha Scaling factor for A @ B^T\n * @param beta Scaling factor for C\n * @param stream CUDA stream (currently unused, for future async support)\n *\n * @return cudaSuccess on success, error code otherwise\n *\n * Requirements:\n * - M must be multiple of 128\n * - N must be multiple of 256\n * - K must be multiple of 64\n * - All pointers must be valid device memory\n * - Tensors must be contiguous with specified layouts\n */\ncudaError_t launch_sm100_gemm_f16_tma(void *d_A, void *d_B, void *d_C,\n                                      void *d_D, int M, int N, int K,\n                                      float alpha, float beta,\n                                      cudaStream_t stream = 0);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm_pytorch.cpp",
    "content": "// sm100_gemm_pytorch.cpp - PyTorch C++ extension (no CUDA code)\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGuard.h>\n#include <cuda_runtime.h>\n#include <torch/extension.h>\n\n#include \"sm100_gemm.h\"\n\n// Check if SM100 support is available at compile time\nbool is_sm100_supported() {\n#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)\n  return true;\n#else\n  return false;\n#endif\n}\n\n// Check if current GPU supports SM100 at runtime\nbool check_sm100_device() {\n  int device;\n  cudaGetDevice(&device);\n\n  cudaDeviceProp props;\n  cudaError_t error = cudaGetDeviceProperties(&props, device);\n  if (error != cudaSuccess) {\n    return false;\n  }\n\n  // Check for SM100 architecture (compute capability 10.0a)\n  return (props.major == 10 && props.minor == 0);\n}\n\ntorch::Tensor sm100_gemm_f16(const torch::Tensor &A, const torch::Tensor &B,\n                             const torch::Tensor &C, float alpha = 1.0f,\n                             float beta = 0.0f) {\n\n  // Check compile-time support\n  TORCH_CHECK(\n      is_sm100_supported(),\n      \"SM100 support not compiled. Requires CUTLASS_ARCH_MMA_SM100_SUPPORTED\");\n\n  // Check runtime device support\n  TORCH_CHECK(check_sm100_device(),\n              \"Current GPU does not support SM100 architecture (requires \"\n              \"compute capability 10.0a)\");\n\n  // Input validation\n  TORCH_CHECK(A.device().is_cuda(), \"A must be a CUDA tensor\");\n  TORCH_CHECK(B.device().is_cuda(), \"B must be a CUDA tensor\");\n  TORCH_CHECK(C.device().is_cuda(), \"C must be a CUDA tensor\");\n  TORCH_CHECK(A.dtype() == torch::kFloat16, \"A must be float16\");\n  TORCH_CHECK(B.dtype() == torch::kFloat16, \"B must be float16\");\n  TORCH_CHECK(C.dtype() == torch::kFloat32, \"C must be float32\");\n  TORCH_CHECK(A.is_contiguous(), \"A must be contiguous\");\n  TORCH_CHECK(B.is_contiguous(), \"B must be contiguous\");\n  TORCH_CHECK(C.is_contiguous(), \"C must be contiguous\");\n  TORCH_CHECK(A.dim() == 2, \"A must be 2D\");\n  TORCH_CHECK(B.dim() == 2, \"B must be 2D\");\n  TORCH_CHECK(C.dim() == 2, \"C must be 2D\");\n\n  // Get dimensions\n  int64_t M = A.size(0);\n  int64_t K = A.size(1);\n  int64_t N = B.size(0);\n  int64_t K_B = B.size(1);\n\n  TORCH_CHECK(K == K_B, \"Inner dimensions must match: A.shape[1]=\", K,\n              \", B.shape[1]=\", K_B);\n  TORCH_CHECK(C.size(0) == M && C.size(1) == N, \"C dimensions (\", C.size(0),\n              \", \", C.size(1), \") must match output shape (\", M, \", \", N, \")\");\n\n  // Check alignment requirements for SM100\n  TORCH_CHECK(M % 128 == 0, \"M=\", M, \" must be multiple of 128\");\n  TORCH_CHECK(N % 256 == 0, \"N=\", N, \" must be multiple of 256\");\n  TORCH_CHECK(K % 64 == 0, \"K=\", K, \" must be multiple of 64\");\n\n  // Check size limits (avoid overflow in int conversion)\n  TORCH_CHECK(M <= INT_MAX && N <= INT_MAX && K <= INT_MAX,\n              \"Dimensions too large for int conversion\");\n\n  // Create output tensor\n  auto D = torch::empty_like(C);\n\n  // Set CUDA device guard\n  const auto device = A.device();\n  c10::cuda::CUDAGuard device_guard(device);\n\n  // Get current CUDA stream\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index()).stream();\n\n  // Launch the kernel\n  cudaError_t error = launch_sm100_gemm_f16_tma(\n      A.data_ptr(), B.data_ptr(), C.data_ptr(), D.data_ptr(),\n      static_cast<int>(M), static_cast<int>(N), static_cast<int>(K), alpha,\n      beta, stream);\n\n  // Check for launch errors\n  TORCH_CHECK(error == cudaSuccess,\n              \"SM100 GEMM kernel launch failed: \", cudaGetErrorString(error));\n\n  // Check for kernel execution errors\n  C10_CUDA_CHECK(cudaGetLastError());\n\n  return D;\n}\n\n// Utility functions for debugging and information\ntorch::Tensor get_device_info() {\n  int device;\n  cudaGetDevice(&device);\n\n  cudaDeviceProp props;\n  cudaGetDeviceProperties(&props, device);\n\n  // Return device info as a tensor (for easy Python access)\n  auto info = torch::zeros({4}, torch::kInt32);\n  auto accessor = info.accessor<int32_t, 1>();\n\n  accessor[0] = props.major;          // Compute capability major\n  accessor[1] = props.minor;          // Compute capability minor\n  accessor[2] = is_sm100_supported(); // Compile-time support\n  accessor[3] = check_sm100_device(); // Runtime device support\n\n  return info;\n}\n\nstd::vector<int64_t> get_aligned_shape(int64_t M, int64_t N, int64_t K) {\n  // Return properly aligned dimensions for SM100\n  int64_t aligned_M = ((M + 127) / 128) * 128;\n  int64_t aligned_N = ((N + 255) / 256) * 256;\n  int64_t aligned_K = ((K + 63) / 64) * 64;\n\n  return {aligned_M, aligned_N, aligned_K};\n}\n\ntorch::Tensor create_aligned_tensor(const std::vector<int64_t> &shape,\n                                    torch::ScalarType dtype,\n                                    torch::Device device) {\n  // Create a tensor with SM100-aligned dimensions\n  TORCH_CHECK(shape.size() == 2, \"Shape must be 2D\");\n\n  auto aligned_shape =\n      get_aligned_shape(shape[0], shape[1], shape.size() > 2 ? shape[2] : 64);\n\n  if (shape.size() == 2) {\n    return torch::zeros({aligned_shape[0], aligned_shape[1]},\n                        torch::TensorOptions().dtype(dtype).device(device));\n  } else {\n    return torch::zeros({aligned_shape[0], aligned_shape[2]},\n                        torch::TensorOptions().dtype(dtype).device(device));\n  }\n}\n\n// Python bindings\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.doc() = \"SM100 GEMM PyTorch Extension\";\n\n  // Main GEMM function\n  m.def(\"sm100_gemm_f16\", &sm100_gemm_f16,\n        \"SM100 GEMM with FP16 inputs and FP32 output: D = alpha * A @ B^T + \"\n        \"beta * C\",\n        py::arg(\"A\"), py::arg(\"B\"), py::arg(\"C\"), py::arg(\"alpha\") = 1.0f,\n        py::arg(\"beta\") = 0.0f);\n\n  // Utility functions\n  m.def(\"is_sm100_supported\", &is_sm100_supported,\n        \"Check if SM100 support was compiled in\");\n\n  m.def(\"check_sm100_device\", &check_sm100_device,\n        \"Check if current GPU supports SM100 architecture\");\n\n  m.def(\"get_device_info\", &get_device_info,\n        \"Get device compute capability and SM100 support info\");\n\n  m.def(\"get_aligned_shape\", &get_aligned_shape,\n        \"Get SM100-aligned dimensions for given shape\", py::arg(\"M\"),\n        py::arg(\"N\"), py::arg(\"K\"));\n\n  m.def(\"create_aligned_tensor\", &create_aligned_tensor,\n        \"Create tensor with SM100-aligned dimensions\", py::arg(\"shape\"),\n        py::arg(\"dtype\"), py::arg(\"device\"));\n\n  // Constants for alignment requirements\n  m.attr(\"MMA_TILE_M\") = 128;\n  m.attr(\"MMA_TILE_N\") = 256;\n  m.attr(\"MMA_TILE_K\") = 64;\n}\n"
  },
  {
    "path": "kernels/cuda/cutlass_gemm/broadcast_load_epilogue_c3x.hpp",
    "content": "/***************************************************************************************************\n * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights\n *reserved. SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice,\n *this list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE\n *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n *POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n\n//\n// This file is a modified excerpt of\n// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp\n// from https://github.com/NVIDIA/cutlass v3.5.0\n// It has been modified to support either row/column or scalar broadcasting\n// where the tensor being loaded from is always passed in via a device pointer.\n// This lets one compiled kernel handle all cases of per-tensor or\n// per-channel/per-token quantization.\n//\n// This interface also allows the scales to be passed in as tensors that\n// consistently reside on the device, which avoids an issue with a previous\n// implementation where scalars needed to be on the CPU since they\n// were passed in via float values. This created a potential performance hazard\n// if scales were initially on the device, and caused torch.compile graphs\n// breaks when moving scales to the CPU.\n//\n#pragma once\n\n// Turn off clang-format for the entire file to keep it close to upstream\n// clang-format off\n\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/arch/barrier.h\"\n\n#include \"cute/tensor.hpp\"\n#include \"cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp\"\n\nnamespace cutlass::epilogue::fusion {\n\nusing namespace cute;\nusing namespace detail;\n\n// Row vector broadcast\ntemplate<\n  // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least\n  // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races\n  int Stages,\n  class CtaTileShapeMNK,\n  class Element,\n  class StrideMNL = Stride<_0,_1,_0>,\n  int Alignment = 128 / sizeof_bits_v<Element>\n>\nstruct Sm90RowOrScalarBroadcast {\n  static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, \"sub-16B alignment not supported yet\");\n  static_assert(\n    (cute::is_same_v<StrideMNL, Stride<_0,_1, _0>>) || // row vector broadcast, e.g. per-col alpha/bias\n    (cute::is_same_v<StrideMNL, Stride<_0,_1,int>>));  // batched row vector broadcast\n\n  // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem\n  struct SharedStorage {\n    alignas(16) array_aligned<Element, size<1>(CtaTileShapeMNK{}) * Stages> smem_row;\n  };\n\n  // This struct has been modified to have a bool indicating that ptr_row is a \n  // scalar that must be broadcast, instead of containing a scalar that is \n  // valid if ptr_row is null.\n  struct Arguments {\n    Element const* ptr_row = nullptr;\n    bool row_broadcast = true;\n    StrideMNL dRow = {};\n  };\n\n  using Params = Arguments;\n\n  template <class ProblemShape>\n  static constexpr Params\n  to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {\n    return args;\n  }\n\n  template <class ProblemShape>\n  static size_t\n  get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {\n    return 0;\n  }\n\n  template <class ProblemShape>\n  static cutlass::Status\n  initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,\n    CudaHostAdapter* cuda_adapter = nullptr) {\n    return cutlass::Status::kSuccess;\n  }\n\n  CUTLASS_HOST_DEVICE\n  Sm90RowOrScalarBroadcast() { }\n\n  CUTLASS_HOST_DEVICE\n  Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)\n      : params(params),\n        smem_row(const_cast<Element*>(shared_storage.smem_row.data())) { }\n\n  Params params;\n  Element* smem_row;\n\n  CUTLASS_DEVICE bool\n  is_producer_load_needed() const {\n    return true;\n  }\n\n  CUTLASS_DEVICE bool\n  is_C_load_needed() const {\n    return false;\n  }\n\n  CUTLASS_DEVICE bool\n  is_zero() const {\n    return (!params.row_broadcast && *(params.ptr_row) == Element(0));\n  }\n\n  template <int EpiTiles, class GTensor, class STensor>\n  struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks {\n    CUTLASS_DEVICE\n    ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params)\n      : gRow(cute::forward<GTensor>(gRow)),\n        sRow(cute::forward<STensor>(sRow)),\n        params(params) {}\n\n    GTensor gRow;                                                                                 // (CTA_M,CTA_N)\n    STensor sRow;                                                                                 // (CTA_M,CTA_N,PIPE)\n    Params const& params;\n\n    CUTLASS_DEVICE void\n    begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) {\n      if (!params.row_broadcast) {\n        return;\n      }\n\n      if (issue_tma_load) {\n        // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size\n        constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v<Element> / 8;\n        cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes);\n        // Issue the TMA bulk copy\n        auto bulk_copy = Copy_Atom<SM90_BULK_COPY_AUTO, Element>{}.with(*full_mbarrier_ptr);\n        // Filter so we don't issue redundant copies over stride-0 modes\n        int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;\n        copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index)));\n      }\n    }\n  };\n\n  template <class... Args>\n  CUTLASS_DEVICE auto\n  get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {\n\n    auto [M, N, K, L] = args.problem_shape_mnkl;\n    auto [m, n, k, l] = args.tile_coord_mnkl;\n    Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);\n    Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l));            // (CTA_M,CTA_N)\n    Tensor sRow = make_tensor(make_smem_ptr(smem_row),                                            // (CTA_M,CTA_N,PIPE)\n                    make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),\n                    make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));\n\n    constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;\n    return ProducerLoadCallbacks<EpiTiles, decltype(gRow), decltype(sRow)>(\n      cute::move(gRow), cute::move(sRow), params);\n  }\n\n  template <int EpiTiles, class RTensor, class STensor>\n  struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {\n    CUTLASS_DEVICE\n    ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params)\n      : tCrRow(cute::forward<RTensor>(tCrRow)),\n        tCsRow(cute::forward<STensor>(tCsRow)),\n        params(params) {}\n\n    RTensor tCrRow;                                                               // (CPY,CPY_M,CPY_N)\n    STensor tCsRow;                                                               // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)\n    Params const& params;\n\n    CUTLASS_DEVICE void\n    previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) {\n      if (!params.row_broadcast) {\n        fill(tCrRow, *(params.ptr_row));\n        return;\n      }\n\n      if (epi_m == 0) { // Assumes M-major subtile loop\n        // Filter so we don't issue redundant copies over stride-0 modes\n        // (only works if 0-strides are in same location, which is by construction)\n        int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;\n        copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow));\n      }\n    }\n\n    template <typename ElementAccumulator, int FragmentSize>\n    CUTLASS_DEVICE Array<Element, FragmentSize>\n    visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {\n      Array<Element, FragmentSize> frg_row;\n\n      CUTLASS_PRAGMA_UNROLL\n      for (int i = 0; i < FragmentSize; ++i) {\n        frg_row[i] = tCrRow(epi_v * FragmentSize + i);\n      }\n\n      return frg_row;\n    }\n  };\n\n  template <\n    bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy\n    class... Args\n  >\n  CUTLASS_DEVICE auto\n  get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {\n\n    Tensor sRow = make_tensor(make_smem_ptr(smem_row),                                            // (CTA_M,CTA_N,PIPE)\n                    make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),\n                    make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));\n    Tensor tCsRow = sm90_partition_for_epilogue<ReferenceSrc>(                    // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)\n                      sRow, args.epi_tile, args.tiled_copy, args.thread_idx);\n    Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow));                                           // (CPY,CPY_M,CPY_N)\n\n    constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;\n    return ConsumerStoreCallbacks<EpiTiles, decltype(tCrRow), decltype(tCsRow)>(\n      cute::move(tCrRow), cute::move(tCsRow), params);\n  }\n};\n\n/////////////////////////////////////////////////////////////////////////////////////////////////\n\n// Column vector broadcast\ntemplate<\n  int Stages,\n  class CtaTileShapeMNK,\n  class Element,\n  class StrideMNL = Stride<_1,_0,_0>,\n  int Alignment = 128 / sizeof_bits_v<Element>\n>\nstruct Sm90ColOrScalarBroadcast {\n  static_assert(Stages == 0, \"Column broadcast doesn't support smem usage yet\");\n  static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, \"sub-16B alignment not supported yet\");\n  static_assert(\n    (cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias\n    (cute::is_same_v<StrideMNL, Stride<_1,_0,int>>));  // batched col vector broadcast, e.g. batched per-row bias\n\n  // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem\n  struct SharedStorage { };\n\n  // This struct has been modified to have a bool indicating that ptr_col is a \n  // scalar that must be broadcast, instead of containing a scalar that is \n  // valid if ptr_col is null.\n  struct Arguments {\n    Element const* ptr_col = nullptr;\n    bool col_broadcast = true;\n    StrideMNL dCol = {};\n  };\n\n  using Params = Arguments;\n\n  template <class ProblemShape>\n  static constexpr Params\n  to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {\n    return args;\n  }\n\n  template <class ProblemShape>\n  static size_t\n  get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {\n    return 0;\n  }\n\n  template <class ProblemShape>\n  static cutlass::Status\n  initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,\n    CudaHostAdapter* cuda_adapter = nullptr) {\n    return cutlass::Status::kSuccess;\n  }\n\n  CUTLASS_DEVICE bool\n  is_producer_load_needed() const {\n    return false;\n  }\n\n  CUTLASS_DEVICE bool\n  is_C_load_needed() const {\n    return false;\n  }\n\n  CUTLASS_DEVICE bool\n  is_zero() const {\n    return (!params.col_broadcast && *(params.ptr_col) == Element(0));\n  }\n\n  CUTLASS_HOST_DEVICE\n  Sm90ColOrScalarBroadcast() { }\n\n  CUTLASS_HOST_DEVICE\n  Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)\n      : params(params) { }\n\n  Params params;\n\n  template <class... Args>\n  CUTLASS_DEVICE auto\n  get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {\n    return EmptyProducerLoadCallbacks{};\n  }\n\n  template<class GTensor, class RTensor>\n  struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {\n    CUTLASS_DEVICE\n    ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params)\n      : tCgCol(cute::forward<GTensor>(tCgCol)),\n        tCrCol(cute::forward<RTensor>(tCrCol)),\n        params(params) {}\n\n    GTensor tCgCol;                                                                    // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)\n    RTensor tCrCol;                                                                    // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)\n    Params const& params;\n\n    CUTLASS_DEVICE void\n    begin() {\n      if (!params.col_broadcast) {\n        fill(tCrCol, *(params.ptr_col));\n        return;\n      }\n\n      // Filter so we don't issue redundant copies over stride-0 modes\n      // (only works if 0-strides are in same location, which is by construction)\n      copy_aligned(filter(tCgCol), filter(tCrCol));\n    }\n\n    template <typename ElementAccumulator, int FragmentSize>\n    CUTLASS_DEVICE Array<Element, FragmentSize>\n    visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {\n      Array<Element, FragmentSize> frg_col;\n      Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);\n\n      CUTLASS_PRAGMA_UNROLL\n      for (int i = 0; i < FragmentSize; ++i) {\n        frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);\n      }\n\n      return frg_col;\n    }\n\n  };\n\n  template <\n    bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy\n    class... Args\n  >\n  CUTLASS_DEVICE auto\n  get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {\n\n    auto [M, N, K, L] = args.problem_shape_mnkl;\n    Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);\n    Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>(                         // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)\n      mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);\n    Tensor tCrCol = make_tensor_like(tCgCol);                                          // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)\n\n    return ConsumerStoreCallbacks<decltype(tCgCol), decltype(tCrCol)>(\n      cute::move(tCgCol), cute::move(tCrCol), params);\n  }\n};\n\n}"
  },
  {
    "path": "kernels/cuda/cutlass_gemm/common.hpp",
    "content": "#pragma once\n\n#include \"cutlass/cutlass.h\"\n#include <climits>\n\n/**\n * Helper function for checking CUTLASS errors\n */\n#define CUTLASS_CHECK(status)                        \\\n  {                                                  \\\n    TORCH_CHECK(status == cutlass::Status::kSuccess, \\\n                cutlassGetStatusString(status))      \\\n  }\n\ninline uint32_t next_pow_2(uint32_t const num) {\n  if (num <= 1) return num;\n  return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));\n}\n\ninline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {\n  int max_shared_mem_per_block_opt_in = 0;\n  cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,\n                        cudaDevAttrMaxSharedMemoryPerBlockOptin,\n                        device);\n  return max_shared_mem_per_block_opt_in;\n}\n"
  },
  {
    "path": "kernels/cuda/cutlass_gemm/cutlass.cpp",
    "content": "#include <torch/extension.h>\n#include<torch/all.h>\n\nvoid cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b,\n                            torch::Tensor const& a_scales, torch::Tensor const& b_scales);\n\ntorch::Tensor cutlass_scaled_mm(torch::Tensor a, torch::Tensor b, torch::Tensor a_scales, torch::Tensor b_scales) {\n    \n    auto acc_dtype = torch::kFloat16;\n    auto options = torch::TensorOptions().dtype(acc_dtype).device(a.device());\n    torch::Tensor out = torch::empty({a.size(0), b.size(1)}, options);\n\n    cutlass_scaled_mm_sm90(out, a, b, a_scales, b_scales);\n    return out;\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"cutlass_scaled_mm\", &cutlass_scaled_mm, \"CUTLASS Scaled Matrix Multiplication\");\n}"
  },
  {
    "path": "kernels/cuda/cutlass_gemm/cutlass_kernel.cu",
    "content": "// clang-format will break include orders\n// clang-format off\n#include <cudaTypedefs.h>\n\n#if defined CUDA_VERSION && CUDA_VERSION >= 12000\n\n#include <torch/all.h>\n\n#include <ATen/cuda/CUDAContext.h>\n\n#include <iostream>\n#include <sstream>\n#include <vector>\n\n#include \"cutlass/cutlass.h\"\n\n#include \"cute/tensor.hpp\"\n#include \"cute/atom/mma_atom.hpp\"\n#include \"cutlass/numeric_types.h\"\n\n#include \"cutlass/gemm/device/gemm_universal_adapter.h\"\n#include \"cutlass/gemm/kernel/gemm_universal.hpp\"\n#include \"cutlass/epilogue/collective/collective_builder.hpp\"\n#include \"cutlass/gemm/collective/collective_builder.hpp\"\n\n#include \"broadcast_load_epilogue_c3x.hpp\"\n#include \"common.hpp\"\n// clang-format on\n\nusing namespace cute;\n\n/*\n   This file defines quantized GEMM operations using the CUTLASS 3.x API, for\n   NVIDIA GPUs with sm90a (Hopper) or later.\n\n   Epilogue functions can be defined to post-process the output before it is\n   written to GPU memory.\n   Epilogues must contain a public type named EVTCompute of type Sm90EVT,\n   as well as a static prepare_args function that constructs an\n   EVTCompute::Arguments struct.\n*/\n\nnamespace {\n\n// A wrapper for the GEMM kernel that is used to guard against compilation on\n// architectures that will never use the kernel. The purpose of this is to\n// reduce the size of the compiled binary.\n// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef\n// into code that will be executed on the device where it is defined.\ntemplate <typename Kernel>\nstruct enable_sm90_or_later : Kernel {\n  template <typename... Args>\n  CUTLASS_DEVICE void operator()(Args&&... args) {\n  #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900\n    Kernel::operator()(std::forward<Args>(args)...);\n  #endif\n  }\n};\n\n/*\n   This epilogue function defines a quantized GEMM operation similar to\n   torch.scaled_mm_.\n\n   A and B may be both either int8 or fp8_e4m3. A can be\n   quantized per-tensor or per-row. B can be quantized per-tensor or per-column.\n   Any combination of per-tensor and per-row or column is supported.\n   A and B must have symmetric quantization (zero point == 0).\n\n   So the GEMM operation is D = (a_scales * A) (b_scales * B), where the\n   scales are applied elementwise with numpy-style broadcasting.\n\n   ScaleA and ScaleB define the epilogue functions that apply the scales for\n   the A and B operands respectively. These scales may be either per-tensor or\n   per row or column.\n*/\ntemplate <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>\nstruct ScaledEpilogue {\n private:\n  using Accum = cutlass::epilogue::fusion::Sm90AccFetch;\n\n  using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<\n      0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,\n      Stride<Int<1>, Int<0>, Int<0>>>;\n\n  using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<\n      0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,\n      Stride<Int<0>, Int<1>, Int<0>>>;\n\n  using Compute0 = cutlass::epilogue::fusion::Sm90Compute<\n      cutlass::multiplies, float, float,\n      cutlass::FloatRoundStyle::round_to_nearest>;\n\n  using EVTCompute0 =\n      cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;\n\n  using Compute1 = cutlass::epilogue::fusion::Sm90Compute<\n      cutlass::multiplies, ElementD, float,\n      cutlass::FloatRoundStyle::round_to_nearest>;\n\n public:\n  using EVTCompute =\n      cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;\n  using ArgumentType = typename EVTCompute::Arguments;\n\n  static ArgumentType prepare_args(torch::Tensor const& a_scales,\n                                   torch::Tensor const& b_scales) {\n    using ScaleA_Args = typename ScaleA::Arguments;\n    using ScaleB_Args = typename ScaleB::Arguments;\n\n    ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};\n    ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};\n\n    return ArgumentType{a_args, {b_args}};\n  }\n};\n\ntemplate <typename ElementAB_, typename ElementD_,\n          template <typename, typename, typename> typename Epilogue_,\n          typename TileShape, typename ClusterShape, typename KernelSchedule,\n          typename EpilogueSchedule>\nstruct cutlass_3x_gemm {\n  using ElementAB = ElementAB_;\n  using ElementD = ElementD_;\n  using ElementAcc =\n      typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,\n                                float>::type;\n\n  using EpilogueDescriptor =\n      cutlass::epilogue::collective::detail::EpilogueDescriptor<\n          TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,\n          ElementD, EpilogueSchedule>;\n\n  using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;\n\n  using StrideD = Stride<int64_t, Int<1>, Int<0>>;\n  using ElementC = void;\n  using StrideC = StrideD;\n\n  using EVTCompute = typename Epilogue::EVTCompute;\n\n  using CollectiveEpilogue =\n      typename cutlass::epilogue::collective::CollectiveBuilder<\n          cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,\n          ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,\n          ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,\n          EpilogueSchedule, EVTCompute>::CollectiveOp;\n\n  static constexpr size_t CEStorageSize =\n      sizeof(typename CollectiveEpilogue::SharedStorage);\n  using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<\n      static_cast<int>(CEStorageSize)>;\n\n  // clang-format off\n  using CollectiveMainloop =\n      typename cutlass::gemm::collective::CollectiveBuilder<\n          cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, \n          ElementAB, cutlass::layout::RowMajor, 16, \n          ElementAB, cutlass::layout::ColumnMajor, 16, \n          ElementAcc, TileShape, ClusterShape,\n          Stages,\n          KernelSchedule>::CollectiveOp;\n  // clang-format on\n\n  using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<\n      cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,\n      cutlass::gemm::PersistentScheduler>>;\n\n  struct GemmKernel : public KernelType {};\n};\n\ntemplate <typename Gemm, typename... EpilogueArgs>\nvoid cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,\n                         torch::Tensor const& b,\n                         EpilogueArgs&&... epilogue_params) {\n  using ElementAB = typename Gemm::ElementAB;\n  using ElementD = typename Gemm::ElementD;\n\n  int32_t m = a.size(0);\n  int32_t n = b.size(1);\n  int32_t k = a.size(1);\n\n  int64_t lda = a.stride(0);\n  int64_t ldb = b.stride(1);\n  int64_t ldc = out.stride(0);\n\n  using StrideA = Stride<int64_t, Int<1>, int64_t>;\n  using StrideB = Stride<int64_t, Int<1>, int64_t>;\n  using StrideC = typename Gemm::StrideC;\n\n  StrideA a_stride{lda, Int<1>{}, 0};\n  StrideB b_stride{ldb, Int<1>{}, 0};\n  StrideC c_stride{ldc, Int<1>{}, Int<0>{}};\n\n  using GemmKernel = typename Gemm::GemmKernel;\n  typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};\n\n  auto a_ptr = static_cast<ElementAB*>(a.data_ptr());\n  auto b_ptr = static_cast<ElementAB*>(b.data_ptr());\n  typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,\n                                                       b_stride};\n\n  auto c_ptr = static_cast<ElementD*>(out.data_ptr());\n  typename GemmKernel::EpilogueArguments epilogue_args{\n      Gemm::Epilogue::prepare_args(\n          std::forward<EpilogueArgs>(epilogue_params)...),\n      c_ptr, c_stride, c_ptr, c_stride};\n\n  typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,\n                                      prob_shape, mainloop_args, epilogue_args};\n\n  // Launch the CUTLASS GEMM kernel.\n  using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;\n  GemmOp gemm_op;\n  // CUTLASS_CHECK(gemm_op.can_implement(args));\n\n  size_t workspace_size = gemm_op.get_workspace_size(args);\n  auto const workspace_options =\n      torch::TensorOptions().dtype(torch::kUInt8).device(a.device());\n  auto workspace = torch::empty(workspace_size, workspace_options);\n\n  auto stream = at::cuda::getCurrentCUDAStream(a.get_device());\n\n  cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);\n  CUTLASS_CHECK(status);\n}\n\ntemplate <typename InType, typename OutType,\n          template <typename, typename, typename> typename Epilogue>\nstruct sm90_fp8_config_default {\n  // M in (128, inf)\n  static_assert(std::is_same<InType, cutlass::float_e4m3_t>());\n  using KernelSchedule =\n      cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;\n  using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;\n  using TileShape = Shape<_128, _128, _128>;\n  using ClusterShape = Shape<_2, _1, _1>;\n  using Cutlass3xGemm =\n      cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,\n                      KernelSchedule, EpilogueSchedule>;\n};\n\ntemplate <typename InType, typename OutType,\n          template <typename, typename, typename> typename Epilogue>\nstruct sm90_fp8_config_M128 {\n  // M in (64, 128]\n  static_assert(std::is_same<InType, cutlass::float_e4m3_t>());\n  using KernelSchedule =\n      cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;\n  using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;\n  using TileShape = Shape<_64, _128, _128>;\n  using ClusterShape = Shape<_2, _1, _1>;\n  using Cutlass3xGemm =\n      cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,\n                      KernelSchedule, EpilogueSchedule>;\n};\n\ntemplate <typename InType, typename OutType,\n          template <typename, typename, typename> typename Epilogue>\nstruct sm90_fp8_config_M64 {\n  // M in [1, 64]\n  static_assert(std::is_same<InType, cutlass::float_e4m3_t>());\n  using KernelSchedule =\n      cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;\n  using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;\n  using TileShape = Shape<_64, _64, _128>;\n  using ClusterShape = Shape<_1, _8, _1>;\n\n  using Cutlass3xGemm =\n      cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,\n                      KernelSchedule, EpilogueSchedule>;\n};\n\ntemplate <typename InType, typename OutType,\n          template <typename, typename, typename> typename Epilogue>\nstruct sm90_int8_config_default {\n  // For M > 128 and any N\n  static_assert(std::is_same<InType, int8_t>());\n  using KernelSchedule =\n      typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;\n  using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;\n  using TileShape = Shape<_128, _128, _128>;\n  using ClusterShape = Shape<_2, _1, _1>;\n  using Cutlass3xGemm =\n      cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,\n                      KernelSchedule, EpilogueSchedule>;\n};\n\ntemplate <typename InType, typename OutType,\n          template <typename, typename, typename> typename Epilogue>\nstruct sm90_int8_config_M128 {\n  // For M in (64, 128] and any N\n  static_assert(std::is_same<InType, int8_t>());\n  using KernelSchedule =\n      typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;\n  using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;\n  using TileShape = Shape<_64, _128, _128>;\n  using ClusterShape = Shape<_2, _1, _1>;\n  using Cutlass3xGemm =\n      cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,\n                      KernelSchedule, EpilogueSchedule>;\n};\n\ntemplate <typename InType, typename OutType,\n          template <typename, typename, typename> typename Epilogue>\nstruct sm90_int8_config_M64 {\n  // For M in (32, 64] and any N\n  static_assert(std::is_same<InType, int8_t>());\n  using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;\n  using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;\n  using TileShape = Shape<_64, _64, _256>;\n  using ClusterShape = Shape<_1, _1, _1>;\n  using Cutlass3xGemm =\n      cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,\n                      KernelSchedule, EpilogueSchedule>;\n};\n\ntemplate <typename InType, typename OutType,\n          template <typename, typename, typename> typename Epilogue>\nstruct sm90_int8_config_M32_NBig {\n  // For M in [1, 32] and N >= 8192\n  static_assert(std::is_same<InType, int8_t>());\n  using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;\n  using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;\n  using TileShape = Shape<_64, _128, _256>;\n  using ClusterShape = Shape<_1, _4, _1>;\n  using Cutlass3xGemm =\n      cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,\n                      KernelSchedule, EpilogueSchedule>;\n};\n\ntemplate <typename InType, typename OutType,\n          template <typename, typename, typename> typename Epilogue>\nstruct sm90_int8_config_M32_NSmall {\n  // For M in [1, 32] and N < 8192\n  static_assert(std::is_same<InType, int8_t>());\n  using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;\n  using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;\n  using TileShape = Shape<_64, _64, _256>;\n  using ClusterShape = Shape<_1, _8, _1>;\n  using Cutlass3xGemm =\n      cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,\n                      KernelSchedule, EpilogueSchedule>;\n};\n\n}  // namespace\n\ntemplate <typename InType, typename OutType,\n          template <typename, typename, typename> typename Epilogue,\n          typename... EpilogueArgs>\nvoid cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,\n                                    torch::Tensor const& b,\n                                    EpilogueArgs&&... args) {\n  static_assert(std::is_same<InType, cutlass::float_e4m3_t>());\n  TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);\n  TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);\n\n  using Cutlass3xGemmDefault =\n      typename sm90_fp8_config_default<InType, OutType,\n                                       Epilogue>::Cutlass3xGemm;\n  using Cutlass3xGemmM64 =\n      typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;\n  using Cutlass3xGemmM128 =\n      typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;\n\n  uint32_t const m = a.size(0);\n  uint32_t const mp2 =\n      std::max(static_cast<uint32_t>(64), next_pow_2(m));  // next power of 2\n\n  if (mp2 <= 64) {\n    // m in [1, 64]\n    return cutlass_gemm_caller<Cutlass3xGemmM64>(\n        out, a, b, std::forward<EpilogueArgs>(args)...);\n  } else if (mp2 <= 128) {\n    // m in (64, 128]\n    return cutlass_gemm_caller<Cutlass3xGemmM128>(\n        out, a, b, std::forward<EpilogueArgs>(args)...);\n  } else {\n    // m in (128, inf)\n    return cutlass_gemm_caller<Cutlass3xGemmDefault>(\n        out, a, b, std::forward<EpilogueArgs>(args)...);\n  }\n}\n\ntemplate <typename InType, typename OutType,\n          template <typename, typename, typename> typename Epilogue,\n          typename... EpilogueArgs>\nvoid cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,\n                                     torch::Tensor const& b,\n                                     EpilogueArgs&&... args) {\n  static_assert(std::is_same<InType, int8_t>());\n  TORCH_CHECK(a.dtype() == torch::kInt8);\n  TORCH_CHECK(b.dtype() == torch::kInt8);\n\n  using Cutlass3xGemmDefault =\n      typename sm90_int8_config_default<InType, OutType,\n                                        Epilogue>::Cutlass3xGemm;\n  using Cutlass3xGemmM128 =\n      typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;\n  using Cutlass3xGemmM64 =\n      typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;\n  using Cutlass3xGemmM32NBig =\n      typename sm90_int8_config_M32_NBig<InType, OutType,\n                                         Epilogue>::Cutlass3xGemm;\n  using Cutlass3xGemmM32NSmall =\n      typename sm90_int8_config_M32_NSmall<InType, OutType,\n                                           Epilogue>::Cutlass3xGemm;\n\n  uint32_t const n = out.size(1);\n  bool const is_small_n = n < 8192;\n\n  uint32_t const m = a.size(0);\n  uint32_t const mp2 =\n      std::max(static_cast<uint32_t>(32), next_pow_2(m));  // next power of 2\n\n  if (mp2 <= 32) {\n    // m in [1, 32]\n    if (is_small_n) {\n      return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(\n          out, a, b, std::forward<EpilogueArgs>(args)...);\n    } else {\n      return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(\n          out, a, b, std::forward<EpilogueArgs>(args)...);\n    }\n  } else if (mp2 <= 64) {\n    // m in (32, 64]\n    return cutlass_gemm_caller<Cutlass3xGemmM64>(\n        out, a, b, std::forward<EpilogueArgs>(args)...);\n  } else if (mp2 <= 128) {\n    // m in (64, 128]\n    return cutlass_gemm_caller<Cutlass3xGemmM128>(\n        out, a, b, std::forward<EpilogueArgs>(args)...);\n  } else {\n    // m in (128, inf)\n    return cutlass_gemm_caller<Cutlass3xGemmDefault>(\n        out, a, b, std::forward<EpilogueArgs>(args)...);\n  }\n}\n\nvoid cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,\n                            torch::Tensor const& b,\n                            torch::Tensor const& a_scales,\n                            torch::Tensor const& b_scales) {\n  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);\n  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);\n\n  if (a.dtype() == torch::kInt8) {\n    TORCH_CHECK(b.dtype() == torch::kInt8);\n\n    if (out.dtype() == torch::kBFloat16) {\n      return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,\n                                             ScaledEpilogue>(\n          out, a, b, a_scales, b_scales);\n    } else {\n      TORCH_CHECK(out.dtype() == torch::kFloat16);\n      return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t,\n                                             ScaledEpilogue>(\n          out, a, b, a_scales, b_scales);\n    }\n  } else {\n    TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);\n    TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);\n\n    if (out.dtype() == torch::kBFloat16) {\n      return cutlass_gemm_sm90_fp8_dispatch<\n          cutlass::float_e4m3_t, cutlass::bfloat16_t, ScaledEpilogue>(\n          out, a, b, a_scales, b_scales);\n    } else {\n      TORCH_CHECK(out.dtype() == torch::kFloat16);\n      return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,\n                                            cutlass::half_t, ScaledEpilogue>(\n          out, a, b, a_scales, b_scales);\n    }\n  }\n}\n\n#endif"
  },
  {
    "path": "kernels/cuda/cutlass_gemm/readme.md",
    "content": "Currently the CPP extension builds with Cutlass 3.5.1 (credit to @SamirMoustafa for the update).  \n3.6 will fail atm due to a refactor in the TMA descriptor.  \n"
  },
  {
    "path": "kernels/cuda/cutlass_gemm/setup.py",
    "content": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\nsetup(\n    name='cutlass_gemm',\n    ext_modules=[\n        CUDAExtension(\n            name='pingpong_gemm',\n            sources=['cutlass.cpp', 'cutlass_kernel.cu'],\n            extra_compile_args={\n                'nvcc': [\n                    '-DNDEBUG',\n                    '-O3', \n                    '-g', \n                    '-lineinfo',\n                    '--keep', \n                    '--ptxas-options=--warn-on-local-memory-usage',\n                    '--ptxas-options=--warn-on-spills',\n                    '--resource-usage',\n                    '--source-in-ptx',\n                    '-DCUTLASS_DEBUG_TRACE_LEVEL=1',\n                    '-gencode=arch=compute_90a, code=sm_90a',\n                ]\n            },\n            include_dirs=[\n                '/home/adhoq26/cutlass/include',\n                '/home/adhoq26/cutlass/tools/util/include',\n            ],\n            libraries=['cuda'],\n            library_dirs=['/usr/local/cuda-12.4/lib64'],\n        )\n    ],\n    cmdclass={\n        'build_ext': BuildExtension\n    }\n)"
  },
  {
    "path": "kernels/cuda/cutlass_gemm/test_cutlass_gemm.py",
    "content": "from pingpong_gemm import cutlass_scaled_mm\nimport torch\n\nm, k, n = 16, 4096, 4096\ndtype = torch.float8_e4m3fn\nout_dtype = torch.float16\n\na = torch.empty(m, k).normal_(mean=0.0, std=0.5).to(dtype=dtype, device='cuda')\nbt = torch.empty(n, k).normal_(mean=0.0, std=0.5).to(dtype=dtype, device='cuda').t()\nscale_a = torch.ones((1,)).to(dtype=torch.float32, device='cuda')\nscale_b = torch.ones((1,)).to(dtype=torch.float32, device='cuda')\ny = cutlass_scaled_mm(a, bt, scale_a, scale_b)\nprint(y)"
  },
  {
    "path": "kernels/cuda/inference/README.md",
    "content": "cuda kernels\n"
  },
  {
    "path": "kernels/cuda/inference/hadamard_transform/hadamard_transform.cpp",
    "content": "#include <torch/extension.h>\n#include <pybind11/pybind11.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGuard.h>\n\nusing namespace torch::indexing;\n\ntemplate <torch::ScalarType dtype>\nvoid run_fht(void* a, void* out, uint32_t numel, uint32_t had_size, cudaStream_t stream);\n\nconstexpr bool is_power_of_two(uint32_t x) {\n    return x && !(x & (x - 1));\n}\n\ntorch::Tensor hadamard_transform(at::Tensor& in, bool inplace) {\n    auto dtype = in.scalar_type();\n    TORCH_CHECK(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16, \"Only fp16 and bf16 supported currently\");\n    TORCH_CHECK(in.is_cuda());\n    \n    const int had_size = in.size(-1);\n    TORCH_CHECK(is_power_of_two(had_size) && (had_size <= (1U << 15)),\n        \"Only power of two Hadamard sizes up to 2^15 are supported, got \", had_size);\n    \n    const auto res_shape = in.sizes();\n    torch::Tensor x = in.reshape({-1, had_size});\n    \n    auto numel = in.numel();\n    if (numel % 256 != 0) {\n        x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, 0, 0, (256 - numel % 256) / had_size}));\n    }\n    \n    if (x.stride(-1) != 1) {\n        x = x.contiguous();\n    }\n    torch::Tensor out = inplace ? x : torch::empty_like(x);\n\n    at::cuda::CUDAGuard device_guard{(char)x.get_device()};\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n    if (dtype == torch::ScalarType::Half) {\n        run_fht<torch::ScalarType::Half>(x.data_ptr(), out.data_ptr(), x.numel(), had_size, stream);\n    } else {\n        run_fht<torch::ScalarType::BFloat16>(x.data_ptr(), out.data_ptr(), x.numel(), had_size, stream);\n    }\n\n    if (numel % 256 != 0) {\n        out = out.index({Slice(0, numel / had_size)});\n    }\n\n    if (inplace && out.data_ptr() != in.data_ptr()) {\n        in.copy_(out.view(res_shape));\n        return in;\n    }\n    return out.reshape(res_shape);\n}\n\nnamespace py = pybind11;\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"hadamard_transform\", &hadamard_transform, \"A function to perform a fast Hadamard transform\", py::arg(\"x\"), py::arg(\"inplace\")=false);\n}"
  },
  {
    "path": "kernels/cuda/inference/hadamard_transform/hadamard_transform_cuda.cu",
    "content": "#include <torch/extension.h>\n#include <stdint.h>\n#include <cuda_runtime.h>\n#include <mma.h>\n#include <cuda/annotated_ptr>\n#include <c10/cuda/CUDAException.h>\n\n#ifndef __CUDACC__\n#define __launch_bounds__(x,y)\n#endif\n\n#define MAX_WARPS_PER_SM 48\n\n#define MIN(a, b) ((a) < (b) ? (a) : (b))\n\ntypedef uint32_t b32;\ntypedef uint16_t b16;\n\nconstexpr int launch_configs_big[7][3] = {\n    // default\n    {2, 1, 24},\n    {2, 2, 16}, \n    {2, 4, 8}, \n    {2, 8, 4}, \n    {2, 16, 3},\n    {4, 16, 2},\n    {8, 16, 1}\n    // // extra coalescing\n    // {2, 1, 24},\n    // {2, 2, 16}, \n    // {2, 4, 8}, \n    // {2, 8, 4}, \n    // {4, 8, 3},\n    // {8, 8, 2},\n    // {16, 8, 1}\n    // // less coalescing\n    // {2, 1, 24},\n    // {2, 2, 16}, \n    // {2, 4, 8}, \n    // {2, 8, 4}, \n    // {1, 32, 1},\n    // {2, 32, 1},\n    // {4, 32, 1}\n};\n\n// a 4x2, b 2x2, c 2x2\ntemplate <torch::ScalarType dtype>\n__device__ __forceinline__ void mma_m16_n8_k16_b16_b16_b16_noacc(b32 a0, b32 a1, b32 a2, b32 a3, b32 b0, b32 b1, b32& c0, b32& c1){\n    static_assert(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16);\n    // d, a, b, c\n    b32 zero = 0;\n    if constexpr(dtype == torch::ScalarType::Half) {\n        asm (\n            \"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 \"\n            \"{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\\n\\t\"\n            : \"=r\"(c0), \"=r\"(c1) : \"r\"(a0), \"r\"(a1), \"r\"(a2), \"r\"(a3), \"r\"(b0), \"r\"(b1), \"r\"(zero), \"r\"(zero)\n        );\n    } else {\n        b32 temp0, temp1, temp2, temp3;\n        asm (\n            \"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \"\n            \"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\\n\\t\"\n            : \"=r\"(temp0), \"=r\"(temp1), \"=r\"(temp2), \"=r\"(temp3) : \"r\"(a0), \"r\"(a1), \"r\"(a2), \"r\"(a3), \"r\"(b0), \"r\"(b1), \"r\"(zero), \"r\"(zero), \"r\"(zero), \"r\"(zero)\n        );\n        asm (\"cvt.rn.bf16x2.f32 %0, %1, %2;\\n\\t\" : \"=r\"(c0) : \"r\"(temp1), \"r\"(temp0));\n        asm (\"cvt.rn.bf16x2.f32 %0, %1, %2;\\n\\t\" : \"=r\"(c1) : \"r\"(temp3), \"r\"(temp2));\n    }\n}\n\n// a 4x2, b 4x2, c 4x2\ntemplate <torch::ScalarType dtype>\n__device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc(b32 a0, b32 a1, b32 a2, b32 a3, b32 b0, b32 b1, b32 b2, b32 b3, b32& c0, b32& c1, b32& c2, b32& c3){\n    mma_m16_n8_k16_b16_b16_b16_noacc<dtype>(a0, a1, a2, a3, b0, b1, c0, c1);\n    mma_m16_n8_k16_b16_b16_b16_noacc<dtype>(a0, a1, a2, a3, b2, b3, c2, c3);\n}\n\n__device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(b32& a0) {\n    asm (\n        \"movmatrix.sync.aligned.m8n8.trans.b16 \"\n        \"%0, %1;\\n\\t\"\n        : \"=r\"(a0) : \"r\"(a0)\n    );\n}\n\n#define p_p(i) ((val_1p[i] & 0x0000FFFF) | val_1p[i] << 16)\n#define p_n(i) ((val_1p[i] & 0x0000FFFF) | val_1n[i] << 16)\n#define n_p(i) ((val_1n[i] & 0x0000FFFF) | val_1p[i] << 16)\n#define n_n(i) ((val_1n[i] & 0x0000FFFF) | val_1n[i] << 16)\n\ntemplate<int num_chunks, int warps_per_block, int log_had_size, int blocks_per_sm, bool enable_mask, torch::ScalarType dtype>\n__global__ void __launch_bounds__(32 * warps_per_block, blocks_per_sm)\n// a is column major, b is row major\nhadamard_transform_kernel(b16* a, b16* out, int total_num_chunks) {\n    static_assert(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16, \"Only fp16 and bf16 supported currently\");\n\n    b32 b_frag_all[num_chunks][4]; // for all chunks, holds matrix fragment (which takes 4 regs of b16x2 * 32 threads)\n\n    uint blockid = blockIdx.x * warps_per_block + threadIdx.x / 32;\n    uint threadid = threadIdx.x % 32;\n    extern __shared__ b32 bfrag_arr[]; // num_chunks * warps_per_block * 128\n    int real_num_chunks = ((blockid + 1) * num_chunks) > total_num_chunks ? (total_num_chunks - (blockid * num_chunks)) : num_chunks;\n    int diff_num_chunks = real_num_chunks - num_chunks;\n\n    b32* a_start_ptr = (b32*) (a + blockid * num_chunks * 256); // offset a to where this warp starts\n    b32* out_start_ptr = (b32*) (out + blockid * num_chunks * 256);\n    b32* a_ptr = a_start_ptr + threadid * 4;\n    b32* b_frag_ptr = bfrag_arr + (blockid % warps_per_block) * num_chunks * 128 + threadid * 4;\n\n    #if (__CUDA_ARCH__ < 900) // SM80, SM89\n    uint64_t cache_policy;\n    asm volatile(\n        \"createpolicy.fractional.L2::evict_first.b64 %0, 1.0;\\n\"\n        : \"=l\"(cache_policy)\n    );\n    #endif\n\n    #pragma unroll\n    for (int k = 0; k < num_chunks; k++) {\n        size_t shared_ptr = __cvta_generic_to_shared(b_frag_ptr);\n        #if (__CUDA_ARCH__ >= 900) // SM90\n            asm volatile(\n                \"cp.async.cg.shared.global [%0], [%1], 16;\\n\"\n                \"cp.async.commit_group;\\n\"\n                :: \"l\"(shared_ptr), \"l\"(a_ptr)\n            );\n        #else // SM80, SM89\n            asm volatile(\n                \"cp.async.cg.shared.global.L2::cache_hint.L2::256B [%0], [%1], 16, %2;\\n\"\n                \"cp.async.commit_group;\\n\"\n                :: \"l\"(shared_ptr), \"l\"(a_ptr), \"l\"(cache_policy)\n            );\n        #endif\n\n        a_ptr += 128;\n        b_frag_ptr += 128;\n    }\n\n    // generate hadamard 16x16 (up to 2 of them)\n    constexpr b16 fp16_1p[4] = {0b0011100110101000, 0b0011100000000000, 0b0011010110101000, 0b0011010000000000};\n    constexpr b16 fp16_1n[4] = {0b1011100110101000, 0b1011100000000000, 0b1011010110101000, 0b1011010000000000};\n    constexpr b16 bf16_1p[4] = {0b0011111100110101, 0b0011111100000000, 0b0011111010110101, 0b0011111010000000};\n    constexpr b16 bf16_1n[4] = {0b1011111100110101, 0b1011111100000000, 0b1011111010110101, 0b1011111010000000};\n\n    #define val_type_1p(i) (((dtype) == torch::ScalarType::Half) ? (fp16_1p[i]) : (bf16_1p[i]))\n    #define val_type_1n(i) (((dtype) == torch::ScalarType::Half) ? (fp16_1n[i]) : (bf16_1n[i]))\n    constexpr b16 val_1p[4] = {val_type_1p(0), val_type_1p(1), val_type_1p(2), val_type_1p(3)};\n    constexpr b16 val_1n[4] = {val_type_1n(0), val_type_1n(1), val_type_1n(2), val_type_1n(3)};\n\n    constexpr b32 p_p[4] = {p_p(0), p_p(1), p_p(2), p_p(3)};\n    constexpr b32 p_n[4] = {p_n(0), p_n(1), p_n(2), p_n(3)};\n    constexpr b32 n_p[4] = {n_p(0), n_p(1), n_p(2), n_p(3)};\n    constexpr b32 n_n[4] = {n_n(0), n_n(1), n_n(2), n_n(3)};\n    const b32 had_16_p1[4][4] = {\n        {\n            0b10001000010001000010001000010001,\n            0b00000000000000000000000000000000,\n            0b00000000000000000000000000000000,\n            0b10001000010001000010001000010001\n        },\n        {\n            0b11001100100010000011001100100010,\n            0b00000000000000000000000000000000,\n            0b00000000000000000000000000000000,\n            0b11001100100010000011001100100010\n        },\n        {\n            0b11111111101010101100110010011001,\n            0b00000000000000000000000000000000,\n            0b00000000000000000000000000000000,\n            0b11111111101010101100110010011001\n        },\n        {\n            0b11111111101010101100110010011001,\n            0b11111111101010101100110010011001,\n            0b11111111101010101100110010011001,\n            0b00000000010101010011001101100110\n        }\n    };\n    const b32 had_16_p2[4][4] = {\n        {\n            0b10000000010000000010000000010000,\n            0b00000000000000000000000000000000,\n            0b00000000000000000000000000000000,\n            0b10000000010000000010000000010000\n        },\n        {\n            0b11000000100001000011000000100001,\n            0b00000000000000000000000000000000,\n            0b00000000000000000000000000000000,\n            0b11000000100001000011000000100001\n        },\n        {\n            0b11110000101001011100001110010110,\n            0b00000000000000000000000000000000,\n            0b00000000000000000000000000000000,\n            0b11110000101001011100001110010110\n        },\n        {\n            0b11110000101001011100001110010110,\n            0b11110000101001011100001110010110,\n            0b11110000101001011100001110010110,\n            0b00001111010110100011110001101001\n        }\n    };\n    const b32 had_16_mask[3][4] = {\n        {\n            0b10001000010001000010001000010001,\n            0b00000000000000000000000000000000,\n            0b00000000000000000000000000000000,\n            0b10001000010001000010001000010001\n        },\n        {\n            0b11001100110011000011001100110011,\n            0b00000000000000000000000000000000,\n            0b00000000000000000000000000000000,\n            0b11001100110011000011001100110011\n        },\n        {\n            0b11111111111111111111111111111111,\n            0b00000000000000000000000000000000,\n            0b00000000000000000000000000000000,\n            0b11111111111111111111111111111111\n        }\n    };\n    b32 had_frag[8];\n    #pragma unroll\n    for (int i = 0; i < 2; i++) {\n        int c_log_h = (i == 0) ? MIN(4, log_had_size) : log_had_size % 4;\n        #pragma unroll\n        for (int j = 0; j < 4; j++) {\n            if (c_log_h < 4) {\n                bool mask = had_16_mask[c_log_h - 1][j] & (1 << (31 - threadid));\n                if (!mask) {\n                    had_frag[i * 4 + j] = 0;\n                    continue;\n                }\n            }\n            bool pred1 = had_16_p1[c_log_h - 1][j] & (1 << (31 - threadid));\n            bool pred2 = had_16_p2[c_log_h - 1][j] & (1 << (31 - threadid));\n            b32 val = pred1 ? (pred2 ? p_p[c_log_h - 1] : p_n[c_log_h - 1]) : (pred2 ? n_p[c_log_h - 1] : n_n[c_log_h - 1]);\n            had_frag[i * 4 + j] = val;\n        }\n        if constexpr(log_had_size <= 4 || log_had_size % 4 == 0) break;\n    }\n\n    // log had size above 8, only used for above 2^8 = 256 size\n    constexpr int part8_log_had_size = log_had_size - 8;\n\n    b32* a_chunk_ptr = a_start_ptr; // first chunk starts at this warp's data starts\n    b32* out_chunk_ptr = out_start_ptr;\n\n    #pragma unroll\n    for (int l = 0; l < 2; l++) {\n        if constexpr(log_had_size <= 8) { // l == 0 guaranteed, redundant simplified version of else body, to help compiler warnings\n            b_frag_ptr = bfrag_arr + (blockid % warps_per_block) * num_chunks * 128;\n        } else {\n            b_frag_ptr = bfrag_arr + (blockid % warps_per_block) * num_chunks * (l == 0 ? 128 : (128 >> part8_log_had_size));\n        }\n\n        if (l == 1) {\n            if constexpr(log_had_size > 8) {\n                __syncthreads(); // sync between first and second iterations if above size 256\n\n                if constexpr(log_had_size >= 12) {\n                    // sizes 4k and above\n\n                    // a + threadblock offset + warp offset\n                    // can then index into all chunks owned by this warp\n                    b32* store = bfrag_arr + (128 >> part8_log_had_size) * (num_chunks * (blockid % warps_per_block));\n\n                    #pragma unroll\n                    for (int j = 0; j < 4; j++) {\n                        #pragma unroll\n                        for (int k = 0; k < num_chunks; k++) {\n                            // here, j represents register, and k represents 8-offset/chunk\n                            int real_chunk_num = (num_chunks - (threadid % num_chunks) + k) % num_chunks; // chunk at which you have target thread #'s data\n                            \n                            int real_thread_id = (threadid / num_chunks) * num_chunks + k; // target thread #\n                            int chunk_idx = 128 * real_chunk_num; // index due to fetching from another chunk (chunk in which this thread has the target thread's original data)\n                            int thread_group_idx = (real_thread_id / 4) * 16; // index due to fetching from another group of num_chunk threads (since shuffle is between num_chunk threads)\n                            int thread_idx = (real_thread_id % 4) * 2; // index due to original thread's position within the group of num_chunk threads\n                            int reg_idx = (j / 2) * 8 + (j % 2); // index due to target register\n                            int idx = chunk_idx + thread_group_idx + thread_idx + reg_idx; // final index\n\n                            // fix idx for majorness\n                            int rowidx = idx % (1 << part8_log_had_size);\n                            int colidx = idx >> part8_log_had_size;\n\n                            // store[rowidx * 128 + colidx] = data;\n                            b32 data = store[rowidx * 128 + colidx];\n\n                            // compiler generates excessive instructions, so we manually do the if statement\n                            #pragma unroll\n                            for (int i = 0; i < num_chunks; i++) {\n                                asm volatile (\n                                    \"{\\n\\t\"\n                                    \"  .reg .pred p0;\\n\\t\"\n                                    \"  setp.eq.u32 p0, %1, %2;\\n\\t\"\n                                    \"  @p0 mov.b32 %0, %3;\\n\\t\"\n                                    \"}\\n\\t\"\n                                    : \"+r\"(b_frag_all[i][j]) // Output operand %0\n                                    : \"r\"(real_chunk_num), \"r\"(i), \"r\"(data) // Input operands %1, %2, %3\n                                );\n                            }\n                        }\n                    }\n\n                    #pragma unroll\n                    for (int j = 0; j < 4; j++) {\n                        #pragma unroll\n                        for (int k = 1; k < num_chunks; k++) {\n                            int threadid_contig = threadid % num_chunks;\n                            int threadid_mul = threadid / num_chunks;\n                            int threadid2 = (threadid_contig + num_chunks - k) % num_chunks + threadid_mul * num_chunks; // thread to give your data to\n                            b_frag_all[k][j] = __shfl_sync(0xFFFFFFFF, b_frag_all[k][j], threadid2);\n                        }\n                    }\n                }\n            }\n        }\n\n        #pragma unroll\n        for (int k = 0; k < num_chunks; k++) {\n            if constexpr(enable_mask) {\n                if (k >= real_num_chunks)\n                    break;\n            }\n            if (l == 0) {\n                // bad fix for k not being recognized as a constexpr by compiler\n                // asm(\"cp.async.wait_group %0;\\n\" :: \"n\"(num_chunks - k - 1));\n                #define SWITCH_WAIT_ASYNC_LOAD_GROUP(i) case i: asm volatile(\"cp.async.wait_group %0;\\n\" :: \"n\"(num_chunks - i - 1)); break;\n                if constexpr(enable_mask) {\n                    switch(k + diff_num_chunks) {\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(0)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(1)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(2)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(3)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(4)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(5)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(6)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(7)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(8)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(9)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(10)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(11)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(12)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(13)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(14)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(15)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(16)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(17)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(18)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(19)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(20)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(21)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(22)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(23)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(24)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(25)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(26)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(27)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(28)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(29)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(30)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(31)\n                    }\n                } else {\n                    switch(k) {\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(0)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(1)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(2)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(3)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(4)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(5)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(6)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(7)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(8)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(9)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(10)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(11)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(12)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(13)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(14)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(15)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(16)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(17)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(18)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(19)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(20)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(21)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(22)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(23)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(24)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(25)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(26)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(27)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(28)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(29)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(30)\n                        SWITCH_WAIT_ASYNC_LOAD_GROUP(31)\n                    }\n                }\n            }\n\n            if (l == 0) {\n                // loading for the first iteration\n\n                // thread 0 loads  [t0r0, t16r1, t0r2, t16r3]\n                // thread 16 loads [t0r1, t16r0, t0r3, t16r2]\n                // allows full coalescing, same for t1/t17, t2/t18, etc.\n                #pragma unroll\n                for (int j = 0; j < 4; j++) {\n                    int reg = ((threadid & 16) == 0) ? j : (j / 2 * 2 + (1 - j % 2));\n                    int real_thread_id = (reg == 0 || reg == 2) ? threadid : (threadid ^ 16);\n                    int real_row = real_thread_id % 4;\n                    int real_col = real_thread_id / 4;\n                    b_frag_all[k][j] = b_frag_ptr[(real_row + (reg % 2) * 4) + (real_col + (j / 2) * 8) * 8];\n                }\n\n                // for t16 swap r0/r1 and r2/r3 to have [t16r0, t0r1, t16r2, t0r3]\n                // so registers are in right order, same for t17, t18, etc.\n                if ((threadid & 16) != 0) {\n                    b32 temp = b_frag_all[k][0];\n                    b_frag_all[k][0] = b_frag_all[k][1];\n                    b_frag_all[k][1] = temp;\n\n                    temp = b_frag_all[k][2];\n                    b_frag_all[k][2] = b_frag_all[k][3];\n                    b_frag_all[k][3] = temp;\n                }\n\n                // t0 and t16 swap r1 and r3 to have their own data,\n                // same for t1/t17, t2/18, etc.\n                #pragma unroll\n                for (int j = 1; j < 4; j += 2) {\n                    b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], 16);\n                }\n            } else if constexpr(log_had_size > 8) { // condition is redundant to help compiler warnings\n                if constexpr(log_had_size < 12) {\n                    // sizes 512, 1k, and 2k\n\n                    // for 512:\n                    //     thread 0 loads  [t0r0, t0r1, t16r2, t16r3]\n                    //     thread 16 loads [t0r2, t0r3, t16r0, t16r1]\n                    //     same for t1/t17, t2/t18, etc.\n                    // for 1k and 2k:\n                    //     thread 0 loads [t0r0, t0r1, t1r2, t1r3]\n                    //     thread 1 loads [t0r2, t0r3, t1r0, t1r1]\n                    //     same for t2/t3, t4/t5, etc.\n                    // allows full coalescing for 512 and 1k, 16x coalescing for 2k\n                    constexpr int xor_val = log_had_size == 9 ? 16 : 1;\n\n                    #pragma unroll\n                    for (int j = 0; j < 4; j++) {\n                        int reg = ((threadid & xor_val) == 0) ? j : (j + 2) % 4;\n                        int real_thread_id = reg < 2 ? threadid : (threadid ^ xor_val);\n                        int idx = (real_thread_id / 4 * 16) + (real_thread_id % 4 * 2) + (reg / 2 * 8) + (reg % 2);\n                        int rowidx = idx % (1 << part8_log_had_size);\n                        int colidx = idx >> part8_log_had_size;\n                        b_frag_all[k][j] = b_frag_ptr[rowidx * 128 + colidx];\n                    }\n\n                    if ((threadid & xor_val) != 0) {\n                        b32 temp = b_frag_all[k][0];\n                        b_frag_all[k][0] = b_frag_all[k][2];\n                        b_frag_all[k][2] = temp;\n\n                        temp = b_frag_all[k][1];\n                        b_frag_all[k][1] = b_frag_all[k][3];\n                        b_frag_all[k][3] = temp;\n                    }\n\n                    #pragma unroll\n                    for (int j = 2; j < 4; j++) {\n                        b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], xor_val);\n                    }\n                }\n            }\n\n            if (l == 1) {\n                // for second iteration, we load 2 consecutive b16s (1 b32) per register,\n                // but tensor core register layout requires 2 b16s that are in the\n                // same column/consecutive rows to be in the same register, so do the swap\n                b32 f0 = ((b_frag_all[k][1] & 0xFFFF) << 16) | (b_frag_all[k][0] & 0xFFFF);\n                b32 f1 = ((b_frag_all[k][3] & 0xFFFF) << 16) | (b_frag_all[k][2] & 0xFFFF);\n                b32 f2 = (b_frag_all[k][1] & 0xFFFF0000) | (b_frag_all[k][0] >> 16);\n                b32 f3 = (b_frag_all[k][3] & 0xFFFF0000) | (b_frag_all[k][2] >> 16);\n                b_frag_all[k][0] = f0;\n                b_frag_all[k][1] = f1;\n                b_frag_all[k][2] = f2;\n                b_frag_all[k][3] = f3;\n            }\n\n            #pragma unroll\n            for(int i = 0, remaining_log_had_size = log_had_size - l * 8; i < 2 && remaining_log_had_size > 0; i++) {\n                int had_off = ((remaining_log_had_size < 4) && !(log_had_size <= 4 || log_had_size % 4 == 0)) ? 4 : 0;\n                mma_m16_n16_k16_b16_b16_b16_noacc<dtype>(had_frag[had_off + 0], had_frag[had_off + 1], had_frag[had_off + 2], had_frag[had_off + 3], b_frag_all[k][0], b_frag_all[k][1], b_frag_all[k][2], b_frag_all[k][3], b_frag_all[k][0], b_frag_all[k][1], b_frag_all[k][2], b_frag_all[k][3]);\n\n                remaining_log_had_size -= 4;\n                if (remaining_log_had_size <= 0 && i == 0) {\n                    // TODO: consider different storing so no need for transpose\n                    matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][0]);\n                    matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][1]);\n                    matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][2]);\n                    matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][3]);\n                } else {\n                    // swap and use output directly as b_frag for next iteration as an actually free transpose\n                    b32 temp = b_frag_all[k][1];\n                    b_frag_all[k][1] = b_frag_all[k][2];\n                    b_frag_all[k][2] = temp;\n                }\n            }\n\n            if (l == 1) {\n                // invert swap from above for second iteration\n                b32 f0 = ((b_frag_all[k][2] & 0xFFFF) << 16) | (b_frag_all[k][0] & 0xFFFF);\n                b32 f1 = (b_frag_all[k][2] & 0xFFFF0000) | (b_frag_all[k][0] >> 16);\n                b32 f2 = ((b_frag_all[k][3] & 0xFFFF) << 16) | (b_frag_all[k][1] & 0xFFFF);\n                b32 f3 = (b_frag_all[k][3] & 0xFFFF0000) | (b_frag_all[k][1] >> 16);\n                b_frag_all[k][0] = f0;\n                b_frag_all[k][1] = f1;\n                b_frag_all[k][2] = f2;\n                b_frag_all[k][3] = f3;\n            }\n\n            if (l == 0) {\n                // inverse of coalesced load for first iteration to store result\n                #pragma unroll\n                for (int j = 1; j < 4; j += 2) {\n                    b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], 16);\n                }\n\n                if ((threadid & 16) != 0) {\n                    b32 temp = b_frag_all[k][0];\n                    b_frag_all[k][0] = b_frag_all[k][1];\n                    b_frag_all[k][1] = temp;\n\n                    temp = b_frag_all[k][2];\n                    b_frag_all[k][2] = b_frag_all[k][3];\n                    b_frag_all[k][3] = temp;\n                }\n\n                // if only going up to 256 size, store directly back to global memory,\n                // otherwise store back to shared memory for next iteration\n                b32* store = (log_had_size <= 8) ? out_chunk_ptr : b_frag_ptr;\n\n                #pragma unroll\n                for (int j = 0; j < 4; j++) {\n                    int reg = ((threadid & 16) == 0) ? j : (j / 2 * 2 + (1 - j % 2));\n                    int real_thread_id = (reg == 0 || reg == 2) ? threadid : (threadid ^ 16);\n                    int real_row = real_thread_id % 4;\n                    int real_col = real_thread_id / 4;\n                    store[(real_row + (reg % 2) * 4) + (real_col + (reg / 2) * 8) * 8] = b_frag_all[k][j];\n                }\n            } else if constexpr(log_had_size > 8) { // condition is redundant to help compiler warnings\n                if (log_had_size < 12) {\n                    // inverse of coalesced load for sizes 512, 1k and 2k to store result\n                    constexpr int xor_val = log_had_size == 9 ? 16 : 1;\n                    #pragma unroll\n                    for (int j = 2; j < 4; j++) {\n                        b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], xor_val);\n                    }\n\n                    if ((threadid & xor_val) != 0) {\n                        b32 temp = b_frag_all[k][0];\n                        b_frag_all[k][0] = b_frag_all[k][2];\n                        b_frag_all[k][2] = temp;\n\n                        temp = b_frag_all[k][1];\n                        b_frag_all[k][1] = b_frag_all[k][3];\n                        b_frag_all[k][3] = temp;\n                    }\n\n                    b32* store = (b32*)(out + (blockid / warps_per_block) * (num_chunks * warps_per_block) * 256 + (256 >> part8_log_had_size) * (num_chunks * (blockid % warps_per_block) + k));\n                    #pragma unroll\n                    for (int j = 0; j < 4; j++) {\n                        int reg = ((threadid & xor_val) == 0) ? j : (j + 2) % 4;\n                        b32 data = b_frag_all[k][j];\n                        int real_thread_id = reg < 2 ? threadid : (threadid ^ xor_val);\n                        int idx = (real_thread_id / 4 * 16) + (real_thread_id % 4 * 2) + (reg / 2 * 8) + (reg % 2);\n                        int rowidx = idx % (1 << part8_log_had_size);\n                        int colidx = idx >> part8_log_had_size;\n                        store[rowidx * 128 + colidx] = data;\n                    }\n                }\n                // for size 4k and above, wait to process all chunks so a final store can be performed coalesced\n            }\n\n            a_chunk_ptr += 128; // (only affects first 256 size) move on to next chunk by skipping 256 elements in b16 (= 128 in b32)\n            out_chunk_ptr += 128;\n            if constexpr(log_had_size > 8) {\n                b_frag_ptr += (l == 0 ? 128 : (128 >> part8_log_had_size));\n            } else { // else is redundant, simplified version of if body, to help compiler warnings\n                b_frag_ptr += 128;\n            }\n        }\n        if (log_had_size <= 8)\n            break;\n    }\n\n    if constexpr(log_had_size >= 12) {\n        // for sizes 4k and above, perform final coalesced store after processing all chunks\n        #pragma unroll\n        for (int j = 0; j < 4; j++) {\n            #pragma unroll\n            for (int k = 1; k < num_chunks; k++) {\n                int threadid_contig = threadid % num_chunks;\n                int threadid_mul = threadid / num_chunks;\n                int threadid2 = (threadid_contig + k) % num_chunks + threadid_mul * num_chunks; // thread to give your data to\n                b_frag_all[k][j] = __shfl_sync(0xFFFFFFFF, b_frag_all[k][j], threadid2);\n            }\n        }\n\n        // a + threadblock offset + warp offset\n        // can then index into all chunks owned by this warp\n        b32* store = bfrag_arr + (128 >> part8_log_had_size) * (num_chunks * (blockid % warps_per_block));\n\n        #pragma unroll\n        for (int j = 0; j < 4; j++) {\n            #pragma unroll\n            for (int k = 0; k < num_chunks; k++) {\n                // here, j represents register, and k represents 8-offset/chunk\n                int real_chunk_num = (num_chunks - (threadid % num_chunks) + k) % num_chunks; // chunk at which you have target thread #'s data\n\n                // b32 data = b_frag_all[real_chunk_num][j]; // target thread data\n                b32 data;\n                #pragma unroll\n                for (int i = 0; i < num_chunks; i++) {\n                    if (real_chunk_num == i) data = b_frag_all[i][j];\n                }\n                \n                int real_thread_id = (threadid / num_chunks) * num_chunks + k; // target thread #\n                int chunk_idx = 128 * real_chunk_num; // index due to fetching from another chunk (chunk in which this thread has the target thread's original data)\n                int thread_group_idx = (real_thread_id / 4) * 16; // index due to fetching from another group of num_chunk threads (since shuffle is between num_chunk threads)\n                int thread_idx = (real_thread_id % 4) * 2; // index due to original thread's position within the group of num_chunk threads\n                int reg_idx = (j / 2) * 8 + (j % 2); // index due to target register\n                int idx = chunk_idx + thread_group_idx + thread_idx + reg_idx; // final index\n\n                // fix idx for majorness\n                int rowidx = idx % (1 << part8_log_had_size);\n                int colidx = idx >> part8_log_had_size;\n\n                store[rowidx * 128 + colidx] = data;\n            }\n        }\n\n        __syncthreads();\n        store = ((b32*) out) + (blockid / warps_per_block) * (num_chunks * warps_per_block) * 128;\n        int4* store4 = (int4*) store;\n        int4* bfrag_arr4 = (int4*) bfrag_arr;\n        // flush smem, simply linearly write to store\n        // always divisible by 128*32b, so (32*4)*32b is ok\n        #pragma unroll\n        for (int warp_off = 0; warp_off < (num_chunks * warps_per_block * 128 / 4); warp_off += 32 * warps_per_block) {\n            int total_off = warp_off + threadid + (blockid % warps_per_block) * 32;\n            store4[total_off] = bfrag_arr4[total_off];\n        }\n    }\n\n}\n\nconstexpr int ceil_div(int a, int b) {\n    return (a + b - 1) / b;\n}\n\ntemplate <torch::ScalarType dtype, int chunks_per_warp, int warps_per_block, int log_had_size, int blocks_per_sm, bool check_masking = false>\nvoid __forceinline__ run_kernel(b16* a_mat, b16* out, int num_chunks, cudaStream_t stream) {\n    int shared_size = chunks_per_warp * warps_per_block * 128 * 4;\n    dim3 block_size = 32 * warps_per_block;\n\n    #define CHECK_SHARED_LIM() {                                                                              \\\n        if (shared_size > 48 * 1024) {                                                                        \\    \n            C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536)); \\\n        }                                                                                                     \\\n    }                                                                                                         \\\n\n    if constexpr(check_masking) {\n        if (num_chunks % (chunks_per_warp * warps_per_block) != 0) {\n            dim3 grid_size = ceil_div(ceil_div(num_chunks, chunks_per_warp), warps_per_block);\n            auto kernel = hadamard_transform_kernel<chunks_per_warp, warps_per_block, log_had_size, blocks_per_sm, true, dtype>;\n            CHECK_SHARED_LIM();\n            kernel<<<dim3(grid_size), dim3(block_size), shared_size, stream>>>(a_mat, out, num_chunks);\n        } else {\n            dim3 grid_size = num_chunks / chunks_per_warp / warps_per_block;\n            auto kernel = hadamard_transform_kernel<chunks_per_warp, warps_per_block, log_had_size, blocks_per_sm, false, dtype>;\n            CHECK_SHARED_LIM();\n            kernel<<<dim3(grid_size), dim3(block_size), shared_size, stream>>>(a_mat, out, num_chunks);\n        }\n    } else {\n        dim3 grid_size = num_chunks / chunks_per_warp / warps_per_block;\n        auto kernel = hadamard_transform_kernel<chunks_per_warp, warps_per_block, log_had_size, blocks_per_sm, false, dtype>;\n        CHECK_SHARED_LIM();\n        kernel<<<dim3(grid_size), dim3(block_size), shared_size, stream>>>(a_mat, out, num_chunks);\n    }\n    \n    C10_CUDA_KERNEL_LAUNCH_CHECK();\n}\n\ntemplate <torch::ScalarType dtype>\nvoid run_fht(void* a_mat_ptr, void* out_ptr, uint32_t numel, uint32_t had_size, cudaStream_t stream) {\n    uint32_t num_chunks = numel / 256; // caller required to ensure divisible by 256\n    // for size 256, use (2, 1)\n    // for size 32k use (8, 16)\n    constexpr int chunks_per_warp_small = 1;// 8;\n    constexpr int warps_per_block_small = 1;//2;//16;\n    constexpr int blocks_per_sm_small = 24;\n    constexpr int chunks_per_warp_large = 2;\n    constexpr int warps_per_block_large = 1;\n    constexpr int blocks_per_sm_large = 24;\n\n    // constexpr torch::ScalarType dtype = torch::ScalarType::Half;\n\n    b16* a_mat = (b16*) a_mat_ptr;\n    b16* out = (b16*) out_ptr;\n\n    if (numel <= 256) {\n        switch (had_size) {\n            case (1<<1): run_kernel<dtype, chunks_per_warp_small, warps_per_block_small, 1, blocks_per_sm_small>(a_mat, out, num_chunks, stream); break;\n            case (1<<2): run_kernel<dtype, chunks_per_warp_small, warps_per_block_small, 2, blocks_per_sm_small>(a_mat, out, num_chunks, stream); break;\n            case (1<<3): run_kernel<dtype, chunks_per_warp_small, warps_per_block_small, 3, blocks_per_sm_small>(a_mat, out, num_chunks, stream); break;\n            case (1<<4): run_kernel<dtype, chunks_per_warp_small, warps_per_block_small, 4, blocks_per_sm_small>(a_mat, out, num_chunks, stream); break;\n            case (1<<5): run_kernel<dtype, chunks_per_warp_small, warps_per_block_small, 5, blocks_per_sm_small>(a_mat, out, num_chunks, stream); break;\n            case (1<<6): run_kernel<dtype, chunks_per_warp_small, warps_per_block_small, 6, blocks_per_sm_small>(a_mat, out, num_chunks, stream); break;\n            case (1<<7): run_kernel<dtype, chunks_per_warp_small, warps_per_block_small, 7, blocks_per_sm_small>(a_mat, out, num_chunks, stream); break;\n            case (1<<8): run_kernel<dtype, chunks_per_warp_small, warps_per_block_small, 8, blocks_per_sm_small>(a_mat, out, num_chunks, stream); break;\n        }\n    } else {\n        switch (had_size) {\n            case (1<<1):  run_kernel<dtype, chunks_per_warp_large, warps_per_block_large, 1, blocks_per_sm_large, true>(a_mat, out, num_chunks, stream); break;\n            case (1<<2):  run_kernel<dtype, chunks_per_warp_large, warps_per_block_large, 2, blocks_per_sm_large, true>(a_mat, out, num_chunks, stream); break;\n            case (1<<3):  run_kernel<dtype, chunks_per_warp_large, warps_per_block_large, 3, blocks_per_sm_large, true>(a_mat, out, num_chunks, stream); break;\n            case (1<<4):  run_kernel<dtype, chunks_per_warp_large, warps_per_block_large, 4, blocks_per_sm_large, true>(a_mat, out, num_chunks, stream); break;\n            case (1<<5):  run_kernel<dtype, chunks_per_warp_large, warps_per_block_large, 5, blocks_per_sm_large, true>(a_mat, out, num_chunks, stream); break;\n            case (1<<6):  run_kernel<dtype, chunks_per_warp_large, warps_per_block_large, 6, blocks_per_sm_large, true>(a_mat, out, num_chunks, stream); break;\n            case (1<<7):  run_kernel<dtype, chunks_per_warp_large, warps_per_block_large, 7, blocks_per_sm_large, true>(a_mat, out, num_chunks, stream); break;\n            case (1<<8):  run_kernel<dtype, chunks_per_warp_large, warps_per_block_large, 8, blocks_per_sm_large, true>(a_mat, out, num_chunks, stream); break;\n            case (1<<9):  run_kernel<dtype, launch_configs_big[0][0], launch_configs_big[0][1], 9 , launch_configs_big[0][2]>(a_mat, out, num_chunks, stream); break;\n            case (1<<10): run_kernel<dtype, launch_configs_big[1][0], launch_configs_big[1][1], 10, launch_configs_big[1][2]>(a_mat, out, num_chunks, stream); break;\n            case (1<<11): run_kernel<dtype, launch_configs_big[2][0], launch_configs_big[2][1], 11, launch_configs_big[2][2]>(a_mat, out, num_chunks, stream); break;\n            case (1<<12): run_kernel<dtype, launch_configs_big[3][0], launch_configs_big[3][1], 12, launch_configs_big[3][2]>(a_mat, out, num_chunks, stream); break;\n            case (1<<13): run_kernel<dtype, launch_configs_big[4][0], launch_configs_big[4][1], 13, launch_configs_big[4][2]>(a_mat, out, num_chunks, stream); break;\n            case (1<<14): run_kernel<dtype, launch_configs_big[5][0], launch_configs_big[5][1], 14, launch_configs_big[5][2]>(a_mat, out, num_chunks, stream); break;\n            case (1<<15): run_kernel<dtype, launch_configs_big[6][0], launch_configs_big[6][1], 15, launch_configs_big[6][2]>(a_mat, out, num_chunks, stream); break;\n        }\n    }\n}\n\ntemplate void run_fht<torch::ScalarType::Half>(void* a_mat_ptr, void* out_ptr, uint32_t numel, uint32_t had_size, cudaStream_t stream);\ntemplate void run_fht<torch::ScalarType::BFloat16>(void* a_mat_ptr, void* out_ptr, uint32_t numel, uint32_t had_size, cudaStream_t stream);"
  },
  {
    "path": "kernels/cuda/inference/hadamard_transform/setup.py",
    "content": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\nversions = [\n    \"-gencode\",\n    \"arch=compute_80,code=sm_80\",\n    \"-gencode\",\n    \"arch=compute_89,code=sm_89\",\n    \"-gencode\",\n    \"arch=compute_90,code=sm_90\",\n] # TODO: assumes installed CUDA toolkit supports sm_80 to sm_90\n\nsetup(\n    name='faster_hadamard_transform',\n    ext_modules=[\n        CUDAExtension(\n            name=\"faster_hadamard_transform\",\n            sources=[\n                \"hadamard_transform.cpp\",\n                \"hadamard_transform_cuda.cu\",\n            ],\n            extra_compile_args={\n                \"cxx\": [\"-O3\"],\n                \"nvcc\": [\n                    \"-O3\",\n                    \"-lineinfo\",\n                    '--ptxas-options=--warn-on-local-memory-usage',\n                    '--ptxas-options=--warn-on-spills',\n                ] + versions\n            }\n        ),\n    ],\n    cmdclass={\n        'build_ext': BuildExtension\n    }\n)"
  },
  {
    "path": "kernels/cuda/inference/hadamard_transform/test.py",
    "content": "import torch\nimport faster_hadamard_transform\nimport scipy.linalg\nimport math\n\n# set to false to check performance\ncorrectness_check = True\n# set to warmup count + 1 to check performance\n# for quick testing, 2 is good.\nruns_per_size = 2\n\n# hadamard sizes\ntest_sizes_m = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]\n\ntest_elem_counts = [1 << i for i in range(9, 26, 1)] # 32MB # 64MB # 2**28 = 256M\n\nprint(\"test_sizes_m: \", test_sizes_m)\nprint(\"test_elem_counts: \", test_elem_counts)\n\ntest_count = len(test_sizes_m) * len(test_elem_counts)\ntests_done = 0\nfailed_tests = 0\n\ndef get_scale(size):\n    return math.sqrt(1 / size)\n\ntruth_hadamards = [torch.tensor(scipy.linalg.hadamard(size), device='cuda', dtype=torch.float32) * get_scale(size) for size in test_sizes_m]\ntruth_hadamards = [(x.to(torch.float16), x.to(torch.bfloat16)) for x in truth_hadamards]\ntruth_hadamards_fp16, truth_hadamards_bf16 = zip(*truth_hadamards)\ntruth_hadamards_fp16 = list(truth_hadamards_fp16)\ntruth_hadamards_bf16 = list(truth_hadamards_bf16)\ndel truth_hadamards\n\ndef truth_hadamard_transform_inplace(a: torch.Tensor, truth_hadamards):\n    target_index = -1\n    for i in range(len(test_sizes_m)):\n        if test_sizes_m[i] == a.shape[1]:\n            target_index = i\n            break\n    return a @ truth_hadamards[int(target_index)]\n\ndef test_hadamard_transform_inplace_rowmajor(a: torch.Tensor):\n    faster_hadamard_transform.hadamard_transform(a, inplace=True)\n    return a\n\ntorch.manual_seed(0)\n\ndef check_correctness(m, elem_c, a, result, truth, atol=1e-2, rtol=0):\n    success = torch.allclose(truth, result, atol=atol, rtol=rtol)\n\n    if not success:\n        torch.set_printoptions(threshold=100)\n        print(f'Failed test: {m}x{elem_c // m}')\n        print(f'Input:')\n        print(a)\n        print(f'Expected:')\n        print(truth)\n        print(f'Got:')\n        print(result)\n                    # worst element\n        diff = torch.abs(truth - result)\n        max_diff = torch.max(diff)\n        print(f'Max diff: {max_diff}')\n        print(f'Max diff index: {torch.argmax(diff)}')\n        diff_input = torch.abs(a - result)\n        max_diff_input = torch.max(diff_input)\n        print(f'Max diff input: {max_diff_input}')\n        print('')\n        exit(1)\n\nfor m in test_sizes_m:\n    for elem_c in test_elem_counts:\n        if elem_c < m:\n            tests_done += runs_per_size\n            if tests_done % 100 == 0 or tests_done == test_count:\n                print(f'{tests_done}/{test_count} tests done')\n            continue\n        print(f'Testing size {m}x{elem_c // m}')\n\n        a = torch.randn((elem_c // m, m), device='cuda', dtype=torch.float32)\n        # a = torch.zeros((m, elem_c // m), device='cuda', dtype=torch.float16)\n        # for i in range(min(a.shape[0], a.shape[1])):\n        #     a[i, i] = 1.0\n        if correctness_check:\n            for i in range(runs_per_size):\n                # run test here\n                a_result_fp16 = a.clone().to(torch.float16)\n                a_truth_fp16 = a.clone().to(torch.float16)\n                result_fp16 = test_hadamard_transform_inplace_rowmajor(a_result_fp16)\n                truth_fp16 = truth_hadamard_transform_inplace(a_truth_fp16, truth_hadamards_fp16)\n                check_correctness(m, elem_c, a, result_fp16, truth_fp16, atol=1e-2) # TODO: NOTE: we are not accurate down to 3 decimal places (atol)\n\n                a_result_bf16 = a.clone().to(torch.bfloat16)\n                a_truth_bf16 = a.clone().to(torch.bfloat16)\n                result_bf16 = test_hadamard_transform_inplace_rowmajor(a_result_bf16)\n                truth_bf16 = truth_hadamard_transform_inplace(a_truth_bf16, truth_hadamards_bf16)\n                check_correctness(m, elem_c, a, result_bf16, truth_bf16, atol=5e-2) # TODO: NOTE: need 5x atol to pass for bf16\n        else:\n            # run in a row so that warmup is valid\n            a_result = a # we can clobber the result cause we are only interested in timing\n            for i in range(runs_per_size):\n                a_result = test_hadamard_transform_inplace_rowmajor(a_result)\n            a_truth = a\n            for i in range(runs_per_size):\n                a_truth = truth_hadamard_transform_inplace(a_truth)\n            a_memcpy = a\n            # also can compare timing to memcpy\n            temp = torch.empty_like(a)\n            for i in range(runs_per_size):\n                temp.copy_(a_memcpy)\n            # do nothing with results since we are only interested in timing\n            # NOTE: make sure to disable clearing cache in Nsight Compute\n\n        tests_done += 1\n        if tests_done % 100 == 0 or tests_done == test_count:\n            print(f'{tests_done}/{test_count} size tests done')"
  },
  {
    "path": "kernels/cuda/training/README.md",
    "content": "kernels with backward pass support\n"
  },
  {
    "path": "kernels/cuda/tutorials/README.md",
    "content": "CUDA tutorials\n"
  },
  {
    "path": "kernels/cuda/tutorials/flash2.cu",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\n\n// flash2\n\n__global__\nvoid forward_kernel(const float* Q, const float* K, const float* V, const int N, const int d,\nconst int Tc, const int Tr, const int Bc, const int Br, const float sm_scale,\nfloat* l, float* m, float* O)\n{\n    int tidx = threadIdx.x;\n    int bidx = blockIdx.x;  // batch index\n    int bidy = blockIdx.y;  // head index\n\n    int qkv_offset = (bidx * gridDim.y * N * d) + (bidy*N*d);\n    int lm_offset = (bidx * gridDim.y *N) + (bidy *N); //l and m offset\n\n    extern __shared__ float sram[];\n    int tile_size = Bc * d; size of Qi, Kj, Vj\n\n    float* Qi = sram;\n    float * Kj = &sram[tile_size];\n    float* Vj = &sram[tile_size *2];\n    float* S = &sram[tile_size *3];\n\n    for (int j=0; j < Tc; j++) {\n\n        // load Kj, Vj to sram\n        for (int x=0; x < d; x++) {\n            Kj[(tx*d)+x] = K[qkv_offset + (tile_size *j) + (tx*d) +x];\n            Vj[(tx*d) + x] = V[qkv_offset +(tile_size *j) + (tx*d) +x];\n        }\n        __synchthreads();\n\n    }\n}\n\n\n    for (int j = 0; j < Tc; j++) {\n\n        // Load Kj, Vj to SRAM\n        for (int x = 0; x < d; x++) {\n            Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];\n            Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];\n        }\n        __syncthreads();  // such that the inner loop can use the correct Kj, Vj\n"
  },
  {
    "path": "kernels/needs_perf_help/fp8_gemm_bench.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\n# pyre-strict\n\nfrom typing import Callable, Tuple\n\n#import click\n\nimport torch\nimport triton  # @manual\n\nfrom fp8_gemm_rowwise import (\n    matmul_fp8_block,\n    matmul_fp8_row,\n    quantize_fp8_block,\n    quantize_fp8_row,\n)\nfrom torch._tensor import Tensor\n\n\n#@click.command()\n#@click.option(\"--cuda-graph\", type=bool, default=True)\n#@click.option(\"--rowwise-tma\", is_flag=True, default=False)\ndef bench(cuda_graph: bool, rowwise_tma: bool=True) -> None:\n    \"\"\"Benchmark bf16 vs scale/cast + fp8.\"\"\"\n\n    def _run_benchmark(\n        bench_factory: Callable[\n            [torch.Tensor, torch.Tensor], Callable[[], torch.Tensor]\n        ],\n        shape: Tuple[int, int, int] = (1024, 1024, 1024),\n        tag: str = \"\",\n    ) -> None:\n        # Benchmarks the function returned by bench_factory.\n        # Any pre-processing that should not be benchmarked can occur inside bench_factory.\n        m, n, k = shape\n\n        input_shape = (m, k)\n        weight_shape = (n, k)\n\n        base_dtype = torch.bfloat16\n        input_ = torch.randn(input_shape, device=\"cuda\", dtype=base_dtype)\n        weight_ = torch.randn(weight_shape, device=\"cuda\", dtype=base_dtype)\n\n        gemm_fn = bench_factory(input_, weight_)\n\n        if cuda_graph:\n            bench_stream = torch.cuda.Stream()\n            with torch.cuda.stream(bench_stream):\n                ms = triton.testing.do_bench_cudagraph(\n                    lambda: gemm_fn(),\n                    rep=100,\n                )\n        else:\n            ms = triton.testing.do_bench(\n                lambda: gemm_fn(),\n                warmup=25,\n                rep=100,\n            )\n\n        tflops = (2 * m * n * k) / 1e12\n        sec = ms / 1e3\n        perf_str = f\"{tflops / sec:.2f}\"\n        print(\n            f\"{(tag + ':').ljust(40)}\\tshape {str(shape):<25} tflops {perf_str:<8} ms {ms:.3f}\"\n        )\n\n    shapes = [\n        (8192, 8192, 512),\n        (8192, 8192, 8192),\n        (65536, 8192, 7168),\n        (65536, 3584, 8192),\n        (8192, 14336, 4096),\n    ]\n    for shape in shapes:\n        _run_benchmark(bf16_bench, shape=shape, tag=\"bf16\")\n        _run_benchmark(scale_row_bench, shape=shape, tag=\"fp8 scale + row gemm\")\n        _run_benchmark(scale_block_bench, shape=shape, tag=\"fp8 scale + block gemm\")\n        _run_benchmark(\n            row_gemm_bench,\n            shape=shape,\n            tag=\"fp8 row gemm only | fp8_fast_accum=True\",\n        )\n        _run_benchmark(\n            row_gemm_bench_no_fast_acc,\n            shape=shape,\n            tag=\"fp8 row gemm only | fp8_fast_accum=False\",\n        )\n        _run_benchmark(\n            row_gemm_bench_imprecise_acc,\n            shape=shape,\n            tag=\"fp8 row gemm only | max_num_imprecise_acc=32\",\n        )\n        _run_benchmark(block_gemm_bench, shape=shape, tag=\"fp8 block gemm only\")\n        if rowwise_tma:\n            _run_benchmark(\n                row_gemm_bench_tma,\n                shape=shape,\n                tag=\"fp8 row gemm only | fp8_fast_accum=True | tma_persistent=True\",\n            )\n\n\ndef bf16_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]:\n    def gemm_fn() -> Tensor:\n        return torch.matmul(x, w.T)\n\n    return gemm_fn\n\n\ndef scale_row_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]:\n    # Benchmark quantize(x) + gemm for inference.\n    def run_gemm() -> Tensor:\n        x_fp8: Tensor\n        w_fp8: Tensor\n        x_scale: Tensor\n        w_scale: Tensor\n        x_fp8, x_scale = quantize_fp8_row(x)\n        w_fp8, w_scale = quantize_fp8_row(w)\n        return matmul_fp8_row(\n            x_fp8,\n            w_fp8,\n            x_scale,\n            w_scale,\n            dot_out_dtype=torch.float32,\n            allow_tf32=True,\n            fp8_fast_accum=True,\n        )\n\n    return run_gemm\n\n\ndef row_gemm_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]:\n    # Benchmark only row-wise gemm, caching scaling.\n    x_fp8: Tensor\n    w_fp8: Tensor\n    x_scale: Tensor\n    w_scale: Tensor\n    x_fp8, x_scale = quantize_fp8_row(x)\n    w_fp8, w_scale = quantize_fp8_row(w)\n\n    def run_gemm() -> Tensor:\n        return matmul_fp8_row(\n            x_fp8,\n            w_fp8,\n            x_scale,\n            w_scale,\n            dot_out_dtype=torch.float32,\n            allow_tf32=True,\n            fp8_fast_accum=True,\n        )\n\n    return run_gemm\n\n\ndef row_gemm_bench_tma(x: Tensor, w: Tensor) -> Callable[[], Tensor]:\n    # Benchmark only row-wise gemm with TMA persistent\n    x_fp8: Tensor\n    w_fp8: Tensor\n    x_scale: Tensor\n    w_scale: Tensor\n    x_fp8, x_scale = quantize_fp8_row(x)\n    w_fp8, w_scale = quantize_fp8_row(w)\n\n    def run_gemm() -> Tensor:\n        return matmul_fp8_row(\n            x_fp8,\n            w_fp8,\n            x_scale,\n            w_scale,\n            dot_out_dtype=torch.float32,\n            allow_tf32=True,\n            fp8_fast_accum=True,\n            tma_persistent=True,\n        )\n\n    return run_gemm\n"
  },
  {
    "path": "kernels/needs_perf_help/fp8_rowwise_tma_persistent.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\n# pyre-unsafe\nimport logging\nfrom typing import List, Optional, Tuple\n\nimport torch\nimport triton  # @manual\n\nimport triton.language as tl  # @manual\nfrom torch._tensor import Tensor\n\nfrom triton import Config  # @manual\nfrom triton.ops.matmul_perf_model import (  # @manual\n    early_config_prune,\n    estimate_matmul_time,\n)\nfrom triton.runtime.jit import reinterpret as tl_reinterpret, TensorWrapper  # @manual\n\nlogger: logging.Logger = logging.getLogger(__name__)\n\n\ndef get_fp8_constants() -> Tuple[torch.dtype, tl.dtype, float, float]:\n    \"\"\"\n    Helper function to get constant values for the current platform.\n\n    Returns:\n        pt_dtype (torch.dtype): The correct torch fp8 datatype.\n        tl_dtype (tl.dtype): The correct triton fp8 datatype.\n        max_fp8 (float): The maximum reprsentable value for the fp8 datatype.\n        eps (float): Minimum clip value to prevent divide by zero.\n    \"\"\"\n    if torch.version.hip is not None:\n        pt_fp8_dtype = torch.float8_e4m3fnuz\n        tl_fp8_dtype = tl.float8e4b8\n    else:\n        pt_fp8_dtype = torch.float8_e4m3fn\n        tl_fp8_dtype = tl.float8e4nv\n    return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12\n\n\ndef convert_fp8_type(tensor, dtype) -> triton.TensorWrapper:\n    \"\"\"\n    Converts tensor to triton fp8 type.\n\n    Args:\n        tensor (torch.Tensor): input tensor.\n        dtype (tl.dtype): target triton dtype.\n\n    Returns:\n        triton.TensorWrapper: fp8 tensor.\n    \"\"\"\n    return tl_reinterpret(tensor, dtype=dtype)\n\n\ndef init_to_zero(name):\n    return lambda nargs: nargs[name].zero_()\n\n\ndef get_configs_io_bound() -> List[Config]:\n    \"\"\"\n    Returns a list of configs for matmul that are IO bound.\n\n    Returns:\n        List[Config]: list of configs.\n    \"\"\"\n    configs = []\n    for num_stages in [2, 3, 4, 5, 6]:\n        for block_m in [16, 32]:\n            for block_k in [32, 64]:\n                for block_n in [32, 64, 128, 256]:\n                    num_warps = 2 if block_n <= 64 else 4\n                    configs.append(\n                        Config(\n                            {\n                                \"BLOCK_M\": block_m,\n                                \"BLOCK_N\": block_n,\n                                \"BLOCK_K\": block_k,\n                                \"SPLIT_K\": 1,\n                            },\n                            num_stages=num_stages,\n                            num_warps=num_warps,\n                        )\n                    )\n                    # split_k\n                    for split_k in []:  # Disabled [2, 4, 8, 16]:\n                        configs.append(\n                            Config(\n                                {\n                                    \"BLOCK_M\": block_m,\n                                    \"BLOCK_N\": block_n,\n                                    \"BLOCK_K\": block_k,\n                                    \"SPLIT_K\": split_k,\n                                },\n                                num_stages=num_stages,\n                                num_warps=num_warps,\n                                pre_hook=init_to_zero(\"C\"),\n                            )\n                        )\n    return configs\n\n\n@triton.jit\ndef _kernel_matmul_fp8_row_tma_persistent(\n    A_ptr,\n    B_ptr,\n    C_ptr,\n    M,\n    N,\n    K,\n    A_scale,\n    B_scale,\n    stride_am,\n    stride_ak,\n    stride_bn,\n    stride_bk,\n    stride_cm,\n    stride_cn,\n    dot_out_dtype: tl.constexpr,\n    allow_tf32: tl.constexpr,\n    fp8_fast_accum: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_K: tl.constexpr,\n    GROUP_M: tl.constexpr,\n    AB_DTYPE: tl.constexpr,\n    NUM_SMS: tl.constexpr,\n) -> None:\n    \"\"\"Matmul kernel of [M, K] @ [N, K] with row-wise scales\n\n    performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.\n\n    Args:\n        A (TensorWrapper): [M, K] input tensor.\n        B (TensorWrapper): [N, K] input tensor.\n        C (TensorWrapper): [M, N] output tensor.\n        M (int): M dimension of input tensor.\n        N (int): N dimension of input tensor.\n        K (int): K dimension of input tensor.\n        A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A\n        B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B\n        stride_am (int): Stride of M dimension of A.\n        stride_ak (int): Stride of K dimension of A.\n        stride_bn (int): Stride of N dimension of B.\n        stride_bk (int): Stride of K dimension of B.\n        stride_cm (int): Stride of M dimension of C.\n        stride_cn (int): Stride of N dimension of C.\n        dot_out_dtype (torch.dtype): Output type of tensor core.\n        allow_tf32 (bool): Whether to use TF32 for tensor core.\n        fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.\n        BLOCK_M (int): Block size for M dimension.\n        BLOCK_N (int): Block size for N dimension.\n        BLOCK_K (int): Block size for K dimension.\n        GROUP_M (int): Number of groups for M dimension swizzle.\n        SPLIT_K (int): Number of SM's to launch per row.\n        EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.\n        AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.\n    \"\"\"\n    # Matrix multiplication.\n    start_pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(M, BLOCK_M)\n    num_pid_n = tl.cdiv(N, BLOCK_N)\n    k_tiles = tl.cdiv(K, BLOCK_K)\n    num_tiles = num_pid_m * num_pid_n\n\n    tiles_per_SM = num_tiles // NUM_SMS\n    if start_pid < num_tiles % NUM_SMS:\n        tiles_per_SM += 1\n\n    tile_id = start_pid - NUM_SMS\n    ki = -1\n\n    pid_m = 0\n    pid_n = 0\n    offs_am = 0\n    offs_bn = 0\n\n    num_pid_in_group = GROUP_M * num_pid_n\n\n    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)\n\n    dtype_fp8 = tl.float8e4nv\n    scale_dtype = tl.float32\n\n    for _ in range(0, k_tiles * tiles_per_SM):\n        ki = tl.where(ki == k_tiles - 1, 0, ki + 1)\n        if ki == 0:\n            tile_id += NUM_SMS\n            group_id = tile_id // num_pid_in_group\n            first_pid_m = group_id * GROUP_M\n            group_size_m = min(num_pid_m - first_pid_m, GROUP_M)\n            pid_m = first_pid_m + (tile_id % group_size_m)\n            pid_n = (tile_id % num_pid_in_group) // group_size_m\n\n            offs_am = pid_m * BLOCK_M\n            offs_bn = pid_n * BLOCK_N\n            offs_am = tl.multiple_of(offs_am, BLOCK_M)\n            offs_bn = tl.multiple_of(offs_bn, BLOCK_N)\n\n        offs_k = ki * BLOCK_K\n\n        a = tl._experimental_descriptor_load(\n            A_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], dtype_fp8\n        )\n        b = tl._experimental_descriptor_load(\n            B_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], dtype_fp8\n        )\n        acc = tl.dot(a, b.T, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)\n\n        if ki == k_tiles - 1:\n            # rematerialize rm and rn to save registers\n            rm = pid_m * BLOCK_M\n            rn = pid_n * BLOCK_N\n\n            # # Invert scaling.\n            a_scale = tl._experimental_descriptor_load(\n                A_scale, [rm], [BLOCK_M], scale_dtype\n            )\n            b_scale = tl._experimental_descriptor_load(\n                B_scale, [rn], [BLOCK_N], scale_dtype\n            )\n            # pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.\n            scale = a_scale[:, None] * b_scale[None, :]\n            acc *= scale\n            acc = acc.to(C_ptr.dtype.element_ty)\n\n            tl._experimental_descriptor_store(C_ptr, acc, [rm, rn])\n            acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)\n\n\ndef matmul_fp8_row(\n    a: torch.Tensor,\n    b: torch.Tensor,\n    a_scale: torch.Tensor,\n    b_scale: torch.Tensor,\n    dot_out_dtype: Optional[torch.dtype] = None,\n    allow_tf32: bool = True,\n    fp8_fast_accum: bool = True,\n    imprecise_acc: bool = False,\n    tma_persistent: bool = False,\n) -> torch.Tensor:\n    \"\"\"\n    Performs matmul on [M, K] and [N, K] fp8 matrices with row-wise scalings [M], [N].\n\n    Args:\n        a (torch.Tensor): [M, K] input tensor.\n        b (torch.Tensor): [N, K] input tensor.\n        a_scale (torch.Tensor): [M] reciprocal scale tensor per row. A * a_scale = original A\n        b_scale (torch.Tensor): [N] reciprocal scale tensor per row. B * b_scale = original B\n        dot_out_dtype (torch.dtype): Output type of tensor core.\n        allow_tf32 (bool): Whether to use TF32 for tensor core.\n        fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.\n        tma_persistent (bool): Whether to use TMA persistent kernel impl.\n\n    Returns:\n        torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :])\n    \"\"\"\n    # Get datatypes and constants to use.\n    _, tl_dtype, _, _ = get_fp8_constants()\n    # Reinterpret inputs into proper triton fp8 dtype.\n    a_tl = convert_fp8_type(a, tl_dtype)\n    b_tl = convert_fp8_type(b, tl_dtype)\n    M, N, K, m_key, n_key, k_key, c, dot_out_dtype_triton, device = prep_matmul(\n        a_tl, b_tl, dot_out_dtype\n    )\n    # launch kernel\n    if a.device == torch.device(\"cpu\"):\n        logger.info(\n            \"FP8 Row-wise Triton kernel not supported on cpu, fallback to torch\"\n        )\n        return (\n            torch.matmul(a.to(torch.bfloat16), b.to(torch.bfloat16).T)\n            * (a_scale[:, None] * b_scale[None, :])\n        ).to(dtype=c.dtype)\n\n    def grid(META):\n        return (\n            triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),\n            META[\"SPLIT_K\"],\n        )\n\n    NUM_SMS = torch.cuda.get_device_properties(\"cuda\").multi_processor_count\n\n    def persistent_grid(META):\n        return (\n            min(\n                NUM_SMS,\n                triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),\n            ),\n        )\n\n    if tma_persistent:\n        # used by TMA persistent kernel\n        TMA_SIZE = 128\n        import numpy as np\n\n        # autotune doesn't work with TMA\n        # https://github.com/triton-lang/triton/blob/main/python/tutorials/09-persistent-matmul.py#L312\n\n        BLOCK_M = 128\n        BLOCK_N = 256\n        BLOCK_K = 128\n        GROUP_M = 8\n        num_stages = 3\n        num_warps = 8\n\n        desc_a = np.empty(TMA_SIZE, dtype=np.int8)\n        desc_b = np.empty(TMA_SIZE, dtype=np.int8)\n        desc_c = np.empty(TMA_SIZE, dtype=np.int8)\n        desc_a_scale = np.empty(TMA_SIZE, dtype=np.int8)\n        desc_b_scale = np.empty(TMA_SIZE, dtype=np.int8)\n\n        triton.runtime.driver.active.utils.fill_2d_tma_descriptor(\n            a_tl.data_ptr(),\n            M,\n            K,\n            BLOCK_M,\n            BLOCK_K,\n            a_tl.element_size(),\n            desc_a,\n        )\n        triton.runtime.driver.active.utils.fill_2d_tma_descriptor(\n            b_tl.data_ptr(),\n            N,\n            K,\n            BLOCK_N,\n            BLOCK_K,\n            b_tl.element_size(),\n            desc_b,\n        )\n        triton.runtime.driver.active.utils.fill_2d_tma_descriptor(\n            c.data_ptr(),\n            M,\n            N,\n            BLOCK_M,\n            BLOCK_N,\n            c.element_size(),\n            desc_c,\n        )\n        triton.runtime.driver.active.utils.fill_1d_tma_descriptor(\n            a_scale.data_ptr(),\n            M,\n            BLOCK_M,\n            a_scale.element_size(),\n            desc_a_scale,\n        )\n        triton.runtime.driver.active.utils.fill_1d_tma_descriptor(\n            b_scale.data_ptr(),\n            N,\n            BLOCK_N,\n            b_scale.element_size(),\n            desc_b_scale,\n        )\n        desc_a = torch.tensor(desc_a, device=\"cuda\")\n        desc_b = torch.tensor(desc_b, device=\"cuda\")\n        desc_c = torch.tensor(desc_c, device=\"cuda\")\n        desc_a_scale = torch.tensor(desc_a_scale, device=\"cuda\")\n        desc_b_scale = torch.tensor(desc_b_scale, device=\"cuda\")\n\n        # pyre-ignore[28]:\n        _kernel_matmul_fp8_row_tma_persistent[persistent_grid](\n            desc_a,\n            desc_b,\n            desc_c,\n            M,\n            N,\n            K,\n            desc_a_scale,\n            desc_b_scale,\n            a.stride(0),\n            a.stride(1),\n            b.stride(0),\n            b.stride(1),\n            c.stride(0),\n            c.stride(1),\n            dot_out_dtype=dot_out_dtype_triton,\n            allow_tf32=allow_tf32,\n            fp8_fast_accum=fp8_fast_accum,\n            BLOCK_M=BLOCK_M,\n            BLOCK_N=BLOCK_N,\n            BLOCK_K=BLOCK_K,\n            GROUP_M=GROUP_M,\n            AB_DTYPE=False,\n            NUM_SMS=NUM_SMS,\n            num_stages=num_stages,\n            num_warps=num_warps,\n        )\n        return c\n"
  },
  {
    "path": "kernels/triton/inference/README.md",
    "content": "Triton Inference kernels\n"
  },
  {
    "path": "kernels/triton/inference/col_major_moe_gemm/README.md",
    "content": "\n**MoE (Mixture of Experts) GEMM Kernels**\n\n\nTriton kernel supporting and accelerating MoE inference (Mixtral).\nThis kernel was contributed by IBM Research.\n\nThis kernel showcases the following optimizations:\n\n* Column-Major Launch Schedule (L2 Cache Optimization)\n* SplitK Work Decomposition (Parallel Work Strategy Optimization)\n\nSee blog post: https://pytorch.org/blog/accelerating-moe-model/\n\n\n* v0 = grouped MM\n* v1 = SplitK MM\n* v2 = Col Major MM\n\nThis requires vLLM to be installed to run.\n"
  },
  {
    "path": "kernels/triton/inference/col_major_moe_gemm/perf_test_moe.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport pytest\nimport torch\nimport triton\nfrom vllm.model_executor.layers.fused_moe import fused_moe\nfrom vllm.model_executor.layers.activation import SiluAndMul\nfrom v0_moe_fused import fused_moe as fused_moe_grouped\nfrom v2_moe_fused import fused_moe as fused_moe_col\nimport time\n\ndef torch_moe(a, w1, w2, topk_weight, topk_ids):\n    B, D = a.shape\n    a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D)\n    out = torch.zeros(B * topk_ids.shape[1],\n                      w2.shape[1],\n                      dtype=a.dtype,\n                      device=a.device)\n\n    topk_ids = topk_ids.view(-1)\n    topk_weight = topk_weight.view(-1)\n    for i in range(w1.shape[0]):\n        mask = topk_ids == i\n        if mask.sum():\n            out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)\n    return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1)).sum(dim=1)\n\n\ndef test_fused_moe(\n    m: int,\n    n: int,\n    k: int,\n    e: int,\n    topk: int,\n    dtype: torch.dtype,\n):\n    torch.cuda.manual_seed(3227)\n\n    a = torch.randn((m, k), device='cuda', dtype=dtype) / 10\n    w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10\n    w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10\n\n    score = torch.randn((m, e), device='cuda', dtype=dtype)\n    score = torch.softmax(score, dim=-1)\n\n    topk_weight, topk_ids = torch.topk(score, topk)\n\n    start = time.time()\n    triton_output_gl = fused_moe_grouped(a, w1, w2, topk_weight, topk_ids, False)\n    end = time.time()\n    gl_time = end - start\n    gl_time = gl_time * 1000\n    print(\"Grouped Launch Time (us): \", gl_time)\n\n    start = time.time()\n    triton_output_cm = fused_moe_col(a, w1, w2, topk_weight, topk_ids, False)\n    end = time.time()\n    cm_major_time = end - start\n    cm_major_time = cm_major_time * 1000\n    print(\"Columm Major Time (us): \", cm_major_time)\n\n    torch_base = torch_moe(a, w1, w2, topk_weight, topk_ids)\n    torch.testing.assert_close(triton_output_cm, torch_base, atol=1e-2, rtol=0)\n\n    # print(f\"{triton_output_cm=}\\n\")\n    # print(f\"{triton_output_gl=}\\n\")\n\n    print(f\"Col Major Speedup {((gl_time - cm_major_time)/(gl_time))*100}\")\n\n\nif __name__ == '__main__':\n\n\n    # test_fused_moe(512, 14336//2, 4096, 8, 2, torch.float16)\n\n    @triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=['m'],  # Argument names to use as an x-axis for the plot\n        x_vals=[\n            2**i for i in range(0, 10)\n        ],  # Different possible values for `x_name`\n        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot\n        # Possible values for `line_arg`\n        line_vals=['cm', 'gl'],\n        # Label name for the lines\n        line_names=[\"Fused MoE GEMM Kernel - Column Major\", \"vLLM MoE GEMM Kernel\"],\n\n        # Line styles\n        styles=[('blue', '-'), ('green', '-')],\n        ylabel=\"TFLOPS\",  # Label name for the y-axis\n        plot_name=\"test\",  # Name for the plot, used also as a file name for saving the plot.\n        args={},\n    )\n)\n    def benchmark(m, provider):\n\n        m = m\n        n = 14336//2\n        k = 4096\n        e = 8\n        topk = 2\n\n        torch.cuda.manual_seed(3227)\n        dtype = torch.float16\n\n        a = torch.randn((m, k), device='cuda', dtype=dtype) / 10\n        w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10\n        w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10\n\n        score = torch.randn((m, e), device='cuda', dtype=dtype)\n        score = torch.softmax(score, dim=-1)\n        topk_weight, topk_ids = torch.topk(score, topk)\n\n        quantiles = [0.5, 0.2, 0.8]\n        if provider == 'cm':\n            ms, min_ms, max_ms = triton.testing.do_bench(lambda: fused_moe_col(a, w1, w2, topk_weight, topk_ids, False), quantiles=quantiles)\n        if provider == 'gl':\n            ms, min_ms, max_ms = triton.testing.do_bench(lambda: fused_moe_grouped(a, w1, w2, topk_weight, topk_ids, False), quantiles=quantiles)\n        perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3)\n        return perf(ms), perf(max_ms), perf(min_ms)\n\nbenchmark.run(show_plots=True, print_data=True, save_path='./')\n"
  },
  {
    "path": "kernels/triton/inference/col_major_moe_gemm/profile_moe.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport pytest\nimport torch\nfrom vllm.model_executor.layers.fused_moe import fused_moe\nfrom vllm.model_executor.layers.activation import SiluAndMul\nfrom v0_moe_fused import fused_moe as fused_moe_base\nfrom triton.kernels.mixtral.v1_moe_fused import fused_moe\nimport time\n\ndef torch_moe(a, w1, w2, topk_weight, topk_ids):\n    B, D = a.shape\n    a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D)\n    out = torch.zeros(B * topk_ids.shape[1],\n                      w2.shape[1],\n                      dtype=a.dtype,\n                      device=a.device)\n\n    topk_ids = topk_ids.view(-1)\n    topk_weight = topk_weight.view(-1)\n    for i in range(w1.shape[0]):\n        mask = topk_ids == i\n        if mask.sum():\n            out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)\n    return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1)).sum(dim=1)\n\n\ndef test_fused_moe(\n    m: int,\n    n: int,\n    k: int,\n    e: int,\n    topk: int,\n    dtype: torch.dtype,\n):\n\n    a = torch.randn((m, k), device='cuda', dtype=dtype) / 10\n    w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10\n    w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10\n\n    score = torch.randn((m, e), device='cuda', dtype=dtype)\n    score = torch.softmax(score, dim=-1)\n    topk_weight, topk_ids = torch.topk(score, topk)\n\n    triton_output_splitk = fused_moe(a, w1, w2, topk_weight, topk_ids, False)\n    triton_output_base = fused_moe_base(a, w1, w2, topk_weight, topk_ids, False)\n\n\nif __name__ == '__main__':\n\n    test_fused_moe(2, 14336//2, 4096, 8, 2, torch.float16)\n"
  },
  {
    "path": "kernels/triton/inference/col_major_moe_gemm/results.html",
    "content": "<html><body>\n<image src=\"test.png\"/>\n</body></html>\n"
  },
  {
    "path": "kernels/triton/inference/col_major_moe_gemm/test.csv",
    "content": "m,Fused MoE GEMM Kernel - Column Major,vLLM MoE GEMM Kernel\n1.000000,0.412454,0.259585\n2.000000,0.883064,0.269004\n4.000000,1.751380,0.447645\n8.000000,2.106783,0.571765\n16.000000,4.121877,1.002326\n32.000000,8.259988,1.991226\n64.000000,16.105391,3.879061\n128.000000,29.356460,7.191373\n256.000000,50.550095,12.524316\n512.000000,72.862390,19.934314\n"
  },
  {
    "path": "kernels/triton/inference/col_major_moe_gemm/test_moe_gemm.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport pytest\nimport torch\nfrom vllm.model_executor.layers.fused_moe import fused_moe\nfrom vllm.model_executor.layers.activation import SiluAndMul\nfrom v0_moe_fused import fused_moe as fused_moe_v0\nfrom v1_moe_fused import fused_moe as fused_moe_v1\nfrom splitk_moe_fused import fused_moe\nimport time\n\ndef torch_moe(a, w1, w2, topk_weight, topk_ids):\n    B, D = a.shape\n    a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D)\n    out = torch.zeros(B * topk_ids.shape[1],\n                      w2.shape[1],\n                      dtype=a.dtype,\n                      device=a.device)\n\n    topk_ids = topk_ids.view(-1)\n    topk_weight = topk_weight.view(-1)\n    for i in range(w1.shape[0]):\n        mask = topk_ids == i\n        if mask.sum():\n            out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)\n    return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1)).sum(dim=1)\n\n\n@pytest.mark.parametrize(\"m\", [2, 4, 8, 16, 32, 64, 128, 512, 1024, 2048])\n@pytest.mark.parametrize(\"n\", [14336//2])\n@pytest.mark.parametrize(\"k\", [4096])\n@pytest.mark.parametrize(\"e\", [8])\n@pytest.mark.parametrize(\"topk\", [2])\n@pytest.mark.parametrize(\"dtype\", [torch.float16])\ndef test_fused_moe(\n    m: int,\n    n: int,\n    k: int,\n    e: int,\n    topk: int,\n    dtype: torch.dtype,\n):\n\n    torch.cuda.manual_seed(3227)\n    a = torch.randn((m, k), device='cuda', dtype=dtype) / 10\n    w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10\n\n    w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10\n\n    score = torch.randn((m, e), device='cuda', dtype=dtype)\n    score = torch.softmax(score, dim=-1)\n\n    topk_weight, topk_ids = torch.topk(score, topk)\n\n    start = time.time()\n    triton_output_gl = fused_moe_v0(a, w1, w2, topk_weight, topk_ids, False)\n    end = time.time()\n\n    gl_time = end - start\n    gl_time = gl_time * 1000\n    print(\"Grouped Launch Time (us): \\n\", gl_time)\n\n\n    start = time.time()\n    triton_output_cm = fused_moe_v1(a, w1, w2, topk_weight, topk_ids, False)\n    end = time.time()\n    cm_major_time = end - start\n    cm_major_time = cm_major_time * 1000\n    print(\"Columm Major Time (us): \\n\", cm_major_time)\n\n\n    torch_base = torch_moe(a, w1, w2, topk_weight, topk_ids)\n\n    assert torch.allclose(triton_output_cm, torch_base, atol=1e-2, rtol=0)\n    assert torch.allclose(triton_output_cm, triton_output_gl, atol=1e-2, rtol=0)\n\n    # print(f\"{triton_output_cm=}\\n\")\n    # print(f\"{triton_output_gl=}\\n\")\n    # print(f\"{torch_base=}\\n\")\n\n    print(f\"Col Major Speedup: {((gl_time/cm_major_time))} x\\n\")\n"
  },
  {
    "path": "kernels/triton/inference/col_major_moe_gemm/v0_moe_fused.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\n# Credit:\n# Woosuk vLLM: https://github.com/vllm-project/vllm/blob/3d925165f2b18379640a63fbb42de95440d63b64/vllm/model_executor/layers/fused_moe/fused_moe.py\n\n\"\"\"Fused MoE kernel.\"\"\"\nimport torch\nimport triton\nimport triton.language as tl\nfrom vllm._C import ops\n\n\n@triton.jit\ndef fused_moe_kernel(\n    # Pointers to matrices\n    a_ptr,\n    b_ptr,\n    c_ptr,\n    topk_weights_ptr,\n    sorted_token_ids_ptr,\n    expert_ids_ptr,\n    num_tokens_post_padded_ptr,\n    # Matrix dimensions\n    N,\n    K,\n    EM,\n    num_valid_tokens,\n    # The stride variables represent how much to increase the ptr by when moving by 1\n    # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`\n    # by to get the element one row down (A has M rows).\n    stride_am,\n    stride_ak,\n    stride_be,\n    stride_bk,\n    stride_bn,\n    stride_cm,\n    stride_cn,\n    stride_weight,\n    stride_token_id,\n    # Meta-parameters\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n    MUL_ROUTED_WEIGHT: tl.constexpr,\n    top_k: tl.constexpr,\n    compute_type: tl.constexpr,\n):\n    \"\"\"\n    Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices.\n\n    Key Parameters:\n    - A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token.\n    - B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension.\n    - C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated,\n        and N is the output feature dimension.\n    - sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to.\n    - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A.\n    This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids`\n    by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert.\n    \"\"\"\n    # -----------------------------------------------------------\n    # Map program ids `pid` to the block of C it should compute.\n    # This is done in a grouped ordering to promote L2 data reuse.\n    pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n    num_pid_in_group = GROUP_SIZE_M * num_pid_n\n    group_id = pid // num_pid_in_group\n    first_pid_m = group_id * GROUP_SIZE_M\n    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n    pid_n = (pid % num_pid_in_group) // group_size_m\n\n    # ----------------------------------------------------------\n    # Create pointers for the first blocks of A and B.\n    # We will advance this pointer as we move in the K direction\n    # and accumulate\n    # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers\n    # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers\n    num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n    if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:\n        return\n\n    offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n    token_mask = offs_token < num_valid_tokens\n\n    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak)\n\n    off_experts = tl.load(expert_ids_ptr + pid_m)\n    b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n    # -----------------------------------------------------------\n    # Iterate to compute a block of the C matrix.\n    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block\n    # of fp32 values for higher accuracy.\n    # `accumulator` will be converted back to fp16 after the loop.\n    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n        # Load the next block of A and B, generate a mask by checking the K dimension.\n        a = tl.load(a_ptrs,\n                    mask=token_mask[:, None] &\n                    (offs_k[None, :] < K - k * BLOCK_SIZE_K),\n                    other=0.0)\n        b = tl.load(b_ptrs,\n                    mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,\n                    other=0.0)\n        # We accumulate along the K dimension.\n        accumulator += tl.dot(a, b)\n        # Advance the ptrs to the next K block.\n        a_ptrs += BLOCK_SIZE_K * stride_ak\n        b_ptrs += BLOCK_SIZE_K * stride_bk\n\n    if MUL_ROUTED_WEIGHT:\n        moe_weight = tl.load(topk_weights_ptr + offs_token * stride_weight,\n                             mask=token_mask,\n                             other=0)\n        accumulator = accumulator * moe_weight[:, None]\n\n    accumulator = accumulator.to(compute_type)\n    # -----------------------------------------------------------\n    # Write back the block of the output\n    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[\n        None, :]\n    c_mask = token_mask[:, None] & (offs_cn[None, :] < N)\n    tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef moe_align_block_size(\n        topk_ids: torch.Tensor, block_size: int,\n        num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor):\n    \"\"\"\n    Aligns the token distribution across experts to be compatible with block size for matrix multiplication.\n\n    Parameters:\n    - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token.\n    - block_size: The block size used in block matrix multiplication.\n    - num_experts: The total number of experts.\n\n    Returns:\n    - sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert.\n    - expert_ids: A tensor indicating the assigned expert index for each block.\n    - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size.\n\n    This function pads the number of tokens that each expert needs to process so that it is divisible by block_size.\n    Padding ensures that during block matrix multiplication, the dimensions align correctly.\n\n    Example:\n    Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4:\n    - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens.\n    - As block_size is 4, we pad 1 token for each expert.\n    - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].\n    - Then append padding tokens [12, 12, 12, 12] for each block.\n    - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].\n        Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication.\n    - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations.\n    \"\"\"\n    sorted_ids = torch.empty(\n        (topk_ids.numel() + num_experts * (block_size - 1), ),\n        dtype=torch.int32,\n        device=topk_ids.device)\n    expert_ids = torch.empty((topk_ids.numel() + num_experts, ),\n                             dtype=torch.int32,\n                             device=topk_ids.device)\n    sorted_ids.fill_(topk_ids.numel())\n    num_tokens_post_pad = torch.empty((1),\n                                      dtype=torch.int32,\n                                      device=topk_ids.device)\n    ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,\n                             expert_ids, num_tokens_post_pad)\n    return sorted_ids, expert_ids, num_tokens_post_pad\n\n\ndef invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,\n                            topk_weights: torch.Tensor, topk_ids: torch.Tensor,\n                            sorted_token_ids: torch.Tensor,\n                            expert_ids: torch.Tensor,\n                            num_tokens_post_padded: torch.Tensor,\n                            mul_routed_weight: bool, top_k: int, config: dict):\n\n    grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )\n\n    print(f\"Base {config}\\n\")\n\n    fused_moe_kernel[grid](\n        A,\n        B,\n        C,\n        topk_weights,\n        sorted_token_ids,\n        expert_ids,\n        num_tokens_post_padded,\n        B.shape[1],\n        B.shape[2],\n        sorted_token_ids.shape[0],\n        topk_ids.numel(),\n        A.stride(0),\n        A.stride(1),\n        B.stride(0),\n        B.stride(2),\n        B.stride(1),\n        C.stride(1),\n        C.stride(2),\n        topk_weights.stride(1),\n        sorted_token_ids.stride(0),\n        MUL_ROUTED_WEIGHT=mul_routed_weight,\n        top_k=top_k,\n        compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,\n        **config,\n    )\n\n\ndef fused_moe(hidden_states: torch.Tensor,\n              w1: torch.Tensor,\n              w2: torch.Tensor,\n              topk_weights: torch.Tensor,\n              topk_ids: torch.Tensor,\n              inplace=False):\n    \"\"\"\n    This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.\n\n    Parameters:\n    - hidden_states (torch.Tensor): The input tensor to the MoE layer.\n    - w1 (torch.Tensor): The first set of expert weights.\n    - w2 (torch.Tensor): The second set of expert weights.\n    - topk_weights (torch.Tensor): The weights for the top-k selected experts.\n    - topk_ids (torch.Tensor): The indices of the top-k selected experts.\n    - inplace (bool): If True, perform the operation in-place. Defaults to False.\n\n    Returns:\n    - torch.Tensor: The output tensor after applying the MoE layer.\n    \"\"\"\n    # Check constraints.\n    assert hidden_states.shape[1] == w1.shape[2], \"Incompatible dimensions\"\n    assert hidden_states.is_contiguous(), \"Hidden_states must be contiguous\"\n    assert w1.is_contiguous(), \"Expert weights1 must be contiguous\"\n    assert w2.is_contiguous(), \"Expert weights2 must be contiguous\"\n    assert hidden_states.dtype in [torch.float16, torch.bfloat16]\n    M, _ = hidden_states.shape\n    E, N, _ = w1.shape\n\n    config = {\n        'BLOCK_SIZE_M': 64,\n        'BLOCK_SIZE_N': 64,\n        'BLOCK_SIZE_K': 32,\n        'GROUP_SIZE_M': 8\n    }\n\n    if topk_ids.numel() <= w1.shape[0]:\n        config = {\n            'BLOCK_SIZE_M': 16,\n            'BLOCK_SIZE_N': 32,\n            'BLOCK_SIZE_K': 64,\n            'GROUP_SIZE_M': 1\n        }\n\n    intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),\n                                      device=hidden_states.device,\n                                      dtype=hidden_states.dtype)\n    intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),\n                                      device=hidden_states.device,\n                                      dtype=hidden_states.dtype)\n    intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),\n                                      device=hidden_states.device,\n                                      dtype=hidden_states.dtype)\n\n    sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(\n        topk_ids, config['BLOCK_SIZE_M'], E)\n\n    invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1,\n                            topk_weights, topk_ids, sorted_token_ids,\n                            expert_ids, num_tokens_post_padded, False,\n                            topk_ids.shape[1], config)\n\n    ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))\n\n    invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3,\n                            topk_weights, topk_ids, sorted_token_ids,\n                            expert_ids, num_tokens_post_padded, True, 1,\n                            config)\n\n    if inplace:\n        return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),\n                         dim=1,\n                         out=hidden_states)\n    return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),\n                     dim=1)\n"
  },
  {
    "path": "kernels/triton/inference/col_major_moe_gemm/v1_moe_fused.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\n# Credit:\n# Woosuk vLLM: https://github.com/vllm-project/vllm/blob/3d925165f2b18379640a63fbb42de95440d63b64/vllm/model_executor/layers/fused_moe/fused_moe.py\n\n\"\"\"Fused MoE kernel.\"\"\"\nimport torch\nimport triton\nimport triton.language as tl\nfrom vllm._C import ops\nfrom typing import Any, Dict, Optional\nimport functools\nimport json\nimport os\n\n@triton.jit()\ndef grouped_launch(pid,\n                m, n,\n                block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr):\n\n    grid_m = tl.cdiv(m, block_m)\n    grid_n = tl.cdiv(n, block_n)\n\n    width = group_m * grid_n\n    group_id = pid // width\n    group_size = tl.minimum(grid_m - group_id * group_m, group_m)\n\n    pid_m = group_id * group_m + (pid % group_size)\n    pid_n = (pid % width) // group_size\n\n    return pid_m, pid_n\n\n\n@triton.jit()\ndef fused_moe_kernel_splitk(\n    # Pointers to matrices\n    a_ptr,\n    b_ptr,\n    c_ptr,\n    topk_weights_ptr,\n    sorted_token_ids_ptr,\n    expert_ids_ptr,\n    num_tokens_post_padded_ptr,\n    # Matrix dimensions\n    N,\n    K,\n    EM,\n    num_valid_tokens,\n    # The stride variables represent how much to increase the ptr by when moving by 1\n    # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`\n    # by to get the element one row down (A has M rows).\n    stride_am,\n    stride_ak,\n    stride_be,\n    stride_bk,\n    stride_bn,\n    stride_cm,\n    stride_cn,\n    stride_weight,\n    stride_token_id,\n    # Meta-parameters\n    block_m: tl.constexpr,\n    block_n: tl.constexpr,\n    block_k: tl.constexpr,\n    group_m: tl.constexpr,\n    split_k: tl.constexpr,\n    MUL_ROUTED_WEIGHT: tl.constexpr,\n    top_k: tl.constexpr,\n    compute_type: tl.constexpr,\n):\n    \"\"\"\n    Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices.\n\n    Key Parameters:\n    - A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token.\n    - B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension.\n    - C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated,\n        and N is the output feature dimension.\n    - sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to.\n    - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A.\n    This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids`\n    by expert index and padding ensures divisibility by block_m, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert.\n    \"\"\"\n    # -----------------------------------------------------------\n    # Map program ids `pid` to the block of C it should compute.\n    # This is done in a grouped ordering to promote L2 data reuse.\n\n    # Scheduling Problem\n\n    pid = tl.program_id(axis=0)\n    pid_k = tl.program_id(axis=1)\n\n    num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n\n    # print(\"num_tokens_post_padded: \", num_tokens_post_padded)\n\n    pid_m, pid_n = grouped_launch(pid,\n                                EM, N,\n                                block_m, block_n, group_m)\n\n    total_blocks_k = tl.cdiv(K, block_k*split_k)\n\n    if pid_m * block_m >= num_tokens_post_padded:\n        return\n\n    offs_token_id = pid_m * block_m + tl.arange(0, block_m)\n    offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n    token_mask = offs_token < num_valid_tokens\n\n    offs_bn = (pid_n * block_n + tl.arange(0, block_n)) % N\n    offs_k = pid_k*block_k + tl.arange(0, block_k)\n    a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak)\n\n    off_experts = tl.load(expert_ids_ptr + pid_m)\n    b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n    accumulator = tl.zeros((block_m, block_n), dtype=tl.float32)\n    for k in range(0, total_blocks_k):\n        a = tl.load(a_ptrs,\n                    mask=token_mask[:, None] & (offs_k[None, :] < K - k * (block_k * split_k)),\n                    other=0.0)\n        b = tl.load(b_ptrs,\n                    mask=offs_k[:, None] < K - k * (block_k * split_k),\n                    other=0.0)\n        # We accumulate along the K dimension.\n        accumulator += tl.dot(a, b)\n        # Advance the ptrs to the next K block.\n        a_ptrs += block_k * stride_ak * split_k\n        b_ptrs += block_k * stride_bk * split_k\n\n    if MUL_ROUTED_WEIGHT:\n        moe_weight = tl.load(topk_weights_ptr + offs_token * stride_weight,\n                             mask=token_mask,\n                             other=0)\n        accumulator = accumulator * moe_weight[:, None]\n\n    accumulator = accumulator.to(compute_type)\n    # -----------------------------------------------------------\n    # Write back the block of the output\n    offs_cn = pid_n * block_n + tl.arange(0, block_n)\n    c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]\n    c_mask = token_mask[:, None] & (offs_cn[None, :] < N)\n    tl.atomic_add(c_ptrs, accumulator, mask=c_mask)\n\n\ndef moe_align_block_size(\n        topk_ids: torch.Tensor, block_size: int,\n        num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor):\n    \"\"\"\n    Aligns the token distribution across experts to be compatible with block size for matrix multiplication.\n\n    Parameters:\n    - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token.\n    - block_size: The block size used in block matrix multiplication.\n    - num_experts: The total number of experts.\n\n    Returns:\n    - sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert.\n    - expert_ids: A tensor indicating the assigned expert index for each block.\n    - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size.\n\n    This function pads the number of tokens that each expert needs to process so that it is divisible by block_size.\n    Padding ensures that during block matrix multiplication, the dimensions align correctly.\n\n    Example:\n    Given topk_ids = [[2, 3, ], [1, 2], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4:\n    - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens.\n    - As block_size is 4, we pad 1 token for each expert.\n    - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].\n    - Then append padding tokens [12, 12, 12, 12] for each block.\n    - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].\n        Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication.\n    - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations.\n    \"\"\"\n    sorted_ids = torch.empty(\n        (topk_ids.numel() + num_experts * (block_size - 1), ),\n        dtype=torch.int32,\n        device=topk_ids.device)\n    expert_ids = torch.empty((topk_ids.numel() + num_experts, ),\n                             dtype=torch.int32,\n                             device=topk_ids.device)\n    sorted_ids.fill_(topk_ids.numel())\n    num_tokens_post_pad = torch.empty((1),\n                                      dtype=torch.int32,\n                                      device=topk_ids.device)\n    ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,\n                             expert_ids, num_tokens_post_pad)\n    return sorted_ids, expert_ids, num_tokens_post_pad\n\n\ndef invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,\n                            topk_weights: torch.Tensor, topk_ids: torch.Tensor,\n                            sorted_token_ids: torch.Tensor,\n                            expert_ids: torch.Tensor,\n                            num_tokens_post_padded: torch.Tensor,\n                            mul_routed_weight: bool, top_k: int, config: dict):\n\n    N = B.shape[1] # 14336\n    K = B.shape[2] # 4096\n    EM = sorted_token_ids.shape[0] # 124\n\n    grid = lambda META: (triton.cdiv(EM, META['block_m']) * triton.cdiv(N, META['block_n']), META['split_k'])\n\n    # print(f\"SplitK {config}\\n\")\n    k = fused_moe_kernel_splitk[grid](\n        A,\n        B,\n        C,\n        topk_weights,\n        sorted_token_ids,\n        expert_ids,\n        num_tokens_post_padded, # 64\n        N,\n        K,\n        EM,\n        topk_ids.numel(),\n        A.stride(0),\n        A.stride(1),\n        B.stride(0),\n        B.stride(2),\n        B.stride(1),\n        C.stride(1),\n        C.stride(2),\n        topk_weights.stride(1),\n        sorted_token_ids.stride(0),\n        MUL_ROUTED_WEIGHT=mul_routed_weight,\n        top_k=top_k,\n        compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,\n        **config,\n        num_warps=8,\n    )\n\n    # print(f\"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\\n\")\n\n    # with open('split_k_moe_ttir.txt', 'w') as f:\n\n    #     print(f\"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\\n\", file=f)\n    #     print(\"IR\", k.asm['ttir'], file=f)\n    #     print(\"TTGIR\", k.asm['ttgir'], file=f)\n    #     print(\"PTX\", k.asm['ptx'], file=f)\n    #     print(f\"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\\n\", file=f)\n\n\n\ndef fused_moe(hidden_states: torch.Tensor,\n              w1: torch.Tensor,\n              w2: torch.Tensor,\n              topk_weights: torch.Tensor,\n              topk_ids: torch.Tensor,\n              inplace=False):\n    \"\"\"\n    This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.\n\n    Parameters:\n    - hidden_states (torch.Tensor): The input tensor to the MoE layer.\n    - w1 (torch.Tensor): The first set of expert weights.\n    - w2 (torch.Tensor): The second set of expert weights.\n    - topk_weights (torch.Tensor): The weights for the top-k selected experts.\n    - topk_ids (torch.Tensor): The indices of the top-k selected experts.\n    - inplace (bool): If True, perform the operation in-place. Defaults to False.\n\n    Returns:\n    - torch.Tensor: The output tensor after applying the MoE layer.\n    \"\"\"\n    # Check constraints.\n    assert hidden_states.shape[1] == w1.shape[2], \"Incompatible dimensions\"\n    assert hidden_states.is_contiguous(), \"Hidden_states must be contiguous\"\n    assert w1.is_contiguous(), \"Expert weights1 must be contiguous\"\n    assert w2.is_contiguous(), \"Expert weights2 must be contiguous\"\n    assert hidden_states.dtype in [torch.float16, torch.bfloat16]\n    M, _ = hidden_states.shape\n    E, N, _ = w1.shape\n\n\n    # Prefill\n    config_w1 = {\n        'block_m': 32,\n        'block_n': 64,\n        'block_k': 64,\n        'group_m': 8,\n        'split_k': 2,\n    }\n\n    config_w2 = {\n        'block_m': 32,\n        'block_n': 64,\n        'block_k': 64,\n        'group_m': 8,\n        'split_k': 2,\n    }\n\n    # Decoding\n    if topk_ids.numel() <= w1.shape[0]:\n        config_w1 = {\n            'block_m': 16,\n            'block_n': 64,\n            'block_k': 128,\n            'group_m': 8,\n            'split_k' : 2,\n        }\n\n        config_w2 = {\n            'block_m': 16,\n            'block_n': 128,\n            'block_k': 64,\n            'group_m': 8,\n            'split_k': 4,\n        }\n\n    intermediate_cache1 = torch.zeros((M, topk_ids.shape[1], N),\n                                      device=hidden_states.device,\n                                      dtype=hidden_states.dtype)\n    intermediate_cache2 = torch.zeros((M * topk_ids.shape[1], N // 2),\n                                      device=hidden_states.device,\n                                      dtype=hidden_states.dtype)\n    intermediate_cache3 = torch.zeros((M, topk_ids.shape[1], w2.shape[1]),\n                                      device=hidden_states.device,\n                                      dtype=hidden_states.dtype)\n\n    sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(\n        topk_ids, config_w1['block_m'], E)\n\n    invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1,\n                            topk_weights, topk_ids, sorted_token_ids,\n                            expert_ids, num_tokens_post_padded, False,\n                            topk_ids.shape[1], config_w1)\n\n\n    ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))\n\n    invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3,\n                            topk_weights, topk_ids, sorted_token_ids,\n                            expert_ids, num_tokens_post_padded, True, 1,\n                            config_w2)\n\n    if inplace:\n        return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),\n                         dim=1,\n                         out=hidden_states)\n\n    return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),\n                     dim=1)\n"
  },
  {
    "path": "kernels/triton/inference/col_major_moe_gemm/v2_moe_fused.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\n# Credit:\n# Woosuk vLLM: https://github.com/vllm-project/vllm/blob/3d925165f2b18379640a63fbb42de95440d63b64/vllm/model_executor/layers/fused_moe/fused_moe.py\n\n\"\"\"Fused MoE kernel.\"\"\"\nimport torch\nimport triton\nimport triton.language as tl\nfrom vllm._C import ops\n\n\n@triton.jit()\ndef col_major(pid,\n              m, n,\n              block_m: tl.constexpr, block_n: tl.constexpr):\n\n    grid_m = tl.cdiv(m, block_m)\n    grid_n = tl.cdiv(n, block_n)\n\n    pid_m = (pid % grid_n)\n    pid_n = pid // grid_m\n\n    return pid_m, pid_n\n\n@triton.jit\ndef fused_moe_kernel(\n    # Pointers to matrices\n    a_ptr,\n    b_ptr,\n    c_ptr,\n    topk_weights_ptr,\n    sorted_token_ids_ptr,\n    expert_ids_ptr,\n    num_tokens_post_padded_ptr,\n    # Matrix dimensions\n    N,\n    K,\n    EM,\n    num_valid_tokens,\n    # The stride variables represent how much to increase the ptr by when moving by 1\n    # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`\n    # by to get the element one row down (A has M rows).\n    stride_am,\n    stride_ak,\n    stride_be,\n    stride_bk,\n    stride_bn,\n    stride_cm,\n    stride_cn,\n    stride_weight,\n    stride_token_id,\n    # Meta-parameters\n    block_m: tl.constexpr,\n    block_n: tl.constexpr,\n    block_k: tl.constexpr,\n    MUL_ROUTED_WEIGHT: tl.constexpr,\n    top_k: tl.constexpr,\n    compute_type: tl.constexpr,\n):\n    \"\"\"\n    Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices.\n\n    Key Parameters:\n    - A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token.\n    - B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension.\n    - C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated,\n        and N is the output feature dimension.\n    - sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to.\n    - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A.\n    This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids`\n    by expert index and padding ensures divisibility by block_m, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert.\n    \"\"\"\n\n    pid = tl.program_id(axis=0)\n    pid_m, pid_n = col_major(pid,\n                             EM, N,\n                             block_m, block_n,)\n\n    num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n    if pid_m * block_m >= num_tokens_post_padded:\n        return\n\n    offs_token_id = pid_m * block_m + tl.arange(0, block_m)\n    offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n    token_mask = offs_token < num_valid_tokens\n\n    offs_bn = (pid_n * block_n + tl.arange(0, block_n)) % N\n    offs_k = tl.arange(0, block_k)\n    a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak)\n\n    off_experts = tl.load(expert_ids_ptr + pid_m)\n    b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n    # -----------------------------------------------------------\n    # Iterate to compute a block of the C matrix.\n    # We accumulate into a `[block_m, block_n]` block\n    # of fp32 values for higher accuracy.\n    # `accumulator` will be converted back to fp16 after the loop.\n    accumulator = tl.zeros((block_m, block_n), dtype=tl.float32)\n\n    for k in range(0, tl.cdiv(K, block_k)):\n        # Load the next block of A and B, generate a mask by checking the K dimension.\n        a = tl.load(a_ptrs,\n                    mask=token_mask[:, None] &\n                    (offs_k[None, :] < K - k * block_k),\n                    other=0.0)\n        b = tl.load(b_ptrs,\n                    mask=offs_k[:, None] < K - k * block_k,\n                    other=0.0)\n        # We accumulate along the K dimension.\n        accumulator += tl.dot(a, b)\n        # Advance the ptrs to the next K block.\n        a_ptrs += block_k * stride_ak\n        b_ptrs += block_k * stride_bk\n\n    if MUL_ROUTED_WEIGHT:\n        moe_weight = tl.load(topk_weights_ptr + offs_token * stride_weight,\n                             mask=token_mask,\n                             other=0)\n        accumulator = accumulator * moe_weight[:, None]\n\n    accumulator = accumulator.to(compute_type)\n    # -----------------------------------------------------------\n    # Write back the block of the output\n    offs_cn = pid_n * block_n + tl.arange(0, block_n)\n    c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[\n        None, :]\n    c_mask = token_mask[:, None] & (offs_cn[None, :] < N)\n    tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef moe_align_block_size(\n        topk_ids: torch.Tensor, block_size: int,\n        num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor):\n    \"\"\"\n    Aligns the token distribution across experts to be compatible with block size for matrix multiplication.\n\n    Parameters:\n    - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token.\n    - block_size: The block size used in block matrix multiplication.\n    - num_experts: The total number of experts.\n\n    Returns:\n    - sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert.\n    - expert_ids: A tensor indicating the assigned expert index for each block.\n    - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size.\n\n    This function pads the number of tokens that each expert needs to process so that it is divisible by block_size.\n    Padding ensures that during block matrix multiplication, the dimensions align correctly.\n\n    Example:\n    Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4:\n    - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens.\n    - As block_size is 4, we pad 1 token for each expert.\n    - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].\n    - Then append padding tokens [12, 12, 12, 12] for each block.\n    - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].\n        Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication.\n    - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations.\n    \"\"\"\n    sorted_ids = torch.empty(\n        (topk_ids.numel() + num_experts * (block_size - 1), ),\n        dtype=torch.int32,\n        device=topk_ids.device)\n    expert_ids = torch.empty((topk_ids.numel() + num_experts, ),\n                             dtype=torch.int32,\n                             device=topk_ids.device)\n    sorted_ids.fill_(topk_ids.numel())\n    num_tokens_post_pad = torch.empty((1),\n                                      dtype=torch.int32,\n                                      device=topk_ids.device)\n    ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,\n                             expert_ids, num_tokens_post_pad)\n    return sorted_ids, expert_ids, num_tokens_post_pad\n\n\ndef invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,\n                            topk_weights: torch.Tensor, topk_ids: torch.Tensor,\n                            sorted_token_ids: torch.Tensor,\n                            expert_ids: torch.Tensor,\n                            num_tokens_post_padded: torch.Tensor,\n                            mul_routed_weight: bool, top_k: int, config: dict):\n\n    EM = sorted_token_ids.shape[0]\n    N = B.shape[1]\n\n    grid = lambda META: (triton.cdiv(EM, META['block_m']) * triton.cdiv(N, META['block_n']), )\n    fused_moe_kernel[grid](\n        A,\n        B,\n        C,\n        topk_weights,\n        sorted_token_ids,\n        expert_ids,\n        num_tokens_post_padded,\n        B.shape[1],\n        B.shape[2],\n        sorted_token_ids.shape[0],\n        topk_ids.numel(),\n        A.stride(0),\n        A.stride(1),\n        B.stride(0),\n        B.stride(2),\n        B.stride(1),\n        C.stride(1),\n        C.stride(2),\n        topk_weights.stride(1),\n        sorted_token_ids.stride(0),\n        MUL_ROUTED_WEIGHT=mul_routed_weight,\n        top_k=top_k,\n        compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,\n        **config,\n    )\n\n\ndef fused_moe(hidden_states: torch.Tensor,\n              w1: torch.Tensor,\n              w2: torch.Tensor,\n              topk_weights: torch.Tensor,\n              topk_ids: torch.Tensor,\n              inplace=False):\n    \"\"\"\n    This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.\n\n    Parameters:\n    - hidden_states (torch.Tensor): The input tensor to the MoE layer.\n    - w1 (torch.Tensor): The first set of expert weights.\n    - w2 (torch.Tensor): The second set of expert weights.\n    - topk_weights (torch.Tensor): The weights for the top-k selected experts.\n    - topk_ids (torch.Tensor): The indices of the top-k selected experts.\n    - inplace (bool): If True, perform the operation in-place. Defaults to False.\n\n    Returns:\n    - torch.Tensor: The output tensor after applying the MoE layer.\n    \"\"\"\n    # Check constraints.\n    assert hidden_states.shape[1] == w1.shape[2], \"Incompatible dimensions\"\n    assert hidden_states.is_contiguous(), \"Hidden_states must be contiguous\"\n    assert w1.is_contiguous(), \"Expert weights1 must be contiguous\"\n    assert w2.is_contiguous(), \"Expert weights2 must be contiguous\"\n    assert hidden_states.dtype in [torch.float16, torch.bfloat16]\n    M, _ = hidden_states.shape\n    E, N, _ = w1.shape\n\n    config = {\n        'block_m': 64,\n        'block_n': 64,\n        'block_k': 32,\n    }\n\n    if topk_ids.numel() <= w1.shape[0]:\n        config = {\n            'block_m': 16,\n            'block_n': 32,\n            'block_k': 64,\n        }\n\n    intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),\n                                      device=hidden_states.device,\n                                      dtype=hidden_states.dtype)\n    intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),\n                                      device=hidden_states.device,\n                                      dtype=hidden_states.dtype)\n    intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),\n                                      device=hidden_states.device,\n                                      dtype=hidden_states.dtype)\n\n    sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(\n        topk_ids, config['block_m'], E)\n\n    invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1,\n                            topk_weights, topk_ids, sorted_token_ids,\n                            expert_ids, num_tokens_post_padded, False,\n                            topk_ids.shape[1], config)\n\n    ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))\n\n    invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3,\n                            topk_weights, topk_ids, sorted_token_ids,\n                            expert_ids, num_tokens_post_padded, True, 1,\n                            config)\n\n    if inplace:\n        return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),\n                         dim=1,\n                         out=hidden_states)\n    return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),\n                     dim=1)\n"
  },
  {
    "path": "kernels/triton/inference/flash_attention/stay_attention.py",
    "content": "import triton.language as tl\nimport triton\nimport torch\n\n\n@triton.jit()\ndef stay_attention(\n    q_ptr, k_ptr, v_ptr, o_ptr,\n    stride_b, stride_nh, \n    stride_qs, stride_qh,\n    stride_ks, stride_kh,\n    stride_vs, stride_vh,\n    stride_os, stride_oh,\n    seq_len, head_dim,\n    sm_scale,\n    BLOCK_SEQ: tl.constexpr, \n    BLOCK_HD: tl.constexpr, \n    NUM_SM: tl.constexpr,\n):  \n\n    pid_b = tl.program_id(0)\n    pid_h = tl.program_id(1)\n    pid = tl.program_id(2)\n    \n    qkv_offset = pid_b*stride_b + pid_h*stride_nh\n    num_tiles_seq_len = tl.cdiv(seq_len, BLOCK_SEQ)\n\n    tiles_per_SM = num_tiles_seq_len // NUM_SM\n    if pid < num_tiles_seq_len % NUM_SM:\n        tiles_per_SM += 1\n\n    tile_id = pid - NUM_SM\n    si = -1\n\n    pid_seq_m = 0\n    pid_seq_n = 0\n\n    offs_seq_m = tl.arange(0, BLOCK_SEQ)\n    offs_seq_n = tl.arange(0, BLOCK_SEQ)\n    offs_head = tl.arange(0, BLOCK_HD)\n\n    q_ptrs = q_ptr + qkv_offset + offs_seq_n[:, None]*stride_qs + offs_head[None, :]*stride_qh\n\n    # initialize pointer to m and l\n    m_i = tl.zeros([BLOCK_SEQ], dtype=tl.float32) - float(\"inf\")\n    l_i = tl.zeros([BLOCK_SEQ], dtype=tl.float32)\n    qk_scale = sm_scale * 1.44269504\n\n    q = tl.load(q_ptrs)\n    q = (q * qk_scale)\n    \n    pv = tl.zeros([BLOCK_SEQ, BLOCK_HD], dtype=tl.float32)\n    for _ in range(0, num_tiles_seq_len * tiles_per_SM):\n\n        si = tl.where(si == num_tiles_seq_len - 1, 0, si + 1)\n        \n        if si == 0:\n\n            tile_id += NUM_SM\n\n            pid_seq_m = pid // num_tiles_seq_len\n            pid_seq_n = pid % num_tiles_seq_len\n\n            offs_seq_m = pid_seq_m*BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)\n            offs_seq_n = pid_seq_n*BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)\n            offs_head = tl.arange(0, BLOCK_HD)\n\n            q_ptrs = q_ptr + qkv_offset + offs_seq_n[:, None]*stride_qs + offs_head[None, :]*stride_qh\n            \n            qk_scale = sm_scale * 1.44269504\n            q = tl.load(q_ptrs)\n            q = (q * qk_scale)\n        \n        offs_seq_m = si*BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)\n        offs_head = tl.arange(0, BLOCK_HD)\n\n        k_ptrs = k_ptr + qkv_offset + offs_seq_m[:, None]*stride_ks + offs_head[None, :]*stride_kh\n        v_ptrs = v_ptr + qkv_offset + offs_seq_m[:, None]*stride_vs + offs_head[None, :]*stride_vh\n\n        k = tl.load(k_ptrs)\n        v = tl.load(v_ptrs)\n\n        qk = tl.dot(q.to(tl.float16), k.T, out_dtype=tl.float32)\n\n        # -- compute scaling constant ---\n        m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n        alpha = tl.math.exp2(m_i - m_i_new)\n        p = tl.math.exp2(qk - m_i_new[:, None])\n\n        # -- scale and update acc --\n        pv *= alpha[:, None]\n        pv += tl.dot(p.to(tl.float16), v, out_dtype=tl.float32)\n\n        # -- update m_i and l_i --\n        l_i = l_i * alpha + tl.sum(p, 1)\n        m_i = m_i_new\n\n        if si == num_tiles_seq_len - 1:\n\n            offs_seq_n = pid_seq_n*BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)\n            pv = pv / l_i[:, None]\n            o_ptrs = o_ptr + qkv_offset + offs_seq_n[:, None]*stride_os + offs_head[None, :]*stride_oh\n            tl.store(o_ptrs, pv)\n            pv = tl.zeros([BLOCK_SEQ, BLOCK_HD], dtype=tl.float32)\n\n\ndef flash_fn(q, k, v):\n\n    batch, num_heads, seq_len, head_dim = q.shape\n\n    sm_scale = 0.5\n    BLOCK_SEQ = 64\n    BLOCK_HD = 128\n\n    NUM_SM = torch.cuda.get_device_properties(\"cuda\").multi_processor_count\n    grid = (batch, num_heads, NUM_SM)\n    o = torch.zeros(batch, num_heads, seq_len, head_dim, dtype=torch.float16, device='cuda')    \n    stay_attention[grid](q, k, v, o,\n                             q.stride(0), q.stride(1), \n                             q.stride(2), q.stride(3), \n                             k.stride(2), k.stride(3),\n                             v.stride(2), v.stride(3),\n                             o.stride(2), o.stride(3),\n                             seq_len, head_dim,\n                             sm_scale,\n                             BLOCK_SEQ, BLOCK_HD, NUM_SM)\n    return o \n\n\nif __name__ == '__main__':\n\n    torch.manual_seed(0)\n\n    batch, num_heads, seq_len, head_dim = 1, 32, 4096, 128\n\n    q = torch.randn(batch, num_heads, seq_len, head_dim, dtype=torch.float16, device='cuda') // 10\n    k = torch.randn(batch, num_heads, seq_len, head_dim, dtype=torch.float16, device='cuda') // 10\n    v = torch.randn(batch, num_heads, seq_len, head_dim, dtype=torch.float16, device='cuda') // 10\n\n    sm_scale = 0.5\n    p = (q @ k.transpose(2, 3)) * sm_scale\n    p = torch.softmax(p.float(), dim=-1)\n    o_torch = torch.matmul(p.to(torch.float16), v)\n\n    o_triton = flash_fn(q, k, v)\n\n    print(f\"{o_triton=}\")\n    print(f\"{o_torch=}\")\n\n    torch.testing.assert_close(o_triton, o_torch, atol=1e-2, rtol=0)\n\n"
  },
  {
    "path": "kernels/triton/inference/fp8/float8_groupwise_quant.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\n\nfrom typing import Tuple\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom triton import Config\n\n# global constants\nFP8_MAX: tl.constexpr = 448.0\nEPSILON: tl.constexpr = 1e-12\n\n\n@triton.jit\ndef _float8_groupwise_quant_kernel(\n    in_ptr, out_ptr, scale_ptr, BLOCK_SIZE: tl.constexpr\n):\n    \"\"\"\n    Quantizes the input tensor via BLOCK_SIZE groupwise scaling (i.e. 1x 128).\n\n    Results:\n    Stores\n    1 - float8_e4m3fn result in `out_ptr`\n    2 - scaling factor in `scale_ptr`\n\n    \"\"\"\n    pid = tl.program_id(axis=0)\n\n    # load inputs\n    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    x_vec = tl.load(in_ptr + offsets).to(tl.float32)\n\n    # calc max and scale\n    max_val = tl.max(tl.abs(x_vec))\n    safe_scale = tl.maximum(max_val, EPSILON) / FP8_MAX\n    y_vec = x_vec / safe_scale\n\n    # quantize\n    y_clamped = tl.minimum(tl.maximum(y_vec, -FP8_MAX), FP8_MAX)\n    y_fp8 = y_clamped.to(out_ptr.dtype.element_ty)\n\n    # store quantized values and scale\n    tl.store(out_ptr + offsets, y_fp8)\n    tl.store(scale_ptr + pid, safe_scale)\n\n\ndef float8_groupwise_quantize(x: torch.Tensor, block_size=128):\n    \"\"\"\n    Quantizes the input tensor via block_size groupwise scaling (i.e. 1x 128)\n    to torch.float8_e4m3fn format.\n\n    Results:\n    Stores\n    1 - float8_e4m3fn result in `out_ptr`\n    2 - scaling factor in `scale_ptr`\n\n    \"\"\"\n    # verify input tensor\n    x_last_dim_size = x.size(-1)\n\n    # evenly divisible?\n    if x_last_dim_size % block_size != 0:\n        raise ValueError(\n            f\"Input tensor must have a last dimension that is a multiple of {block_size}\"\n        )\n    # contiguous?\n    if x.stride(-1) != 1:\n        x = x.contiguous()\n\n    # allocate output tensors\n    output = torch.empty_like(x, dtype=torch.float8_e4m3fn)\n    scales = x.new_empty(\n        *x.size()[:-1], x_last_dim_size // block_size, dtype=torch.float32\n    )\n    print(f\"{scales.size()=}\")\n\n    grid = lambda meta: (x.numel() // block_size,)\n    _float8_groupwise_quant_kernel[grid](\n        in_ptr=x,\n        out_ptr=output,\n        scale_ptr=scales,\n        BLOCK_SIZE=block_size,\n    )\n\n    return output, scales\n"
  },
  {
    "path": "kernels/triton/inference/fp8/scaled_fp8_gemm.py",
    "content": "import torch\nimport triton\nimport triton.language as tl\nimport time\nimport os\nos.environ['ENABLE_TMA'] = '1'\n\n\n@triton.jit\ndef grouped_launch(pid,\n                m, n,\n                block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr):\n    \n    grid_m = tl.cdiv(m, block_m)\n    grid_n = tl.cdiv(n, block_n)\n\n    width = group_m * grid_n\n    group_id = pid // width\n    group_size = tl.minimum(grid_m - group_id * group_m, group_m)\n\n    pid_m = group_id * group_m + (pid % group_size)\n    pid_n = (pid % width) // group_size\n\n    return pid_m, pid_n\n\n@triton.jit()\ndef column_major(pid,\n              m, n,\n              block_m: tl.constexpr, block_n: tl.constexpr):\n    \n    grid_m = tl.cdiv(m, block_m) \n\n    pid_m = pid % grid_m\n    pid_n = pid // grid_m\n\n    return pid_m, pid_n\n\n@triton.jit\ndef scaled_gemm_splitk(a_ptr, b_ptr, c_ptr,\n            stride_am, stride_ak,\n            stride_bk, stride_bn,\n            stride_cm, stride_cn,\n            scale_a, scale_b,\n            m, n, k,\n            block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr,\n            split_k: tl.constexpr, group_m: tl.constexpr):\n    \n    pid = tl.program_id(0)\n    pid_k = tl.program_id(1)\n    grid_k = tl.cdiv(k, block_k*split_k)\n\n    # Column Major produces speedup over Grouped Launch for small-to-medium M\n    pid_m, pid_n = column_major(pid,\n                                m, n,\n                                block_m, block_n)\n\n\n    offs_m = pid_m*block_m + tl.arange(0, block_m)\n    offs_n = pid_n*block_n + tl.arange(0, block_n)\n    offs_k = pid_k*block_k + tl.arange(0, block_k)\n\n    offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m)\n    offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n)\n\n    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n\n    acc = tl.zeros((block_m, block_n), dtype=tl.float32)\n    for k_ in range(0, grid_k):\n        \n        k_remaining = k - k_ * (block_k * split_k)\n\n        a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)\n        b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)\n\n        acc = tl.dot(a, b, acc, out_dtype=tl.float32)\n\n        a_ptrs += block_k * split_k * stride_ak\n        b_ptrs += block_k * split_k * stride_bk\n    \n    # Scaled in SRAM before write back to DRAM\n    acc = scale_a * scale_b * acc\n    acc.to(tl.float16)\n\n    offs_m = pid_m*block_m + tl.arange(0, block_m)\n    offs_n = pid_n*block_n + tl.arange(0, block_n)\n    \n    c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)\n    mask = (offs_m < m)[:, None] & (offs_n < n)[None, :]\n    \n    tl.atomic_add(c_ptrs, acc, mask=mask)\n\ndef scaled_mm_splitk(a, b, scale_a: float=1.0, scale_b: float=1.0):\n    assert a.shape[1] == b.shape[0]\n    m, k = a.shape\n    _, n = b.shape\n    \n    block_m = 64\n    block_n = 64\n    block_k = 256\n    num_stages = 3\n    num_warps = 8\n    split_k = 4\n    group_m = 8\n\n    total_blocks_m = triton.cdiv(m, block_m)\n    total_blocks_n = triton.cdiv(n, block_n)\n    total_programs_mn = total_blocks_m * total_blocks_n\n    total_programs_k = split_k\n    \n    grid = (total_programs_mn, total_programs_k)\n\n    c = torch.zeros((m, n), device=a.device, dtype=torch.float16)\n    k = scaled_gemm_splitk[grid](a, b, c,\n                              a.stride(0), a.stride(1),\n                              b.stride(0), b.stride(1),\n                              c.stride(0), c.stride(1),\n                              scale_a, scale_b,                              \n                              m, n, k,\n                              block_m, block_n, block_k,\n                              split_k, group_m, num_stages=num_stages, num_warps=num_warps)\n\n    return c"
  },
  {
    "path": "kernels/triton/inference/fp8/splitk_gemm_fp8.py",
    "content": "import torch\nimport triton\nimport triton.language as tl\nimport time\nimport os\nos.environ['ENABLE_TMA'] = '1'\n\n@triton.jit\ndef grouped_launch(pid,\n                m, n,\n                block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr):\n    \n    grid_m = tl.cdiv(m, block_m)\n    grid_n = tl.cdiv(n, block_n)\n\n    width = group_m * grid_n\n    group_id = pid // width\n    group_size = tl.minimum(grid_m - group_id * group_m, group_m)\n\n    pid_m = group_id * group_m + (pid % group_size)\n    pid_n = (pid % width) // group_size\n\n    return pid_m, pid_n\n\n\n@triton.jit()\ndef col_major(pid,\n              m, n,\n              block_m: tl.constexpr, block_n: tl.constexpr):\n    \n    grid_m = tl.cdiv(m, block_m) \n\n    pid_m = pid % grid_m\n    pid_n = pid // grid_m\n\n    return pid_m, pid_n\n\n\n@triton.jit\ndef gemm_split_k_kernel(a_ptr, b_ptr, c_ptr,\n            stride_am, stride_ak,\n            stride_bk, stride_bn,\n            stride_cm, stride_cn,\n            m, n, k,\n            block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr,\n            split_k: tl.constexpr, group_m: tl.constexpr):\n    \n    pid = tl.program_id(0)\n    pid_k = tl.program_id(1)\n    grid_k = tl.cdiv(k, block_k*split_k)\n\n    pid_m, pid_n = grouped_launch(pid,\n                                  m, n,\n                                  block_m, block_n, group_m)\n\n    offs_m = pid_m*block_m + tl.arange(0, block_m)\n    offs_n = pid_n*block_n + tl.arange(0, block_n)\n    offs_k = pid_k*block_k + tl.arange(0, block_k)\n\n    offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m)\n    offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n)\n\n    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n\n    acc = tl.zeros((block_m, block_n), dtype=tl.float32)\n    for k_ in range(0, grid_k):\n        \n        k_remaining = k - k_ * (block_k * split_k)\n\n        a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)\n        b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)\n\n        acc = tl.dot(a, b, acc, out_dtype=tl.float32)\n\n        a_ptrs += block_k * split_k * stride_ak\n        b_ptrs += block_k * split_k * stride_bk\n\n    acc.to(tl.float16)\n\n    offs_m = pid_m*block_m + tl.arange(0, block_m)\n    offs_n = pid_n*block_n + tl.arange(0, block_n)\n    \n    c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)\n    mask = (offs_m < m)[:, None] & (offs_n < n)[None, :]\n    \n    tl.atomic_add(c_ptrs, acc, mask=mask)\n\ndef gemm_split_k(a, b):\n\n    m, k = a.shape\n    _, n = b.shape\n    \n    block_m = 64\n    block_n = 64\n    block_k = 512\n    num_stages = 3\n    num_warps = 8\n    split_k = 4\n    group_m = 8\n\n    total_blocks_m = triton.cdiv(m, block_m)\n    total_blocks_n = triton.cdiv(n, block_n)\n    total_programs_mn = total_blocks_m * total_blocks_n\n    total_programs_k = split_k\n    \n    grid = (total_programs_mn, total_programs_k)\n\n    # print(f\"problem m size: {m}, tile size m: {block_m}, total blocks m: {total_blocks_m}\")\n    # print(f\"problem n size: {n}, tile size n: {block_n}, total blocks n: {total_blocks_n}\")\n    # print(f\"problem k size: {k}, tile size k: {block_k}, total thread blocks k: {split_k}\")\n\n    # print(f\"total thread blocks k: {k}, total thread blocks m and total thread blocks n = {total_blocks_m=} x {total_blocks_n} = {total_programs_mn}\")\n    # print(f\"{total_programs_mn=}, {total_programs_k=}\")\n    \n    c = torch.zeros((m, n), device=a.device, dtype=torch.float16)\n    k = gemm_split_k_kernel[grid](a, b, c,\n                              a.stride(0), a.stride(1),\n                              b.stride(0), b.stride(1),\n                              c.stride(0), c.stride(1),\n                              m, n, k,\n                              block_m, block_n, block_k,\n                              split_k, group_m, num_stages=num_stages, num_warps=num_warps)\n    \n    # print(f\"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\\n\")\n\n    # with open('matmul_split_k.txt', 'w') as f:\n\n    #     print(f\"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\\n\", file=f)\n    #     print(\"IR\", k.asm['ttir'], file=f)\n    #     print(\"TTGIR\", k.asm['ttgir'], file=f)\n    #     print(\"PTX\", k.asm['ptx'], file=f)\n    #     print(f\"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\\n\", file=f)\n\n    return c\n\n    \n    \n    \n\n"
  },
  {
    "path": "kernels/triton/inference/fp8/tma_gemm.py",
    "content": "import triton\nimport triton.language as tl\nimport numpy as np\nimport torch\n\n@triton.jit\ndef gemm_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr,  #\n                      prob_m, prob_n, prob_k, block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr):\n    \n    pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(prob_m, block_m)\n    num_pid_k = tl.cdiv(prob_k, block_k)\n    pid_m = pid % num_pid_m\n    pid_n = pid // num_pid_m\n    offs_am = pid_m * block_m\n    offs_bn = pid_n * block_n\n    offs_k = 0\n\n    accumulator = tl.zeros((block_m, block_n), dtype=tl.float32)\n    for kk in range(0, num_pid_k):\n\n        a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [block_m, block_k], tl.float8e4nv)\n        b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [block_n, block_k], tl.float8e4nv)\n        \n        accumulator = tl.dot(a, b.T, acc=accumulator, out_dtype=tl.float32)\n        offs_k += block_k\n\n    accumulator = accumulator.to(tl.float16)\n    tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])\n\n\ndef matmul(a, b, config=None):\n\n    m, _ = a.shape\n    n, k = b.shape\n\n    if config:\n        block_m = config[\"block_m\"]\n        block_n = config[\"block_n\"]\n        block_k = config[\"block_k\"]\n        num_warps = config[\"num_warps\"]\n        num_stages = config[\"num_stages\"]\n    \n    block_m = 64\n    block_n = 64\n    block_k = 256\n    num_warps = 4\n    num_stages = 4\n    TMA_SIZE = 512\n\n    desc_a = np.empty(TMA_SIZE, dtype=np.int8)\n    desc_b = np.empty(TMA_SIZE, dtype=np.int8)\n    desc_c = np.empty(TMA_SIZE, dtype=np.int8)\n\n    c = torch.empty((m, n), dtype=torch.float16, device='cuda')\n    triton.runtime.driver.active.utils.fill_2d_tma_descriptor(a.data_ptr(), m, k, block_m, block_k, a.element_size(),\n                                                            desc_a)\n    triton.runtime.driver.active.utils.fill_2d_tma_descriptor(b.data_ptr(), n, k, block_n, block_k, b.element_size(),\n                                                            desc_b)\n    triton.runtime.driver.active.utils.fill_2d_tma_descriptor(c.data_ptr(), m, n, block_m, block_n, c.element_size(),\n                                                            desc_c)\n    desc_a = torch.tensor(desc_a, device='cuda')\n    desc_b = torch.tensor(desc_b, device='cuda')\n    desc_c = torch.tensor(desc_c, device='cuda')\n\n    total_blocks_m = triton.cdiv(m, block_m)\n    total_blocks_n = triton.cdiv(n, block_n)\n    \n    grid = (total_blocks_m * total_blocks_n, 1, 1)\n    k = gemm_kernel_tma[grid](\n        desc_a, desc_b, desc_c,\n        m, n, k,\n        block_m,\n        block_n,\n        block_k,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n\n    # with open('tma_fp8.ttgir', 'w') as f:\n    #      print(k.asm['ttgir'], file=f)\n\n    # with open('tma_fp8.ptx', 'w') as f:\n    #      print(k.asm['ptx'], file=f)\n\n    return c\n\n\nif __name__ == '__main__':\n\n    M = 128\n    N = 4096\n    K = 4096\n\n    a = torch.randn((M, K), device=\"cuda\", dtype=torch.float16).to(torch.float8_e4m3fn)\n    b = torch.randn((K, N), device=\"cuda\", dtype=torch.float16).to(torch.float8_e4m3fn)\n    b = b.T.contiguous()\n\n    c = matmul(a, b)\n"
  },
  {
    "path": "kernels/triton/inference/gptq/a100_qlinear.py",
    "content": "import triton\nimport triton.language as tl\nimport torch \n\n@triton.jit()\ndef _a100_quantized_matmul(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr,\n                             stride_am, stride_ak,\n                             stride_bk, stride_bn,\n                             stride_cm, stride_cn,\n                             stride_scales_g, stride_scales_n,\n                             stride_zeros_g, stride_zeros_n,\n                             groupsize,\n                             m, n, k,\n                             block_size_m: tl.constexpr, block_size_n: tl.constexpr, block_size_k: tl.constexpr,\n                             group_size_m: tl.constexpr,\n                             ):\n    \n    pid = tl.program_id(0)\n\n    total_blocks_m = tl.cdiv(m, block_size_m)\n    total_blocks_n = tl.cdiv(n, block_size_n)\n    total_blocks_k = tl.cdiv(k, block_size_k)\n\n    num_blocks_in_group = group_size_m * total_blocks_n\n    group_id = pid // num_blocks_in_group\n    group_size = min(total_blocks_m - group_id * group_size_m, group_size_m)\n\n    pid_m = group_id * group_size_m + (pid % group_size)\n    pid_n = (pid % num_blocks_in_group) // (group_size)\n\n    offs_m = (pid_m * block_size_m + tl.arange(0, block_size_m)) % m\n    offs_n = (pid_n * block_size_n + tl.arange(0, block_size_n)) % n\n\n    offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_size_m), block_size_m)\n    offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_size_n), block_size_n)\n    offs_k = tl.arange(0, block_size_k)\n    \n    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n    b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn)\n    \n    scales_ptrs = scales_ptr + offs_bn * stride_scales_n\n    zeros_ptrs = zeros_ptr + ((offs_bn // 8) * stride_zeros_n)\n\n    shifter = (offs_k % 8) * 4\n    zeros_shifter = (offs_bn % 8) * 4\n\n\n    output = tl.zeros((block_size_m, block_size_n), dtype=tl.float32)\n    for k in range(0, total_blocks_k):\n \n        a = tl.load(a_ptrs)\n        b = tl.load(b_ptrs)\n        g_id = k // (groupsize // block_size_k)\n\n        ptr = scales_ptrs + g_id * stride_scales_g\n        scales = tl.load(ptr)\n        \n        ptr = zeros_ptrs + g_id * stride_zeros_g\n        zeros = tl.load(ptr)\n\n        zeros = (zeros >> zeros_shifter) & 0xF\n        zeros = (zeros + 1) * scales\n\n        b = (b >> shifter[:, None]) & 0xF # b -> int32\n        b = b * scales[None, :] - zeros[None, :] # b -> fp16\n        \n        output += tl.dot(a, b)\n        a_ptrs += stride_ak * block_size_k\n        b_ptrs +=  (block_size_k//8) * stride_bk\n    \n    output.to(tl.float16)\n    offs_cm = pid_m * block_size_m + tl.arange(0, block_size_m)\n    offs_cn = pid_n * block_size_n + tl.arange(0, block_size_n)\n    c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn)\n    tl.store(c_ptrs, output)\n\nclass a100_qlinear(torch.autograd.Function):\n    def forward(ctx, a, b, scales, zeros):\n\n        m, k = a.shape\n        _, n = b.shape\n\n        quant_groupsize = 128\n        block_size_m = 16 \n        block_size_n = 32 # [N = 4096 // 32] = 128 blocks\n        block_size_k = 256\n        group_size_m = 8\n        num_warps = 4\n        num_stages = 8\n        total_blocks_m = triton.cdiv(m, block_size_m)\n        total_blocks_n = triton.cdiv(n, block_size_n)\n        total_programs  = total_blocks_m * total_blocks_n\n        grid = (total_programs, 1)\n\n        c = torch.zeros((m, n), device=b.device, dtype=torch.float16)\n        k = _a100_quantized_matmul[grid](\n            a, b, c, scales, zeros,\n            a.stride(0), a.stride(1),\n            b.stride(0), b.stride(1),\n            c.stride(0), c.stride(1),\n            scales.stride(0), scales.stride(1),\n            zeros.stride(0), zeros.stride(1),\n            quant_groupsize,\n            m, n, k,\n            block_size_m, block_size_n, block_size_k, group_size_m,\n            num_warps = num_warps, num_stages = num_stages,\n        )\n\n        print(f\"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\\n\")\n\n        with open('dequant_simple.txt', 'w') as f:\n\n            print(f\"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\\n\", file=f)\n            print(\"IR\", k.asm['ttir'], file=f)\n            print(\"TTGIR\", k.asm['ttgir'], file=f)\n            print(\"PTX\", k.asm['ptx'], file=f)\n            print(f\"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\\n\", file=f)\n\n            print(f\"{total_blocks_m=} x {total_blocks_n=} = {total_programs=}\")\n        return c\n        \n\na100_qlinear = a100_qlinear.apply"
  },
  {
    "path": "kernels/triton/inference/gptq/benchmark.py",
    "content": "import argparse\nimport time\nimport logging\nfrom tqdm import tqdm\nimport torch\nfrom transformers import AutoTokenizer\nfrom auto_gptq import AutoGPTQForCausalLM \n\n# Configure logging\nlogger = logging.getLogger(__name__)\nlogging.basicConfig(level=logging.INFO)\n\n\n\ndef benchmark_generation_speed(model, tokenizer, prompt, batch_size, device, num_passes=5):\n\n    token_dict = tokenizer([prompt] * batch_size, return_tensors=\"pt\", padding=\"longest\").to(device)\n\n    total_generation_time = 0\n    total_num_generated_tokens = 0\n\n    # Warmup\n    logger.info(\"Starting warmup...\")\n    for _ in tqdm(range(4), desc=\"Warmup\", leave=False):\n        with torch.inference_mode():\n            _ = model.generate(**token_dict, min_length=30, max_length=30)\n\n    logger.info(\"Starting benchmark...\")\n    with tqdm(range(num_passes), desc=\"Benchmark Passes\") as pbar:\n        for pass_num in pbar:\n            token_dict = tokenizer([prompt] * batch_size, return_tensors=\"pt\", padding=\"longest\").to(device)\n\n            start = time.time()\n            with torch.inference_mode():\n                outputs_ids = model.generate(**token_dict, min_length=30, max_length=30)\n            end = time.time()\n\n            generation_time = end - start\n            num_generated_tokens = sum(len(output_ids) for output_ids in outputs_ids) - batch_size * len(token_dict['input_ids'][0])\n            tokens_per_second = num_generated_tokens / generation_time\n\n            total_generation_time += generation_time\n            total_num_generated_tokens += num_generated_tokens\n\n            # Update tqdm post-fix with current iteration results\n            pbar.set_postfix({\"Time (s)\": f\"{generation_time:.2f}\", \"Tokens/s\": f\"{tokens_per_second:.2f}\"})\n\n    # Calculate average statistics\n    avg_generation_time = total_generation_time / num_passes\n    avg_tokens_per_second = total_num_generated_tokens / total_generation_time\n    avg_num_generated_tokens = total_num_generated_tokens / num_passes\n\n    # Log average statistics\n    logger.info(f\"Batch size: {batch_size}, Avg Time: {avg_generation_time:.2f}s, Avg Tokens/s: {avg_tokens_per_second:.2f}, Avg Total tokens: {avg_num_generated_tokens}\")\n    return avg_generation_time, avg_tokens_per_second, avg_num_generated_tokens\n\n\n\ndef main():\n    parser = argparse.ArgumentParser(description='Benchmark Llama-70B')\n    parser.add_argument('--use_triton', type=lambda x: (str(x).lower() == 'true'), help='use Triton Kernel')\n    parser.add_argument('--batch_size', type=int, required=True, help='Batch size for the benchmark')\n    args = parser.parse_args()\n\n    device = \"cuda:5\"  \n    quantized_model_dir = '/net/storage149/autofs/css22/ccyang/fm-models/llama-gptq/gptq_output_act0_grp128_bluewiki' \n    \n    tokenizer = AutoTokenizer.from_pretrained(quantized_model_dir, use_fast=True)\n    tokenizer.pad_token = tokenizer.eos_token\n    \n    tokenizer.padding_side = \"left\"\n\n    if args.use_triton:\n        torch.cuda.empty_cache()\n        model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device=device, inject_fused_attention=False, inject_fused_mlp=False,\n                                                use_triton=args.use_triton, disable_exllamaV2=True, low_cpu_mem_usage=True, warmup_triton=False)\n    else:\n        model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device=device, inject_fused_attention=False, inject_fused_mlp=False,\n                                            use_triton=False, disable_exllamaV2=False, low_cpu_mem_usage=True, warmup_triton=False)\n        \n    model = torch.compile(model, mode=\"reduce-overhead\")\n    benchmark_generation_speed(model, tokenizer, \"auto-gptq is a\", args.batch_size, device)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kernels/triton/inference/gptq/h100_qlinear.py",
    "content": "import triton\nimport triton.language as tl\nimport torch \n\n\n@triton.jit()\ndef _h100_quantized_matmul(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr,\n                             stride_am, stride_ak,\n                             stride_bk, stride_bn,\n                             stride_cm, stride_cn,\n                             stride_scales_g, stride_scales_n,\n                             stride_zeros_g, stride_zeros_n,\n                             groupsize,\n                             m, n, k,\n                             block_size_m: tl.constexpr, block_size_n: tl.constexpr, block_size_k: tl.constexpr,\n                             group_size_m: tl.constexpr,\n                             fp8_fast_accum: tl.constexpr,):\n    \n    pid = tl.program_id(0)\n\n    total_blocks_m = tl.cdiv(m, block_size_m)\n    total_blocks_n = tl.cdiv(n, block_size_n)\n    total_blocks_k = tl.cdiv(k, block_size_k)\n\n    num_blocks_in_group = group_size_m * total_blocks_n\n    group_id = pid // num_blocks_in_group\n    group_size = min(total_blocks_m - group_id * group_size_m, group_size_m)\n\n    pid_m = group_id * group_size_m + (pid % group_size)\n    pid_n = (pid % num_blocks_in_group) // (group_size)\n\n    offs_n = pid_n * block_size_n + tl.arange(0, block_size_n)\n    offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_size_n), block_size_n)\n    offs_k = tl.arange(0, block_size_k)\n\n    a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(m,k), strides=(stride_am, stride_ak),\n                                offsets=(pid_m*block_size_m, 0), block_shape=(block_size_m, block_size_k),\n                                order =(1,0))\n    \n\n    b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn)\n    scales_ptrs = scales_ptr + offs_bn * stride_scales_n\n    zeros_ptrs = zeros_ptr + ((offs_bn // 8) * stride_zeros_n)\n\n    shifter = (offs_k % 8) * 4\n    zeros_shifter = (offs_bn % 8) * 4\n\n    acc = tl.zeros((block_size_m, block_size_n), dtype=tl.float32)\n    for k in range(0, total_blocks_k):\n\n        a = tl.load(a_block_ptr, boundary_check=(0,1))\n        b = tl.load(b_ptrs)\n        g_id = k // (groupsize // block_size_k)\n\n        ptr = scales_ptrs + g_id * stride_scales_g\n\n        scales = tl.load(ptr)\n        ptr = zeros_ptrs + g_id * stride_zeros_g\n        zeros = tl.load(ptr)\n\n        zeros = (zeros >> zeros_shifter) & 0xF\n        zeros = (zeros + 1) * scales\n\n        b = (b >> shifter[:, None]) & 0xF\n        b = b * scales[None, :] - zeros[None, :]\n\n        if fp8_fast_accum:\n            acc = tl.dot(a.to(tl.float), b.to(tl.float8e4nv), acc)\n        else:\n            acc += tl.dot(a,b)\n\n        a_block_ptr = tl.advance(a_block_ptr, (0, block_size_k))\n        b_ptrs += (block_size_k//8) * stride_bk\n\n    acc.to(tl.float16)\n    offs_cm = pid_m * block_size_m + tl.arange(0, block_size_m)\n    offs_cn = pid_n * block_size_n + tl.arange(0, block_size_n)\n\n    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n    c_mask = (offs_cm[:, None] < n) & (offs_cn[None, :] < n)\n    tl.store(c_ptrs, acc, mask=c_mask)\n\n\n    \n\n\nclass h100_qlinear(torch.autograd.Function):\n    def forward(ctx, a, b, scales, zeros):\n\n        m, k = a.shape\n        _, n = b.shape\n\n        quant_groupsize = 128\n        block_size_m = 16\n        block_size_n = 32\n        block_size_k = 256\n        group_size_m = 8\n        num_warps = 4\n        num_stages = 4\n        total_blocks_m = triton.cdiv(m, block_size_m)\n        total_blocks_n = triton.cdiv(n, block_size_n)\n        total_programs  = total_blocks_m * total_blocks_n\n        grid = (total_programs, 1)\n        fp8_fast_accum = False\n\n        c = torch.zeros((m, n), device=a.device, dtype=a.dtype)\n        k = _h100_quantized_matmul[grid](\n            a, b, c, scales, zeros,\n            a.stride(0), a.stride(1),\n            b.stride(0), b.stride(1),\n            c.stride(0), c.stride(1),\n            scales.stride(0), scales.stride(1),\n            zeros.stride(0), zeros.stride(1),\n            quant_groupsize,\n            m, n, k,\n            block_size_m, block_size_n, block_size_k, group_size_m, fp8_fast_accum = fp8_fast_accum,\n            num_warps = num_warps, num_stages = num_stages,\n        )\n\n        print(f\"{total_blocks_m=} x {total_blocks_n=} = {total_programs=}\")\n        return c\n        \n\nh100_qlinear = h100_qlinear.apply"
  },
  {
    "path": "kernels/triton/inference/gptq/mixtral/test_dequant_moe_gemm.py",
    "content": "import pytest\nimport torch\nfrom vllm.model_executor.layers.fused_moe import fused_moe\nfrom vllm.model_executor.layers.activation import SiluAndMul\nfrom triton.kernels.gptq.mixtral.w4a16_fused_dequant_gemm import dequant_gemm_moe\nfrom v0_moe_fused import fused_moe as fused_moe_base\nimport time\n\ndef torch_moe(a, w1, w2, topk_weight, topk_ids):\n    B, D = a.shape\n    a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D)\n    out = torch.zeros(B * topk_ids.shape[1],\n                      w2.shape[1],\n                      dtype=a.dtype,\n                      device=a.device)\n    \n    topk_ids = topk_ids.view(-1)\n    topk_weight = topk_weight.view(-1)\n    for i in range(w1.shape[0]):\n        mask = topk_ids == i\n        if mask.sum():\n            out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)\n    return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1)).sum(dim=1)\n\ndef test_dequant_moe(\n    m: int,\n    n: int,\n    k: int,\n    e: int,\n    topk: int,\n):  \n    m = m\n    n = n\n    k = k\n    e = e\n    topk = topk\n    groupsize = 128\n    packed_k_dim = k // 8 \n    packed_n_dim = n // 8\n    g = k // groupsize\n    topk = 2\n\n    a = torch.randn((m, k), dtype=torch.float16, device='cuda')\n    qw1 = torch.randint(0, 5, (e, packed_k_dim, n), device='cuda', dtype=torch.int32)\n    qw2 = torch.randint(0, 5, (e, 2*n, packed_k_dim), device='cuda', dtype=torch.int32)\n    qw1_zeros = torch.randint(0, 5, (e, g, packed_n_dim), device='cuda', dtype=torch.int32)\n    qw2_zeros = torch.randint(0, 5, (e, g, packed_n_dim), device='cuda', dtype=torch.int32)\n    qw1_scales = torch.randn((e, g, n), dtype=torch.float16, device='cuda')\n    qw2_scales = torch.randn((e, g, n), dtype=torch.float16, device='cuda')\n    score = torch.randn((m, e), device='cuda', dtype=torch.float16)\n    score = torch.softmax(score, dim=-1)\n    _, topk_ids = torch.topk(score, topk)\n\n\n    # dtype = torch.float16\n    # a = torch.randn((m, k), device='cuda', dtype=dtype) / 10\n    # w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10\n    # w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10\n\n\n    # score = torch.randn((m, e), device='cuda', dtype=dtype)\n    # score = torch.softmax(score, dim=-1)\n    # topk_weight, topk_ids = torch.topk(score, topk)\n\n    # triton_output_base = fused_moe_base(a, w1, w2, topk_weight, topk_ids, False)\n\n    # print(triton_output_base)\n\n    # breakpoint()\n    c = dequant_gemm_moe(a, \n                     qw1,\n                     qw2,\n                     qw1_scales,\n                     qw2_scales,\n                     qw1_zeros,\n                     qw2_zeros,\n                     topk_ids,\n                    )\n    # print(c)\n    # assert torch.allclose(triton_output_splitk, torch_output, atol=1e-1, rtol=0)\n\nif __name__ == '__main__':\n\n    test_dequant_moe(2, 14336//2, 4096, 8, 2)"
  },
  {
    "path": "kernels/triton/inference/gptq/mixtral/w4a16_fused_dequant_gemm.py",
    "content": "\"\"\"Fused MoE W4A16 Kernel.\"\"\"\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom vllm._C import ops\n\n@triton.jit\ndef print_tensor_dim(tensor, str_name):\n    if tl.program_id(0) == 0 and tl.program_id(1) == 0:\n        tl.static_print(str_name,\" \",tensor.shape,\" \",tensor.dtype)\n@triton.jit\ndef print_value(value):\n    if tl.program_id(0) == 0 and tl.program_id(1) == 0:\n        tl.device_print(str(value))\n\n@triton.jit()\ndef grouped_launch(pid,\n                m, n,\n                block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr):\n    \n    grid_m = tl.cdiv(m, block_m)\n    grid_n = tl.cdiv(n, block_n)\n\n    width = group_m * grid_n\n    group_id = pid // width\n    group_size = min(grid_m - group_id * group_m, group_m)\n\n    pid_m = group_id * group_m + (pid % group_size)\n    pid_n = (pid % width) // group_size\n\n    return pid_m, pid_n\n\n\n@triton.jit()\ndef col_major(pid,\n              m, n, num_tokens_post_padded,\n              block_m: tl.constexpr, block_n: tl.constexpr):\n    \n    grid_m = tl.cdiv(m, block_m)    \n    grid_n = tl.cdiv(n, block_n)\n    \n    pid_m = (pid % grid_n) \n    pid_n = pid // grid_m\n\n    return pid_m, pid_n\n\n\n@triton.jit()\ndef w4a16_fused_moe_kernel(\n    # Pointers to matrices\n    a_ptr,\n    b_ptr,\n    c_ptr,\n    sorted_token_ids_ptr,\n    expert_ids_ptr,\n    num_tokens_post_padded_ptr,\n    # Quantization Scales and Zeros Ptr\n    scales_ptr,\n    zeros_ptr,\n    # Matrix dimensions\n    N,\n    K,\n    EM,\n    num_valid_tokens,\n    # The stride variables represent how much to increase the ptr by when moving by 1\n    # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`\n    # by to get the element one row down (A has M rows).\n    stride_am, stride_ak,\n    stride_be, stride_bk, stride_bn,\n    stride_cm, stride_cn,\n    # Quantization Scales and Zeros Strides\n    stride_scales_e, stride_scales_g, stride_scales_n,\n    stride_zeros_e, stride_zeros_g, stride_zeros_n,\n    # Meta-parameters\n    groupsize: tl.constexpr,\n    top_k: tl.constexpr,\n    block_m: tl.constexpr,\n    block_n: tl.constexpr,\n    block_k: tl.constexpr,\n    group_m: tl.constexpr,\n):\n\n    pid = tl.program_id(0)\n\n    # GEMM Schedule\n    pid_m, pid_n = grouped_launch(pid,\n                                  EM, N,\n                                  block_m, block_n, group_m)\n    grid_k = tl.cdiv(K, block_k)\n    num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n    if pid_m * block_m >= num_tokens_post_padded:\n        return\n    \n    # Offset Calculations\n    offs_token_id = pid_m*block_m + tl.arange(0, block_m)\n    offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n    offs_bn = (pid_n * block_n + tl.arange(0, block_n)) % N # NOTE: No change needed here since weights are packed along K dim\n    offs_k = tl.arange(0, block_k)\n    off_experts = tl.load(expert_ids_ptr + pid_m)\n\n    # Mask for Activations\n    token_mask = offs_token < num_valid_tokens \n\n    # Pointer Calculations\n    a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) #NOTE: offs_token[:, None] // top_k -> since each row of activations repeats top_k times\n    b_ptrs = b_ptr + off_experts * stride_be + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) #NOTE: offs_k[:, None] // 8 -> since B is packed along k dim is packed \n\n    # We need to handle the e dim of the scales and zeros pointers\n    # We can do this in the same fashion that the stacked expert weight matrix is handled\n    \n    # off_experts = tl.load(expert_ids_ptr + pid_m)\n    # b_ptr + off_experts * stride_be + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn)\n\n    scales_ptrs = scales_ptr + off_experts * stride_scales_e + offs_bn * stride_scales_n\n    zeros_ptrs = zeros_ptr + off_experts * stride_zeros_e + ((offs_bn//8) * stride_zeros_n)\n\n    shifter = (offs_k % 8) * 4\n    zeros_shifter = (offs_bn % 8) * 4\n\n    acc = tl.zeros([block_m, block_n], dtype=tl.float32)\n    for k in range(0, grid_k):\n\n        a = tl.load(a_ptrs, mask=token_mask[:, None] & (offs_k[None, :] < K - k * block_k), other=0.0)\n        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * block_k, other=0.0)\n        \n        g_id = k // (groupsize // block_k)\n        ptr = scales_ptrs + g_id * stride_scales_g\n        \n        scales = tl.load(ptr)\n        ptr = zeros_ptrs + g_id * stride_zeros_g\n        zeros = tl.load(ptr) \n        zeros = (zeros >> zeros_shifter) & 0xF\n        zeros = (zeros + 1) * scales\n\n        b = (b >> shifter[:, None]) & 0xF\n        b = b * scales[None, :] - zeros[None, :]\n\n        acc += tl.dot(a, b)\n\n        a_ptrs += block_k * stride_ak\n        b_ptrs += (block_k // 8) * stride_bk\n    \n    acc.to(tl.float16)\n\n    offs_m = pid_m*block_m + tl.arange(0, block_m)\n    offs_n = pid_n*block_n + tl.arange(0, block_n)\n    c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)\n    tl.store(c_ptrs, acc)\n\n\n\ndef invoke_dequant_gemm_moe(activations: torch.Tensor, \n                            qweight: torch.Tensor, \n                            c: torch.Tensor,\n                            scales: torch.Tensor, \n                            qzeros: torch.Tensor,\n                            topk_ids: torch.Tensor, \n                            sorted_token_ids: torch.Tensor,\n                            expert_ids: torch.Tensor,\n                            num_tokens_post_padded: torch.Tensor,\n                            topk: torch.Tensor,\n                            ):\n    \n    EM = sorted_token_ids.shape[0]\n    N = qweight.shape[1]\n    K = qweight.shape[2]\n    block_m = 32\n    block_n = 32\n    block_k = 32\n    group_m = 8\n    groupsize = 128\n    topk = 2\n\n    if topk_ids.numel() <= qweight.shape[0]:\n            block_m = 16\n            block_n = 128\n            block_k = 128\n            group_m = 8\n\n    total_blocks_m = triton.cdiv(EM, block_m)\n    total_blocks_n = triton.cdiv(N, block_n)\n\n    grid = (total_blocks_m * total_blocks_n,)\n    w4a16_fused_moe_kernel[grid](\n        activations,\n        qweight,\n        c,\n        sorted_token_ids,\n        expert_ids,\n        num_tokens_post_padded,\n        scales,\n        qzeros,\n        N,\n        K,\n        EM,\n        topk_ids.numel(),\n        activations.stride(0), activations.stride(1),\n        qweight.stride(0), qweight.stride(2), qweight.stride(1),\n        c.stride(1), c.stride(2),\n        scales.stride(0), scales.stride(1), scales.stride(2),\n        qzeros.stride(0), qzeros.stride(1), qzeros.stride(2),\n        groupsize=groupsize,\n        top_k=topk,\n        block_m=block_m,\n        block_n=block_n,\n        block_k=block_k,\n        group_m=group_m,\n    )\n\ndef moe_align_block_size(\n        topk_ids: torch.Tensor, block_size: int,\n        num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor):\n    \"\"\"\n    Aligns the token distribution across experts to be compatible with block size for matrix multiplication.\n\n    Parameters:\n    - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token.\n    - block_size: The block size used in block matrix multiplication.\n    - num_experts: The total number of experts.\n\n    Returns:\n    - sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert.\n    - expert_ids: A tensor indicating the assigned expert index for each block.\n    - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size.\n\n    This function pads the number of tokens that each expert needs to process so that it is divisible by block_size. \n    Padding ensures that during block matrix multiplication, the dimensions align correctly.\n\n    Example:\n    Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4:\n    - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens.\n    - As block_size is 4, we pad 1 token for each expert.\n    - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].\n    - Then append padding tokens [12, 12, 12, 12] for each block.\n    - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. \n        Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication.\n    - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations.\n    \"\"\"\n    sorted_ids = torch.empty(\n        (topk_ids.numel() + num_experts * (block_size - 1), ),\n        dtype=torch.int32,\n        device=topk_ids.device)\n    expert_ids = torch.empty((topk_ids.numel() + num_experts, ),\n                             dtype=torch.int32,\n                             device=topk_ids.device)\n    sorted_ids.fill_(topk_ids.numel())\n    num_tokens_post_pad = torch.empty((1),\n                                      dtype=torch.int32,\n                                      device=topk_ids.device)\n    ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,\n                             expert_ids, num_tokens_post_pad)\n    return sorted_ids, expert_ids, num_tokens_post_pad\n\ndef dequant_gemm_moe(hidden_states: torch.Tensor,\n                    qw1: torch.Tensor,\n                    qw2: torch.Tensor,\n                    scales_qw1: torch.Tensor,\n                    scales_qw2: torch.Tensor,\n                    zeros_qw1: torch.Tensor,\n                    zeros_qw2: torch.Tensor,\n                    topk_ids: torch.Tensor,\n                    ):\n    \n    # Check constraints.\n    # assert hidden_states.shape[1] == qw1.shape[2], \"Incompatible dimensions\"\n    assert hidden_states.is_contiguous(), \"Hidden_states must be contiguous\"\n    assert qw1.is_contiguous(), \"Expert weights1 must be contiguous\"\n    assert qw2.is_contiguous(), \"Expert weights2 must be contiguous\"\n    # assert hidden_states.dtype in [torch.float16, torch.bfloat16]\n    M, _ = hidden_states.shape\n    E, N, _ = qw1.shape\n\n    block_m = 32\n    if topk_ids.numel() <= qw1.shape[0]:\n        block_m = 16\n\n    intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),\n                                      device=hidden_states.device,\n                                      dtype=hidden_states.dtype)\n    intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),\n                                      device=hidden_states.device,\n                                      dtype=hidden_states.dtype)\n    intermediate_cache3 = torch.empty((M, topk_ids.shape[1], qw2.shape[1]),\n                                      device=hidden_states.device,\n                                      dtype=hidden_states.dtype)\n\n    sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(\n        topk_ids, block_m, E)\n\n    invoke_dequant_gemm_moe(hidden_states, \n                            qw1, \n                            intermediate_cache1,\n                            scales_qw1, \n                            zeros_qw1,\n                            topk_ids, \n                            sorted_token_ids,\n                            expert_ids, \n                            num_tokens_post_padded,\n                            topk_ids.shape[1],)\n    \n    # return torch.sum(intermediate_cache1.view(*intermediate_cache1.shape), dim=1)\n\n    ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))\n\n    invoke_dequant_gemm_moe(intermediate_cache2, \n                            qw2, \n                            intermediate_cache3,\n                            scales_qw2,\n                            zeros_qw2,\n                            topk_ids, \n                            sorted_token_ids,\n                            expert_ids, \n                            num_tokens_post_padded, \n                            1,)\n    \n    return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),\n                    dim=1)\n    \n"
  },
  {
    "path": "kernels/triton/inference/gptq/small_benchmark_cuda_graphs.py",
    "content": "import torch\nimport triton\nfrom triton import language as tl\nimport sys\nimport marlin \nimport torch.nn as nn\nfrom auto_gptq.utils.import_utils import dynamically_import_QuantLinear\nfrom auto_gptq.modeling._utils import autogptq_post_init\n\n@triton.jit()\ndef swizzle_tile(pid,\n                m, n,\n                block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr):\n    \n    grid_m = tl.cdiv(m, block_m)\n    grid_n = tl.cdiv(n, block_n)\n\n    width = group_m * grid_n\n    group_id = pid // width\n    group_size = tl.minimum(grid_m - group_id * group_m, group_m)\n\n    pid_m = group_id * group_m + (pid % group_size)\n    pid_n = (pid % width) // group_size\n\n    return pid_m, pid_n\n\n\n@triton.jit()\ndef matmul_data_parallel_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr,\n                             stride_am, stride_ak,\n                             stride_bk, stride_bn,\n                             stride_cm, stride_cn,\n                             stride_scales_g, stride_scales_n,\n                             stride_zeros_g, stride_zeros_n,\n                             groupsize,\n                             m, n, k,\n                             block_size_m: tl.constexpr, block_size_n: tl.constexpr, block_size_k: tl.constexpr,\n                             group_size_m: tl.constexpr,\n                             fp8_fast_accum: tl.constexpr,):\n    \n    pid = tl.program_id(0)\n    total_blocks_m = tl.cdiv(m, block_size_m)\n    total_blocks_n = tl.cdiv(n, block_size_n)\n    total_blocks_k = tl.cdiv(k, block_size_k)\n\n    num_blocks_in_group = group_size_m * total_blocks_n\n    group_id = pid // num_blocks_in_group\n    group_size = min(total_blocks_m - group_id * group_size_m, group_size_m)\n\n    pid_m = group_id * group_size_m + (pid % group_size)\n    pid_n = (pid % num_blocks_in_group) // (group_size)\n\n    offs_m = (pid_m * block_size_m + tl.arange(0, block_size_m)) % m\n    offs_n = (pid_n * block_size_n + tl.arange(0, block_size_n)) % n\n\n    offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_size_m), block_size_m)\n    offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_size_n), block_size_n)\n    offs_k = tl.arange(0, block_size_k)\n    \n    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (16, 64)\n    b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn)\n    \n    scales_ptrs = scales_ptr + offs_bn * stride_scales_n\n    zeros_ptrs = zeros_ptr + ((offs_bn // 8) * stride_zeros_n)\n\n    shifter = (offs_k % 8) * 4\n    zeros_shifter = (offs_bn % 8) * 4\n\n    output = tl.zeros((block_size_m, block_size_n), dtype=tl.float32)\n    for k in range(0, total_blocks_k):\n\n \n        a = tl.load(a_ptrs)\n        b = tl.load(b_ptrs)\n\n        # tl.device_print(\"data parallel b: \", b)\n\n        g_id = k // (groupsize // block_size_k)\n\n        ptr = scales_ptrs + g_id * stride_scales_g\n        scales = tl.load(ptr)\n        \n        ptr = zeros_ptrs + g_id * stride_zeros_g\n        zeros = tl.load(ptr)\n\n        zeros = (zeros >> zeros_shifter) & 0xF\n        zeros = (zeros + 1) * scales\n\n        b = (b >> shifter[:, None]) & 0xF # b is int32\n        b = b * scales[None, :] - zeros[None, :] # b is fp16\n        \n        # output +=  tl.dot(a, b)\n        # output += tl.sum(a, b, axis=0)\n        # print(b.type)\n        # result = a[:, None] * b # (1 x 64 x 64 x 32) x illegal # (NEED A SQUARE MATRIX for B)\n        # b -> 64 x 64 instead 64 x 32\n\n        output += tl.dot(a, b)\n        # a_block_ptr = tl.advance(a_block_ptr, (0, block_size_k))\n        a_ptrs += stride_ak * block_size_k\n        b_ptrs += (block_size_k//8) * stride_bk\n    \n    output.to(tl.float16)\n    offs_cm = pid_m * block_size_m + tl.arange(0, block_size_m)\n    offs_cn = pid_n * block_size_n + tl.arange(0, block_size_n)\n    c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn)\n    tl.store(c_ptrs, output)\n\nclass small_qlinear(torch.autograd.Function):\n    def forward(ctx, a, b, scales, zeros):\n\n        m, k = a.shape\n        _, n = b.shape\n\n        quant_groupsize = 128\n        block_size_m = 64\n        block_size_n = 64 # [N = 4096 // 32] = 128 blocks\n        block_size_k = 64\n        group_size_m = 8\n        num_warps = 4\n        num_stages = 8\n        total_blocks_m = triton.cdiv(m, block_size_m)\n        total_blocks_n = triton.cdiv(n, block_size_n)\n        total_programs  = total_blocks_m * total_blocks_n\n        grid = (total_programs, 1)\n        fp8_fast_accum = False\n\n        c = torch.zeros((m, n), device=b.device, dtype=torch.float16)\n        # output = torch.em\n        k = matmul_data_parallel_kernel[grid](\n            a, b, c, scales, zeros,\n            a.stride(0), a.stride(1),\n            b.stride(0), b.stride(1),\n            c.stride(0), c.stride(1),\n            scales.stride(0), scales.stride(1),\n            zeros.stride(0), zeros.stride(1),\n            quant_groupsize,\n            m, n, k,\n            block_size_m, block_size_n, block_size_k, group_size_m, fp8_fast_accum = fp8_fast_accum,\n            num_warps = num_warps, num_stages = num_stages,\n        )\n\n        print(f\"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\\n\")\n\n        with open('dequant_simple.txt', 'w') as f:\n\n            print(f\"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\\n\", file=f)\n            print(\"IR\", k.asm['ttir'], file=f)\n            print(\"TTGIR\", k.asm['ttgir'], file=f)\n            print(\"PTX\", k.asm['ptx'], file=f)\n            print(f\"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\\n\", file=f)\n\n            print(f\"{total_blocks_m=} x {total_blocks_n=} = {total_programs=}\")\n        return c\n        \n\nmatmul_data_parallel = small_qlinear.apply\n\n\n@triton.jit()\ndef matmul_split_k_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr,\n            stride_am, stride_ak,\n            stride_bk, stride_bn,\n            stride_cm, stride_cn,\n            stride_scales_g, stride_scales_n,\n            stride_zeros_g, stride_zeros_n,\n            groupsize,\n            m, n, k,\n            block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr,\n            group_m: tl.constexpr, split_k: tl.constexpr):\n    \n    pid = tl.program_id(0)\n    pid_k = tl.program_id(1)\n    num_pid_k = tl.cdiv(k, block_k*split_k)\n\n    pid_m, pid_n = swizzle_tile(pid,\n                                m, n,\n                                block_m, block_n, group_m)\n    \n    offs_m = pid_m*block_m + tl.arange(0, block_m)\n    offs_n = pid_n*block_n + tl.arange(0, block_n)\n    offs_k = pid_k*block_k + tl.arange(0, block_k)\n\n    offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m)\n    offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n) \n\n    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n    b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn)\n\n    scales_ptrs = scales_ptr + offs_bn * stride_scales_n\n    zeros_ptrs = zeros_ptr + ((offs_bn // 8) * stride_zeros_n)\n\n    shifter = (offs_k % 8) * 4\n    zeros_shifter = (offs_bn % 8) * 4\n    \n    acc = tl.zeros((block_m, block_n), dtype=tl.float32)\n    for k in range(0, num_pid_k):\n        \n        a = tl.load(a_ptrs)\n        b = tl.load(b_ptrs)\n        \n        g_id = k // (groupsize // (block_k*split_k)) \n\n        ptr = scales_ptrs + g_id * stride_scales_g\n        scales = tl.load(ptr) # -> 1D naive assumes no reordering\n        \n        ptr = zeros_ptrs + g_id * stride_zeros_g\n        zeros = tl.load(ptr) # -> 1D naive assumes no reordering\n\n        zeros = (zeros >> zeros_shifter) & 0xF\n        zeros = (zeros + 1) * scales\n\n        b = (b >> shifter[:, None]) & 0xF # b is int32\n        b = b * scales[None, :] - zeros[None, :]\n\n        acc += tl.dot(a, b)\n        a_ptrs += block_k * split_k * stride_ak\n        b_ptrs += (block_k//8) * split_k * stride_bk\n\n    acc.to(tl.float16)\n\n    offs_cm = pid_m*block_m + tl.arange(0, block_m)\n    offs_cn = pid_n*block_n + tl.arange(0, block_n)\n\n    c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn)\n    tl.atomic_add(c_ptrs, acc)\n\ndef matmul_split_k(a, b, scales, zeros):\n\n    m, k = a.shape\n    _, n = b.shape\n    \n    quant_groupsize = 128\n    block_m = 16\n    block_n = 32\n    block_k = 128\n    group_m = 8\n    num_stages = 3\n    num_warps = 4\n    split_k = 4\n\n    total_blocks_m = triton.cdiv(m, block_m)\n    total_blocks_n = triton.cdiv(n, block_n)\n    total_programs_mn = total_blocks_m * total_blocks_n\n    total_programs_k = split_k\n    \n    grid = (total_programs_mn, total_programs_k)\n\n    # print(f\"problem m size: {m}, tile size m: {block_m}, total blocks m: {total_blocks_m}\")\n    # print(f\"problem n size: {n}, tile size n: {block_n}, total blocks n: {total_blocks_n}\")\n    # print(f\"problem k size: {k}, tile size k: {block_k}, total thread blocks k: {split_k}\")\n    # print(f\"total thread blocks k: {k}, total thread blocks m and total thread blocks n = {total_blocks_m=} x {total_blocks_n} = {total_programs_mn}\")\n\n\n    # print(f\"{total_programs_mn=}, {total_programs_k=}\")\n    \n    c = torch.zeros((m, n), device=a.device, dtype=torch.float16)\n    k = matmul_split_k_kernel[grid](a, b, c, scales, zeros,\n                              a.stride(0), a.stride(1),\n                              b.stride(0), b.stride(1),\n                              c.stride(0), c.stride(1),\n                              scales.stride(0), scales.stride(1),\n                              zeros.stride(0), zeros.stride(1),\n                              quant_groupsize,\n                              m, n, k,\n                              block_m, block_n, block_k,\n                              group_m, split_k, num_stages=num_stages, num_warps=num_warps)\n    \n    # print(f\"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\\n\")\n\n    # with open('matmul_split_k.txt', 'w') as f:\n\n    #     print(f\"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\\n\", file=f)\n    #     print(\"IR\", k.asm['ttir'], file=f)\n    #     print(\"TTGIR\", k.asm['ttgir'], file=f)\n    #     print(\"PTX\", k.asm['ptx'], file=f)\n    #     print(f\"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\\n\", file=f)\n\n    return c\n\ndef make_tensor(M, N, dtype):\n    if dtype == torch.int32:\n        # Fill with random integers for int32 type\n        res = torch.randint(low=-2**31, high=2**31, size=(M, N), dtype=dtype, device=\"cuda\")\n    else:\n        # Fill with normally distributed random values for other types\n        res = torch.empty((M, N), dtype=dtype, device=\"cuda\")\n        res.normal_(mean=0.0, std=0.5)\n    return res\n\n\ndef gen_quant4(m, n, groupsize=-1):\n    tile = 16\n    maxq = 2 ** 4 - 1\n    w = torch.randn((m, n), dtype=torch.half, device=\"cuda\")\n    if groupsize != -1:\n        w = w.reshape((-1, groupsize, n))\n        w = w.permute(1, 0, 2)\n        w = w.reshape((groupsize, -1))\n    s = torch.max(torch.abs(w), 0, keepdim=True)[0]\n    s *= 2 / maxq\n    w = torch.round(w / s).int()\n    w += (maxq + 1) // 2\n    w = torch.clamp(w, 0, maxq)\n    ref = (w - (maxq + 1) // 2).half() * s\n    if groupsize != -1:\n        def reshape(w):\n            w = w.reshape((groupsize, -1, n))\n            w = w.permute(1, 0, 2)\n            w = w.reshape((m, n)).contiguous()\n            return w\n        ref = reshape(ref)\n        w = reshape(w)\n    s = s.reshape((-1, n)).contiguous()\n    linear = nn.Linear(m, n)\n    linear.weight.data = ref.t()\n    # Workaround to test some special cases that are forbidden by the API\n    layer = marlin.Layer(256, 256, groupsize=groupsize)\n    if groupsize == -1:\n        groupsize = m\n    layer.k = m\n    layer.n = n\n    layer.groupsize = groupsize\n    layer.B = torch.empty((m // 16, n * 16 // 8), dtype=torch.int, device=\"cuda\")\n    layer.s = torch.empty((m // groupsize, n), dtype=torch.half, device=\"cuda\")\n    layer.pack(linear, s.t())\n    q = layer.B\n    s = layer.s\n    return ref, q, s\n\nif __name__ == '__main__':\n\n    m = 16\n    k = 4096\n    n = 4096\n    groupsize = 128\n    g = k // groupsize\n\n    a = make_tensor(m, k, dtype=torch.float16)\n    b = make_tensor(k//8, n, dtype=torch.int32)\n    c = make_tensor(m, n, dtype=torch.float16)\n    workspace = torch.zeros(n//128*16, device=\"cuda\")\n\n    zeros = make_tensor(g, n//8, torch.int32)\n    scales = make_tensor(g, n, torch.float16)\n\n\n    # Marlin\n    # m, n, k = 16, 4096, 4096\n    # A = torch.randn((m, k), dtype=torch.half, device=\"cuda\")\n    # B_ref, B, s = gen_quant4(k, n)\n    # C = torch.zeros((m, n), dtype=torch.half, device=\"cuda\")\n    # workspace = torch.zeros(n // 128*16, device=\"cuda\")\n\n    output_marlin = marlin.mul(a, b, c, scales, workspace, sms=108)\n    output_split_k = matmul_split_k(a, b, scales, zeros)\n    nbits = 4\n    group_size=128\n    disable_exllama=True\n    disable_exllamav2=False\n    use_triton = False\n\n    linear_class = dynamically_import_QuantLinear(\n    disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2,\n    use_triton=use_triton, desc_act=False, group_size=group_size, bits=nbits)\n\n    linear = linear_class(\n    bits=nbits,\n    group_size=group_size,\n    infeatures=k,\n    outfeatures=n,\n    bias=0,\n    )\n\n    device = torch.device('cuda')\n\n    linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32)\n    linear.scales = linear.scales + 0.002\n\n    linear = linear.eval().to(device)\n    linear = autogptq_post_init(linear, use_act_order=False)\n\n    b_fake = torch.randn((k, n), dtype=torch.float16, device=\"cuda\")\n\n    # Warmup\n    for i in range(3):\n        linear(a)\n        matmul_split_k(a, b, scales, zeros)\n        torch.matmul(a, b_fake)\n\n\n    s = torch.cuda.Stream()\n    s.wait_stream(torch.cuda.current_stream())\n\n    with torch.cuda.stream(s):\n        matmul_split_k(a, b, scales, zeros)\n\n    torch.cuda.current_stream().wait_stream(s)\n\n    # capture\n    g = torch.cuda.CUDAGraph()\n\n    with torch.cuda.graph(g):\n        matmul_split_k(a, b, scales, zeros)\n\n    for i in range(7):\n        torch.matmul(a, b_fake)\n\n    for i in range(7):\n        linear(a)\n\n    for i in range(7):\n        g.replay()  # This replays the captured operations in the graph\n        \n\n    for i in range(7):\n        matmul_data_parallel(a, b, scales, zeros)\n\n    for i in range(7):\n        matmul_split_k(a, b, scales, zeros)"
  },
  {
    "path": "kernels/triton/inference/gptq/splitk_dequant_gemm.py",
    "content": "import torch\nimport triton\nfrom triton import language as tl\n# from actual_base_gptq_4 import triton_matmul4\n\n@triton.jit()\ndef swizzle_tile(pid,\n                m, n,\n                block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr):\n    \n    grid_m = tl.cdiv(m, block_m)\n    grid_n = tl.cdiv(n, block_n)\n\n    width = group_m * grid_n\n    group_id = pid // width\n    group_size = tl.minimum(grid_m - group_id * group_m, group_m)\n\n    pid_m = group_id * group_m + (pid % group_size)\n    pid_n = (pid % width) // group_size\n\n    return pid_m, pid_n\n\n@triton.jit()\ndef matmul_split_k_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr,\n            stride_am, stride_ak,\n            stride_bk, stride_bn,\n            stride_cm, stride_cn,\n            stride_scales_g, stride_scales_n,\n            stride_zeros_g, stride_zeros_n,\n            groupsize,\n            m, n, k,\n            block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr,\n            group_m: tl.constexpr, split_k: tl.constexpr):\n    \n    pid = tl.program_id(0)\n    pid_k = tl.program_id(1)\n    total_blocks_k = tl.cdiv(k, block_k*split_k)\n\n    pid_m, pid_n = swizzle_tile(pid,\n                                m, n,\n                                block_m, block_n, group_m)\n    \n    offs_m = pid_m*block_m + tl.arange(0, block_m)\n    offs_n = pid_n*block_n + tl.arange(0, block_n)\n    offs_k = pid_k*block_k + tl.arange(0, block_k)\n\n    offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m)\n    offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n)\n\n    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n    b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn)\n\n    scales_ptrs = scales_ptr + offs_bn * stride_scales_n\n    zeros_ptrs = zeros_ptr + ((offs_bn // 8) * stride_zeros_n)\n\n    shifter = (offs_k % 8) * 4\n    zeros_shifter = (offs_bn % 8) * 4\n    \n    acc = tl.zeros((block_m, block_n), dtype=tl.float32)\n    for k in range(0, total_blocks_k):\n        \n        a = tl.load(a_ptrs)\n        b = tl.load(b_ptrs)\n        \n        g_id = (k * split_k + pid_k) // (groupsize // block_k)\n\n        ptr = scales_ptrs + g_id * stride_scales_g\n        scales = tl.load(ptr)\n        \n        ptr = zeros_ptrs + g_id * stride_zeros_g\n        zeros = tl.load(ptr) \n\n        zeros = (zeros >> zeros_shifter) & 0xF\n        zeros = (zeros + 1) * scales\n\n        b = (b >> shifter[:, None]) & 0xF\n        b = b * scales[None, :] - zeros[None, :]\n\n        acc += tl.dot(a, b)\n        a_ptrs += block_k * split_k * stride_ak\n        b_ptrs += (block_k // 8) * split_k * stride_bk\n\n    acc.to(tl.float16)\n\n    offs_m = pid_m*block_m + tl.arange(0, block_m)\n    offs_n = pid_n*block_n + tl.arange(0, block_n)\n\n    c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)\n    tl.atomic_add(c_ptrs, acc, sem='release')\n\ndef matmul_split_k(a, b, scales, zeros):\n\n    m, k = a.shape\n    _, n = b.shape\n    \n    quant_groupsize = 128\n    block_m = 16\n    block_n = 32\n    block_k = 128\n    group_m = 8\n    num_stages = 3\n    num_warps = 4\n    split_k = 4\n\n    total_blocks_m = triton.cdiv(m, block_m)\n    total_blocks_n = triton.cdiv(n, block_n)\n    total_programs_mn = total_blocks_m * total_blocks_n\n    total_programs_k = split_k\n    \n    grid = (total_programs_mn, total_programs_k)\n\n    print(f\"problem m size: {m}, tile size m: {block_m}, total blocks m: {total_blocks_m}\")\n    print(f\"problem n size: {n}, tile size n: {block_n}, total blocks n: {total_blocks_n}\")\n    print(f\"problem k size: {k}, tile size k: {block_k}, total thread blocks k: {split_k}\")\n\n    print(f\"total thread blocks k: {k}, total thread blocks m and total thread blocks n = {total_blocks_m=} x {total_blocks_n} = {total_programs_mn}\")\n    print(f\"{total_programs_mn=}, {total_programs_k=}\")\n    \n    c = torch.zeros((m, n), device=a.device, dtype=torch.float16)\n    k = matmul_split_k_kernel[grid](a, b, c, scales, zeros,\n                              a.stride(0), a.stride(1),\n                              b.stride(0), b.stride(1),\n                              c.stride(0), c.stride(1),\n                              scales.stride(0), scales.stride(1),\n                              zeros.stride(0), zeros.stride(1),\n                              quant_groupsize,\n                              m, n, k,\n                              block_m, block_n, block_k,\n                              group_m, split_k, num_stages=num_stages, num_warps=num_warps)\n    \n    print(f\"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\\n\")\n\n    with open('matmul_split_k.txt', 'w') as f:\n\n        print(f\"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\\n\", file=f)\n        print(\"IR\", k.asm['ttir'], file=f)\n        print(\"TTGIR\", k.asm['ttgir'], file=f)\n        print(\"PTX\", k.asm['ptx'], file=f)\n        print(f\"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\\n\", file=f)\n\n    return c\n\ndef make_tensor(M, N, dtype):\n    if dtype == torch.int32:\n        # Fill with random integers for int32 type\n        res = torch.randint(low=-2147483648, high=2147483647, size=(M, N), dtype=dtype, device=\"cuda\")\n    else:\n        # Fill with normally distributed random values for other types\n        res = torch.empty((M, N), dtype=dtype, device=\"cuda\")\n        res.normal_(mean=0.0, std=0.5)\n    return res\n\n\nif __name__ == '__main__':\n\n    m = 16\n    k = 4096\n    n = 4096\n    groupsize = 128\n    g = k // groupsize\n\n    a = make_tensor(m, k, dtype=torch.float16)\n    b = make_tensor(k//8, n, dtype=torch.int32)\n    c = make_tensor(m, n, dtype=torch.float16)\n    zeros = make_tensor(g, n//8, torch.int32)\n    scales = make_tensor(g, n, torch.float16)\n    \n    # base = no_autotune(groupsize, a, b, scales, zeros)\n    # print(f\"{base.shape=}, {base[0][0:4]}\")\n\n    # c = custom_qlinear(a, b, scales, zeros)\n    # print(f\"{c.shape=}, {c[0][0:4]}\")\n\n\n    split_k_output = matmul_split_k(a, b, scales, zeros)\n    print(f\"{split_k_output.shape=}, {split_k_output[0][0:4]}\")\n\n\n"
  },
  {
    "path": "kernels/triton/inference/mamba/causal_1d_conv/causal_1d_conv/causal_1d_conv.py",
    "content": "# Copyright (c) 2025, IBM Research\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange\nfrom typing import Literal, Optional\n\n\n@triton.autotune(\n    configs=[\n        triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 256}, num_stages=3, num_warps=8),\n        triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128}, num_stages=3, num_warps=8),\n        triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64}, num_stages=3, num_warps=8),\n        triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 32}, num_stages=3, num_warps=8),\n        triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 128}, num_stages=3, num_warps=8),\n        triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 64}, num_stages=3, num_warps=8),\n        triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 32}, num_stages=3, num_warps=8),\n        triton.Config({\"BLOCK_M\": 32, \"BLOCK_N\": 256}, num_stages=3, num_warps=8),\n        triton.Config({\"BLOCK_M\": 32, \"BLOCK_N\": 128}, num_stages=3, num_warps=8),\n        triton.Config({\"BLOCK_M\": 32, \"BLOCK_N\": 64}, num_stages=3, num_warps=8),\n        triton.Config({\"BLOCK_M\": 32, \"BLOCK_N\": 32}, num_stages=3, num_warps=8),\n    ],\n    key=[\"seqlen\", \"dim\", \"batch\"],\n)\n@triton.jit()\ndef _causal_conv1d_fwd_kernel(\n    # Pointers to matrices\n    x_ptr,  # (batch, dim, seqlen)\n    w_ptr,  # (dim, width)\n    bias_ptr,\n    initial_states_ptr,\n    o_ptr,  # (batch, dim, seqlen)\n    # Matrix dimensions\n    batch,\n    dim,\n    seqlen,\n    # Strides\n    stride_x_seq,  # stride to get to next sequence,\n    stride_x_dim,  # stride to get to next feature-value,\n    stride_x_token,  # stride to get to next token (same feature-index, same sequence-index)\n    stride_weight_dim,  # stride to get to next dim-axis value\n    stride_weight_width,  # stride to get to next width-axis value\n    stride_istate_seq,\n    stride_istate_dim,\n    stride_istate_token,\n    stride_o_seq,\n    stride_o_dim,\n    stride_o_token,\n    # Meta-parameters\n    HAS_BIAS: tl.constexpr,\n    KERNEL_WIDTH: tl.constexpr,  # maybe using this we don't need 'width'\n    SILU_ACTIVATION: tl.constexpr,\n    HAS_INITIAL_STATES: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    indices_0 = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n    idx_seqs = indices_0 // seqlen\n    idx_tokens = indices_0 % seqlen\n\n    x_base = x_ptr + (idx_seqs * stride_x_seq)[:, None]  # the beginning features at all tokens at all sequences processed by this Triton program\n    idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)\n    w_base = w_ptr + (idx_feats * stride_weight_dim)  # first kernel column, configured for weights to handle BLOCK_N features in range\n    load_init_state = False\n    if HAS_INITIAL_STATES:\n        load_init_state = tl.min(idx_tokens) < KERNEL_WIDTH - 1\n        initial_states_base = initial_states_ptr + (idx_seqs * stride_istate_seq)[:, None] + (idx_feats * stride_istate_dim)[None, :]\n\n    # store output data at the corresponding tokens (BLOCK_M of them) and feature-indices (BLOCK_N of them) in these tokens\n    if HAS_BIAS:\n        bias = bias_ptr + idx_feats\n        mask_bias = idx_feats < dim\n        acc = tl.load(bias, mask=mask_bias, other=0.0).to(tl.float32)[None, :]  # [BLOCK_N]\n        acc = tl.broadcast_to(acc, (BLOCK_M, BLOCK_N))\n    else:\n        acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n    PADDING_W = KERNEL_WIDTH - 1\n    for j in range(KERNEL_WIDTH):\n        idx_x_w = j - PADDING_W + idx_tokens  # the token index to multiply with kernel[:, 0], given kernel with width-columns, i.e. kernel[:, 0..(width-1)]\n        x_ptrs = x_base + ((idx_x_w * stride_x_token)[:, None] + (idx_feats * stride_x_dim)[None, :])  # [BLOCK_M, BLOCK_N]\n        mask_x = ((idx_seqs < batch)[:, None]  # sequence-index\n                  & (idx_x_w >= 0)[:, None]  # token-index\n                  & (idx_x_w < seqlen)[:, None]  # token-index\n                  & (idx_feats < dim)[None, :]  # feature-index\n                  )\n        if HAS_INITIAL_STATES:\n            if load_init_state:\n                initial_states_ptrs = initial_states_base + ((idx_x_w + KERNEL_WIDTH - 1) * stride_istate_token)[:, None]  # [BLOCK_M, BLOCK_N]\n                mask_w = (idx_seqs < batch)[:, None] & (idx_x_w < 0)[:, None] & (idx_feats < dim)[None, :]  # sequence-index  # token-index  # feature-index\n                initial_states = tl.load(initial_states_ptrs, mask_w, 0.0)\n            else:\n                initial_states = tl.zeros((BLOCK_M, BLOCK_N), dtype=x_ptr.dtype.element_ty)\n            matrix_x = tl.load(x_ptrs, mask=mask_x, other=initial_states)\n        else:\n            matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)\n\n        w_ptrs = w_base[None, :] + \\\n            (j * stride_weight_width)  # [1, BLOCK_N] tensor\n        mask_w = (idx_feats < dim)[None, :]\n        matrix_w = tl.load(w_ptrs, mask_w, other=0.0)\n        acc += matrix_x * matrix_w\n\n    if SILU_ACTIVATION:\n        acc = acc / (1 + tl.exp(-acc))\n    mask = (\n        (idx_seqs < batch)[:, None]  # sequence-index\n        & (idx_tokens < seqlen)[:, None]  # token-index\n        & (idx_feats < dim)[None, :]  # feature-index\n    )\n    o_ptrs = (\n        o_ptr\n        + (idx_seqs * stride_o_seq)[:, None]\n        + (idx_tokens * stride_o_token)[:, None]\n        + (idx_feats * stride_o_dim)[None, :]\n    )\n\n    tl.store(o_ptrs, acc, mask=mask)\n\n\ndef causal_conv1d_fwd(\n    x: torch.Tensor,\n    weight: torch.Tensor,\n    bias: Optional[torch.Tensor] = None,\n    seq_idx: Optional[torch.Tensor] = None,\n    initial_states: Optional[torch.Tensor] = None,\n    return_final_states: Optional[torch.Tensor] = False,\n    final_states_out: Optional[torch.Tensor] = None,\n    activation: Optional[Literal[\"silu\", \"swish\"]] = None,\n):\n    batch, dim, seqlen = x.shape\n    _, width = weight.shape\n    assert (dim, width) == weight.shape\n    assert x.stride(2) == 1 or x.stride(1) == 1\n    # TODO: we may want to use weight such that weight.stride(dim)==1\n    assert weight.stride(1) == 1\n    # Tensor layout as NHWC is called channel last with 'C' is time-dimension\n    is_channel_last = (x.stride(1) == 1) & (x.stride(2) > 1)\n    stride_w_dim = weight.stride(0)\n    stride_w_width = weight.stride(1)\n    # effort to make data contiguous along dim-axis:\n    weight = weight.transpose(0, 1).contiguous()\n    stride_w_dim = weight.stride(1)\n    stride_w_width = weight.stride(0)\n\n    # assert initial_states is None  # only this for now\n    assert return_final_states is False\n    stride_istate_seq = 0\n    stride_istate_dim = 0\n    stride_istate_token = 0\n    if initial_states is not None:\n        assert (batch, dim, width - 1) == initial_states.shape\n        stride_istate_seq = initial_states.stride(0)\n        stride_istate_dim = initial_states.stride(1)\n        stride_istate_token = initial_states.stride(2)\n        assert stride_istate_dim == 1\n\n    out = torch.empty_like(x)\n\n    if not is_channel_last:\n        assert 0, \"Need to run in channel-last layout\"\n    else:\n\n        def grid(META):\n            return (\n                triton.cdiv(batch * seqlen, META[\"BLOCK_M\"]),\n                triton.cdiv(dim, META[\"BLOCK_N\"]),\n            )\n\n        with torch.cuda.device(x.device.index):\n            _causal_conv1d_fwd_kernel[grid](\n                # Pointers to matrices\n                x,\n                weight,\n                bias,\n                initial_states,\n                out,\n                # Matrix dimensions\n                batch,\n                dim,\n                seqlen,\n                # stride\n                x.stride(0),\n                x.stride(1),\n                x.stride(2),\n                stride_w_dim,\n                stride_w_width,\n                stride_istate_seq,\n                stride_istate_dim,\n                stride_istate_token,\n                out.stride(0),\n                out.stride(1),\n                out.stride(2),\n                # META\n                HAS_BIAS=bias is not None,\n                KERNEL_WIDTH=width,\n                SILU_ACTIVATION=activation in [\"silu\", \"swish\"],\n                HAS_INITIAL_STATES=initial_states is not None,\n            )\n    return out\n\n\nclass CausalConv1dFn(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        x,\n        weight,\n        bias=None,\n        seq_idx=None,\n        initial_states=None,\n        return_final_states: bool = False,\n        final_states_out=None,\n        activation: Optional[Literal[\"silu\", \"swish\"]] = None,\n    ):\n        # NOTE: in fact, 'beta=1' would turn swish into silu - and only silu form is used\n        if x.stride(2) != 1 and x.stride(1) != 1:\n            x = x.contiguous()\n        bias = bias.contiguous() if bias is not None else None\n        if seq_idx is not None:\n            assert initial_states is None, \"initial_states must be None if seq_idx is not None\"\n            assert not return_final_states, \"If seq_idx is not None, we don't return final_states_out\"\n        seq_idx = seq_idx.contiguous() if seq_idx is not None else None\n        if initial_states is not None and ((initial_states.stride(2) != 1) and (initial_states.stride(1) != 1)):\n            initial_states = initial_states.contiguous()\n        if return_final_states:\n            assert (\n                x.stride(1) == 1\n            ), \"Only channel-last layout support returning final_states_out\"\n            if final_states_out is not None:\n                assert (\n                    (final_states_out.stride(2) == 1) or (\n                        final_states_out.stride(1) == 1)\n                )\n            else:\n                batch, dim, seqlen = x.shape\n                width = weight.shape[1]\n                final_states_out = torch.empty(\n                    batch, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)\n        else:\n            final_states_out = None\n        ctx.activation = activation\n        out = causal_conv1d_fwd(\n            x,\n            weight,\n            bias=bias,\n            seq_idx=seq_idx,\n            initial_states=initial_states,\n            return_final_states=return_final_states,\n            final_states_out=final_states_out,\n            activation=ctx.activation,\n        )\n        ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)\n        ctx.return_final_states = return_final_states\n        ctx.return_dinitial_states = initial_states is not None and initial_states.requires_grad\n        return out if not return_final_states else (out, final_states_out)\n\n    @staticmethod\n    def backward(ctx, dout, *args):\n        \"\"\"dout = dL/dy\n        RETURN: dL/dx, dL/dweight, dL/dbias, ...\n        GIVEN THAT: def forward(ctx, x, weight, bias=None...)\n        \"\"\"\n        x, weight, bias, seq_idx, initial_states = ctx.saved_tensors\n        dfinal_states = args[0] if ctx.return_final_states else None\n        if dout.stride(2) != 1 and dout.stride(1) != 1:\n            dout = dout.contiguous()\n        # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the\n        # backward of conv1d with the backward of chunk).\n        # Here we just pass in None and dx will be allocated in the C++ code.\n        dx, dweight, dbias, dinitial_states = causal_conv1d_bwd(\n            x,\n            weight,\n            bias,\n            dout,\n            seq_idx,\n            initial_states,\n            dfinal_states,\n            None,\n            ctx.return_dinitial_states,\n            ctx.activation,\n        )\n        return (\n            dx,\n            dweight,\n            dbias if bias is not None else None,\n            None,\n            dinitial_states if initial_states is not None else None,\n            None,\n            None,\n            None,\n        )\n\n\ndef causal_conv1d_fn(\n    x,  # channel last, i.e. (batch, dim, seqlen)\n    weight,  # (dim, w)\n    bias=None,  # (dim, )scalar\n    seq_idx=None,\n    initial_states=None,\n    return_final_states=False,\n    final_states_out=None,\n    activation: Optional[Literal[\"silu\", \"swish\"]] = None,\n):\n    \"\"\"causal_conv1d_fn.\n\n    :param x: (batch, dim, seqlen) tensor\n    :param weight: (dim, w) tensor\n    :param bias: (dim,) tensor\n    :param activation: [\"silu\", \"swish\"]\n    :param seq_idx=None\n    :param initial_states=None\n    :param return_final_states=False\n    :param final_states_out=None\n\n    Return: (batch, dim, seqlen) tensor\n    \"\"\"\n    if weight.dim() == 3:\n        assert weight.shape[1] == 1\n        weight = rearrange(weight, \"d 1 w -> d w\")\n    return CausalConv1dFn.apply(\n        x,\n        weight,\n        bias,\n        seq_idx,\n        initial_states,\n        return_final_states,\n        final_states_out,\n        activation,\n    )\n"
  },
  {
    "path": "kernels/triton/inference/mamba/causal_1d_conv/tests/test_causal_1d_conv.py",
    "content": "# Copyright (C) 2025, IBM Research.\n# python -m pytest tests/test_causal_conv1d.py\n\nimport sys\nfrom einops import rearrange\nimport pytest\nimport torch.nn.functional as F\nimport torch\nimport math\n\nimport os\nfrom pathlib import Path\n\nbase_path = Path(os.path.abspath(os.path.dirname(os.path.realpath(__file__))))\n\nsys.path.insert(0, str(base_path / \"../causal_1d_conv\"))\n\ntry:\n    from causal_1d_conv import causal_conv1d_fn\nexcept ImportError:\n    raise\n\n\ndef _undecorated_test_causal_conv1d(\n    batch,\n    dim,\n    seqlen,\n    width,\n    has_bias,\n    silu_activation,\n    itype,\n    channel_last,\n    has_initial_states,\n    return_final_states,\n    check_backward,\n):\n    if not channel_last and (has_initial_states or return_final_states):\n        pytest.skip(\"Only channel_last support initial_states or return_final_states\")\n    device = \"cuda\"\n    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)\n    if itype == torch.bfloat16:\n        rtol, atol = 1e-2, 5e-2\n    rtolw, atolw = (1e-3, 1e-3)\n    # set seed\n    torch.random.manual_seed(0)\n    if not channel_last:\n        x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[\n            :, 4096: 4096 + dim, :\n        ].requires_grad_()\n    else:\n        x = rearrange(\n            torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096: 4096 + dim],\n            \"b s d -> b d s\",\n        ).requires_grad_()\n    weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)\n    if has_bias:\n        bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)\n    else:\n        bias = None\n    if has_initial_states:\n        initial_states = torch.randn(batch, width - 1, dim, device=device, dtype=itype).transpose(1, 2).requires_grad_()\n    else:\n        initial_states = None\n    x_ref = x.detach().clone().requires_grad_()\n    weight_ref = weight.detach().clone().requires_grad_()\n    bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None\n    initial_states_ref = initial_states.detach().clone().requires_grad_() if initial_states is not None else None\n    activation = None if not silu_activation else \"silu\"\n    out = causal_conv1d_fn(\n        x, weight, bias, initial_states=initial_states, return_final_states=return_final_states, activation=activation\n    )\n    out_ref = causal_conv1d_ref(\n        x_ref,\n        weight_ref,\n        bias_ref,\n        initial_states=initial_states_ref,\n        return_final_states=return_final_states,\n        activation=activation,\n    )\n    if return_final_states:\n        out, final_states = out\n        out_ref, final_states_ref = out_ref\n        print(f\"Final states max diff: {(final_states - final_states_ref).abs().max().item()}\")\n        print(f\"Final states mean diff: {(final_states - final_states_ref).abs().mean().item()}\")\n        assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)\n\n    if return_final_states:\n        out += F.sigmoid(final_states).sum(dim=-1, keepdim=True)\n        out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True)\n\n    if check_backward:\n        g = torch.randn_like(out)\n        out.backward(g)\n        out_ref.backward(g)\n\n        print(f\"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}\")\n        print(f\"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}\")\n        if has_bias:\n            print(f\"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}\")\n        if has_initial_states:\n            print(f\"dinitial_states max diff: {(initial_states.grad - initial_states_ref.grad).abs().max().item()}\")\n\n        assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)\n        assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)\n        if has_bias:\n            assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)\n        if has_initial_states:\n            assert torch.allclose(initial_states.grad, initial_states_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)\n    torch.cuda.empty_cache()\n    del x_ref, x, weight, weight_ref, bias, bias_ref, out, out_ref\n\n\ndef causal_conv1d_ref(\n    x,\n    weight,\n    bias=None,\n    initial_states=None,\n    return_final_states=False,\n    final_states_out=None,\n    activation=None,\n):\n    \"\"\"[copied from causal_conv1d/causal_conv1d_interface.py]\n    x: (batch, dim, seqlen)\n    weight: (dim, width)\n    bias: (dim,)\n    initial_states: (batch, dim, width - 1)\n    final_states_out: (batch, dim, width - 1)\n\n    out: (batch, dim, seqlen)\n    \"\"\"\n    if activation not in [None, \"silu\", \"swish\"]:\n        raise NotImplementedError(\"activation must be None, silu, or swish\")\n    dtype_in = x.dtype\n    x = x.to(weight.dtype)\n    seqlen = x.shape[-1]\n    dim, width = weight.shape\n    if initial_states is None:\n        out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)\n    else:\n        x = torch.cat([initial_states, x], dim=-1)\n        out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)\n    out = out[..., :seqlen]\n    if return_final_states:\n        final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(dtype_in)  # (batch, dim, width - 1)\n        if final_states_out is not None:\n            final_states_out.copy_(final_states)\n        else:\n            final_states_out = final_states\n    out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)\n    return out if not return_final_states else (out, final_states_out)\n\n\n@pytest.mark.parametrize(\"batch\", [1, 2, 3, 8, 16, 32, 64])  # END-GOAL\n# @pytest.mark.parametrize(\"batch\", [2])\n@pytest.mark.parametrize(\"dim\", [64, 4096 + 32])  # END-GOAL\n# @pytest.mark.parametrize('dim', [64])\n# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])\n# @pytest.mark.parametrize('seqlen', [128])\n@pytest.mark.parametrize(\n    \"seqlen\", [1, 2, 8, 16, 32, 64, 128, 129, 130, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]\n)  # END-GOAL\n@pytest.mark.parametrize(\"width\", [2, 3, 4, 5])  # END-GOAL\n# @pytest.mark.parametrize('width', [3])\n@pytest.mark.parametrize(\"has_bias\", [False, True])  # END-GOAL\n# @pytest.mark.parametrize('has_bias', [True])\n# @pytest.mark.parametrize('has_bias', [False])\n@pytest.mark.parametrize(\"silu_activation\", [False, True])  # END-GOAL\n# @pytest.mark.parametrize(\"silu_activation\", [True])\n@pytest.mark.parametrize(\"itype\", [torch.float32, torch.float16, torch.bfloat16])\n# @pytest.mark.parametrize('itype', [torch.float16])\n# @pytest.mark.parametrize(\"channel_last\", [False, True])\n@pytest.mark.parametrize(\"channel_last\", [True])  # END-GOAL\n@pytest.mark.parametrize(\"has_initial_states\", [False, True])  # END-GOAL\n# @pytest.mark.parametrize(\"has_initial_states\", [False])\n# @pytest.mark.parametrize(\"return_final_states\", [False, True]) # END-GOAL\n@pytest.mark.parametrize(\"return_final_states\", [False])\n# @pytest.mark.parametrize('check_backward', [True]) # END-GOAL\n@pytest.mark.parametrize(\"check_backward\", [False])\ndef test_causal_conv1d(\n    batch,\n    dim,\n    seqlen,\n    width,\n    has_bias,\n    silu_activation,\n    itype,\n    channel_last,\n    has_initial_states,\n    return_final_states,\n    check_backward,\n):\n    return _undecorated_test_causal_conv1d(\n        batch,\n        dim,\n        seqlen,\n        width,\n        has_bias,\n        silu_activation,\n        itype,\n        channel_last,\n        has_initial_states,\n        return_final_states,\n        check_backward,\n    )\n"
  },
  {
    "path": "kernels/triton/inference/paged_attention/attention_triton.py",
    "content": "#from einops import rearrange\nimport torch\nimport triton\nimport triton.language as tl\n\n# Credit:\n# vedantroy https://github.com/openai/triton/issues/2200#issuecomment-1815471999\n\n# Expect block table to map\n# logical bid (block id) -> (physical bid, # filled)\n# In tests, it maps: logical pid -> physical bid\n\n@triton.jit\ndef print_tensor_dim(tensor, str_name):\n    if tl.program_id(0) == 0 and tl.program_id(1) == 0:\n        tl.static_print(str_name,\" \",tensor.shape,\" \",tensor.dtype)\n        #tl.static_print('*************** program id: ', tl.program_id(0), tl.program_id(1))\n\n@triton.jit\ndef print_value(value):\n    if tl.program_id(0) == 0 and tl.program_id(1) == 0:\n        tl.device_print(str(value))\n        #tl.static_print('*************** program id: ', tl.program_id(0), tl.program_id(1))\n        #tl.static_print(str_name+\" \")\n\n@triton.jit\ndef print_line(str_line):\n    if tl.program_id(0) == 0 and tl.program_id(1) == 0:\n        print(str_line)\n\n#Paged Attention V1: basic version, has a memory limitation error\n@triton.jit\ndef paged_attention_v1(\n    # need these b/c we can't use view/reshape\n    scratchpad_key_ptr,  # [num_seqs, max_seq_len, num_heads, head_size]\n    scratchpad_value_ptr,  # [num_seqs, max_seq_len, num_heads, head_size]\n    output_ptr,  # [num_seqs, num_query_heads, head_size]\n    query_ptr,  # [num_seqs, num_query_heads, head_size]\n    key_cache_ptr,  # [num_blocks, num_kv_heads, head_size, block_size]\n    value_cache_ptr,  # [num_blocks, num_kv_heads, head_size, block_size]\n    block_tables_ptr,  # [num_seqs, max_num_blocks_per_seq]\n    context_lens_ptr,  # [num_seqs]\n    scale,  # float32\n    num_seqs,  # int\n    num_heads,  # int\n    cache_block_stride,  # int\n    MAX_SEQ_LEN: tl.constexpr,  # int (same as max_seq_len)\n    BLOCK_SIZE: tl.constexpr,  # int\n    HEAD_SIZE: tl.constexpr,  # int, must be power of 2\n    MAX_NUM_BLOCKS_PER_SEQ: tl.constexpr,  # int, must be power of 2\n):\n    seq_idx = tl.program_id(0).to(tl.int64)\n    head_idx = tl.program_id(1).to(tl.int64)\n     \n    #Compute the offsets of the query using the strides\n    #TODO(amorari) use the strides as returned from tensor.stride() instead \n    query_offset = seq_idx * num_seqs + head_idx * HEAD_SIZE\n    query_head = tl.load(query_ptr + query_offset + tl.arange(0, HEAD_SIZE))\n    #print_tensor_dim(query_head, \"query_head\")\n    \n    block_table_offset = seq_idx * MAX_NUM_BLOCKS_PER_SEQ\n    #load the context len for this q vector\n    context_len = tl.load(context_lens_ptr + seq_idx)\n\n    #print_tensor_dim(block_tables_ptr, \"block_tables_ptr\")\n   \n    #iterate on the tokens\n    for tok_idx in range(0, context_len):\n        logical_block_idx = tok_idx // BLOCK_SIZE\n        \n        #physical block starting pointer for token\n        physical_block_idx = tl.load(\n            block_tables_ptr + block_table_offset + logical_block_idx\n        )\n\n        start_of_block_offset = (\n            physical_block_idx.to(tl.int64) * cache_block_stride + head_idx * HEAD_SIZE * BLOCK_SIZE\n        )\n        tok_idx_within_block = tok_idx % BLOCK_SIZE\n        tok_offsets = (\n            start_of_block_offset\n            + BLOCK_SIZE * tl.arange(0, HEAD_SIZE)\n            + tok_idx_within_block\n        )\n\n        #Get all blocks for this token\n        tok_key = tl.load(key_cache_ptr + tok_offsets)\n        tok_value = tl.load(value_cache_ptr + tok_offsets)\n        #print_tensor_dim(tok_key, \"tok_key\")\n        #print_tensor_dim(tok_value, \"tok_value\")\n\n        #Compute offsets to write in the scratchpad\n        scratchpad_offset = (\n            seq_idx.to(tl.int64) * (MAX_SEQ_LEN * num_heads.to(tl.int64) * HEAD_SIZE)\n            + tok_idx.to(tl.int64) * (num_heads * HEAD_SIZE)\n            + head_idx * HEAD_SIZE\n        )\n        tl.store(\n            scratchpad_key_ptr + scratchpad_offset + tl.arange(0, HEAD_SIZE), tok_key\n        )\n        tl.store(\n            scratchpad_value_ptr + scratchpad_offset + tl.arange(0, HEAD_SIZE),\n            tok_value,\n        )\n\n\n    # TODO: Not sure if this is necessary\n    tl.debug_barrier()\n\n    # scratchpad_key_ptr,  # [num_seqs, max_seq_len, num_heads, head_size]\n    start_seq_offset = (MAX_SEQ_LEN * num_heads * HEAD_SIZE) * seq_idx\n\n    start_tok_offset = start_seq_offset + tl.arange(0, MAX_SEQ_LEN) \\\n        * (num_heads * HEAD_SIZE) + head_idx * HEAD_SIZE\n\n\n    # [seq_len, head_size]\n    # zero out keys that aren't part of the sequence\n\n    mask = tl.arange(0, MAX_SEQ_LEN)[:, None] < context_len\n    kv_offs = start_tok_offset[:, None] + tl.arange(0, HEAD_SIZE)[None, :]\n    print_tensor_dim(kv_offs, \"kv_offs_v1\")\n    keys = tl.load(scratchpad_key_ptr + kv_offs, mask=mask, other=0.0)\n    print_tensor_dim(keys, \"keys_v1\")\n    values = tl.load(scratchpad_value_ptr + kv_offs, mask=mask, other=0.0)\n    print_tensor_dim(values, \"values_v1\")\n\n    # keys shape  [seq_len x head_size], query shape = [head_size]\n    # Can't do below b/c minimum size on all dimensions is 16\n    # scores = tl.dot(query_head[None, :], keys.T)\n    \n    scores = tl.sum(scale * keys * query_head[None, :], axis=1)\n\n    # This mask is necessary b/c even though we mask out the keys on load\n    # that just results in 0s in the attention dot product, \n    # which then get softmaxed and result in non-zero values \n    # in the softmax output (which is wrong)\n    # -inf guarantees that the softmax output will be 0 for masked values\n    mask = tl.full([MAX_SEQ_LEN], -float('inf'), dtype=tl.float32)\n    cond = tl.arange(0, MAX_SEQ_LEN) < context_len\n    scores_masked = tl.where(cond, scores, mask)\n\n    # do a numerically stable softmax on the scores\n    scores_minus_max = scores_masked - tl.max(scores_masked, axis=0)\n\n    \n    numerator = tl.exp(scores_minus_max)\n    denominator = tl.sum(numerator, axis=0) + float(1e-6)\n    logits = numerator / denominator\n    print_tensor_dim(logits, \"logits_v1\")\n\n    weighted_values = tl.sum(values * logits[:, None], axis=0)\n    print_tensor_dim(weighted_values, \"weighted_values_v1\")\n\n    output_offset = seq_idx * (num_heads * HEAD_SIZE) + head_idx * HEAD_SIZE\n    tl.store(output_ptr + output_offset + tl.arange(0, HEAD_SIZE), weighted_values)\n\ndef paged_attention_triton_v1(\n            output,\n            query,\n            key_cache,\n            value_cache,\n            #head_mapping,\n            scale,\n            block_tables,\n            context_lens,\n            block_size,\n            #max_seq_len,\n            #alibi_slopes, \n            num_seqs,\n            num_query_heads,\n            max_seq_len,\n            max_num_blocks_per_seq,\n            head_size\n):\n    scratchpad_key = torch.zeros(\n        (num_seqs, max_seq_len, num_query_heads, head_size),\n        dtype=torch.float32,\n        device=\"cuda\",\n    )\n    \n    scratchpad_value = torch.zeros_like(scratchpad_key)\n\n    paged_attention_v1[(num_seqs, num_query_heads)](\n        scratchpad_key_ptr=scratchpad_key,\n        scratchpad_value_ptr=scratchpad_value,\n        output_ptr=output,\n        query_ptr=query,\n        key_cache_ptr=key_cache,\n        value_cache_ptr=value_cache,\n        block_tables_ptr=block_tables,\n        context_lens_ptr=context_lens,\n        scale=scale,\n        num_seqs=num_seqs,\n        num_heads=num_query_heads,\n        cache_block_stride=key_cache.stride(0),\n        MAX_SEQ_LEN=max_seq_len,\n        BLOCK_SIZE=block_size,\n        HEAD_SIZE=head_size,\n        MAX_NUM_BLOCKS_PER_SEQ=max_num_blocks_per_seq,\n    )\n\n\n#Paged Attention V2: Iterate on kv vectors to avoid memory limitation error (sram)\n@triton.jit\ndef paged_attention_v2(\n    # need these b/c we can't use view/reshape\n    scratchpad_key_ptr,  # [num_seqs, max_seq_len, num_heads, head_size]\n    scratchpad_value_ptr,  # [num_seqs, max_seq_len, num_heads, head_size]\n    partition_buf_ptr,\n    output_ptr,  # [num_seqs, num_query_heads, head_size]\n    query_ptr,  # [num_seqs, num_query_heads, head_size]\n    key_cache_ptr,  # [num_blocks, num_kv_heads, head_size, block_size]\n    value_cache_ptr,  # [num_blocks, num_kv_heads, head_size, block_size]\n    block_tables_ptr,  # [num_seqs, max_num_blocks_per_seq]\n    context_lens_ptr,  # [num_seqs]\n    scale,  # float32\n    num_seqs,  # int\n    num_heads,  # int\n    cache_block_stride,  # int\n    num_partitions, #int\n    PARTITION_SIZE: tl.constexpr, #int\n    MAX_SEQ_LEN: tl.constexpr,  # int\n    BLOCK_SIZE: tl.constexpr,  # int\n    HEAD_SIZE: tl.constexpr,  # int, must be power of 2\n    MAX_NUM_BLOCKS_PER_SEQ: tl.constexpr,  # int, must be power of 2\n):\n    seq_idx = tl.program_id(0).to(tl.int64)\n    head_idx = tl.program_id(1).to(tl.int64)\n    partition_idx = tl.program_id(2).to(tl.int64)\n   \n    #Compute the offsets of the query using the strides\n    #TODO(amorari) use the strides as returned from tensor.stride() instead \n    query_offset = seq_idx * num_seqs + head_idx * HEAD_SIZE\n\n    #load one q vector\n    query_head = tl.load(query_ptr + query_offset + tl.arange(0, HEAD_SIZE))\n    print_tensor_dim(query_head, \"query_head\")\n    \n    block_table_offset = seq_idx * MAX_NUM_BLOCKS_PER_SEQ\n    #load the context len for this q vector\n    context_len = tl.load(context_lens_ptr + seq_idx)\n    assert(context_len <= MAX_SEQ_LEN)\n\n    #iterate on the tokens in this partition\n    token_start_idx = partition_idx * PARTITION_SIZE\n    token_end_idx = min((partition_idx + 1) * PARTITION_SIZE, context_len)\n    #NOTE: For some sequence, it is possible that context_len < token_start_idx\n    for tok_idx in range(token_start_idx, token_end_idx):\n        logical_block_offset = tok_idx // BLOCK_SIZE\n        \n        #physical block starting pointer for token\n        physical_block_idx = tl.load(\n            block_tables_ptr + block_table_offset + logical_block_offset\n        )\n\n        start_of_block_offset = (\n            physical_block_idx * cache_block_stride + head_idx * HEAD_SIZE * BLOCK_SIZE\n        )\n\n        tok_idx_within_block = tok_idx % BLOCK_SIZE\n        tok_offsets = (\n            start_of_block_offset\n            + BLOCK_SIZE * tl.arange(0, HEAD_SIZE)\n            + tok_idx_within_block\n        )\n\n        tok_key = tl.load(key_cache_ptr + tok_offsets)\n        #print_tensor_dim(tok_key, \"tok_key\")\n        tok_value = tl.load(value_cache_ptr + tok_offsets)\n        #print_tensor_dim(tok_key, \"tok_value\")\n\n        scratchpad_offset = (\n            seq_idx.to(tl.int64) * (MAX_SEQ_LEN * num_heads.to(tl.int64) * HEAD_SIZE)\n            + tok_idx.to(tl.int64) * (num_heads.to(tl.int64) * HEAD_SIZE)\n            + head_idx * HEAD_SIZE\n        )\n\n        print_tensor_dim(scratchpad_key_ptr, \"scratchpad_key_ptr\")\n        mask=tl.full([HEAD_SIZE], 1, dtype=tl.float32) > 0\n        #store the keys in line\n        tl.store(\n            scratchpad_key_ptr + scratchpad_offset + tl.arange(0, HEAD_SIZE), tok_key, mask\n        )\n        #store the values in line\n        tl.store(\n            scratchpad_value_ptr + scratchpad_offset + tl.arange(0, HEAD_SIZE), \n            tok_value, mask\n        )\n\n    # TODO: Not sure if this is necessary\n    tl.debug_barrier()\n\n   \n    #start of the sequence\n    start_seq_offset = (MAX_SEQ_LEN * num_heads.to(tl.int64) * HEAD_SIZE) * seq_idx.to(tl.int64)\n    #offsets with the start of the token\n    start_tok_offsets = start_seq_offset.to(tl.int64) \\\n                    + tl.arange(0, PARTITION_SIZE) * (num_heads.to(tl.int64) * HEAD_SIZE) \\\n                    + head_idx.to(tl.int64) * HEAD_SIZE\n\n    # [seq_len, head_size]\n    # zero out keys that aren't part of the sequence\n\n    mask = tl.arange(0, PARTITION_SIZE)[:, None] < context_len\n    kv_offs = start_tok_offsets[:, None] + tl.arange(0, HEAD_SIZE)[None, :]\n    print_tensor_dim(kv_offs, \"kv_offs_v2\")\n    keys = tl.load(scratchpad_key_ptr + kv_offs, mask=mask, other=0.0)\n    print_tensor_dim(keys, \"keys_v2\")\n\n    # Can't do below b/c minimum size on all dimensions is 16\n    # scores = tl.dot(query_head[None, :], keys.T)\n    scores = tl.sum(scale * keys * query_head[None, :], axis=1)\n    print_tensor_dim(keys, \"scores_v2\")\n\n    partition_buf_offset = start_seq_offset \\\n        + head_idx.to(tl.int64) * HEAD_SIZE + partition_idx.to(tl.int64) * PARTITION_SIZE\n    print_tensor_dim(partition_buf_offset, \"partition_buf_offset_v2\")\n\n    tl.store(partition_buf_ptr + partition_buf_offset + tl.arange(0, PARTITION_SIZE), scores)\n        \n    #weighted_values = tl.zeros(HEAD_SIZE, dtype=tl.float32)\n\n    # This mask is necessary b/c even though we mask out the keys on load\n    # that just results in 0s in the attention dot product, \n    # which then get softmaxed and result in non-zero values \n    # in the softmax output (which is wrong)\n    # -inf guarantees that the softmax output will be 0 for masked values\n    mask = tl.full([PARTITION_SIZE], -float('inf'), dtype=tl.float32)\n    cond = tl.arange(0, PARTITION_SIZE) < context_len\n    scores_masked = tl.where(cond, scores, mask)\n\n    # do a numerically stable softmax on the scores\n    scores_minus_max = scores_masked - tl.max(scores_masked, axis=0)\n    numerator = tl.exp(scores_minus_max)\n    denominator = tl.sum(numerator, axis=0) + float(1e-6)\n\n    logits = numerator / denominator\n    print_tensor_dim(logits, \"logits_v2\")\n\n    values = tl.load(scratchpad_value_ptr + kv_offs, mask=mask, other=0.0)\n    print_tensor_dim(values, \"values_v2\")\n    weighted_values += tl.sum(values * logits[:, None], axis=0)\n    print_tensor_dim(weighted_values, \"weighed_values_v2\")\n\n    #output_offset = seq_idx.to(tl.int64) * (num_heads.to(tl.int64) * HEAD_SIZE) \\\n    #    + head_idx.to(tl.int64) * HEAD_SIZE + seq_partition_idx.to(tl.int64) * PARTITION_SIZE\n\n    #to_store_values=weighted_values.to(tl.float32)\n    #mask = tl.full([HEAD_SIZE], 1, dtype=tl.float32) > 0\n    #tl.store(output_ptr + output_offset + tl.arange(0, HEAD_SIZE), to_store_values, mask)\n\n\n    output_offset = seq_idx * (num_heads * HEAD_SIZE) + head_idx * HEAD_SIZE\n    tl.store(output_ptr + output_offset + tl.arange(0, HEAD_SIZE), weighted_values)\n\n\ndef paged_attention_triton_v2(\n            output,\n            query,\n            key_cache,\n            value_cache,\n            #head_mapping,\n            scale,\n            block_tables,\n            context_lens,\n            block_size,\n            partition_size,\n            #alibi_slopes, \n            num_seqs,\n            num_query_heads,\n            max_seq_len,\n            max_num_blocks_per_seq,\n            head_size\n):\n\n    scratchpad_key = torch.zeros(\n        (num_seqs, max_seq_len, num_query_heads, head_size),\n        dtype=torch.float32,\n        device=\"cuda\",\n    )\n\n    scratchpad_value = torch.zeros_like(scratchpad_key)\n\n    num_partitions = max_seq_len//partition_size\n    assert(max_seq_len % partition_size == 0)\n\n    partition_buf_ptr = torch.zeros((num_seqs,max_seq_len,num_query_heads,head_size),\n                                    dtype=torch.float32,\n                                    device=\"cuda\")\n   \n    #print(f\"started_v2 num_seqs: {num_seqs} num_query_heads: {num_query_heads}\")\n    paged_attention_v2[(num_seqs, num_query_heads, num_partitions)](\n        scratchpad_key_ptr=scratchpad_key,\n        scratchpad_value_ptr=scratchpad_value,\n        partition_buf_ptr=partition_buf_ptr,\n        output_ptr=output,\n        query_ptr=query,\n        key_cache_ptr=key_cache,\n        value_cache_ptr=value_cache,\n        block_tables_ptr=block_tables,\n        context_lens_ptr=context_lens,\n        scale=scale,\n        num_seqs=num_seqs,\n        num_heads=num_query_heads,\n        cache_block_stride=key_cache.stride(0),\n        num_partitions=num_partitions,\n        PARTITION_SIZE=partition_size,\n        MAX_SEQ_LEN=max_seq_len,\n        BLOCK_SIZE=block_size,\n        HEAD_SIZE=head_size,\n        MAX_NUM_BLOCKS_PER_SEQ=max_num_blocks_per_seq,\n    )\n    #print(\"finished_v2\")\n\n"
  },
  {
    "path": "kernels/triton/inference/torch_compile/flash_backward.py",
    "content": "#!/usr/bin/env python\n\"\"\"\nCode copied from https://github.com/ROCm/triton/blob/triton-mlir/python/perf-kernels/flash-attention.py\n\"\"\"\n\n\"\"\"\nFused Attention\n===============\n\nThis is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)\nCredits: OpenAI kernel team, AMD ML Frameworks Triton team\n\nFeatures supported:\n\n1) Fwd with causal masking\n2) Any sequence lengths without padding (currently fwd kernel only)\n3) Support for different sequence lengths for q and k\n4) Nested tensor API currently does not support dropout or bias.\n\nNot currently supported:\n\n1) Non power of two head dims\n\n\"\"\"\n\nimport argparse\nimport random\nimport sys\nimport torch\n\nimport triton\nimport triton.language as tl\n\ntorch_dtype:tl.constexpr = torch.float16\n\n#TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz')\n#if TORCH_HAS_FP8E5:\n#    torch_dtype:tl.constexpr = torch.float8_e5m2fnuz\nTORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2')\nif TORCH_HAS_FP8E5:\n    torch_dtype:tl.constexpr = torch.float8_e5m2\n\nclass MetaData():\n    cu_seqlens_q = None\n    cu_seqlens_k = None\n    max_seqlens_q = 0\n    max_seqlens_k = 0\n    bias = None\n    alibi_slopes = None\n    causal = False\n    num_contexts = 0\n    varlen = False\n    dropout_p, return_encoded_softmax = 0.0, False\n\n    def __init__(self, sm_scale=1.0):\n        self.sm_scale = sm_scale\n\n    def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):\n        self.varlen = True\n        self.cu_seqlens_q = cu_seqlens_q\n        self.cu_seqlens_k = cu_seqlens_k\n        # Without \"varlen\", there should still be one sequence.\n        assert len(cu_seqlens_q) >= 2\n        assert len(cu_seqlens_q) == len(cu_seqlens_k)\n        self.num_contexts = len(cu_seqlens_q) - 1\n        for i in range (0, self.num_contexts):\n            self.max_seqlens_q = max(cu_seqlens_q[i+1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q)\n            self.max_seqlens_k = max(cu_seqlens_k[i+1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k)\n\n    def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k):\n        assert bias.is_cuda\n        assert bias.dim() == 4\n        assert bias.shape[0] == 1\n        assert bias.shape[2:] == (seqlen_q, seqlen_k)\n        self.bias = bias\n\n    def need_alibi(self, alibi_slopes, batch, nheads):\n        assert alibi_slopes.is_cuda\n        assert alibi_slopes.dim() == 2\n        assert alibi_slopes.shape[0] == batch\n        assert alibi_slopes.shape[1] == nheads\n        self.alibi_slopes = alibi_slopes\n\n    def need_causal(self):\n        self.causal = True\n\n    def need_dropout(dropout_p, return_encoded_softmax):\n        self.dropout_p = dropout_p\n        self.return_encoded_softmax = return_encoded_softmax\n\n    def check_args(self, q, k, v, o):\n        assert q.dim() == k.dim() and q.dim() == v.dim()\n        if self.varlen:\n            assert q.dim() == 3\n            total_q, nheads_q, head_size = q.shape\n            total_k, nheads_k, _ = k.shape\n            assert self.cu_seqlens_q is not None\n            assert self.cu_seqlens_k is not None\n            assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k)\n            # TODO: Remove once bias is supported with varlen\n            assert self.bias == None\n            # TODO:Remove once dropout is supported with varlen\n            assert self.dropout_p == 0.0\n            assert not self.return_encoded_softmax\n        else:\n            assert q.dim() == 4\n            batch, nheads_q, seqlen_q, head_size = q.shape\n            _, nheads_k, seqlen_k, _ = k.shape\n            assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0\n            assert self.cu_seqlens_q is None and self.cu_seqlens_k is None\n        assert k.shape == v.shape\n        assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]\n        # TODO: Change assert if we support qkl f8 and v f16\n        assert q.dtype == k.dtype and q.dtype == v.dtype\n        assert head_size <= 256\n        assert o.shape == q.shape\n        assert (nheads_q % nheads_k) == 0\n\n@triton.jit\ndef cdiv_fn(x,y):\n    return (x + y - 1) // y\n\n@triton.jit\ndef max_fn(x, y):\n    return tl.math.max(x, y)\n\n@triton.jit\ndef dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):\n    ms = tl.arange(0, m)\n    ns = tl.arange(0, n)\n    return philox_offset + ms[:, None] * stride + ns[None, :]\n\n@triton.jit\ndef dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):\n    rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32)\n    # TODO: use tl.randint for better performance\n    return tl.rand(philox_seed, rng_offsets)\n\n@triton.jit\ndef dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):\n    rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)\n    rng_keep = rng_output > dropout_p\n    return rng_keep\n\n@triton.jit\ndef load_fn(block_ptr, first, second, pad):\n    if first and second:\n        tensor = tl.load(block_ptr, boundary_check=(0,1), padding_option=pad)\n    elif first:\n        tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad)\n    elif second:\n        tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad)\n    else:\n        tensor = tl.load(block_ptr)\n    return tensor\n\n@triton.jit\ndef _attn_fwd_inner(\n    acc, l_i, m_i, q,\n    K_block_ptr, V_block_ptr,\n    start_m,\n    actual_seqlen_k,\n    actual_seqlen_q,\n    dropout_p,\n    philox_seed,\n    batch_philox_offset,\n    encoded_softmax_block_ptr,\n    block_min, block_max,\n    offs_n_causal,\n    masked_blocks,\n    n_extra_tokens,\n    bias_ptr,\n    alibi_slope,\n    IS_CAUSAL: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    OFFS_M: tl.constexpr,\n    OFFS_N: tl.constexpr,\n    PRE_LOAD_V: tl.constexpr,\n    MASK_STEPS: tl.constexpr,\n    ENABLE_DROPOUT: tl.constexpr,\n    RETURN_ENCODED_SOFTMAX: tl.constexpr,\n    PADDED_HEAD: tl.constexpr\n):\n    # loop over k, v, and update accumulator\n    for start_n in range (block_min, block_max, BLOCK_N):\n        # For padded blocks, we will overrun the tensor size if\n        # we load all BLOCK_N. For others, the blocks are all within range.\n        k = load_fn(K_block_ptr, PADDED_HEAD, MASK_STEPS and (n_extra_tokens != 0), \"zero\")\n        if PRE_LOAD_V:\n            v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, \"zero\")\n        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n        # We start from end of seqlen_k so only the first iteration would need\n        # to be checked for padding if it is not a multiple of block_n\n        # TODO: This can be optimized to only be true for the padded block.\n        if MASK_STEPS:\n            # If this is the last block / iteration, we want to\n            # mask if the sequence length is not a multiple of block size\n            # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn.\n            # last step might get wasted but that is okay. check if this masking works For\n            # that case.\n            if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):\n                boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)\n                size_n = start_n + OFFS_N[None,:]\n                mask = size_n < boundary_m[:,None]\n                qk = tl.where(mask, qk, float(\"-inf\"))\n        if IS_CAUSAL:\n            causal_boundary = start_n + offs_n_causal\n            causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]\n            qk = tl.where(causal_mask, qk, float(\"-inf\"))\n        # -- compute qk ----\n        qk += tl.dot(q, k)\n        if bias_ptr is not None:\n            bias = load_fn(bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), \"zero\")\n            # While bias is added after multiplying qk with sm_scale,\n            # our optimization to use 2^x instead of e^x results in an additional\n            # scale factor of log2(e) which we must also multiply the bias with.\n            qk += (bias * 1.44269504089)\n           \n        if alibi_slope is not None:\n            # Compute the global position of each token within the sequence\n            global_m_positions = start_m*BLOCK_M + tl.arange(0, BLOCK_M)\n            global_n_positions = start_n + tl.arange(0, BLOCK_N)\n\n            # Compute the relative position using the global positions\n            relative_pos_block = global_m_positions[:,None] + actual_seqlen_k - global_n_positions[None,:] - actual_seqlen_q\n            relative_pos_block = tl.abs(relative_pos_block)\n\n\n            alibi_block = -1 * alibi_slope  * relative_pos_block\n\n            qk += (alibi_block * 1.44269504089) # scale factor of log2(e)\n\n        # softmax\n        m_ij = tl.maximum(m_i, tl.max(qk,1))\n        qk = qk - m_ij[:, None]\n        p = tl.math.exp2(qk)\n\n        # CAVEAT: Must update l_ij before applying dropout\n        l_ij = tl.sum(p, 1)\n        if ENABLE_DROPOUT:\n            philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N\n            keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k)\n            if RETURN_ENCODED_SOFTMAX:\n                tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty))\n            p = tl.where(keep, p, 0.0)\n        elif RETURN_ENCODED_SOFTMAX:\n            tl.store(encoded_softmax_block_ptr, p.to(encoded_softmax_block_ptr.type.element_ty))\n        # -- update output accumulator --\n        alpha = tl.math.exp2(m_i - m_ij)\n        acc = acc * alpha[:, None]\n        if not PRE_LOAD_V:\n            v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, \"zero\")\n        # -- update m_i and l_i\n        l_i = l_i * alpha + l_ij\n        # update m_i and l_i\n        m_i = m_ij\n        acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)\n        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n        if bias_ptr is not None:\n            bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))\n        if RETURN_ENCODED_SOFTMAX:\n            encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N))\n    return acc, l_i, m_i\n\n\n@triton.jit\ndef attn_fwd(\n    Q, K, V, bias, sm_scale, L, Out,\n    stride_qz, stride_qh, stride_qm, stride_qk,\n    stride_kz, stride_kh, stride_kn, stride_kk,\n    stride_vz, stride_vh, stride_vk, stride_vn,\n    stride_oz, stride_oh, stride_om, stride_on,\n    stride_bz, stride_bh, stride_bm, stride_bn,\n    stride_az, stride_ah,\n    cu_seqlens_q, cu_seqlens_k,\n    dropout_p, philox_seed, philox_offset_base, encoded_softmax,\n    hq, hk,\n    alibi_slopes,\n    ACTUAL_BLOCK_DMODEL:tl.constexpr,\n    MAX_SEQLENS_Q:tl.constexpr, MAX_SEQLENS_K:tl.constexpr,\n    VARLEN: tl.constexpr,\n    IS_CAUSAL: tl.constexpr,\n    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,\n    PRE_LOAD_V: tl.constexpr,\n    BIAS_TYPE: tl.constexpr,\n    ENABLE_DROPOUT: tl.constexpr,\n    RETURN_ENCODED_SOFTMAX: tl.constexpr,\n    USE_ALIBI: tl.constexpr,\n    BATCH_SIZE: tl.constexpr,\n):\n    start_m = tl.program_id(0)\n    off_h_q = tl.program_id(1)\n    off_z = tl.program_id(2)\n    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    offs_n = tl.arange(0, BLOCK_N)\n    if VARLEN:\n        cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)\n        cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)\n        seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start\n        # We have a one-size-fits-all grid in id(0). Some seqlens might be too\n        # small for all start_m so for those we return early.\n        if start_m * BLOCK_M > seqlen_q:\n            return\n        cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)\n        cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)\n        seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start\n    else:\n        cu_seqlens_q_start = 0\n        cu_seqlens_k_start = 0\n        seqlen_q = MAX_SEQLENS_Q\n        seqlen_k = MAX_SEQLENS_K\n\n    # Now we compute whether we need to exit early due to causal masking.\n    # This is because for seqlen_q > seqlen_k, M rows of the attn scores\n    # are completely masked, resulting in 0s written to the output, and\n    # inf written to LSE. We don't need to do any GEMMs in this case.\n    # This block of code determines what N is, and if this WG is operating\n    # on those M rows.\n    n_blocks = cdiv_fn(seqlen_k, BLOCK_N)\n    if (IS_CAUSAL):\n        # If seqlen_q == seqlen_k, the attn scores are a square matrix.\n        # If seqlen_q != seqlen_k, attn scores are rectangular which means\n        # the causal mask boundary is bottom right aligned, and ends at either\n        # the top edge (seqlen_q < seqlen_k) or left edge.\n        # This captures the decrease in n_blocks if we have a rectangular attn matrix\n        n_blocks_seqlen = cdiv_fn(\n            (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q,\n            BLOCK_N\n        )\n        # This is what adjusts the block_max for the current WG, only\n        # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks\n        n_blocks = min(n_blocks, n_blocks_seqlen)\n        # If we have no blocks after adjusting for seqlen deltas, this WG is part of\n        # the blocks that are all 0. We exit early.\n        if n_blocks <= 0:\n            o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh\n            O_block_ptr = tl.make_block_ptr(\n                base=Out + o_offset,\n                shape=(seqlen_q, BLOCK_DMODEL),\n                strides=(stride_om, stride_on),\n                offsets=(start_m * BLOCK_M, 0),\n                block_shape=(BLOCK_M, BLOCK_DMODEL),\n                order=(1, 0)\n            )\n            acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)\n            # We still need to write 0s to the result\n            tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1))\n            l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m\n            # We store inf to LSE, not -inf because in the bwd pass, we subtract this\n            # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks.\n            l = tl.full([BLOCK_M], value=float(\"inf\"), dtype=tl.float32)\n            tl.store(l_ptrs, l)\n            # TODO: Should dropout and return encoded softmax be handled here too?\n            return\n\n    is_mqa = hq != hk\n    off_h_k = off_h_q % hk if is_mqa else off_h_q\n    need_padding = False\n    n_extra_tokens = 0\n    if seqlen_k < BLOCK_N:\n        need_padding = True\n        n_extra_tokens = BLOCK_N - seqlen_k\n    elif seqlen_k % BLOCK_N:\n        need_padding = True\n        n_extra_tokens = seqlen_k % BLOCK_N\n    padded_head = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL)\n\n    # Compute pointers for all the tensors used in this kernel.\n    q_offset = off_z * stride_qz +  off_h_q * stride_qh + cu_seqlens_q_start * stride_qm\n    Q_block_ptr = tl.make_block_ptr(\n        base=Q + q_offset,\n        shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),\n        strides=(stride_qm, stride_qk),\n        offsets=(start_m * BLOCK_M, 0),\n        block_shape=(BLOCK_M, BLOCK_DMODEL),\n        order=(1, 0)\n    )\n    k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn\n    K_block_ptr = tl.make_block_ptr(\n        base=K + k_offset,\n        shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),\n        strides=(stride_kk, stride_kn),\n        offsets=(0, 0),\n        block_shape=(BLOCK_DMODEL, BLOCK_N),\n        order=(0, 1)\n    )\n    v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk\n    V_block_ptr = tl.make_block_ptr(\n        base=V + v_offset,\n        shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),\n        strides=(stride_vk, stride_vn),\n        offsets=(0, 0),\n        block_shape=(BLOCK_N, BLOCK_DMODEL),\n        order=(1, 0)\n    )\n    if BIAS_TYPE != 0:\n        b_offset = off_h_q * stride_bh # Note: this might get large enough to overflow on some configs\n        bias_ptr = tl.make_block_ptr(\n            base=bias + b_offset,\n            shape=(seqlen_q, seqlen_k),\n            strides=(stride_bm, stride_bn),\n            offsets=(start_m * BLOCK_M, 0),\n            block_shape=(BLOCK_M, BLOCK_N),\n            order=(1, 0),\n        )\n    else:\n        bias_ptr = None\n\n    if USE_ALIBI != 0:\n        a_offset = off_z * stride_az +  off_h_q * stride_ah \n        alibi_slope = tl.load(alibi_slopes + a_offset)\n    else:\n        alibi_slope = None\n\n    if ENABLE_DROPOUT:\n        batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k\n    else:\n        batch_philox_offset = 0\n    # We can ask to return the dropout mask without actually doing any dropout. In\n    # this case, we return an invalid pointer so indicate the mask is not valid.\n    # TODO: Fix encoded softmax. It currently uses just h_q in the base offset.\n    if RETURN_ENCODED_SOFTMAX:\n        encoded_softmax_block_ptr = tl.make_block_ptr(\n                base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,\n                shape=(seqlen_q, seqlen_k),\n                strides=(seqlen_k, 1),\n                offsets=(start_m * BLOCK_M, 0),\n                block_shape=(BLOCK_M, BLOCK_N),\n                order=(1, 0)\n                )\n    else:\n        encoded_softmax_block_ptr = 0\n    # initialize pointer to m and l\n    m_i = tl.full([BLOCK_M], float(\"-inf\"), dtype=tl.float32)\n    l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)\n    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n    # scale sm_scale by log_2(e) and use 2^x in the loop as we do not\n    # have native e^x support in HW.\n    qk_scale = sm_scale * 1.44269504089\n    # Q is loaded once at the beginning and shared by all N blocks.\n    q = load_fn(Q_block_ptr, True, padded_head, \"zero\")\n    q = (q * qk_scale).to(Q_block_ptr.type.element_ty)\n\n    # Here we compute how many full and masked blocks we have.\n    padded_block_k = n_extra_tokens != 0\n    is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)\n    if IS_CAUSAL:\n        # There are always at least BLOCK_M // BLOCK_N masked blocks.\n        # Additionally there might be one more due to dissimilar seqlens.\n        masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)\n    else:\n        # Padding on Q does not need to be masked in the FA loop.\n        masked_blocks = padded_block_k\n    # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block.\n    # In this case we might exceed n_blocks so pick the min.\n    masked_blocks = min(masked_blocks, n_blocks)\n    n_full_blocks = n_blocks - masked_blocks\n    block_min = 0\n    block_max = n_blocks * BLOCK_N\n    # Compute for full blocks. Here we set causal to false regardless of its actual\n    # value because there is no masking. Similarly we do not need padding.\n    if n_full_blocks > 0:\n        block_max = (n_blocks - masked_blocks) * BLOCK_N\n        acc, l_i, m_i = _attn_fwd_inner(\n            acc, l_i, m_i, q, K_block_ptr, V_block_ptr,\n            start_m, seqlen_k, seqlen_q,\n            dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr,\n            # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _\n            block_min, block_max, 0, 0, 0, bias_ptr, alibi_slope,\n            # IS_CAUSAL, ....\n            False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,\n            # _, MASK_STEPS, ...\n            PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, padded_head\n        )\n        block_min = block_max\n        block_max = n_blocks * BLOCK_N\n\n    tl.debug_barrier()\n    # Remaining blocks, if any, are full / not masked.\n    if (masked_blocks > 0):\n        if IS_CAUSAL:\n            offs_n_causal = offs_n + (seqlen_q - seqlen_k)\n        else:\n            offs_n_causal = 0\n        K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks*BLOCK_N))\n        V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks*BLOCK_N, 0))\n        if bias_ptr is not None:\n            bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks*BLOCK_N))\n        if RETURN_ENCODED_SOFTMAX:\n            encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,\n                                                   (0, n_full_blocks))\n        acc, l_i, m_i = _attn_fwd_inner(\n            acc, l_i, m_i, q, K_block_ptr, V_block_ptr,\n            start_m, seqlen_k, seqlen_q,\n            dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr,\n            block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, bias_ptr, alibi_slope,\n            IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,\n            # _, MASK_STEPS, ...\n            PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, padded_head\n        )\n    # epilogue\n    acc = acc / l_i[:, None]\n    if ENABLE_DROPOUT:\n        acc = acc / (1 - dropout_p)\n    # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,\n    # then we have one block with a row of all NaNs which come from computing\n    # softmax over a row of all -infs (-inf - inf = NaN). We check for that here\n    # and store 0s where there are NaNs as these rows should've been zeroed out.\n    end_m_idx = (start_m + 1) * BLOCK_M\n    start_m_idx = start_m * BLOCK_M\n    causal_start_idx = seqlen_q - seqlen_k\n    acc = acc.to(Out.type.element_ty)\n    if IS_CAUSAL:\n        if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:\n            out_mask_boundary = tl.full((BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32)\n            mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)\n            out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]\n            z = 0.0\n            acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))\n    # write back LSE\n    l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m\n    # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows.\n    # This is only true for the last M block. For others, overflow_size will be -ve\n    overflow_size = end_m_idx - seqlen_q\n    if overflow_size > 0:\n        boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)\n        # This is a > check because mask being 0 blocks the store.\n        l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)\n        tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)\n    else:\n        tl.store(l_ptrs, m_i + tl.math.log2(l_i))\n\n    # write back O\n    o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh\n    O_block_ptr = tl.make_block_ptr(\n        base=Out + o_offset,\n        shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),\n        strides=(stride_om, stride_on),\n        offsets=(start_m * BLOCK_M, 0),\n        block_shape=(BLOCK_M, BLOCK_DMODEL),\n        order=(1, 0)\n    )\n    # Need boundary check on this to make sure the padding from the\n    # Q and KV tensors in both dims are not part of what we store back.\n    # TODO: Do the boundary check optionally.\n    tl.store(O_block_ptr, acc, boundary_check=(0,1))\n\n\n\ndef attention(q, k, v, sm_scale):\n\n    o = torch.empty_like(q, dtype=v.dtype)\n\n    batch, nheads_q, seqlen_q, head_size = q.shape\n    _, nheads_k, seqlen_k, _ = k.shape\n\n    max_seqlens_q = seqlen_q\n\n    q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3))\n    k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3))\n    v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3))\n    o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3))\n\n    # Get closest power of 2 over or equal to 32.\n    unpadded_head_dims = {32, 64, 128, 256}\n    if head_size not in unpadded_head_dims:\n        padded_d_model = None\n        for i in unpadded_head_dims:\n            if i > head_size:\n                padded_d_model = i\n                break\n        assert padded_d_model is not None\n    else:\n        padded_d_model = head_size\n\n\n    # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4)\n\n    BLOCK_M = 128\n    BLOCK_N = 128\n    PRE_LOAD_V = False\n    num_stages = 1\n    num_warps = 4\n\n    grid = (triton.cdiv(max_seqlens_q, BLOCK_M), nheads_q, batch)\n\n    # encoded_softmax is used to validate dropout behavior vs the PyTorch SDPA math backend reference.  We zero this out\n    # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according\n    # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing\n    # only.  This return holds no useful output aside from debugging.\n\n    encoded_softmax = None\n\n    M = torch.empty((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32)\n\n    # Seed the RNG so we get reproducible results for testing.\n    philox_seed = 0x1BF52\n    philox_offset = 0x1D4B42\n    \n    bias_strides = (0,0,0,0)\n    alibi_strides = (0, 0)\n    \n    attn_fwd[grid](\n        q, k, v, None, sm_scale, M, o,\n        *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, *alibi_strides,\n        None, None,\n        BLOCK_M=BLOCK_M,\n        PRE_LOAD_V=PRE_LOAD_V,\n        BLOCK_N=BLOCK_N,\n        dropout_p=0.0,\n        philox_seed=philox_seed,\n        philox_offset_base=philox_offset,\n        encoded_softmax=encoded_softmax,\n        hq=nheads_q, hk=nheads_k,\n        alibi_slopes = None,\n        ACTUAL_BLOCK_DMODEL=head_size,\n        MAX_SEQLENS_Q=seqlen_q, \n        MAX_SEQLENS_K=seqlen_k,\n        IS_CAUSAL=False, ########################\n        VARLEN=False,\n        BLOCK_DMODEL=padded_d_model,\n        BIAS_TYPE=0,\n        USE_ALIBI=0,\n        ENABLE_DROPOUT=False,\n        RETURN_ENCODED_SOFTMAX=False,\n        BATCH_SIZE= q.shape[0],\n    )\n    return o\n\n\n\n\n@triton.jit\ndef _attn_bwd_preprocess(\n    Out, DO,\n    Delta,\n    stride_oz, stride_oh, stride_om, stride_on,\n    stride_doz, stride_doh, stride_dom, stride_don,\n    seqlen_q,\n    head_dim,\n    BLOCK_M: tl.constexpr,\n    D_HEAD: tl.constexpr,\n):\n    # off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n    # off_n = tl.arange(0, D_HEAD)\n    off_m = tl.program_id(0) * BLOCK_M\n    off_h = tl.program_id(1) # head index\n    off_z = tl.program_id(2) # batch index\n    num_h = tl.num_programs(1)\n    o_offset = off_h * stride_oh + off_z * stride_oz\n    O_block_ptr = tl.make_block_ptr(\n        base=Out + o_offset,\n        shape=(seqlen_q, head_dim),\n        strides=(stride_om, stride_on),\n        offsets=(off_m, 0),\n        block_shape=(BLOCK_M, D_HEAD),\n        order=(1, 0)\n    )\n    do_offset = off_h * stride_doh + off_z * stride_doz\n    DO_block_ptr = tl.make_block_ptr(\n        base=DO + do_offset,\n        shape=(seqlen_q, head_dim),\n        strides=(stride_dom, stride_don),\n        offsets=(off_m, 0),\n        block_shape=(BLOCK_M, D_HEAD),\n        order=(1, 0)\n    )\n    # load\n    # o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n    # do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n    o = tl.load(O_block_ptr, boundary_check=(0,1), padding_option=\"zero\").to(tl.float32)\n    do = tl.load(DO_block_ptr, boundary_check=(0,1), padding_option=\"zero\").to(tl.float32)\n    # compute\n    delta = tl.sum(o * do, axis=1)\n    # write-back, shape (q.shape[0] * q.shape[1], q.shape[2])\n    off_zh = off_z * num_h + off_h * 1\n    # Check for OOB accesses\n    delta_ptrs = Delta + off_zh * seqlen_q + off_m + tl.arange(0, BLOCK_M)\n    overflow = off_m + BLOCK_M - seqlen_q\n    if overflow > 0:\n        boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow, dtype=tl.int32)\n        mask = boundary > tl.arange(0, BLOCK_M)\n        tl.store(delta_ptrs, delta, mask=mask)\n    else:\n        tl.store(delta_ptrs, delta)\n\n@triton.jit\ndef _bwd_kernel_dk_dv(\n                   dk, dv,\n                   Q, k, v, sm_scale, alibi_slope,\n                   DO,\n                   M, D,\n                   # shared by Q/K/V/DO.\n                   stride_tok, stride_d,\n                   H, N_CTX, BLOCK_M1: tl.constexpr,\n                   BLOCK_N1: tl.constexpr,\n                   BLOCK_DMODEL: tl.constexpr,\n                   # Filled in by the wrapper.\n                   start_n, start_m, num_steps,\n                   MASK: tl.constexpr):\n    offs_m = start_m + tl.arange(0, BLOCK_M1)\n    offs_n = start_n + tl.arange(0, BLOCK_N1)\n    offs_k = tl.arange(0, BLOCK_DMODEL)\n    QT_block_ptr = tl.make_block_ptr(\n        base=Q,\n        shape=(BLOCK_DMODEL, N_CTX),\n        strides=(stride_d, stride_tok),\n        offsets=(0, start_m),\n        block_shape=(BLOCK_DMODEL, BLOCK_M1),\n        order=(0,1)\n    )\n    DO_block_ptr = tl.make_block_ptr(\n        base=DO,\n        shape=(N_CTX, BLOCK_DMODEL),\n        strides=(stride_tok, stride_d),\n        offsets=(start_m, 0),\n        block_shape=(BLOCK_M1, BLOCK_DMODEL),\n        order=(1,0)\n    )\n    # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.\n    tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)\n    curr_m = start_m\n    step_m = BLOCK_M1\n    for blk_idx in range(num_steps):\n        qT = tl.load(QT_block_ptr)\n        # Load m before computing qk to reduce pipeline stall.\n        offs_m = curr_m + tl.arange(0, BLOCK_M1)\n        m = tl.load(M + offs_m)\n        kqT = tl.dot(k, qT)\n        if alibi_slope is not None:\n            alibi_block = compute_alibi_block(alibi_slope, N_CTX, N_CTX, offs_m, offs_n, True)\n            kqT += alibi_block * 1.44269504089\n\n        pT = tl.math.exp2(kqT - m[None, :])\n        # Autoregressive masking.\n        if MASK:\n            mask = (offs_m[None, :] >= offs_n[:, None])\n            pT = tl.where(mask, pT, 0.0)\n        do = tl.load(DO_block_ptr)\n        # Compute dV.\n        ppT = pT\n        ppT = ppT.to(tl.bfloat16)\n        dv += tl.dot(ppT, do)\n        # D (= delta) is pre-divided by ds_scale.\n        Di = tl.load(D + offs_m)\n        # Compute dP and dS.\n        dpT = tl.dot(v, tl.trans(do))\n        dsT = pT * (dpT - Di[None, :])\n        dsT = dsT.to(tl.bfloat16)\n        dk += tl.dot(dsT, tl.trans(qT))\n        # Increment pointers.\n        curr_m += step_m\n        QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m))\n        DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0))\n    return dk, dv\n\n@triton.jit\ndef _bwd_kernel_dq(dq, q, K, V,\n                 do, m, D, alibi_slope,\n                 # shared by Q/K/V/DO.\n                 stride_tok, stride_d,\n                 H, N_CTX,\n                 BLOCK_M2: tl.constexpr,\n                 BLOCK_N2: tl.constexpr,\n                 BLOCK_DMODEL: tl.constexpr,\n                 # Filled in by the wrapper.\n                 start_m, start_n, num_steps,\n                 MASK: tl.constexpr):\n    offs_m = start_m + tl.arange(0, BLOCK_M2)\n    offs_n = start_n + tl.arange(0, BLOCK_N2)\n    offs_k = tl.arange(0, BLOCK_DMODEL)\n    KT_block_ptr = tl.make_block_ptr(\n        base=K,\n        shape=(BLOCK_DMODEL, N_CTX),\n        strides=(stride_d, stride_tok),\n        offsets=(0, start_n),\n        block_shape=(BLOCK_DMODEL, BLOCK_N2),\n        order=(0, 1)\n    )\n    VT_block_ptr = tl.make_block_ptr(\n        base=V,\n        shape=(BLOCK_DMODEL, N_CTX),\n        strides=(stride_d, stride_tok),\n        offsets=(0, start_n),\n        block_shape=(BLOCK_DMODEL, BLOCK_N2),\n        order=(0, 1)\n    )\n    # D (= delta) is pre-divided by ds_scale.\n    Di = tl.load(D + offs_m)\n    # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.\n    tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)\n    curr_n = start_n\n    step_n = BLOCK_N2\n    for blk_idx in range(num_steps):\n        kT = tl.load(KT_block_ptr)\n        qk = tl.dot(q, kT)\n        p = tl.math.exp2(qk - m)\n        # Autoregressive masking.\n        if MASK:\n            offs_n = curr_n + tl.arange(0, BLOCK_N2)\n            mask = (offs_m[:, None] >= offs_n[None, :])\n            p = tl.where(mask, p, 0.0)\n        # Compute dP and dS.\n        vT = tl.load(VT_block_ptr)\n        dp = tl.dot(do, vT).to(tl.float32)\n        ds = p * (dp - Di[:, None])\n        ds = ds.to(tl.bfloat16)\n        # Compute dQ.0.\n        # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.\n        dq += tl.dot(ds, tl.trans(kT))\n        # Increment pointers.\n        curr_n += step_n\n        KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n))\n        VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n))\n    return dq\n\n@triton.jit\ndef _attn_bwd(Q, K, V, sm_scale, alibi_slopes,\n              DO,\n              DQ, DK, DV,\n              M, D,\n              # shared by Q/K/V/DO.\n              stride_z, stride_h, stride_tok, stride_d,\n              # H = 16, N_CTX = 1024\n              H, N_CTX,\n              BLOCK_DMODEL: tl.constexpr,\n              BLOCK_M1: tl.constexpr,\n              BLOCK_N1: tl.constexpr,\n              BLOCK_M2: tl.constexpr,\n              BLOCK_N2: tl.constexpr,\n              BLK_SLICE_FACTOR: tl.constexpr,\n              USE_ALIBI: tl.constexpr):\n    LN2: tl.constexpr = 0.6931471824645996  # = ln(2)\n\n    bhid = tl.program_id(2)\n    off_chz = (bhid * N_CTX).to(tl.int64)\n    adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)\n    pid = tl.program_id(0)\n\n    # offset pointers for batch/head\n    Q += adj\n    K += adj\n    V += adj\n    DO += adj\n    DQ += adj\n    DK += adj\n    DV += adj\n    M += off_chz\n    D += off_chz\n\n    offs_k = tl.arange(0, BLOCK_DMODEL)\n\n    start_n = pid * BLOCK_N1\n    # This assignment is important. It is what allows us to pick the diagonal\n    # blocks. Later, when we want to do the lower triangular, we update start_m\n    # after the first dkdv call.\n    start_m = start_n\n\n    MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR\n    offs_n = start_n + tl.arange(0, BLOCK_N1)\n\n    dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)\n    dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)\n\n    K_block_ptr = tl.make_block_ptr(\n        base=K,\n        shape=(N_CTX, BLOCK_DMODEL),\n        strides=(stride_tok, stride_d),\n        offsets=(start_n, 0),\n        block_shape=(BLOCK_N1, BLOCK_DMODEL),\n        order=(1, 0),\n    )\n    V_block_ptr = tl.make_block_ptr(\n        base=V,\n        shape=(N_CTX, BLOCK_DMODEL),\n        strides=(stride_tok, stride_d),\n        offsets=(start_n, 0),\n        block_shape=(BLOCK_N1, BLOCK_DMODEL),\n        order=(1, 0),\n    )\n\n    # load K and V: they stay in SRAM throughout the inner loop for dkdv.\n    k = tl.load(K_block_ptr)\n    v = tl.load(V_block_ptr)\n\n    if USE_ALIBI:\n        a_offset = bhid\n        alibi_slope = tl.load(alibi_slopes + a_offset)\n    else:\n        alibi_slope = None\n\n    # compute dK and dV for blocks close to the diagonal that need to be masked\n    num_steps = BLOCK_N1 // MASK_BLOCK_M1\n    dk, dv = _bwd_kernel_dk_dv(\n                            dk, dv,\n                            Q, k, v, sm_scale, alibi_slope,\n                            DO,\n                            M, D,\n                            stride_tok, stride_d,\n                            H, N_CTX,\n                            MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,\n                            start_n, start_m, num_steps,\n                            MASK=True\n                            )\n\n    # compute dK and dV for blocks that don't need masking further from the diagonal\n    start_m += num_steps * MASK_BLOCK_M1\n    num_steps = (N_CTX - start_m) // BLOCK_M1\n\n    dk, dv = _bwd_kernel_dk_dv(\n        dk, dv,\n        Q, k, v, sm_scale, alibi_slope,\n        DO,\n        M, D,\n        stride_tok, stride_d,\n        H, N_CTX,\n        BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,\n        start_n, start_m, num_steps,\n        MASK=False\n    )\n\n    DV_block_ptrs = tl.make_block_ptr(\n        base=DV,\n        shape=(N_CTX, BLOCK_DMODEL),\n        strides=(stride_tok, stride_d),\n        offsets=(start_n, 0),\n        block_shape=(BLOCK_N1, BLOCK_DMODEL),\n        order=(1,0)\n    )\n    tl.store(DV_block_ptrs, dv.to(v.dtype))\n\n    # Write back dK.\n    dk *= sm_scale\n    DK_block_ptrs = tl.make_block_ptr(\n        base=DK,\n        shape=(N_CTX, BLOCK_DMODEL),\n        strides=(stride_tok, stride_d),\n        offsets=(start_n, 0),\n        block_shape=(BLOCK_N1, BLOCK_DMODEL),\n        order=(1,0)\n    )\n    tl.store(DK_block_ptrs, dk.to(k.dtype))\n\n    # THIS BLOCK DOES DQ:\n    start_m = pid * BLOCK_M2\n    end_n = start_m + BLOCK_M2\n\n    MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR\n    offs_m = start_m + tl.arange(0, BLOCK_M2)\n\n    Q_block_ptr = tl.make_block_ptr(\n        base=Q,\n        shape=(N_CTX, BLOCK_DMODEL),\n        strides=(stride_tok, stride_d),\n        offsets=(start_m, 0),\n        block_shape=(BLOCK_M2, BLOCK_DMODEL),\n        order=(1, 0)\n    )\n\n    DO_block_ptr = tl.make_block_ptr(\n        base=DO,\n        shape=(N_CTX, BLOCK_DMODEL),\n        strides=(stride_tok, stride_d),\n        offsets=(start_m, 0),\n        block_shape=(BLOCK_M2, BLOCK_DMODEL),\n        order=(1, 0)\n    )\n    q = tl.load(Q_block_ptr)\n    do = tl.load(DO_block_ptr)\n    dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)\n\n    m = tl.load(M + offs_m)\n    m = m[:, None]\n\n    # Compute dQ for masked (diagonal) blocks.\n    # NOTE: This code scans each row of QK^T backward (from right to left,\n    # but inside each call to _attn_bwd_dq, from left to right), but that's\n    # not due to anything important.  I just wanted to reuse the loop\n    # structure for dK & dV above as much as possible.\n    num_steps = BLOCK_M2 // MASK_BLOCK_N2\n    dq = _bwd_kernel_dq(dq, q, K, V,\n                      do, m, D, alibi_slope,\n                      stride_tok, stride_d,\n                      H, N_CTX,\n                      BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL,\n                      start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,\n                      MASK=True\n                      )\n    end_n -= num_steps * MASK_BLOCK_N2\n    # stage 2\n    num_steps = end_n // BLOCK_N2\n    dq = _bwd_kernel_dq(dq, q, K, V,\n                      do, m, D, alibi_slope,\n                      stride_tok, stride_d,\n                      H, N_CTX,\n                      BLOCK_M2, BLOCK_N2, BLOCK_DMODEL,\n                      start_m, end_n - num_steps * BLOCK_N2, num_steps,\n                      MASK=False\n                      )\n    # Write back dQ.\n    DQ_block_ptr = tl.make_block_ptr(\n        base=DQ,\n        shape=(N_CTX, BLOCK_DMODEL),\n        strides=(stride_tok, stride_d),\n        offsets=(start_m, 0),\n        block_shape=(BLOCK_M2, BLOCK_DMODEL),\n        order=(1, 0)\n    )\n    dq *= LN2\n    tl.store(DQ_block_ptr, dq.to(q.dtype))\n\n@torch.library.custom_op(\"triton::flash_bwd\", mutates_args=())\ndef flash_bwd(q: torch.Tensor, k: torch.Tensor, v:torch.Tensor, o: torch.Tensor, M:torch.Tensor, do: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n\n    BLOCK = 128\n    sm_scale = q.shape[-1] ** -0.5\n    batch, nheads_q, seqlen_q, head_size = q.shape\n    _, nheads_k, seqlen_k, _ = k.shape\n    Lk = k.shape[-1]\n    max_seqlens_q = seqlen_q\n    padded_d_model = head_size\n\n    # assert do.is_contiguous()\n    assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()\n    seqlen_q = q.shape[2]\n    dq = torch.empty_like(q)\n    dk = torch.empty_like(k)\n    dv = torch.empty_like(v)\n\n    BATCH, N_HEAD, N_CTX = q.shape[:3]\n    PRE_BLOCK = 128\n    NUM_WARPS, NUM_STAGES = 4, 1\n    BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32\n\n    BLK_SLICE_FACTOR = 2\n    RCP_LN2 = 1.4426950408889634  # = 1.0 / ln(2)\n    arg_k = k\n    arg_k = arg_k * (sm_scale * RCP_LN2)\n\n    assert N_CTX % PRE_BLOCK == 0\n\n    delta = torch.empty_like(M)\n\n    grid_preprocess = (triton.cdiv(do.shape[2], BLOCK), do.shape[1], do.shape[0])\n    _attn_bwd_preprocess[grid_preprocess](\n        o, do, delta,\n        o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n        do.stride(0), do.stride(1), do.stride(2), do.stride(3),\n        seqlen_q,\n        head_dim=Lk,\n        BLOCK_M=BLOCK, D_HEAD=padded_d_model,\n    )\n\n    grid = (triton.cdiv(N_CTX, BLOCK_N1), 1, BATCH * N_HEAD)\n    _attn_bwd[grid](\n        q, arg_k, v, sm_scale, None, do, dq, dk, dv,\n        M, delta,\n        q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n        N_HEAD, N_CTX,\n        BLOCK_DMODEL=padded_d_model,\n        BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2,\n        BLK_SLICE_FACTOR=BLK_SLICE_FACTOR,\n        USE_ALIBI=False,\n        num_warps=NUM_WARPS, num_stages=NUM_STAGES,\n    )\n\n    return dq, dk, dv\n\n\n@flash_bwd.register_fake\ndef _(q, k, v, o, M, do):\n\n    dq = torch.empty_like(q)\n    dk = torch.empty_like(k)\n    dv = torch.empty_like(v)\n\n    return dq, dk, dv\n\n@torch.library.custom_op(\"triton::flash\", mutates_args=())\ndef flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, o: torch.Tensor, M: torch.Tensor) -> torch.Tensor:\n\n    sm_scale = q.shape[-1] ** -0.5\n\n    batch, nheads_q, seqlen_q, head_size = q.shape\n    _, nheads_k, seqlen_k, _ = k.shape\n\n    max_seqlens_q = seqlen_q\n\n    q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3))\n    k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3))\n    v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3))\n    o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3))\n\n    # Get closest power of 2 over or equal to 32.\n    unpadded_head_dims = {32, 64, 128, 256}\n    if head_size not in unpadded_head_dims:\n        padded_d_model = None\n        for i in unpadded_head_dims:\n            if i > head_size:\n                padded_d_model = i\n                break\n        assert padded_d_model is not None\n    else:\n        padded_d_model = head_size\n\n    # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4)\n    BLOCK_M = 128\n    BLOCK_N = 128\n    PRE_LOAD_V = False\n\n    grid = (triton.cdiv(max_seqlens_q, BLOCK_M), nheads_q, batch)\n\n    # encoded_softmax is used to validate dropout behavior vs the PyTorch SDPA math backend reference.  We zero this out\n    # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according\n    # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing\n    # only.  This return holds no useful output aside from debugging.\n    \n    encoded_softmax = None\n\n    # Seed the RNG so we get reproducible results for testing.\n    philox_seed = 0x1BF52\n    philox_offset = 0x1D4B42\n    \n    bias_strides = (0, 0, 0, 0)\n    alibi_strides = (0, 0)\n    \n    attn_fwd[grid](\n        q, k, v, None, sm_scale, M, o,\n        *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, *alibi_strides,\n        None, None,\n        BLOCK_M=BLOCK_M,\n        PRE_LOAD_V=PRE_LOAD_V,\n        BLOCK_N=BLOCK_N,\n        dropout_p=0.0,\n        philox_seed=philox_seed,\n        philox_offset_base=philox_offset,\n        encoded_softmax=encoded_softmax,\n        hq=nheads_q, hk=nheads_k,\n        alibi_slopes = None,\n        ACTUAL_BLOCK_DMODEL=head_size,\n        MAX_SEQLENS_Q=seqlen_q, \n        MAX_SEQLENS_K=seqlen_k,\n        IS_CAUSAL=True, ########################\n        VARLEN=False,\n        BLOCK_DMODEL=padded_d_model,\n        BIAS_TYPE=0,\n        USE_ALIBI=0,\n        ENABLE_DROPOUT=False,\n        RETURN_ENCODED_SOFTMAX=False,\n        BATCH_SIZE= q.shape[0],\n    )\n    out = o.clone()\n    return out\n\n@flash.register_fake\ndef _(q, k, v, o, M):\n    return torch.empty_like(q, dtype=v.dtype)\n\n   \ndef setup_context(ctx, inputs, output) -> torch.Tensor:\n    q, k, v, o, M = inputs\n    ctx.save_for_backward(q, k, v, o, M)\n\ndef backward(ctx, do):\n\n    q, k, v, o, M = ctx.saved_tensors\n    dq, dk, dv = flash_bwd(q, k, v, o, M, do)\n\n    return dq, dk, dv, None, None\n\nflash.register_autograd(backward, setup_context=setup_context)\n\nif __name__ == \"__main__\":\n\n    b, nh, s, hd = 1, 32, 128, 128\n\n    q = torch.randn(b, nh, s, hd, dtype=torch.float16, device='cuda').requires_grad_()\n    k = torch.randn(b, nh, s, hd, dtype=torch.float16, device='cuda').requires_grad_()\n    v = torch.randn(b, nh, s, hd, dtype=torch.float16, device='cuda').requires_grad_()\n\n    sm_scale = q.shape[-1] ** -0.5\n\n    @torch.compile(fullgraph=True)\n    def f(q, k, v):\n        return flash(q, k, v)\n    \n    o = f(q, k, v)\n    print(f\"{o=}\")\n\n    dout = torch.randn_like(q)\n    \n    o.backward(dout)\n\n    tri_dq = q.grad.clone()\n    tri_dk = k.grad.clone()\n    tri_dv = v.grad.clone()"
  },
  {
    "path": "kernels/triton/training/README.md",
    "content": "Triton training kernels\n"
  },
  {
    "path": "kernels/triton/training/fused_softmax/README.md",
    "content": "Fused Softmax in Triton, supporting both inference (fwd) and training (fwd/backward). \n\nPerf testing on A100:\n\n<img width=\"790\" alt=\"fused_softmax_a100\" src=\"https://github.com/lessw2020/applied-ai/assets/46302957/c8930c44-9960-4353-83be-e8f6e4b24d96\">\n"
  },
  {
    "path": "kernels/triton/training/fused_softmax/softmax.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\n# ---- Fused Softmax written in Triton ------\n# Extra Credits:\n# Triton Softmax Tutorial\n# LucidRains Triton_Transformers\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom torch import autograd\n\ndef _get_num_warps(block_size: int)-> int:\n    num_warps = 4\n    if block_size > 2047:\n        num_warps = 8\n    if block_size > 4095:\n        num_warps=16\n    return num_warps\n\n@triton.jit\ndef _softmax_kernel_fwd(\n    output_ptr,\n    output_row_stride,\n    input_ptr,\n    input_row_stride,\n    n_cols,\n    block_size: tl.constexpr,\n):\n    # setup input location\n    row_index = tl.program_id(0)\n    input_row_ptr = input_ptr + (row_index * input_row_stride)\n    col_offsets = tl.arange(0, block_size)\n    input_ptrs = input_row_ptr + col_offsets\n    rw_mask = col_offsets < n_cols\n    row = tl.load(input_ptrs, mask = rw_mask, other=float(\"-inf\"))\n\n    # safe softmax proper\n    safe_row = row - tl.max(row, axis=0)\n    numerator = tl.exp(safe_row)\n    denom = tl.sum(numerator, axis=0)\n    sm_out = numerator / denom\n\n    # write results to HBM\n    out_row_ptr = output_ptr + (row_index * output_row_stride)\n    out_row_ptrs = out_row_ptr + col_offsets\n    tl.store(out_row_ptrs, sm_out, mask = rw_mask)\n\n\n@triton.jit\ndef _softmax_kernel_bwd(\n    output_ptr,\n    stride_output_row,\n    grad_ptr,\n    stride_grad_row,\n    input_ptr,\n    stride_input_row,\n    n_cols,\n    block_size: tl.constexpr,\n\n):\n    # setup input locations - need both grad and input access\n    row_index = tl.program_id(0)\n\n    input_row_ptr = input_ptr + (row_index * stride_input_row)\n    grad_row_ptr = grad_ptr + (row_index * stride_grad_row)\n\n    col_offsets = tl.arange(0,block_size)\n    rw_mask = col_offsets < n_cols\n\n    input_row_ptrs = input_row_ptr + col_offsets\n    grad_row_ptrs = grad_row_ptr + col_offsets\n\n\n    probs_row =tl.load(input_row_ptrs, mask=rw_mask, other = 0)\n    grads_row = tl.load(grad_row_ptrs, mask = rw_mask, other=0)\n\n    # compute derivatives\n    dx = probs_row * grads_row\n    dsm_out = dx - probs_row * (tl.sum(dx, axis=0))\n\n    # write to HBM\n    output_row_ptr = output_ptr + (row_index * stride_output_row)\n    output_ptrs = output_row_ptr + col_offsets\n    tl.store(output_ptrs, dsm_out, mask=rw_mask)\n\n\nclass triton_softmax(autograd.Function):\n    @staticmethod\n    def forward(ctx, x):\n        orig_shape = x.shape\n        x = x.view(-1, orig_shape[-1])\n        nrows, ncols = x.shape\n\n        block_size = triton.next_power_of_2(ncols)\n        num_warps = _get_num_warps(block_size)\n\n        res = torch.empty_like(x)\n        grid = (nrows,)\n\n        _softmax_kernel_fwd[grid](\n            res,\n            res.stride(0),\n            x,\n            x.stride(0),\n            ncols,\n            block_size=block_size,\n            num_warps=num_warps,\n\n        )\n\n        if x.requires_grad:\n            ctx.save_for_backward(res)\n        return res.view(*orig_shape)\n\n    @staticmethod\n    def backward(ctx, grad_probs):\n        orig_shape = grad_probs.shape\n        probs, = ctx.saved_tensors\n\n        grad_probs = grad_probs.view(-1, orig_shape[-1])\n        nrows, ncols = grad_probs.shape\n\n        block_size = triton.next_power_of_2(ncols)\n        num_warps = _get_num_warps(block_size)\n\n        dx = torch.empty_like(probs)\n        grid = (nrows,)\n\n        _softmax_kernel_bwd[grid](\n            dx,\n            dx.stride(0),\n            probs,\n            probs.stride(0),\n            grad_probs,\n            grad_probs.stride(0),\n            ncols,\n            block_size=block_size,\n            num_warps=num_warps,\n\n        )\n        return dx.view(*orig_shape), None\n\nfused_softmax = triton_softmax.apply\n\nif __name__ == '__main__':\n    sample = torch.tensor([[1,2,3,4,5], [5,4,3,2,1]], dtype = torch.float32, device=\"cuda\", requires_grad=True)\n    from torch.nn.functional import softmax as torch_softmax\n    res_torch = torch_softmax(sample, dim=1)\n    res_triton = fused_softmax(sample)\n\n    torch.testing.assert_close(res_torch, res_triton, rtol=0, atol=1e-4)\n\n    # backward\n    dout = torch.randn_like(sample)\n    bwd_torch = res_torch.backward(dout)\n    bwd_triton = res_triton.backward(dout)\n\n    torch.testing.assert_close(bwd_triton, bwd_torch, rtol=0, atol=1e-4)\n"
  },
  {
    "path": "kernels/triton/training/rms_norm/fused_rms_norm.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.\n\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# Credit\n# Tri Dao's Triton LayerNorm: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py\n# Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n\n# pylint: skip-file\n# flake8: noqa\n\nimport math\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n    configs=[\n        triton.Config({}, num_warps=1),\n        triton.Config({}, num_warps=2),\n        triton.Config({}, num_warps=4),\n        triton.Config({}, num_warps=8),\n        triton.Config({}, num_warps=16),\n        triton.Config({}, num_warps=32),\n    ],\n    key=[\"N\"],\n)\n@triton.jit\ndef _rms_norm_fwd_kernel(\n    X,\n    stride_x,\n    Y,\n    stride_y,\n    W,\n    Rstd,\n    eps,\n    M,  # num rows\n    N,  # num cols\n    block_N: tl.constexpr,\n):\n\n    row = tl.program_id(0)\n    cols = tl.arange(0, block_N)\n\n    # Load input data and weights\n    mask = cols < N\n    x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)\n    w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)\n\n    # Compute mean and variance\n    # xbar = tl.sum(x, axis=0) / tl.max(tl.sum(mask, axis=0), 1)\n    xbar = tl.where(cols < N, x, 0.0)\n    var = tl.sum(xbar * xbar, axis=0) / N\n    rstd = 1 / tl.sqrt(var + eps)\n\n    # Store the reciprocal standard deviation\n    tl.store(Rstd + row, rstd)\n\n    # Normalize and apply linear transformation\n    x_hat = x * rstd\n    y = x_hat * w\n\n    # Write output\n    tl.store(Y + row * stride_y + cols, y, mask=mask)\n\n\n@triton.autotune(\n    configs=[\n        triton.Config({}, num_warps=1),\n        triton.Config({}, num_warps=2),\n        triton.Config({}, num_warps=4),\n        triton.Config({}, num_warps=8),\n        triton.Config({}, num_warps=16),\n        triton.Config({}, num_warps=32),\n    ],\n    key=[\"N\"],\n)\n@triton.jit\ndef _rms_norm_bwd_kernel_sm(\n    X,\n    stride_x,\n    W,\n    DY,\n    stride_dy,\n    DX,\n    stride_dx,\n    Rstd,\n    DW,\n    eps,\n    M,  # num rows\n    N,  # num cols\n    rows_per_program,\n    block_N: tl.constexpr,\n):\n    row_block_id = tl.program_id(0)\n    row_start = row_block_id * rows_per_program\n    cols = tl.arange(0, block_N)\n    mask = cols < N\n\n    # Load weights\n    w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)\n\n    # Accumulate gradients for weights\n    dw = tl.zeros((block_N,), dtype=tl.float32)\n\n    row_end = min(row_start + rows_per_program, M)\n    for row in range(row_start, row_end):\n        # Load input, output gradient, and reciprocal standard deviation\n        x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)\n        dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32)\n        rstd = tl.load(Rstd + row)\n\n        # Compute normalized input and gradients\n        x_hat = x * rstd\n        wdy = w * dy\n        dw += dy * x_hat\n        c1 = tl.sum(x_hat * wdy, axis=0) / N\n        dx = (wdy - x_hat * c1) * rstd\n\n        # Store input gradient\n        tl.store(DX + row * stride_dx + cols, dx, mask=mask)\n\n    # Store weight gradients\n    tl.store(DW + row_block_id * N + cols, dw, mask=mask)\n\n\n\"\"\"\n# using the sm count to determine the number of rows per program\n# appears to be slightly faster than this bwd kernel below.\n@triton.jit\ndef _rms_norm_bwd_kernel(\n    X,  stride_x,\n    W,\n    DY, stride_dy,\n    DX, stride_dx,\n    Rstd,\n    DW,\n    eps,\n    M,  # num rows\n    N,  # num cols\n    rows_per_program,\n    block_N: tl.constexpr,\n):\n    row_block_id = tl.program_id(0)\n    row_start = row_block_id * rows_per_program\n    cols = tl.arange(0, block_N)\n    mask = cols < N\n\n    # Load weights\n    w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)\n\n    # Accumulate gradients for weights\n    dw = tl.zeros((block_N,), dtype=tl.float32)\n\n    row_end = min(row_start + rows_per_program, M)\n    for row in range(row_start, row_end):\n        # Load input, output gradient, and reciprocal standard deviation\n        x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)\n        dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32)\n        rstd = tl.load(Rstd + row)\n\n        # Compute normalized input and gradients\n        x_hat = x * rstd\n        wdy = w * dy\n        dw += dy * x_hat\n        c1 = tl.sum(x_hat * wdy, axis=0) / N\n        dx = (wdy - x_hat * c1) * rstd\n\n        # Store input gradient\n        tl.store(DX + row * stride_dx + cols, dx, mask=mask)\n\n    # Store weight gradients\n    tl.store(DW + cols, dw, mask=mask)\n\"\"\"\n\n\nclass ttt_RMSNorm(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, x, weight, eps):\n        x_shape_start = x.shape\n\n        # Flatten input\n        x = x.reshape(-1, x.shape[-1])\n        if x.stride(-1) != 1:\n            x = x.contiguous()\n        if weight.stride(-1) != 1:\n            weight = weight.contiguous()\n\n        M, N = x.shape\n        y = torch.empty_like(x)\n        rstd = torch.empty((M,), dtype=torch.float32, device=x.device)\n\n        max_size = 65536 // x.element_size()\n        block_N = min(max_size, triton.next_power_of_2(N))\n\n        if N > block_N:\n            raise ValueError(f\"N {N} must be <= {block_N=}\")\n\n        grid = lambda meta: (M,)\n        _rms_norm_fwd_kernel[grid](\n            x,\n            x.stride(0),\n            y,\n            y.stride(0),\n            weight,\n            rstd,\n            eps,\n            M,\n            N,\n            block_N,\n        )\n\n        ctx.eps = eps\n        ctx.save_for_backward(x, weight, rstd)\n        ctx.x_shape_start = x_shape_start\n\n        y = y.reshape(x_shape_start)\n        return y\n\n    @staticmethod\n    def backward(ctx, dy):\n        x, weight, rstd = ctx.saved_tensors\n        eps = ctx.eps\n        x_shape_start = ctx.x_shape_start\n\n        # Flatten input and output gradients\n        dy = dy.reshape(-1, dy.shape[-1])\n        if dy.stride(-1) != 1:\n            dy = dy.contiguous()\n\n        M, N = dy.shape\n        dx = torch.empty_like(x)\n        dw = torch.empty_like(weight)\n\n        sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count\n        _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)\n\n        max_size = 65536 // x.element_size()\n        block_N = min(max_size, triton.next_power_of_2(N))\n        rows_per_sm = math.ceil(M / sm_count)\n\n        if N > block_N:\n            raise ValueError(f\"N {N} must be <= {block_N=}\")\n\n        grid = lambda meta: (sm_count,)\n        _rms_norm_bwd_kernel_sm[grid](\n            x,\n            x.stride(0),\n            weight,\n            dy,\n            dy.stride(0),\n            dx,\n            dx.stride(0),\n            rstd,\n            _dw,\n            eps,\n            M,\n            N,\n            rows_per_sm,\n            block_N,\n        )\n        dw = _dw.sum(0).to(weight.dtype)\n        dx = dx.reshape(x_shape_start)\n        return dx, dw, None\n\n\n\"\"\"\n    # this is an alternative approach - but it seems to be just slightly slower than sm approach.\n    @staticmethod\n    def backward(ctx, dy):\n        x, weight, rstd = ctx.saved_tensors\n        eps = ctx.eps\n        x_shape_start = ctx.x_shape_start\n\n        # Flatten input and output gradients\n        dy = dy.reshape(-1, dy.shape[-1])\n        if dy.stride(-1) != 1:\n            dy = dy.contiguous()\n\n        M, N = dy.shape\n        dx = torch.empty_like(x)\n        dw = torch.empty_like(weight)\n\n        max_size = 65536 // x.element_size()\n        block_N = min(max_size, triton.next_power_of_2(N))\n        rows_per_program = 1024\n\n        if N > block_N:\n            raise ValueError(f\"N {N} must be <= {block_N=}\")\n\n        grid = lambda meta: (triton.cdiv(M, rows_per_program),)\n        _rms_norm_bwd_kernel[grid](\n            x, x.stride(0),\n            weight,\n            dy, dy.stride(0),\n            dx, dx.stride(0),\n            rstd,\n            dw,\n            eps,\n            M, N,\n            rows_per_program,\n            block_N,\n        )\n        dx = dx.reshape(x_shape_start)\n        return dx, dw, None\n    \"\"\"\n\n\ndef fused_rms_norm_fn(\n    x,\n    weight,\n    eps=1e-6,\n):\n    return ttt_RMSNorm.apply(\n        x,\n        weight,\n        eps,\n    )\n\n\nclass FusedRMSNorm(torch.nn.Module):\n    def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.eps = eps\n        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.ones_(self.weight)\n\n    def forward(\n        self,\n        x,\n    ):\n        return fused_rms_norm_fn(\n            x,\n            self.weight,\n            eps=self.eps,\n        )\n"
  },
  {
    "path": "kernels/triton/tutorials/README.md",
    "content": "Triton tutorials\n"
  },
  {
    "path": "readme.md",
    "content": "\n### Applied AI repo\nFor experiments and research on Applied AI.\n\n### Projects\n\n#### Kernels\n\nHousing a variety of Triton and CUDA kernels for training and inference.\n\nInference kernels = no backward pass support.\n\n##### Triton Kernels\n\n#### 1 - Triton - MoE (Mixtral) GEMM for accelerating inference. Uses col major access pattern to increase locality.\n\n<img width=\"556\" alt=\"moe_gemm_a100\" src=\"https://github.com/meta-pytorch/applied-ai/assets/46302957/9eece843-b5e1-4250-a98a-3ae79dff1bc3\">\n\n\n#### 2 - Triton - Fused Softmax for both training and inference.\n\n<img width=\"556\" alt=\"softmax_fused\" src=\"https://github.com/meta-pytorch/applied-ai/assets/46302957/de11686b-4c17-4696-857a-4f56488d6df3\">\n\n#### 3 - Triton - Fused RMSNorm for both training and inference. \n[Fused RMSNorm Kernel](https://github.com/meta-pytorch/applied-ai/blob/main/kernels/triton/training/rms_norm/fused_rms_norm.py)\n\n#### Other projects from Applied AI\n\n1. [CUDA Mode](https://github.com/cuda-mode) - Reading group for learning CUDA programming - ([Discord](https://discord.gg/cudamode), [Lecture Materials](https://github.com/cuda-mode/lectures), [Lecture recordings](https://www.youtube.com/@CUDAMODE))\n2. [llama-recipes](https://github.com/meta-llama/llama-recipes) - Recipes for fine-tuning and inference for Llama model series\n3. NeurIPS'23 [LLM Efficiency Challenge](https://llm-efficiency-challenge.github.io/) - 1LLM + 1GPU + 1Day competition - ([website](https://llm-efficiency-challenge.github.io/), [code](https://github.com/llm-efficiency-challenge), [NeurIPS Workshop recordings](https://neurips.cc/virtual/2023/competition/66594))\n\n### Papers and Publications\n\n1. PyTorch 2: Faster Machine Learning Through Dynamic Python Bytecode Transformation and Graph Compilation [paper](https://pytorch.org/assets/pytorch2-2.pdf)\n2. Accelerating a Triton Fused Kernel for W4A16 Quantized Inference with SplitK Work Decomposition [paper](https://ai.meta.com/research/publications/accelerating-a-triton-fused-kernel-for-w4a16-quantized-inference-with-splitk-work-decomposition/)\n3. PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel [paper](https://arxiv.org/abs/2304.11277)\n4. Sustainable AI: Environmental Implications, Challenges and Opportunities [paper](https://arxiv.org/abs/2111.00364)\n\n\n\n### License\nThe applied-ai repo is released under the [BSD 3](LICENSE) license.\n"
  },
  {
    "path": "tutorials/triton/kernels/__init__.py",
    "content": "\n"
  },
  {
    "path": "tutorials/triton/kernels/flash_attention_fwd.py",
    "content": "# flash forward v2\n"
  },
  {
    "path": "tutorials/triton/kernels/fused_softmax.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# ---- Fused Softmax written in Triton ------\n# Extra Credits:\n# Triton Softmax Tutorial\n# LucidRains Triton_Transformers\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom torch import autograd\n\ndef _get_num_warps(block_size: int)-> int:\n    num_warps = 4\n    if block_size > 2047:\n        num_warps = 8\n    if block_size > 4095:\n        num_warps=16\n    return num_warps\n\n@triton.jit\ndef _softmax_kernel_fwd(\n    output_ptr,\n    output_row_stride,\n    input_ptr,\n    input_row_stride,\n    n_cols,\n    block_size: tl.constexpr,\n):\n    # setup input location\n    row_index = tl.program_id(0)\n    input_row_ptr = input_ptr + (row_index * input_row_stride)\n    col_offsets = tl.arange(0, block_size)\n    input_ptrs = input_row_ptr + col_offsets\n    rw_mask = col_offsets < n_cols\n    row = tl.load(input_ptrs, mask = rw_mask, other=float(\"-inf\"))\n\n    # safe softmax proper\n    safe_row = row - tl.max(row, axis=0)\n    numerator = tl.exp(safe_row)\n    denom = tl.sum(numerator, axis=0)\n    sm_out = numerator / denom\n\n    # write results to HBM\n    out_row_ptr = output_ptr + (row_index * output_row_stride)\n    out_row_ptrs = out_row_ptr + col_offsets\n    tl.store(out_row_ptrs, sm_out, mask = rw_mask)\n\n\n@triton.jit\ndef _softmax_kernel_bwd(\n    output_ptr, \n    stride_output_row,\n    grad_ptr, \n    stride_grad_row,\n    input_ptr,\n    stride_input_row,\n    n_cols,\n    block_size: tl.constexpr,\n\n):\n    # setup input locations - need both grad and input access\n    row_index = tl.program_id(0)\n\n    input_row_ptr = input_ptr + (row_index * stride_input_row)\n    grad_row_ptr = grad_ptr + (row_index * stride_grad_row)\n\n    col_offsets = tl.arange(0,block_size)\n    rw_mask = col_offsets < n_cols\n\n    input_row_ptrs = input_row_ptr + col_offsets\n    grad_row_ptrs = grad_row_ptr + col_offsets\n\n\n    probs_row =tl.load(input_row_ptrs, mask=rw_mask, other = 0)\n    grads_row = tl.load(grad_row_ptrs, mask = rw_mask, other=0)\n\n    # compute derivatives\n    dx = probs_row * grads_row\n    dsm_out = dx - probs_row * (tl.sum(dx, axis=0))\n\n    # write to HBM\n    output_row_ptr = output_ptr + (row_index * stride_output_row)\n    output_ptrs = output_row_ptr + col_offsets\n    tl.store(output_ptrs, dsm_out, mask=rw_mask)\n\n\nclass triton_softmax(autograd.Function):\n    @staticmethod\n    def forward(ctx, x):\n        orig_shape = x.shape\n        x = x.view(-1, orig_shape[-1])\n        nrows, ncols = x.shape\n\n        block_size = triton.next_power_of_2(ncols)\n        num_warps = _get_num_warps(block_size)\n\n        res = torch.empty_like(x)\n        grid = (nrows,)\n\n        _softmax_kernel_fwd[grid](\n            res,\n            res.stride(0),\n            x,\n            x.stride(0),\n            ncols,\n            block_size=block_size,\n            num_warps=num_warps,\n\n        )\n\n        if x.requires_grad:\n            ctx.save_for_backward(res)\n        return res.view(*orig_shape)\n    \n    @staticmethod\n    def backward(ctx, grad_probs):\n        orig_shape = grad_probs.shape\n        probs, = ctx.saved_tensors\n\n        grad_probs = grad_probs.view(-1, orig_shape[-1])\n        nrows, ncols = grad_probs.shape\n\n        block_size = triton.next_power_of_2(ncols)\n        num_warps = _get_num_warps(block_size)\n\n        dx = torch.empty_like(probs)\n        grid = (nrows,)\n\n        _softmax_kernel_bwd[grid](\n            dx,\n            dx.stride(0),\n            probs,\n            probs.stride(0),\n            grad_probs,\n            grad_probs.stride(0),\n            ncols,\n            block_size=block_size,\n            num_warps=num_warps,\n\n        )\n        return dx.view(*orig_shape), None\n\nfused_softmax = triton_softmax.apply\n\nif __name__ == '__main__':\n    sample = torch.tensor([[1,2,3,4,5], [5,4,3,2,1]], dtype = torch.float32, device=\"cuda\", requires_grad=True)\n    from torch.nn.functional import softmax as torch_softmax\n    res_torch = torch_softmax(sample, dim=1)\n    res_triton = fused_softmax(sample)\n\n    torch.testing.assert_close(res_torch, res_triton, rtol=0, atol=1e-4)\n\n    # backward\n    dout = torch.randn_like(sample)\n    bwd_torch = res_torch.backward(dout)\n    bwd_triton = res_triton.backward(dout)\n\n    torch.testing.assert_close(bwd_triton, bwd_torch, rtol=0, atol=1e-4)\n"
  },
  {
    "path": "tutorials/triton/kernels/readme.md",
    "content": "Triton tutorials\n\n1 - Vector Add - Starting tutorial on simple first kernel  \n2 - Fused Softmax - Full fused softmax with both forward and backward (training ready)\n"
  },
  {
    "path": "tutorials/triton/kernels/vector_add.py",
    "content": "# coding up a Triton vector addition kernel\n# links to\n\nimport triton\nimport triton.language as tl \nimport torch\n\n@triton.jit\ndef kernel_vector_addition(a_ptr, b_ptr, out_ptr, \n                           num_elems: tl.constexpr, \n                           block_size: tl.constexpr):\n    \n    pid = tl.program_id(axis = 0)\n    \n    block_start = pid * block_size  # 0 * 2 = 0, 1 * 2 = 2, \n    thread_offsets = block_start + tl.arange(0, block_size)\n    mask = thread_offsets < num_elems\n    a_pointers = tl.load(a_ptr + thread_offsets, mask = mask)\n    b_pointers = tl.load(b_ptr + thread_offsets, mask = mask)\n    res = a_pointers + b_pointers\n    tl.store(out_ptr + thread_offsets, res, mask=mask)\n\n\ndef ceil_div(x: int, y: int)-> int:\n    return ((x+y-1)// y)\n\ndef vector_addition(a: torch.tensor, b: torch.tensor)-> torch.tensor:\n    output_buffer = torch.empty_like(a)\n    assert a.is_cuda() and b.is_cuda()\n    num_elems = a.numel()\n    assert num_elems == b.numel() # todo - handle mismatched sizes\n\n    block_size = 128 \n    grid_size = ceil_div(num_elems, block_size)\n    grid = (grid_size,)\n\n    k2 = kernel_vector_addition[grid](a, b, output_buffer,\n                                      num_elems, \n                                      block_size)\n    \n    return output_buffer\n"
  },
  {
    "path": "tutorials/triton/tests/test_softmax.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\nimport pytest\nimport torch\nimport sys\nsys.path.append('..')\nfrom triton_kernels.softmax import fused_softmax\n\nfrom test_utils import assert_expected, set_rng_seed, gpu_test\n\n@pytest.fixture(autouse=True)\ndef set_seed():\n    set_rng_seed(2020)\n\n\n@gpu_test()\nclass TestForwardSoftMax:\n    \n    def test_forward_2D_float32(self,):\n        # float32\n        seq_len = 768\n\n        sample_constant_float32 = torch.ones((seq_len, seq_len), dtype=torch.float32, device='cuda')\n        sample_random_float32 = torch.randn_like(sample_constant_float32)\n\n        expected_out_constant32 = torch.softmax(sample_constant_float32, dim=1)\n        expected_out_random32 = torch.softmax(sample_random_float32, dim=1)\n\n        triton_out_c32 = fused_softmax(sample_constant_float32)\n        triton_out_random32 = fused_softmax(sample_random_float32)\n\n        assert_expected(triton_out_c32, expected_out_constant32 )\n        assert_expected(triton_out_random32, expected_out_random32)\n\n    def test_forward_2D_bfloat16(self,):\n        # bfloat16\n        seq_len = 2048\n        sample_constant_bf16 = torch.ones((seq_len, seq_len), dtype=torch.bfloat16, device='cuda')\n        sample_random_bf16  = torch.randn_like(sample_constant_bf16)\n\n        expected_out_c_bf16 = torch.softmax(sample_constant_bf16, dim=1)\n        expected_out_rand_bf16 = torch.softmax(sample_random_bf16, dim=1)\n\n        triton_out_c_bf16 = fused_softmax(sample_constant_bf16)\n        triton_out_rand_bf16 = fused_softmax(sample_random_bf16)\n\n        assert_expected(triton_out_c_bf16, expected_out_c_bf16 )\n        assert_expected(triton_out_rand_bf16, expected_out_rand_bf16)\n    \n    def test_forward_3D_bfloat16(self,):\n        # bfloat16\n        seq_len = 2048\n        batch = 12\n\n        sample_constant_bf16 = torch.ones((batch, seq_len, seq_len), dtype=torch.bfloat16, device='cuda')\n        sample_random_bf16  = torch.randn_like(sample_constant_bf16)\n\n        expected_out_c_bf16 = torch.softmax(sample_constant_bf16, dim=1)\n        expected_out_rand_bf16 = torch.softmax(sample_random_bf16, dim=1)\n\n        triton_out_c_bf16 = fused_softmax(sample_constant_bf16)\n        triton_out_rand_bf16 = fused_softmax(sample_random_bf16)\n\n        assert_expected(triton_out_c_bf16, expected_out_c_bf16, atol=1e-2 )\n        assert_expected(triton_out_rand_bf16, expected_out_rand_bf16, atol=1e-2)\n\n\n@gpu_test()\nclass TestBackwardSoftMax:\n    \n    def test_backward_2D(self,):\n        seq_len = 1024\n\n        sample_constant_float32 = torch.ones((seq_len, seq_len), dtype=torch.float32, device='cuda', requires_grad=True)\n        sample_random_float32 = torch.randn_like(sample_constant_float32, requires_grad=True)\n\n        expected_fwd_constant32 = torch.softmax(sample_constant_float32, dim=1)\n        expected_fwd_random32 = torch.softmax(sample_random_float32, dim=1)\n\n        triton_fwd_c32 = fused_softmax(sample_constant_float32)\n        triton_fwd_random32 = fused_softmax(sample_random_float32)\n\n        dout = torch.randn_like(sample_constant_float32)\n\n        expected_bwd_c32 = expected_fwd_constant32.backward(dout)\n        expected_bwd_r32 = expected_fwd_random32.backward(dout)\n\n        triton_bwd_c32 = triton_fwd_c32.backward(dout)\n        triton_bwd_r32 = triton_fwd_random32.backward(dout)\n\n\n        assert_expected(triton_bwd_c32, expected_bwd_c32 )\n        assert_expected(triton_bwd_r32, expected_bwd_r32)\n\n    def test_bwd_3D(self,):\n        seq_len = 2048\n        batch = 4\n\n        sample_constant_float32 = torch.ones((batch, seq_len, seq_len), dtype=torch.float32, device='cuda', requires_grad=True)\n        sample_random_float32 = torch.randn_like(sample_constant_float32, requires_grad=True)\n\n        expected_fwd_constant32 = torch.softmax(sample_constant_float32, dim=1)\n        expected_fwd_random32 = torch.softmax(sample_random_float32, dim=1)\n\n        triton_fwd_c32 = fused_softmax(sample_constant_float32)\n        triton_fwd_random32 = fused_softmax(sample_random_float32)\n\n        dout = torch.randn_like(sample_constant_float32)\n\n        expected_bwd_c32 = expected_fwd_constant32.backward(dout)\n        expected_bwd_r32 = expected_fwd_random32.backward(dout)\n\n        triton_bwd_c32 = triton_fwd_c32.backward(dout)\n        triton_bwd_r32 = triton_fwd_random32.backward(dout)\n\n\n        assert_expected(triton_bwd_c32, expected_bwd_c32 )\n        assert_expected(triton_bwd_r32, expected_bwd_r32)\n"
  },
  {
    "path": "tutorials/triton/tests/test_utils.py",
    "content": "from pathlib import Path\nfrom typing import Any, Dict, NamedTuple, Optional, Tuple, Union\n\nimport pytest\nimport torch\nimport torch.distributed as dist\nfrom torch import Tensor, nn\n\n\ndef assert_expected(\n    actual: Any,\n    expected: Any,\n    rtol: Optional[float] = 0,\n    atol: Optional[float] = 1e-4,\n    check_device=True,\n):\n    torch.testing.assert_close(\n        actual,\n        expected,\n        rtol=rtol,\n        atol=atol,\n        check_device=check_device,\n        msg=f\"actual: {actual}, expected: {expected}\",\n    )\n\ndef set_rng_seed(seed):\n    \"\"\"Sets the seed for pytorch random number generators\"\"\"\n    torch.manual_seed(seed)\n\n\ndef gpu_test(gpu_count: int = 1):\n    \"\"\"\n    Annotation for GPU tests, skipping the test if the\n    required amount of GPU is not available\n    \"\"\"\n    message = f\"Not enough GPUs to run the test: required {gpu_count}\"\n    return pytest.mark.skipif(torch.cuda.device_count() < gpu_count, reason=message)\n"
  }
]