Full Code of pytorch-labs/applied-ai for AI

main 2391954b1998 cached
119 files
547.0 KB
155.6k tokens
317 symbols
1 requests
Download .txt
Showing preview only (585K chars total). Download the full file or copy to clipboard to get everything.
Repository: pytorch-labs/applied-ai
Branch: main
Commit: 2391954b1998
Files: 119
Total size: 547.0 KB

Directory structure:
gitextract_gcbwaf39/

├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── assets/
│   └── images/
│       ├── dev-discuss-asynctp/
│       │   └── readme.md
│       └── readme.md
├── dev/
│   ├── sr/
│   │   ├── .gitignore
│   │   ├── readme.md
│   │   ├── setup.py
│   │   ├── src/
│   │   │   ├── stochastic_rounding.cu
│   │   │   ├── stochastic_rounding.hpp
│   │   │   └── stochastic_rounding_cuda.cu
│   │   ├── test.md
│   │   ├── tests/
│   │   │   ├── benchmark.py
│   │   │   └── core_unit_tests.py
│   │   ├── usage.py
│   │   └── usage2.py
│   └── triton_groupGEMM/
│       ├── groupgemm.py
│       ├── testing/
│       │   ├── base_testing.py
│       │   └── unit_tests.py
│       ├── tma_utils.py
│       └── triton_tutorial_groupgemm.py
├── kernels/
│   ├── MoE/
│   │   └── group_GEMM/
│   │       └── triton/
│   │           ├── readme.md
│   │           ├── testing/
│   │           │   ├── fast_verification.py
│   │           │   └── pytorch_reference_backwards.py
│   │           ├── tgroup_gemm_backwards.py
│   │           ├── tgroup_gemm_forward.py
│   │           └── utils/
│   │               └── tma_utils.py
│   ├── blackwell/
│   │   ├── cute_gemm_01/
│   │   │   ├── Makefile
│   │   │   ├── build/
│   │   │   │   └── temp.linux-x86_64-cpython-312/
│   │   │   │       ├── .ninja_deps
│   │   │   │       ├── .ninja_log
│   │   │   │       ├── build.ninja
│   │   │   │       ├── sm100_gemm.o
│   │   │   │       └── sm100_gemm_pytorch.o
│   │   │   ├── dist/
│   │   │   │   └── sm100_gemm-0.0.0-py3.12-linux-x86_64.egg
│   │   │   ├── driver.py
│   │   │   ├── setup.py
│   │   │   ├── sm100_gemm.cu
│   │   │   ├── sm100_gemm.egg-info/
│   │   │   │   ├── PKG-INFO
│   │   │   │   ├── SOURCES.txt
│   │   │   │   ├── dependency_links.txt
│   │   │   │   ├── not-zip-safe
│   │   │   │   ├── requires.txt
│   │   │   │   └── top_level.txt
│   │   │   ├── sm100_gemm.h
│   │   │   └── sm100_gemm_pytorch.cpp
│   │   └── cute_gemm_02_tma/
│   │       ├── build/
│   │       │   └── temp.linux-x86_64-cpython-312/
│   │       │       ├── .ninja_deps
│   │       │       ├── .ninja_log
│   │       │       ├── build.ninja
│   │       │       ├── sm100_gemm.o
│   │       │       └── sm100_gemm_pytorch.o
│   │       ├── dist/
│   │       │   └── sm100_gemm-0.0.0-py3.12-linux-x86_64.egg
│   │       ├── driver.py
│   │       ├── setup.py
│   │       ├── sm100_gemm.cu
│   │       ├── sm100_gemm.egg-info/
│   │       │   ├── PKG-INFO
│   │       │   ├── SOURCES.txt
│   │       │   ├── dependency_links.txt
│   │       │   ├── not-zip-safe
│   │       │   ├── requires.txt
│   │       │   └── top_level.txt
│   │       ├── sm100_gemm.h
│   │       └── sm100_gemm_pytorch.cpp
│   ├── cuda/
│   │   ├── cutlass_gemm/
│   │   │   ├── broadcast_load_epilogue_c3x.hpp
│   │   │   ├── common.hpp
│   │   │   ├── cutlass.cpp
│   │   │   ├── cutlass_kernel.cu
│   │   │   ├── readme.md
│   │   │   ├── setup.py
│   │   │   └── test_cutlass_gemm.py
│   │   ├── inference/
│   │   │   ├── README.md
│   │   │   └── hadamard_transform/
│   │   │       ├── hadamard_transform.cpp
│   │   │       ├── hadamard_transform_cuda.cu
│   │   │       ├── setup.py
│   │   │       └── test.py
│   │   ├── training/
│   │   │   └── README.md
│   │   └── tutorials/
│   │       ├── README.md
│   │       └── flash2.cu
│   ├── needs_perf_help/
│   │   ├── fp8_gemm_bench.py
│   │   └── fp8_rowwise_tma_persistent.py
│   └── triton/
│       ├── inference/
│       │   ├── README.md
│       │   ├── col_major_moe_gemm/
│       │   │   ├── README.md
│       │   │   ├── perf_test_moe.py
│       │   │   ├── profile_moe.py
│       │   │   ├── results.html
│       │   │   ├── test.csv
│       │   │   ├── test_moe_gemm.py
│       │   │   ├── v0_moe_fused.py
│       │   │   ├── v1_moe_fused.py
│       │   │   └── v2_moe_fused.py
│       │   ├── flash_attention/
│       │   │   └── stay_attention.py
│       │   ├── fp8/
│       │   │   ├── float8_groupwise_quant.py
│       │   │   ├── scaled_fp8_gemm.py
│       │   │   ├── splitk_gemm_fp8.py
│       │   │   └── tma_gemm.py
│       │   ├── gptq/
│       │   │   ├── a100_qlinear.py
│       │   │   ├── benchmark.py
│       │   │   ├── h100_qlinear.py
│       │   │   ├── mixtral/
│       │   │   │   ├── test_dequant_moe_gemm.py
│       │   │   │   └── w4a16_fused_dequant_gemm.py
│       │   │   ├── small_benchmark_cuda_graphs.py
│       │   │   └── splitk_dequant_gemm.py
│       │   ├── mamba/
│       │   │   └── causal_1d_conv/
│       │   │       ├── causal_1d_conv/
│       │   │       │   └── causal_1d_conv.py
│       │   │       └── tests/
│       │   │           └── test_causal_1d_conv.py
│       │   ├── paged_attention/
│       │   │   └── attention_triton.py
│       │   └── torch_compile/
│       │       └── flash_backward.py
│       ├── training/
│       │   ├── README.md
│       │   ├── fused_softmax/
│       │   │   ├── README.md
│       │   │   └── softmax.py
│       │   └── rms_norm/
│       │       └── fused_rms_norm.py
│       └── tutorials/
│           └── README.md
├── readme.md
└── tutorials/
    └── triton/
        ├── kernels/
        │   ├── __init__.py
        │   ├── flash_attention_fwd.py
        │   ├── fused_softmax.py
        │   ├── readme.md
        │   └── vector_add.py
        └── tests/
            ├── test_softmax.py
            └── test_utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*.pyc
**/.ipynb_checkpoints


================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Code of Conduct

## Our Pledge

In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.

## Our Standards

Examples of behavior that contributes to creating a positive environment
include:

* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members

Examples of unacceptable behavior by participants include:

* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting

## Our Responsibilities

Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.

Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.

## Scope

This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.

This Code of Conduct also applies outside the project spaces when there is a
reasonable belief that an individual's behavior may have a negative impact on
the project or its community.

## Enforcement

Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at <opensource-conduct@meta.com>. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.

Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.

## Attribution

This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html

[homepage]: https://www.contributor-covenant.org

For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq


================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to Applied AI
We want to make contributing to this project as easy and transparent as
possible.

## Our Development Process
... (in particular how this is synced with internal changes to the project)

## Pull Requests
We actively welcome your pull requests.

1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").

## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Meta's open source projects.

Complete your CLA here: <https://code.facebook.com/cla>

## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.

Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.

## Coding Style
* 2 spaces for indentation rather than tabs
* 80 character line length
* ...

## License
By contributing to applied-ai, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.


================================================
FILE: LICENSE
================================================
Copyright 2024 Meta

Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


================================================
FILE: assets/images/dev-discuss-asynctp/readme.md
================================================
This folder is for hosting the images for the AsyncTP public post at:    
[https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487)


================================================
FILE: assets/images/readme.md
================================================
Folder for housing images for the readmes.


================================================
FILE: dev/sr/.gitignore
================================================
*.o
*.ninja
*.txt
*.egg-info
*.ninja-deps
*.ninja-log/
*.so
dist/
build/


================================================
FILE: dev/sr/readme.md
================================================
Branch for stochastic rounding kernel
Currently processes 4 elements per thread to leverage rand4


================================================
FILE: dev/sr/setup.py
================================================

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name='stochastic_rounding_cuda',
    version='0.1.021825',
    ext_modules=[
        CUDAExtension('stochastic_rounding_cuda', [
            'src/stochastic_rounding.cu',
            'src/stochastic_rounding_cuda.cu'
        ],
        extra_compile_args={
            'cxx': ['-O3'],
            'nvcc': [
                '-O3',
                '--expt-relaxed-constexpr',  # better template support
                #'-gencode=arch=compute_70,code=sm_70',  # Volta
                #'-gencode=arch=compute_75,code=sm_75',  # Turing
                #'-gencode=arch=compute_80,code=sm_80'   # Amper
                #'-gencode=arch=compute_86,code=sm_86'   # Ampere
                '-gencode=arch=compute_90,code=sm_90',  # Hopper
            ]
        })
    ],
    cmdclass={
        'build_ext': BuildExtension
    }
)


================================================
FILE: dev/sr/src/stochastic_rounding.cu
================================================

#include <pybind11/pybind11.h>
#include "stochastic_rounding.hpp"
#include <random>

namespace py = pybind11;

__host__ int getOptimalBlockSize() {
    cudaDeviceProp prop;
    cudaGetDeviceProperties(&prop, 0);
    return std::min(prop.maxThreadsPerBlock, 256);
}

torch::Tensor stochastic_round_bf16_cuda(torch::Tensor input, bool requires_grad) {
    TORCH_CHECK(input.is_cuda(), "Input tensor must be on CUDA device");
    TORCH_CHECK(input.is_contiguous(), "Input tensor must be contiguous");
    TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input tensor must be float32");

    const int threads_per_block = 256;
    const int num_elements = input.numel();
    const int elements_per_thread = 4;

    const int min_blocks = (num_elements + elements_per_thread * threads_per_block - 1) /
                          (elements_per_thread * threads_per_block);

    cudaDeviceProp prop;
    cudaGetDeviceProperties(&prop, 0);
    const int blocks_per_sm = 4;
    const int min_blocks_for_sms = prop.multiProcessorCount * blocks_per_sm;
    const int num_blocks = std::max(min_blocks, min_blocks_for_sms);

    auto options = torch::TensorOptions()
                      .dtype(torch::kBFloat16)
                      .device(input.device())
                      .requires_grad(requires_grad);
    auto output = torch::empty_like(input, options);

    std::random_device rd;
    std::mt19937_64 gen(rd());
    std::uniform_int_distribution<unsigned long long> dis;
    const unsigned long long seed = dis(gen);

    stochastic_round_bf16<<<num_blocks, threads_per_block>>>(
        input.data_ptr<float>(),
        reinterpret_cast<__nv_bfloat16*>(output.data_ptr()),
        num_elements,
        seed);

    cudaError_t err = cudaGetLastError();
    TORCH_CHECK(err == cudaSuccess,
                "CUDA kernel execution failed: ", cudaGetErrorString(err));

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("stochastic_round_bf16",
          static_cast<torch::Tensor (*)(torch::Tensor, bool)>(&stochastic_round_bf16_cuda),
          "Stochastic rounding to BFloat16",
          py::arg("input"),
          py::arg("requires_grad") = false);
}


================================================
FILE: dev/sr/src/stochastic_rounding.hpp
================================================

#pragma once
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <vector_types.h>
#include <torch/extension.h>
#include <pybind11/pybind11.h>

namespace philox {
    constexpr unsigned int W32_0   = 0x9E3779B9;
    constexpr unsigned int W32_1   = 0xBB67AE85;
    constexpr unsigned int M0      = 0xD2511F53;
    constexpr unsigned int M1      = 0xCD9E8D57;
    constexpr int ROUNDS          = 7;
}

// Forward declarations
class PhiloxGenerator {
public:
    __device__ __forceinline__ PhiloxGenerator();
    __device__ __forceinline__ void init(const unsigned long long seed, const unsigned int thread_id);
    __device__ __forceinline__ uint4 next();
private:
    uint2 key;
    uint4 counter;
    static __device__ __forceinline__ uint2 mulhilo(const unsigned int a, const unsigned int b);
    static __device__ __forceinline__ uint4 round(uint4 ctr, uint2 key);
};

// CUDA kernel declaration
__global__ void stochastic_round_bf16(
    float *__restrict__ input,
    __nv_bfloat16 *__restrict__ output,
    const int size,
    const unsigned long long seed);

// Host functions
__host__ int getOptimalBlockSize();
torch::Tensor stochastic_round_bf16_cuda(torch::Tensor input, bool requires_grad = false);


================================================
FILE: dev/sr/src/stochastic_rounding_cuda.cu
================================================
 #include "stochastic_rounding.hpp"
#include <cstdint>

// Philox RNG implementation

__device__ __forceinline__ PhiloxGenerator::PhiloxGenerator() :
    key(make_uint2(0, 0)),
    counter(make_uint4(0, 0, 0, 0)) {}

__device__ __forceinline__ void PhiloxGenerator::init(const unsigned long long seed, const unsigned int thread_id) {
    key.x = static_cast<unsigned int>(seed);
    key.y = static_cast<unsigned int>(seed >> 32);
    counter = make_uint4(thread_id, 0, 0, 0);
    __threadfence_block();
}

__device__ __forceinline__ uint2 PhiloxGenerator::mulhilo(const unsigned int a, const unsigned int b) {
    uint2 result;
    unsigned long long prod;
    asm("mul.wide.u32 %0, %1, %2;" : "=l"(prod) : "r"(a), "r"(b));
    result.x = static_cast<unsigned int>(prod);
    result.y = static_cast<unsigned int>(prod >> 32);
    return result;
}

__device__ __forceinline__ uint4 PhiloxGenerator::round(uint4 ctr, uint2 key) {
    const uint2 mul0 = mulhilo(philox::M0, ctr.x);
    const uint2 mul1 = mulhilo(philox::M1, ctr.z);

    return make_uint4(
        mul1.y ^ ctr.y ^ key.x,
        mul1.x,
        mul0.y ^ ctr.w ^ key.y,
        mul0.x
    );
}

__device__ __forceinline__ uint4 PhiloxGenerator::next() {
    uint4 ctr = counter;
    uint2 k = key;

    #pragma unroll
    for (int i = 0; i < philox::ROUNDS; ++i) {
        ctr = round(ctr, k);
        k.x += philox::W32_0;
        k.y += philox::W32_1;
    }

    counter.x += 4;
    return ctr;
}

__device__ __forceinline__ __nv_bfloat16 float_to_bf16_stochastic(const float value, const uint32_t rand) {
    const uint32_t val_bits = __float_as_uint(value);
    const uint32_t rounding_bits = val_bits & 0xFFFF;
    uint32_t result = val_bits & 0xFFFF0000u;
    result += (rand & 0xFFFF) < rounding_bits ? 0x10000u : 0;
    return __float2bfloat16(__uint_as_float(result));
}

__device__ __forceinline__ void float4_to_bf16_stochastic(
    const float4& values,
    uint4& rand_vals,
    __nv_bfloat16* output) {

    float vals[4] = {values.x, values.y, values.z, values.w};
    uint32_t rands[4] = {rand_vals.x, rand_vals.y, rand_vals.z, rand_vals.w};

    #pragma unroll
    for (int i = 0; i < 4; i++) {
        output[i] = float_to_bf16_stochastic(vals[i], rands[i]);
    }
}

__global__ void stochastic_round_bf16(
    float *__restrict__ input,
    __nv_bfloat16 *__restrict__ output,
    const int size,
    const unsigned long long seed) {

    PhiloxGenerator rng;
    rng.init(seed, blockIdx.x * blockDim.x + threadIdx.x);

    int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
    int stride = blockDim.x * gridDim.x * 4;

    float4 values;
    __nv_bfloat16 local_output[4];

    // Process full vectors of 4 elements
    for (; idx <= size - 4; idx += stride) {
        values = *reinterpret_cast<float4*>(&input[idx]);
        uint4 rand = rng.next();
        float4_to_bf16_stochastic(values, rand, local_output);

        for (int j = 0; j < 4; j++) {
            output[idx + j] = local_output[j];
        }
    }

    // Handle remaining elements
    if (idx < size) {
        float remaining_values[4] = {0.0f, 0.0f, 0.0f, 0.0f};
        int remainder = size - idx;

        for (int j = 0; j < remainder; j++) {
            remaining_values[j] = input[idx + j];
        }

        values.x = remaining_values[0];
        values.y = remaining_values[1];
        values.z = remaining_values[2];
        values.w = remaining_values[3];

        uint4 rand = rng.next();
        float4_to_bf16_stochastic(values, rand, local_output);

        for (int j = 0; j < remainder; j++) {
            output[idx + j] = local_output[j];
        }
    }
}


================================================
FILE: dev/sr/test.md
================================================
(tkdev11) [less@devgpu115.cco2 ~/local/applied-ai/dev/sr (sr_kernel)]$ python usage.py
Launching kernel with blocks=1, threads_per_block=256, num_elements=12
Input tensor: tensor([ 0.3282, -0.4513, -1.0612,  0.1446, -0.8440, -1.4669, -0.7135, -0.6183,
        -2.2411,  2.1464,  1.4772, -1.3564], device='cuda:0')
Output tensor: tensor([ 0.3281, -0.4512, -1.0625,  0.1445, -0.8438, -1.4688, -0.7109, -0.6172,
        -2.2344,  2.1406,  1.4766, -1.3516], device='cuda:0',
       dtype=torch.bfloat16)
Output tensor dtype: torch.bfloat16
Success!


================================================
FILE: dev/sr/tests/benchmark.py
================================================
import torch
import stochastic_rounding_cuda
import numpy as np
import time
from tabulate import tabulate
import argparse

def measure_performance(func, input_tensor, warmup=0, repeats=1):
    """Measure performance of a function with proper CUDA synchronization"""
    # Warmup
    for _ in range(warmup):
        output = func(input_tensor)

    torch.cuda.synchronize()
    start = time.perf_counter()

    for _ in range(repeats):
        output = func(input_tensor)

    torch.cuda.synchronize()
    end = time.perf_counter()

    avg_time = (end - start) / repeats
    elements_per_second = input_tensor.numel() / avg_time
    return avg_time, elements_per_second

def benchmark_sizes(sizes= [1000, 10000, 100000, 1000000, 10000000, (10000000*10), (10000000*100)]):
    #[ 50,000,000]): #
    """Benchmark different input sizes"""
    results = []

    for size in sizes:
        # Create input tensor
        x = torch.randn(size, device='cuda')

        # Measure stochastic rounding
        time_stoch, throughput_stoch = measure_performance(
            stochastic_rounding_cuda.stochastic_round_bf16, x)

        # Measure regular BF16 casting
        time_regular, throughput_regular = measure_performance(
            lambda t: t.to(torch.bfloat16), x)

        results.append([
            size,
            time_stoch * 1000,  # convert to ms
            throughput_stoch / 1e6,  # convert to GElements/s
            time_regular * 1000,
            throughput_regular / 1e6,
            throughput_regular / throughput_stoch  # speedup
        ])

    print("\nSize Comparison:")
    print(tabulate(results,
                  headers=['Size', 'Stoch Time (ms)', 'Stoch ME/s',
                          'Regular Time (ms)', 'Regular ME/s', 'Casting faster by'],
                  floatfmt='.3f'))

def benchmark_shapes(total_size=1000000):
    """Benchmark different tensor shapes with same total size"""
    shapes = [
        (total_size,),           # 1D
        (1000, total_size//1000),  # 2D
        (100, 100, total_size//10000),  # 3D
    ]

    results = []
    for shape in shapes:
        x = torch.randn(*shape, device='cuda')
        time_stoch, throughput_stoch = measure_performance(
            stochastic_rounding_cuda.stochastic_round_bf16, x)

        results.append([
            'x'.join(str(d) for d in shape),
            time_stoch * 1000,
            throughput_stoch / 1e9
        ])

    print("\nShape Comparison (same total size):")
    print(tabulate(results,
                  headers=['Shape', 'Time (ms)', 'GElements/s'],
                  floatfmt='.3f'))

def stress_test(duration=10):
    """Run a stress test for specified duration"""
    print(f"\nRunning stress test for {duration} seconds...")

    size = 1000000
    x = torch.randn(size, device='cuda')
    start_time = time.time()
    iterations = 0

    while time.time() - start_time < duration:
        stochastic_rounding_cuda.stochastic_round_bf16(x)
        iterations += 1

    print(f"Completed {iterations} iterations without errors")
    print(f"Average throughput: {(iterations * size) / (duration * 1e9):.2f} GElements/s")

def memory_test(max_size=1e9):
    """Test memory scaling"""
    sizes = np.logspace(3, min(9, np.log10(max_size)), num=7, dtype=int)
    results = []

    for size in sizes:
        try:
            torch.cuda.empty_cache()
            x = torch.randn(size, device='cuda')
            torch.cuda.synchronize()

            # Measure peak memory during operation
            torch.cuda.reset_peak_memory_stats()
            _ = stochastic_rounding_cuda.stochastic_round_bf16(x)
            torch.cuda.synchronize()

            peak_memory = torch.cuda.max_memory_allocated() / 1e6  # MB
            results.append([size, peak_memory])

        except RuntimeError as e:
            print(f"Out of memory at size {size}")
            break

    print("\nMemory Usage:")
    print(tabulate(results,
                  headers=['Size', 'Peak Memory (MB)'],
                  floatfmt='.2f'))

def main():
    parser = argparse.ArgumentParser(description='Benchmark stochastic rounding')
    parser.add_argument('--sizes', action='store_true', help='Run size benchmarks')
    parser.add_argument('--shapes', action='store_true', help='Run shape benchmarks')
    parser.add_argument('--stress', action='store_true', help='Run stress test')
    parser.add_argument('--memory', action='store_true', help='Run memory test')
    parser.add_argument('--all', action='store_true', help='Run all benchmarks')

    args = parser.parse_args()

    # Print device information
    device = torch.cuda.get_device_properties(0)
    print(f"\nRunning on: {device.name}")
    print(f"Compute Capability: {device.major}.{device.minor}")


    if args.all or args.sizes:
        benchmark_sizes()

    if args.all or args.shapes:
        benchmark_shapes()

    if args.all or args.stress:
        stress_test()

    if args.all or args.memory:
        memory_test()

if __name__ == '__main__':
    main()


================================================
FILE: dev/sr/tests/core_unit_tests.py
================================================
import torch
import numpy as np
from collections import Counter
import unittest
import stochastic_rounding_cuda
import time

class TestStochasticRounding(unittest.TestCase):
    def setup(self):
        # Ensure deterministic behavior for some tests
        torch.manual_seed(42)
        np.random.seed(42)

    def _test_rounding_statistics_helper(self, value, lower_value, upper_value, tensor_size=10000, rounds=100):
        """Helper method for testing stochastic rounding statistics"""
        print(f"\nInput value: {value}")
        MAX_VARIANCE = 0.03
        x = torch.full((tensor_size,), value, device='cuda')
        torch.cuda.manual_seed(42)

        # Single round test - isolate and show the round up and round down values
        single_result = stochastic_rounding_cuda.stochastic_round_bf16(x)
        print(f"Possible rounded values: {torch.unique(single_result)}")

        # Multiple rounds
        results = torch.empty((rounds, tensor_size), device='cuda', dtype=torch.bfloat16)
        for i in range(rounds):
            results[i] = stochastic_rounding_cuda.stochastic_round_bf16(x)

        prob_up = (results == upper_value).float().mean().item()
        print(f"Kernel's probability of rounding up: {prob_up:.4f}")

        distance_to_lower = abs(value - lower_value)
        total_distance = upper_value - lower_value
        expected_prob = distance_to_lower / total_distance
        print(f"Expected probability: {expected_prob:.4f}")

        self.assertTrue(abs(prob_up - expected_prob) < MAX_VARIANCE)

    def test_special_values(self):
        """Test handling of special values like inf, -inf, nan"""
        special_values = torch.tensor([float('inf'), float('-inf'), float('nan'), 0.0, -0.0],
                                    device='cuda')
        rounded = stochastic_rounding_cuda.stochastic_round_bf16(special_values)

        # Check inf and -inf are preserved
        self.assertTrue(torch.isinf(rounded[0]))
        self.assertTrue(torch.isinf(rounded[1]))
        self.assertTrue(rounded[0] > 0)
        self.assertTrue(rounded[1] < 0)

        # Check nan is preserved
        self.assertTrue(torch.isnan(rounded[2]))

        # Check zeros are preserved
        self.assertEqual(rounded[3].item(), 0.0)
        self.assertEqual(rounded[4].item(), 0.0)

    def test_small_values(self):
        """Test handling of small values near zero"""
        small_values = torch.tensor([1e-38, -1e-38, 1e-20, -1e-20], device='cuda')
        rounded = stochastic_rounding_cuda.stochastic_round_bf16(small_values)

        # Check that very small values are handled properly
        self.assertTrue(torch.all(torch.isfinite(rounded)))

    def test_vectorized_loading(self):
        """Test if vectorized loading works correctly for different tensor sizes"""
        sizes = [4, 8, 9, 16, 32, 100]  # Test various sizes including non-aligned

        for size in sizes:
            x = torch.linspace(1, size, size, device='cuda')
            rounded = stochastic_rounding_cuda.stochastic_round_bf16(x)

            # Check output size matches input
            self.assertEqual(rounded.size(0), size)

            # Check dtype
            self.assertEqual(rounded.dtype, torch.bfloat16)

    def test_large_values(self):
        """Test handling of large values"""
        large_values = torch.tensor([1e38, -1e38, 1e20, -1e20], device='cuda')
        rounded = stochastic_rounding_cuda.stochastic_round_bf16(large_values)

        # Values should be preserved approximately in BF16 range
        self.assertTrue(torch.all(torch.isfinite(rounded)))

    def test_rounding_statistics(self):
        """Test if rounding probabilities match expected distribution"""
        self._test_rounding_statistics_helper(2.1999969482421875, 2.1875, 2.2031)

    def test_rounding_statistics_2(self):
        """Test stochastic rounding with different BF16 boundary values"""
        self._test_rounding_statistics_helper(1.7999992370605469, 1.7969, 1.8047)

    def test_rounding_statistics_small(self):
        """Test stochastic rounding for number between 0 and 1"""
        self._test_rounding_statistics_helper(0.7499847412109375, 0.7480, 0.7500)

    def test_rounding_statistics_large(self):
        """Test stochastic rounding for large number, over 100"""
        self._test_rounding_statistics_helper(128.99998474121094, 128.875, 129.000)



if __name__ == '__main__':
    unittest.main(verbosity=2)


================================================
FILE: dev/sr/usage.py
================================================
import torch
import stochastic_rounding_cuda

# Create input tensor
input_tensor = torch.randn(12, device='cuda', dtype=torch.float32)

# Apply stochastic rounding
output_tensor = stochastic_rounding_cuda.stochastic_round_bf16(input_tensor)
print(f"Input tensor: {input_tensor}")
print(f"Output tensor: {output_tensor}")
print(f"Output tensor dtype: {output_tensor.dtype}")
print(f"Success!")

'''
# Test tensor
x = torch.tensor([9.8751e-01, -8.5288e-01, 1.6775e+00, -1.3683e+00,
                  4.0467e-01, 1.0759e-03, 2.8418e-01, -4.9392e-01,
                  8.7239e-01, -9.0545e-01, 1.1134e+00, 0],  # -2.6872e+00
                device='cuda')

# Convert to BF16
y = stochastic_rounding_cuda.stochastic_round_bf16(x)
print(f"Input: {x}")
print(f"Output: {y}")
'''


================================================
FILE: dev/sr/usage2.py
================================================
import torch
import stochastic_rounding_cuda

# Test tensor
x = torch.tensor([9.8751e-01, -8.5288e-01, 1.6775e+00], device='cuda')

# Compare with regular rounding
y_normal = x.to(torch.bfloat16)
y_stochastic = stochastic_rounding_cuda.stochastic_round_bf16(x)

print(f"Input: {x}")
print(f"Normal BF16: {y_normal}")
print(f"Stochastic BF16: {y_stochastic}")


================================================
FILE: dev/triton_groupGEMM/groupgemm.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import functools
from typing import Optional

import tma_utils as utils

import torch

import triton
import triton.language as tl
from triton.runtime import driver  # @manual


_NV_CONFIGS = [
    triton.Config(
        {
            "BLOCK_SIZE_M": block_size_m,
            "BLOCK_SIZE_N": block_size_n,
            "BLOCK_SIZE_K": block_size_k,
        },
        num_stages=num_stages,
        num_warps=num_warps,
        num_ctas=num_ctas,
    )
    for block_size_m in [64, 128]
    for block_size_n in [64, 128, 256]
    for block_size_k in [64, 128, 256]
    for num_stages in [3, 4]
    for num_warps in [4, 8]
    for num_ctas in [1]
]

_AMD_CONFIGS = [
    triton.Config(
        {
            "BLOCK_SIZE_M": block_size_m,
            "BLOCK_SIZE_N": block_size_n,
            "BLOCK_SIZE_K": block_size_k,
            "waves_per_eu": waves_per_cu,
            "matrix_instr_nonkdim": matrix_instr_nonkdim,
        },
        num_stages=num_stages,
        num_warps=num_warps,
    )
    for block_size_m in [32, 64, 128]
    for block_size_n in [32, 64, 128, 256]
    for block_size_k in [128, 256]
    for num_stages in [1, 2]
    for num_warps, waves_per_cu in [(4, 1), (8, 2), (16, 4)]
    for matrix_instr_nonkdim in [16]
]


def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
    device = torch.cuda.current_device()
    # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
    if dtsize is None:
        dtsize = named_args["c_ptr"].element_size()
    if dtype is None:
        dtype = named_args["c_ptr"].dtype

    pruned_configs = []
    for config in configs:
        kw = config.kwargs
        BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
            kw["BLOCK_SIZE_M"],
            kw["BLOCK_SIZE_N"],
            kw["BLOCK_SIZE_K"],
            config.num_stages,
        )
        G, M, N, K = (
            named_args["G"],
            named_args["M_BUCKET"],
            named_args["N"],
            named_args["K"],
        )

        # 1. make sure we have enough smem
        max_shared_memory = driver.active.utils.get_device_properties(device)[
            "max_shared_mem"
        ]
        if torch.version.hip:
            required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize
        else:
            required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
        if required_shared_memory > max_shared_memory:
            continue

        M_PER_GROUP = M // G
        MIN_M_TILES = 32 if torch.version.hip else 64
        # 2. make sure we don't load M tiles that are too big
        if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
            continue
        # 3. make sure we don't load N tiles that are too small
        if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
            continue

        num_sm = driver.active.utils.get_device_properties(device)[
            "multiprocessor_count"
        ]
        N_TILES = N // BLOCK_N
        MIN_N_TILES = 32 if torch.version.hip else 64
        # 4. make sure we don't load N tiles that are too big
        if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
            continue
        # 5. make sure we don't load N tiles that are too small
        if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
            continue
        # 6. make sure K can be evenly divided
        if K % BLOCK_K != 0:
            continue

        pruned_configs.append(config)

    return pruned_configs


@triton.autotune(
    configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
    key=["G", "M_BUCKET", "N", "K"],
    prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_grouped_gemm(
    a_desc_ptr,
    b_desc_ptr,
    c_ptr,
    workspace,
    m_sizes,
    # problem sizes
    G: tl.constexpr,
    M_BUCKET: tl.constexpr,
    N: tl.constexpr,
    K: tl.constexpr,
    NUM_SMS: tl.constexpr,
    USE_TMA_LOAD: tl.constexpr,
    USE_TMA_STORE: tl.constexpr,
    # tile sizes
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
) -> None:
    tidx = tl.program_id(0)

    dtype: tl.dtype = c_ptr.dtype.element_ty
    TMA_SIZE: tl.constexpr = tl.constexpr(128)
    if USE_TMA_STORE:
        c_desc_ptr = workspace + tidx * TMA_SIZE
    else:
        c_desc_ptr = None

    M_end_offset = 0
    iterated_tiles = 0
    for g in tl.range(G):
        # Move across groups
        M_start_offset = M_end_offset
        m_size = tl.load(m_sizes + g)
        M_end_offset = M_start_offset + m_size

        if m_size > 0:
            N_start_offset = g * N
            n_size = N
            num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
            num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
            num_tiles = num_m_tiles * num_n_tiles

            if USE_TMA_STORE:
                # pyre-ignore
                tl.extra.cuda.experimental_device_tensormap_create2d(
                    desc_ptr=c_desc_ptr,
                    global_address=c_ptr + M_start_offset * N,
                    load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
                    global_size=[m_size, n_size],
                    element_ty=c_ptr.dtype.element_ty,
                )
                # pyre-ignore
                tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

            # Move across tiles
            while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
                gidx = tidx - iterated_tiles
                # Split M first and N second.
                tile_m_idx = gidx % num_m_tiles
                tile_n_idx = gidx // num_m_tiles

                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
                tl.static_assert(K % BLOCK_SIZE_K == 0)
                if USE_TMA_LOAD:
                    m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
                    n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
                    for k_offset in range(0, K, BLOCK_SIZE_K):
                        a = tl._experimental_descriptor_load(
                            a_desc_ptr,
                            [m_offset, k_offset],
                            [BLOCK_SIZE_M, BLOCK_SIZE_K],
                            dtype,
                        )
                        b = tl._experimental_descriptor_load(
                            b_desc_ptr,
                            [n_offset, k_offset],
                            [BLOCK_SIZE_N, BLOCK_SIZE_K],
                            dtype,
                        )
                        accumulator += tl.dot(a, b.T)
                else:
                    offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
                    offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
                    offs_k = tl.arange(0, BLOCK_SIZE_K)
                    a_ptrs = (
                        a_desc_ptr
                        + (M_start_offset + offs_am[:, None]) * K
                        + offs_k[None, :]
                    )
                    b_ptrs = (
                        b_desc_ptr
                        + (N_start_offset + offs_bn[:, None]) * K
                        + offs_k[None, :]
                    )
                    for k_offset in range(0, K, BLOCK_SIZE_K):
                        a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
                        b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
                        accumulator += tl.dot(a, b.T)
                        a_ptrs += BLOCK_SIZE_K
                        b_ptrs += BLOCK_SIZE_K

                if USE_TMA_STORE:
                    m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
                    n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
                    tl._experimental_descriptor_store(
                        c_desc_ptr,
                        accumulator.to(c_ptr.dtype.element_ty),
                        [m_offset, n_offset],
                    )
                else:
                    offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
                    offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
                    c = accumulator.to(c_ptr.dtype.element_ty)
                    tl.store(
                        c_ptr
                        + (M_start_offset + offs_am[:, None]) * N
                        + offs_bn[None, :],
                        c,
                        mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
                    )
                tidx += NUM_SMS

            iterated_tiles += num_tiles


TT_FP8_DTYPE = tl.float8e4b8 if torch.version.hip else tl.float8e4nv


@triton.autotune(
    configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
    key=["G", "M_BUCKET", "N", "K"],
    prune_configs_by={
        "early_config_prune": functools.partial(
            early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1
        )
    },
)
@triton.jit
def _kernel_grouped_gemm_fp8_rowwise(
    a_desc_ptr,
    a_scale_ptr,
    b_desc_ptr,
    b_scale_ptr,
    c_ptr,
    workspace,
    m_sizes,
    # problem sizes
    G: tl.constexpr,
    M_BUCKET: tl.constexpr,
    N: tl.constexpr,
    K: tl.constexpr,
    NUM_SMS: tl.constexpr,
    USE_TMA_LOAD: tl.constexpr,
    USE_TMA_STORE: tl.constexpr,
    # tile sizes
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
) -> None:
    tidx = tl.program_id(0)

    dtype = TT_FP8_DTYPE
    TMA_SIZE: tl.constexpr = tl.constexpr(128)
    if USE_TMA_STORE:
        c_desc_ptr = workspace + tidx * TMA_SIZE
    else:
        c_desc_ptr = None

    M_end_offset = 0
    iterated_tiles = 0
    for g in tl.range(G):
        # Move across groups
        M_start_offset = M_end_offset
        m_size = tl.load(m_sizes + g)
        M_end_offset = M_start_offset + m_size

        if m_size > 0:
            N_start_offset = g * N
            n_size = N
            num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
            num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
            num_tiles = num_m_tiles * num_n_tiles

            if USE_TMA_STORE:
                # pyre-ignore
                tl.extra.cuda.experimental_device_tensormap_create2d(
                    desc_ptr=c_desc_ptr,
                    global_address=c_ptr + M_start_offset * N,
                    load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
                    global_size=[m_size, n_size],
                    element_ty=c_ptr.dtype.element_ty,
                )
                # pyre-ignore
                tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

            # Move across tiles
            while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
                gidx = tidx - iterated_tiles
                # Split M first and N second.
                tile_m_idx = gidx % num_m_tiles
                tile_n_idx = gidx // num_m_tiles

                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
                tl.static_assert(K % BLOCK_SIZE_K == 0)
                if USE_TMA_LOAD:
                    m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
                    n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
                    for k_offset in range(0, K, BLOCK_SIZE_K):
                        a = tl._experimental_descriptor_load(
                            a_desc_ptr,
                            [m_offset, k_offset],
                            [BLOCK_SIZE_M, BLOCK_SIZE_K],
                            dtype,
                        )
                        b = tl._experimental_descriptor_load(
                            b_desc_ptr,
                            [n_offset, k_offset],
                            [BLOCK_SIZE_N, BLOCK_SIZE_K],
                            dtype,
                        )
                        accumulator += tl.dot(a, b.T)
                else:
                    offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
                    offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
                    offs_k = tl.arange(0, BLOCK_SIZE_K)
                    a_ptrs = (
                        a_desc_ptr
                        + (M_start_offset + offs_am[:, None]) * K
                        + offs_k[None, :]
                    )
                    b_ptrs = (
                        b_desc_ptr
                        + (N_start_offset + offs_bn[:, None]) * K
                        + offs_k[None, :]
                    )
                    for k_offset in range(0, K, BLOCK_SIZE_K):
                        a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
                        b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
                        accumulator += tl.dot(a, b.T)
                        a_ptrs += BLOCK_SIZE_K
                        b_ptrs += BLOCK_SIZE_K

                offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
                offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
                a_scale = tl.load(
                    a_scale_ptr + M_start_offset + offs_am[:, None],
                    mask=offs_am[:, None] < m_size,
                )
                b_scale = tl.load(
                    b_scale_ptr + N_start_offset + offs_bn[None, :],
                    mask=offs_bn[None, :] < n_size,
                )
                c = accumulator.to(tl.float32) * a_scale * b_scale

                if USE_TMA_STORE:
                    m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
                    n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
                    tl._experimental_descriptor_store(
                        c_desc_ptr,
                        c.to(c_ptr.dtype.element_ty),
                        [m_offset, n_offset],
                    )
                else:
                    tl.store(
                        c_ptr
                        + (M_start_offset + offs_am[:, None]) * N
                        + offs_bn[None, :],
                        c,
                        mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
                    )
                tidx += NUM_SMS

            iterated_tiles += num_tiles


def _grouped_gemm(
    x: torch.Tensor,
    w: torch.Tensor,
    m_sizes: torch.Tensor,
    x_scale: Optional[torch.Tensor] = None,
    w_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if not utils.HAS_TMA_DESC:
        raise NotImplementedError("Grouped GEMM without TMA is not supported yet")

    G = m_sizes.shape[0]

    assert x.is_contiguous()
    assert w.is_contiguous()
    assert m_sizes.is_contiguous()

    M, K = x.shape
    N = w.shape[0] // G
    assert K == w.shape[1]

    y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)

    NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
    USE_TMA_LOAD = not torch.version.hip
    USE_TMA_STORE = False

    desc_helper = None
    desc_x = x
    desc_w = w
    workspace = None

    if USE_TMA_LOAD:
        desc_helper = utils.TmaAutoTuneHelper()
        desc_helper.init_tma_descriptor("x")
        desc_helper.init_tma_descriptor("w")
        desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
        desc_w = desc_helper.get_tma_descriptor_kernel_param("w")

    if USE_TMA_STORE:
        workspace = torch.empty(
            NUM_SMS * utils.TmaAutoTuneHelper.TMA_SIZE,
            device=x.device,
            dtype=torch.uint8,
        )

    def grid(META):
        if USE_TMA_LOAD:
            nonlocal desc_helper
            desc_helper.fill_2d_tma_descriptor(
                "x",
                x.data_ptr(),
                M,
                K,
                META["BLOCK_SIZE_M"],
                META["BLOCK_SIZE_K"],
                x.element_size(),
            )

            desc_helper.fill_2d_tma_descriptor(
                "w",
                w.data_ptr(),
                N * G,
                K,
                META["BLOCK_SIZE_N"],
                META["BLOCK_SIZE_K"],
                w.element_size(),
            )

        return (NUM_SMS,)

    M_BUCKET = triton.next_power_of_2(M)
    if x_scale is not None and w_scale is not None:
        assert x_scale.is_contiguous()
        assert w_scale.is_contiguous()
        _kernel_grouped_gemm_fp8_rowwise[grid](
            desc_x,
            x_scale,
            desc_w,
            w_scale,
            y,
            workspace,
            m_sizes,
            G,
            M_BUCKET,
            N,
            K,
            NUM_SMS,
            USE_TMA_LOAD,
            USE_TMA_STORE,
        )
    else:
        assert x_scale is None
        assert w_scale is None
        _kernel_grouped_gemm[grid](
            desc_x,
            desc_w,
            y,
            workspace,
            m_sizes,
            G,
            M_BUCKET,
            N,
            K,
            NUM_SMS,
            USE_TMA_LOAD,
            USE_TMA_STORE,
        )

    return y


def grouped_gemm(
    x: torch.Tensor, w: torch.Tensor, m_sizes: torch.Tensor
) -> torch.Tensor:
    return _grouped_gemm(x, w, m_sizes)


def grouped_gemm_fp8_rowwise(
    x: torch.Tensor,
    w: torch.Tensor,
    m_sizes: torch.Tensor,
    x_scale: torch.Tensor,
    w_scale: torch.Tensor,
) -> torch.Tensor:
    return _grouped_gemm(x, w, m_sizes, x_scale, w_scale)


================================================
FILE: dev/triton_groupGEMM/testing/base_testing.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import logging

# Configure logging to print to stdout
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

import os
import sys
import unittest
from typing import Tuple

import torch

# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))


if torch.cuda.is_available():
    # from fp8_gemm import quantize_fp8_row
    from groupgemm import grouped_gemm  # , grouped_gemm_fp8_rowwise
    from tma_utils import HAS_TMA_DESC


@unittest.skipIf(
    not torch.cuda.is_available()
    or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9
    or not HAS_TMA_DESC,
    "Skip when H100 or TMA is not available",
)
class TestGroupedGEMM(unittest.TestCase):
    def setUp(self) -> None:
        torch.manual_seed(0)

    """def test_grouped_gemm_fp8_rowwise(self) -> None:
        def _test_grouped_gemm_fp8_rowwise(
            shape: Tuple[int, int, int, int],
            device: torch.device,
        ) -> None:
            G, M, N, K = shape
            a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
            b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
            m_ends, _ = torch.sort(
                torch.randint(
                    low=0, high=M, size=[G - 1], device=device, dtype=torch.int32
                )
            )
            m_ends = m_ends.tolist()
            m_starts = [0] + m_ends
            m_ends = m_ends + [M]
            m_sizes = torch.tensor(
                [m_ends[i] - m_starts[i] for i in range(G)], device=device
            ).to(torch.int32)

            a_fp8, a_scale = quantize_fp8_row(a)
            b_fp8, b_scale = quantize_fp8_row(b)

            result = grouped_gemm_fp8_rowwise(
                a_fp8,
                b_fp8,
                m_sizes,
                a_scale,
                b_scale,
            )
            self.assertTrue(result.shape == (M, N))

            expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)
            # Running baseline with quantization to exclude quantization error from the test as it has nothing to do with the correctness of the kernel implementation.
            for g in range(G):
                m_start = m_starts[g]
                m_end = m_ends[g]
                n_start = g * N
                n_end = (g + 1) * N

                expected_result[m_start:m_end, :] = (
                    a_fp8[m_start:m_end, :].to(torch.float32)
                    @ b_fp8[n_start:n_end, :].to(torch.float32).T
                    * a_scale[m_start:m_end][:, None]
                    * b_scale[n_start:n_end][None, :]
                ).to(torch.bfloat16)

            torch.testing.assert_close(result, expected_result, atol=2e-2, rtol=1.6e-2)

        for G in (1, 4, 16):
            for M in (64, 512):
                logging.info(f"Testing FP8 GMM with G={G}, M={M}")
                _test_grouped_gemm_fp8_rowwise((G, M, 256, 256), torch.device("cuda"))
        """

    def test_grouped_gemm_bf16(self) -> None:
        def _test_grouped_gemm_bf16(
            shape: Tuple[int, int, int, int],
            device: torch.device,
        ) -> None:
            G, M, N, K = shape
            a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
            b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
            m_ends, _ = torch.sort(
                torch.randint(
                    low=0, high=M, size=[G - 1], device=device, dtype=torch.int32
                )
            )
            m_ends = m_ends.tolist()
            m_starts = [0] + m_ends
            m_ends = m_ends + [M]
            m_sizes = torch.tensor(
                [m_ends[i] - m_starts[i] for i in range(G)], device=device
            ).to(torch.int32)

            result = grouped_gemm(
                a,
                b,
                m_sizes,
            )
            self.assertTrue(result.shape == (M, N))

            expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)
            for g in range(G):
                m_start = m_starts[g]
                m_end = m_ends[g]
                expected_result[m_start:m_end, :] = (
                    a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T
                )

            torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)

        for G in (1, 4, 16):
            for M in (64, 512):
                logging.info(f"Testing BF16 GMM with G={G}, M={M}")
                _test_grouped_gemm_bf16((G, M, 256, 256), torch.device("cuda"))


if __name__ == "__main__":
    unittest.main(exit=False)


================================================
FILE: dev/triton_groupGEMM/testing/unit_tests.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
# This code is derived from: https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gemm/triton_gemm


import logging
import unittest
from typing import Tuple

import torch

# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from groupgemm import grouped_gemm


class TestGroupedGEMM(unittest.TestCase):
    def test_grouped_gemm_bf16(self) -> None:
        def _test_grouped_gemm_bf16(
            shape: Tuple[int, int, int, int],
            device: torch.device,
        ) -> None:
            G, M, N, K = shape
            a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
            b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
            m_ends, _ = torch.sort(
                torch.randint(
                    low=0, high=M, size=[G - 1], device=device, dtype=torch.int32
                )
            )
            m_ends = m_ends.tolist()
            m_starts = [0] + m_ends
            m_ends = m_ends + [M]
            m_sizes = torch.tensor(
                [m_ends[i] - m_starts[i] for i in range(G)], device=device
            ).to(torch.int32)
            result = grouped_gemm(
                a,
                b,
                m_sizes,
            )
            self.assertTrue(result.shape == (M, N))
            expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)
            for g in range(G):
                m_start = m_starts[g]
                m_end = m_ends[g]
                expected_result[m_start:m_end, :] = (
                    a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T
                )
            torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)

        for G in (1, 4, 16):
            for M in (64, 512):
                logging.info(f"Testing BF16 GMM with G={G}, M={M}")
                _test_grouped_gemm_bf16((G, M, 256, 256), torch.device("cuda"))

    def test_grouped_gemm_bf16_various_dimensions(self) -> None:
        """Test grouped_gemm with bf16 precision and various dimensions"""

        def _test_grouped_gemm_bf16(
            shape: Tuple[int, int, int, int],
            device: torch.device,
        ) -> None:
            G, M, N, K = shape
            a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
            b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
            m_ends, _ = torch.sort(
                torch.randint(
                    low=0, high=M, size=[G - 1], device=device, dtype=torch.int32
                )
            )
            m_ends = m_ends.tolist()
            m_starts = [0] + m_ends
            m_ends = m_ends + [M]
            m_sizes = torch.tensor(
                [m_ends[i] - m_starts[i] for i in range(G)], device=device
            ).to(torch.int32)
            result = grouped_gemm(
                a,
                b,
                m_sizes,
            )
            self.assertTrue(result.shape == (M, N))
            expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)
            for g in range(G):
                m_start = m_starts[g]
                m_end = m_ends[g]
                expected_result[m_start:m_end, :] = (
                    a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T
                )
            torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)

        for G in (4, 8):
            for M in (128, 256):
                for N, K in [(128, 256), (256, 128), (64, 64)]:
                    logging.info(f"Testing BF16 GMM with G={G}, M={M}, N={N}, K={K}")
                    _test_grouped_gemm_bf16((G, M, N, K), torch.device("cuda"))

    def test_grouped_gemm_bf16_edge_cases(self) -> None:
        """Test grouped_gemm with bfloat16 for various edge cases"""
        device = torch.device("cuda")

        # Test with G=1 (single group case)
        G, M, N, K = 1, 32, 32, 32
        a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
        m_sizes = torch.tensor([M], device=device).to(torch.int32)
        result = grouped_gemm(a, b, m_sizes)
        expected_result = a @ b.T
        torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)

        # Test with uneven group sizes
        G, M, N, K = 3, 100, 32, 32
        a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
        m_sizes = torch.tensor([25, 50, 25], device=device).to(torch.int32)
        result = grouped_gemm(a, b, m_sizes)
        self.assertTrue(result.shape == (M, N))
        expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)
        m_start = 0
        for g in range(G):
            m_end = m_start + m_sizes[g].item()
            expected_result[m_start:m_end, :] = (
                a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T
            )
            m_start = m_end
        torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)

        # Test with extremely small matrices
        G, M, N, K = 2, 8, 8, 8
        a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
        m_sizes = torch.tensor([4, 4], device=device).to(torch.int32)
        result = grouped_gemm(a, b, m_sizes)
        expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)
        m_start = 0
        for g in range(G):
            m_end = m_start + m_sizes[g].item()
            expected_result[m_start:m_end, :] = (
                a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T
            )
            m_start = m_end
        torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)

        # Test with large group count but small matrix sizes
        G, M, N, K = 32, 128, 16, 16
        a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
        m_sizes = torch.ones(G, device=device).to(torch.int32) * (M // G)
        m_sizes[-1] += M % G  # Adjust the last group size to account for remainder
        result = grouped_gemm(a, b, m_sizes)
        expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)
        m_start = 0
        for g in range(G):
            m_end = m_start + m_sizes[g].item()
            expected_result[m_start:m_end, :] = (
                a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T
            )
            m_start = m_end
        torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)

    def test_grouped_gemm_bf16_invalid_inputs(self) -> None:
        """Test grouped_gemm with invalid inputs to ensure proper error handling"""
        device = torch.device("cuda")

        # Test with mismatched dimensions
        G, M, N, K = 2, 64, 32, 32
        a = torch.randn(
            M, K + 1, dtype=torch.bfloat16, device=device
        )  # Wrong K dimension
        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
        m_sizes = torch.tensor([32, 32], device=device).to(torch.int32)

        with self.assertRaises(RuntimeError):
            grouped_gemm(a, b, m_sizes)

        # Test with mismatched G and m_sizes length
        G, M, N, K = 2, 64, 32, 32
        a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
        m_sizes = torch.tensor([32, 32, 32], device=device).to(
            torch.int32
        )  # Too many groups

        with self.assertRaises((RuntimeError, ValueError, IndexError)):
            grouped_gemm(a, b, m_sizes)

        # Test with incorrect sum of m_sizes
        G, M, N, K = 2, 64, 32, 32
        a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
        m_sizes = torch.tensor([32, 40], device=device).to(torch.int32)  # Sum > M

        with self.assertRaises((RuntimeError, ValueError, IndexError)):
            grouped_gemm(a, b, m_sizes)

        # Test with negative m_sizes values
        G, M, N, K = 2, 64, 32, 32
        a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
        m_sizes = torch.tensor([40, -8], device=device).to(
            torch.int32
        )  # Negative group size

        with self.assertRaises((RuntimeError, ValueError)):
            grouped_gemm(a, b, m_sizes)

    def test_grouped_gemm_bf16_deterministic(self) -> None:
        """Test that grouped_gemm produces deterministic results with the same inputs"""
        G, M, N, K = 4, 128, 64, 64
        device = torch.device("cuda")

        # Fix the random seed for reproducibility
        torch.manual_seed(42)

        a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
        m_sizes = torch.tensor([32, 32, 32, 32], device=device).to(torch.int32)

        # First run
        result1 = grouped_gemm(a, b, m_sizes)

        # Second run with same inputs
        result2 = grouped_gemm(a, b, m_sizes)

        # Results should be exactly the same
        self.assertTrue(torch.all(result1 == result2))

    def test_grouped_gemm_bf16_large_matrices(self) -> None:
        """Test grouped_gemm with larger matrices to stress test performance and stability"""
        device = torch.device("cuda")

        # Test with large matrices but fewer groups
        G, M, N, K = 2, 2048, 512, 1024
        a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
        b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
        m_sizes = torch.tensor([1024, 1024], device=device).to(torch.int32)

        result = grouped_gemm(a, b, m_sizes)
        expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)

        m_start = 0
        for g in range(G):
            m_end = m_start + m_sizes[g].item()
            expected_result[m_start:m_end, :] = (
                a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T
            )
            m_start = m_end

        torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)


if __name__ == "__main__":
    unittest.main(argv=["first-arg-is-ignored"], exit=False)


================================================
FILE: dev/triton_groupGEMM/tma_utils.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
# This code is derived from: https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gemm/triton_gemm

import sys

import torch
import triton  # @manual

import triton.language as tl  # @manual


def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:
    """
    Maps torch dtype to triton dtype.

    Args:
        dtype (torch.dtype): input dtype.

    Returns:
        tl.dtype: triton dtype.
    """
    if dtype == torch.float16:
        return tl.float16
    elif dtype == torch.bfloat16:
        return tl.bfloat16
    elif dtype == torch.float32:
        return tl.float32
    elif dtype == torch.int32:
        return tl.int32
    elif dtype == torch.float8_e4m3fn and torch.version.hip is None:
        return tl.float8e4nv
    else:
        raise ValueError(f"Unsupported dtype {dtype}")


# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498).
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)

if HAS_TMA_DESC:
    print(
        "TMA benchmarks will be running with experimental grid constant TMA descriptor.",
        file=sys.stderr,
    )
else:
    print(
        "Missing: This group gemm code will not run without TMA descriptor support....",
        file=sys.stderr,
    )
    raise NotImplementedError("grouped Gemm without TMA is not supported")


class TmaAutoTuneHelper:

    # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
    class KernelParamWrapper:
        def __init__(self, desc):
            self.desc = desc

        def tma_desc_cpu_ptr(self):
            return self.desc.data_ptr()

    TMA_SIZE = 128

    def __init__(self):
        self.fill_1d_tma_descriptor_inner = (
            triton.runtime.driver.active.utils.fill_1d_tma_descriptor
        )
        self.fill_2d_tma_descriptor_inner = (
            triton.runtime.driver.active.utils.fill_2d_tma_descriptor
        )
        if HAS_TMA_DESC:
            self.descriptors = {}
        else:
            self.cuda_descriptors = {}

    # Call this method outside of the lambda function for grid size
    def init_tma_descriptor(self, name):
        if HAS_TMA_DESC:
            self.descriptors[name] = torch.empty(
                TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8
            )
        else:
            self.cuda_descriptors[name] = torch.empty(
                TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8
            )

    # Call this method inside the lambda function for grid size
    def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
        if HAS_TMA_DESC:
            desc_x = self.descriptors[name]
            assert desc_x.data_ptr() % 64 == 0
            self.fill_1d_tma_descriptor_inner(
                ptr, dim, block_dim, element_size, desc_x.data_ptr()
            )
        else:
            desc_x = self.cuda_descriptors[name]
            buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
            self.fill_1d_tma_descriptor_inner(
                ptr, dim, block_dim, element_size, buf_x.data_ptr()
            )
            desc_x.copy_(buf_x, non_blocking=True)

    # Call this method inside the lambda function for grid size
    def fill_2d_tma_descriptor(
        self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size
    ):
        if HAS_TMA_DESC:
            desc_x = self.descriptors[name]
            assert desc_x.data_ptr() % 64 == 0
            self.fill_2d_tma_descriptor_inner(
                ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
            )
        else:
            desc_x = self.cuda_descriptors[name]
            buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
            self.fill_2d_tma_descriptor_inner(
                ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()
            )
            desc_x.copy_(buf_x, non_blocking=True)

    def get_tma_descriptor_kernel_param(self, name):
        if HAS_TMA_DESC:
            assert self.descriptors[name] is not None
            return self.KernelParamWrapper(self.descriptors[name])
        else:
            assert self.cuda_descriptors[name] is not None
            return self.cuda_descriptors[name]


================================================
FILE: dev/triton_groupGEMM/triton_tutorial_groupgemm.py
================================================
"""
Group GEMM
============================
This group gemm kernel launches a fixed number of CTA to compute a group
of gemms. The scheduling is static and we do it on device.
"""

# Copyright (c) 2023 - 2025 NVIDIA Corporation & Affiliates. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

# from: https://github.com/triton-lang/triton/blob/main/python/tutorials/08-grouped-gemm.py

from typing import Optional

import torch

import triton
import triton.language as tl

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_cuda():
    return triton.runtime.driver.active.get_current_target().backend == "cuda"


def supports_tma():
    return is_cuda() and torch.cuda.get_device_capability()[0] >= 9


def num_sms():
    if is_cuda():
        return torch.cuda.get_device_properties("cuda").multi_processor_count
    return 148


@triton.autotune(
    configs=[
        triton.Config(
            {
                "BLOCK_SIZE_M": 128,
                "BLOCK_SIZE_N": 128,
                "BLOCK_SIZE_K": 32,
                "NUM_SM": 84,
            }
        ),
        triton.Config(
            {
                "BLOCK_SIZE_M": 128,
                "BLOCK_SIZE_N": 128,
                "BLOCK_SIZE_K": 32,
                "NUM_SM": 128,
            }
        ),
        triton.Config(
            {
                "BLOCK_SIZE_M": 64,
                "BLOCK_SIZE_N": 64,
                "BLOCK_SIZE_K": 32,
                "NUM_SM": 84,
            }
        ),
        triton.Config(
            {
                "BLOCK_SIZE_M": 64,
                "BLOCK_SIZE_N": 64,
                "BLOCK_SIZE_K": 32,
                "NUM_SM": 128,
            }
        ),
        triton.Config(
            {
                "BLOCK_SIZE_M": 128,
                "BLOCK_SIZE_N": 128,
                "BLOCK_SIZE_K": 64,
                "NUM_SM": num_sms(),
            }
        ),
        triton.Config(
            {
                "BLOCK_SIZE_M": 64,
                "BLOCK_SIZE_N": 128,
                "BLOCK_SIZE_K": 64,
                "NUM_SM": num_sms(),
            }
        ),
    ],
    key=["group_size"],
)
@triton.jit
def grouped_matmul_kernel(
    # device tensor of matrices pointers
    group_a_ptrs,
    group_b_ptrs,
    group_c_ptrs,
    # device tensor of gemm sizes. its shape is [group_size, 3]
    # dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemm
    group_gemm_sizes,
    # device tensor of leading dimension sizes. its shape is [group_size, 3]
    # dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemm
    g_lds,
    # number of gemms
    group_size,
    # number of virtual SM
    NUM_SM: tl.constexpr,
    # tile sizes
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
):
    tile_idx = tl.program_id(0)
    last_problem_end = 0
    for g in range(group_size):
        # get the gemm size of the current problem
        gm = tl.load(group_gemm_sizes + g * 3)
        gn = tl.load(group_gemm_sizes + g * 3 + 1)
        gk = tl.load(group_gemm_sizes + g * 3 + 2)
        num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)
        num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)
        num_tiles = num_m_tiles * num_n_tiles
        # iterate through the tiles in the current gemm problem
        while tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles:
            # pick up a tile from the current gemm problem
            k = gk
            lda = tl.load(g_lds + g * 3)
            ldb = tl.load(g_lds + g * 3 + 1)
            ldc = tl.load(g_lds + g * 3 + 2)
            a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16))
            b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16))
            c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16))
            # figure out tile coordinates
            tile_idx_in_gemm = tile_idx - last_problem_end
            tile_m_idx = tile_idx_in_gemm // num_n_tiles
            tile_n_idx = tile_idx_in_gemm % num_n_tiles

            # do regular gemm here
            offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
            offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
            offs_k = tl.arange(0, BLOCK_SIZE_K)
            a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :]
            b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :]
            accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
            for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
                # hint to Triton compiler to do proper loop pipelining
                tl.multiple_of(a_ptrs, [16, 16])
                tl.multiple_of(b_ptrs, [16, 16])
                # assume full tile for now
                a = tl.load(a_ptrs)
                b = tl.load(b_ptrs)
                accumulator += tl.dot(a, b)
                a_ptrs += BLOCK_SIZE_K
                b_ptrs += BLOCK_SIZE_K * ldb
            c = accumulator.to(tl.float16)

            offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
            offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
            c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :]

            # assumes full tile for now
            tl.store(c_ptrs, c)

            # go to the next tile by advancing NUM_SM
            tile_idx += NUM_SM

        # get ready to go to the next gemm problem
        last_problem_end = last_problem_end + num_tiles


def group_gemm_fn(group_A, group_B):
    assert len(group_A) == len(group_B)
    group_size = len(group_A)

    A_addrs = []
    B_addrs = []
    C_addrs = []
    g_sizes = []
    g_lds = []
    group_C = []
    for i in range(group_size):
        A = group_A[i]
        B = group_B[i]
        assert A.shape[1] == B.shape[0]
        M, K = A.shape
        K, N = B.shape
        C = torch.empty((M, N), device=DEVICE, dtype=A.dtype)
        group_C.append(C)
        A_addrs.append(A.data_ptr())
        B_addrs.append(B.data_ptr())
        C_addrs.append(C.data_ptr())
        g_sizes += [M, N, K]
        g_lds += [A.stride(0), B.stride(0), C.stride(0)]

    # note these are device tensors
    d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)
    d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)
    d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)
    d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)
    d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)
    # we use a fixed number of CTA, and it's auto-tunable
    grid = lambda META: (META["NUM_SM"],)
    grouped_matmul_kernel[grid](
        d_a_ptrs,
        d_b_ptrs,
        d_c_ptrs,
        d_g_sizes,
        d_g_lds,
        group_size,
    )

    return group_C


tma_configs = [
    triton.Config(
        {"BLOCK_SIZE_M": BM, "BLOCK_SIZE_N": BN, "BLOCK_SIZE_K": BK},
        num_stages=s,
        num_warps=w,
    )
    for BM in [128]
    for BN in [128, 256]
    for BK in [64, 128]
    for s in ([3, 4])
    for w in [4, 8]
]


@triton.autotune(
    tma_configs,
    key=["group_a_ptrs", "group_b_ptrs", "gropup_c_ptrs", "group_size"],
)
@triton.jit
def grouped_matmul_tma_kernel(
    # device tensor of matrices pointers
    group_a_ptrs,
    group_b_ptrs,
    group_c_ptrs,
    # device tensor of gemm sizes. its shape is [group_size, 3]
    # dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemm
    group_gemm_sizes,
    # device tensor of leading dimension sizes. its shape is [group_size, 3]
    # dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemm
    g_lds,
    # number of gemms
    group_size,
    # number of virtual SM
    NUM_SM: tl.constexpr,
    # tile sizes
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    # is the output FP8 or FP16
    FP8: tl.constexpr,
):
    dtype = tl.float8e4nv if FP8 else tl.float16
    tile_idx = tl.program_id(0)
    last_problem_end = 0
    for g in range(group_size):
        # get the gemm size of the current problem
        gm = tl.load(group_gemm_sizes + g * 3)
        gn = tl.load(group_gemm_sizes + g * 3 + 1)
        gk = tl.load(group_gemm_sizes + g * 3 + 2)
        num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)
        num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)
        num_tiles = num_m_tiles * num_n_tiles
        if tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles:
            # pick up a tile from the current gemm problem
            lda = tl.load(g_lds + g * 3)
            ldb = tl.load(g_lds + g * 3 + 1)
            ldc = tl.load(g_lds + g * 3 + 2)

            a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(dtype))
            b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(dtype))
            c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(dtype))

            a_desc = tl._experimental_make_tensor_descriptor(
                a_ptr,
                shape=[gm, gk],
                strides=[lda, 1],
                block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
            )

            b_desc = tl._experimental_make_tensor_descriptor(
                b_ptr,
                shape=[gn, gk],
                strides=[ldb, 1],
                block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
            )
            c_desc = tl._experimental_make_tensor_descriptor(
                c_ptr,
                shape=[gm, gn],
                strides=[ldc, 1],
                block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
            )

            # iterate through the tiles in the current gemm problem
            while (
                tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles
            ):
                k = gk
                # figure out tile coordinates
                tile_idx_in_gemm = tile_idx - last_problem_end
                tile_m_idx = tile_idx_in_gemm // num_n_tiles
                tile_n_idx = tile_idx_in_gemm % num_n_tiles

                # do regular gemm here
                offs_am = tile_m_idx * BLOCK_SIZE_M
                offs_bn = tile_n_idx * BLOCK_SIZE_N

                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
                for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
                    a = a_desc.load([offs_am, kk * BLOCK_SIZE_K])
                    b = b_desc.load([offs_bn, kk * BLOCK_SIZE_K])
                    accumulator += tl.dot(a, b.T)

                offs_cm = tile_m_idx * BLOCK_SIZE_M
                offs_cn = tile_n_idx * BLOCK_SIZE_N

                c = accumulator.to(dtype)
                c_desc.store([offs_cm, offs_cn], c)

                # go to the next tile by advancing NUM_SM
                tile_idx += NUM_SM

        # get ready to go to the next gemm problem
        last_problem_end = last_problem_end + num_tiles


def group_gemm_tma_fn(group_A, group_B):

    assert supports_tma()

    assert len(group_A) == len(group_B)
    group_size = len(group_A)

    A_addrs = []
    B_addrs = []
    C_addrs = []
    g_sizes = []
    g_lds = []
    group_C = []
    for i in range(group_size):
        A = group_A[i]
        B = group_B[i]
        assert A.shape[1] == B.shape[1]
        M, K = A.shape
        N, K = B.shape
        C = torch.empty((M, N), device=DEVICE, dtype=A.dtype)
        group_C.append(C)
        A_addrs.append(A.data_ptr())
        B_addrs.append(B.data_ptr())
        C_addrs.append(C.data_ptr())
        g_sizes += [M, N, K]
        g_lds += [A.stride(0), B.stride(0), C.stride(0)]
    # note these are device tensors
    d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)
    d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)
    d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)
    d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)
    d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)

    # we use a fixed number of CTA, and it's auto-tunable

    # TMA descriptors require a global memory allocation
    def alloc_fn(size: int, alignment: int, stream: Optional[int]):
        return torch.empty(size, device="cuda", dtype=torch.int8)

    triton.set_allocator(alloc_fn)

    grid = lambda META: (META["NUM_SM"],)
    grouped_matmul_tma_kernel[grid](
        d_a_ptrs,
        d_b_ptrs,
        d_c_ptrs,
        d_g_sizes,
        d_g_lds,
        group_size,
        FP8=torch.float8_e4m3fn == group_A[0].dtype,
        NUM_SM=num_sms(),
    )
    return group_C


group_m = [1024, 512, 256, 128]
group_n = [1024, 512, 256, 128]
group_k = [1024, 512, 256, 128]
group_A = []
group_B = []
group_B_T = []
assert len(group_m) == len(group_n)
assert len(group_n) == len(group_k)
group_size = len(group_m)
for i in range(group_size):
    M = group_m[i]
    N = group_n[i]
    K = group_k[i]
    A = torch.rand((M, K), device=DEVICE, dtype=torch.float16)
    B = torch.rand((K, N), device=DEVICE, dtype=torch.float16)
    B_T = B.T.contiguous()
    group_A.append(A)
    group_B.append(B)
    group_B_T.append(B_T)

tri_out = group_gemm_fn(group_A, group_B)
ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)]
for i in range(group_size):
    assert torch.allclose(ref_out[i], tri_out[i], atol=1e-2, rtol=0)

if supports_tma():
    tri_tma_out = group_gemm_tma_fn(group_A, group_B_T)
    for i in range(group_size):
        assert torch.allclose(ref_out[i], tri_tma_out[i], atol=1e-2, rtol=0)


# only launch the kernel, no tensor preparation here to remove all overhead
def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size):
    grid = lambda META: (META["NUM_SM"],)
    grouped_matmul_kernel[grid](
        a_ptrs,
        b_ptrs,
        c_ptrs,
        sizes,
        lds,
        group_size,
    )


def triton_tma_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, dtype):
    grid = lambda META: (META["NUM_SM"],)
    grouped_matmul_tma_kernel[grid](
        a_ptrs,
        b_ptrs,
        c_ptrs,
        sizes,
        lds,
        group_size,
        FP8=torch.float8_e4m3fn == dtype,
        NUM_SM=num_sms(),
    )


def torch_perf_fn(group_A, group_B):
    for a, b in zip(group_A, group_B):
        torch.matmul(a, b)


@triton.testing.perf_report(
    triton.testing.Benchmark(
        # argument names to use as an x-axis for the plot
        x_names=["N"],
        x_vals=[2**i for i in range(7, 11)],  # different possible values for `x_name`
        line_arg="provider",
        # argument name whose value corresponds to a different line in the plot
        # possible values for `line_arg``
        line_vals=["cublas", "triton"] + (["triton-tma"] if supports_tma() else []),
        # label name for the lines
        line_names=["cuBLAS", "Triton"] + (["Triton + TMA"] if supports_tma() else []),
        # line styles
        styles=[("green", "-"), ("blue", "-")]
        + ([("red", "-")] if supports_tma() else []),
        ylabel="runtime(ms)",  # label name for the y-axis
        plot_name="group-gemm-performance",
        # name for the plot. Used also as a file name for saving the plot.
        args={},
    )
)
def benchmark_square_matrices(N, provider):
    group_size = 4
    group_A = []
    group_B = []
    group_B_T = []
    A_addrs = []
    B_addrs = []
    B_T_addrs = []
    C_addrs = []
    g_sizes = []
    g_lds = []
    group_C = []
    for i in range(group_size):
        A = torch.rand((N, N), device=DEVICE, dtype=torch.float16)
        B = torch.rand((N, N), device=DEVICE, dtype=torch.float16)
        C = torch.empty((N, N), device=DEVICE, dtype=torch.float16)
        B_T = B.T.contiguous()
        group_A.append(A)
        group_B.append(B)
        group_B_T.append(B_T)
        group_C.append(C)
        A_addrs.append(A.data_ptr())
        B_addrs.append(B.data_ptr())
        B_T_addrs.append(B_T.data_ptr())
        C_addrs.append(C.data_ptr())
        g_sizes += [N, N, N]
        g_lds += [N, N, N]

    d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)
    d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)
    d_b_t_ptrs = torch.tensor(B_T_addrs, device=DEVICE)
    d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)
    d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)
    d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)

    quantiles = [0.5, 0.2, 0.8]
    if provider == "cublas":
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles
        )
    if provider == "triton":
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: triton_perf_fn(
                d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size
            ),
            quantiles=quantiles,
        )
    if provider == "triton-tma":
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: triton_tma_perf_fn(
                d_a_ptrs,
                d_b_t_ptrs,
                d_c_ptrs,
                d_g_sizes,
                d_g_lds,
                group_size,
                dtype=torch.float16,
            ),
            quantiles=quantiles,
        )
    return ms, max_ms, min_ms


@triton.testing.perf_report(
    triton.testing.Benchmark(
        # argument names to use as an x-axis for the plot
        x_names=["M"],
        x_vals=[2**i for i in range(7, 11)],  # different possible values for `x_name`
        line_arg="provider",
        # argument name whose value corresponds to a different line in the plot
        # possible values for `line_arg``
        line_vals=["cublas", "triton"] + (["triton-tma"] if supports_tma() else []),
        # label name for the lines
        line_names=["cuBLAS", "Triton"] + (["Triton + TMA"] if supports_tma() else []),
        # line styles
        styles=[("green", "-"), ("blue", "-")]
        + ([("red", "-")] if supports_tma() else []),
        ylabel="runtime(ms)",  # label name for the y-axis
        plot_name="group-gemm-performance-m-8192-k-8192",
        # name for the plot. Used also as a file name for saving the plot.
        args={},
    )
)
def benchmark_batches(M, provider):
    N = 8192
    K = 8192
    group_size = 4
    group_A = []
    group_B = []
    group_B_T = []
    A_addrs = []
    B_addrs = []
    B_T_addrs = []
    C_addrs = []
    g_sizes = []
    g_lds = []
    g_T_lds = []
    group_C = []
    for i in range(group_size):
        A = torch.rand((M, K), device=DEVICE, dtype=torch.float16)
        B = torch.rand((K, N), device=DEVICE, dtype=torch.float16)
        C = torch.empty((M, N), device=DEVICE, dtype=torch.float16)
        B_T = B.T.contiguous()
        group_A.append(A)
        group_B.append(B)
        group_B_T.append(B_T)
        group_C.append(C)
        A_addrs.append(A.data_ptr())
        B_addrs.append(B.data_ptr())
        B_T_addrs.append(B_T.data_ptr())
        C_addrs.append(C.data_ptr())
        g_sizes += [M, N, K]
        g_lds += [A.stride(0), B.stride(0), C.stride(0)]
        g_T_lds += [A.stride(0), B_T.stride(0), C.stride(0)]

    d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)
    d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)
    d_b_t_ptrs = torch.tensor(B_T_addrs, device=DEVICE)
    d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)
    d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)
    d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)
    d_g_t_lds = torch.tensor(g_T_lds, dtype=torch.int32, device=DEVICE)

    quantiles = [0.5, 0.2, 0.8]
    if provider == "cublas":
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles
        )
    if provider == "triton":
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: triton_perf_fn(
                d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size
            ),
            quantiles=quantiles,
        )
    if provider == "triton-tma":
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: triton_tma_perf_fn(
                d_a_ptrs,
                d_b_t_ptrs,
                d_c_ptrs,
                d_g_sizes,
                d_g_t_lds,
                group_size,
                dtype=torch.float16,
            ),
            quantiles=quantiles,
        )
    return ms, max_ms, min_ms


benchmark_square_matrices.run(show_plots=True, print_data=True)
benchmark_batches.run(show_plots=True, print_data=True)


================================================
FILE: kernels/MoE/group_GEMM/triton/readme.md
================================================
##  Experimental

Triton Group GEMM for supporting MoE training. 


================================================
FILE: kernels/MoE/group_GEMM/triton/testing/fast_verification.py
================================================
import logging

import torch

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

# import the reference implementations
from pytorch_reference_backwards import (
    _compute_grad_w_pytorch,
    _compute_grad_x_pytorch,
    _pytorch_fallback_backward,
    _pytorch_reference_backward,
)

# Import the grouped GEMM modules
from tgrouped_gemm_backwards import grouped_gemm_backward
from tgrouped_gemm_forward import grouped_gemm_forward as grouped_gemm


def test_backward_pass():
    """
    A simple test for the grouped GEMM backward pass with detailed error handling.
    """
    try:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Test parameters
        G = 20  # Number of groups
        M = 1024  # Input dimension
        N = 512  # Output dimension per group
        K = 256  # Hidden dimension

        # Create input and weight tensors
        x = torch.randn(M, K, dtype=torch.bfloat16, device=device, requires_grad=True)
        w = torch.randn(
            N * G, K, dtype=torch.bfloat16, device=device, requires_grad=True
        )

        # Create group sizes
        m_sizes = torch.zeros(G, device=device, dtype=torch.int32)
        base_size = M // G
        remainder = M % G

        for i in range(G):
            m_sizes[i] = base_size + (1 if i < remainder else 0)

        # Log the setup
        print(f"Test setup - G: {G}, M: {M}, N: {N}, K: {K}")
        print(f"Input x shape: {x.shape}")
        logging.info(f"Weight w shape: {w.shape}")
        logging.info(f"Group sizes: {m_sizes}")

        # Step 1: Run forward pass
        logging.info("Running forward pass")
        result = grouped_gemm(x, w, m_sizes)
        logging.info(f"Forward result shape: {result.shape}")

        # Create a gradient for backpropagation
        grad_output = torch.randn_like(result)
        logging.info(f"Created gradient with shape: {grad_output.shape}")

        # Step 2: Run backward pass directly
        logging.info("Running backward pass directly")
        grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes)

        # Verify gradient shapes
        logging.info(
            f"Gradient shapes - grad_x: {grad_x.shape}, grad_w: {grad_w.shape}"
        )

        # Step 3: Verify gradient computation using PyTorch's autograd
        # First create autograd-enabled tensors
        x_autograd = x.detach().clone().requires_grad_(True)
        w_autograd = w.detach().clone().requires_grad_(True)

        # Create a PyTorch reference implementation to compare against
        logging.info("Running PyTorch reference implementation")

        # Compute reference result
        reference_result = torch.zeros_like(result)
        m_start = 0
        for g in range(G):
            m_size = m_sizes[g].item()
            m_end = m_start + m_size
            n_start = g * N
            n_end = (g + 1) * N

            if m_size > 0:
                reference_result[m_start:m_end, n_start:n_end] = (
                    x_autograd[m_start:m_end, :] @ w_autograd[n_start:n_end, :].T
                )

            m_start = m_end

        # Backpropagate using PyTorch
        reference_result.backward(grad_output)

        # Compare gradients
        logging.info("Comparing gradients with PyTorch reference")
        grad_x_error = (grad_x - x_autograd.grad).abs().max().item()
        grad_w_error = (grad_w - w_autograd.grad).abs().max().item()

        logging.info(
            f"Maximum gradient error - grad_x: {grad_x_error}, grad_w: {grad_w_error}"
        )

        # Check if gradients are close using allclose
        rtol = 1e-2  # Relative tolerance for bfloat16
        atol = 1e-2  # Absolute tolerance for bfloat16

        grad_x_close = torch.allclose(grad_x, x_autograd.grad, rtol=rtol, atol=atol)
        if not grad_x_close:
            logging.warning("FAILED: Gradient mismatch detected in grad_x")
        else:
            logging.info(
                "✓ SUCCESS! grad_X matches the PyTorch reference (allclose check passed)"
            )

        grad_w_close = torch.allclose(grad_w, w_autograd.grad, rtol=rtol, atol=atol)
        if not grad_w_close:
            logging.warning("FAILED: Gradient mismatch detected in grad_w")
        else:
            logging.info(
                "✓ SUCCESS! grad_W matches the PyTorch reference (allclose check passed)"
            )

        logging.info(
            f"Gradients allclose check - grad_x: {grad_x_close}, grad_w: {grad_w_close}"
        )

        if grad_x_close and grad_w_close:
            logging.info(
                "✓ SUCCESS: Gradients match the PyTorch reference (allclose check passed)"
            )
        else:
            logging.error("✗ FAILURE: Gradient mismatch detected in allclose check")

        # Additional diagnostics (for failed cases or in general)
        if True:  # not grad_x_close:
            # Find where the largest differences are
            diff_x = (grad_x - x_autograd.grad).abs()
            max_idx_x = diff_x.argmax().item()
            flat_idx_x = max_idx_x
            idx_x = np.unravel_index(flat_idx_x, grad_x.shape)
            logging.error(
                f"Largest grad_x difference at {idx_x}: "
                f"{grad_x[idx_x].item()} vs {x_autograd.grad[idx_x].item()}"
            )
            # Count zeros
            zeros_grad_x = (grad_x == 0).sum().item()
            zeros_autograd_x = (x_autograd.grad == 0).sum().item()
            logging.error(
                f"Zeros in grad_x: {zeros_grad_x}/{grad_x.numel()} ({zeros_grad_x/grad_x.numel()*100:.2f}%)"
            )
            logging.error(
                f"Zeros in x_autograd.grad: {zeros_autograd_x}/{x_autograd.grad.numel()} ({zeros_autograd_x/x_autograd.grad.numel()*100:.2f}%)"
            )

        if True:  # not grad_w_close:
            # Find where the largest differences are
            diff_w = (grad_w - w_autograd.grad).abs()
            max_idx_w = diff_w.argmax().item()
            flat_idx_w = max_idx_w
            idx_w = np.unravel_index(flat_idx_w, grad_w.shape)
            logging.error(
                f"Largest grad_w difference at {idx_w}: "
                f"{grad_w[idx_w].item()} vs {w_autograd.grad[idx_w].item()}"
            )
            # Count zeros
            zeros_grad_w = (grad_w == 0).sum().item()
            zeros_autograd_w = (w_autograd.grad == 0).sum().item()
            logging.error(
                f"Zeros in grad_w: {zeros_grad_w}/{grad_w.numel()} ({zeros_grad_w/grad_w.numel()*100:.2f}%)"
            )
            logging.error(
                f"Zeros in w_autograd.grad: {zeros_autograd_w}/{w_autograd.grad.numel()} ({zeros_autograd_w/w_autograd.grad.numel()*100:.2f}%)"
            )

        return grad_x_close and grad_w_close

    except Exception as e:
        logging.error(f"Test failed with error: {e}")
        import traceback

        logging.error(traceback.format_exc())
        return False


if __name__ == "__main__":
    print("Running test_backward_pass")
    logging.debug("Running test_backward_pass")
    # Add numpy import for unravel_index
    import numpy as np

    success = test_backward_pass()
    logging.info(f"Test {'succeeded' if success else 'failed'}")


================================================
FILE: kernels/MoE/group_GEMM/triton/testing/pytorch_reference_backwards.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import torch

# This is a series of helper functions for grouped GEMM backward that compute the gradients
# using eager PyTorch operations. They are used as a verification reference for the Triton kernels.
# They can also used as a fallback when the Triton kernels cannot be used, though lets hope that is not needed.


def _compute_grad_x_pytorch(grad_output, w, m_sizes, grad_x):
    """
    Compute grad_x using pure PyTorch operations with FP32 precision
    """
    G = m_sizes.shape[0]
    M, K = grad_x.shape
    N = w.shape[0] // G

    # Zero out the output tensor first
    grad_x.zero_()

    # Store original dtype and convert to float32 for computation
    orig_dtype = grad_x.dtype
    grad_output_fp32 = grad_output.float()
    w_fp32 = w.float()
    grad_x_fp32 = torch.zeros_like(grad_x, dtype=torch.float32)

    # Process each group separately
    m_start = 0
    for g in range(G):
        m_size = m_sizes[g].item()
        if m_size > 0:
            m_end = m_start + m_size
            n_start = g * N
            n_end = (g + 1) * N

            # Get slices for this group
            grad_output_slice = grad_output_fp32[m_start:m_end, n_start:n_end]
            w_slice = w_fp32[n_start:n_end]

            # Process in chunks for better precision on large matrices
            CHUNK_SIZE = 256
            for chunk_start in range(0, m_size, CHUNK_SIZE):
                chunk_end = min(chunk_start + CHUNK_SIZE, m_size)
                chunk_size = chunk_end - chunk_start

                # Compute matrix multiplication with higher precision
                grad_output_chunk = grad_output_slice[chunk_start:chunk_end]
                result_chunk = torch.matmul(
                    grad_output_chunk.double(), w_slice.double()
                )

                # Store the result
                grad_x_fp32[m_start + chunk_start : m_start + chunk_end].copy_(
                    result_chunk.float()
                )

        m_start = m_end

    # Convert back to original dtype
    grad_x.copy_(grad_x_fp32.to(orig_dtype))


def _compute_grad_w_pytorch(grad_output, x, m_sizes, grad_w):
    """
    Compute grad_w using pure PyTorch operations with FP64 precision for better accuracy.
    """
    G = m_sizes.shape[0]
    N_times_G, K = grad_w.shape
    N = N_times_G // G

    # Zero out the output tensor first
    grad_w.zero_()

    # Store original dtype and convert to float32 for computation
    orig_dtype = grad_w.dtype
    grad_output_fp32 = grad_output.float()
    x_fp32 = x.float()
    grad_w_fp32 = torch.zeros_like(grad_w, dtype=torch.float32)

    # Handle potential K dimension mismatches
    K_x = x.shape[1]
    min_K = min(K, K_x)

    # Process each group separately
    m_start = 0
    for g in range(G):
        m_size = m_sizes[g].item()
        if m_size > 0:
            m_end = m_start + m_size
            n_start = g * N
            n_end = (g + 1) * N

            # Get slices for this group
            grad_output_slice = grad_output_fp32[m_start:m_end, n_start:n_end]
            x_slice = x_fp32[m_start:m_end, :min_K]

            # Process in chunks for better precision
            CHUNK_SIZE = 32
            result = torch.zeros(
                (grad_output_slice.shape[1], min_K),
                dtype=torch.float64,
                device=grad_output_slice.device,
            )

            for chunk_start in range(0, m_size, CHUNK_SIZE):
                chunk_end = min(chunk_start + CHUNK_SIZE, m_size)

                # Get chunks
                grad_output_chunk = grad_output_slice[chunk_start:chunk_end].double()
                x_chunk = x_slice[chunk_start:chunk_end].double()

                # Matrix multiplication in FP64
                chunk_result = torch.matmul(grad_output_chunk.t(), x_chunk)
                result += chunk_result

            # Handle K dimension padding if needed
            if K > min_K:
                temp_result = torch.zeros(
                    (grad_output_slice.shape[1], K),
                    dtype=torch.float32,
                    device=grad_output_slice.device,
                )
                temp_result[:, :min_K] = result.float()
                grad_w_fp32[n_start:n_end].copy_(temp_result)
            else:
                grad_w_fp32[n_start:n_end].copy_(result.float())

        m_start = m_end

    # Convert back to original dtype
    grad_w.copy_(grad_w_fp32.to(orig_dtype))


def _pytorch_fallback_backward(grad_output, x, w, m_sizes):
    """
    Pure PyTorch implementation of grouped GEMM backward with high precision.
    Used as a fallback when the Triton kernels cannot be used.
    """
    logging.info(
        "WARNING:  Using PyTorch fallback for grouped GEMM backward with high precision"
    )

    # Ensure inputs are contiguous
    x = x.contiguous()
    w = w.contiguous()
    grad_output = grad_output.contiguous()
    m_sizes = m_sizes.contiguous()

    # Allocate output tensors
    grad_x = torch.zeros_like(x)
    grad_w = torch.zeros_like(w)

    # Compute gradients using the helper functions
    _compute_grad_x_pytorch(grad_output, w, m_sizes, grad_x)
    _compute_grad_w_pytorch(grad_output, x, m_sizes, grad_w)

    return grad_x, grad_w


def _pytorch_reference_backward(grad_output, x, w, m_sizes):
    """
    Pure PyTorch implementation of grouped GEMM backward for validation.
    Simple version that's easy to verify but may be less numerically accurate
    for large matrices.
    """
    # Create output gradients
    grad_x = torch.zeros_like(x)
    grad_w = torch.zeros_like(w)

    # Compute group-by-group
    G = m_sizes.shape[0]
    N = w.shape[0] // G

    m_start = 0
    for g in range(G):
        m_size = m_sizes[g].item()
        if m_size > 0:
            m_end = m_start + m_size
            n_start = g * N
            n_end = (g + 1) * N

            # Compute gradients
            grad_x[m_start:m_end] = torch.matmul(
                grad_output[m_start:m_end, n_start:n_end], w[n_start:n_end]
            )
            grad_w[n_start:n_end] = torch.matmul(
                grad_output[m_start:m_end, n_start:n_end].t(), x[m_start:m_end]
            )

        m_start += m_size

    return grad_x, grad_w


# ========== End helper functions ==========


================================================
FILE: kernels/MoE/group_GEMM/triton/tgroup_gemm_backwards.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Tuple

import torch
import triton
import triton.language as tl
from tma_utils import TmaAutoTuneHelper

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

"""
Backward pass for grouped GEMM with Triton, where grouping is N*G
We are computing gradients with respect to both the input (`grad_x`) and the weights (`grad_w`).
"""


# =============== Start Triton Kernels ===============
@triton.jit
def _kernel_grouped_gemm_backward_x_scheduled(
    grad_y_ptr,  # grad of dl/dY [M, N*G]
    w_t_ptr,  # w transposed [K, N*G]
    grad_x_ptr,  # output of kernel [M, K]
    group_offsets_ptr,  # Pre-computed group offsets [G+1]
    workspace,  # Workspace for TMA descriptors
    G,  # Number of groups
    M,  # Total M dimension size
    N,  # N per group
    K,  # K dimension size
    stride_go_m,
    stride_go_n,
    stride_w_n,
    stride_w_k,
    stride_gx_m,
    stride_gx_k,
    NUM_SMS,
    USE_TMA_LOAD: tl.constexpr = False,
    USE_TMA_STORE: tl.constexpr = False,
    BLOCK_SIZE_M: tl.constexpr = 64,
    BLOCK_SIZE_N: tl.constexpr = 64,
    BLOCK_SIZE_K: tl.constexpr = 64,
    GROUP_SIZE_M: tl.constexpr = 8,
    EVEN_K: tl.constexpr = True,
) -> None:
    """
    Scheduled grouped GEMM backward for X with TMA support.

    For each group g, computes: grad_x[g] = grad_y[g] @ w_t[g].T

    Where:
    - grad_y is [M, N*G]
    - w_t is [K, N*G] (transposed from [N*G, K])
    - grad_x is [M, K]
    """
    # Get coordinates for the current program
    tidx = tl.program_id(axis=0)
    dtype = grad_x_ptr.dtype.element_ty
    TMA_SIZE: tl.constexpr = 128

    # Initialize workspace pointer if using TMA store
    if USE_TMA_STORE:
        c_desc_ptr = workspace + tidx * TMA_SIZE
    else:
        c_desc_ptr = None

    # Calculate work distribution parameters
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
    num_pid_in_group = num_pid_m * num_pid_k

    # Process all assigned work items
    pid = tidx
    while pid < G * num_pid_in_group:
        # Calculate work distribution for this pid
        group_id = pid // num_pid_in_group
        pid_in_group = pid % num_pid_in_group
        pid_m = pid_in_group % num_pid_m
        pid_k = pid_in_group // num_pid_m

        # Get group boundaries
        valid_group = group_id < G
        group_start = tl.where(valid_group, tl.load(group_offsets_ptr + group_id), 0)
        group_end = tl.where(valid_group, tl.load(group_offsets_ptr + group_id + 1), 0)
        group_size = group_end - group_start

        # Calculate a mask for valid processing (valid group and non-empty)
        valid_work = valid_group & (group_size > 0)

        # Only process if we have valid work
        if valid_work:
            # Compute offsets for this group
            n_start = group_id * N

            # Block dimensions
            m_block_offset = pid_m * BLOCK_SIZE_M
            k_block_offset = pid_k * BLOCK_SIZE_K

            # Setup TMA descriptor for output if using TMA
            if USE_TMA_STORE:
                m_size = tl.minimum(
                    BLOCK_SIZE_M, group_end - (group_start + m_block_offset)
                )
                if m_size > 0:
                    tl.extra.cuda.experimental_device_tensormap_create2d(
                        desc_ptr=c_desc_ptr,
                        global_address=grad_x_ptr
                        + (group_start + m_block_offset) * stride_gx_m
                        + k_block_offset * stride_gx_k,
                        load_size=[
                            m_size,
                            tl.minimum(BLOCK_SIZE_K, K - k_block_offset),
                        ],
                        global_size=[m_size, K],
                        element_ty=dtype,
                    )
                    tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

            # Initialize offsets for this block
            offs_m = group_start + m_block_offset + tl.arange(0, BLOCK_SIZE_M)

            # For K dimension, optimize memory access if EVEN_K is True
            offs_k = k_block_offset + tl.arange(0, BLOCK_SIZE_K)

            # Create masks
            m_mask = offs_m < group_end
            k_mask = offs_k < K

            # Initialize accumulator
            accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)

            # Loop over the reduction dimension (N)
            # Use smaller steps to improve precision and avoid numerical issues
            for n_offset in range(0, N, BLOCK_SIZE_N):
                # Handle boundary conditions for the reduction dimension
                n_size = tl.minimum(BLOCK_SIZE_N, N - n_offset)
                offs_n = n_start + n_offset + tl.arange(0, BLOCK_SIZE_N)
                n_mask = offs_n < (n_start + N)

                # Fixed stride formats to ensure consistent memory access
                grad_y_block = tl.load(
                    grad_y_ptr
                    + offs_m[:, None] * stride_go_m
                    + offs_n[None, :] * stride_go_n,
                    mask=m_mask[:, None] & n_mask[None, :],
                    other=0.0,
                )

                # Load w_t [K, N*G] block with correct strides
                w_t_block = tl.load(
                    w_t_ptr
                    + offs_k[:, None] * stride_w_k
                    + offs_n[None, :] * stride_w_n,
                    mask=k_mask[:, None] & n_mask[None, :],
                    other=0.0,
                )

                # grad_y @ w_t.T
                # Allow TF32 if K is even and divisible by 8
                if EVEN_K:
                    accumulator += tl.dot(
                        grad_y_block.to(tl.float32),
                        w_t_block.to(tl.float32).T,
                        allow_tf32=True,
                    )
                else:
                    accumulator += tl.dot(
                        grad_y_block.to(tl.float32),
                        w_t_block.to(tl.float32).T,
                        allow_tf32=False,
                    )

            # Store result to grad_x with explicit strides
            if USE_TMA_STORE:
                # TMA store
                tl._experimental_descriptor_store(
                    c_desc_ptr,
                    accumulator.to(dtype),
                    [0, 0],  # Starting offset in the output block
                )
            else:
                # Standard store
                tl.store(
                    grad_x_ptr
                    + offs_m[:, None] * stride_gx_m
                    + offs_k[None, :] * stride_gx_k,
                    accumulator.to(dtype),
                    mask=m_mask[:, None] & k_mask[None, :],
                )

        pid = pid + NUM_SMS


@triton.jit
def _kernel_grouped_gemm_backward_w_scheduled(
    x_t_ptr,  # x transposed [K, M]
    grad_y_ptr,  # grad of dl/dY [M, N*G]
    grad_w_ptr,  # output of kernel (grad_w) [N*G, K]
    group_offsets_ptr,  # Pre-computed group offsets [G+1]
    workspace,  # Workspace for TMA descriptors
    G,  # Number of groups
    M,  # Total M dimension size
    N,  # N per group
    K,  # K dimension size
    stride_x_m,
    stride_x_k,
    stride_go_m,
    stride_go_n,
    stride_gw_n,
    stride_gw_k,
    NUM_SMS,
    USE_TMA_LOAD: tl.constexpr = False,
    USE_TMA_STORE: tl.constexpr = False,
    BLOCK_SIZE_N: tl.constexpr = 64,
    BLOCK_SIZE_K: tl.constexpr = 64,
    BLOCK_SIZE_M: tl.constexpr = 32,
    GROUP_SIZE_N: tl.constexpr = 8,
    EVEN_K: tl.constexpr = True,
) -> None:
    """
    Scheduled implementation of grouped GEMM backward for W with TMA support.

    For each group g, computes:
        grad_w[g] = grad_y[g].T @ x[g]

    Where:
    - x_t is [K, M] (transposed from [M, K])
    - grad_y is [M, N*G]
    - grad_w is [N*G, K]
    """
    # Define coordinates for the current program
    tidx = tl.program_id(axis=0)
    dtype = grad_w_ptr.dtype.element_ty
    TMA_SIZE: tl.constexpr = 128

    # Initialize workspace pointer if using TMA store
    if USE_TMA_STORE:
        c_desc_ptr = workspace + tidx * TMA_SIZE
    else:
        c_desc_ptr = None

    # Calculate work distribution parameters
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
    num_pid_in_group = num_pid_n * num_pid_k

    # Process all assigned work items
    pid = tidx
    while pid < G * num_pid_in_group:
        # Calculate work distribution for this pid
        group_id = pid // num_pid_in_group
        pid_in_group = pid % num_pid_in_group
        pid_n = pid_in_group % num_pid_n
        pid_k = pid_in_group // num_pid_n

        # Get group boundaries
        valid_group = group_id < G
        group_start = tl.where(valid_group, tl.load(group_offsets_ptr + group_id), 0)
        group_end = tl.where(valid_group, tl.load(group_offsets_ptr + group_id + 1), 0)
        group_size = group_end - group_start

        # Calculate a mask for valid processing (valid group and non-empty)
        valid_work = valid_group & (group_size > 0)

        # Only process if we have valid work
        if valid_work:
            # Compute offsets for this group
            n_start = group_id * N

            # Block dimensions
            n_block_offset = pid_n * BLOCK_SIZE_N
            k_block_offset = pid_k * BLOCK_SIZE_K

            # Setup TMA descriptor for output if using TMA
            if USE_TMA_STORE:
                n_size = tl.minimum(BLOCK_SIZE_N, N - n_block_offset)
                if n_size > 0:
                    tl.extra.cuda.experimental_device_tensormap_create2d(
                        desc_ptr=c_desc_ptr,
                        global_address=grad_w_ptr
                        + (n_start + n_block_offset) * stride_gw_n
                        + k_block_offset * stride_gw_k,
                        load_size=[
                            n_size,
                            tl.minimum(BLOCK_SIZE_K, K - k_block_offset),
                        ],
                        global_size=[n_size, K],
                        element_ty=dtype,
                    )
                    tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

            # Initialize offsets for this block
            offs_n = n_start + n_block_offset + tl.arange(0, BLOCK_SIZE_N)
            offs_k = k_block_offset + tl.arange(0, BLOCK_SIZE_K)

            # Create masks
            n_mask = offs_n < (n_start + N)
            k_mask = offs_k < K

            # Initialize accumulator
            accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)

            # Loop over the reduction dimension (M) with smaller steps to avoid overflow
            for m_offset in range(0, group_size, BLOCK_SIZE_M):
                # Handle boundary conditions for the reduction dimension
                m_size = tl.minimum(BLOCK_SIZE_M, group_size - m_offset)
                offs_m = group_start + m_offset + tl.arange(0, BLOCK_SIZE_M)
                m_mask = offs_m < group_end

                # Load grad_y [M, N*G] block with explicit strides
                grad_y_block = tl.load(
                    grad_y_ptr
                    + offs_m[:, None] * stride_go_m
                    + offs_n[None, :] * stride_go_n,
                    mask=m_mask[:, None] & n_mask[None, :],
                    other=0.0,
                )

                # Load x_t [K, M] block with explicit strides
                x_t_block = tl.load(
                    x_t_ptr
                    + offs_k[:, None] * stride_x_k
                    + offs_m[None, :] * stride_x_m,
                    mask=k_mask[:, None] & m_mask[None, :],
                    other=0.0,
                )

                # Matrix multiplication: (grad_y_block.T @ x_t_block.T)
                if EVEN_K:
                    accumulator += tl.dot(
                        grad_y_block.to(
                            tl.float32
                        ).T,  # Shape: [BLOCK_SIZE_N, BLOCK_SIZE_M]
                        x_t_block.to(
                            tl.float32
                        ).T,  # Shape: [BLOCK_SIZE_M, BLOCK_SIZE_K]
                        allow_tf32=True,
                    )
                else:
                    accumulator += tl.dot(
                        grad_y_block.to(
                            tl.float32
                        ).T,  # Shape: [BLOCK_SIZE_N, BLOCK_SIZE_M]
                        x_t_block.to(
                            tl.float32
                        ).T,  # Shape: [BLOCK_SIZE_M, BLOCK_SIZE_K]
                        allow_tf32=False,
                    )

            # Store result to grad_w with explicit strides
            if USE_TMA_STORE:
                # TMA store
                tl._experimental_descriptor_store(
                    c_desc_ptr,
                    accumulator.to(dtype),
                    [0, 0],  # Starting offset in the output block
                )
            else:
                # Standard store with explicit strides
                tl.store(
                    grad_w_ptr
                    + offs_n[:, None] * stride_gw_n
                    + offs_k[None, :] * stride_gw_k,
                    accumulator.to(dtype),
                    mask=n_mask[:, None] & k_mask[None, :],
                )

        pid = pid + NUM_SMS


# ========== End Triton kernels ==========

# ========== Begin grouped_gemm_backward cover function ==========

def grouped_gemm_backward(
    grad_output: torch.Tensor,
    x: torch.Tensor,
    w: torch.Tensor,
    m_sizes: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Backward pass for grouped matrix multiplication using scheduled kernels with TMA support.

    Args:
        grad_output: Gradient with respect to output, shape [M, N*G]
        x: Input tensor from forward pass, shape [M, K]
        w: Weight tensor from forward pass, shape [N*G, K]
        m_sizes: Group sizes tensor, shape [G]

    Returns:
        Tuple of gradients with respect to x and w: (grad_x, grad_w)
    """
    logging.info("Starting grouped_gemm_backward with TMA-enabled scheduling")

    # Check CUDA availability
    if not torch.cuda.is_available():
        logging.error("CUDA not available for backward pass")
        raise RuntimeError("CUDA not available for backward pass")
        # return _pytorch_fallback_backward(grad_output, x, w, m_sizes)

    # Get GPU parameters - TODO: this can use PyTorch cached info...
    device_props = torch.cuda.get_device_properties("cuda")
    NUM_SMS = device_props.multi_processor_count

    # Check TMA support
    has_tma = hasattr(tl.extra, "cuda") and device_props.major >= 9

    if has_tma:
        logging.info(f"TMA support detected on GPU with {NUM_SMS} SMs")
        USE_TMA_LOAD = True  # TODO - this does nothing atm..removed to focus on numerical correctness first.
        USE_TMA_STORE = False
    else:
        logging.warning("TMA support not detected, disabling TMA optimizations")
        USE_TMA_LOAD = False
        USE_TMA_STORE = False

    # Validate input dimensions
    G = m_sizes.shape[0]
    M, K_x = x.shape
    N_times_G, K_w = w.shape

    # Check that K dimensions match
    if K_x != K_w:
        logging.warning(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
        raise ValueError("K dimensions must match for grouped GEMM backward")
        # return _pytorch_fallback_backward(grad_output, x, w, m_sizes)

    try:
        # Ensure contiguous tensors
        grad_output = grad_output.contiguous()
        x = x.contiguous()
        w = w.contiguous()
        m_sizes = m_sizes.contiguous()

        # Allocate output tensors
        grad_x = torch.zeros_like(x)
        grad_w = torch.zeros_like(w)

        # Determine N per group
        # N*G is the second dimension size of grad_output
        N = N_times_G // G

        # Set stride values
        # Direct access pattern for grad_output tensor
        stride_go_m = grad_output.stride(0)  # grad_output in M dimension
        stride_go_n = grad_output.stride(1)  # grad_output in N dimension

        # Pattern match the transposed weight tensor
        stride_w_n = 1  # transposed weights in N dimension
        stride_w_k = N * G  # transposed weights in K dimension

        # Pattern match the output grad_x tensor
        stride_gx_m = grad_x.stride(0)  # grad_x in M dimension
        stride_gx_k = grad_x.stride(1)  # Sgrad_x in K dimension

        # Pattern match the transposed x tensor
        stride_x_m = 1  # Stride for transposed x in M dimension
        stride_x_k = M  # Stride for transposed x in K dimension

        # Pattern match the output grad_w tensor
        stride_gw_n = grad_w.stride(0)  # grad_w in N dimension
        stride_gw_k = grad_w.stride(1)  # grad_w in K dimension

        # Pre-compute group offsets for indexing
        group_offsets = torch.zeros(G + 1, device=m_sizes.device, dtype=torch.int32)
        m_offset = 0
        for g in range(G):
            group_offsets[g] = m_offset
            m_offset += m_sizes[g].item()
        group_offsets[G] = m_offset  # Total M

        # Check if K dimension is even (optimize memory access patterns)
        EVEN_K = (K_x % 8) == 0
        logging.info(f"EVEN_K optimization enabled: {EVEN_K} (K={K_x})")

        # Transpose x and w for backward computation
        x_t = x.T.contiguous()  # Shape: [K, M]
        w_t = w.T.contiguous()  # Shape: [K, N*G]

        # Allocate workspace for TMA descriptors if needed
        if USE_TMA_LOAD or USE_TMA_STORE:
            workspace = torch.empty((NUM_SMS * 128), device=x.device, dtype=torch.uint8)
        else:
            # Empty tensor when TMA is not used
            workspace = torch.empty(0, device=x.device, dtype=torch.uint8)

        # Set block sizes based on K dimension
        # For larger K, use smaller blocks to reduce register pressure
        BLOCK_SIZE = 64 if K_x <= 64 else 32

        BLOCK_SIZE_K = BLOCK_SIZE
        BLOCK_SIZE_M = BLOCK_SIZE
        BLOCK_SIZE_N = BLOCK_SIZE

        # Determine maximum size needed and set the grid size
        num_pid_m = triton.cdiv(M, BLOCK_SIZE_M)
        num_pid_k = triton.cdiv(K_x, BLOCK_SIZE_K)
        num_pid_n = triton.cdiv(N, BLOCK_SIZE_N)

        # Compute total number of blocks needed for each kernel
        total_blocks_x = G * num_pid_m * num_pid_k
        total_blocks_w = G * num_pid_n * num_pid_k

        try:
            logging.info("Computing grad_x with TMA-enabled kernel")

            # Fixed grid size based on SM count
            grid = (NUM_SMS,)

            _kernel_grouped_gemm_backward_x_scheduled[grid](
                grad_output,
                w_t,  # Using transposed weights
                grad_x,
                group_offsets,
                workspace,
                G,
                M,
                N,
                K_x,
                stride_go_m,
                stride_go_n,
                stride_w_n,
                stride_w_k,
                stride_gx_m,
                stride_gx_k,
                NUM_SMS,
                USE_TMA_LOAD,
                USE_TMA_STORE,
                BLOCK_SIZE_M=BLOCK_SIZE_M,
                BLOCK_SIZE_N=BLOCK_SIZE_N,
                BLOCK_SIZE_K=BLOCK_SIZE_K,
                EVEN_K=EVEN_K,
            )
            logging.info(
                "Kernel run success: grad_X computation successful with TMA-enabled kernel"
            )
        except Exception as e:
            logging.error(f"FAILED: Error in TMA-enabled backward_x kernel: {e}")
            logging.info("WARNING: Falling back to PyTorch for grad_x")
            _compute_grad_x_pytorch(grad_output, w, m_sizes, grad_x)

        try:
            logging.info("Computing grad_w with TMA-enabled kernel")

            # Fixed grid size based on SM count
            grid = (NUM_SMS,)

            _kernel_grouped_gemm_backward_w_scheduled[grid](
                x_t,  # Using transposed inputs
                grad_output,
                grad_w,
                group_offsets,
                workspace,
                G,
                M,
                N,
                K_w,
                stride_x_m,
                stride_x_k,
                stride_go_m,
                stride_go_n,
                stride_gw_n,
                stride_gw_k,
                NUM_SMS,
                USE_TMA_LOAD,
                USE_TMA_STORE,
                BLOCK_SIZE_N=BLOCK_SIZE_N,
                BLOCK_SIZE_K=BLOCK_SIZE_K,
                BLOCK_SIZE_M=BLOCK_SIZE_M,
                EVEN_K=EVEN_K,
            )
            logging.info(
                "Kernel run success - grad_W computation successful with TMA-enabled kernel"
            )
        except Exception as e:
            logging.error(f"FAILED: Error in TMA-enabled backward_w kernel: {e}")
            logging.info("WARNING: Falling back to PyTorch for grad_w")
            # _compute_grad_w_pytorch(grad_output, x, m_sizes, grad_w)

        return grad_x, grad_w
    except Exception as e:
        logging.error(f"Error in grouped_gemm_backward: {e}")
        # return _pytorch_fallback_backward(grad_output, x, w, m_sizes)


================================================
FILE: kernels/MoE/group_GEMM/triton/tgroup_gemm_forward.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

# This is copied from FBGEMM, with some modifications.  Not kept in sync, Original code:
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py

import functools
from typing import Optional

import tma_utils as utils

import torch

import triton
import triton.language as tl
from triton.runtime import driver  # @manual

"""
_NV_CONFIGS = [
    triton.Config(
        {
            "BLOCK_SIZE_M": block_size_m,
            "BLOCK_SIZE_N": block_size_n,
            "BLOCK_SIZE_K": block_size_k,
        },
        num_stages=num_stages,
        num_warps=num_warps,
        num_ctas=num_ctas,
    )
    for block_size_m in [64, 128]
    for block_size_n in [64, 128, 256]
    for block_size_k in [64, 128, 256]
    for num_stages in [3, 4]
    for num_warps in [4, 8]
    for num_ctas in [1]
]

_AMD_CONFIGS = [
    triton.Config(
        {
            "BLOCK_SIZE_M": block_size_m,
            "BLOCK_SIZE_N": block_size_n,
            "BLOCK_SIZE_K": block_size_k,
            "waves_per_eu": waves_per_cu,
            "matrix_instr_nonkdim": matrix_instr_nonkdim,
        },
        num_stages=num_stages,
        num_warps=num_warps,
    )
    for block_size_m in [32, 64, 128]
    for block_size_n in [32, 64, 128, 256]
    for block_size_k in [128, 256]
    for num_stages in [1, 2]
    for num_warps, waves_per_cu in [(4, 1), (8, 2), (16, 4)]
    for matrix_instr_nonkdim in [16]
]


def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
    device = torch.cuda.current_device()
    # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
    if dtsize is None:
        dtsize = named_args["c_ptr"].element_size()
    if dtype is None:
        dtype = named_args["c_ptr"].dtype

    pruned_configs = []
    for config in configs:
        kw = config.kwargs
        BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
            kw["BLOCK_SIZE_M"],
            kw["BLOCK_SIZE_N"],
            kw["BLOCK_SIZE_K"],
            config.num_stages,
        )
        G, M, N, K = (
            named_args["G"],
            named_args["M_BUCKET"],
            named_args["N"],
            named_args["K"],
        )

        # 1. make sure we have enough smem
        max_shared_memory = driver.active.utils.get_device_properties(device)[
            "max_shared_mem"
        ]
        if torch.version.hip:
            required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize
        else:
            required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
        if required_shared_memory > max_shared_memory:
            continue

        M_PER_GROUP = M // G
        MIN_M_TILES = 32 if torch.version.hip else 64
        # 2. make sure we don't load M tiles that are too big
        if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
            continue
        # 3. make sure we don't load N tiles that are too small
        if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
            continue

        num_sm = driver.active.utils.get_device_properties(device)[
            "multiprocessor_count"
        ]
        N_TILES = N // BLOCK_N
        MIN_N_TILES = 32 if torch.version.hip else 64
        # 4. make sure we don't load N tiles that are too big
        if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
            continue
        # 5. make sure we don't load N tiles that are too small
        if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
            continue
        # 6. make sure K can be evenly divided
        if K % BLOCK_K != 0:
            continue

        pruned_configs.append(config)

    return pruned_configs


@triton.autotune(
    configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
    key=["G", "M_BUCKET", "N", "K"],
    prune_configs_by={"early_config_prune": early_config_prune},
)
"""


@triton.jit
def _kernel_grouped_gemm(
    a_desc_ptr,
    b_desc_ptr,
    c_ptr,
    workspace,
    m_sizes,
    # problem sizes
    G: tl.constexpr,
    M_BUCKET: tl.constexpr,
    N: tl.constexpr,  # N is per group
    K: tl.constexpr,
    NUM_SMS: tl.constexpr,
    USE_TMA_LOAD: tl.constexpr,
    USE_TMA_STORE: tl.constexpr,
    # tile sizes
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
) -> None:
    tidx = tl.program_id(0)

    dtype: tl.dtype = c_ptr.dtype.element_ty
    TMA_SIZE: tl.constexpr = tl.constexpr(128)
    if USE_TMA_STORE:
        c_desc_ptr = workspace + tidx * TMA_SIZE
    else:
        c_desc_ptr = None

    M_end_offset = 0
    iterated_tiles = 0
    for g in tl.range(G):
        # Move across groups
        M_start_offset = M_end_offset
        m_size = tl.load(m_sizes + g)
        M_end_offset = M_start_offset + m_size

        if m_size > 0:
            # Compute for this group
            N_start_offset = g * N
            n_size = N  # N is already per group

            # Calculate the number of tiles for this group
            num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
            num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
            num_tiles = num_m_tiles * num_n_tiles

            if USE_TMA_STORE:
                # Set up TMA descriptor for output
                # pyre-ignore
                tl.extra.cuda.experimental_device_tensormap_create2d(
                    desc_ptr=c_desc_ptr,
                    global_address=c_ptr
                    + M_start_offset * (N * G)
                    + N_start_offset,  # Offset to this group's output
                    load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
                    global_size=[m_size, n_size],
                    element_ty=c_ptr.dtype.element_ty,
                )
                # pyre-ignore
                tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

            # Move across tiles
            while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
                gidx = tidx - iterated_tiles
                # Split M first and N second.
                tile_m_idx = gidx % num_m_tiles
                tile_n_idx = gidx // num_m_tiles

                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
                tl.static_assert(K % BLOCK_SIZE_K == 0)

                if USE_TMA_LOAD:
                    # Use TMA to load input and weight blocks
                    m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
                    n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)

                    for k_offset in range(0, K, BLOCK_SIZE_K):
                        # Load input block [M, K]
                        a = tl._experimental_descriptor_load(
                            a_desc_ptr,
                            [m_offset, k_offset],
                            [BLOCK_SIZE_M, BLOCK_SIZE_K],
                            dtype,
                        )

                        # Load weight block [N, K]
                        b = tl._experimental_descriptor_load(
                            b_desc_ptr,
                            [n_offset, k_offset],
                            [BLOCK_SIZE_N, BLOCK_SIZE_K],
                            dtype,
                        )

                        # Compute matrix multiplication
                        accumulator += tl.dot(a, b.T)
                else:
                    # Manual load without TMA
                    offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
                    offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
                    offs_k = tl.arange(0, BLOCK_SIZE_K)

                    a_ptrs = (
                        a_desc_ptr
                        + (M_start_offset + offs_am[:, None]) * K
                        + offs_k[None, :]
                    )

                    b_ptrs = (
                        b_desc_ptr
                        + (N_start_offset + offs_bn[:, None]) * K
                        + offs_k[None, :]
                    )

                    for k_offset in range(0, K, BLOCK_SIZE_K):
                        # Load with bounds checking
                        a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
                        b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)

                        # Compute matrix multiplication
                        accumulator += tl.dot(a, b.T)

                        # Update pointers for next block
                        a_ptrs += BLOCK_SIZE_K
                        b_ptrs += BLOCK_SIZE_K

                # Store result
                if USE_TMA_STORE:
                    # Store using TMA
                    m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
                    n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)

                    tl._experimental_descriptor_store(
                        c_desc_ptr,
                        accumulator.to(c_ptr.dtype.element_ty),
                        [m_offset, n_offset],
                    )
                else:
                    # Manual store
                    offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
                    offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)

                    c = accumulator.to(c_ptr.dtype.element_ty)

                    tl.store(
                        c_ptr
                        + (M_start_offset + offs_am[:, None])
                        * (N * G)  # Row stride is N*G
                        + (
                            N_start_offset + offs_bn[None, :]
                        ),  # Column offset to this group's N
                        c,
                        mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
                    )

                tidx += NUM_SMS  # Move to next tile

            iterated_tiles += num_tiles


TT_FP8_DTYPE = tl.float8e4b8 if torch.version.hip else tl.float8e4nv


"""@triton.autotune(
    configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
    key=["G", "M_BUCKET", "N", "K"],
    prune_configs_by={
        "early_config_prune": functools.partial(
            early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1
        )
    },
)
"""


@triton.jit
def _kernel_grouped_gemm_fp8_rowwise(
    a_desc_ptr,
    a_scale_ptr,
    b_desc_ptr,
    b_scale_ptr,
    c_ptr,
    workspace,
    m_sizes,
    # problem sizes
    G: tl.constexpr,
    M_BUCKET: tl.constexpr,
    N: tl.constexpr,  # N is per group
    K: tl.constexpr,
    NUM_SMS: tl.constexpr,
    USE_TMA_LOAD: tl.constexpr,
    USE_TMA_STORE: tl.constexpr,
    # tile sizes
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
) -> None:
    tidx = tl.program_id(0)

    dtype = TT_FP8_DTYPE
    TMA_SIZE: tl.constexpr = tl.constexpr(128)
    if USE_TMA_STORE:
        c_desc_ptr = workspace + tidx * TMA_SIZE
    else:
        c_desc_ptr = None

    M_end_offset = 0
    iterated_tiles = 0
    for g in tl.range(G):
        # Move across groups
        M_start_offset = M_end_offset
        m_size = tl.load(m_sizes + g)
        M_end_offset = M_start_offset + m_size

        if m_size > 0:
            # Compute for this group
            N_start_offset = g * N
            n_size = N  # N is already per group

            # Calculate the number of tiles for this group
            num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
            num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
            num_tiles = num_m_tiles * num_n_tiles

            if USE_TMA_STORE:
                # Set up TMA descriptor for output
                # pyre-ignore
                tl.extra.cuda.experimental_device_tensormap_create2d(
                    desc_ptr=c_desc_ptr,
                    global_address=c_ptr
                    + M_start_offset * (N * G)
                    + N_start_offset,  # Offset to this group's output
                    load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
                    global_size=[m_size, n_size],
                    element_ty=c_ptr.dtype.element_ty,
                )
                # pyre-ignore
                tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

            # Move across tiles
            while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
                gidx = tidx - iterated_tiles
                # Split M first and N second.
                tile_m_idx = gidx % num_m_tiles
                tile_n_idx = gidx // num_m_tiles

                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
                tl.static_assert(K % BLOCK_SIZE_K == 0)

                if USE_TMA_LOAD:
                    # Use TMA to load input and weight blocks with FP8 support
                    m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
                    n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)

                    for k_offset in range(0, K, BLOCK_SIZE_K):
                        # Load input block [M, K] with FP8
                        a = tl._experimental_descriptor_load(
                            a_desc_ptr,
                            [m_offset, k_offset],
                            [BLOCK_SIZE_M, BLOCK_SIZE_K],
                            dtype,
                        )

                        # Load weight block [N, K] with FP8
                        b = tl._experimental_descriptor_load(
                            b_desc_ptr,
                            [n_offset, k_offset],
                            [BLOCK_SIZE_N, BLOCK_SIZE_K],
                            dtype,
                        )

                        # Compute matrix multiplication
                        accumulator += tl.dot(a, b.T)
                else:
                    # Manual load without TMA for FP8
                    offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
                    offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
                    offs_k = tl.arange(0, BLOCK_SIZE_K)

                    a_ptrs = (
                        a_desc_ptr
                        + (M_start_offset + offs_am[:, None]) * K
                        + offs_k[None, :]
                    )

                    b_ptrs = (
                        b_desc_ptr
                        + (N_start_offset + offs_bn[:, None]) * K
                        + offs_k[None, :]
                    )

                    for k_offset in range(0, K, BLOCK_SIZE_K):
                        # Load with bounds checking
                        a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
                        b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)

                        # Compute matrix multiplication
                        accumulator += tl.dot(a, b.T)

                        # Update pointers for next block
                        a_ptrs += BLOCK_SIZE_K
                        b_ptrs += BLOCK_SIZE_K

                # Load FP8 scales
                offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
                offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)

                a_scale = tl.load(
                    a_scale_ptr + M_start_offset + offs_am[:, None],
                    mask=offs_am[:, None] < m_size,
                )

                b_scale = tl.load(
                    b_scale_ptr + N_start_offset + offs_bn[None, :],
                    mask=offs_bn[None, :] < n_size,
                )

                # Apply scales to result
                c = accumulator.to(tl.float32) * a_scale * b_scale

                # Store result
                if USE_TMA_STORE:
                    # Store using TMA
                    m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
                    n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)

                    tl._experimental_descriptor_store(
                        c_desc_ptr,
                        c.to(c_ptr.dtype.element_ty),
                        [m_offset, n_offset],
                    )
                else:
                    # Manual store
                    tl.store(
                        c_ptr
                        + (M_start_offset + offs_am[:, None])
                        * (N * G)  # Row stride is N*G
                        + (
                            N_start_offset + offs_bn[None, :]
                        ),  # Column offset to this group's N
                        c,
                        mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
                    )

                tidx += NUM_SMS  # Move to next tile

            iterated_tiles += num_tiles


def _grouped_gemm(
    x: torch.Tensor,
    w: torch.Tensor,
    m_sizes: torch.Tensor,
    x_scale: Optional[torch.Tensor] = None,
    w_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if not utils.HAS_TMA_DESC:
        raise NotImplementedError("Grouped GEMM without TMA is not supported yet")

    G = m_sizes.shape[0]

    assert x.is_contiguous()
    assert w.is_contiguous()
    assert m_sizes.is_contiguous()

    M, K = x.shape
    N_times_G = w.shape[0]

    # Ensure N is per group
    assert (
        N_times_G % G == 0
    ), f"Weight dimension ({N_times_G}) must be divisible by groups ({G})"
    N = N_times_G // G

    assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"

    # Create output tensor with correct shape [M, N*G]
    y = torch.empty((M, N_times_G), device=x.device, dtype=torch.bfloat16)

    NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
    USE_TMA_LOAD = True  # not torch.version.hip
    USE_TMA_STORE = True

    desc_helper = None
    desc_x = x
    desc_w = w
    workspace = None

    if USE_TMA_LOAD:
        desc_helper = utils.TmaAutoTuneHelper()
        desc_helper.init_tma_descriptor("x")
        desc_helper.init_tma_descriptor("w")
        desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
        desc_w = desc_helper.get_tma_descriptor_kernel_param("w")

    if USE_TMA_STORE:
        workspace = torch.empty(
            NUM_SMS * utils.TmaAutoTuneHelper.TMA_SIZE,
            device=x.device,
            dtype=torch.uint8,
        )

    # Skip autotuning - use fixed grid size
    grid_size = (min(NUM_SMS, 4),)  # Use smaller grid for small inputs
    M_BUCKET = triton.next_power_of_2(M)

    try:

        if USE_TMA_LOAD and desc_helper is not None:
            # Fixed block sizes that work well for most cases
            BLOCK_SIZE_M = 64
            BLOCK_SIZE_N = 64
            BLOCK_SIZE_K = 32

            desc_helper.fill_2d_tma_descriptor(
                "x",
                x.data_ptr(),
                M,
                K,
                BLOCK_SIZE_M,
                BLOCK_SIZE_K,
                x.element_size(),
            )

            desc_helper.fill_2d_tma_descriptor(
                "w",
                w.data_ptr(),
                N_times_G,
                K,
                BLOCK_SIZE_N,
                BLOCK_SIZE_K,
                w.element_size(),
            )
    except Exception as e:
        print(f"Error in TMA descriptor setup: {e}")

    if x_scale is not None and w_scale is not None:
        assert x_scale.is_contiguous()
        assert w_scale.is_contiguous()
        # Call kernel directly without autotuning
        _kernel_grouped_gemm_fp8_rowwise[grid_size](
            desc_x,
            x_scale,
            desc_w,
            w_scale,
            y,
            workspace,
            m_sizes,
            G,
            M_BUCKET,
            N,  # N is per group
            K,
            NUM_SMS,
            USE_TMA_LOAD,
            USE_TMA_STORE,
            BLOCK_SIZE_M=64,  # Fixed block sizes
            BLOCK_SIZE_N=64,
            BLOCK_SIZE_K=32,
        )
    else:
        assert x_scale is None
        assert w_scale is None
        # Call kernel directly without autotuning
        _kernel_grouped_gemm[grid_size](
            desc_x,
            desc_w,
            y,
            workspace,
            m_sizes,
            G,
            M_BUCKET,
            N,  # N is per group
            K,
            NUM_SMS,
            USE_TMA_LOAD,
            USE_TMA_STORE,
            BLOCK_SIZE_M=64,  # Fixed block sizes
            BLOCK_SIZE_N=64,
            BLOCK_SIZE_K=32,
        )

    # Verify the output shape
    expected_output_shape = (M, N_times_G)
    assert y.shape == expected_output_shape, (
        f"Output shape mismatch: got {y.shape}, " f"expected {expected_output_shape}"
    )

    return y


def grouped_gemm_forward(
    x: torch.Tensor, w: torch.Tensor, m_sizes: torch.Tensor
) -> torch.Tensor:
    return _grouped_gemm(x, w, m_sizes)


def grouped_gemm_fp8_rowwise(
    x: torch.Tensor,
    w: torch.Tensor,
    m_sizes: torch.Tensor,
    x_scale: torch.Tensor,
    w_scale: torch.Tensor,
) -> torch.Tensor:
    return _grouped_gemm(x, w, m_sizes, x_scale, w_scale)


================================================
FILE: kernels/MoE/group_GEMM/triton/utils/tma_utils.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
# This code is derived from: https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gemm/triton_gemm

import sys

import torch
import triton  # @manual

import triton.language as tl  # @manual


def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:
    """
    Maps torch dtype to triton dtype.

    Args:
        dtype (torch.dtype): input dtype.

    Returns:
        tl.dtype: triton dtype.
    """
    if dtype == torch.float16:
        return tl.float16
    elif dtype == torch.bfloat16:
        return tl.bfloat16
    elif dtype == torch.float32:
        return tl.float32
    elif dtype == torch.int32:
        return tl.int32
    elif dtype == torch.float8_e4m3fn and torch.version.hip is None:
        return tl.float8e4nv
    else:
        raise ValueError(f"Unsupported dtype {dtype}")


# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498).
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)

if HAS_TMA_DESC:
    print(
        "TMA benchmarks will be running with experimental grid constant TMA descriptor.",
        file=sys.stderr,
    )
else:
    print(
        "Missing: This group gemm code will not run without TMA descriptor support....",
        file=sys.stderr,
    )
    raise NotImplementedError("grouped Gemm without TMA is not supported")


class TmaAutoTuneHelper:

    # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
    class KernelParamWrapper:
        def __init__(self, desc):
            self.desc = desc

        def tma_desc_cpu_ptr(self):
            return self.desc.data_ptr()

    TMA_SIZE = 128

    def __init__(self):
        self.fill_1d_tma_descriptor_inner = (
            triton.runtime.driver.active.utils.fill_1d_tma_descriptor
        )
        self.fill_2d_tma_descriptor_inner = (
            triton.runtime.driver.active.utils.fill_2d_tma_descriptor
        )
        if HAS_TMA_DESC:
            self.descriptors = {}
        else:
            self.cuda_descriptors = {}

    # Call this method outside of the lambda function for grid size
    def init_tma_descriptor(self, name):
        if HAS_TMA_DESC:
            self.descriptors[name] = torch.empty(
                TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8
            )
        else:
            self.cuda_descriptors[name] = torch.empty(
                TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8
            )

    # Call this method inside the lambda function for grid size
    def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
        if HAS_TMA_DESC:
            desc_x = self.descriptors[name]
            assert desc_x.data_ptr() % 64 == 0
            self.fill_1d_tma_descriptor_inner(
                ptr, dim, block_dim, element_size, desc_x.data_ptr()
            )
        else:
            desc_x = self.cuda_descriptors[name]
            buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
            self.fill_1d_tma_descriptor_inner(
                ptr, dim, block_dim, element_size, buf_x.data_ptr()
            )
            desc_x.copy_(buf_x, non_blocking=True)

    # Call this method inside the lambda function for grid size
    def fill_2d_tma_descriptor(
        self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size
    ):
        if HAS_TMA_DESC:
            desc_x = self.descriptors[name]
            assert desc_x.data_ptr() % 64 == 0
            self.fill_2d_tma_descriptor_inner(
                ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
            )
        else:
            desc_x = self.cuda_descriptors[name]
            buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
            self.fill_2d_tma_descriptor_inner(
                ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()
            )
            desc_x.copy_(buf_x, non_blocking=True)

    def get_tma_descriptor_kernel_param(self, name):
        if HAS_TMA_DESC:
            assert self.descriptors[name] is not None
            return self.KernelParamWrapper(self.descriptors[name])
        else:
            assert self.cuda_descriptors[name] is not None
            return self.cuda_descriptors[name]


================================================
FILE: kernels/blackwell/cute_gemm_01/Makefile
================================================

# Makefile for SM100 GEMM PyTorch Extension

# Set these paths according to your installation
CUTLASS_PATH ?= /path/to/cutlass
CUDA_HOME ?= $(shell python -c "import torch; print(torch.utils.cpp_extension.CUDA_HOME)")

# Build the extension
build:
	CUTLASS_PATH=$(CUTLASS_PATH) python setup.py build_ext --inplace

# Install the extension
install:
	CUTLASS_PATH=$(CUTLASS_PATH) pip install .

# Clean build artifacts
clean:
	rm -rf build/ dist/ *.egg-info/ sm100_gemm*.so

# Test the installation
test:
	python python_interface.py

# Check CUDA device capability
check_device:
	python -c "import torch; print(f'CUDA device: {torch.cuda.get_device_name()}, Compute capability: {torch.cuda.get_device_capability()}')"

.PHONY: build install clean test check_device


================================================
FILE: kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/.ninja_log
================================================
# ninja log v5
0	15279	1748131038212164071	/data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o	1163be77f63db063
6	13596	1748131241209889865	/data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm.o	79aa61597088743a
8	13684	1748132015451659084	/data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm.o	89ead7aaccf82852


================================================
FILE: kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/build.ninja
================================================
ninja_required_version = 1.3
cxx = c++
nvcc = /usr/local/cuda-12.8/bin/nvcc

cflags = -pthread -B /home/less/.conda/envs/pycutlass/compiler_compat -fno-strict-overflow -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c
post_cflags = -O3 -std=c++17 -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1018"' -DTORCH_EXTENSION_NAME=sm100_gemm
cuda_cflags = -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c
cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++17 --expt-relaxed-constexpr --expt-extended-lambda -gencode=arch=compute_100a,code=sm_100a -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED --use_fast_math -Xcompiler=-fPIC -DCUTE_ARCH_TCGEN05_TMEM_ENABLED=1 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1018"' -DTORCH_EXTENSION_NAME=sm100_gemm
cuda_dlink_post_cflags = 
sycl_dlink_post_cflags = 
ldflags = 

rule compile
  command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
  depfile = $out.d
  deps = gcc

rule cuda_compile
  depfile = $out.d
  deps = gcc
  command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags







build /data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm.o: cuda_compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm/sm100_gemm.cu
build /data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o: compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm/sm100_gemm_pytorch.cpp










================================================
FILE: kernels/blackwell/cute_gemm_01/driver.py
================================================
# ==============================================================================
# python_interface.py - High-level Python interface
# ==============================================================================


import torch

try:
    import sm100_gemm  # The compiled extension - this has to go after import torch...but auto-formatting is blocking
except ImportError:
    print("❌ SM100 not ready!")
    raise ImportError(
        "SM100 not ready! Please build the extension using `python setup.py install`"
    )


def sm100_gemm_f16(A, B, C=None, alpha=1.0, beta=0.0):
    """
    Perform GEMM using SM100 optimized kernel: D = alpha * A @ B^T + beta * C

    Args:
        A (torch.Tensor): Input tensor A of shape (M, K), dtype=torch.float16
        B (torch.Tensor): Input tensor B of shape (N, K), dtype=torch.float16
        C (torch.Tensor, optional): Input tensor C of shape (M, N), dtype=torch.float32
                                   If None, creates zero tensor
        alpha (float): Scaling factor for A @ B^T
        beta (float): Scaling factor for C

    Returns:
        torch.Tensor: Output tensor D of shape (M, N), dtype=torch.float32

    Note:
        - A and B are K-major (transposed in BLAS terms)
        - C and D are N-major (row-major)
        - All tensors must be on CUDA
        - M must be multiple of 128, N multiple of 256, K multiple of 64
    """

    # Input validation
    assert A.dtype == torch.float16, f"A must be float16, got {A.dtype}"
    assert B.dtype == torch.float16, f"B must be float16, got {B.dtype}"
    assert A.is_cuda and B.is_cuda, "A and B must be on CUDA"
    assert A.is_contiguous() and B.is_contiguous(), "A and B must be contiguous"

    M, K = A.shape
    N, K_B = B.shape
    assert K == K_B, f"Inner dimensions must match: A.shape[1]={K}, B.shape[1]={K_B}"

    # Check alignment requirements
    assert M % 128 == 0, f"M={M} must be multiple of 128"
    assert N % 256 == 0, f"N={N} must be multiple of 256"
    assert K % 64 == 0, f"K={K} must be multiple of 64"

    # Create C if not provided
    if C is None:
        C = torch.zeros(M, N, dtype=torch.float32, device=A.device)
    else:
        assert C.dtype == torch.float32, f"C must be float32, got {C.dtype}"
        assert C.is_cuda, "C must be on CUDA"
        assert C.is_contiguous(), "C must be contiguous"
        assert C.shape == (
            M,
            N,
        ), f"C shape {C.shape} must match output shape ({M}, {N})"

    # Call the extension
    return sm100_gemm.sm100_gemm_f16(A, B, C, alpha, beta)


def benchmark_sm100_vs_torch(
    M=1024, N=2048, K=256, num_warmup=1, num_trials=10
):  # M=512, N=1024, K=256, num_warmup=10, num_trials=100):
    """
    Benchmark SM100 GEMM against PyTorch's native GEMM
    """
    # Ensure dimensions are aligned
    M = ((M + 127) // 128) * 128
    N = ((N + 255) // 256) * 256
    K = ((K + 63) // 64) * 64

    print(f"Benchmarking GEMM with shape: ({M}, {N}, {K})")

    # Create test tensors
    A = torch.randn(M, K, dtype=torch.float16, device="cuda")
    B = torch.randn(N, K, dtype=torch.float16, device="cuda")
    C = torch.randn(M, N, dtype=torch.float16, device="cuda")
    C32 = C.to(torch.float32).clone()

    # Keep A and B as FP16 for PyTorch
    A_fp16 = A
    B_fp16 = B

    # Warmup
    for _ in range(num_warmup):
        # PyTorch GEMM (using FP16)
        torch_result = torch.addmm(C, A_fp16, B_fp16.T)

        # SM100 GEMM
        sm100_result = sm100_gemm_f16(A, B, C32)

    torch.cuda.synchronize()

    # Benchmark PyTorch
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    for _ in range(num_trials):
        torch_result = torch.addmm(C, A_fp16, B_fp16.T)
    end.record()
    torch.cuda.synchronize()
    torch_time = start.elapsed_time(end) / num_trials

    # Benchmark SM100
    start.record()
    for _ in range(num_trials):
        sm100_result = sm100_gemm_f16(A, B, C32)
    end.record()
    torch.cuda.synchronize()
    sm100_time = start.elapsed_time(end) / num_trials

    # Check correctness
    max_diff = torch.max(torch.abs(torch_result - sm100_result.to(torch.float16)))
    rel_error = max_diff / torch.max(torch.abs(torch_result))

    # Calculate FLOPS
    flops = 2 * M * N * K  # Multiply-add operations
    torch_tflops = flops / (torch_time * 1e-3) / 1e12
    sm100_tflops = flops / (sm100_time * 1e-3) / 1e12

    print(f"PyTorch time: {torch_time:.3f} ms ({torch_tflops:.2f} TFLOPS)")
    print(f"SM100 time: {sm100_time:.3f} ms ({sm100_tflops:.2f} TFLOPS)")
    print(f"Speedup: {torch_time/sm100_time:.2f}x")
    print(f"Max difference: {max_diff:.6f}")
    print(f"Relative error: {rel_error:.6f}")

    return {
        "torch_time": torch_time,
        "sm100_time": sm100_time,
        "speedup": torch_time / sm100_time,
        "torch_tflops": torch_tflops,
        "sm100_tflops": sm100_tflops,
        "max_diff": max_diff.item(),
        "rel_error": rel_error.item(),
    }


# Example usage and test
if __name__ == "__main__":
    # Test basic functionality
    print("Testing SM100 GEMM...")

    M, N, K = 512, 1024, 256
    A = torch.randn(M, K, dtype=torch.float16, device="cuda")
    B = torch.randn(N, K, dtype=torch.float16, device="cuda")
    C = torch.randn(M, N, dtype=torch.float32, device="cuda")

    # Test the GEMM
    result = sm100_gemm_f16(A, B, C, alpha=1.0, beta=0.5)
    print(f"Result shape: {result.shape}, dtype: {result.dtype}")

    # Run benchmark
    print("\nRunning benchmark...")
    benchmark_results = benchmark_sm100_vs_torch(M, N, K)

# ==============================================================================
# Makefile for easy building
# ==============================================================================
'''
MAKEFILE_CONTENT = """
# Makefile for SM100 GEMM PyTorch Extension

# Set these paths according to your installation
CUTLASS_PATH ?= /path/to/cutlass
CUDA_HOME ?= $(shell python -c "import torch; print(torch.utils.cpp_extension.CUDA_HOME)")

# Build the extension
build:
	CUTLASS_PATH=$(CUTLASS_PATH) python setup.py build_ext --inplace

# Install the extension
install:
	CUTLASS_PATH=$(CUTLASS_PATH) pip install .

# Clean build artifacts
clean:
	rm -rf build/ dist/ *.egg-info/ sm100_gemm*.so

# Test the installation
test:
	python python_interface.py

# Check CUDA device capability
check_device:
	python -c "import torch; print(f'CUDA device: {torch.cuda.get_device_name()}, Compute capability: {torch.cuda.get_device_capability()}')"

.PHONY: build install clean test check_device
"""

# Write Makefile
with open("Makefile", "w") as f:
    f.write(MAKEFILE_CONTENT)

print("Setup files created!")
print("To build:")
print("1. Set CUTLASS_PATH environment variable to your CUTLASS installation")
print("2. Run: make build")
print("3. Test: make test")
'''


================================================
FILE: kernels/blackwell/cute_gemm_01/setup.py
================================================
# setup.py
import os

import pybind11
import torch
from pybind11 import get_cmake_dir
from pybind11.setup_helpers import build_ext, Pybind11Extension
from setuptools import Extension, setup
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension

# IMPORTANT: The following two lines are the only ones you need to change
# Get CUTLASS path (you'll need to set this to your CUTLASS installation)
CUTLASS_PATH = os.environ.get("CUTLASS_PATH", "/home/less/local/cutlas40")

# CUDA and PyTorch paths
cuda_home = torch.utils.cpp_extension.CUDA_HOME
pytorch_includes = torch.utils.cpp_extension.include_paths()

ext_modules = [
    CUDAExtension(
        name="sm100_gemm",
        sources=[
            "sm100_gemm_pytorch.cpp",  # PyTorch bindings (C++)
            "sm100_gemm.cu",  # CUDA kernel implementation
        ],
        include_dirs=[
            # PyTorch includes
            *pytorch_includes,
            # CUTLASS includes
            f"{CUTLASS_PATH}/include",
            f"{CUTLASS_PATH}/tools/util/include",
            # CUDA includes
            f"{cuda_home}/include",
        ],
        library_dirs=[
            f"{cuda_home}/lib64",
        ],
        libraries=["cuda", "cudart"],
        extra_compile_args={
            "cxx": [
                "-O3",
                "-std=c++17",
                "-DCUTLASS_ARCH_MMA_SM100_SUPPORTED",
                "-DCUTE_SM100_ENABLED",
            ],
            "nvcc": [
                "-O3",
                "-std=c++17",
                "--expt-relaxed-constexpr",
                "--expt-extended-lambda",
                "-gencode=arch=compute_100a,code=sm_100a",  # SM100 architecture
                "-DCUTLASS_ARCH_MMA_SM100_SUPPORTED",
                "-DCUTE_SM100_ENABLED",
                "--use_fast_math",
                "-Xcompiler=-fPIC",
                "-DCUTE_ARCH_TCGEN05_TMEM_ENABLED=1",  # Enable TCGEN05_TMEM
            ],
        },
        extra_link_args=["-lcuda", "-lcudart"],
        language="c++",
    )
]

setup(
    name="sm100_gemm",
    ext_modules=ext_modules,
    cmdclass={"build_ext": BuildExtension},
    zip_safe=False,
    python_requires=">=3.8",
    install_requires=["torch>=1.12.0"],
)


================================================
FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.cu
================================================
// sm100_gemm_kernel.cu - CUDA kernel implementation
#include "sm100_gemm.h"

#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)

#include <cutlass/arch/barrier.h>
#include <cutlass/cluster_launch.hpp>
#include <cutlass/half.h>
#include <cutlass/util/print_error.hpp>

#include <cute/algorithm/cooperative_copy.hpp>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/tmem_allocator_sm100.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cute/tensor.hpp>

using namespace cute;

// Shared storage structure
template <class TypeA, class TypeB, class ASmemLayout, class BSmemLayout>
struct SharedStorage {
  alignas(128) cute::ArrayEngine<TypeA, cute::cosize_v<ASmemLayout>> A;
  alignas(128) cute::ArrayEngine<TypeB, cute::cosize_v<BSmemLayout>> B;
  alignas(16) cute::uint64_t mma_barrier;
  alignas(16) cute::uint32_t tmem_base_ptr;

  CUTE_DEVICE constexpr auto tensor_sA() {
    return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{});
  }
  CUTE_DEVICE constexpr auto tensor_sB() {
    return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{});
  }
};

// Device kernel
template <class SharedStorage, class ATensor, class BTensor, class CTensor,
          class DTensor, class MmaTiler_MNK, class TiledMMA,
          class ClusterShape_MNK, class Alpha, class Beta>
__global__ static void
gemm_device(ATensor mA, BTensor mB, CTensor mC, DTensor mD,
            MmaTiler_MNK mma_tiler, TiledMMA tiled_mma,
            ClusterShape_MNK cluster_shape, Alpha alpha, Beta beta) {
  // Step 1: The Prologue
  Layout cluster_layout_vmnk = tiled_divide(
      make_layout(cluster_shape), make_tile(typename TiledMMA::AtomThrID{}));

  auto mma_coord_vmnk =
      make_coord(blockIdx.x % size<0>(cluster_layout_vmnk),
                 blockIdx.x / size<0>(cluster_layout_vmnk), blockIdx.y, _);

  auto mma_coord = select<1, 2, 3>(mma_coord_vmnk);
  Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X, _1>{});
  Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step<X, _1, _1>{});
  Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1, _1, X>{});
  Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1, _1, X>{});

  // SMEM allocation
  extern __shared__ char shared_memory[];
  SharedStorage &shared_storage =
      *reinterpret_cast<SharedStorage *>(shared_memory);

  Tensor tCsA = shared_storage.tensor_sA();
  Tensor tCsB = shared_storage.tensor_sB();

  // MMA partitioning
  auto mma_v = get<0>(mma_coord_vmnk);
  ThrMMA cta_mma = tiled_mma.get_slice(mma_v);
  Tensor tCgA = cta_mma.partition_A(gA);
  Tensor tCgB = cta_mma.partition_B(gB);
  Tensor tCgC = cta_mma.partition_C(gC);
  Tensor tCgD = cta_mma.partition_C(gD);

  // Fragment allocation
  Tensor tCrA = cta_mma.make_fragment_A(tCsA);
  Tensor tCrB = cta_mma.make_fragment_B(tCsB);
  Tensor tCtAcc = cta_mma.make_fragment_C(tCgC);

  uint32_t elect_one_thr = cute::elect_one_sync();
  uint32_t elect_one_warp = (threadIdx.x / 32 == 0);

  using TmemAllocator = cute::TMEM::Allocator1Sm;
  TmemAllocator tmem_allocator{};

  // TMEM allocation
  if (elect_one_warp) {
    tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns,
                            &shared_storage.tmem_base_ptr);
  }
  __syncthreads();
  tCtAcc.data() = shared_storage.tmem_base_ptr;

  // Barrier initialization
  if (elect_one_warp && elect_one_thr) {
    cute::initialize_barrier(shared_storage.mma_barrier, 1);
  }
  int mma_barrier_phase_bit = 0;
  __syncthreads();

  // Step 2: The Mainloop
  tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;

  for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile) {
    // Load A and B tiles
    cooperative_copy<128>(threadIdx.x, tCgA(_, _, _, k_tile), tCsA);
    cooperative_copy<128>(threadIdx.x, tCgB(_, _, _, k_tile), tCsB);

    __syncthreads();

    // Execute MMAs
    if (elect_one_warp) {
      for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
        gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCtAcc);
        tiled_mma.accumulate_ = UMMA::ScaleOut::One;
      }
      cutlass::arch::umma_arrive(&shared_storage.mma_barrier);
    }
    cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit);
    mma_barrier_phase_bit ^= 1;
  }

  // Step 3: The Epilogue
  TiledCopy tiled_t2r_copy =
      make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc);
  ThrCopy thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x);

  Tensor tDgC = thr_t2r_copy.partition_D(tCgC);
  Tensor tDrC = make_fragment_like(tDgC);
  copy(tDgC, tDrC);

  Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc);
  Tensor tDgD = thr_t2r_copy.partition_D(tCgD);
  using AccType = typename decltype(tCtAcc)::value_type;
  Tensor tDrAcc = make_tensor<AccType>(shape(tDgD));
  copy(tiled_t2r_copy, tDtAcc, tDrAcc);

  // AXPBY and store result
  axpby(alpha, tDrAcc, beta, tDrC);
  copy(tDrC, tDgD);

  __syncthreads();

  // Cleanup TMEM
  if (elect_one_warp) {
    tmem_allocator.release_allocation_lock();
    tmem_allocator.free(shared_storage.tmem_base_ptr,
                        TmemAllocator::Sm100TmemCapacityColumns);
  }
}

// Host function that launches the kernel
cudaError_t launch_sm100_gemm_f16(void *d_A, void *d_B, void *d_C, void *d_D,
                                  int M, int N, int K, float alpha, float beta,
                                  cudaStream_t stream) {
  // Define types
  using TypeA = cutlass::half_t;
  using TypeB = cutlass::half_t;
  using TypeC = float;
  using TypeD = float;

  // Create layouts (K-major for A and B, N-major for C and D)
  auto layout_A = make_layout(make_shape(M, K), make_stride(K, Int<1>{}));
  auto layout_B = make_layout(make_shape(N, K), make_stride(K, Int<1>{}));
  auto layout_C = make_layout(make_shape(M, N), make_stride(N, Int<1>{}));
  auto layout_D = layout_C;

  // Create CuTe tensors
  auto mA =
      make_tensor(make_gmem_ptr(reinterpret_cast<TypeA *>(d_A)), layout_A);
  auto mB =
      make_tensor(make_gmem_ptr(reinterpret_cast<TypeB *>(d_B)), layout_B);
  auto mC =
      make_tensor(make_gmem_ptr(reinterpret_cast<TypeC *>(d_C)), layout_C);
  auto mD =
      make_tensor(make_gmem_ptr(reinterpret_cast<TypeD *>(d_D)), layout_D);

  // Create TiledMMA
  TiledMMA tiled_mma =
      make_tiled_mma(SM100_MMA_F16BF16_SS<TypeA, TypeB, TypeC, 128, 256,
                                          UMMA::Major::K, UMMA::Major::K>{});

  // Define MMA tiler sizes
  auto bM = tile_size<0>(tiled_mma);            // 128
  auto bN = tile_size<1>(tiled_mma);            // 256
  auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // 64
  auto mma_tiler = make_shape(bM, bN, bK);

  // Check alignment
  if (M % int(bM) != 0 || N % int(bN) != 0 || K % int(bK) != 0) {
    return cudaErrorInvalidValue;
  }

  // Create SMEM layouts
  auto mma_shape_A = partition_shape_A(
      tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler)));
  auto mma_shape_B = partition_shape_B(
      tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler)));

  auto sA_layout =
      UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeA>{}, mma_shape_A);
  auto sB_layout =
      UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeB>{}, mma_shape_B);

  using SMEMStorage =
      SharedStorage<TypeA, TypeB, decltype(sA_layout), decltype(sB_layout)>;

  // Cluster configuration
  auto cluster_shape = make_shape(Int<1>{}, Int<1>{}, Int<1>{});

  // Launch parameters
  dim3 dimBlock(128);
  dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape),
                  size<2>(cluster_shape));
  dim3 dimGrid(ceil_div(M, int(bM)), ceil_div(N, int(bN)));
  int smemBytes = sizeof(SMEMStorage);

  // Get kernel pointer
  auto *kernel_ptr =
      &gemm_device<SMEMStorage, decltype(mA), decltype(mB), decltype(mC),
                   decltype(mD), decltype(mma_tiler), decltype(tiled_mma),
                   decltype(cluster_shape), float, float>;

  // Set kernel attributes
  cudaError_t error = cudaFuncSetAttribute(
      kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smemBytes);
  if (error != cudaSuccess) {
    return error;
  }

  // Launch kernel
  cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster,
                                         smemBytes};
  cutlass::Status status = cutlass::launch_kernel_on_cluster(
      params, (void const *)kernel_ptr, mA, mB, mC, mD, mma_tiler, tiled_mma,
      cluster_shape, alpha, beta);

  return (status == cutlass::Status::kSuccess) ? cudaSuccess
                                               : cudaErrorLaunchFailure;
}

#else

cudaError_t launch_sm100_gemm_f16(void *d_A, void *d_B, void *d_C, void *d_D,
                                  int M, int N, int K, float alpha, float beta,
                                  cudaStream_t stream) {
  return cudaErrorNotSupported;
}

#endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED


================================================
FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/PKG-INFO
================================================
Metadata-Version: 2.4
Name: sm100_gemm
Version: 0.0.0
Requires-Python: >=3.8
Requires-Dist: torch>=1.12.0
Dynamic: requires-dist
Dynamic: requires-python


================================================
FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/SOURCES.txt
================================================
setup.py
sm100_gemm.cu
sm100_gemm_pytorch.cpp
sm100_gemm.egg-info/PKG-INFO
sm100_gemm.egg-info/SOURCES.txt
sm100_gemm.egg-info/dependency_links.txt
sm100_gemm.egg-info/not-zip-safe
sm100_gemm.egg-info/requires.txt
sm100_gemm.egg-info/top_level.txt

================================================
FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/dependency_links.txt
================================================



================================================
FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/not-zip-safe
================================================



================================================
FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/requires.txt
================================================
torch>=1.12.0


================================================
FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/top_level.txt
================================================
sm100_gemm


================================================
FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.h
================================================
// sm100_gemm_kernel.h - Header file for CUDA kernel
#pragma once

#include <cuda_runtime.h>

#ifdef __cplusplus
extern "C" {
#endif

/**
 * Launch SM100 GEMM kernel: D = alpha * A @ B^T + beta * C
 *
 * @param d_A Pointer to matrix A in device memory (M x K, FP16, K-major)
 * @param d_B Pointer to matrix B in device memory (N x K, FP16, K-major)
 * @param d_C Pointer to matrix C in device memory (M x N, FP32, N-major)
 * @param d_D Pointer to matrix D in device memory (M x N, FP32, N-major)
 * @param M Number of rows in A and C/D
 * @param N Number of rows in B and columns in C/D
 * @param K Number of columns in A and B
 * @param alpha Scaling factor for A @ B^T
 * @param beta Scaling factor for C
 * @param stream CUDA stream (currently unused, for future async support)
 *
 * @return cudaSuccess on success, error code otherwise
 *
 * Requirements:
 * - M must be multiple of 128
 * - N must be multiple of 256
 * - K must be multiple of 64
 * - All pointers must be valid device memory
 * - Tensors must be contiguous with specified layouts
 */
cudaError_t launch_sm100_gemm_f16(void *d_A, void *d_B, void *d_C, void *d_D,
                                  int M, int N, int K, float alpha, float beta,
                                  cudaStream_t stream = 0);

#ifdef __cplusplus
}
#endif


================================================
FILE: kernels/blackwell/cute_gemm_01/sm100_gemm_pytorch.cpp
================================================
// sm100_gemm_pytorch.cpp - PyTorch C++ extension (no CUDA code)
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <torch/extension.h>

#include "sm100_gemm.h"

// Check if SM100 support is available at compile time
bool is_sm100_supported() {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
  return true;
#else
  return false;
#endif
}

// Check if current GPU supports SM100 at runtime
bool check_sm100_device() {
  int device;
  cudaGetDevice(&device);

  cudaDeviceProp props;
  cudaError_t error = cudaGetDeviceProperties(&props, device);
  if (error != cudaSuccess) {
    return false;
  }

  // Check for SM100 architecture (compute capability 10.0a)
  return (props.major == 10 && props.minor == 0);
}

torch::Tensor sm100_gemm_f16(const torch::Tensor &A, const torch::Tensor &B,
                             const torch::Tensor &C, float alpha = 1.0f,
                             float beta = 0.0f) {

  // Check compile-time support
  TORCH_CHECK(
      is_sm100_supported(),
      "SM100 support not compiled. Requires CUTLASS_ARCH_MMA_SM100_SUPPORTED");

  // Check runtime device support
  TORCH_CHECK(check_sm100_device(),
              "Current GPU does not support SM100 architecture (requires "
              "compute capability 10.0a)");

  // Input validation
  TORCH_CHECK(A.device().is_cuda(), "A must be a CUDA tensor");
  TORCH_CHECK(B.device().is_cuda(), "B must be a CUDA tensor");
  TORCH_CHECK(C.device().is_cuda(), "C must be a CUDA tensor");
  TORCH_CHECK(A.dtype() == torch::kFloat16, "A must be float16");
  TORCH_CHECK(B.dtype() == torch::kFloat16, "B must be float16");
  TORCH_CHECK(C.dtype() == torch::kFloat32, "C must be float32");
  TORCH_CHECK(A.is_contiguous(), "A must be contiguous");
  TORCH_CHECK(B.is_contiguous(), "B must be contiguous");
  TORCH_CHECK(C.is_contiguous(), "C must be contiguous");
  TORCH_CHECK(A.dim() == 2, "A must be 2D");
  TORCH_CHECK(B.dim() == 2, "B must be 2D");
  TORCH_CHECK(C.dim() == 2, "C must be 2D");

  // Get dimensions
  int64_t M = A.size(0);
  int64_t K = A.size(1);
  int64_t N = B.size(0);
  int64_t K_B = B.size(1);

  TORCH_CHECK(K == K_B, "Inner dimensions must match: A.shape[1]=", K,
              ", B.shape[1]=", K_B);
  TORCH_CHECK(C.size(0) == M && C.size(1) == N, "C dimensions (", C.size(0),
              ", ", C.size(1), ") must match output shape (", M, ", ", N, ")");

  // Check alignment requirements for SM100
  TORCH_CHECK(M % 128 == 0, "M=", M, " must be multiple of 128");
  TORCH_CHECK(N % 256 == 0, "N=", N, " must be multiple of 256");
  TORCH_CHECK(K % 64 == 0, "K=", K, " must be multiple of 64");

  // Check size limits (avoid overflow in int conversion)
  TORCH_CHECK(M <= INT_MAX && N <= INT_MAX && K <= INT_MAX,
              "Dimensions too large for int conversion");

  // Create output tensor
  auto D = torch::empty_like(C);

  // Set CUDA device guard
  const auto device = A.device();
  c10::cuda::CUDAGuard device_guard(device);

  // Get current CUDA stream
  cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index()).stream();

  // Launch the kernel
  cudaError_t error = launch_sm100_gemm_f16(
      A.data_ptr(), B.data_ptr(), C.data_ptr(), D.data_ptr(),
      static_cast<int>(M), static_cast<int>(N), static_cast<int>(K), alpha,
      beta, stream);

  // Check for launch errors
  TORCH_CHECK(error == cudaSuccess,
              "SM100 GEMM kernel launch failed: ", cudaGetErrorString(error));

  // Check for kernel execution errors
  C10_CUDA_CHECK(cudaGetLastError());

  return D;
}

// Utility functions for debugging and information
torch::Tensor get_device_info() {
  int device;
  cudaGetDevice(&device);

  cudaDeviceProp props;
  cudaGetDeviceProperties(&props, device);

  // Return device info as a tensor (for easy Python access)
  auto info = torch::zeros({4}, torch::kInt32);
  auto accessor = info.accessor<int32_t, 1>();

  accessor[0] = props.major;          // Compute capability major
  accessor[1] = props.minor;          // Compute capability minor
  accessor[2] = is_sm100_supported(); // Compile-time support
  accessor[3] = check_sm100_device(); // Runtime device support

  return info;
}

std::vector<int64_t> get_aligned_shape(int64_t M, int64_t N, int64_t K) {
  // Return properly aligned dimensions for SM100
  int64_t aligned_M = ((M + 127) / 128) * 128;
  int64_t aligned_N = ((N + 255) / 256) * 256;
  int64_t aligned_K = ((K + 63) / 64) * 64;

  return {aligned_M, aligned_N, aligned_K};
}

torch::Tensor create_aligned_tensor(const std::vector<int64_t> &shape,
                                    torch::ScalarType dtype,
                                    torch::Device device) {
  // Create a tensor with SM100-aligned dimensions
  TORCH_CHECK(shape.size() == 2, "Shape must be 2D");

  auto aligned_shape =
      get_aligned_shape(shape[0], shape[1], shape.size() > 2 ? shape[2] : 64);

  if (shape.size() == 2) {
    return torch::zeros({aligned_shape[0], aligned_shape[1]},
                        torch::TensorOptions().dtype(dtype).device(device));
  } else {
    return torch::zeros({aligned_shape[0], aligned_shape[2]},
                        torch::TensorOptions().dtype(dtype).device(device));
  }
}

// Python bindings
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.doc() = "SM100 GEMM PyTorch Extension";

  // Main GEMM function
  m.def("sm100_gemm_f16", &sm100_gemm_f16,
        "SM100 GEMM with FP16 inputs and FP32 output: D = alpha * A @ B^T + "
        "beta * C",
        py::arg("A"), py::arg("B"), py::arg("C"), py::arg("alpha") = 1.0f,
        py::arg("beta") = 0.0f);

  // Utility functions
  m.def("is_sm100_supported", &is_sm100_supported,
        "Check if SM100 support was compiled in");

  m.def("check_sm100_device", &check_sm100_device,
        "Check if current GPU supports SM100 architecture");

  m.def("get_device_info", &get_device_info,
        "Get device compute capability and SM100 support info");

  m.def("get_aligned_shape", &get_aligned_shape,
        "Get SM100-aligned dimensions for given shape", py::arg("M"),
        py::arg("N"), py::arg("K"));

  m.def("create_aligned_tensor", &create_aligned_tensor,
        "Create tensor with SM100-aligned dimensions", py::arg("shape"),
        py::arg("dtype"), py::arg("device"));

  // Constants for alignment requirements
  m.attr("MMA_TILE_M") = 128;
  m.attr("MMA_TILE_N") = 256;
  m.attr("MMA_TILE_K") = 64;
}


================================================
FILE: kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/.ninja_log
================================================
# ninja log v5
1	15202	1748185895110710199	/data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o	342153d32d365f0b
7	78	1748186494782816813	/data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o	342153d32d365f0b
6	15086	1748186805607894090	/data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o	342153d32d365f0b
6	14058	1748187024415643408	/data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm.o	6c5f77cfca7cfb81


================================================
FILE: kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/build.ninja
================================================
ninja_required_version = 1.3
cxx = c++
nvcc = /usr/local/cuda-12.8/bin/nvcc

cflags = -pthread -B /home/less/.conda/envs/pycutlass/compiler_compat -fno-strict-overflow -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c
post_cflags = -O3 -std=c++17 -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1018"' -DTORCH_EXTENSION_NAME=sm100_gemm
cuda_cflags = -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c
cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++17 --expt-relaxed-constexpr --expt-extended-lambda -gencode=arch=compute_100a,code=sm_100a -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED --use_fast_math -Xcompiler=-fPIC -DCUTE_ARCH_TCGEN05_TMEM_ENABLED=1 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1018"' -DTORCH_EXTENSION_NAME=sm100_gemm
cuda_dlink_post_cflags = 
sycl_dlink_post_cflags = 
ldflags = 

rule compile
  command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
  depfile = $out.d
  deps = gcc

rule cuda_compile
  depfile = $out.d
  deps = gcc
  command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags







build /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm.o: cuda_compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/sm100_gemm.cu
build /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o: compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/sm100_gemm_pytorch.cpp










================================================
FILE: kernels/blackwell/cute_gemm_02_tma/driver.py
================================================
# python_interface.py - High-level Python interface with TMA support

import torch

try:
    import sm100_gemm  # The compiled extension
except ImportError:
    print("❌ SM100 not ready!")
    raise ImportError(
        "SM100 not ready! Please build the extension using `python setup.py install`"
    )


def check_sm100_compatibility():
    """Check if SM100 is supported and available"""
    compile_support = sm100_gemm.is_sm100_supported()
    device_support = sm100_gemm.check_sm100_device()

    info = sm100_gemm.get_device_info()
    major, minor, compile_flag, device_flag = info.tolist()

    print(f"Device compute capability: {major}.{minor}")
    print(f"Compile-time SM100 support: {bool(compile_flag)}")
    print(f"Runtime SM100 device support: {bool(device_flag)}")

    if not compile_support:
        print(
            "❌ SM100 support not compiled in. Rebuild with CUTLASS_ARCH_MMA_SM100_SUPPORTED"
        )
    elif not device_support:
        print("❌ Current GPU does not support SM100 (need compute capability 10.0a)")
    else:
        print("✅ SM100 with TMA ready!")

    return compile_support and device_support


def sm100_gemm_f16_tma(A, B, C=None, alpha=1.0, beta=0.0, check_alignment=True):
    """
    Perform GEMM using SM100 optimized kernel with TMA: D = alpha * A @ B^T + beta * C

    Args:
        A (torch.Tensor): Input tensor A of shape (M, K), dtype=torch.float16
        B (torch.Tensor): Input tensor B of shape (N, K), dtype=torch.float16
        C (torch.Tensor, optional): Input tensor C of shape (M, N), dtype=torch.float32
                                   If None, creates zero tensor
        alpha (float): Scaling factor for A @ B^T
        beta (float): Scaling factor for C
        check_alignment (bool): Whether to check and suggest aligned dimensions

    Returns:
        torch.Tensor: Output tensor D of shape (M, N), dtype=torch.float32

    Note:
        - Uses TMA (Tensor Memory Accelerator) for efficient memory transfers
        - A and B are K-major (transposed in BLAS terms)
        - C and D are N-major (row-major)
        - All tensors must be on CUDA
        - M must be multiple of 128, N multiple of 256, K multiple of 64
    """

    # Input validation
    assert A.dtype == torch.float16, f"A must be float16, got {A.dtype}"
    assert B.dtype == torch.float16, f"B must be float16, got {B.dtype}"
    assert A.is_cuda and B.is_cuda, "A and B must be on CUDA"
    assert A.is_contiguous() and B.is_contiguous(), "A and B must be contiguous"

    M, K = A.shape
    N, K_B = B.shape
    assert K == K_B, f"Inner dimensions must match: A.shape[1]={K}, B.shape[1]={K_B}"

    # Check or fix alignment requirements
    if check_alignment:
        aligned_M, aligned_N, aligned_K = sm100_gemm.get_aligned_shape(M, N, K)

        if M != aligned_M or N != aligned_N or K != aligned_K:
            print(f"Warning: Dimensions ({M}, {N}, {K}) not aligned for SM100")
            print(
                f"Suggested aligned dimensions: ({aligned_M}, {aligned_N}, {aligned_K})"
            )
            print("Consider padding tensors or use create_aligned_tensors()")

    # Strict alignment check
    assert (
        M % sm100_gemm.MMA_TILE_M == 0
    ), f"M={M} must be multiple of {sm100_gemm.MMA_TILE_M}"
    assert (
        N % sm100_gemm.MMA_TILE_N == 0
    ), f"N={N} must be multiple of {sm100_gemm.MMA_TILE_N}"
    assert (
        K % sm100_gemm.MMA_TILE_K == 0
    ), f"K={K} must be multiple of {sm100_gemm.MMA_TILE_K}"

    # Create C if not provided
    if C is None:
        C = torch.zeros(M, N, dtype=torch.float32, device=A.device)
    else:
        assert C.dtype == torch.float32, f"C must be float32, got {C.dtype}"
        assert C.is_cuda, "C must be on CUDA"
        assert C.is_contiguous(), "C must be contiguous"
        assert C.shape == (
            M,
            N,
        ), f"C shape {C.shape} must match output shape ({M}, {N})"

    # Call the extension (now uses TMA internally)
    return sm100_gemm.sm100_gemm_f16(A, B, C, alpha, beta)


# Keep the old name for compatibility
sm100_gemm_f16 = sm100_gemm_f16_tma


def create_aligned_tensors(
    M, N, K, device="cuda", dtype_AB=torch.float16, dtype_C=torch.float32
):
    """
    Create properly aligned tensors for SM100 GEMM with TMA

    Returns:
        tuple: (A, B, C) tensors with aligned dimensions
    """
    aligned_M, aligned_N, aligned_K = sm100_gemm.get_aligned_shape(M, N, K)

    A = torch.zeros(aligned_M, aligned_K, dtype=dtype_AB, device=device)
    B = torch.zeros(aligned_N, aligned_K, dtype=dtype_AB, device=device)
    C = torch.zeros(aligned_M, aligned_N, dtype=dtype_C, device=device)

    return A, B, C


def pad_to_aligned(tensor, target_shape=None, dim_requirements=None):
    """
    Pad tensor to meet SM100 alignment requirements

    Args:
        tensor: Input tensor to pad
        target_shape: Specific target shape (optional)
        dim_requirements: Tuple of (M_align, N_align, K_align) requirements

    Returns:
        Padded tensor and padding info for later unpadding
    """
    if dim_requirements is None:
        dim_requirements = (
            sm100_gemm.MMA_TILE_M,
            sm100_gemm.MMA_TILE_N,
            sm100_gemm.MMA_TILE_K,
        )

    if tensor.dim() == 2:
        M, N = tensor.shape

        if target_shape:
            target_M, target_N = target_shape
        else:
            target_M = (
                (M + dim_requirements[0] - 1) // dim_requirements[0]
            ) * dim_requirements[0]
            target_N = (
                (N + dim_requirements[1] - 1) // dim_requirements[1]
            ) * dim_requirements[1]

        pad_M = target_M - M
        pad_N = target_N - N

        # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
        padded = torch.nn.functional.pad(tensor, (0, pad_N, 0, pad_M))

        return padded, (M, N, pad_M, pad_N)
    else:
        raise ValueError("Only 2D tensors supported")


def unpad_result(tensor, padding_info):
    """Remove padding from result tensor"""
    orig_M, orig_N, pad_M, pad_N = padding_info
    return tensor[:orig_M, :orig_N]


def benchmark_sm100_vs_torch(
    M=512,
    N=1024,
    K=256,
    num_warmup=1,
    num_trials=10,
    auto_align=True,
    compare_tma=True,
):
    """
    Benchmark SM100 GEMM with TMA against PyTorch's native GEMM
    """
    # Ensure dimensions are aligned
    if auto_align:
        M = (
            (M + sm100_gemm.MMA_TILE_M - 1) // sm100_gemm.MMA_TILE_M
        ) * sm100_gemm.MMA_TILE_M
        N = (
            (N + sm100_gemm.MMA_TILE_N - 1) // sm100_gemm.MMA_TILE_N
        ) * sm100_gemm.MMA_TILE_N
        K = (
            (K + sm100_gemm.MMA_TILE_K - 1) // sm100_gemm.MMA_TILE_K
        ) * sm100_gemm.MMA_TILE_K

    print(f"Benchmarking GEMM with TMA for shape: ({M}, {N}, {K})")

    # Check SM100 availability
    if not check_sm100_compatibility():
        print("SM100 not available, skipping benchmark")
        return None

    # Create test tensors
    A = torch.randn(M, K, dtype=torch.float16, device="cuda")
    B = torch.randn(N, K, dtype=torch.float16, device="cuda")
    C = torch.randn(M, N, dtype=torch.float32, device="cuda")

    # PyTorch baseline (using mixed precision)
    A_fp32 = A.float()
    B_fp32 = B.float()

    # Warmup
    for _ in range(num_warmup):
        # PyTorch GEMM
        torch_result = torch.addmm(C, A_fp32, B_fp32.T)

        # SM100 GEMM with TMA
        sm100_result = sm100_gemm_f16_tma(A, B, C.clone(), check_alignment=False)

    torch.cuda.synchronize()

    # Benchmark PyTorch
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    # warmup
    torch_result = torch.addmm(C, A_fp32, B_fp32.T)

    start.record()
    for _ in range(num_trials):
        torch_result = torch.addmm(C, A_fp32, B_fp32.T)
    end.record()
    torch.cuda.synchronize()
    torch_time = start.elapsed_time(end) / num_trials

    # Benchmark SM100 with TMA
    # warmup
    sm100_result = sm100_gemm_f16_tma(A, B, C.clone(), check_alignment=False)

    start.record()
    for _ in range(num_trials):
        sm100_result = sm100_gemm_f16_tma(A, B, C.clone(), check_alignment=False)
    end.record()
    torch.cuda.synchronize()
    sm100_time = start.elapsed_time(end) / num_trials

    # Check correctness
    max_diff = torch.max(torch.abs(torch_result - sm100_result))
    rel_error = max_diff / torch.max(torch.abs(torch_result))

    # Calculate FLOPS
    flops = 2 * M * N * K  # Multiply-add operations
    torch_tflops = flops / (torch_time * 1e-3) / 1e12
    sm100_tflops = flops / (sm100_time * 1e-3) / 1e12

    print(f"PyTorch time: {torch_time:.3f} ms ({torch_tflops:.2f} TFLOPS)")
    print(f"SM100+TMA time: {sm100_time:.3f} ms ({sm100_tflops:.2f} TFLOPS)")
    print(f"Speedup: {torch_time/sm100_time:.2f}x")
    print(f"Max difference: {max_diff:.6f}")
    print(f"Relative error: {rel_error:.6f}")
    print(f"🚀 TMA provides efficient memory transfers for large matrices!")

    return {
        "torch_time": torch_time,
        "sm100_time": sm100_time,
        "speedup": torch_time / sm100_time,
        "torch_tflops": torch_tflops,
        "sm100_tflops": sm100_tflops,
        "max_diff": max_diff.item(),
        "rel_error": rel_error.item(),
    }


# Neural network layer implementations with TMA
class SM100LinearTMA(torch.nn.Module):
    """
    Linear layer using SM100 GEMM with TMA for forward pass
    """

    def __init__(self, in_features, out_features, bias=True, device="cuda"):
        super().__init__()

        # Align dimensions
        self.orig_in_features = in_features
        self.orig_out_features = out_features

        aligned_in = (
            (in_features + sm100_gemm.MMA_TILE_K - 1) // sm100_gemm.MMA_TILE_K
        ) * sm100_gemm.MMA_TILE_K
        aligned_out = (
            (out_features + sm100_gemm.MMA_TILE_N - 1) // sm100_gemm.MMA_TILE_N
        ) * sm100_gemm.MMA_TILE_N

        self.in_features = aligned_in
        self.out_features = aligned_out

        # Parameters (with padding)
        self.weight = torch.nn.Parameter(
            torch.randn(aligned_out, aligned_in, dtype=torch.float16, device=device)
            * 0.1
        )

        if bias:
            self.bias = torch.nn.Parameter(
                torch.zeros(aligned_out, dtype=torch.float32, device=device)
            )
        else:
            self.register_parameter("bias", None)

        print(
            f"SM100LinearTMA: {in_features} -> {out_features} (aligned: {aligned_in} -> {aligned_out})"
        )
        print("🚀 Using TMA for efficient memory transfers")

    def forward(self, x):
        # Pad input if necessary
        batch_size = x.size(0)

        # Align batch size
        aligned_batch = (
            (batch_size + sm100_gemm.MMA_TILE_M - 1) // sm100_gemm.MMA_TILE_M
        ) * sm100_gemm.MMA_TILE_M

        if x.size(1) != self.in_features or batch_size != aligned_batch:
            x_padded = torch.zeros(
                aligned_batch, self.in_features, dtype=torch.float16, device=x.device
            )
            x_padded[:batch_size, : self.orig_in_features] = x
            x = x_padded

        # Prepare bias
        if self.bias is not None:
            C = (
                self.bias.unsqueeze(0)
                .expand(aligned_batch, self.out_features)
                .contiguous()
            )
            beta = 1.0
        else:
            C = torch.zeros(
                aligned_batch, self.out_features, dtype=torch.float32, device=x.device
            )
            beta = 0.0

        # SM100 GEMM with TMA: output = x @ weight^T + bias
        output = sm100_gemm_f16_tma(
            x, self.weight, C, alpha=1.0, beta=beta, check_alignment=False
        )

        # Remove padding
        return output[:batch_size, : self.orig_out_features]


def benchmark_tma_vs_cooperative_copy(M=512, N=1024, K=256, num_trials=50):
    """
    TMA addition
    """

    results = benchmark_sm100_vs_torch(M, N, K, num_trials=num_trials)

    if results:
        print(f"\nTMA-accelerated SM100 GEMM achieved:")
        print(f"   Performance: {results['sm100_tflops']:.2f} TFLOPS")
        print(f"   Speedup: {results['speedup']:.2f}x over PyTorch")
        print(f"   Memory efficiency: Hardware-optimized transfers")


def stress_test_large_matrices():
    """
    Test TMA performance with large matrices that benefit most from TMA
    """
    print("\n=== Large Matrix Stress Test with TMA ==="
Download .txt
gitextract_gcbwaf39/

├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── assets/
│   └── images/
│       ├── dev-discuss-asynctp/
│       │   └── readme.md
│       └── readme.md
├── dev/
│   ├── sr/
│   │   ├── .gitignore
│   │   ├── readme.md
│   │   ├── setup.py
│   │   ├── src/
│   │   │   ├── stochastic_rounding.cu
│   │   │   ├── stochastic_rounding.hpp
│   │   │   └── stochastic_rounding_cuda.cu
│   │   ├── test.md
│   │   ├── tests/
│   │   │   ├── benchmark.py
│   │   │   └── core_unit_tests.py
│   │   ├── usage.py
│   │   └── usage2.py
│   └── triton_groupGEMM/
│       ├── groupgemm.py
│       ├── testing/
│       │   ├── base_testing.py
│       │   └── unit_tests.py
│       ├── tma_utils.py
│       └── triton_tutorial_groupgemm.py
├── kernels/
│   ├── MoE/
│   │   └── group_GEMM/
│   │       └── triton/
│   │           ├── readme.md
│   │           ├── testing/
│   │           │   ├── fast_verification.py
│   │           │   └── pytorch_reference_backwards.py
│   │           ├── tgroup_gemm_backwards.py
│   │           ├── tgroup_gemm_forward.py
│   │           └── utils/
│   │               └── tma_utils.py
│   ├── blackwell/
│   │   ├── cute_gemm_01/
│   │   │   ├── Makefile
│   │   │   ├── build/
│   │   │   │   └── temp.linux-x86_64-cpython-312/
│   │   │   │       ├── .ninja_deps
│   │   │   │       ├── .ninja_log
│   │   │   │       ├── build.ninja
│   │   │   │       ├── sm100_gemm.o
│   │   │   │       └── sm100_gemm_pytorch.o
│   │   │   ├── dist/
│   │   │   │   └── sm100_gemm-0.0.0-py3.12-linux-x86_64.egg
│   │   │   ├── driver.py
│   │   │   ├── setup.py
│   │   │   ├── sm100_gemm.cu
│   │   │   ├── sm100_gemm.egg-info/
│   │   │   │   ├── PKG-INFO
│   │   │   │   ├── SOURCES.txt
│   │   │   │   ├── dependency_links.txt
│   │   │   │   ├── not-zip-safe
│   │   │   │   ├── requires.txt
│   │   │   │   └── top_level.txt
│   │   │   ├── sm100_gemm.h
│   │   │   └── sm100_gemm_pytorch.cpp
│   │   └── cute_gemm_02_tma/
│   │       ├── build/
│   │       │   └── temp.linux-x86_64-cpython-312/
│   │       │       ├── .ninja_deps
│   │       │       ├── .ninja_log
│   │       │       ├── build.ninja
│   │       │       ├── sm100_gemm.o
│   │       │       └── sm100_gemm_pytorch.o
│   │       ├── dist/
│   │       │   └── sm100_gemm-0.0.0-py3.12-linux-x86_64.egg
│   │       ├── driver.py
│   │       ├── setup.py
│   │       ├── sm100_gemm.cu
│   │       ├── sm100_gemm.egg-info/
│   │       │   ├── PKG-INFO
│   │       │   ├── SOURCES.txt
│   │       │   ├── dependency_links.txt
│   │       │   ├── not-zip-safe
│   │       │   ├── requires.txt
│   │       │   └── top_level.txt
│   │       ├── sm100_gemm.h
│   │       └── sm100_gemm_pytorch.cpp
│   ├── cuda/
│   │   ├── cutlass_gemm/
│   │   │   ├── broadcast_load_epilogue_c3x.hpp
│   │   │   ├── common.hpp
│   │   │   ├── cutlass.cpp
│   │   │   ├── cutlass_kernel.cu
│   │   │   ├── readme.md
│   │   │   ├── setup.py
│   │   │   └── test_cutlass_gemm.py
│   │   ├── inference/
│   │   │   ├── README.md
│   │   │   └── hadamard_transform/
│   │   │       ├── hadamard_transform.cpp
│   │   │       ├── hadamard_transform_cuda.cu
│   │   │       ├── setup.py
│   │   │       └── test.py
│   │   ├── training/
│   │   │   └── README.md
│   │   └── tutorials/
│   │       ├── README.md
│   │       └── flash2.cu
│   ├── needs_perf_help/
│   │   ├── fp8_gemm_bench.py
│   │   └── fp8_rowwise_tma_persistent.py
│   └── triton/
│       ├── inference/
│       │   ├── README.md
│       │   ├── col_major_moe_gemm/
│       │   │   ├── README.md
│       │   │   ├── perf_test_moe.py
│       │   │   ├── profile_moe.py
│       │   │   ├── results.html
│       │   │   ├── test.csv
│       │   │   ├── test_moe_gemm.py
│       │   │   ├── v0_moe_fused.py
│       │   │   ├── v1_moe_fused.py
│       │   │   └── v2_moe_fused.py
│       │   ├── flash_attention/
│       │   │   └── stay_attention.py
│       │   ├── fp8/
│       │   │   ├── float8_groupwise_quant.py
│       │   │   ├── scaled_fp8_gemm.py
│       │   │   ├── splitk_gemm_fp8.py
│       │   │   └── tma_gemm.py
│       │   ├── gptq/
│       │   │   ├── a100_qlinear.py
│       │   │   ├── benchmark.py
│       │   │   ├── h100_qlinear.py
│       │   │   ├── mixtral/
│       │   │   │   ├── test_dequant_moe_gemm.py
│       │   │   │   └── w4a16_fused_dequant_gemm.py
│       │   │   ├── small_benchmark_cuda_graphs.py
│       │   │   └── splitk_dequant_gemm.py
│       │   ├── mamba/
│       │   │   └── causal_1d_conv/
│       │   │       ├── causal_1d_conv/
│       │   │       │   └── causal_1d_conv.py
│       │   │       └── tests/
│       │   │           └── test_causal_1d_conv.py
│       │   ├── paged_attention/
│       │   │   └── attention_triton.py
│       │   └── torch_compile/
│       │       └── flash_backward.py
│       ├── training/
│       │   ├── README.md
│       │   ├── fused_softmax/
│       │   │   ├── README.md
│       │   │   └── softmax.py
│       │   └── rms_norm/
│       │       └── fused_rms_norm.py
│       └── tutorials/
│           └── README.md
├── readme.md
└── tutorials/
    └── triton/
        ├── kernels/
        │   ├── __init__.py
        │   ├── flash_attention_fwd.py
        │   ├── fused_softmax.py
        │   ├── readme.md
        │   └── vector_add.py
        └── tests/
            ├── test_softmax.py
            └── test_utils.py
Download .txt
SYMBOL INDEX (317 symbols across 52 files)

FILE: dev/sr/src/stochastic_rounding.hpp
  type philox (line 9) | namespace philox {
  class PhiloxGenerator (line 18) | class PhiloxGenerator {

FILE: dev/sr/tests/benchmark.py
  function measure_performance (line 8) | def measure_performance(func, input_tensor, warmup=0, repeats=1):
  function benchmark_sizes (line 27) | def benchmark_sizes(sizes= [1000, 10000, 100000, 1000000, 10000000, (100...
  function benchmark_shapes (line 59) | def benchmark_shapes(total_size=1000000):
  function stress_test (line 84) | def stress_test(duration=10):
  function memory_test (line 100) | def memory_test(max_size=1e9):
  function main (line 128) | def main():

FILE: dev/sr/tests/core_unit_tests.py
  class TestStochasticRounding (line 8) | class TestStochasticRounding(unittest.TestCase):
    method setup (line 9) | def setup(self):
    method _test_rounding_statistics_helper (line 14) | def _test_rounding_statistics_helper(self, value, lower_value, upper_v...
    method test_special_values (line 40) | def test_special_values(self):
    method test_small_values (line 59) | def test_small_values(self):
    method test_vectorized_loading (line 67) | def test_vectorized_loading(self):
    method test_large_values (line 81) | def test_large_values(self):
    method test_rounding_statistics (line 89) | def test_rounding_statistics(self):
    method test_rounding_statistics_2 (line 93) | def test_rounding_statistics_2(self):
    method test_rounding_statistics_small (line 97) | def test_rounding_statistics_small(self):
    method test_rounding_statistics_large (line 101) | def test_rounding_statistics_large(self):

FILE: dev/triton_groupGEMM/groupgemm.py
  function early_config_prune (line 61) | def early_config_prune(configs, named_args, dtsize=None, dtype=None, **k...
  function _kernel_grouped_gemm (line 131) | def _kernel_grouped_gemm(
  function _kernel_grouped_gemm_fp8_rowwise (line 270) | def _kernel_grouped_gemm_fp8_rowwise(
  function _grouped_gemm (line 407) | def _grouped_gemm(
  function grouped_gemm (line 518) | def grouped_gemm(
  function grouped_gemm_fp8_rowwise (line 524) | def grouped_gemm_fp8_rowwise(

FILE: dev/triton_groupGEMM/testing/base_testing.py
  class TestGroupedGEMM (line 39) | class TestGroupedGEMM(unittest.TestCase):
    method setUp (line 40) | def setUp(self) -> None:
    method test_grouped_gemm_bf16 (line 98) | def test_grouped_gemm_bf16(self) -> None:

FILE: dev/triton_groupGEMM/testing/unit_tests.py
  class TestGroupedGEMM (line 23) | class TestGroupedGEMM(unittest.TestCase):
    method test_grouped_gemm_bf16 (line 24) | def test_grouped_gemm_bf16(self) -> None:
    method test_grouped_gemm_bf16_various_dimensions (line 63) | def test_grouped_gemm_bf16_various_dimensions(self) -> None:
    method test_grouped_gemm_bf16_edge_cases (line 105) | def test_grouped_gemm_bf16_edge_cases(self) -> None:
    method test_grouped_gemm_bf16_invalid_inputs (line 168) | def test_grouped_gemm_bf16_invalid_inputs(self) -> None:
    method test_grouped_gemm_bf16_deterministic (line 214) | def test_grouped_gemm_bf16_deterministic(self) -> None:
    method test_grouped_gemm_bf16_large_matrices (line 235) | def test_grouped_gemm_bf16_large_matrices(self) -> None:

FILE: dev/triton_groupGEMM/tma_utils.py
  function map_dtype_to_triton (line 18) | def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:
  class TmaAutoTuneHelper (line 58) | class TmaAutoTuneHelper:
    class KernelParamWrapper (line 61) | class KernelParamWrapper:
      method __init__ (line 62) | def __init__(self, desc):
      method tma_desc_cpu_ptr (line 65) | def tma_desc_cpu_ptr(self):
    method __init__ (line 70) | def __init__(self):
    method init_tma_descriptor (line 83) | def init_tma_descriptor(self, name):
    method fill_1d_tma_descriptor (line 94) | def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_si...
    method fill_2d_tma_descriptor (line 110) | def fill_2d_tma_descriptor(
    method get_tma_descriptor_kernel_param (line 127) | def get_tma_descriptor_kernel_param(self, name):

FILE: dev/triton_groupGEMM/triton_tutorial_groupgemm.py
  function is_cuda (line 41) | def is_cuda():
  function supports_tma (line 45) | def supports_tma():
  function num_sms (line 49) | def num_sms():
  function grouped_matmul_kernel (line 109) | def grouped_matmul_kernel(
  function group_gemm_fn (line 187) | def group_gemm_fn(group_A, group_B):
  function grouped_matmul_tma_kernel (line 250) | def grouped_matmul_tma_kernel(
  function group_gemm_tma_fn (line 346) | def group_gemm_tma_fn(group_A, group_B):
  function triton_perf_fn (line 433) | def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size):
  function triton_tma_perf_fn (line 445) | def triton_tma_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, d...
  function torch_perf_fn (line 459) | def torch_perf_fn(group_A, group_B):
  function benchmark_square_matrices (line 484) | def benchmark_square_matrices(N, provider):
  function benchmark_batches (line 567) | def benchmark_batches(M, provider):

FILE: kernels/MoE/group_GEMM/triton/testing/fast_verification.py
  function test_backward_pass (line 23) | def test_backward_pass():

FILE: kernels/MoE/group_GEMM/triton/testing/pytorch_reference_backwards.py
  function _compute_grad_x_pytorch (line 15) | def _compute_grad_x_pytorch(grad_output, w, m_sizes, grad_x):
  function _compute_grad_w_pytorch (line 68) | def _compute_grad_w_pytorch(grad_output, x, m_sizes, grad_w):
  function _pytorch_fallback_backward (line 139) | def _pytorch_fallback_backward(grad_output, x, w, m_sizes):
  function _pytorch_reference_backward (line 165) | def _pytorch_reference_backward(grad_output, x, w, m_sizes):

FILE: kernels/MoE/group_GEMM/triton/tgroup_gemm_backwards.py
  function _kernel_grouped_gemm_backward_x_scheduled (line 28) | def _kernel_grouped_gemm_backward_x_scheduled(
  function _kernel_grouped_gemm_backward_w_scheduled (line 202) | def _kernel_grouped_gemm_backward_w_scheduled(
  function grouped_gemm_backward (line 382) | def grouped_gemm_backward(

FILE: kernels/MoE/group_GEMM/triton/tgroup_gemm_forward.py
  function _kernel_grouped_gemm (line 137) | def _kernel_grouped_gemm(
  function _kernel_grouped_gemm_fp8_rowwise (line 312) | def _kernel_grouped_gemm_fp8_rowwise(
  function _grouped_gemm (line 485) | def _grouped_gemm(
  function grouped_gemm_forward (line 626) | def grouped_gemm_forward(
  function grouped_gemm_fp8_rowwise (line 632) | def grouped_gemm_fp8_rowwise(

FILE: kernels/MoE/group_GEMM/triton/utils/tma_utils.py
  function map_dtype_to_triton (line 18) | def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:
  class TmaAutoTuneHelper (line 58) | class TmaAutoTuneHelper:
    class KernelParamWrapper (line 61) | class KernelParamWrapper:
      method __init__ (line 62) | def __init__(self, desc):
      method tma_desc_cpu_ptr (line 65) | def tma_desc_cpu_ptr(self):
    method __init__ (line 70) | def __init__(self):
    method init_tma_descriptor (line 83) | def init_tma_descriptor(self, name):
    method fill_1d_tma_descriptor (line 94) | def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_si...
    method fill_2d_tma_descriptor (line 110) | def fill_2d_tma_descriptor(
    method get_tma_descriptor_kernel_param (line 127) | def get_tma_descriptor_kernel_param(self, name):

FILE: kernels/blackwell/cute_gemm_01/driver.py
  function sm100_gemm_f16 (line 17) | def sm100_gemm_f16(A, B, C=None, alpha=1.0, beta=0.0):
  function benchmark_sm100_vs_torch (line 70) | def benchmark_sm100_vs_torch(

FILE: kernels/blackwell/cute_gemm_01/sm100_gemm_pytorch.cpp
  function is_sm100_supported (line 11) | bool is_sm100_supported() {
  function check_sm100_device (line 20) | bool check_sm100_device() {
  function sm100_gemm_f16 (line 34) | torch::Tensor sm100_gemm_f16(const torch::Tensor &A, const torch::Tensor...
  function get_device_info (line 109) | torch::Tensor get_device_info() {
  function get_aligned_shape (line 128) | std::vector<int64_t> get_aligned_shape(int64_t M, int64_t N, int64_t K) {
  function create_aligned_tensor (line 137) | torch::Tensor create_aligned_tensor(const std::vector<int64_t> &shape,
  function PYBIND11_MODULE (line 156) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: kernels/blackwell/cute_gemm_02_tma/driver.py
  function check_sm100_compatibility (line 14) | def check_sm100_compatibility():
  function sm100_gemm_f16_tma (line 38) | def sm100_gemm_f16_tma(A, B, C=None, alpha=1.0, beta=0.0, check_alignmen...
  function create_aligned_tensors (line 114) | def create_aligned_tensors(
  function pad_to_aligned (line 132) | def pad_to_aligned(tensor, target_shape=None, dim_requirements=None):
  function unpad_result (line 175) | def unpad_result(tensor, padding_info):
  function benchmark_sm100_vs_torch (line 181) | def benchmark_sm100_vs_torch(
  class SM100LinearTMA (line 284) | class SM100LinearTMA(torch.nn.Module):
    method __init__ (line 289) | def __init__(self, in_features, out_features, bias=True, device="cuda"):
    method forward (line 324) | def forward(self, x):
  function benchmark_tma_vs_cooperative_copy (line 363) | def benchmark_tma_vs_cooperative_copy(M=512, N=1024, K=256, num_trials=50):
  function stress_test_large_matrices (line 377) | def stress_test_large_matrices():
  function check_sm100_compatibility (line 528) | def check_sm100_compatibility():
  function sm100_gemm_f16 (line 552) | def sm100_gemm_f16(A, B, C=None, alpha=1.0, beta=0.0, check_alignment=Tr...
  function create_aligned_tensors (line 623) | def create_aligned_tensors(
  function pad_to_aligned (line 641) | def pad_to_aligned(tensor, target_shape=None, dim_requirements=None):
  function unpad_result (line 684) | def unpad_result(tensor, padding_info):
  function benchmark_sm100_vs_torch (line 690) | def benchmark_sm100_vs_torch(
  class SM100Linear (line 781) | class SM100Linear(torch.nn.Module):
    method __init__ (line 786) | def __init__(self, in_features, out_features, bias=True, device="cuda"):
    method forward (line 820) | def forward(self, x):

FILE: kernels/blackwell/cute_gemm_02_tma/sm100_gemm_pytorch.cpp
  function is_sm100_supported (line 11) | bool is_sm100_supported() {
  function check_sm100_device (line 20) | bool check_sm100_device() {
  function sm100_gemm_f16 (line 34) | torch::Tensor sm100_gemm_f16(const torch::Tensor &A, const torch::Tensor...
  function get_device_info (line 109) | torch::Tensor get_device_info() {
  function get_aligned_shape (line 128) | std::vector<int64_t> get_aligned_shape(int64_t M, int64_t N, int64_t K) {
  function create_aligned_tensor (line 137) | torch::Tensor create_aligned_tensor(const std::vector<int64_t> &shape,
  function PYBIND11_MODULE (line 156) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: kernels/cuda/cutlass_gemm/broadcast_load_epilogue_c3x.hpp
  type cutlass::epilogue::fusion (line 60) | namespace cutlass::epilogue::fusion {
    type Sm90RowOrScalarBroadcast (line 75) | struct Sm90RowOrScalarBroadcast {
      type SharedStorage (line 82) | struct SharedStorage {
      type Arguments (line 89) | struct Arguments {
      method Params (line 98) | static constexpr Params
      method get_workspace_size (line 104) | static size_t
      method initialize_workspace (line 110) | static cutlass::Status
      method CUTLASS_HOST_DEVICE (line 116) | CUTLASS_HOST_DEVICE
      method CUTLASS_DEVICE (line 127) | CUTLASS_DEVICE bool
      method CUTLASS_DEVICE (line 132) | CUTLASS_DEVICE bool
      method CUTLASS_DEVICE (line 137) | CUTLASS_DEVICE bool
      type ProducerLoadCallbacks (line 143) | struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks {
        method CUTLASS_DEVICE (line 154) | CUTLASS_DEVICE void
      method get_producer_load_callbacks (line 174) | CUTLASS_DEVICE auto
      type ConsumerStoreCallbacks (line 191) | struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
        method CUTLASS_DEVICE (line 202) | CUTLASS_DEVICE void
        method CUTLASS_DEVICE (line 218) | CUTLASS_DEVICE Array<Element, FragmentSize>
      method get_consumer_store_callbacks (line 235) | CUTLASS_DEVICE auto
    type Sm90ColOrScalarBroadcast (line 261) | struct Sm90ColOrScalarBroadcast {
      type SharedStorage (line 269) | struct SharedStorage { }
      type Arguments (line 274) | struct Arguments {
      method Params (line 283) | static constexpr Params
      method get_workspace_size (line 289) | static size_t
      method initialize_workspace (line 295) | static cutlass::Status
      method CUTLASS_DEVICE (line 301) | CUTLASS_DEVICE bool
      method CUTLASS_DEVICE (line 306) | CUTLASS_DEVICE bool
      method CUTLASS_DEVICE (line 311) | CUTLASS_DEVICE bool
      method CUTLASS_HOST_DEVICE (line 316) | CUTLASS_HOST_DEVICE
      method CUTLASS_HOST_DEVICE (line 319) | CUTLASS_HOST_DEVICE
      method get_producer_load_callbacks (line 326) | CUTLASS_DEVICE auto
      type ConsumerStoreCallbacks (line 332) | struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
        method CUTLASS_DEVICE (line 343) | CUTLASS_DEVICE void
        method CUTLASS_DEVICE (line 356) | CUTLASS_DEVICE Array<Element, FragmentSize>
      method get_consumer_store_callbacks (line 375) | CUTLASS_DEVICE auto

FILE: kernels/cuda/cutlass_gemm/common.hpp
  function next_pow_2 (line 15) | inline uint32_t next_pow_2(uint32_t const num) {
  function get_cuda_max_shared_memory_per_block_opt_in (line 20) | inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {

FILE: kernels/cuda/cutlass_gemm/cutlass.cpp
  function cutlass_scaled_mm (line 7) | torch::Tensor cutlass_scaled_mm(torch::Tensor a, torch::Tensor b, torch:...
  function PYBIND11_MODULE (line 17) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: kernels/cuda/inference/hadamard_transform/hadamard_transform.cpp
  function is_power_of_two (line 11) | constexpr bool is_power_of_two(uint32_t x) {
  function hadamard_transform (line 15) | torch::Tensor hadamard_transform(at::Tensor& in, bool inplace) {
  function PYBIND11_MODULE (line 58) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: kernels/cuda/inference/hadamard_transform/test.py
  function get_scale (line 24) | def get_scale(size):
  function truth_hadamard_transform_inplace (line 34) | def truth_hadamard_transform_inplace(a: torch.Tensor, truth_hadamards):
  function test_hadamard_transform_inplace_rowmajor (line 42) | def test_hadamard_transform_inplace_rowmajor(a: torch.Tensor):
  function check_correctness (line 48) | def check_correctness(m, elem_c, a, result, truth, atol=1e-2, rtol=0):

FILE: kernels/needs_perf_help/fp8_gemm_bench.py
  function bench (line 28) | def bench(cuda_graph: bool, rowwise_tma: bool=True) -> None:
  function bf16_bench (line 107) | def bf16_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]:
  function scale_row_bench (line 114) | def scale_row_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]:
  function row_gemm_bench (line 136) | def row_gemm_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]:
  function row_gemm_bench_tma (line 159) | def row_gemm_bench_tma(x: Tensor, w: Tensor) -> Callable[[], Tensor]:

FILE: kernels/needs_perf_help/fp8_rowwise_tma_persistent.py
  function get_fp8_constants (line 27) | def get_fp8_constants() -> Tuple[torch.dtype, tl.dtype, float, float]:
  function convert_fp8_type (line 46) | def convert_fp8_type(tensor, dtype) -> triton.TensorWrapper:
  function init_to_zero (line 60) | def init_to_zero(name):
  function get_configs_io_bound (line 64) | def get_configs_io_bound() -> List[Config]:
  function _kernel_matmul_fp8_row_tma_persistent (line 108) | def _kernel_matmul_fp8_row_tma_persistent(
  function matmul_fp8_row (line 235) | def matmul_fp8_row(

FILE: kernels/triton/inference/col_major_moe_gemm/perf_test_moe.py
  function torch_moe (line 16) | def torch_moe(a, w1, w2, topk_weight, topk_ids):
  function test_fused_moe (line 33) | def test_fused_moe(
  function benchmark (line 99) | def benchmark(m, provider):

FILE: kernels/triton/inference/col_major_moe_gemm/profile_moe.py
  function torch_moe (line 15) | def torch_moe(a, w1, w2, topk_weight, topk_ids):
  function test_fused_moe (line 32) | def test_fused_moe(

FILE: kernels/triton/inference/col_major_moe_gemm/test_moe_gemm.py
  function torch_moe (line 16) | def torch_moe(a, w1, w2, topk_weight, topk_ids):
  function test_fused_moe (line 39) | def test_fused_moe(

FILE: kernels/triton/inference/col_major_moe_gemm/v0_moe_fused.py
  function fused_moe_kernel (line 18) | def fused_moe_kernel(
  function moe_align_block_size (line 138) | def moe_align_block_size(
  function invoke_fused_moe_kernel (line 183) | def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.T...
  function fused_moe (line 222) | def fused_moe(hidden_states: torch.Tensor,

FILE: kernels/triton/inference/col_major_moe_gemm/v1_moe_fused.py
  function grouped_launch (line 21) | def grouped_launch(pid,
  function fused_moe_kernel_splitk (line 39) | def fused_moe_kernel_splitk(
  function moe_align_block_size (line 150) | def moe_align_block_size(
  function invoke_fused_moe_kernel (line 195) | def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.T...
  function fused_moe (line 249) | def fused_moe(hidden_states: torch.Tensor,

FILE: kernels/triton/inference/col_major_moe_gemm/v2_moe_fused.py
  function col_major (line 18) | def col_major(pid,
  function fused_moe_kernel (line 31) | def fused_moe_kernel(
  function moe_align_block_size (line 136) | def moe_align_block_size(
  function invoke_fused_moe_kernel (line 181) | def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.T...
  function fused_moe (line 220) | def fused_moe(hidden_states: torch.Tensor,

FILE: kernels/triton/inference/flash_attention/stay_attention.py
  function stay_attention (line 7) | def stay_attention(
  function flash_fn (line 107) | def flash_fn(q, k, v):

FILE: kernels/triton/inference/fp8/float8_groupwise_quant.py
  function _float8_groupwise_quant_kernel (line 21) | def _float8_groupwise_quant_kernel(
  function float8_groupwise_quantize (line 53) | def float8_groupwise_quantize(x: torch.Tensor, block_size=128):

FILE: kernels/triton/inference/fp8/scaled_fp8_gemm.py
  function grouped_launch (line 10) | def grouped_launch(pid,
  function column_major (line 27) | def column_major(pid,
  function scaled_gemm_splitk (line 39) | def scaled_gemm_splitk(a_ptr, b_ptr, c_ptr,
  function scaled_mm_splitk (line 94) | def scaled_mm_splitk(a, b, scale_a: float=1.0, scale_b: float=1.0):

FILE: kernels/triton/inference/fp8/splitk_gemm_fp8.py
  function grouped_launch (line 9) | def grouped_launch(pid,
  function col_major (line 27) | def col_major(pid,
  function gemm_split_k_kernel (line 40) | def gemm_split_k_kernel(a_ptr, b_ptr, c_ptr,
  function gemm_split_k (line 90) | def gemm_split_k(a, b):

FILE: kernels/triton/inference/fp8/tma_gemm.py
  function gemm_kernel_tma (line 7) | def gemm_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr,  #
  function matmul (line 32) | def matmul(a, b, config=None):

FILE: kernels/triton/inference/gptq/a100_qlinear.py
  function _a100_quantized_matmul (line 6) | def _a100_quantized_matmul(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr,
  class a100_qlinear (line 77) | class a100_qlinear(torch.autograd.Function):
    method forward (line 78) | def forward(ctx, a, b, scales, zeros):

FILE: kernels/triton/inference/gptq/benchmark.py
  function benchmark_generation_speed (line 15) | def benchmark_generation_speed(model, tokenizer, prompt, batch_size, dev...
  function main (line 59) | def main():

FILE: kernels/triton/inference/gptq/h100_qlinear.py
  function _h100_quantized_matmul (line 7) | def _h100_quantized_matmul(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr,
  class h100_qlinear (line 87) | class h100_qlinear(torch.autograd.Function):
    method forward (line 88) | def forward(ctx, a, b, scales, zeros):

FILE: kernels/triton/inference/gptq/mixtral/test_dequant_moe_gemm.py
  function torch_moe (line 9) | def torch_moe(a, w1, w2, topk_weight, topk_ids):
  function test_dequant_moe (line 25) | def test_dequant_moe(

FILE: kernels/triton/inference/gptq/mixtral/w4a16_fused_dequant_gemm.py
  function print_tensor_dim (line 9) | def print_tensor_dim(tensor, str_name):
  function print_value (line 13) | def print_value(value):
  function grouped_launch (line 18) | def grouped_launch(pid,
  function col_major (line 36) | def col_major(pid,
  function w4a16_fused_moe_kernel (line 50) | def w4a16_fused_moe_kernel(
  function invoke_dequant_gemm_moe (line 153) | def invoke_dequant_gemm_moe(activations: torch.Tensor,
  function moe_align_block_size (line 211) | def moe_align_block_size(
  function dequant_gemm_moe (line 255) | def dequant_gemm_moe(hidden_states: torch.Tensor,

FILE: kernels/triton/inference/gptq/small_benchmark_cuda_graphs.py
  function swizzle_tile (line 11) | def swizzle_tile(pid,
  function matmul_data_parallel_kernel (line 29) | def matmul_data_parallel_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr,
  class small_qlinear (line 109) | class small_qlinear(torch.autograd.Function):
    method forward (line 110) | def forward(ctx, a, b, scales, zeros):
  function matmul_split_k_kernel (line 161) | def matmul_split_k_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr,
  function matmul_split_k (line 228) | def matmul_split_k(a, b, scales, zeros):
  function make_tensor (line 281) | def make_tensor(M, N, dtype):
  function gen_quant4 (line 292) | def gen_quant4(m, n, groupsize=-1):

FILE: kernels/triton/inference/gptq/splitk_dequant_gemm.py
  function swizzle_tile (line 7) | def swizzle_tile(pid,
  function matmul_split_k_kernel (line 24) | def matmul_split_k_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr,
  function matmul_split_k (line 91) | def matmul_split_k(a, b, scales, zeros):
  function make_tensor (line 143) | def make_tensor(M, N, dtype):

FILE: kernels/triton/inference/mamba/causal_1d_conv/causal_1d_conv/causal_1d_conv.py
  function _causal_conv1d_fwd_kernel (line 27) | def _causal_conv1d_fwd_kernel(
  function causal_conv1d_fwd (line 121) | def causal_conv1d_fwd(
  class CausalConv1dFn (line 203) | class CausalConv1dFn(torch.autograd.Function):
    method forward (line 205) | def forward(
    method backward (line 259) | def backward(ctx, dout, *args):
  function causal_conv1d_fn (line 295) | def causal_conv1d_fn(

FILE: kernels/triton/inference/mamba/causal_1d_conv/tests/test_causal_1d_conv.py
  function _undecorated_test_causal_conv1d (line 24) | def _undecorated_test_causal_conv1d(
  function causal_conv1d_ref (line 117) | def causal_conv1d_ref(
  function test_causal_conv1d (line 183) | def test_causal_conv1d(

FILE: kernels/triton/inference/paged_attention/attention_triton.py
  function print_tensor_dim (line 14) | def print_tensor_dim(tensor, str_name):
  function print_value (line 20) | def print_value(value):
  function print_line (line 27) | def print_line(str_line):
  function paged_attention_v1 (line 33) | def paged_attention_v1(
  function paged_attention_triton_v1 (line 158) | def paged_attention_triton_v1(
  function paged_attention_v2 (line 206) | def paged_attention_v2(
  function paged_attention_triton_v2 (line 358) | def paged_attention_triton_v2(

FILE: kernels/triton/inference/torch_compile/flash_backward.py
  class MetaData (line 43) | class MetaData():
    method __init__ (line 55) | def __init__(self, sm_scale=1.0):
    method set_varlen_params (line 58) | def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
    method need_bias (line 70) | def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k):
    method need_alibi (line 77) | def need_alibi(self, alibi_slopes, batch, nheads):
    method need_causal (line 84) | def need_causal(self):
    method need_dropout (line 87) | def need_dropout(dropout_p, return_encoded_softmax):
    method check_args (line 91) | def check_args(self, q, k, v, o):
  function cdiv_fn (line 120) | def cdiv_fn(x,y):
  function max_fn (line 124) | def max_fn(x, y):
  function dropout_offsets (line 128) | def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
  function dropout_rng (line 134) | def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
  function dropout_mask (line 140) | def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
  function load_fn (line 146) | def load_fn(block_ptr, first, second, pad):
  function _attn_fwd_inner (line 158) | def _attn_fwd_inner(
  function attn_fwd (line 270) | def attn_fwd(
  function attention (line 553) | def attention(q, k, v, sm_scale):
  function _attn_bwd_preprocess (line 637) | def _attn_bwd_preprocess(
  function _bwd_kernel_dk_dv (line 691) | def _bwd_kernel_dk_dv(
  function _bwd_kernel_dq (line 761) | def _bwd_kernel_dq(dq, q, K, V,
  function _attn_bwd (line 821) | def _attn_bwd(Q, K, V, sm_scale, alibi_slopes,
  function flash_bwd (line 1016) | def flash_bwd(q: torch.Tensor, k: torch.Tensor, v:torch.Tensor, o: torch...
  function _ (line 1074) | def _(q, k, v, o, M, do):
  function flash (line 1083) | def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, o: torch.Te...
  function _ (line 1159) | def _(q, k, v, o, M):
  function setup_context (line 1163) | def setup_context(ctx, inputs, output) -> torch.Tensor:
  function backward (line 1167) | def backward(ctx, do):
  function f (line 1187) | def f(q, k, v):

FILE: kernels/triton/training/fused_softmax/softmax.py
  function _get_num_warps (line 18) | def _get_num_warps(block_size: int)-> int:
  function _softmax_kernel_fwd (line 27) | def _softmax_kernel_fwd(
  function _softmax_kernel_bwd (line 56) | def _softmax_kernel_bwd(
  class triton_softmax (line 93) | class triton_softmax(autograd.Function):
    method forward (line 95) | def forward(ctx, x):
    method backward (line 122) | def backward(ctx, grad_probs):

FILE: kernels/triton/training/rms_norm/fused_rms_norm.py
  function _rms_norm_fwd_kernel (line 33) | def _rms_norm_fwd_kernel(
  function _rms_norm_bwd_kernel_sm (line 83) | def _rms_norm_bwd_kernel_sm(
  class ttt_RMSNorm (line 181) | class ttt_RMSNorm(torch.autograd.Function):
    method forward (line 183) | def forward(ctx, x, weight, eps):
    method backward (line 225) | def backward(ctx, dy):
  function fused_rms_norm_fn (line 313) | def fused_rms_norm_fn(
  class FusedRMSNorm (line 325) | class FusedRMSNorm(torch.nn.Module):
    method __init__ (line 326) | def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, ...
    method reset_parameters (line 333) | def reset_parameters(self):
    method forward (line 336) | def forward(

FILE: tutorials/triton/kernels/fused_softmax.py
  function _get_num_warps (line 15) | def _get_num_warps(block_size: int)-> int:
  function _softmax_kernel_fwd (line 24) | def _softmax_kernel_fwd(
  function _softmax_kernel_bwd (line 53) | def _softmax_kernel_bwd(
  class triton_softmax (line 90) | class triton_softmax(autograd.Function):
    method forward (line 92) | def forward(ctx, x):
    method backward (line 119) | def backward(ctx, grad_probs):

FILE: tutorials/triton/kernels/vector_add.py
  function kernel_vector_addition (line 9) | def kernel_vector_addition(a_ptr, b_ptr, out_ptr,
  function ceil_div (line 24) | def ceil_div(x: int, y: int)-> int:
  function vector_addition (line 27) | def vector_addition(a: torch.tensor, b: torch.tensor)-> torch.tensor:

FILE: tutorials/triton/tests/test_softmax.py
  function set_seed (line 13) | def set_seed():
  class TestForwardSoftMax (line 18) | class TestForwardSoftMax:
    method test_forward_2D_float32 (line 20) | def test_forward_2D_float32(self,):
    method test_forward_2D_bfloat16 (line 36) | def test_forward_2D_bfloat16(self,):
    method test_forward_3D_bfloat16 (line 51) | def test_forward_3D_bfloat16(self,):
  class TestBackwardSoftMax (line 70) | class TestBackwardSoftMax:
    method test_backward_2D (line 72) | def test_backward_2D(self,):
    method test_bwd_3D (line 96) | def test_bwd_3D(self,):

FILE: tutorials/triton/tests/test_utils.py
  function assert_expected (line 10) | def assert_expected(
  function set_rng_seed (line 26) | def set_rng_seed(seed):
  function gpu_test (line 31) | def gpu_test(gpu_count: int = 1):
Condensed preview — 119 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (587K chars).
[
  {
    "path": ".gitignore",
    "chars": 28,
    "preview": "*.pyc\n**/.ipynb_checkpoints\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "chars": 3537,
    "preview": "# Code of Conduct\n\n## Our Pledge\n\nIn the interest of fostering an open and welcoming environment, we as\ncontributors and"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 1445,
    "preview": "# Contributing to Applied AI\nWe want to make contributing to this project as easy and transparent as\npossible.\n\n## Our D"
  },
  {
    "path": "LICENSE",
    "chars": 1449,
    "preview": "Copyright 2024 Meta\n\nRedistribution and use in source and binary forms, with or without modification, are permitted prov"
  },
  {
    "path": "assets/images/dev-discuss-asynctp/readme.md",
    "chars": 297,
    "preview": "This folder is for hosting the images for the AsyncTP public post at:    \n[https://discuss.pytorch.org/t/distributed-w-t"
  },
  {
    "path": "assets/images/readme.md",
    "chars": 43,
    "preview": "Folder for housing images for the readmes.\n"
  },
  {
    "path": "dev/sr/.gitignore",
    "chars": 73,
    "preview": "*.o\n*.ninja\n*.txt\n*.egg-info\n*.ninja-deps\n*.ninja-log/\n*.so\ndist/\nbuild/\n"
  },
  {
    "path": "dev/sr/readme.md",
    "chars": 98,
    "preview": "Branch for stochastic rounding kernel\nCurrently processes 4 elements per thread to leverage rand4\n"
  },
  {
    "path": "dev/sr/setup.py",
    "chars": 929,
    "preview": "\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\nsetup(\n    name='stoc"
  },
  {
    "path": "dev/sr/src/stochastic_rounding.cu",
    "chars": 2183,
    "preview": "\n#include <pybind11/pybind11.h>\n#include \"stochastic_rounding.hpp\"\n#include <random>\n\nnamespace py = pybind11;\n\n__host__"
  },
  {
    "path": "dev/sr/src/stochastic_rounding.hpp",
    "chars": 1217,
    "preview": "\n#pragma once\n#include <cuda_bf16.h>\n#include <cuda_runtime.h>\n#include <vector_types.h>\n#include <torch/extension.h>\n#i"
  },
  {
    "path": "dev/sr/src/stochastic_rounding_cuda.cu",
    "chars": 3637,
    "preview": " #include \"stochastic_rounding.hpp\"\n#include <cstdint>\n\n// Philox RNG implementation\n\n__device__ __forceinline__ PhiloxG"
  },
  {
    "path": "dev/sr/test.md",
    "chars": 545,
    "preview": "(tkdev11) [less@devgpu115.cco2 ~/local/applied-ai/dev/sr (sr_kernel)]$ python usage.py\nLaunching kernel with blocks=1, t"
  },
  {
    "path": "dev/sr/tests/benchmark.py",
    "chars": 5028,
    "preview": "import torch\nimport stochastic_rounding_cuda\nimport numpy as np\nimport time\nfrom tabulate import tabulate\nimport argpars"
  },
  {
    "path": "dev/sr/tests/core_unit_tests.py",
    "chars": 4443,
    "preview": "import torch\nimport numpy as np\nfrom collections import Counter\nimport unittest\nimport stochastic_rounding_cuda\nimport t"
  },
  {
    "path": "dev/sr/usage.py",
    "chars": 772,
    "preview": "import torch\nimport stochastic_rounding_cuda\n\n# Create input tensor\ninput_tensor = torch.randn(12, device='cuda', dtype="
  },
  {
    "path": "dev/sr/usage2.py",
    "chars": 359,
    "preview": "import torch\nimport stochastic_rounding_cuda\n\n# Test tensor\nx = torch.tensor([9.8751e-01, -8.5288e-01, 1.6775e+00], devi"
  },
  {
    "path": "dev/triton_groupGEMM/groupgemm.py",
    "chars": 17760,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "dev/triton_groupGEMM/testing/base_testing.py",
    "chars": 4906,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "dev/triton_groupGEMM/testing/unit_tests.py",
    "chars": 10701,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "dev/triton_groupGEMM/tma_utils.py",
    "chars": 4501,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "dev/triton_groupGEMM/triton_tutorial_groupgemm.py",
    "chars": 21412,
    "preview": "\"\"\"\nGroup GEMM\n============================\nThis group gemm kernel launches a fixed number of CTA to compute a group\nof "
  },
  {
    "path": "kernels/MoE/group_GEMM/triton/readme.md",
    "chars": 66,
    "preview": "##  Experimental\n\nTriton Group GEMM for supporting MoE training. \n"
  },
  {
    "path": "kernels/MoE/group_GEMM/triton/testing/fast_verification.py",
    "chars": 7319,
    "preview": "import logging\n\nimport torch\n\n# Configure logging\nlogging.basicConfig(\n    level=logging.INFO, format=\"%(asctime)s - %(l"
  },
  {
    "path": "kernels/MoE/group_GEMM/triton/testing/pytorch_reference_backwards.py",
    "chars": 6472,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "kernels/MoE/group_GEMM/triton/tgroup_gemm_backwards.py",
    "chars": 21469,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "kernels/MoE/group_GEMM/triton/tgroup_gemm_forward.py",
    "chars": 21277,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "kernels/MoE/group_GEMM/triton/utils/tma_utils.py",
    "chars": 4501,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/Makefile",
    "chars": 764,
    "preview": "\n# Makefile for SM100 GEMM PyTorch Extension\n\n# Set these paths according to your installation\nCUTLASS_PATH ?= /path/to/"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/.ninja_log",
    "chars": 473,
    "preview": "# ninja log v5\n0\t15279\t1748131038212164071\t/data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/build.ninja",
    "chars": 3098,
    "preview": "ninja_required_version = 1.3\ncxx = c++\nnvcc = /usr/local/cuda-12.8/bin/nvcc\n\ncflags = -pthread -B /home/less/.conda/envs"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/driver.py",
    "chars": 6895,
    "preview": "# ==============================================================================\n# python_interface.py - High-level Pyth"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/setup.py",
    "chars": 2229,
    "preview": "# setup.py\nimport os\n\nimport pybind11\nimport torch\nfrom pybind11 import get_cmake_dir\nfrom pybind11.setup_helpers import"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/sm100_gemm.cu",
    "chars": 8809,
    "preview": "// sm100_gemm_kernel.cu - CUDA kernel implementation\n#include \"sm100_gemm.h\"\n\n#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORT"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/PKG-INFO",
    "chars": 154,
    "preview": "Metadata-Version: 2.4\nName: sm100_gemm\nVersion: 0.0.0\nRequires-Python: >=3.8\nRequires-Dist: torch>=1.12.0\nDynamic: requi"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/SOURCES.txt",
    "chars": 247,
    "preview": "setup.py\nsm100_gemm.cu\nsm100_gemm_pytorch.cpp\nsm100_gemm.egg-info/PKG-INFO\nsm100_gemm.egg-info/SOURCES.txt\nsm100_gemm.eg"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/dependency_links.txt",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/not-zip-safe",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/requires.txt",
    "chars": 14,
    "preview": "torch>=1.12.0\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/top_level.txt",
    "chars": 11,
    "preview": "sm100_gemm\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/sm100_gemm.h",
    "chars": 1305,
    "preview": "// sm100_gemm_kernel.h - Header file for CUDA kernel\n#pragma once\n\n#include <cuda_runtime.h>\n\n#ifdef __cplusplus\nextern "
  },
  {
    "path": "kernels/blackwell/cute_gemm_01/sm100_gemm_pytorch.cpp",
    "chars": 6465,
    "preview": "// sm100_gemm_pytorch.cpp - PyTorch C++ extension (no CUDA code)\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext."
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/.ninja_log",
    "chars": 664,
    "preview": "# ninja log v5\n1\t15202\t1748185895110710199\t/data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.lin"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/build.ninja",
    "chars": 3126,
    "preview": "ninja_required_version = 1.3\ncxx = c++\nnvcc = /usr/local/cuda-12.8/bin/nvcc\n\ncflags = -pthread -B /home/less/.conda/envs"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/driver.py",
    "chars": 30452,
    "preview": "# python_interface.py - High-level Python interface with TMA support\n\nimport torch\n\ntry:\n    import sm100_gemm  # The co"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/setup.py",
    "chars": 2229,
    "preview": "# setup.py\nimport os\n\nimport pybind11\nimport torch\nfrom pybind11 import get_cmake_dir\nfrom pybind11.setup_helpers import"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.cu",
    "chars": 12042,
    "preview": "// sm100_gemm_kernel.cu - CUDA kernel implementation with TMA\n#include \"sm100_gemm.h\"\n\n#if defined(CUTLASS_ARCH_MMA_SM10"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/PKG-INFO",
    "chars": 154,
    "preview": "Metadata-Version: 2.4\nName: sm100_gemm\nVersion: 0.0.0\nRequires-Python: >=3.8\nRequires-Dist: torch>=1.12.0\nDynamic: requi"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/SOURCES.txt",
    "chars": 247,
    "preview": "setup.py\nsm100_gemm.cu\nsm100_gemm_pytorch.cpp\nsm100_gemm.egg-info/PKG-INFO\nsm100_gemm.egg-info/SOURCES.txt\nsm100_gemm.eg"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/dependency_links.txt",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/not-zip-safe",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/requires.txt",
    "chars": 14,
    "preview": "torch>=1.12.0\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/top_level.txt",
    "chars": 11,
    "preview": "sm100_gemm\n"
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.h",
    "chars": 1355,
    "preview": "// sm100_gemm_kernel.h - Header file for CUDA kernel\n#pragma once\n\n#include <cuda_runtime.h>\n\n#ifdef __cplusplus\nextern "
  },
  {
    "path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm_pytorch.cpp",
    "chars": 6469,
    "preview": "// sm100_gemm_pytorch.cpp - PyTorch C++ extension (no CUDA code)\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext."
  },
  {
    "path": "kernels/cuda/cutlass_gemm/broadcast_load_epilogue_c3x.hpp",
    "chars": 15348,
    "preview": "/***************************************************************************************************\n * Copyright (c) 20"
  },
  {
    "path": "kernels/cuda/cutlass_gemm/common.hpp",
    "chars": 807,
    "preview": "#pragma once\n\n#include \"cutlass/cutlass.h\"\n#include <climits>\n\n/**\n * Helper function for checking CUTLASS errors\n */\n#d"
  },
  {
    "path": "kernels/cuda/cutlass_gemm/cutlass.cpp",
    "chars": 764,
    "preview": "#include <torch/extension.h>\n#include<torch/all.h>\n\nvoid cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const&"
  },
  {
    "path": "kernels/cuda/cutlass_gemm/cutlass_kernel.cu",
    "chars": 18444,
    "preview": "// clang-format will break include orders\n// clang-format off\n#include <cudaTypedefs.h>\n\n#if defined CUDA_VERSION && CUD"
  },
  {
    "path": "kernels/cuda/cutlass_gemm/readme.md",
    "chars": 160,
    "preview": "Currently the CPP extension builds with Cutlass 3.5.1 (credit to @SamirMoustafa for the update).  \n3.6 will fail atm due"
  },
  {
    "path": "kernels/cuda/cutlass_gemm/setup.py",
    "chars": 1139,
    "preview": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\nsetup(\n    name='cutla"
  },
  {
    "path": "kernels/cuda/cutlass_gemm/test_cutlass_gemm.py",
    "chars": 490,
    "preview": "from pingpong_gemm import cutlass_scaled_mm\nimport torch\n\nm, k, n = 16, 4096, 4096\ndtype = torch.float8_e4m3fn\nout_dtype"
  },
  {
    "path": "kernels/cuda/inference/README.md",
    "chars": 13,
    "preview": "cuda kernels\n"
  },
  {
    "path": "kernels/cuda/inference/hadamard_transform/hadamard_transform.cpp",
    "chars": 2100,
    "preview": "#include <torch/extension.h>\n#include <pybind11/pybind11.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGu"
  },
  {
    "path": "kernels/cuda/inference/hadamard_transform/hadamard_transform_cuda.cu",
    "chars": 38515,
    "preview": "#include <torch/extension.h>\n#include <stdint.h>\n#include <cuda_runtime.h>\n#include <mma.h>\n#include <cuda/annotated_ptr"
  },
  {
    "path": "kernels/cuda/inference/hadamard_transform/setup.py",
    "chars": 978,
    "preview": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\nversions = [\n    \"-gen"
  },
  {
    "path": "kernels/cuda/inference/hadamard_transform/test.py",
    "chars": 4743,
    "preview": "import torch\nimport faster_hadamard_transform\nimport scipy.linalg\nimport math\n\n# set to false to check performance\ncorre"
  },
  {
    "path": "kernels/cuda/training/README.md",
    "chars": 35,
    "preview": "kernels with backward pass support\n"
  },
  {
    "path": "kernels/cuda/tutorials/README.md",
    "chars": 15,
    "preview": "CUDA tutorials\n"
  },
  {
    "path": "kernels/cuda/tutorials/flash2.cu",
    "chars": 1545,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "kernels/needs_perf_help/fp8_gemm_bench.py",
    "chars": 5124,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "kernels/needs_perf_help/fp8_rowwise_tma_persistent.py",
    "chars": 12923,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "kernels/triton/inference/README.md",
    "chars": 25,
    "preview": "Triton Inference kernels\n"
  },
  {
    "path": "kernels/triton/inference/col_major_moe_gemm/README.md",
    "chars": 497,
    "preview": "\n**MoE (Mixture of Experts) GEMM Kernels**\n\n\nTriton kernel supporting and accelerating MoE inference (Mixtral).\nThis ker"
  },
  {
    "path": "kernels/triton/inference/col_major_moe_gemm/perf_test_moe.py",
    "chars": 4397,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "kernels/triton/inference/col_major_moe_gemm/profile_moe.py",
    "chars": 1821,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "kernels/triton/inference/col_major_moe_gemm/results.html",
    "chars": 52,
    "preview": "<html><body>\n<image src=\"test.png\"/>\n</body></html>\n"
  },
  {
    "path": "kernels/triton/inference/col_major_moe_gemm/test.csv",
    "chars": 345,
    "preview": "m,Fused MoE GEMM Kernel - Column Major,vLLM MoE GEMM Kernel\n1.000000,0.412454,0.259585\n2.000000,0.883064,0.269004\n4.0000"
  },
  {
    "path": "kernels/triton/inference/col_major_moe_gemm/test_moe_gemm.py",
    "chars": 2795,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "kernels/triton/inference/col_major_moe_gemm/v0_moe_fused.py",
    "chars": 13030,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "kernels/triton/inference/col_major_moe_gemm/v1_moe_fused.py",
    "chars": 13671,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "kernels/triton/inference/col_major_moe_gemm/v2_moe_fused.py",
    "chars": 12229,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "kernels/triton/inference/flash_attention/stay_attention.py",
    "chars": 4721,
    "preview": "import triton.language as tl\nimport triton\nimport torch\n\n\n@triton.jit()\ndef stay_attention(\n    q_ptr, k_ptr, v_ptr, o_p"
  },
  {
    "path": "kernels/triton/inference/fp8/float8_groupwise_quant.py",
    "chars": 2352,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "kernels/triton/inference/fp8/scaled_fp8_gemm.py",
    "chars": 3840,
    "preview": "import torch\nimport triton\nimport triton.language as tl\nimport time\nimport os\nos.environ['ENABLE_TMA'] = '1'\n\n\n@triton.j"
  },
  {
    "path": "kernels/triton/inference/fp8/splitk_gemm_fp8.py",
    "chars": 4537,
    "preview": "import torch\nimport triton\nimport triton.language as tl\nimport time\nimport os\nos.environ['ENABLE_TMA'] = '1'\n\n@triton.ji"
  },
  {
    "path": "kernels/triton/inference/fp8/tma_gemm.py",
    "chars": 3199,
    "preview": "import triton\nimport triton.language as tl\nimport numpy as np\nimport torch\n\n@triton.jit\ndef gemm_kernel_tma(a_desc_ptr, "
  },
  {
    "path": "kernels/triton/inference/gptq/a100_qlinear.py",
    "chars": 4633,
    "preview": "import triton\nimport triton.language as tl\nimport torch \n\n@triton.jit()\ndef _a100_quantized_matmul(a_ptr, b_ptr, c_ptr, "
  },
  {
    "path": "kernels/triton/inference/gptq/benchmark.py",
    "chars": 3723,
    "preview": "import argparse\nimport time\nimport logging\nfrom tqdm import tqdm\nimport torch\nfrom transformers import AutoTokenizer\nfro"
  },
  {
    "path": "kernels/triton/inference/gptq/h100_qlinear.py",
    "chars": 4341,
    "preview": "import triton\nimport triton.language as tl\nimport torch \n\n\n@triton.jit()\ndef _h100_quantized_matmul(a_ptr, b_ptr, c_ptr,"
  },
  {
    "path": "kernels/triton/inference/gptq/mixtral/test_dequant_moe_gemm.py",
    "chars": 2876,
    "preview": "import pytest\nimport torch\nfrom vllm.model_executor.layers.fused_moe import fused_moe\nfrom vllm.model_executor.layers.ac"
  },
  {
    "path": "kernels/triton/inference/gptq/mixtral/w4a16_fused_dequant_gemm.py",
    "chars": 11959,
    "preview": "\"\"\"Fused MoE W4A16 Kernel.\"\"\"\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom vllm._C import ops\n\n@triton."
  },
  {
    "path": "kernels/triton/inference/gptq/small_benchmark_cuda_graphs.py",
    "chars": 14410,
    "preview": "import torch\nimport triton\nfrom triton import language as tl\nimport sys\nimport marlin \nimport torch.nn as nn\nfrom auto_g"
  },
  {
    "path": "kernels/triton/inference/gptq/splitk_dequant_gemm.py",
    "chars": 6121,
    "preview": "import torch\nimport triton\nfrom triton import language as tl\n# from actual_base_gptq_4 import triton_matmul4\n\n@triton.ji"
  },
  {
    "path": "kernels/triton/inference/mamba/causal_1d_conv/causal_1d_conv/causal_1d_conv.py",
    "chars": 12632,
    "preview": "# Copyright (c) 2025, IBM Research\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange"
  },
  {
    "path": "kernels/triton/inference/mamba/causal_1d_conv/tests/test_causal_1d_conv.py",
    "chars": 7734,
    "preview": "# Copyright (C) 2025, IBM Research.\n# python -m pytest tests/test_causal_conv1d.py\n\nimport sys\nfrom einops import rearra"
  },
  {
    "path": "kernels/triton/inference/paged_attention/attention_triton.py",
    "chars": 15747,
    "preview": "#from einops import rearrange\nimport torch\nimport triton\nimport triton.language as tl\n\n# Credit:\n# vedantroy https://git"
  },
  {
    "path": "kernels/triton/inference/torch_compile/flash_backward.py",
    "chars": 44315,
    "preview": "#!/usr/bin/env python\n\"\"\"\nCode copied from https://github.com/ROCm/triton/blob/triton-mlir/python/perf-kernels/flash-att"
  },
  {
    "path": "kernels/triton/training/README.md",
    "chars": 24,
    "preview": "Triton training kernels\n"
  },
  {
    "path": "kernels/triton/training/fused_softmax/README.md",
    "chars": 253,
    "preview": "Fused Softmax in Triton, supporting both inference (fwd) and training (fwd/backward). \n\nPerf testing on A100:\n\n<img widt"
  },
  {
    "path": "kernels/triton/training/fused_softmax/softmax.py",
    "chars": 4418,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "kernels/triton/training/rms_norm/fused_rms_norm.py",
    "chars": 8996,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# This software may be used and distributed according to the terms "
  },
  {
    "path": "kernels/triton/tutorials/README.md",
    "chars": 17,
    "preview": "Triton tutorials\n"
  },
  {
    "path": "readme.md",
    "chars": 2373,
    "preview": "\n### Applied AI repo\nFor experiments and research on Applied AI.\n\n### Projects\n\n#### Kernels\n\nHousing a variety of Trito"
  },
  {
    "path": "tutorials/triton/kernels/__init__.py",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "tutorials/triton/kernels/flash_attention_fwd.py",
    "chars": 19,
    "preview": "# flash forward v2\n"
  },
  {
    "path": "tutorials/triton/kernels/fused_softmax.py",
    "chars": 4292,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# ---- Fused Softmax written in Triton ----"
  },
  {
    "path": "tutorials/triton/kernels/readme.md",
    "chars": 165,
    "preview": "Triton tutorials\n\n1 - Vector Add - Starting tutorial on simple first kernel  \n2 - Fused Softmax - Full fused softmax wit"
  },
  {
    "path": "tutorials/triton/kernels/vector_add.py",
    "chars": 1291,
    "preview": "# coding up a Triton vector addition kernel\n# links to\n\nimport triton\nimport triton.language as tl \nimport torch\n\n@trito"
  },
  {
    "path": "tutorials/triton/tests/test_softmax.py",
    "chars": 4467,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\nimport pytest\nimport torch\nimport sys\nsys.p"
  },
  {
    "path": "tutorials/triton/tests/test_utils.py",
    "chars": 951,
    "preview": "from pathlib import Path\nfrom typing import Any, Dict, NamedTuple, Optional, Tuple, Union\n\nimport pytest\nimport torch\nim"
  }
]

// ... and 8 more files (download for full content)

About this extraction

This page contains the full source code of the pytorch-labs/applied-ai GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 119 files (547.0 KB), approximately 155.6k tokens, and a symbol index with 317 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!