Showing preview only (585K chars total). Download the full file or copy to clipboard to get everything.
Repository: pytorch-labs/applied-ai
Branch: main
Commit: 2391954b1998
Files: 119
Total size: 547.0 KB
Directory structure:
gitextract_gcbwaf39/
├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── assets/
│ └── images/
│ ├── dev-discuss-asynctp/
│ │ └── readme.md
│ └── readme.md
├── dev/
│ ├── sr/
│ │ ├── .gitignore
│ │ ├── readme.md
│ │ ├── setup.py
│ │ ├── src/
│ │ │ ├── stochastic_rounding.cu
│ │ │ ├── stochastic_rounding.hpp
│ │ │ └── stochastic_rounding_cuda.cu
│ │ ├── test.md
│ │ ├── tests/
│ │ │ ├── benchmark.py
│ │ │ └── core_unit_tests.py
│ │ ├── usage.py
│ │ └── usage2.py
│ └── triton_groupGEMM/
│ ├── groupgemm.py
│ ├── testing/
│ │ ├── base_testing.py
│ │ └── unit_tests.py
│ ├── tma_utils.py
│ └── triton_tutorial_groupgemm.py
├── kernels/
│ ├── MoE/
│ │ └── group_GEMM/
│ │ └── triton/
│ │ ├── readme.md
│ │ ├── testing/
│ │ │ ├── fast_verification.py
│ │ │ └── pytorch_reference_backwards.py
│ │ ├── tgroup_gemm_backwards.py
│ │ ├── tgroup_gemm_forward.py
│ │ └── utils/
│ │ └── tma_utils.py
│ ├── blackwell/
│ │ ├── cute_gemm_01/
│ │ │ ├── Makefile
│ │ │ ├── build/
│ │ │ │ └── temp.linux-x86_64-cpython-312/
│ │ │ │ ├── .ninja_deps
│ │ │ │ ├── .ninja_log
│ │ │ │ ├── build.ninja
│ │ │ │ ├── sm100_gemm.o
│ │ │ │ └── sm100_gemm_pytorch.o
│ │ │ ├── dist/
│ │ │ │ └── sm100_gemm-0.0.0-py3.12-linux-x86_64.egg
│ │ │ ├── driver.py
│ │ │ ├── setup.py
│ │ │ ├── sm100_gemm.cu
│ │ │ ├── sm100_gemm.egg-info/
│ │ │ │ ├── PKG-INFO
│ │ │ │ ├── SOURCES.txt
│ │ │ │ ├── dependency_links.txt
│ │ │ │ ├── not-zip-safe
│ │ │ │ ├── requires.txt
│ │ │ │ └── top_level.txt
│ │ │ ├── sm100_gemm.h
│ │ │ └── sm100_gemm_pytorch.cpp
│ │ └── cute_gemm_02_tma/
│ │ ├── build/
│ │ │ └── temp.linux-x86_64-cpython-312/
│ │ │ ├── .ninja_deps
│ │ │ ├── .ninja_log
│ │ │ ├── build.ninja
│ │ │ ├── sm100_gemm.o
│ │ │ └── sm100_gemm_pytorch.o
│ │ ├── dist/
│ │ │ └── sm100_gemm-0.0.0-py3.12-linux-x86_64.egg
│ │ ├── driver.py
│ │ ├── setup.py
│ │ ├── sm100_gemm.cu
│ │ ├── sm100_gemm.egg-info/
│ │ │ ├── PKG-INFO
│ │ │ ├── SOURCES.txt
│ │ │ ├── dependency_links.txt
│ │ │ ├── not-zip-safe
│ │ │ ├── requires.txt
│ │ │ └── top_level.txt
│ │ ├── sm100_gemm.h
│ │ └── sm100_gemm_pytorch.cpp
│ ├── cuda/
│ │ ├── cutlass_gemm/
│ │ │ ├── broadcast_load_epilogue_c3x.hpp
│ │ │ ├── common.hpp
│ │ │ ├── cutlass.cpp
│ │ │ ├── cutlass_kernel.cu
│ │ │ ├── readme.md
│ │ │ ├── setup.py
│ │ │ └── test_cutlass_gemm.py
│ │ ├── inference/
│ │ │ ├── README.md
│ │ │ └── hadamard_transform/
│ │ │ ├── hadamard_transform.cpp
│ │ │ ├── hadamard_transform_cuda.cu
│ │ │ ├── setup.py
│ │ │ └── test.py
│ │ ├── training/
│ │ │ └── README.md
│ │ └── tutorials/
│ │ ├── README.md
│ │ └── flash2.cu
│ ├── needs_perf_help/
│ │ ├── fp8_gemm_bench.py
│ │ └── fp8_rowwise_tma_persistent.py
│ └── triton/
│ ├── inference/
│ │ ├── README.md
│ │ ├── col_major_moe_gemm/
│ │ │ ├── README.md
│ │ │ ├── perf_test_moe.py
│ │ │ ├── profile_moe.py
│ │ │ ├── results.html
│ │ │ ├── test.csv
│ │ │ ├── test_moe_gemm.py
│ │ │ ├── v0_moe_fused.py
│ │ │ ├── v1_moe_fused.py
│ │ │ └── v2_moe_fused.py
│ │ ├── flash_attention/
│ │ │ └── stay_attention.py
│ │ ├── fp8/
│ │ │ ├── float8_groupwise_quant.py
│ │ │ ├── scaled_fp8_gemm.py
│ │ │ ├── splitk_gemm_fp8.py
│ │ │ └── tma_gemm.py
│ │ ├── gptq/
│ │ │ ├── a100_qlinear.py
│ │ │ ├── benchmark.py
│ │ │ ├── h100_qlinear.py
│ │ │ ├── mixtral/
│ │ │ │ ├── test_dequant_moe_gemm.py
│ │ │ │ └── w4a16_fused_dequant_gemm.py
│ │ │ ├── small_benchmark_cuda_graphs.py
│ │ │ └── splitk_dequant_gemm.py
│ │ ├── mamba/
│ │ │ └── causal_1d_conv/
│ │ │ ├── causal_1d_conv/
│ │ │ │ └── causal_1d_conv.py
│ │ │ └── tests/
│ │ │ └── test_causal_1d_conv.py
│ │ ├── paged_attention/
│ │ │ └── attention_triton.py
│ │ └── torch_compile/
│ │ └── flash_backward.py
│ ├── training/
│ │ ├── README.md
│ │ ├── fused_softmax/
│ │ │ ├── README.md
│ │ │ └── softmax.py
│ │ └── rms_norm/
│ │ └── fused_rms_norm.py
│ └── tutorials/
│ └── README.md
├── readme.md
└── tutorials/
└── triton/
├── kernels/
│ ├── __init__.py
│ ├── flash_attention_fwd.py
│ ├── fused_softmax.py
│ ├── readme.md
│ └── vector_add.py
└── tests/
├── test_softmax.py
└── test_utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.pyc
**/.ipynb_checkpoints
================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.
This Code of Conduct also applies outside the project spaces when there is a
reasonable belief that an individual's behavior may have a negative impact on
the project or its community.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at <opensource-conduct@meta.com>. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to Applied AI
We want to make contributing to this project as easy and transparent as
possible.
## Our Development Process
... (in particular how this is synced with internal changes to the project)
## Pull Requests
We actively welcome your pull requests.
1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Meta's open source projects.
Complete your CLA here: <https://code.facebook.com/cla>
## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.
Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.
## Coding Style
* 2 spaces for indentation rather than tabs
* 80 character line length
* ...
## License
By contributing to applied-ai, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.
================================================
FILE: LICENSE
================================================
Copyright 2024 Meta
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: assets/images/dev-discuss-asynctp/readme.md
================================================
This folder is for hosting the images for the AsyncTP public post at:
[https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487)
================================================
FILE: assets/images/readme.md
================================================
Folder for housing images for the readmes.
================================================
FILE: dev/sr/.gitignore
================================================
*.o
*.ninja
*.txt
*.egg-info
*.ninja-deps
*.ninja-log/
*.so
dist/
build/
================================================
FILE: dev/sr/readme.md
================================================
Branch for stochastic rounding kernel
Currently processes 4 elements per thread to leverage rand4
================================================
FILE: dev/sr/setup.py
================================================
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='stochastic_rounding_cuda',
version='0.1.021825',
ext_modules=[
CUDAExtension('stochastic_rounding_cuda', [
'src/stochastic_rounding.cu',
'src/stochastic_rounding_cuda.cu'
],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': [
'-O3',
'--expt-relaxed-constexpr', # better template support
#'-gencode=arch=compute_70,code=sm_70', # Volta
#'-gencode=arch=compute_75,code=sm_75', # Turing
#'-gencode=arch=compute_80,code=sm_80' # Amper
#'-gencode=arch=compute_86,code=sm_86' # Ampere
'-gencode=arch=compute_90,code=sm_90', # Hopper
]
})
],
cmdclass={
'build_ext': BuildExtension
}
)
================================================
FILE: dev/sr/src/stochastic_rounding.cu
================================================
#include <pybind11/pybind11.h>
#include "stochastic_rounding.hpp"
#include <random>
namespace py = pybind11;
__host__ int getOptimalBlockSize() {
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
return std::min(prop.maxThreadsPerBlock, 256);
}
torch::Tensor stochastic_round_bf16_cuda(torch::Tensor input, bool requires_grad) {
TORCH_CHECK(input.is_cuda(), "Input tensor must be on CUDA device");
TORCH_CHECK(input.is_contiguous(), "Input tensor must be contiguous");
TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input tensor must be float32");
const int threads_per_block = 256;
const int num_elements = input.numel();
const int elements_per_thread = 4;
const int min_blocks = (num_elements + elements_per_thread * threads_per_block - 1) /
(elements_per_thread * threads_per_block);
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
const int blocks_per_sm = 4;
const int min_blocks_for_sms = prop.multiProcessorCount * blocks_per_sm;
const int num_blocks = std::max(min_blocks, min_blocks_for_sms);
auto options = torch::TensorOptions()
.dtype(torch::kBFloat16)
.device(input.device())
.requires_grad(requires_grad);
auto output = torch::empty_like(input, options);
std::random_device rd;
std::mt19937_64 gen(rd());
std::uniform_int_distribution<unsigned long long> dis;
const unsigned long long seed = dis(gen);
stochastic_round_bf16<<<num_blocks, threads_per_block>>>(
input.data_ptr<float>(),
reinterpret_cast<__nv_bfloat16*>(output.data_ptr()),
num_elements,
seed);
cudaError_t err = cudaGetLastError();
TORCH_CHECK(err == cudaSuccess,
"CUDA kernel execution failed: ", cudaGetErrorString(err));
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("stochastic_round_bf16",
static_cast<torch::Tensor (*)(torch::Tensor, bool)>(&stochastic_round_bf16_cuda),
"Stochastic rounding to BFloat16",
py::arg("input"),
py::arg("requires_grad") = false);
}
================================================
FILE: dev/sr/src/stochastic_rounding.hpp
================================================
#pragma once
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <vector_types.h>
#include <torch/extension.h>
#include <pybind11/pybind11.h>
namespace philox {
constexpr unsigned int W32_0 = 0x9E3779B9;
constexpr unsigned int W32_1 = 0xBB67AE85;
constexpr unsigned int M0 = 0xD2511F53;
constexpr unsigned int M1 = 0xCD9E8D57;
constexpr int ROUNDS = 7;
}
// Forward declarations
class PhiloxGenerator {
public:
__device__ __forceinline__ PhiloxGenerator();
__device__ __forceinline__ void init(const unsigned long long seed, const unsigned int thread_id);
__device__ __forceinline__ uint4 next();
private:
uint2 key;
uint4 counter;
static __device__ __forceinline__ uint2 mulhilo(const unsigned int a, const unsigned int b);
static __device__ __forceinline__ uint4 round(uint4 ctr, uint2 key);
};
// CUDA kernel declaration
__global__ void stochastic_round_bf16(
float *__restrict__ input,
__nv_bfloat16 *__restrict__ output,
const int size,
const unsigned long long seed);
// Host functions
__host__ int getOptimalBlockSize();
torch::Tensor stochastic_round_bf16_cuda(torch::Tensor input, bool requires_grad = false);
================================================
FILE: dev/sr/src/stochastic_rounding_cuda.cu
================================================
#include "stochastic_rounding.hpp"
#include <cstdint>
// Philox RNG implementation
__device__ __forceinline__ PhiloxGenerator::PhiloxGenerator() :
key(make_uint2(0, 0)),
counter(make_uint4(0, 0, 0, 0)) {}
__device__ __forceinline__ void PhiloxGenerator::init(const unsigned long long seed, const unsigned int thread_id) {
key.x = static_cast<unsigned int>(seed);
key.y = static_cast<unsigned int>(seed >> 32);
counter = make_uint4(thread_id, 0, 0, 0);
__threadfence_block();
}
__device__ __forceinline__ uint2 PhiloxGenerator::mulhilo(const unsigned int a, const unsigned int b) {
uint2 result;
unsigned long long prod;
asm("mul.wide.u32 %0, %1, %2;" : "=l"(prod) : "r"(a), "r"(b));
result.x = static_cast<unsigned int>(prod);
result.y = static_cast<unsigned int>(prod >> 32);
return result;
}
__device__ __forceinline__ uint4 PhiloxGenerator::round(uint4 ctr, uint2 key) {
const uint2 mul0 = mulhilo(philox::M0, ctr.x);
const uint2 mul1 = mulhilo(philox::M1, ctr.z);
return make_uint4(
mul1.y ^ ctr.y ^ key.x,
mul1.x,
mul0.y ^ ctr.w ^ key.y,
mul0.x
);
}
__device__ __forceinline__ uint4 PhiloxGenerator::next() {
uint4 ctr = counter;
uint2 k = key;
#pragma unroll
for (int i = 0; i < philox::ROUNDS; ++i) {
ctr = round(ctr, k);
k.x += philox::W32_0;
k.y += philox::W32_1;
}
counter.x += 4;
return ctr;
}
__device__ __forceinline__ __nv_bfloat16 float_to_bf16_stochastic(const float value, const uint32_t rand) {
const uint32_t val_bits = __float_as_uint(value);
const uint32_t rounding_bits = val_bits & 0xFFFF;
uint32_t result = val_bits & 0xFFFF0000u;
result += (rand & 0xFFFF) < rounding_bits ? 0x10000u : 0;
return __float2bfloat16(__uint_as_float(result));
}
__device__ __forceinline__ void float4_to_bf16_stochastic(
const float4& values,
uint4& rand_vals,
__nv_bfloat16* output) {
float vals[4] = {values.x, values.y, values.z, values.w};
uint32_t rands[4] = {rand_vals.x, rand_vals.y, rand_vals.z, rand_vals.w};
#pragma unroll
for (int i = 0; i < 4; i++) {
output[i] = float_to_bf16_stochastic(vals[i], rands[i]);
}
}
__global__ void stochastic_round_bf16(
float *__restrict__ input,
__nv_bfloat16 *__restrict__ output,
const int size,
const unsigned long long seed) {
PhiloxGenerator rng;
rng.init(seed, blockIdx.x * blockDim.x + threadIdx.x);
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
int stride = blockDim.x * gridDim.x * 4;
float4 values;
__nv_bfloat16 local_output[4];
// Process full vectors of 4 elements
for (; idx <= size - 4; idx += stride) {
values = *reinterpret_cast<float4*>(&input[idx]);
uint4 rand = rng.next();
float4_to_bf16_stochastic(values, rand, local_output);
for (int j = 0; j < 4; j++) {
output[idx + j] = local_output[j];
}
}
// Handle remaining elements
if (idx < size) {
float remaining_values[4] = {0.0f, 0.0f, 0.0f, 0.0f};
int remainder = size - idx;
for (int j = 0; j < remainder; j++) {
remaining_values[j] = input[idx + j];
}
values.x = remaining_values[0];
values.y = remaining_values[1];
values.z = remaining_values[2];
values.w = remaining_values[3];
uint4 rand = rng.next();
float4_to_bf16_stochastic(values, rand, local_output);
for (int j = 0; j < remainder; j++) {
output[idx + j] = local_output[j];
}
}
}
================================================
FILE: dev/sr/test.md
================================================
(tkdev11) [less@devgpu115.cco2 ~/local/applied-ai/dev/sr (sr_kernel)]$ python usage.py
Launching kernel with blocks=1, threads_per_block=256, num_elements=12
Input tensor: tensor([ 0.3282, -0.4513, -1.0612, 0.1446, -0.8440, -1.4669, -0.7135, -0.6183,
-2.2411, 2.1464, 1.4772, -1.3564], device='cuda:0')
Output tensor: tensor([ 0.3281, -0.4512, -1.0625, 0.1445, -0.8438, -1.4688, -0.7109, -0.6172,
-2.2344, 2.1406, 1.4766, -1.3516], device='cuda:0',
dtype=torch.bfloat16)
Output tensor dtype: torch.bfloat16
Success!
================================================
FILE: dev/sr/tests/benchmark.py
================================================
import torch
import stochastic_rounding_cuda
import numpy as np
import time
from tabulate import tabulate
import argparse
def measure_performance(func, input_tensor, warmup=0, repeats=1):
"""Measure performance of a function with proper CUDA synchronization"""
# Warmup
for _ in range(warmup):
output = func(input_tensor)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(repeats):
output = func(input_tensor)
torch.cuda.synchronize()
end = time.perf_counter()
avg_time = (end - start) / repeats
elements_per_second = input_tensor.numel() / avg_time
return avg_time, elements_per_second
def benchmark_sizes(sizes= [1000, 10000, 100000, 1000000, 10000000, (10000000*10), (10000000*100)]):
#[ 50,000,000]): #
"""Benchmark different input sizes"""
results = []
for size in sizes:
# Create input tensor
x = torch.randn(size, device='cuda')
# Measure stochastic rounding
time_stoch, throughput_stoch = measure_performance(
stochastic_rounding_cuda.stochastic_round_bf16, x)
# Measure regular BF16 casting
time_regular, throughput_regular = measure_performance(
lambda t: t.to(torch.bfloat16), x)
results.append([
size,
time_stoch * 1000, # convert to ms
throughput_stoch / 1e6, # convert to GElements/s
time_regular * 1000,
throughput_regular / 1e6,
throughput_regular / throughput_stoch # speedup
])
print("\nSize Comparison:")
print(tabulate(results,
headers=['Size', 'Stoch Time (ms)', 'Stoch ME/s',
'Regular Time (ms)', 'Regular ME/s', 'Casting faster by'],
floatfmt='.3f'))
def benchmark_shapes(total_size=1000000):
"""Benchmark different tensor shapes with same total size"""
shapes = [
(total_size,), # 1D
(1000, total_size//1000), # 2D
(100, 100, total_size//10000), # 3D
]
results = []
for shape in shapes:
x = torch.randn(*shape, device='cuda')
time_stoch, throughput_stoch = measure_performance(
stochastic_rounding_cuda.stochastic_round_bf16, x)
results.append([
'x'.join(str(d) for d in shape),
time_stoch * 1000,
throughput_stoch / 1e9
])
print("\nShape Comparison (same total size):")
print(tabulate(results,
headers=['Shape', 'Time (ms)', 'GElements/s'],
floatfmt='.3f'))
def stress_test(duration=10):
"""Run a stress test for specified duration"""
print(f"\nRunning stress test for {duration} seconds...")
size = 1000000
x = torch.randn(size, device='cuda')
start_time = time.time()
iterations = 0
while time.time() - start_time < duration:
stochastic_rounding_cuda.stochastic_round_bf16(x)
iterations += 1
print(f"Completed {iterations} iterations without errors")
print(f"Average throughput: {(iterations * size) / (duration * 1e9):.2f} GElements/s")
def memory_test(max_size=1e9):
"""Test memory scaling"""
sizes = np.logspace(3, min(9, np.log10(max_size)), num=7, dtype=int)
results = []
for size in sizes:
try:
torch.cuda.empty_cache()
x = torch.randn(size, device='cuda')
torch.cuda.synchronize()
# Measure peak memory during operation
torch.cuda.reset_peak_memory_stats()
_ = stochastic_rounding_cuda.stochastic_round_bf16(x)
torch.cuda.synchronize()
peak_memory = torch.cuda.max_memory_allocated() / 1e6 # MB
results.append([size, peak_memory])
except RuntimeError as e:
print(f"Out of memory at size {size}")
break
print("\nMemory Usage:")
print(tabulate(results,
headers=['Size', 'Peak Memory (MB)'],
floatfmt='.2f'))
def main():
parser = argparse.ArgumentParser(description='Benchmark stochastic rounding')
parser.add_argument('--sizes', action='store_true', help='Run size benchmarks')
parser.add_argument('--shapes', action='store_true', help='Run shape benchmarks')
parser.add_argument('--stress', action='store_true', help='Run stress test')
parser.add_argument('--memory', action='store_true', help='Run memory test')
parser.add_argument('--all', action='store_true', help='Run all benchmarks')
args = parser.parse_args()
# Print device information
device = torch.cuda.get_device_properties(0)
print(f"\nRunning on: {device.name}")
print(f"Compute Capability: {device.major}.{device.minor}")
if args.all or args.sizes:
benchmark_sizes()
if args.all or args.shapes:
benchmark_shapes()
if args.all or args.stress:
stress_test()
if args.all or args.memory:
memory_test()
if __name__ == '__main__':
main()
================================================
FILE: dev/sr/tests/core_unit_tests.py
================================================
import torch
import numpy as np
from collections import Counter
import unittest
import stochastic_rounding_cuda
import time
class TestStochasticRounding(unittest.TestCase):
def setup(self):
# Ensure deterministic behavior for some tests
torch.manual_seed(42)
np.random.seed(42)
def _test_rounding_statistics_helper(self, value, lower_value, upper_value, tensor_size=10000, rounds=100):
"""Helper method for testing stochastic rounding statistics"""
print(f"\nInput value: {value}")
MAX_VARIANCE = 0.03
x = torch.full((tensor_size,), value, device='cuda')
torch.cuda.manual_seed(42)
# Single round test - isolate and show the round up and round down values
single_result = stochastic_rounding_cuda.stochastic_round_bf16(x)
print(f"Possible rounded values: {torch.unique(single_result)}")
# Multiple rounds
results = torch.empty((rounds, tensor_size), device='cuda', dtype=torch.bfloat16)
for i in range(rounds):
results[i] = stochastic_rounding_cuda.stochastic_round_bf16(x)
prob_up = (results == upper_value).float().mean().item()
print(f"Kernel's probability of rounding up: {prob_up:.4f}")
distance_to_lower = abs(value - lower_value)
total_distance = upper_value - lower_value
expected_prob = distance_to_lower / total_distance
print(f"Expected probability: {expected_prob:.4f}")
self.assertTrue(abs(prob_up - expected_prob) < MAX_VARIANCE)
def test_special_values(self):
"""Test handling of special values like inf, -inf, nan"""
special_values = torch.tensor([float('inf'), float('-inf'), float('nan'), 0.0, -0.0],
device='cuda')
rounded = stochastic_rounding_cuda.stochastic_round_bf16(special_values)
# Check inf and -inf are preserved
self.assertTrue(torch.isinf(rounded[0]))
self.assertTrue(torch.isinf(rounded[1]))
self.assertTrue(rounded[0] > 0)
self.assertTrue(rounded[1] < 0)
# Check nan is preserved
self.assertTrue(torch.isnan(rounded[2]))
# Check zeros are preserved
self.assertEqual(rounded[3].item(), 0.0)
self.assertEqual(rounded[4].item(), 0.0)
def test_small_values(self):
"""Test handling of small values near zero"""
small_values = torch.tensor([1e-38, -1e-38, 1e-20, -1e-20], device='cuda')
rounded = stochastic_rounding_cuda.stochastic_round_bf16(small_values)
# Check that very small values are handled properly
self.assertTrue(torch.all(torch.isfinite(rounded)))
def test_vectorized_loading(self):
"""Test if vectorized loading works correctly for different tensor sizes"""
sizes = [4, 8, 9, 16, 32, 100] # Test various sizes including non-aligned
for size in sizes:
x = torch.linspace(1, size, size, device='cuda')
rounded = stochastic_rounding_cuda.stochastic_round_bf16(x)
# Check output size matches input
self.assertEqual(rounded.size(0), size)
# Check dtype
self.assertEqual(rounded.dtype, torch.bfloat16)
def test_large_values(self):
"""Test handling of large values"""
large_values = torch.tensor([1e38, -1e38, 1e20, -1e20], device='cuda')
rounded = stochastic_rounding_cuda.stochastic_round_bf16(large_values)
# Values should be preserved approximately in BF16 range
self.assertTrue(torch.all(torch.isfinite(rounded)))
def test_rounding_statistics(self):
"""Test if rounding probabilities match expected distribution"""
self._test_rounding_statistics_helper(2.1999969482421875, 2.1875, 2.2031)
def test_rounding_statistics_2(self):
"""Test stochastic rounding with different BF16 boundary values"""
self._test_rounding_statistics_helper(1.7999992370605469, 1.7969, 1.8047)
def test_rounding_statistics_small(self):
"""Test stochastic rounding for number between 0 and 1"""
self._test_rounding_statistics_helper(0.7499847412109375, 0.7480, 0.7500)
def test_rounding_statistics_large(self):
"""Test stochastic rounding for large number, over 100"""
self._test_rounding_statistics_helper(128.99998474121094, 128.875, 129.000)
if __name__ == '__main__':
unittest.main(verbosity=2)
================================================
FILE: dev/sr/usage.py
================================================
import torch
import stochastic_rounding_cuda
# Create input tensor
input_tensor = torch.randn(12, device='cuda', dtype=torch.float32)
# Apply stochastic rounding
output_tensor = stochastic_rounding_cuda.stochastic_round_bf16(input_tensor)
print(f"Input tensor: {input_tensor}")
print(f"Output tensor: {output_tensor}")
print(f"Output tensor dtype: {output_tensor.dtype}")
print(f"Success!")
'''
# Test tensor
x = torch.tensor([9.8751e-01, -8.5288e-01, 1.6775e+00, -1.3683e+00,
4.0467e-01, 1.0759e-03, 2.8418e-01, -4.9392e-01,
8.7239e-01, -9.0545e-01, 1.1134e+00, 0], # -2.6872e+00
device='cuda')
# Convert to BF16
y = stochastic_rounding_cuda.stochastic_round_bf16(x)
print(f"Input: {x}")
print(f"Output: {y}")
'''
================================================
FILE: dev/sr/usage2.py
================================================
import torch
import stochastic_rounding_cuda
# Test tensor
x = torch.tensor([9.8751e-01, -8.5288e-01, 1.6775e+00], device='cuda')
# Compare with regular rounding
y_normal = x.to(torch.bfloat16)
y_stochastic = stochastic_rounding_cuda.stochastic_round_bf16(x)
print(f"Input: {x}")
print(f"Normal BF16: {y_normal}")
print(f"Stochastic BF16: {y_stochastic}")
================================================
FILE: dev/triton_groupGEMM/groupgemm.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import functools
from typing import Optional
import tma_utils as utils
import torch
import triton
import triton.language as tl
from triton.runtime import driver # @manual
_NV_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
},
num_stages=num_stages,
num_warps=num_warps,
num_ctas=num_ctas,
)
for block_size_m in [64, 128]
for block_size_n in [64, 128, 256]
for block_size_k in [64, 128, 256]
for num_stages in [3, 4]
for num_warps in [4, 8]
for num_ctas in [1]
]
_AMD_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"waves_per_eu": waves_per_cu,
"matrix_instr_nonkdim": matrix_instr_nonkdim,
},
num_stages=num_stages,
num_warps=num_warps,
)
for block_size_m in [32, 64, 128]
for block_size_n in [32, 64, 128, 256]
for block_size_k in [128, 256]
for num_stages in [1, 2]
for num_warps, waves_per_cu in [(4, 1), (8, 2), (16, 4)]
for matrix_instr_nonkdim in [16]
]
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
device = torch.cuda.current_device()
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
if dtsize is None:
dtsize = named_args["c_ptr"].element_size()
if dtype is None:
dtype = named_args["c_ptr"].dtype
pruned_configs = []
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
kw["BLOCK_SIZE_M"],
kw["BLOCK_SIZE_N"],
kw["BLOCK_SIZE_K"],
config.num_stages,
)
G, M, N, K = (
named_args["G"],
named_args["M_BUCKET"],
named_args["N"],
named_args["K"],
)
# 1. make sure we have enough smem
max_shared_memory = driver.active.utils.get_device_properties(device)[
"max_shared_mem"
]
if torch.version.hip:
required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize
else:
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
if required_shared_memory > max_shared_memory:
continue
M_PER_GROUP = M // G
MIN_M_TILES = 32 if torch.version.hip else 64
# 2. make sure we don't load M tiles that are too big
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
continue
# 3. make sure we don't load N tiles that are too small
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
continue
num_sm = driver.active.utils.get_device_properties(device)[
"multiprocessor_count"
]
N_TILES = N // BLOCK_N
MIN_N_TILES = 32 if torch.version.hip else 64
# 4. make sure we don't load N tiles that are too big
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
continue
# 5. make sure we don't load N tiles that are too small
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
continue
# 6. make sure K can be evenly divided
if K % BLOCK_K != 0:
continue
pruned_configs.append(config)
return pruned_configs
@triton.autotune(
configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_grouped_gemm(
a_desc_ptr,
b_desc_ptr,
c_ptr,
workspace,
m_sizes,
# problem sizes
G: tl.constexpr,
M_BUCKET: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_SMS: tl.constexpr,
USE_TMA_LOAD: tl.constexpr,
USE_TMA_STORE: tl.constexpr,
# tile sizes
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
) -> None:
tidx = tl.program_id(0)
dtype: tl.dtype = c_ptr.dtype.element_ty
TMA_SIZE: tl.constexpr = tl.constexpr(128)
if USE_TMA_STORE:
c_desc_ptr = workspace + tidx * TMA_SIZE
else:
c_desc_ptr = None
M_end_offset = 0
iterated_tiles = 0
for g in tl.range(G):
# Move across groups
M_start_offset = M_end_offset
m_size = tl.load(m_sizes + g)
M_end_offset = M_start_offset + m_size
if m_size > 0:
N_start_offset = g * N
n_size = N
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
num_tiles = num_m_tiles * num_n_tiles
if USE_TMA_STORE:
# pyre-ignore
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=c_desc_ptr,
global_address=c_ptr + M_start_offset * N,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, n_size],
element_ty=c_ptr.dtype.element_ty,
)
# pyre-ignore
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
# Move across tiles
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
gidx = tidx - iterated_tiles
# Split M first and N second.
tile_m_idx = gidx % num_m_tiles
tile_n_idx = gidx // num_m_tiles
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
tl.static_assert(K % BLOCK_SIZE_K == 0)
if USE_TMA_LOAD:
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
for k_offset in range(0, K, BLOCK_SIZE_K):
a = tl._experimental_descriptor_load(
a_desc_ptr,
[m_offset, k_offset],
[BLOCK_SIZE_M, BLOCK_SIZE_K],
dtype,
)
b = tl._experimental_descriptor_load(
b_desc_ptr,
[n_offset, k_offset],
[BLOCK_SIZE_N, BLOCK_SIZE_K],
dtype,
)
accumulator += tl.dot(a, b.T)
else:
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = (
a_desc_ptr
+ (M_start_offset + offs_am[:, None]) * K
+ offs_k[None, :]
)
b_ptrs = (
b_desc_ptr
+ (N_start_offset + offs_bn[:, None]) * K
+ offs_k[None, :]
)
for k_offset in range(0, K, BLOCK_SIZE_K):
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
accumulator += tl.dot(a, b.T)
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
if USE_TMA_STORE:
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
tl._experimental_descriptor_store(
c_desc_ptr,
accumulator.to(c_ptr.dtype.element_ty),
[m_offset, n_offset],
)
else:
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c = accumulator.to(c_ptr.dtype.element_ty)
tl.store(
c_ptr
+ (M_start_offset + offs_am[:, None]) * N
+ offs_bn[None, :],
c,
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
)
tidx += NUM_SMS
iterated_tiles += num_tiles
TT_FP8_DTYPE = tl.float8e4b8 if torch.version.hip else tl.float8e4nv
@triton.autotune(
configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={
"early_config_prune": functools.partial(
early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1
)
},
)
@triton.jit
def _kernel_grouped_gemm_fp8_rowwise(
a_desc_ptr,
a_scale_ptr,
b_desc_ptr,
b_scale_ptr,
c_ptr,
workspace,
m_sizes,
# problem sizes
G: tl.constexpr,
M_BUCKET: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_SMS: tl.constexpr,
USE_TMA_LOAD: tl.constexpr,
USE_TMA_STORE: tl.constexpr,
# tile sizes
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
) -> None:
tidx = tl.program_id(0)
dtype = TT_FP8_DTYPE
TMA_SIZE: tl.constexpr = tl.constexpr(128)
if USE_TMA_STORE:
c_desc_ptr = workspace + tidx * TMA_SIZE
else:
c_desc_ptr = None
M_end_offset = 0
iterated_tiles = 0
for g in tl.range(G):
# Move across groups
M_start_offset = M_end_offset
m_size = tl.load(m_sizes + g)
M_end_offset = M_start_offset + m_size
if m_size > 0:
N_start_offset = g * N
n_size = N
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
num_tiles = num_m_tiles * num_n_tiles
if USE_TMA_STORE:
# pyre-ignore
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=c_desc_ptr,
global_address=c_ptr + M_start_offset * N,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, n_size],
element_ty=c_ptr.dtype.element_ty,
)
# pyre-ignore
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
# Move across tiles
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
gidx = tidx - iterated_tiles
# Split M first and N second.
tile_m_idx = gidx % num_m_tiles
tile_n_idx = gidx // num_m_tiles
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
tl.static_assert(K % BLOCK_SIZE_K == 0)
if USE_TMA_LOAD:
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
for k_offset in range(0, K, BLOCK_SIZE_K):
a = tl._experimental_descriptor_load(
a_desc_ptr,
[m_offset, k_offset],
[BLOCK_SIZE_M, BLOCK_SIZE_K],
dtype,
)
b = tl._experimental_descriptor_load(
b_desc_ptr,
[n_offset, k_offset],
[BLOCK_SIZE_N, BLOCK_SIZE_K],
dtype,
)
accumulator += tl.dot(a, b.T)
else:
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = (
a_desc_ptr
+ (M_start_offset + offs_am[:, None]) * K
+ offs_k[None, :]
)
b_ptrs = (
b_desc_ptr
+ (N_start_offset + offs_bn[:, None]) * K
+ offs_k[None, :]
)
for k_offset in range(0, K, BLOCK_SIZE_K):
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
accumulator += tl.dot(a, b.T)
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
a_scale = tl.load(
a_scale_ptr + M_start_offset + offs_am[:, None],
mask=offs_am[:, None] < m_size,
)
b_scale = tl.load(
b_scale_ptr + N_start_offset + offs_bn[None, :],
mask=offs_bn[None, :] < n_size,
)
c = accumulator.to(tl.float32) * a_scale * b_scale
if USE_TMA_STORE:
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
tl._experimental_descriptor_store(
c_desc_ptr,
c.to(c_ptr.dtype.element_ty),
[m_offset, n_offset],
)
else:
tl.store(
c_ptr
+ (M_start_offset + offs_am[:, None]) * N
+ offs_bn[None, :],
c,
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
)
tidx += NUM_SMS
iterated_tiles += num_tiles
def _grouped_gemm(
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
x_scale: Optional[torch.Tensor] = None,
w_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if not utils.HAS_TMA_DESC:
raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
G = m_sizes.shape[0]
assert x.is_contiguous()
assert w.is_contiguous()
assert m_sizes.is_contiguous()
M, K = x.shape
N = w.shape[0] // G
assert K == w.shape[1]
y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
USE_TMA_LOAD = not torch.version.hip
USE_TMA_STORE = False
desc_helper = None
desc_x = x
desc_w = w
workspace = None
if USE_TMA_LOAD:
desc_helper = utils.TmaAutoTuneHelper()
desc_helper.init_tma_descriptor("x")
desc_helper.init_tma_descriptor("w")
desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
if USE_TMA_STORE:
workspace = torch.empty(
NUM_SMS * utils.TmaAutoTuneHelper.TMA_SIZE,
device=x.device,
dtype=torch.uint8,
)
def grid(META):
if USE_TMA_LOAD:
nonlocal desc_helper
desc_helper.fill_2d_tma_descriptor(
"x",
x.data_ptr(),
M,
K,
META["BLOCK_SIZE_M"],
META["BLOCK_SIZE_K"],
x.element_size(),
)
desc_helper.fill_2d_tma_descriptor(
"w",
w.data_ptr(),
N * G,
K,
META["BLOCK_SIZE_N"],
META["BLOCK_SIZE_K"],
w.element_size(),
)
return (NUM_SMS,)
M_BUCKET = triton.next_power_of_2(M)
if x_scale is not None and w_scale is not None:
assert x_scale.is_contiguous()
assert w_scale.is_contiguous()
_kernel_grouped_gemm_fp8_rowwise[grid](
desc_x,
x_scale,
desc_w,
w_scale,
y,
workspace,
m_sizes,
G,
M_BUCKET,
N,
K,
NUM_SMS,
USE_TMA_LOAD,
USE_TMA_STORE,
)
else:
assert x_scale is None
assert w_scale is None
_kernel_grouped_gemm[grid](
desc_x,
desc_w,
y,
workspace,
m_sizes,
G,
M_BUCKET,
N,
K,
NUM_SMS,
USE_TMA_LOAD,
USE_TMA_STORE,
)
return y
def grouped_gemm(
x: torch.Tensor, w: torch.Tensor, m_sizes: torch.Tensor
) -> torch.Tensor:
return _grouped_gemm(x, w, m_sizes)
def grouped_gemm_fp8_rowwise(
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
x_scale: torch.Tensor,
w_scale: torch.Tensor,
) -> torch.Tensor:
return _grouped_gemm(x, w, m_sizes, x_scale, w_scale)
================================================
FILE: dev/triton_groupGEMM/testing/base_testing.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import logging
# Configure logging to print to stdout
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
import os
import sys
import unittest
from typing import Tuple
import torch
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if torch.cuda.is_available():
# from fp8_gemm import quantize_fp8_row
from groupgemm import grouped_gemm # , grouped_gemm_fp8_rowwise
from tma_utils import HAS_TMA_DESC
@unittest.skipIf(
not torch.cuda.is_available()
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9
or not HAS_TMA_DESC,
"Skip when H100 or TMA is not available",
)
class TestGroupedGEMM(unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(0)
"""def test_grouped_gemm_fp8_rowwise(self) -> None:
def _test_grouped_gemm_fp8_rowwise(
shape: Tuple[int, int, int, int],
device: torch.device,
) -> None:
G, M, N, K = shape
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
m_ends, _ = torch.sort(
torch.randint(
low=0, high=M, size=[G - 1], device=device, dtype=torch.int32
)
)
m_ends = m_ends.tolist()
m_starts = [0] + m_ends
m_ends = m_ends + [M]
m_sizes = torch.tensor(
[m_ends[i] - m_starts[i] for i in range(G)], device=device
).to(torch.int32)
a_fp8, a_scale = quantize_fp8_row(a)
b_fp8, b_scale = quantize_fp8_row(b)
result = grouped_gemm_fp8_rowwise(
a_fp8,
b_fp8,
m_sizes,
a_scale,
b_scale,
)
self.assertTrue(result.shape == (M, N))
expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)
# Running baseline with quantization to exclude quantization error from the test as it has nothing to do with the correctness of the kernel implementation.
for g in range(G):
m_start = m_starts[g]
m_end = m_ends[g]
n_start = g * N
n_end = (g + 1) * N
expected_result[m_start:m_end, :] = (
a_fp8[m_start:m_end, :].to(torch.float32)
@ b_fp8[n_start:n_end, :].to(torch.float32).T
* a_scale[m_start:m_end][:, None]
* b_scale[n_start:n_end][None, :]
).to(torch.bfloat16)
torch.testing.assert_close(result, expected_result, atol=2e-2, rtol=1.6e-2)
for G in (1, 4, 16):
for M in (64, 512):
logging.info(f"Testing FP8 GMM with G={G}, M={M}")
_test_grouped_gemm_fp8_rowwise((G, M, 256, 256), torch.device("cuda"))
"""
def test_grouped_gemm_bf16(self) -> None:
def _test_grouped_gemm_bf16(
shape: Tuple[int, int, int, int],
device: torch.device,
) -> None:
G, M, N, K = shape
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
m_ends, _ = torch.sort(
torch.randint(
low=0, high=M, size=[G - 1], device=device, dtype=torch.int32
)
)
m_ends = m_ends.tolist()
m_starts = [0] + m_ends
m_ends = m_ends + [M]
m_sizes = torch.tensor(
[m_ends[i] - m_starts[i] for i in range(G)], device=device
).to(torch.int32)
result = grouped_gemm(
a,
b,
m_sizes,
)
self.assertTrue(result.shape == (M, N))
expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)
for g in range(G):
m_start = m_starts[g]
m_end = m_ends[g]
expected_result[m_start:m_end, :] = (
a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T
)
torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)
for G in (1, 4, 16):
for M in (64, 512):
logging.info(f"Testing BF16 GMM with G={G}, M={M}")
_test_grouped_gemm_bf16((G, M, 256, 256), torch.device("cuda"))
if __name__ == "__main__":
unittest.main(exit=False)
================================================
FILE: dev/triton_groupGEMM/testing/unit_tests.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
# This code is derived from: https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gemm/triton_gemm
import logging
import unittest
from typing import Tuple
import torch
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from groupgemm import grouped_gemm
class TestGroupedGEMM(unittest.TestCase):
def test_grouped_gemm_bf16(self) -> None:
def _test_grouped_gemm_bf16(
shape: Tuple[int, int, int, int],
device: torch.device,
) -> None:
G, M, N, K = shape
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
m_ends, _ = torch.sort(
torch.randint(
low=0, high=M, size=[G - 1], device=device, dtype=torch.int32
)
)
m_ends = m_ends.tolist()
m_starts = [0] + m_ends
m_ends = m_ends + [M]
m_sizes = torch.tensor(
[m_ends[i] - m_starts[i] for i in range(G)], device=device
).to(torch.int32)
result = grouped_gemm(
a,
b,
m_sizes,
)
self.assertTrue(result.shape == (M, N))
expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)
for g in range(G):
m_start = m_starts[g]
m_end = m_ends[g]
expected_result[m_start:m_end, :] = (
a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T
)
torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)
for G in (1, 4, 16):
for M in (64, 512):
logging.info(f"Testing BF16 GMM with G={G}, M={M}")
_test_grouped_gemm_bf16((G, M, 256, 256), torch.device("cuda"))
def test_grouped_gemm_bf16_various_dimensions(self) -> None:
"""Test grouped_gemm with bf16 precision and various dimensions"""
def _test_grouped_gemm_bf16(
shape: Tuple[int, int, int, int],
device: torch.device,
) -> None:
G, M, N, K = shape
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
m_ends, _ = torch.sort(
torch.randint(
low=0, high=M, size=[G - 1], device=device, dtype=torch.int32
)
)
m_ends = m_ends.tolist()
m_starts = [0] + m_ends
m_ends = m_ends + [M]
m_sizes = torch.tensor(
[m_ends[i] - m_starts[i] for i in range(G)], device=device
).to(torch.int32)
result = grouped_gemm(
a,
b,
m_sizes,
)
self.assertTrue(result.shape == (M, N))
expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)
for g in range(G):
m_start = m_starts[g]
m_end = m_ends[g]
expected_result[m_start:m_end, :] = (
a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T
)
torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)
for G in (4, 8):
for M in (128, 256):
for N, K in [(128, 256), (256, 128), (64, 64)]:
logging.info(f"Testing BF16 GMM with G={G}, M={M}, N={N}, K={K}")
_test_grouped_gemm_bf16((G, M, N, K), torch.device("cuda"))
def test_grouped_gemm_bf16_edge_cases(self) -> None:
"""Test grouped_gemm with bfloat16 for various edge cases"""
device = torch.device("cuda")
# Test with G=1 (single group case)
G, M, N, K = 1, 32, 32, 32
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
m_sizes = torch.tensor([M], device=device).to(torch.int32)
result = grouped_gemm(a, b, m_sizes)
expected_result = a @ b.T
torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)
# Test with uneven group sizes
G, M, N, K = 3, 100, 32, 32
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
m_sizes = torch.tensor([25, 50, 25], device=device).to(torch.int32)
result = grouped_gemm(a, b, m_sizes)
self.assertTrue(result.shape == (M, N))
expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)
m_start = 0
for g in range(G):
m_end = m_start + m_sizes[g].item()
expected_result[m_start:m_end, :] = (
a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T
)
m_start = m_end
torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)
# Test with extremely small matrices
G, M, N, K = 2, 8, 8, 8
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
m_sizes = torch.tensor([4, 4], device=device).to(torch.int32)
result = grouped_gemm(a, b, m_sizes)
expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)
m_start = 0
for g in range(G):
m_end = m_start + m_sizes[g].item()
expected_result[m_start:m_end, :] = (
a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T
)
m_start = m_end
torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)
# Test with large group count but small matrix sizes
G, M, N, K = 32, 128, 16, 16
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
m_sizes = torch.ones(G, device=device).to(torch.int32) * (M // G)
m_sizes[-1] += M % G # Adjust the last group size to account for remainder
result = grouped_gemm(a, b, m_sizes)
expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)
m_start = 0
for g in range(G):
m_end = m_start + m_sizes[g].item()
expected_result[m_start:m_end, :] = (
a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T
)
m_start = m_end
torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)
def test_grouped_gemm_bf16_invalid_inputs(self) -> None:
"""Test grouped_gemm with invalid inputs to ensure proper error handling"""
device = torch.device("cuda")
# Test with mismatched dimensions
G, M, N, K = 2, 64, 32, 32
a = torch.randn(
M, K + 1, dtype=torch.bfloat16, device=device
) # Wrong K dimension
b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
m_sizes = torch.tensor([32, 32], device=device).to(torch.int32)
with self.assertRaises(RuntimeError):
grouped_gemm(a, b, m_sizes)
# Test with mismatched G and m_sizes length
G, M, N, K = 2, 64, 32, 32
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
m_sizes = torch.tensor([32, 32, 32], device=device).to(
torch.int32
) # Too many groups
with self.assertRaises((RuntimeError, ValueError, IndexError)):
grouped_gemm(a, b, m_sizes)
# Test with incorrect sum of m_sizes
G, M, N, K = 2, 64, 32, 32
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
m_sizes = torch.tensor([32, 40], device=device).to(torch.int32) # Sum > M
with self.assertRaises((RuntimeError, ValueError, IndexError)):
grouped_gemm(a, b, m_sizes)
# Test with negative m_sizes values
G, M, N, K = 2, 64, 32, 32
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
m_sizes = torch.tensor([40, -8], device=device).to(
torch.int32
) # Negative group size
with self.assertRaises((RuntimeError, ValueError)):
grouped_gemm(a, b, m_sizes)
def test_grouped_gemm_bf16_deterministic(self) -> None:
"""Test that grouped_gemm produces deterministic results with the same inputs"""
G, M, N, K = 4, 128, 64, 64
device = torch.device("cuda")
# Fix the random seed for reproducibility
torch.manual_seed(42)
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
m_sizes = torch.tensor([32, 32, 32, 32], device=device).to(torch.int32)
# First run
result1 = grouped_gemm(a, b, m_sizes)
# Second run with same inputs
result2 = grouped_gemm(a, b, m_sizes)
# Results should be exactly the same
self.assertTrue(torch.all(result1 == result2))
def test_grouped_gemm_bf16_large_matrices(self) -> None:
"""Test grouped_gemm with larger matrices to stress test performance and stability"""
device = torch.device("cuda")
# Test with large matrices but fewer groups
G, M, N, K = 2, 2048, 512, 1024
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
m_sizes = torch.tensor([1024, 1024], device=device).to(torch.int32)
result = grouped_gemm(a, b, m_sizes)
expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device)
m_start = 0
for g in range(G):
m_end = m_start + m_sizes[g].item()
expected_result[m_start:m_end, :] = (
a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T
)
m_start = m_end
torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)
if __name__ == "__main__":
unittest.main(argv=["first-arg-is-ignored"], exit=False)
================================================
FILE: dev/triton_groupGEMM/tma_utils.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
# This code is derived from: https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gemm/triton_gemm
import sys
import torch
import triton # @manual
import triton.language as tl # @manual
def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:
"""
Maps torch dtype to triton dtype.
Args:
dtype (torch.dtype): input dtype.
Returns:
tl.dtype: triton dtype.
"""
if dtype == torch.float16:
return tl.float16
elif dtype == torch.bfloat16:
return tl.bfloat16
elif dtype == torch.float32:
return tl.float32
elif dtype == torch.int32:
return tl.int32
elif dtype == torch.float8_e4m3fn and torch.version.hip is None:
return tl.float8e4nv
else:
raise ValueError(f"Unsupported dtype {dtype}")
# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498).
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)
if HAS_TMA_DESC:
print(
"TMA benchmarks will be running with experimental grid constant TMA descriptor.",
file=sys.stderr,
)
else:
print(
"Missing: This group gemm code will not run without TMA descriptor support....",
file=sys.stderr,
)
raise NotImplementedError("grouped Gemm without TMA is not supported")
class TmaAutoTuneHelper:
# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
class KernelParamWrapper:
def __init__(self, desc):
self.desc = desc
def tma_desc_cpu_ptr(self):
return self.desc.data_ptr()
TMA_SIZE = 128
def __init__(self):
self.fill_1d_tma_descriptor_inner = (
triton.runtime.driver.active.utils.fill_1d_tma_descriptor
)
self.fill_2d_tma_descriptor_inner = (
triton.runtime.driver.active.utils.fill_2d_tma_descriptor
)
if HAS_TMA_DESC:
self.descriptors = {}
else:
self.cuda_descriptors = {}
# Call this method outside of the lambda function for grid size
def init_tma_descriptor(self, name):
if HAS_TMA_DESC:
self.descriptors[name] = torch.empty(
TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8
)
else:
self.cuda_descriptors[name] = torch.empty(
TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8
)
# Call this method inside the lambda function for grid size
def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, desc_x.data_ptr()
)
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, buf_x.data_ptr()
)
desc_x.copy_(buf_x, non_blocking=True)
# Call this method inside the lambda function for grid size
def fill_2d_tma_descriptor(
self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size
):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
)
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()
)
desc_x.copy_(buf_x, non_blocking=True)
def get_tma_descriptor_kernel_param(self, name):
if HAS_TMA_DESC:
assert self.descriptors[name] is not None
return self.KernelParamWrapper(self.descriptors[name])
else:
assert self.cuda_descriptors[name] is not None
return self.cuda_descriptors[name]
================================================
FILE: dev/triton_groupGEMM/triton_tutorial_groupgemm.py
================================================
"""
Group GEMM
============================
This group gemm kernel launches a fixed number of CTA to compute a group
of gemms. The scheduling is static and we do it on device.
"""
# Copyright (c) 2023 - 2025 NVIDIA Corporation & Affiliates. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# from: https://github.com/triton-lang/triton/blob/main/python/tutorials/08-grouped-gemm.py
from typing import Optional
import torch
import triton
import triton.language as tl
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"
def supports_tma():
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
def num_sms():
if is_cuda():
return torch.cuda.get_device_properties("cuda").multi_processor_count
return 148
@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"NUM_SM": 84,
}
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"NUM_SM": 128,
}
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"NUM_SM": 84,
}
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"NUM_SM": 128,
}
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"NUM_SM": num_sms(),
}
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"NUM_SM": num_sms(),
}
),
],
key=["group_size"],
)
@triton.jit
def grouped_matmul_kernel(
# device tensor of matrices pointers
group_a_ptrs,
group_b_ptrs,
group_c_ptrs,
# device tensor of gemm sizes. its shape is [group_size, 3]
# dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemm
group_gemm_sizes,
# device tensor of leading dimension sizes. its shape is [group_size, 3]
# dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemm
g_lds,
# number of gemms
group_size,
# number of virtual SM
NUM_SM: tl.constexpr,
# tile sizes
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
tile_idx = tl.program_id(0)
last_problem_end = 0
for g in range(group_size):
# get the gemm size of the current problem
gm = tl.load(group_gemm_sizes + g * 3)
gn = tl.load(group_gemm_sizes + g * 3 + 1)
gk = tl.load(group_gemm_sizes + g * 3 + 2)
num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)
num_tiles = num_m_tiles * num_n_tiles
# iterate through the tiles in the current gemm problem
while tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles:
# pick up a tile from the current gemm problem
k = gk
lda = tl.load(g_lds + g * 3)
ldb = tl.load(g_lds + g * 3 + 1)
ldc = tl.load(g_lds + g * 3 + 2)
a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16))
b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16))
c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16))
# figure out tile coordinates
tile_idx_in_gemm = tile_idx - last_problem_end
tile_m_idx = tile_idx_in_gemm // num_n_tiles
tile_n_idx = tile_idx_in_gemm % num_n_tiles
# do regular gemm here
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :]
b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :]
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
# hint to Triton compiler to do proper loop pipelining
tl.multiple_of(a_ptrs, [16, 16])
tl.multiple_of(b_ptrs, [16, 16])
# assume full tile for now
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K * ldb
c = accumulator.to(tl.float16)
offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :]
# assumes full tile for now
tl.store(c_ptrs, c)
# go to the next tile by advancing NUM_SM
tile_idx += NUM_SM
# get ready to go to the next gemm problem
last_problem_end = last_problem_end + num_tiles
def group_gemm_fn(group_A, group_B):
assert len(group_A) == len(group_B)
group_size = len(group_A)
A_addrs = []
B_addrs = []
C_addrs = []
g_sizes = []
g_lds = []
group_C = []
for i in range(group_size):
A = group_A[i]
B = group_B[i]
assert A.shape[1] == B.shape[0]
M, K = A.shape
K, N = B.shape
C = torch.empty((M, N), device=DEVICE, dtype=A.dtype)
group_C.append(C)
A_addrs.append(A.data_ptr())
B_addrs.append(B.data_ptr())
C_addrs.append(C.data_ptr())
g_sizes += [M, N, K]
g_lds += [A.stride(0), B.stride(0), C.stride(0)]
# note these are device tensors
d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)
d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)
d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)
# we use a fixed number of CTA, and it's auto-tunable
grid = lambda META: (META["NUM_SM"],)
grouped_matmul_kernel[grid](
d_a_ptrs,
d_b_ptrs,
d_c_ptrs,
d_g_sizes,
d_g_lds,
group_size,
)
return group_C
tma_configs = [
triton.Config(
{"BLOCK_SIZE_M": BM, "BLOCK_SIZE_N": BN, "BLOCK_SIZE_K": BK},
num_stages=s,
num_warps=w,
)
for BM in [128]
for BN in [128, 256]
for BK in [64, 128]
for s in ([3, 4])
for w in [4, 8]
]
@triton.autotune(
tma_configs,
key=["group_a_ptrs", "group_b_ptrs", "gropup_c_ptrs", "group_size"],
)
@triton.jit
def grouped_matmul_tma_kernel(
# device tensor of matrices pointers
group_a_ptrs,
group_b_ptrs,
group_c_ptrs,
# device tensor of gemm sizes. its shape is [group_size, 3]
# dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemm
group_gemm_sizes,
# device tensor of leading dimension sizes. its shape is [group_size, 3]
# dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemm
g_lds,
# number of gemms
group_size,
# number of virtual SM
NUM_SM: tl.constexpr,
# tile sizes
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
# is the output FP8 or FP16
FP8: tl.constexpr,
):
dtype = tl.float8e4nv if FP8 else tl.float16
tile_idx = tl.program_id(0)
last_problem_end = 0
for g in range(group_size):
# get the gemm size of the current problem
gm = tl.load(group_gemm_sizes + g * 3)
gn = tl.load(group_gemm_sizes + g * 3 + 1)
gk = tl.load(group_gemm_sizes + g * 3 + 2)
num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)
num_tiles = num_m_tiles * num_n_tiles
if tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles:
# pick up a tile from the current gemm problem
lda = tl.load(g_lds + g * 3)
ldb = tl.load(g_lds + g * 3 + 1)
ldc = tl.load(g_lds + g * 3 + 2)
a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(dtype))
b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(dtype))
c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(dtype))
a_desc = tl._experimental_make_tensor_descriptor(
a_ptr,
shape=[gm, gk],
strides=[lda, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
)
b_desc = tl._experimental_make_tensor_descriptor(
b_ptr,
shape=[gn, gk],
strides=[ldb, 1],
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
)
c_desc = tl._experimental_make_tensor_descriptor(
c_ptr,
shape=[gm, gn],
strides=[ldc, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
)
# iterate through the tiles in the current gemm problem
while (
tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles
):
k = gk
# figure out tile coordinates
tile_idx_in_gemm = tile_idx - last_problem_end
tile_m_idx = tile_idx_in_gemm // num_n_tiles
tile_n_idx = tile_idx_in_gemm % num_n_tiles
# do regular gemm here
offs_am = tile_m_idx * BLOCK_SIZE_M
offs_bn = tile_n_idx * BLOCK_SIZE_N
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
a = a_desc.load([offs_am, kk * BLOCK_SIZE_K])
b = b_desc.load([offs_bn, kk * BLOCK_SIZE_K])
accumulator += tl.dot(a, b.T)
offs_cm = tile_m_idx * BLOCK_SIZE_M
offs_cn = tile_n_idx * BLOCK_SIZE_N
c = accumulator.to(dtype)
c_desc.store([offs_cm, offs_cn], c)
# go to the next tile by advancing NUM_SM
tile_idx += NUM_SM
# get ready to go to the next gemm problem
last_problem_end = last_problem_end + num_tiles
def group_gemm_tma_fn(group_A, group_B):
assert supports_tma()
assert len(group_A) == len(group_B)
group_size = len(group_A)
A_addrs = []
B_addrs = []
C_addrs = []
g_sizes = []
g_lds = []
group_C = []
for i in range(group_size):
A = group_A[i]
B = group_B[i]
assert A.shape[1] == B.shape[1]
M, K = A.shape
N, K = B.shape
C = torch.empty((M, N), device=DEVICE, dtype=A.dtype)
group_C.append(C)
A_addrs.append(A.data_ptr())
B_addrs.append(B.data_ptr())
C_addrs.append(C.data_ptr())
g_sizes += [M, N, K]
g_lds += [A.stride(0), B.stride(0), C.stride(0)]
# note these are device tensors
d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)
d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)
d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)
# we use a fixed number of CTA, and it's auto-tunable
# TMA descriptors require a global memory allocation
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
return torch.empty(size, device="cuda", dtype=torch.int8)
triton.set_allocator(alloc_fn)
grid = lambda META: (META["NUM_SM"],)
grouped_matmul_tma_kernel[grid](
d_a_ptrs,
d_b_ptrs,
d_c_ptrs,
d_g_sizes,
d_g_lds,
group_size,
FP8=torch.float8_e4m3fn == group_A[0].dtype,
NUM_SM=num_sms(),
)
return group_C
group_m = [1024, 512, 256, 128]
group_n = [1024, 512, 256, 128]
group_k = [1024, 512, 256, 128]
group_A = []
group_B = []
group_B_T = []
assert len(group_m) == len(group_n)
assert len(group_n) == len(group_k)
group_size = len(group_m)
for i in range(group_size):
M = group_m[i]
N = group_n[i]
K = group_k[i]
A = torch.rand((M, K), device=DEVICE, dtype=torch.float16)
B = torch.rand((K, N), device=DEVICE, dtype=torch.float16)
B_T = B.T.contiguous()
group_A.append(A)
group_B.append(B)
group_B_T.append(B_T)
tri_out = group_gemm_fn(group_A, group_B)
ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)]
for i in range(group_size):
assert torch.allclose(ref_out[i], tri_out[i], atol=1e-2, rtol=0)
if supports_tma():
tri_tma_out = group_gemm_tma_fn(group_A, group_B_T)
for i in range(group_size):
assert torch.allclose(ref_out[i], tri_tma_out[i], atol=1e-2, rtol=0)
# only launch the kernel, no tensor preparation here to remove all overhead
def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size):
grid = lambda META: (META["NUM_SM"],)
grouped_matmul_kernel[grid](
a_ptrs,
b_ptrs,
c_ptrs,
sizes,
lds,
group_size,
)
def triton_tma_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, dtype):
grid = lambda META: (META["NUM_SM"],)
grouped_matmul_tma_kernel[grid](
a_ptrs,
b_ptrs,
c_ptrs,
sizes,
lds,
group_size,
FP8=torch.float8_e4m3fn == dtype,
NUM_SM=num_sms(),
)
def torch_perf_fn(group_A, group_B):
for a, b in zip(group_A, group_B):
torch.matmul(a, b)
@triton.testing.perf_report(
triton.testing.Benchmark(
# argument names to use as an x-axis for the plot
x_names=["N"],
x_vals=[2**i for i in range(7, 11)], # different possible values for `x_name`
line_arg="provider",
# argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``
line_vals=["cublas", "triton"] + (["triton-tma"] if supports_tma() else []),
# label name for the lines
line_names=["cuBLAS", "Triton"] + (["Triton + TMA"] if supports_tma() else []),
# line styles
styles=[("green", "-"), ("blue", "-")]
+ ([("red", "-")] if supports_tma() else []),
ylabel="runtime(ms)", # label name for the y-axis
plot_name="group-gemm-performance",
# name for the plot. Used also as a file name for saving the plot.
args={},
)
)
def benchmark_square_matrices(N, provider):
group_size = 4
group_A = []
group_B = []
group_B_T = []
A_addrs = []
B_addrs = []
B_T_addrs = []
C_addrs = []
g_sizes = []
g_lds = []
group_C = []
for i in range(group_size):
A = torch.rand((N, N), device=DEVICE, dtype=torch.float16)
B = torch.rand((N, N), device=DEVICE, dtype=torch.float16)
C = torch.empty((N, N), device=DEVICE, dtype=torch.float16)
B_T = B.T.contiguous()
group_A.append(A)
group_B.append(B)
group_B_T.append(B_T)
group_C.append(C)
A_addrs.append(A.data_ptr())
B_addrs.append(B.data_ptr())
B_T_addrs.append(B_T.data_ptr())
C_addrs.append(C.data_ptr())
g_sizes += [N, N, N]
g_lds += [N, N, N]
d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)
d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)
d_b_t_ptrs = torch.tensor(B_T_addrs, device=DEVICE)
d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)
quantiles = [0.5, 0.2, 0.8]
if provider == "cublas":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles
)
if provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: triton_perf_fn(
d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size
),
quantiles=quantiles,
)
if provider == "triton-tma":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: triton_tma_perf_fn(
d_a_ptrs,
d_b_t_ptrs,
d_c_ptrs,
d_g_sizes,
d_g_lds,
group_size,
dtype=torch.float16,
),
quantiles=quantiles,
)
return ms, max_ms, min_ms
@triton.testing.perf_report(
triton.testing.Benchmark(
# argument names to use as an x-axis for the plot
x_names=["M"],
x_vals=[2**i for i in range(7, 11)], # different possible values for `x_name`
line_arg="provider",
# argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``
line_vals=["cublas", "triton"] + (["triton-tma"] if supports_tma() else []),
# label name for the lines
line_names=["cuBLAS", "Triton"] + (["Triton + TMA"] if supports_tma() else []),
# line styles
styles=[("green", "-"), ("blue", "-")]
+ ([("red", "-")] if supports_tma() else []),
ylabel="runtime(ms)", # label name for the y-axis
plot_name="group-gemm-performance-m-8192-k-8192",
# name for the plot. Used also as a file name for saving the plot.
args={},
)
)
def benchmark_batches(M, provider):
N = 8192
K = 8192
group_size = 4
group_A = []
group_B = []
group_B_T = []
A_addrs = []
B_addrs = []
B_T_addrs = []
C_addrs = []
g_sizes = []
g_lds = []
g_T_lds = []
group_C = []
for i in range(group_size):
A = torch.rand((M, K), device=DEVICE, dtype=torch.float16)
B = torch.rand((K, N), device=DEVICE, dtype=torch.float16)
C = torch.empty((M, N), device=DEVICE, dtype=torch.float16)
B_T = B.T.contiguous()
group_A.append(A)
group_B.append(B)
group_B_T.append(B_T)
group_C.append(C)
A_addrs.append(A.data_ptr())
B_addrs.append(B.data_ptr())
B_T_addrs.append(B_T.data_ptr())
C_addrs.append(C.data_ptr())
g_sizes += [M, N, K]
g_lds += [A.stride(0), B.stride(0), C.stride(0)]
g_T_lds += [A.stride(0), B_T.stride(0), C.stride(0)]
d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)
d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)
d_b_t_ptrs = torch.tensor(B_T_addrs, device=DEVICE)
d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)
d_g_t_lds = torch.tensor(g_T_lds, dtype=torch.int32, device=DEVICE)
quantiles = [0.5, 0.2, 0.8]
if provider == "cublas":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles
)
if provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: triton_perf_fn(
d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size
),
quantiles=quantiles,
)
if provider == "triton-tma":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: triton_tma_perf_fn(
d_a_ptrs,
d_b_t_ptrs,
d_c_ptrs,
d_g_sizes,
d_g_t_lds,
group_size,
dtype=torch.float16,
),
quantiles=quantiles,
)
return ms, max_ms, min_ms
benchmark_square_matrices.run(show_plots=True, print_data=True)
benchmark_batches.run(show_plots=True, print_data=True)
================================================
FILE: kernels/MoE/group_GEMM/triton/readme.md
================================================
## Experimental
Triton Group GEMM for supporting MoE training.
================================================
FILE: kernels/MoE/group_GEMM/triton/testing/fast_verification.py
================================================
import logging
import torch
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# import the reference implementations
from pytorch_reference_backwards import (
_compute_grad_w_pytorch,
_compute_grad_x_pytorch,
_pytorch_fallback_backward,
_pytorch_reference_backward,
)
# Import the grouped GEMM modules
from tgrouped_gemm_backwards import grouped_gemm_backward
from tgrouped_gemm_forward import grouped_gemm_forward as grouped_gemm
def test_backward_pass():
"""
A simple test for the grouped GEMM backward pass with detailed error handling.
"""
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Test parameters
G = 20 # Number of groups
M = 1024 # Input dimension
N = 512 # Output dimension per group
K = 256 # Hidden dimension
# Create input and weight tensors
x = torch.randn(M, K, dtype=torch.bfloat16, device=device, requires_grad=True)
w = torch.randn(
N * G, K, dtype=torch.bfloat16, device=device, requires_grad=True
)
# Create group sizes
m_sizes = torch.zeros(G, device=device, dtype=torch.int32)
base_size = M // G
remainder = M % G
for i in range(G):
m_sizes[i] = base_size + (1 if i < remainder else 0)
# Log the setup
print(f"Test setup - G: {G}, M: {M}, N: {N}, K: {K}")
print(f"Input x shape: {x.shape}")
logging.info(f"Weight w shape: {w.shape}")
logging.info(f"Group sizes: {m_sizes}")
# Step 1: Run forward pass
logging.info("Running forward pass")
result = grouped_gemm(x, w, m_sizes)
logging.info(f"Forward result shape: {result.shape}")
# Create a gradient for backpropagation
grad_output = torch.randn_like(result)
logging.info(f"Created gradient with shape: {grad_output.shape}")
# Step 2: Run backward pass directly
logging.info("Running backward pass directly")
grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes)
# Verify gradient shapes
logging.info(
f"Gradient shapes - grad_x: {grad_x.shape}, grad_w: {grad_w.shape}"
)
# Step 3: Verify gradient computation using PyTorch's autograd
# First create autograd-enabled tensors
x_autograd = x.detach().clone().requires_grad_(True)
w_autograd = w.detach().clone().requires_grad_(True)
# Create a PyTorch reference implementation to compare against
logging.info("Running PyTorch reference implementation")
# Compute reference result
reference_result = torch.zeros_like(result)
m_start = 0
for g in range(G):
m_size = m_sizes[g].item()
m_end = m_start + m_size
n_start = g * N
n_end = (g + 1) * N
if m_size > 0:
reference_result[m_start:m_end, n_start:n_end] = (
x_autograd[m_start:m_end, :] @ w_autograd[n_start:n_end, :].T
)
m_start = m_end
# Backpropagate using PyTorch
reference_result.backward(grad_output)
# Compare gradients
logging.info("Comparing gradients with PyTorch reference")
grad_x_error = (grad_x - x_autograd.grad).abs().max().item()
grad_w_error = (grad_w - w_autograd.grad).abs().max().item()
logging.info(
f"Maximum gradient error - grad_x: {grad_x_error}, grad_w: {grad_w_error}"
)
# Check if gradients are close using allclose
rtol = 1e-2 # Relative tolerance for bfloat16
atol = 1e-2 # Absolute tolerance for bfloat16
grad_x_close = torch.allclose(grad_x, x_autograd.grad, rtol=rtol, atol=atol)
if not grad_x_close:
logging.warning("FAILED: Gradient mismatch detected in grad_x")
else:
logging.info(
"✓ SUCCESS! grad_X matches the PyTorch reference (allclose check passed)"
)
grad_w_close = torch.allclose(grad_w, w_autograd.grad, rtol=rtol, atol=atol)
if not grad_w_close:
logging.warning("FAILED: Gradient mismatch detected in grad_w")
else:
logging.info(
"✓ SUCCESS! grad_W matches the PyTorch reference (allclose check passed)"
)
logging.info(
f"Gradients allclose check - grad_x: {grad_x_close}, grad_w: {grad_w_close}"
)
if grad_x_close and grad_w_close:
logging.info(
"✓ SUCCESS: Gradients match the PyTorch reference (allclose check passed)"
)
else:
logging.error("✗ FAILURE: Gradient mismatch detected in allclose check")
# Additional diagnostics (for failed cases or in general)
if True: # not grad_x_close:
# Find where the largest differences are
diff_x = (grad_x - x_autograd.grad).abs()
max_idx_x = diff_x.argmax().item()
flat_idx_x = max_idx_x
idx_x = np.unravel_index(flat_idx_x, grad_x.shape)
logging.error(
f"Largest grad_x difference at {idx_x}: "
f"{grad_x[idx_x].item()} vs {x_autograd.grad[idx_x].item()}"
)
# Count zeros
zeros_grad_x = (grad_x == 0).sum().item()
zeros_autograd_x = (x_autograd.grad == 0).sum().item()
logging.error(
f"Zeros in grad_x: {zeros_grad_x}/{grad_x.numel()} ({zeros_grad_x/grad_x.numel()*100:.2f}%)"
)
logging.error(
f"Zeros in x_autograd.grad: {zeros_autograd_x}/{x_autograd.grad.numel()} ({zeros_autograd_x/x_autograd.grad.numel()*100:.2f}%)"
)
if True: # not grad_w_close:
# Find where the largest differences are
diff_w = (grad_w - w_autograd.grad).abs()
max_idx_w = diff_w.argmax().item()
flat_idx_w = max_idx_w
idx_w = np.unravel_index(flat_idx_w, grad_w.shape)
logging.error(
f"Largest grad_w difference at {idx_w}: "
f"{grad_w[idx_w].item()} vs {w_autograd.grad[idx_w].item()}"
)
# Count zeros
zeros_grad_w = (grad_w == 0).sum().item()
zeros_autograd_w = (w_autograd.grad == 0).sum().item()
logging.error(
f"Zeros in grad_w: {zeros_grad_w}/{grad_w.numel()} ({zeros_grad_w/grad_w.numel()*100:.2f}%)"
)
logging.error(
f"Zeros in w_autograd.grad: {zeros_autograd_w}/{w_autograd.grad.numel()} ({zeros_autograd_w/w_autograd.grad.numel()*100:.2f}%)"
)
return grad_x_close and grad_w_close
except Exception as e:
logging.error(f"Test failed with error: {e}")
import traceback
logging.error(traceback.format_exc())
return False
if __name__ == "__main__":
print("Running test_backward_pass")
logging.debug("Running test_backward_pass")
# Add numpy import for unravel_index
import numpy as np
success = test_backward_pass()
logging.info(f"Test {'succeeded' if success else 'failed'}")
================================================
FILE: kernels/MoE/group_GEMM/triton/testing/pytorch_reference_backwards.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
# This is a series of helper functions for grouped GEMM backward that compute the gradients
# using eager PyTorch operations. They are used as a verification reference for the Triton kernels.
# They can also used as a fallback when the Triton kernels cannot be used, though lets hope that is not needed.
def _compute_grad_x_pytorch(grad_output, w, m_sizes, grad_x):
"""
Compute grad_x using pure PyTorch operations with FP32 precision
"""
G = m_sizes.shape[0]
M, K = grad_x.shape
N = w.shape[0] // G
# Zero out the output tensor first
grad_x.zero_()
# Store original dtype and convert to float32 for computation
orig_dtype = grad_x.dtype
grad_output_fp32 = grad_output.float()
w_fp32 = w.float()
grad_x_fp32 = torch.zeros_like(grad_x, dtype=torch.float32)
# Process each group separately
m_start = 0
for g in range(G):
m_size = m_sizes[g].item()
if m_size > 0:
m_end = m_start + m_size
n_start = g * N
n_end = (g + 1) * N
# Get slices for this group
grad_output_slice = grad_output_fp32[m_start:m_end, n_start:n_end]
w_slice = w_fp32[n_start:n_end]
# Process in chunks for better precision on large matrices
CHUNK_SIZE = 256
for chunk_start in range(0, m_size, CHUNK_SIZE):
chunk_end = min(chunk_start + CHUNK_SIZE, m_size)
chunk_size = chunk_end - chunk_start
# Compute matrix multiplication with higher precision
grad_output_chunk = grad_output_slice[chunk_start:chunk_end]
result_chunk = torch.matmul(
grad_output_chunk.double(), w_slice.double()
)
# Store the result
grad_x_fp32[m_start + chunk_start : m_start + chunk_end].copy_(
result_chunk.float()
)
m_start = m_end
# Convert back to original dtype
grad_x.copy_(grad_x_fp32.to(orig_dtype))
def _compute_grad_w_pytorch(grad_output, x, m_sizes, grad_w):
"""
Compute grad_w using pure PyTorch operations with FP64 precision for better accuracy.
"""
G = m_sizes.shape[0]
N_times_G, K = grad_w.shape
N = N_times_G // G
# Zero out the output tensor first
grad_w.zero_()
# Store original dtype and convert to float32 for computation
orig_dtype = grad_w.dtype
grad_output_fp32 = grad_output.float()
x_fp32 = x.float()
grad_w_fp32 = torch.zeros_like(grad_w, dtype=torch.float32)
# Handle potential K dimension mismatches
K_x = x.shape[1]
min_K = min(K, K_x)
# Process each group separately
m_start = 0
for g in range(G):
m_size = m_sizes[g].item()
if m_size > 0:
m_end = m_start + m_size
n_start = g * N
n_end = (g + 1) * N
# Get slices for this group
grad_output_slice = grad_output_fp32[m_start:m_end, n_start:n_end]
x_slice = x_fp32[m_start:m_end, :min_K]
# Process in chunks for better precision
CHUNK_SIZE = 32
result = torch.zeros(
(grad_output_slice.shape[1], min_K),
dtype=torch.float64,
device=grad_output_slice.device,
)
for chunk_start in range(0, m_size, CHUNK_SIZE):
chunk_end = min(chunk_start + CHUNK_SIZE, m_size)
# Get chunks
grad_output_chunk = grad_output_slice[chunk_start:chunk_end].double()
x_chunk = x_slice[chunk_start:chunk_end].double()
# Matrix multiplication in FP64
chunk_result = torch.matmul(grad_output_chunk.t(), x_chunk)
result += chunk_result
# Handle K dimension padding if needed
if K > min_K:
temp_result = torch.zeros(
(grad_output_slice.shape[1], K),
dtype=torch.float32,
device=grad_output_slice.device,
)
temp_result[:, :min_K] = result.float()
grad_w_fp32[n_start:n_end].copy_(temp_result)
else:
grad_w_fp32[n_start:n_end].copy_(result.float())
m_start = m_end
# Convert back to original dtype
grad_w.copy_(grad_w_fp32.to(orig_dtype))
def _pytorch_fallback_backward(grad_output, x, w, m_sizes):
"""
Pure PyTorch implementation of grouped GEMM backward with high precision.
Used as a fallback when the Triton kernels cannot be used.
"""
logging.info(
"WARNING: Using PyTorch fallback for grouped GEMM backward with high precision"
)
# Ensure inputs are contiguous
x = x.contiguous()
w = w.contiguous()
grad_output = grad_output.contiguous()
m_sizes = m_sizes.contiguous()
# Allocate output tensors
grad_x = torch.zeros_like(x)
grad_w = torch.zeros_like(w)
# Compute gradients using the helper functions
_compute_grad_x_pytorch(grad_output, w, m_sizes, grad_x)
_compute_grad_w_pytorch(grad_output, x, m_sizes, grad_w)
return grad_x, grad_w
def _pytorch_reference_backward(grad_output, x, w, m_sizes):
"""
Pure PyTorch implementation of grouped GEMM backward for validation.
Simple version that's easy to verify but may be less numerically accurate
for large matrices.
"""
# Create output gradients
grad_x = torch.zeros_like(x)
grad_w = torch.zeros_like(w)
# Compute group-by-group
G = m_sizes.shape[0]
N = w.shape[0] // G
m_start = 0
for g in range(G):
m_size = m_sizes[g].item()
if m_size > 0:
m_end = m_start + m_size
n_start = g * N
n_end = (g + 1) * N
# Compute gradients
grad_x[m_start:m_end] = torch.matmul(
grad_output[m_start:m_end, n_start:n_end], w[n_start:n_end]
)
grad_w[n_start:n_end] = torch.matmul(
grad_output[m_start:m_end, n_start:n_end].t(), x[m_start:m_end]
)
m_start += m_size
return grad_x, grad_w
# ========== End helper functions ==========
================================================
FILE: kernels/MoE/group_GEMM/triton/tgroup_gemm_backwards.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Tuple
import torch
import triton
import triton.language as tl
from tma_utils import TmaAutoTuneHelper
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
"""
Backward pass for grouped GEMM with Triton, where grouping is N*G
We are computing gradients with respect to both the input (`grad_x`) and the weights (`grad_w`).
"""
# =============== Start Triton Kernels ===============
@triton.jit
def _kernel_grouped_gemm_backward_x_scheduled(
grad_y_ptr, # grad of dl/dY [M, N*G]
w_t_ptr, # w transposed [K, N*G]
grad_x_ptr, # output of kernel [M, K]
group_offsets_ptr, # Pre-computed group offsets [G+1]
workspace, # Workspace for TMA descriptors
G, # Number of groups
M, # Total M dimension size
N, # N per group
K, # K dimension size
stride_go_m,
stride_go_n,
stride_w_n,
stride_w_k,
stride_gx_m,
stride_gx_k,
NUM_SMS,
USE_TMA_LOAD: tl.constexpr = False,
USE_TMA_STORE: tl.constexpr = False,
BLOCK_SIZE_M: tl.constexpr = 64,
BLOCK_SIZE_N: tl.constexpr = 64,
BLOCK_SIZE_K: tl.constexpr = 64,
GROUP_SIZE_M: tl.constexpr = 8,
EVEN_K: tl.constexpr = True,
) -> None:
"""
Scheduled grouped GEMM backward for X with TMA support.
For each group g, computes: grad_x[g] = grad_y[g] @ w_t[g].T
Where:
- grad_y is [M, N*G]
- w_t is [K, N*G] (transposed from [N*G, K])
- grad_x is [M, K]
"""
# Get coordinates for the current program
tidx = tl.program_id(axis=0)
dtype = grad_x_ptr.dtype.element_ty
TMA_SIZE: tl.constexpr = 128
# Initialize workspace pointer if using TMA store
if USE_TMA_STORE:
c_desc_ptr = workspace + tidx * TMA_SIZE
else:
c_desc_ptr = None
# Calculate work distribution parameters
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = num_pid_m * num_pid_k
# Process all assigned work items
pid = tidx
while pid < G * num_pid_in_group:
# Calculate work distribution for this pid
group_id = pid // num_pid_in_group
pid_in_group = pid % num_pid_in_group
pid_m = pid_in_group % num_pid_m
pid_k = pid_in_group // num_pid_m
# Get group boundaries
valid_group = group_id < G
group_start = tl.where(valid_group, tl.load(group_offsets_ptr + group_id), 0)
group_end = tl.where(valid_group, tl.load(group_offsets_ptr + group_id + 1), 0)
group_size = group_end - group_start
# Calculate a mask for valid processing (valid group and non-empty)
valid_work = valid_group & (group_size > 0)
# Only process if we have valid work
if valid_work:
# Compute offsets for this group
n_start = group_id * N
# Block dimensions
m_block_offset = pid_m * BLOCK_SIZE_M
k_block_offset = pid_k * BLOCK_SIZE_K
# Setup TMA descriptor for output if using TMA
if USE_TMA_STORE:
m_size = tl.minimum(
BLOCK_SIZE_M, group_end - (group_start + m_block_offset)
)
if m_size > 0:
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=c_desc_ptr,
global_address=grad_x_ptr
+ (group_start + m_block_offset) * stride_gx_m
+ k_block_offset * stride_gx_k,
load_size=[
m_size,
tl.minimum(BLOCK_SIZE_K, K - k_block_offset),
],
global_size=[m_size, K],
element_ty=dtype,
)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
# Initialize offsets for this block
offs_m = group_start + m_block_offset + tl.arange(0, BLOCK_SIZE_M)
# For K dimension, optimize memory access if EVEN_K is True
offs_k = k_block_offset + tl.arange(0, BLOCK_SIZE_K)
# Create masks
m_mask = offs_m < group_end
k_mask = offs_k < K
# Initialize accumulator
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
# Loop over the reduction dimension (N)
# Use smaller steps to improve precision and avoid numerical issues
for n_offset in range(0, N, BLOCK_SIZE_N):
# Handle boundary conditions for the reduction dimension
n_size = tl.minimum(BLOCK_SIZE_N, N - n_offset)
offs_n = n_start + n_offset + tl.arange(0, BLOCK_SIZE_N)
n_mask = offs_n < (n_start + N)
# Fixed stride formats to ensure consistent memory access
grad_y_block = tl.load(
grad_y_ptr
+ offs_m[:, None] * stride_go_m
+ offs_n[None, :] * stride_go_n,
mask=m_mask[:, None] & n_mask[None, :],
other=0.0,
)
# Load w_t [K, N*G] block with correct strides
w_t_block = tl.load(
w_t_ptr
+ offs_k[:, None] * stride_w_k
+ offs_n[None, :] * stride_w_n,
mask=k_mask[:, None] & n_mask[None, :],
other=0.0,
)
# grad_y @ w_t.T
# Allow TF32 if K is even and divisible by 8
if EVEN_K:
accumulator += tl.dot(
grad_y_block.to(tl.float32),
w_t_block.to(tl.float32).T,
allow_tf32=True,
)
else:
accumulator += tl.dot(
grad_y_block.to(tl.float32),
w_t_block.to(tl.float32).T,
allow_tf32=False,
)
# Store result to grad_x with explicit strides
if USE_TMA_STORE:
# TMA store
tl._experimental_descriptor_store(
c_desc_ptr,
accumulator.to(dtype),
[0, 0], # Starting offset in the output block
)
else:
# Standard store
tl.store(
grad_x_ptr
+ offs_m[:, None] * stride_gx_m
+ offs_k[None, :] * stride_gx_k,
accumulator.to(dtype),
mask=m_mask[:, None] & k_mask[None, :],
)
pid = pid + NUM_SMS
@triton.jit
def _kernel_grouped_gemm_backward_w_scheduled(
x_t_ptr, # x transposed [K, M]
grad_y_ptr, # grad of dl/dY [M, N*G]
grad_w_ptr, # output of kernel (grad_w) [N*G, K]
group_offsets_ptr, # Pre-computed group offsets [G+1]
workspace, # Workspace for TMA descriptors
G, # Number of groups
M, # Total M dimension size
N, # N per group
K, # K dimension size
stride_x_m,
stride_x_k,
stride_go_m,
stride_go_n,
stride_gw_n,
stride_gw_k,
NUM_SMS,
USE_TMA_LOAD: tl.constexpr = False,
USE_TMA_STORE: tl.constexpr = False,
BLOCK_SIZE_N: tl.constexpr = 64,
BLOCK_SIZE_K: tl.constexpr = 64,
BLOCK_SIZE_M: tl.constexpr = 32,
GROUP_SIZE_N: tl.constexpr = 8,
EVEN_K: tl.constexpr = True,
) -> None:
"""
Scheduled implementation of grouped GEMM backward for W with TMA support.
For each group g, computes:
grad_w[g] = grad_y[g].T @ x[g]
Where:
- x_t is [K, M] (transposed from [M, K])
- grad_y is [M, N*G]
- grad_w is [N*G, K]
"""
# Define coordinates for the current program
tidx = tl.program_id(axis=0)
dtype = grad_w_ptr.dtype.element_ty
TMA_SIZE: tl.constexpr = 128
# Initialize workspace pointer if using TMA store
if USE_TMA_STORE:
c_desc_ptr = workspace + tidx * TMA_SIZE
else:
c_desc_ptr = None
# Calculate work distribution parameters
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = num_pid_n * num_pid_k
# Process all assigned work items
pid = tidx
while pid < G * num_pid_in_group:
# Calculate work distribution for this pid
group_id = pid // num_pid_in_group
pid_in_group = pid % num_pid_in_group
pid_n = pid_in_group % num_pid_n
pid_k = pid_in_group // num_pid_n
# Get group boundaries
valid_group = group_id < G
group_start = tl.where(valid_group, tl.load(group_offsets_ptr + group_id), 0)
group_end = tl.where(valid_group, tl.load(group_offsets_ptr + group_id + 1), 0)
group_size = group_end - group_start
# Calculate a mask for valid processing (valid group and non-empty)
valid_work = valid_group & (group_size > 0)
# Only process if we have valid work
if valid_work:
# Compute offsets for this group
n_start = group_id * N
# Block dimensions
n_block_offset = pid_n * BLOCK_SIZE_N
k_block_offset = pid_k * BLOCK_SIZE_K
# Setup TMA descriptor for output if using TMA
if USE_TMA_STORE:
n_size = tl.minimum(BLOCK_SIZE_N, N - n_block_offset)
if n_size > 0:
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=c_desc_ptr,
global_address=grad_w_ptr
+ (n_start + n_block_offset) * stride_gw_n
+ k_block_offset * stride_gw_k,
load_size=[
n_size,
tl.minimum(BLOCK_SIZE_K, K - k_block_offset),
],
global_size=[n_size, K],
element_ty=dtype,
)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
# Initialize offsets for this block
offs_n = n_start + n_block_offset + tl.arange(0, BLOCK_SIZE_N)
offs_k = k_block_offset + tl.arange(0, BLOCK_SIZE_K)
# Create masks
n_mask = offs_n < (n_start + N)
k_mask = offs_k < K
# Initialize accumulator
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
# Loop over the reduction dimension (M) with smaller steps to avoid overflow
for m_offset in range(0, group_size, BLOCK_SIZE_M):
# Handle boundary conditions for the reduction dimension
m_size = tl.minimum(BLOCK_SIZE_M, group_size - m_offset)
offs_m = group_start + m_offset + tl.arange(0, BLOCK_SIZE_M)
m_mask = offs_m < group_end
# Load grad_y [M, N*G] block with explicit strides
grad_y_block = tl.load(
grad_y_ptr
+ offs_m[:, None] * stride_go_m
+ offs_n[None, :] * stride_go_n,
mask=m_mask[:, None] & n_mask[None, :],
other=0.0,
)
# Load x_t [K, M] block with explicit strides
x_t_block = tl.load(
x_t_ptr
+ offs_k[:, None] * stride_x_k
+ offs_m[None, :] * stride_x_m,
mask=k_mask[:, None] & m_mask[None, :],
other=0.0,
)
# Matrix multiplication: (grad_y_block.T @ x_t_block.T)
if EVEN_K:
accumulator += tl.dot(
grad_y_block.to(
tl.float32
).T, # Shape: [BLOCK_SIZE_N, BLOCK_SIZE_M]
x_t_block.to(
tl.float32
).T, # Shape: [BLOCK_SIZE_M, BLOCK_SIZE_K]
allow_tf32=True,
)
else:
accumulator += tl.dot(
grad_y_block.to(
tl.float32
).T, # Shape: [BLOCK_SIZE_N, BLOCK_SIZE_M]
x_t_block.to(
tl.float32
).T, # Shape: [BLOCK_SIZE_M, BLOCK_SIZE_K]
allow_tf32=False,
)
# Store result to grad_w with explicit strides
if USE_TMA_STORE:
# TMA store
tl._experimental_descriptor_store(
c_desc_ptr,
accumulator.to(dtype),
[0, 0], # Starting offset in the output block
)
else:
# Standard store with explicit strides
tl.store(
grad_w_ptr
+ offs_n[:, None] * stride_gw_n
+ offs_k[None, :] * stride_gw_k,
accumulator.to(dtype),
mask=n_mask[:, None] & k_mask[None, :],
)
pid = pid + NUM_SMS
# ========== End Triton kernels ==========
# ========== Begin grouped_gemm_backward cover function ==========
def grouped_gemm_backward(
grad_output: torch.Tensor,
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Backward pass for grouped matrix multiplication using scheduled kernels with TMA support.
Args:
grad_output: Gradient with respect to output, shape [M, N*G]
x: Input tensor from forward pass, shape [M, K]
w: Weight tensor from forward pass, shape [N*G, K]
m_sizes: Group sizes tensor, shape [G]
Returns:
Tuple of gradients with respect to x and w: (grad_x, grad_w)
"""
logging.info("Starting grouped_gemm_backward with TMA-enabled scheduling")
# Check CUDA availability
if not torch.cuda.is_available():
logging.error("CUDA not available for backward pass")
raise RuntimeError("CUDA not available for backward pass")
# return _pytorch_fallback_backward(grad_output, x, w, m_sizes)
# Get GPU parameters - TODO: this can use PyTorch cached info...
device_props = torch.cuda.get_device_properties("cuda")
NUM_SMS = device_props.multi_processor_count
# Check TMA support
has_tma = hasattr(tl.extra, "cuda") and device_props.major >= 9
if has_tma:
logging.info(f"TMA support detected on GPU with {NUM_SMS} SMs")
USE_TMA_LOAD = True # TODO - this does nothing atm..removed to focus on numerical correctness first.
USE_TMA_STORE = False
else:
logging.warning("TMA support not detected, disabling TMA optimizations")
USE_TMA_LOAD = False
USE_TMA_STORE = False
# Validate input dimensions
G = m_sizes.shape[0]
M, K_x = x.shape
N_times_G, K_w = w.shape
# Check that K dimensions match
if K_x != K_w:
logging.warning(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
raise ValueError("K dimensions must match for grouped GEMM backward")
# return _pytorch_fallback_backward(grad_output, x, w, m_sizes)
try:
# Ensure contiguous tensors
grad_output = grad_output.contiguous()
x = x.contiguous()
w = w.contiguous()
m_sizes = m_sizes.contiguous()
# Allocate output tensors
grad_x = torch.zeros_like(x)
grad_w = torch.zeros_like(w)
# Determine N per group
# N*G is the second dimension size of grad_output
N = N_times_G // G
# Set stride values
# Direct access pattern for grad_output tensor
stride_go_m = grad_output.stride(0) # grad_output in M dimension
stride_go_n = grad_output.stride(1) # grad_output in N dimension
# Pattern match the transposed weight tensor
stride_w_n = 1 # transposed weights in N dimension
stride_w_k = N * G # transposed weights in K dimension
# Pattern match the output grad_x tensor
stride_gx_m = grad_x.stride(0) # grad_x in M dimension
stride_gx_k = grad_x.stride(1) # Sgrad_x in K dimension
# Pattern match the transposed x tensor
stride_x_m = 1 # Stride for transposed x in M dimension
stride_x_k = M # Stride for transposed x in K dimension
# Pattern match the output grad_w tensor
stride_gw_n = grad_w.stride(0) # grad_w in N dimension
stride_gw_k = grad_w.stride(1) # grad_w in K dimension
# Pre-compute group offsets for indexing
group_offsets = torch.zeros(G + 1, device=m_sizes.device, dtype=torch.int32)
m_offset = 0
for g in range(G):
group_offsets[g] = m_offset
m_offset += m_sizes[g].item()
group_offsets[G] = m_offset # Total M
# Check if K dimension is even (optimize memory access patterns)
EVEN_K = (K_x % 8) == 0
logging.info(f"EVEN_K optimization enabled: {EVEN_K} (K={K_x})")
# Transpose x and w for backward computation
x_t = x.T.contiguous() # Shape: [K, M]
w_t = w.T.contiguous() # Shape: [K, N*G]
# Allocate workspace for TMA descriptors if needed
if USE_TMA_LOAD or USE_TMA_STORE:
workspace = torch.empty((NUM_SMS * 128), device=x.device, dtype=torch.uint8)
else:
# Empty tensor when TMA is not used
workspace = torch.empty(0, device=x.device, dtype=torch.uint8)
# Set block sizes based on K dimension
# For larger K, use smaller blocks to reduce register pressure
BLOCK_SIZE = 64 if K_x <= 64 else 32
BLOCK_SIZE_K = BLOCK_SIZE
BLOCK_SIZE_M = BLOCK_SIZE
BLOCK_SIZE_N = BLOCK_SIZE
# Determine maximum size needed and set the grid size
num_pid_m = triton.cdiv(M, BLOCK_SIZE_M)
num_pid_k = triton.cdiv(K_x, BLOCK_SIZE_K)
num_pid_n = triton.cdiv(N, BLOCK_SIZE_N)
# Compute total number of blocks needed for each kernel
total_blocks_x = G * num_pid_m * num_pid_k
total_blocks_w = G * num_pid_n * num_pid_k
try:
logging.info("Computing grad_x with TMA-enabled kernel")
# Fixed grid size based on SM count
grid = (NUM_SMS,)
_kernel_grouped_gemm_backward_x_scheduled[grid](
grad_output,
w_t, # Using transposed weights
grad_x,
group_offsets,
workspace,
G,
M,
N,
K_x,
stride_go_m,
stride_go_n,
stride_w_n,
stride_w_k,
stride_gx_m,
stride_gx_k,
NUM_SMS,
USE_TMA_LOAD,
USE_TMA_STORE,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
EVEN_K=EVEN_K,
)
logging.info(
"Kernel run success: grad_X computation successful with TMA-enabled kernel"
)
except Exception as e:
logging.error(f"FAILED: Error in TMA-enabled backward_x kernel: {e}")
logging.info("WARNING: Falling back to PyTorch for grad_x")
_compute_grad_x_pytorch(grad_output, w, m_sizes, grad_x)
try:
logging.info("Computing grad_w with TMA-enabled kernel")
# Fixed grid size based on SM count
grid = (NUM_SMS,)
_kernel_grouped_gemm_backward_w_scheduled[grid](
x_t, # Using transposed inputs
grad_output,
grad_w,
group_offsets,
workspace,
G,
M,
N,
K_w,
stride_x_m,
stride_x_k,
stride_go_m,
stride_go_n,
stride_gw_n,
stride_gw_k,
NUM_SMS,
USE_TMA_LOAD,
USE_TMA_STORE,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_M=BLOCK_SIZE_M,
EVEN_K=EVEN_K,
)
logging.info(
"Kernel run success - grad_W computation successful with TMA-enabled kernel"
)
except Exception as e:
logging.error(f"FAILED: Error in TMA-enabled backward_w kernel: {e}")
logging.info("WARNING: Falling back to PyTorch for grad_w")
# _compute_grad_w_pytorch(grad_output, x, m_sizes, grad_w)
return grad_x, grad_w
except Exception as e:
logging.error(f"Error in grouped_gemm_backward: {e}")
# return _pytorch_fallback_backward(grad_output, x, w, m_sizes)
================================================
FILE: kernels/MoE/group_GEMM/triton/tgroup_gemm_forward.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
# This is copied from FBGEMM, with some modifications. Not kept in sync, Original code:
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py
import functools
from typing import Optional
import tma_utils as utils
import torch
import triton
import triton.language as tl
from triton.runtime import driver # @manual
"""
_NV_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
},
num_stages=num_stages,
num_warps=num_warps,
num_ctas=num_ctas,
)
for block_size_m in [64, 128]
for block_size_n in [64, 128, 256]
for block_size_k in [64, 128, 256]
for num_stages in [3, 4]
for num_warps in [4, 8]
for num_ctas in [1]
]
_AMD_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"waves_per_eu": waves_per_cu,
"matrix_instr_nonkdim": matrix_instr_nonkdim,
},
num_stages=num_stages,
num_warps=num_warps,
)
for block_size_m in [32, 64, 128]
for block_size_n in [32, 64, 128, 256]
for block_size_k in [128, 256]
for num_stages in [1, 2]
for num_warps, waves_per_cu in [(4, 1), (8, 2), (16, 4)]
for matrix_instr_nonkdim in [16]
]
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
device = torch.cuda.current_device()
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
if dtsize is None:
dtsize = named_args["c_ptr"].element_size()
if dtype is None:
dtype = named_args["c_ptr"].dtype
pruned_configs = []
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
kw["BLOCK_SIZE_M"],
kw["BLOCK_SIZE_N"],
kw["BLOCK_SIZE_K"],
config.num_stages,
)
G, M, N, K = (
named_args["G"],
named_args["M_BUCKET"],
named_args["N"],
named_args["K"],
)
# 1. make sure we have enough smem
max_shared_memory = driver.active.utils.get_device_properties(device)[
"max_shared_mem"
]
if torch.version.hip:
required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize
else:
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
if required_shared_memory > max_shared_memory:
continue
M_PER_GROUP = M // G
MIN_M_TILES = 32 if torch.version.hip else 64
# 2. make sure we don't load M tiles that are too big
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
continue
# 3. make sure we don't load N tiles that are too small
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
continue
num_sm = driver.active.utils.get_device_properties(device)[
"multiprocessor_count"
]
N_TILES = N // BLOCK_N
MIN_N_TILES = 32 if torch.version.hip else 64
# 4. make sure we don't load N tiles that are too big
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
continue
# 5. make sure we don't load N tiles that are too small
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
continue
# 6. make sure K can be evenly divided
if K % BLOCK_K != 0:
continue
pruned_configs.append(config)
return pruned_configs
@triton.autotune(
configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
"""
@triton.jit
def _kernel_grouped_gemm(
a_desc_ptr,
b_desc_ptr,
c_ptr,
workspace,
m_sizes,
# problem sizes
G: tl.constexpr,
M_BUCKET: tl.constexpr,
N: tl.constexpr, # N is per group
K: tl.constexpr,
NUM_SMS: tl.constexpr,
USE_TMA_LOAD: tl.constexpr,
USE_TMA_STORE: tl.constexpr,
# tile sizes
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
) -> None:
tidx = tl.program_id(0)
dtype: tl.dtype = c_ptr.dtype.element_ty
TMA_SIZE: tl.constexpr = tl.constexpr(128)
if USE_TMA_STORE:
c_desc_ptr = workspace + tidx * TMA_SIZE
else:
c_desc_ptr = None
M_end_offset = 0
iterated_tiles = 0
for g in tl.range(G):
# Move across groups
M_start_offset = M_end_offset
m_size = tl.load(m_sizes + g)
M_end_offset = M_start_offset + m_size
if m_size > 0:
# Compute for this group
N_start_offset = g * N
n_size = N # N is already per group
# Calculate the number of tiles for this group
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
num_tiles = num_m_tiles * num_n_tiles
if USE_TMA_STORE:
# Set up TMA descriptor for output
# pyre-ignore
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=c_desc_ptr,
global_address=c_ptr
+ M_start_offset * (N * G)
+ N_start_offset, # Offset to this group's output
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, n_size],
element_ty=c_ptr.dtype.element_ty,
)
# pyre-ignore
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
# Move across tiles
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
gidx = tidx - iterated_tiles
# Split M first and N second.
tile_m_idx = gidx % num_m_tiles
tile_n_idx = gidx // num_m_tiles
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
tl.static_assert(K % BLOCK_SIZE_K == 0)
if USE_TMA_LOAD:
# Use TMA to load input and weight blocks
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
for k_offset in range(0, K, BLOCK_SIZE_K):
# Load input block [M, K]
a = tl._experimental_descriptor_load(
a_desc_ptr,
[m_offset, k_offset],
[BLOCK_SIZE_M, BLOCK_SIZE_K],
dtype,
)
# Load weight block [N, K]
b = tl._experimental_descriptor_load(
b_desc_ptr,
[n_offset, k_offset],
[BLOCK_SIZE_N, BLOCK_SIZE_K],
dtype,
)
# Compute matrix multiplication
accumulator += tl.dot(a, b.T)
else:
# Manual load without TMA
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = (
a_desc_ptr
+ (M_start_offset + offs_am[:, None]) * K
+ offs_k[None, :]
)
b_ptrs = (
b_desc_ptr
+ (N_start_offset + offs_bn[:, None]) * K
+ offs_k[None, :]
)
for k_offset in range(0, K, BLOCK_SIZE_K):
# Load with bounds checking
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
# Compute matrix multiplication
accumulator += tl.dot(a, b.T)
# Update pointers for next block
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
# Store result
if USE_TMA_STORE:
# Store using TMA
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
tl._experimental_descriptor_store(
c_desc_ptr,
accumulator.to(c_ptr.dtype.element_ty),
[m_offset, n_offset],
)
else:
# Manual store
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c = accumulator.to(c_ptr.dtype.element_ty)
tl.store(
c_ptr
+ (M_start_offset + offs_am[:, None])
* (N * G) # Row stride is N*G
+ (
N_start_offset + offs_bn[None, :]
), # Column offset to this group's N
c,
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
)
tidx += NUM_SMS # Move to next tile
iterated_tiles += num_tiles
TT_FP8_DTYPE = tl.float8e4b8 if torch.version.hip else tl.float8e4nv
"""@triton.autotune(
configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={
"early_config_prune": functools.partial(
early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1
)
},
)
"""
@triton.jit
def _kernel_grouped_gemm_fp8_rowwise(
a_desc_ptr,
a_scale_ptr,
b_desc_ptr,
b_scale_ptr,
c_ptr,
workspace,
m_sizes,
# problem sizes
G: tl.constexpr,
M_BUCKET: tl.constexpr,
N: tl.constexpr, # N is per group
K: tl.constexpr,
NUM_SMS: tl.constexpr,
USE_TMA_LOAD: tl.constexpr,
USE_TMA_STORE: tl.constexpr,
# tile sizes
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
) -> None:
tidx = tl.program_id(0)
dtype = TT_FP8_DTYPE
TMA_SIZE: tl.constexpr = tl.constexpr(128)
if USE_TMA_STORE:
c_desc_ptr = workspace + tidx * TMA_SIZE
else:
c_desc_ptr = None
M_end_offset = 0
iterated_tiles = 0
for g in tl.range(G):
# Move across groups
M_start_offset = M_end_offset
m_size = tl.load(m_sizes + g)
M_end_offset = M_start_offset + m_size
if m_size > 0:
# Compute for this group
N_start_offset = g * N
n_size = N # N is already per group
# Calculate the number of tiles for this group
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
num_tiles = num_m_tiles * num_n_tiles
if USE_TMA_STORE:
# Set up TMA descriptor for output
# pyre-ignore
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=c_desc_ptr,
global_address=c_ptr
+ M_start_offset * (N * G)
+ N_start_offset, # Offset to this group's output
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, n_size],
element_ty=c_ptr.dtype.element_ty,
)
# pyre-ignore
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
# Move across tiles
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
gidx = tidx - iterated_tiles
# Split M first and N second.
tile_m_idx = gidx % num_m_tiles
tile_n_idx = gidx // num_m_tiles
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
tl.static_assert(K % BLOCK_SIZE_K == 0)
if USE_TMA_LOAD:
# Use TMA to load input and weight blocks with FP8 support
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
for k_offset in range(0, K, BLOCK_SIZE_K):
# Load input block [M, K] with FP8
a = tl._experimental_descriptor_load(
a_desc_ptr,
[m_offset, k_offset],
[BLOCK_SIZE_M, BLOCK_SIZE_K],
dtype,
)
# Load weight block [N, K] with FP8
b = tl._experimental_descriptor_load(
b_desc_ptr,
[n_offset, k_offset],
[BLOCK_SIZE_N, BLOCK_SIZE_K],
dtype,
)
# Compute matrix multiplication
accumulator += tl.dot(a, b.T)
else:
# Manual load without TMA for FP8
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = (
a_desc_ptr
+ (M_start_offset + offs_am[:, None]) * K
+ offs_k[None, :]
)
b_ptrs = (
b_desc_ptr
+ (N_start_offset + offs_bn[:, None]) * K
+ offs_k[None, :]
)
for k_offset in range(0, K, BLOCK_SIZE_K):
# Load with bounds checking
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
# Compute matrix multiplication
accumulator += tl.dot(a, b.T)
# Update pointers for next block
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
# Load FP8 scales
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
a_scale = tl.load(
a_scale_ptr + M_start_offset + offs_am[:, None],
mask=offs_am[:, None] < m_size,
)
b_scale = tl.load(
b_scale_ptr + N_start_offset + offs_bn[None, :],
mask=offs_bn[None, :] < n_size,
)
# Apply scales to result
c = accumulator.to(tl.float32) * a_scale * b_scale
# Store result
if USE_TMA_STORE:
# Store using TMA
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
tl._experimental_descriptor_store(
c_desc_ptr,
c.to(c_ptr.dtype.element_ty),
[m_offset, n_offset],
)
else:
# Manual store
tl.store(
c_ptr
+ (M_start_offset + offs_am[:, None])
* (N * G) # Row stride is N*G
+ (
N_start_offset + offs_bn[None, :]
), # Column offset to this group's N
c,
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
)
tidx += NUM_SMS # Move to next tile
iterated_tiles += num_tiles
def _grouped_gemm(
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
x_scale: Optional[torch.Tensor] = None,
w_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if not utils.HAS_TMA_DESC:
raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
G = m_sizes.shape[0]
assert x.is_contiguous()
assert w.is_contiguous()
assert m_sizes.is_contiguous()
M, K = x.shape
N_times_G = w.shape[0]
# Ensure N is per group
assert (
N_times_G % G == 0
), f"Weight dimension ({N_times_G}) must be divisible by groups ({G})"
N = N_times_G // G
assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"
# Create output tensor with correct shape [M, N*G]
y = torch.empty((M, N_times_G), device=x.device, dtype=torch.bfloat16)
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
USE_TMA_LOAD = True # not torch.version.hip
USE_TMA_STORE = True
desc_helper = None
desc_x = x
desc_w = w
workspace = None
if USE_TMA_LOAD:
desc_helper = utils.TmaAutoTuneHelper()
desc_helper.init_tma_descriptor("x")
desc_helper.init_tma_descriptor("w")
desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
if USE_TMA_STORE:
workspace = torch.empty(
NUM_SMS * utils.TmaAutoTuneHelper.TMA_SIZE,
device=x.device,
dtype=torch.uint8,
)
# Skip autotuning - use fixed grid size
grid_size = (min(NUM_SMS, 4),) # Use smaller grid for small inputs
M_BUCKET = triton.next_power_of_2(M)
try:
if USE_TMA_LOAD and desc_helper is not None:
# Fixed block sizes that work well for most cases
BLOCK_SIZE_M = 64
BLOCK_SIZE_N = 64
BLOCK_SIZE_K = 32
desc_helper.fill_2d_tma_descriptor(
"x",
x.data_ptr(),
M,
K,
BLOCK_SIZE_M,
BLOCK_SIZE_K,
x.element_size(),
)
desc_helper.fill_2d_tma_descriptor(
"w",
w.data_ptr(),
N_times_G,
K,
BLOCK_SIZE_N,
BLOCK_SIZE_K,
w.element_size(),
)
except Exception as e:
print(f"Error in TMA descriptor setup: {e}")
if x_scale is not None and w_scale is not None:
assert x_scale.is_contiguous()
assert w_scale.is_contiguous()
# Call kernel directly without autotuning
_kernel_grouped_gemm_fp8_rowwise[grid_size](
desc_x,
x_scale,
desc_w,
w_scale,
y,
workspace,
m_sizes,
G,
M_BUCKET,
N, # N is per group
K,
NUM_SMS,
USE_TMA_LOAD,
USE_TMA_STORE,
BLOCK_SIZE_M=64, # Fixed block sizes
BLOCK_SIZE_N=64,
BLOCK_SIZE_K=32,
)
else:
assert x_scale is None
assert w_scale is None
# Call kernel directly without autotuning
_kernel_grouped_gemm[grid_size](
desc_x,
desc_w,
y,
workspace,
m_sizes,
G,
M_BUCKET,
N, # N is per group
K,
NUM_SMS,
USE_TMA_LOAD,
USE_TMA_STORE,
BLOCK_SIZE_M=64, # Fixed block sizes
BLOCK_SIZE_N=64,
BLOCK_SIZE_K=32,
)
# Verify the output shape
expected_output_shape = (M, N_times_G)
assert y.shape == expected_output_shape, (
f"Output shape mismatch: got {y.shape}, " f"expected {expected_output_shape}"
)
return y
def grouped_gemm_forward(
x: torch.Tensor, w: torch.Tensor, m_sizes: torch.Tensor
) -> torch.Tensor:
return _grouped_gemm(x, w, m_sizes)
def grouped_gemm_fp8_rowwise(
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
x_scale: torch.Tensor,
w_scale: torch.Tensor,
) -> torch.Tensor:
return _grouped_gemm(x, w, m_sizes, x_scale, w_scale)
================================================
FILE: kernels/MoE/group_GEMM/triton/utils/tma_utils.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
# This code is derived from: https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gemm/triton_gemm
import sys
import torch
import triton # @manual
import triton.language as tl # @manual
def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:
"""
Maps torch dtype to triton dtype.
Args:
dtype (torch.dtype): input dtype.
Returns:
tl.dtype: triton dtype.
"""
if dtype == torch.float16:
return tl.float16
elif dtype == torch.bfloat16:
return tl.bfloat16
elif dtype == torch.float32:
return tl.float32
elif dtype == torch.int32:
return tl.int32
elif dtype == torch.float8_e4m3fn and torch.version.hip is None:
return tl.float8e4nv
else:
raise ValueError(f"Unsupported dtype {dtype}")
# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498).
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)
if HAS_TMA_DESC:
print(
"TMA benchmarks will be running with experimental grid constant TMA descriptor.",
file=sys.stderr,
)
else:
print(
"Missing: This group gemm code will not run without TMA descriptor support....",
file=sys.stderr,
)
raise NotImplementedError("grouped Gemm without TMA is not supported")
class TmaAutoTuneHelper:
# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
class KernelParamWrapper:
def __init__(self, desc):
self.desc = desc
def tma_desc_cpu_ptr(self):
return self.desc.data_ptr()
TMA_SIZE = 128
def __init__(self):
self.fill_1d_tma_descriptor_inner = (
triton.runtime.driver.active.utils.fill_1d_tma_descriptor
)
self.fill_2d_tma_descriptor_inner = (
triton.runtime.driver.active.utils.fill_2d_tma_descriptor
)
if HAS_TMA_DESC:
self.descriptors = {}
else:
self.cuda_descriptors = {}
# Call this method outside of the lambda function for grid size
def init_tma_descriptor(self, name):
if HAS_TMA_DESC:
self.descriptors[name] = torch.empty(
TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8
)
else:
self.cuda_descriptors[name] = torch.empty(
TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8
)
# Call this method inside the lambda function for grid size
def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, desc_x.data_ptr()
)
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, buf_x.data_ptr()
)
desc_x.copy_(buf_x, non_blocking=True)
# Call this method inside the lambda function for grid size
def fill_2d_tma_descriptor(
self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size
):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
)
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()
)
desc_x.copy_(buf_x, non_blocking=True)
def get_tma_descriptor_kernel_param(self, name):
if HAS_TMA_DESC:
assert self.descriptors[name] is not None
return self.KernelParamWrapper(self.descriptors[name])
else:
assert self.cuda_descriptors[name] is not None
return self.cuda_descriptors[name]
================================================
FILE: kernels/blackwell/cute_gemm_01/Makefile
================================================
# Makefile for SM100 GEMM PyTorch Extension
# Set these paths according to your installation
CUTLASS_PATH ?= /path/to/cutlass
CUDA_HOME ?= $(shell python -c "import torch; print(torch.utils.cpp_extension.CUDA_HOME)")
# Build the extension
build:
CUTLASS_PATH=$(CUTLASS_PATH) python setup.py build_ext --inplace
# Install the extension
install:
CUTLASS_PATH=$(CUTLASS_PATH) pip install .
# Clean build artifacts
clean:
rm -rf build/ dist/ *.egg-info/ sm100_gemm*.so
# Test the installation
test:
python python_interface.py
# Check CUDA device capability
check_device:
python -c "import torch; print(f'CUDA device: {torch.cuda.get_device_name()}, Compute capability: {torch.cuda.get_device_capability()}')"
.PHONY: build install clean test check_device
================================================
FILE: kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/.ninja_log
================================================
# ninja log v5
0 15279 1748131038212164071 /data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o 1163be77f63db063
6 13596 1748131241209889865 /data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm.o 79aa61597088743a
8 13684 1748132015451659084 /data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm.o 89ead7aaccf82852
================================================
FILE: kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/build.ninja
================================================
ninja_required_version = 1.3
cxx = c++
nvcc = /usr/local/cuda-12.8/bin/nvcc
cflags = -pthread -B /home/less/.conda/envs/pycutlass/compiler_compat -fno-strict-overflow -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c
post_cflags = -O3 -std=c++17 -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1018"' -DTORCH_EXTENSION_NAME=sm100_gemm
cuda_cflags = -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c
cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++17 --expt-relaxed-constexpr --expt-extended-lambda -gencode=arch=compute_100a,code=sm_100a -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED --use_fast_math -Xcompiler=-fPIC -DCUTE_ARCH_TCGEN05_TMEM_ENABLED=1 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1018"' -DTORCH_EXTENSION_NAME=sm100_gemm
cuda_dlink_post_cflags =
sycl_dlink_post_cflags =
ldflags =
rule compile
command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
depfile = $out.d
deps = gcc
rule cuda_compile
depfile = $out.d
deps = gcc
command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags
build /data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm.o: cuda_compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm/sm100_gemm.cu
build /data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o: compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm/sm100_gemm_pytorch.cpp
================================================
FILE: kernels/blackwell/cute_gemm_01/driver.py
================================================
# ==============================================================================
# python_interface.py - High-level Python interface
# ==============================================================================
import torch
try:
import sm100_gemm # The compiled extension - this has to go after import torch...but auto-formatting is blocking
except ImportError:
print("❌ SM100 not ready!")
raise ImportError(
"SM100 not ready! Please build the extension using `python setup.py install`"
)
def sm100_gemm_f16(A, B, C=None, alpha=1.0, beta=0.0):
"""
Perform GEMM using SM100 optimized kernel: D = alpha * A @ B^T + beta * C
Args:
A (torch.Tensor): Input tensor A of shape (M, K), dtype=torch.float16
B (torch.Tensor): Input tensor B of shape (N, K), dtype=torch.float16
C (torch.Tensor, optional): Input tensor C of shape (M, N), dtype=torch.float32
If None, creates zero tensor
alpha (float): Scaling factor for A @ B^T
beta (float): Scaling factor for C
Returns:
torch.Tensor: Output tensor D of shape (M, N), dtype=torch.float32
Note:
- A and B are K-major (transposed in BLAS terms)
- C and D are N-major (row-major)
- All tensors must be on CUDA
- M must be multiple of 128, N multiple of 256, K multiple of 64
"""
# Input validation
assert A.dtype == torch.float16, f"A must be float16, got {A.dtype}"
assert B.dtype == torch.float16, f"B must be float16, got {B.dtype}"
assert A.is_cuda and B.is_cuda, "A and B must be on CUDA"
assert A.is_contiguous() and B.is_contiguous(), "A and B must be contiguous"
M, K = A.shape
N, K_B = B.shape
assert K == K_B, f"Inner dimensions must match: A.shape[1]={K}, B.shape[1]={K_B}"
# Check alignment requirements
assert M % 128 == 0, f"M={M} must be multiple of 128"
assert N % 256 == 0, f"N={N} must be multiple of 256"
assert K % 64 == 0, f"K={K} must be multiple of 64"
# Create C if not provided
if C is None:
C = torch.zeros(M, N, dtype=torch.float32, device=A.device)
else:
assert C.dtype == torch.float32, f"C must be float32, got {C.dtype}"
assert C.is_cuda, "C must be on CUDA"
assert C.is_contiguous(), "C must be contiguous"
assert C.shape == (
M,
N,
), f"C shape {C.shape} must match output shape ({M}, {N})"
# Call the extension
return sm100_gemm.sm100_gemm_f16(A, B, C, alpha, beta)
def benchmark_sm100_vs_torch(
M=1024, N=2048, K=256, num_warmup=1, num_trials=10
): # M=512, N=1024, K=256, num_warmup=10, num_trials=100):
"""
Benchmark SM100 GEMM against PyTorch's native GEMM
"""
# Ensure dimensions are aligned
M = ((M + 127) // 128) * 128
N = ((N + 255) // 256) * 256
K = ((K + 63) // 64) * 64
print(f"Benchmarking GEMM with shape: ({M}, {N}, {K})")
# Create test tensors
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
B = torch.randn(N, K, dtype=torch.float16, device="cuda")
C = torch.randn(M, N, dtype=torch.float16, device="cuda")
C32 = C.to(torch.float32).clone()
# Keep A and B as FP16 for PyTorch
A_fp16 = A
B_fp16 = B
# Warmup
for _ in range(num_warmup):
# PyTorch GEMM (using FP16)
torch_result = torch.addmm(C, A_fp16, B_fp16.T)
# SM100 GEMM
sm100_result = sm100_gemm_f16(A, B, C32)
torch.cuda.synchronize()
# Benchmark PyTorch
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(num_trials):
torch_result = torch.addmm(C, A_fp16, B_fp16.T)
end.record()
torch.cuda.synchronize()
torch_time = start.elapsed_time(end) / num_trials
# Benchmark SM100
start.record()
for _ in range(num_trials):
sm100_result = sm100_gemm_f16(A, B, C32)
end.record()
torch.cuda.synchronize()
sm100_time = start.elapsed_time(end) / num_trials
# Check correctness
max_diff = torch.max(torch.abs(torch_result - sm100_result.to(torch.float16)))
rel_error = max_diff / torch.max(torch.abs(torch_result))
# Calculate FLOPS
flops = 2 * M * N * K # Multiply-add operations
torch_tflops = flops / (torch_time * 1e-3) / 1e12
sm100_tflops = flops / (sm100_time * 1e-3) / 1e12
print(f"PyTorch time: {torch_time:.3f} ms ({torch_tflops:.2f} TFLOPS)")
print(f"SM100 time: {sm100_time:.3f} ms ({sm100_tflops:.2f} TFLOPS)")
print(f"Speedup: {torch_time/sm100_time:.2f}x")
print(f"Max difference: {max_diff:.6f}")
print(f"Relative error: {rel_error:.6f}")
return {
"torch_time": torch_time,
"sm100_time": sm100_time,
"speedup": torch_time / sm100_time,
"torch_tflops": torch_tflops,
"sm100_tflops": sm100_tflops,
"max_diff": max_diff.item(),
"rel_error": rel_error.item(),
}
# Example usage and test
if __name__ == "__main__":
# Test basic functionality
print("Testing SM100 GEMM...")
M, N, K = 512, 1024, 256
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
B = torch.randn(N, K, dtype=torch.float16, device="cuda")
C = torch.randn(M, N, dtype=torch.float32, device="cuda")
# Test the GEMM
result = sm100_gemm_f16(A, B, C, alpha=1.0, beta=0.5)
print(f"Result shape: {result.shape}, dtype: {result.dtype}")
# Run benchmark
print("\nRunning benchmark...")
benchmark_results = benchmark_sm100_vs_torch(M, N, K)
# ==============================================================================
# Makefile for easy building
# ==============================================================================
'''
MAKEFILE_CONTENT = """
# Makefile for SM100 GEMM PyTorch Extension
# Set these paths according to your installation
CUTLASS_PATH ?= /path/to/cutlass
CUDA_HOME ?= $(shell python -c "import torch; print(torch.utils.cpp_extension.CUDA_HOME)")
# Build the extension
build:
CUTLASS_PATH=$(CUTLASS_PATH) python setup.py build_ext --inplace
# Install the extension
install:
CUTLASS_PATH=$(CUTLASS_PATH) pip install .
# Clean build artifacts
clean:
rm -rf build/ dist/ *.egg-info/ sm100_gemm*.so
# Test the installation
test:
python python_interface.py
# Check CUDA device capability
check_device:
python -c "import torch; print(f'CUDA device: {torch.cuda.get_device_name()}, Compute capability: {torch.cuda.get_device_capability()}')"
.PHONY: build install clean test check_device
"""
# Write Makefile
with open("Makefile", "w") as f:
f.write(MAKEFILE_CONTENT)
print("Setup files created!")
print("To build:")
print("1. Set CUTLASS_PATH environment variable to your CUTLASS installation")
print("2. Run: make build")
print("3. Test: make test")
'''
================================================
FILE: kernels/blackwell/cute_gemm_01/setup.py
================================================
# setup.py
import os
import pybind11
import torch
from pybind11 import get_cmake_dir
from pybind11.setup_helpers import build_ext, Pybind11Extension
from setuptools import Extension, setup
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
# IMPORTANT: The following two lines are the only ones you need to change
# Get CUTLASS path (you'll need to set this to your CUTLASS installation)
CUTLASS_PATH = os.environ.get("CUTLASS_PATH", "/home/less/local/cutlas40")
# CUDA and PyTorch paths
cuda_home = torch.utils.cpp_extension.CUDA_HOME
pytorch_includes = torch.utils.cpp_extension.include_paths()
ext_modules = [
CUDAExtension(
name="sm100_gemm",
sources=[
"sm100_gemm_pytorch.cpp", # PyTorch bindings (C++)
"sm100_gemm.cu", # CUDA kernel implementation
],
include_dirs=[
# PyTorch includes
*pytorch_includes,
# CUTLASS includes
f"{CUTLASS_PATH}/include",
f"{CUTLASS_PATH}/tools/util/include",
# CUDA includes
f"{cuda_home}/include",
],
library_dirs=[
f"{cuda_home}/lib64",
],
libraries=["cuda", "cudart"],
extra_compile_args={
"cxx": [
"-O3",
"-std=c++17",
"-DCUTLASS_ARCH_MMA_SM100_SUPPORTED",
"-DCUTE_SM100_ENABLED",
],
"nvcc": [
"-O3",
"-std=c++17",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"-gencode=arch=compute_100a,code=sm_100a", # SM100 architecture
"-DCUTLASS_ARCH_MMA_SM100_SUPPORTED",
"-DCUTE_SM100_ENABLED",
"--use_fast_math",
"-Xcompiler=-fPIC",
"-DCUTE_ARCH_TCGEN05_TMEM_ENABLED=1", # Enable TCGEN05_TMEM
],
},
extra_link_args=["-lcuda", "-lcudart"],
language="c++",
)
]
setup(
name="sm100_gemm",
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
zip_safe=False,
python_requires=">=3.8",
install_requires=["torch>=1.12.0"],
)
================================================
FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.cu
================================================
// sm100_gemm_kernel.cu - CUDA kernel implementation
#include "sm100_gemm.h"
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#include <cutlass/arch/barrier.h>
#include <cutlass/cluster_launch.hpp>
#include <cutlass/half.h>
#include <cutlass/util/print_error.hpp>
#include <cute/algorithm/cooperative_copy.hpp>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/tmem_allocator_sm100.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cute/tensor.hpp>
using namespace cute;
// Shared storage structure
template <class TypeA, class TypeB, class ASmemLayout, class BSmemLayout>
struct SharedStorage {
alignas(128) cute::ArrayEngine<TypeA, cute::cosize_v<ASmemLayout>> A;
alignas(128) cute::ArrayEngine<TypeB, cute::cosize_v<BSmemLayout>> B;
alignas(16) cute::uint64_t mma_barrier;
alignas(16) cute::uint32_t tmem_base_ptr;
CUTE_DEVICE constexpr auto tensor_sA() {
return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{});
}
CUTE_DEVICE constexpr auto tensor_sB() {
return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{});
}
};
// Device kernel
template <class SharedStorage, class ATensor, class BTensor, class CTensor,
class DTensor, class MmaTiler_MNK, class TiledMMA,
class ClusterShape_MNK, class Alpha, class Beta>
__global__ static void
gemm_device(ATensor mA, BTensor mB, CTensor mC, DTensor mD,
MmaTiler_MNK mma_tiler, TiledMMA tiled_mma,
ClusterShape_MNK cluster_shape, Alpha alpha, Beta beta) {
// Step 1: The Prologue
Layout cluster_layout_vmnk = tiled_divide(
make_layout(cluster_shape), make_tile(typename TiledMMA::AtomThrID{}));
auto mma_coord_vmnk =
make_coord(blockIdx.x % size<0>(cluster_layout_vmnk),
blockIdx.x / size<0>(cluster_layout_vmnk), blockIdx.y, _);
auto mma_coord = select<1, 2, 3>(mma_coord_vmnk);
Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X, _1>{});
Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step<X, _1, _1>{});
Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1, _1, X>{});
Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1, _1, X>{});
// SMEM allocation
extern __shared__ char shared_memory[];
SharedStorage &shared_storage =
*reinterpret_cast<SharedStorage *>(shared_memory);
Tensor tCsA = shared_storage.tensor_sA();
Tensor tCsB = shared_storage.tensor_sB();
// MMA partitioning
auto mma_v = get<0>(mma_coord_vmnk);
ThrMMA cta_mma = tiled_mma.get_slice(mma_v);
Tensor tCgA = cta_mma.partition_A(gA);
Tensor tCgB = cta_mma.partition_B(gB);
Tensor tCgC = cta_mma.partition_C(gC);
Tensor tCgD = cta_mma.partition_C(gD);
// Fragment allocation
Tensor tCrA = cta_mma.make_fragment_A(tCsA);
Tensor tCrB = cta_mma.make_fragment_B(tCsB);
Tensor tCtAcc = cta_mma.make_fragment_C(tCgC);
uint32_t elect_one_thr = cute::elect_one_sync();
uint32_t elect_one_warp = (threadIdx.x / 32 == 0);
using TmemAllocator = cute::TMEM::Allocator1Sm;
TmemAllocator tmem_allocator{};
// TMEM allocation
if (elect_one_warp) {
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns,
&shared_storage.tmem_base_ptr);
}
__syncthreads();
tCtAcc.data() = shared_storage.tmem_base_ptr;
// Barrier initialization
if (elect_one_warp && elect_one_thr) {
cute::initialize_barrier(shared_storage.mma_barrier, 1);
}
int mma_barrier_phase_bit = 0;
__syncthreads();
// Step 2: The Mainloop
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile) {
// Load A and B tiles
cooperative_copy<128>(threadIdx.x, tCgA(_, _, _, k_tile), tCsA);
cooperative_copy<128>(threadIdx.x, tCgB(_, _, _, k_tile), tCsB);
__syncthreads();
// Execute MMAs
if (elect_one_warp) {
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCtAcc);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
cutlass::arch::umma_arrive(&shared_storage.mma_barrier);
}
cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit);
mma_barrier_phase_bit ^= 1;
}
// Step 3: The Epilogue
TiledCopy tiled_t2r_copy =
make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc);
ThrCopy thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x);
Tensor tDgC = thr_t2r_copy.partition_D(tCgC);
Tensor tDrC = make_fragment_like(tDgC);
copy(tDgC, tDrC);
Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc);
Tensor tDgD = thr_t2r_copy.partition_D(tCgD);
using AccType = typename decltype(tCtAcc)::value_type;
Tensor tDrAcc = make_tensor<AccType>(shape(tDgD));
copy(tiled_t2r_copy, tDtAcc, tDrAcc);
// AXPBY and store result
axpby(alpha, tDrAcc, beta, tDrC);
copy(tDrC, tDgD);
__syncthreads();
// Cleanup TMEM
if (elect_one_warp) {
tmem_allocator.release_allocation_lock();
tmem_allocator.free(shared_storage.tmem_base_ptr,
TmemAllocator::Sm100TmemCapacityColumns);
}
}
// Host function that launches the kernel
cudaError_t launch_sm100_gemm_f16(void *d_A, void *d_B, void *d_C, void *d_D,
int M, int N, int K, float alpha, float beta,
cudaStream_t stream) {
// Define types
using TypeA = cutlass::half_t;
using TypeB = cutlass::half_t;
using TypeC = float;
using TypeD = float;
// Create layouts (K-major for A and B, N-major for C and D)
auto layout_A = make_layout(make_shape(M, K), make_stride(K, Int<1>{}));
auto layout_B = make_layout(make_shape(N, K), make_stride(K, Int<1>{}));
auto layout_C = make_layout(make_shape(M, N), make_stride(N, Int<1>{}));
auto layout_D = layout_C;
// Create CuTe tensors
auto mA =
make_tensor(make_gmem_ptr(reinterpret_cast<TypeA *>(d_A)), layout_A);
auto mB =
make_tensor(make_gmem_ptr(reinterpret_cast<TypeB *>(d_B)), layout_B);
auto mC =
make_tensor(make_gmem_ptr(reinterpret_cast<TypeC *>(d_C)), layout_C);
auto mD =
make_tensor(make_gmem_ptr(reinterpret_cast<TypeD *>(d_D)), layout_D);
// Create TiledMMA
TiledMMA tiled_mma =
make_tiled_mma(SM100_MMA_F16BF16_SS<TypeA, TypeB, TypeC, 128, 256,
UMMA::Major::K, UMMA::Major::K>{});
// Define MMA tiler sizes
auto bM = tile_size<0>(tiled_mma); // 128
auto bN = tile_size<1>(tiled_mma); // 256
auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // 64
auto mma_tiler = make_shape(bM, bN, bK);
// Check alignment
if (M % int(bM) != 0 || N % int(bN) != 0 || K % int(bK) != 0) {
return cudaErrorInvalidValue;
}
// Create SMEM layouts
auto mma_shape_A = partition_shape_A(
tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler)));
auto mma_shape_B = partition_shape_B(
tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler)));
auto sA_layout =
UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeA>{}, mma_shape_A);
auto sB_layout =
UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom<TypeB>{}, mma_shape_B);
using SMEMStorage =
SharedStorage<TypeA, TypeB, decltype(sA_layout), decltype(sB_layout)>;
// Cluster configuration
auto cluster_shape = make_shape(Int<1>{}, Int<1>{}, Int<1>{});
// Launch parameters
dim3 dimBlock(128);
dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape),
size<2>(cluster_shape));
dim3 dimGrid(ceil_div(M, int(bM)), ceil_div(N, int(bN)));
int smemBytes = sizeof(SMEMStorage);
// Get kernel pointer
auto *kernel_ptr =
&gemm_device<SMEMStorage, decltype(mA), decltype(mB), decltype(mC),
decltype(mD), decltype(mma_tiler), decltype(tiled_mma),
decltype(cluster_shape), float, float>;
// Set kernel attributes
cudaError_t error = cudaFuncSetAttribute(
kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smemBytes);
if (error != cudaSuccess) {
return error;
}
// Launch kernel
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster,
smemBytes};
cutlass::Status status = cutlass::launch_kernel_on_cluster(
params, (void const *)kernel_ptr, mA, mB, mC, mD, mma_tiler, tiled_mma,
cluster_shape, alpha, beta);
return (status == cutlass::Status::kSuccess) ? cudaSuccess
: cudaErrorLaunchFailure;
}
#else
cudaError_t launch_sm100_gemm_f16(void *d_A, void *d_B, void *d_C, void *d_D,
int M, int N, int K, float alpha, float beta,
cudaStream_t stream) {
return cudaErrorNotSupported;
}
#endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED
================================================
FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/PKG-INFO
================================================
Metadata-Version: 2.4
Name: sm100_gemm
Version: 0.0.0
Requires-Python: >=3.8
Requires-Dist: torch>=1.12.0
Dynamic: requires-dist
Dynamic: requires-python
================================================
FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/SOURCES.txt
================================================
setup.py
sm100_gemm.cu
sm100_gemm_pytorch.cpp
sm100_gemm.egg-info/PKG-INFO
sm100_gemm.egg-info/SOURCES.txt
sm100_gemm.egg-info/dependency_links.txt
sm100_gemm.egg-info/not-zip-safe
sm100_gemm.egg-info/requires.txt
sm100_gemm.egg-info/top_level.txt
================================================
FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/dependency_links.txt
================================================
================================================
FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/not-zip-safe
================================================
================================================
FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/requires.txt
================================================
torch>=1.12.0
================================================
FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/top_level.txt
================================================
sm100_gemm
================================================
FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.h
================================================
// sm100_gemm_kernel.h - Header file for CUDA kernel
#pragma once
#include <cuda_runtime.h>
#ifdef __cplusplus
extern "C" {
#endif
/**
* Launch SM100 GEMM kernel: D = alpha * A @ B^T + beta * C
*
* @param d_A Pointer to matrix A in device memory (M x K, FP16, K-major)
* @param d_B Pointer to matrix B in device memory (N x K, FP16, K-major)
* @param d_C Pointer to matrix C in device memory (M x N, FP32, N-major)
* @param d_D Pointer to matrix D in device memory (M x N, FP32, N-major)
* @param M Number of rows in A and C/D
* @param N Number of rows in B and columns in C/D
* @param K Number of columns in A and B
* @param alpha Scaling factor for A @ B^T
* @param beta Scaling factor for C
* @param stream CUDA stream (currently unused, for future async support)
*
* @return cudaSuccess on success, error code otherwise
*
* Requirements:
* - M must be multiple of 128
* - N must be multiple of 256
* - K must be multiple of 64
* - All pointers must be valid device memory
* - Tensors must be contiguous with specified layouts
*/
cudaError_t launch_sm100_gemm_f16(void *d_A, void *d_B, void *d_C, void *d_D,
int M, int N, int K, float alpha, float beta,
cudaStream_t stream = 0);
#ifdef __cplusplus
}
#endif
================================================
FILE: kernels/blackwell/cute_gemm_01/sm100_gemm_pytorch.cpp
================================================
// sm100_gemm_pytorch.cpp - PyTorch C++ extension (no CUDA code)
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "sm100_gemm.h"
// Check if SM100 support is available at compile time
bool is_sm100_supported() {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
return true;
#else
return false;
#endif
}
// Check if current GPU supports SM100 at runtime
bool check_sm100_device() {
int device;
cudaGetDevice(&device);
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, device);
if (error != cudaSuccess) {
return false;
}
// Check for SM100 architecture (compute capability 10.0a)
return (props.major == 10 && props.minor == 0);
}
torch::Tensor sm100_gemm_f16(const torch::Tensor &A, const torch::Tensor &B,
const torch::Tensor &C, float alpha = 1.0f,
float beta = 0.0f) {
// Check compile-time support
TORCH_CHECK(
is_sm100_supported(),
"SM100 support not compiled. Requires CUTLASS_ARCH_MMA_SM100_SUPPORTED");
// Check runtime device support
TORCH_CHECK(check_sm100_device(),
"Current GPU does not support SM100 architecture (requires "
"compute capability 10.0a)");
// Input validation
TORCH_CHECK(A.device().is_cuda(), "A must be a CUDA tensor");
TORCH_CHECK(B.device().is_cuda(), "B must be a CUDA tensor");
TORCH_CHECK(C.device().is_cuda(), "C must be a CUDA tensor");
TORCH_CHECK(A.dtype() == torch::kFloat16, "A must be float16");
TORCH_CHECK(B.dtype() == torch::kFloat16, "B must be float16");
TORCH_CHECK(C.dtype() == torch::kFloat32, "C must be float32");
TORCH_CHECK(A.is_contiguous(), "A must be contiguous");
TORCH_CHECK(B.is_contiguous(), "B must be contiguous");
TORCH_CHECK(C.is_contiguous(), "C must be contiguous");
TORCH_CHECK(A.dim() == 2, "A must be 2D");
TORCH_CHECK(B.dim() == 2, "B must be 2D");
TORCH_CHECK(C.dim() == 2, "C must be 2D");
// Get dimensions
int64_t M = A.size(0);
int64_t K = A.size(1);
int64_t N = B.size(0);
int64_t K_B = B.size(1);
TORCH_CHECK(K == K_B, "Inner dimensions must match: A.shape[1]=", K,
", B.shape[1]=", K_B);
TORCH_CHECK(C.size(0) == M && C.size(1) == N, "C dimensions (", C.size(0),
", ", C.size(1), ") must match output shape (", M, ", ", N, ")");
// Check alignment requirements for SM100
TORCH_CHECK(M % 128 == 0, "M=", M, " must be multiple of 128");
TORCH_CHECK(N % 256 == 0, "N=", N, " must be multiple of 256");
TORCH_CHECK(K % 64 == 0, "K=", K, " must be multiple of 64");
// Check size limits (avoid overflow in int conversion)
TORCH_CHECK(M <= INT_MAX && N <= INT_MAX && K <= INT_MAX,
"Dimensions too large for int conversion");
// Create output tensor
auto D = torch::empty_like(C);
// Set CUDA device guard
const auto device = A.device();
c10::cuda::CUDAGuard device_guard(device);
// Get current CUDA stream
cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index()).stream();
// Launch the kernel
cudaError_t error = launch_sm100_gemm_f16(
A.data_ptr(), B.data_ptr(), C.data_ptr(), D.data_ptr(),
static_cast<int>(M), static_cast<int>(N), static_cast<int>(K), alpha,
beta, stream);
// Check for launch errors
TORCH_CHECK(error == cudaSuccess,
"SM100 GEMM kernel launch failed: ", cudaGetErrorString(error));
// Check for kernel execution errors
C10_CUDA_CHECK(cudaGetLastError());
return D;
}
// Utility functions for debugging and information
torch::Tensor get_device_info() {
int device;
cudaGetDevice(&device);
cudaDeviceProp props;
cudaGetDeviceProperties(&props, device);
// Return device info as a tensor (for easy Python access)
auto info = torch::zeros({4}, torch::kInt32);
auto accessor = info.accessor<int32_t, 1>();
accessor[0] = props.major; // Compute capability major
accessor[1] = props.minor; // Compute capability minor
accessor[2] = is_sm100_supported(); // Compile-time support
accessor[3] = check_sm100_device(); // Runtime device support
return info;
}
std::vector<int64_t> get_aligned_shape(int64_t M, int64_t N, int64_t K) {
// Return properly aligned dimensions for SM100
int64_t aligned_M = ((M + 127) / 128) * 128;
int64_t aligned_N = ((N + 255) / 256) * 256;
int64_t aligned_K = ((K + 63) / 64) * 64;
return {aligned_M, aligned_N, aligned_K};
}
torch::Tensor create_aligned_tensor(const std::vector<int64_t> &shape,
torch::ScalarType dtype,
torch::Device device) {
// Create a tensor with SM100-aligned dimensions
TORCH_CHECK(shape.size() == 2, "Shape must be 2D");
auto aligned_shape =
get_aligned_shape(shape[0], shape[1], shape.size() > 2 ? shape[2] : 64);
if (shape.size() == 2) {
return torch::zeros({aligned_shape[0], aligned_shape[1]},
torch::TensorOptions().dtype(dtype).device(device));
} else {
return torch::zeros({aligned_shape[0], aligned_shape[2]},
torch::TensorOptions().dtype(dtype).device(device));
}
}
// Python bindings
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "SM100 GEMM PyTorch Extension";
// Main GEMM function
m.def("sm100_gemm_f16", &sm100_gemm_f16,
"SM100 GEMM with FP16 inputs and FP32 output: D = alpha * A @ B^T + "
"beta * C",
py::arg("A"), py::arg("B"), py::arg("C"), py::arg("alpha") = 1.0f,
py::arg("beta") = 0.0f);
// Utility functions
m.def("is_sm100_supported", &is_sm100_supported,
"Check if SM100 support was compiled in");
m.def("check_sm100_device", &check_sm100_device,
"Check if current GPU supports SM100 architecture");
m.def("get_device_info", &get_device_info,
"Get device compute capability and SM100 support info");
m.def("get_aligned_shape", &get_aligned_shape,
"Get SM100-aligned dimensions for given shape", py::arg("M"),
py::arg("N"), py::arg("K"));
m.def("create_aligned_tensor", &create_aligned_tensor,
"Create tensor with SM100-aligned dimensions", py::arg("shape"),
py::arg("dtype"), py::arg("device"));
// Constants for alignment requirements
m.attr("MMA_TILE_M") = 128;
m.attr("MMA_TILE_N") = 256;
m.attr("MMA_TILE_K") = 64;
}
================================================
FILE: kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/.ninja_log
================================================
# ninja log v5
1 15202 1748185895110710199 /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o 342153d32d365f0b
7 78 1748186494782816813 /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o 342153d32d365f0b
6 15086 1748186805607894090 /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o 342153d32d365f0b
6 14058 1748187024415643408 /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm.o 6c5f77cfca7cfb81
================================================
FILE: kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/build.ninja
================================================
ninja_required_version = 1.3
cxx = c++
nvcc = /usr/local/cuda-12.8/bin/nvcc
cflags = -pthread -B /home/less/.conda/envs/pycutlass/compiler_compat -fno-strict-overflow -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c
post_cflags = -O3 -std=c++17 -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1018"' -DTORCH_EXTENSION_NAME=sm100_gemm
cuda_cflags = -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c
cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++17 --expt-relaxed-constexpr --expt-extended-lambda -gencode=arch=compute_100a,code=sm_100a -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED --use_fast_math -Xcompiler=-fPIC -DCUTE_ARCH_TCGEN05_TMEM_ENABLED=1 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1018"' -DTORCH_EXTENSION_NAME=sm100_gemm
cuda_dlink_post_cflags =
sycl_dlink_post_cflags =
ldflags =
rule compile
command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
depfile = $out.d
deps = gcc
rule cuda_compile
depfile = $out.d
deps = gcc
command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags
build /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm.o: cuda_compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/sm100_gemm.cu
build /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o: compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/sm100_gemm_pytorch.cpp
================================================
FILE: kernels/blackwell/cute_gemm_02_tma/driver.py
================================================
# python_interface.py - High-level Python interface with TMA support
import torch
try:
import sm100_gemm # The compiled extension
except ImportError:
print("❌ SM100 not ready!")
raise ImportError(
"SM100 not ready! Please build the extension using `python setup.py install`"
)
def check_sm100_compatibility():
"""Check if SM100 is supported and available"""
compile_support = sm100_gemm.is_sm100_supported()
device_support = sm100_gemm.check_sm100_device()
info = sm100_gemm.get_device_info()
major, minor, compile_flag, device_flag = info.tolist()
print(f"Device compute capability: {major}.{minor}")
print(f"Compile-time SM100 support: {bool(compile_flag)}")
print(f"Runtime SM100 device support: {bool(device_flag)}")
if not compile_support:
print(
"❌ SM100 support not compiled in. Rebuild with CUTLASS_ARCH_MMA_SM100_SUPPORTED"
)
elif not device_support:
print("❌ Current GPU does not support SM100 (need compute capability 10.0a)")
else:
print("✅ SM100 with TMA ready!")
return compile_support and device_support
def sm100_gemm_f16_tma(A, B, C=None, alpha=1.0, beta=0.0, check_alignment=True):
"""
Perform GEMM using SM100 optimized kernel with TMA: D = alpha * A @ B^T + beta * C
Args:
A (torch.Tensor): Input tensor A of shape (M, K), dtype=torch.float16
B (torch.Tensor): Input tensor B of shape (N, K), dtype=torch.float16
C (torch.Tensor, optional): Input tensor C of shape (M, N), dtype=torch.float32
If None, creates zero tensor
alpha (float): Scaling factor for A @ B^T
beta (float): Scaling factor for C
check_alignment (bool): Whether to check and suggest aligned dimensions
Returns:
torch.Tensor: Output tensor D of shape (M, N), dtype=torch.float32
Note:
- Uses TMA (Tensor Memory Accelerator) for efficient memory transfers
- A and B are K-major (transposed in BLAS terms)
- C and D are N-major (row-major)
- All tensors must be on CUDA
- M must be multiple of 128, N multiple of 256, K multiple of 64
"""
# Input validation
assert A.dtype == torch.float16, f"A must be float16, got {A.dtype}"
assert B.dtype == torch.float16, f"B must be float16, got {B.dtype}"
assert A.is_cuda and B.is_cuda, "A and B must be on CUDA"
assert A.is_contiguous() and B.is_contiguous(), "A and B must be contiguous"
M, K = A.shape
N, K_B = B.shape
assert K == K_B, f"Inner dimensions must match: A.shape[1]={K}, B.shape[1]={K_B}"
# Check or fix alignment requirements
if check_alignment:
aligned_M, aligned_N, aligned_K = sm100_gemm.get_aligned_shape(M, N, K)
if M != aligned_M or N != aligned_N or K != aligned_K:
print(f"Warning: Dimensions ({M}, {N}, {K}) not aligned for SM100")
print(
f"Suggested aligned dimensions: ({aligned_M}, {aligned_N}, {aligned_K})"
)
print("Consider padding tensors or use create_aligned_tensors()")
# Strict alignment check
assert (
M % sm100_gemm.MMA_TILE_M == 0
), f"M={M} must be multiple of {sm100_gemm.MMA_TILE_M}"
assert (
N % sm100_gemm.MMA_TILE_N == 0
), f"N={N} must be multiple of {sm100_gemm.MMA_TILE_N}"
assert (
K % sm100_gemm.MMA_TILE_K == 0
), f"K={K} must be multiple of {sm100_gemm.MMA_TILE_K}"
# Create C if not provided
if C is None:
C = torch.zeros(M, N, dtype=torch.float32, device=A.device)
else:
assert C.dtype == torch.float32, f"C must be float32, got {C.dtype}"
assert C.is_cuda, "C must be on CUDA"
assert C.is_contiguous(), "C must be contiguous"
assert C.shape == (
M,
N,
), f"C shape {C.shape} must match output shape ({M}, {N})"
# Call the extension (now uses TMA internally)
return sm100_gemm.sm100_gemm_f16(A, B, C, alpha, beta)
# Keep the old name for compatibility
sm100_gemm_f16 = sm100_gemm_f16_tma
def create_aligned_tensors(
M, N, K, device="cuda", dtype_AB=torch.float16, dtype_C=torch.float32
):
"""
Create properly aligned tensors for SM100 GEMM with TMA
Returns:
tuple: (A, B, C) tensors with aligned dimensions
"""
aligned_M, aligned_N, aligned_K = sm100_gemm.get_aligned_shape(M, N, K)
A = torch.zeros(aligned_M, aligned_K, dtype=dtype_AB, device=device)
B = torch.zeros(aligned_N, aligned_K, dtype=dtype_AB, device=device)
C = torch.zeros(aligned_M, aligned_N, dtype=dtype_C, device=device)
return A, B, C
def pad_to_aligned(tensor, target_shape=None, dim_requirements=None):
"""
Pad tensor to meet SM100 alignment requirements
Args:
tensor: Input tensor to pad
target_shape: Specific target shape (optional)
dim_requirements: Tuple of (M_align, N_align, K_align) requirements
Returns:
Padded tensor and padding info for later unpadding
"""
if dim_requirements is None:
dim_requirements = (
sm100_gemm.MMA_TILE_M,
sm100_gemm.MMA_TILE_N,
sm100_gemm.MMA_TILE_K,
)
if tensor.dim() == 2:
M, N = tensor.shape
if target_shape:
target_M, target_N = target_shape
else:
target_M = (
(M + dim_requirements[0] - 1) // dim_requirements[0]
) * dim_requirements[0]
target_N = (
(N + dim_requirements[1] - 1) // dim_requirements[1]
) * dim_requirements[1]
pad_M = target_M - M
pad_N = target_N - N
# PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
padded = torch.nn.functional.pad(tensor, (0, pad_N, 0, pad_M))
return padded, (M, N, pad_M, pad_N)
else:
raise ValueError("Only 2D tensors supported")
def unpad_result(tensor, padding_info):
"""Remove padding from result tensor"""
orig_M, orig_N, pad_M, pad_N = padding_info
return tensor[:orig_M, :orig_N]
def benchmark_sm100_vs_torch(
M=512,
N=1024,
K=256,
num_warmup=1,
num_trials=10,
auto_align=True,
compare_tma=True,
):
"""
Benchmark SM100 GEMM with TMA against PyTorch's native GEMM
"""
# Ensure dimensions are aligned
if auto_align:
M = (
(M + sm100_gemm.MMA_TILE_M - 1) // sm100_gemm.MMA_TILE_M
) * sm100_gemm.MMA_TILE_M
N = (
(N + sm100_gemm.MMA_TILE_N - 1) // sm100_gemm.MMA_TILE_N
) * sm100_gemm.MMA_TILE_N
K = (
(K + sm100_gemm.MMA_TILE_K - 1) // sm100_gemm.MMA_TILE_K
) * sm100_gemm.MMA_TILE_K
print(f"Benchmarking GEMM with TMA for shape: ({M}, {N}, {K})")
# Check SM100 availability
if not check_sm100_compatibility():
print("SM100 not available, skipping benchmark")
return None
# Create test tensors
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
B = torch.randn(N, K, dtype=torch.float16, device="cuda")
C = torch.randn(M, N, dtype=torch.float32, device="cuda")
# PyTorch baseline (using mixed precision)
A_fp32 = A.float()
B_fp32 = B.float()
# Warmup
for _ in range(num_warmup):
# PyTorch GEMM
torch_result = torch.addmm(C, A_fp32, B_fp32.T)
# SM100 GEMM with TMA
sm100_result = sm100_gemm_f16_tma(A, B, C.clone(), check_alignment=False)
torch.cuda.synchronize()
# Benchmark PyTorch
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
# warmup
torch_result = torch.addmm(C, A_fp32, B_fp32.T)
start.record()
for _ in range(num_trials):
torch_result = torch.addmm(C, A_fp32, B_fp32.T)
end.record()
torch.cuda.synchronize()
torch_time = start.elapsed_time(end) / num_trials
# Benchmark SM100 with TMA
# warmup
sm100_result = sm100_gemm_f16_tma(A, B, C.clone(), check_alignment=False)
start.record()
for _ in range(num_trials):
sm100_result = sm100_gemm_f16_tma(A, B, C.clone(), check_alignment=False)
end.record()
torch.cuda.synchronize()
sm100_time = start.elapsed_time(end) / num_trials
# Check correctness
max_diff = torch.max(torch.abs(torch_result - sm100_result))
rel_error = max_diff / torch.max(torch.abs(torch_result))
# Calculate FLOPS
flops = 2 * M * N * K # Multiply-add operations
torch_tflops = flops / (torch_time * 1e-3) / 1e12
sm100_tflops = flops / (sm100_time * 1e-3) / 1e12
print(f"PyTorch time: {torch_time:.3f} ms ({torch_tflops:.2f} TFLOPS)")
print(f"SM100+TMA time: {sm100_time:.3f} ms ({sm100_tflops:.2f} TFLOPS)")
print(f"Speedup: {torch_time/sm100_time:.2f}x")
print(f"Max difference: {max_diff:.6f}")
print(f"Relative error: {rel_error:.6f}")
print(f"🚀 TMA provides efficient memory transfers for large matrices!")
return {
"torch_time": torch_time,
"sm100_time": sm100_time,
"speedup": torch_time / sm100_time,
"torch_tflops": torch_tflops,
"sm100_tflops": sm100_tflops,
"max_diff": max_diff.item(),
"rel_error": rel_error.item(),
}
# Neural network layer implementations with TMA
class SM100LinearTMA(torch.nn.Module):
"""
Linear layer using SM100 GEMM with TMA for forward pass
"""
def __init__(self, in_features, out_features, bias=True, device="cuda"):
super().__init__()
# Align dimensions
self.orig_in_features = in_features
self.orig_out_features = out_features
aligned_in = (
(in_features + sm100_gemm.MMA_TILE_K - 1) // sm100_gemm.MMA_TILE_K
) * sm100_gemm.MMA_TILE_K
aligned_out = (
(out_features + sm100_gemm.MMA_TILE_N - 1) // sm100_gemm.MMA_TILE_N
) * sm100_gemm.MMA_TILE_N
self.in_features = aligned_in
self.out_features = aligned_out
# Parameters (with padding)
self.weight = torch.nn.Parameter(
torch.randn(aligned_out, aligned_in, dtype=torch.float16, device=device)
* 0.1
)
if bias:
self.bias = torch.nn.Parameter(
torch.zeros(aligned_out, dtype=torch.float32, device=device)
)
else:
self.register_parameter("bias", None)
print(
f"SM100LinearTMA: {in_features} -> {out_features} (aligned: {aligned_in} -> {aligned_out})"
)
print("🚀 Using TMA for efficient memory transfers")
def forward(self, x):
# Pad input if necessary
batch_size = x.size(0)
# Align batch size
aligned_batch = (
(batch_size + sm100_gemm.MMA_TILE_M - 1) // sm100_gemm.MMA_TILE_M
) * sm100_gemm.MMA_TILE_M
if x.size(1) != self.in_features or batch_size != aligned_batch:
x_padded = torch.zeros(
aligned_batch, self.in_features, dtype=torch.float16, device=x.device
)
x_padded[:batch_size, : self.orig_in_features] = x
x = x_padded
# Prepare bias
if self.bias is not None:
C = (
self.bias.unsqueeze(0)
.expand(aligned_batch, self.out_features)
.contiguous()
)
beta = 1.0
else:
C = torch.zeros(
aligned_batch, self.out_features, dtype=torch.float32, device=x.device
)
beta = 0.0
# SM100 GEMM with TMA: output = x @ weight^T + bias
output = sm100_gemm_f16_tma(
x, self.weight, C, alpha=1.0, beta=beta, check_alignment=False
)
# Remove padding
return output[:batch_size, : self.orig_out_features]
def benchmark_tma_vs_cooperative_copy(M=512, N=1024, K=256, num_trials=50):
"""
TMA addition
"""
results = benchmark_sm100_vs_torch(M, N, K, num_trials=num_trials)
if results:
print(f"\nTMA-accelerated SM100 GEMM achieved:")
print(f" Performance: {results['sm100_tflops']:.2f} TFLOPS")
print(f" Speedup: {results['speedup']:.2f}x over PyTorch")
print(f" Memory efficiency: Hardware-optimized transfers")
def stress_test_large_matrices():
"""
Test TMA performance with large matrices that benefit most from TMA
"""
print("\n=== Large Matrix Stress Test with TMA ==="
gitextract_gcbwaf39/
├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── assets/
│ └── images/
│ ├── dev-discuss-asynctp/
│ │ └── readme.md
│ └── readme.md
├── dev/
│ ├── sr/
│ │ ├── .gitignore
│ │ ├── readme.md
│ │ ├── setup.py
│ │ ├── src/
│ │ │ ├── stochastic_rounding.cu
│ │ │ ├── stochastic_rounding.hpp
│ │ │ └── stochastic_rounding_cuda.cu
│ │ ├── test.md
│ │ ├── tests/
│ │ │ ├── benchmark.py
│ │ │ └── core_unit_tests.py
│ │ ├── usage.py
│ │ └── usage2.py
│ └── triton_groupGEMM/
│ ├── groupgemm.py
│ ├── testing/
│ │ ├── base_testing.py
│ │ └── unit_tests.py
│ ├── tma_utils.py
│ └── triton_tutorial_groupgemm.py
├── kernels/
│ ├── MoE/
│ │ └── group_GEMM/
│ │ └── triton/
│ │ ├── readme.md
│ │ ├── testing/
│ │ │ ├── fast_verification.py
│ │ │ └── pytorch_reference_backwards.py
│ │ ├── tgroup_gemm_backwards.py
│ │ ├── tgroup_gemm_forward.py
│ │ └── utils/
│ │ └── tma_utils.py
│ ├── blackwell/
│ │ ├── cute_gemm_01/
│ │ │ ├── Makefile
│ │ │ ├── build/
│ │ │ │ └── temp.linux-x86_64-cpython-312/
│ │ │ │ ├── .ninja_deps
│ │ │ │ ├── .ninja_log
│ │ │ │ ├── build.ninja
│ │ │ │ ├── sm100_gemm.o
│ │ │ │ └── sm100_gemm_pytorch.o
│ │ │ ├── dist/
│ │ │ │ └── sm100_gemm-0.0.0-py3.12-linux-x86_64.egg
│ │ │ ├── driver.py
│ │ │ ├── setup.py
│ │ │ ├── sm100_gemm.cu
│ │ │ ├── sm100_gemm.egg-info/
│ │ │ │ ├── PKG-INFO
│ │ │ │ ├── SOURCES.txt
│ │ │ │ ├── dependency_links.txt
│ │ │ │ ├── not-zip-safe
│ │ │ │ ├── requires.txt
│ │ │ │ └── top_level.txt
│ │ │ ├── sm100_gemm.h
│ │ │ └── sm100_gemm_pytorch.cpp
│ │ └── cute_gemm_02_tma/
│ │ ├── build/
│ │ │ └── temp.linux-x86_64-cpython-312/
│ │ │ ├── .ninja_deps
│ │ │ ├── .ninja_log
│ │ │ ├── build.ninja
│ │ │ ├── sm100_gemm.o
│ │ │ └── sm100_gemm_pytorch.o
│ │ ├── dist/
│ │ │ └── sm100_gemm-0.0.0-py3.12-linux-x86_64.egg
│ │ ├── driver.py
│ │ ├── setup.py
│ │ ├── sm100_gemm.cu
│ │ ├── sm100_gemm.egg-info/
│ │ │ ├── PKG-INFO
│ │ │ ├── SOURCES.txt
│ │ │ ├── dependency_links.txt
│ │ │ ├── not-zip-safe
│ │ │ ├── requires.txt
│ │ │ └── top_level.txt
│ │ ├── sm100_gemm.h
│ │ └── sm100_gemm_pytorch.cpp
│ ├── cuda/
│ │ ├── cutlass_gemm/
│ │ │ ├── broadcast_load_epilogue_c3x.hpp
│ │ │ ├── common.hpp
│ │ │ ├── cutlass.cpp
│ │ │ ├── cutlass_kernel.cu
│ │ │ ├── readme.md
│ │ │ ├── setup.py
│ │ │ └── test_cutlass_gemm.py
│ │ ├── inference/
│ │ │ ├── README.md
│ │ │ └── hadamard_transform/
│ │ │ ├── hadamard_transform.cpp
│ │ │ ├── hadamard_transform_cuda.cu
│ │ │ ├── setup.py
│ │ │ └── test.py
│ │ ├── training/
│ │ │ └── README.md
│ │ └── tutorials/
│ │ ├── README.md
│ │ └── flash2.cu
│ ├── needs_perf_help/
│ │ ├── fp8_gemm_bench.py
│ │ └── fp8_rowwise_tma_persistent.py
│ └── triton/
│ ├── inference/
│ │ ├── README.md
│ │ ├── col_major_moe_gemm/
│ │ │ ├── README.md
│ │ │ ├── perf_test_moe.py
│ │ │ ├── profile_moe.py
│ │ │ ├── results.html
│ │ │ ├── test.csv
│ │ │ ├── test_moe_gemm.py
│ │ │ ├── v0_moe_fused.py
│ │ │ ├── v1_moe_fused.py
│ │ │ └── v2_moe_fused.py
│ │ ├── flash_attention/
│ │ │ └── stay_attention.py
│ │ ├── fp8/
│ │ │ ├── float8_groupwise_quant.py
│ │ │ ├── scaled_fp8_gemm.py
│ │ │ ├── splitk_gemm_fp8.py
│ │ │ └── tma_gemm.py
│ │ ├── gptq/
│ │ │ ├── a100_qlinear.py
│ │ │ ├── benchmark.py
│ │ │ ├── h100_qlinear.py
│ │ │ ├── mixtral/
│ │ │ │ ├── test_dequant_moe_gemm.py
│ │ │ │ └── w4a16_fused_dequant_gemm.py
│ │ │ ├── small_benchmark_cuda_graphs.py
│ │ │ └── splitk_dequant_gemm.py
│ │ ├── mamba/
│ │ │ └── causal_1d_conv/
│ │ │ ├── causal_1d_conv/
│ │ │ │ └── causal_1d_conv.py
│ │ │ └── tests/
│ │ │ └── test_causal_1d_conv.py
│ │ ├── paged_attention/
│ │ │ └── attention_triton.py
│ │ └── torch_compile/
│ │ └── flash_backward.py
│ ├── training/
│ │ ├── README.md
│ │ ├── fused_softmax/
│ │ │ ├── README.md
│ │ │ └── softmax.py
│ │ └── rms_norm/
│ │ └── fused_rms_norm.py
│ └── tutorials/
│ └── README.md
├── readme.md
└── tutorials/
└── triton/
├── kernels/
│ ├── __init__.py
│ ├── flash_attention_fwd.py
│ ├── fused_softmax.py
│ ├── readme.md
│ └── vector_add.py
└── tests/
├── test_softmax.py
└── test_utils.py
SYMBOL INDEX (317 symbols across 52 files)
FILE: dev/sr/src/stochastic_rounding.hpp
type philox (line 9) | namespace philox {
class PhiloxGenerator (line 18) | class PhiloxGenerator {
FILE: dev/sr/tests/benchmark.py
function measure_performance (line 8) | def measure_performance(func, input_tensor, warmup=0, repeats=1):
function benchmark_sizes (line 27) | def benchmark_sizes(sizes= [1000, 10000, 100000, 1000000, 10000000, (100...
function benchmark_shapes (line 59) | def benchmark_shapes(total_size=1000000):
function stress_test (line 84) | def stress_test(duration=10):
function memory_test (line 100) | def memory_test(max_size=1e9):
function main (line 128) | def main():
FILE: dev/sr/tests/core_unit_tests.py
class TestStochasticRounding (line 8) | class TestStochasticRounding(unittest.TestCase):
method setup (line 9) | def setup(self):
method _test_rounding_statistics_helper (line 14) | def _test_rounding_statistics_helper(self, value, lower_value, upper_v...
method test_special_values (line 40) | def test_special_values(self):
method test_small_values (line 59) | def test_small_values(self):
method test_vectorized_loading (line 67) | def test_vectorized_loading(self):
method test_large_values (line 81) | def test_large_values(self):
method test_rounding_statistics (line 89) | def test_rounding_statistics(self):
method test_rounding_statistics_2 (line 93) | def test_rounding_statistics_2(self):
method test_rounding_statistics_small (line 97) | def test_rounding_statistics_small(self):
method test_rounding_statistics_large (line 101) | def test_rounding_statistics_large(self):
FILE: dev/triton_groupGEMM/groupgemm.py
function early_config_prune (line 61) | def early_config_prune(configs, named_args, dtsize=None, dtype=None, **k...
function _kernel_grouped_gemm (line 131) | def _kernel_grouped_gemm(
function _kernel_grouped_gemm_fp8_rowwise (line 270) | def _kernel_grouped_gemm_fp8_rowwise(
function _grouped_gemm (line 407) | def _grouped_gemm(
function grouped_gemm (line 518) | def grouped_gemm(
function grouped_gemm_fp8_rowwise (line 524) | def grouped_gemm_fp8_rowwise(
FILE: dev/triton_groupGEMM/testing/base_testing.py
class TestGroupedGEMM (line 39) | class TestGroupedGEMM(unittest.TestCase):
method setUp (line 40) | def setUp(self) -> None:
method test_grouped_gemm_bf16 (line 98) | def test_grouped_gemm_bf16(self) -> None:
FILE: dev/triton_groupGEMM/testing/unit_tests.py
class TestGroupedGEMM (line 23) | class TestGroupedGEMM(unittest.TestCase):
method test_grouped_gemm_bf16 (line 24) | def test_grouped_gemm_bf16(self) -> None:
method test_grouped_gemm_bf16_various_dimensions (line 63) | def test_grouped_gemm_bf16_various_dimensions(self) -> None:
method test_grouped_gemm_bf16_edge_cases (line 105) | def test_grouped_gemm_bf16_edge_cases(self) -> None:
method test_grouped_gemm_bf16_invalid_inputs (line 168) | def test_grouped_gemm_bf16_invalid_inputs(self) -> None:
method test_grouped_gemm_bf16_deterministic (line 214) | def test_grouped_gemm_bf16_deterministic(self) -> None:
method test_grouped_gemm_bf16_large_matrices (line 235) | def test_grouped_gemm_bf16_large_matrices(self) -> None:
FILE: dev/triton_groupGEMM/tma_utils.py
function map_dtype_to_triton (line 18) | def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:
class TmaAutoTuneHelper (line 58) | class TmaAutoTuneHelper:
class KernelParamWrapper (line 61) | class KernelParamWrapper:
method __init__ (line 62) | def __init__(self, desc):
method tma_desc_cpu_ptr (line 65) | def tma_desc_cpu_ptr(self):
method __init__ (line 70) | def __init__(self):
method init_tma_descriptor (line 83) | def init_tma_descriptor(self, name):
method fill_1d_tma_descriptor (line 94) | def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_si...
method fill_2d_tma_descriptor (line 110) | def fill_2d_tma_descriptor(
method get_tma_descriptor_kernel_param (line 127) | def get_tma_descriptor_kernel_param(self, name):
FILE: dev/triton_groupGEMM/triton_tutorial_groupgemm.py
function is_cuda (line 41) | def is_cuda():
function supports_tma (line 45) | def supports_tma():
function num_sms (line 49) | def num_sms():
function grouped_matmul_kernel (line 109) | def grouped_matmul_kernel(
function group_gemm_fn (line 187) | def group_gemm_fn(group_A, group_B):
function grouped_matmul_tma_kernel (line 250) | def grouped_matmul_tma_kernel(
function group_gemm_tma_fn (line 346) | def group_gemm_tma_fn(group_A, group_B):
function triton_perf_fn (line 433) | def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size):
function triton_tma_perf_fn (line 445) | def triton_tma_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, d...
function torch_perf_fn (line 459) | def torch_perf_fn(group_A, group_B):
function benchmark_square_matrices (line 484) | def benchmark_square_matrices(N, provider):
function benchmark_batches (line 567) | def benchmark_batches(M, provider):
FILE: kernels/MoE/group_GEMM/triton/testing/fast_verification.py
function test_backward_pass (line 23) | def test_backward_pass():
FILE: kernels/MoE/group_GEMM/triton/testing/pytorch_reference_backwards.py
function _compute_grad_x_pytorch (line 15) | def _compute_grad_x_pytorch(grad_output, w, m_sizes, grad_x):
function _compute_grad_w_pytorch (line 68) | def _compute_grad_w_pytorch(grad_output, x, m_sizes, grad_w):
function _pytorch_fallback_backward (line 139) | def _pytorch_fallback_backward(grad_output, x, w, m_sizes):
function _pytorch_reference_backward (line 165) | def _pytorch_reference_backward(grad_output, x, w, m_sizes):
FILE: kernels/MoE/group_GEMM/triton/tgroup_gemm_backwards.py
function _kernel_grouped_gemm_backward_x_scheduled (line 28) | def _kernel_grouped_gemm_backward_x_scheduled(
function _kernel_grouped_gemm_backward_w_scheduled (line 202) | def _kernel_grouped_gemm_backward_w_scheduled(
function grouped_gemm_backward (line 382) | def grouped_gemm_backward(
FILE: kernels/MoE/group_GEMM/triton/tgroup_gemm_forward.py
function _kernel_grouped_gemm (line 137) | def _kernel_grouped_gemm(
function _kernel_grouped_gemm_fp8_rowwise (line 312) | def _kernel_grouped_gemm_fp8_rowwise(
function _grouped_gemm (line 485) | def _grouped_gemm(
function grouped_gemm_forward (line 626) | def grouped_gemm_forward(
function grouped_gemm_fp8_rowwise (line 632) | def grouped_gemm_fp8_rowwise(
FILE: kernels/MoE/group_GEMM/triton/utils/tma_utils.py
function map_dtype_to_triton (line 18) | def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:
class TmaAutoTuneHelper (line 58) | class TmaAutoTuneHelper:
class KernelParamWrapper (line 61) | class KernelParamWrapper:
method __init__ (line 62) | def __init__(self, desc):
method tma_desc_cpu_ptr (line 65) | def tma_desc_cpu_ptr(self):
method __init__ (line 70) | def __init__(self):
method init_tma_descriptor (line 83) | def init_tma_descriptor(self, name):
method fill_1d_tma_descriptor (line 94) | def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_si...
method fill_2d_tma_descriptor (line 110) | def fill_2d_tma_descriptor(
method get_tma_descriptor_kernel_param (line 127) | def get_tma_descriptor_kernel_param(self, name):
FILE: kernels/blackwell/cute_gemm_01/driver.py
function sm100_gemm_f16 (line 17) | def sm100_gemm_f16(A, B, C=None, alpha=1.0, beta=0.0):
function benchmark_sm100_vs_torch (line 70) | def benchmark_sm100_vs_torch(
FILE: kernels/blackwell/cute_gemm_01/sm100_gemm_pytorch.cpp
function is_sm100_supported (line 11) | bool is_sm100_supported() {
function check_sm100_device (line 20) | bool check_sm100_device() {
function sm100_gemm_f16 (line 34) | torch::Tensor sm100_gemm_f16(const torch::Tensor &A, const torch::Tensor...
function get_device_info (line 109) | torch::Tensor get_device_info() {
function get_aligned_shape (line 128) | std::vector<int64_t> get_aligned_shape(int64_t M, int64_t N, int64_t K) {
function create_aligned_tensor (line 137) | torch::Tensor create_aligned_tensor(const std::vector<int64_t> &shape,
function PYBIND11_MODULE (line 156) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: kernels/blackwell/cute_gemm_02_tma/driver.py
function check_sm100_compatibility (line 14) | def check_sm100_compatibility():
function sm100_gemm_f16_tma (line 38) | def sm100_gemm_f16_tma(A, B, C=None, alpha=1.0, beta=0.0, check_alignmen...
function create_aligned_tensors (line 114) | def create_aligned_tensors(
function pad_to_aligned (line 132) | def pad_to_aligned(tensor, target_shape=None, dim_requirements=None):
function unpad_result (line 175) | def unpad_result(tensor, padding_info):
function benchmark_sm100_vs_torch (line 181) | def benchmark_sm100_vs_torch(
class SM100LinearTMA (line 284) | class SM100LinearTMA(torch.nn.Module):
method __init__ (line 289) | def __init__(self, in_features, out_features, bias=True, device="cuda"):
method forward (line 324) | def forward(self, x):
function benchmark_tma_vs_cooperative_copy (line 363) | def benchmark_tma_vs_cooperative_copy(M=512, N=1024, K=256, num_trials=50):
function stress_test_large_matrices (line 377) | def stress_test_large_matrices():
function check_sm100_compatibility (line 528) | def check_sm100_compatibility():
function sm100_gemm_f16 (line 552) | def sm100_gemm_f16(A, B, C=None, alpha=1.0, beta=0.0, check_alignment=Tr...
function create_aligned_tensors (line 623) | def create_aligned_tensors(
function pad_to_aligned (line 641) | def pad_to_aligned(tensor, target_shape=None, dim_requirements=None):
function unpad_result (line 684) | def unpad_result(tensor, padding_info):
function benchmark_sm100_vs_torch (line 690) | def benchmark_sm100_vs_torch(
class SM100Linear (line 781) | class SM100Linear(torch.nn.Module):
method __init__ (line 786) | def __init__(self, in_features, out_features, bias=True, device="cuda"):
method forward (line 820) | def forward(self, x):
FILE: kernels/blackwell/cute_gemm_02_tma/sm100_gemm_pytorch.cpp
function is_sm100_supported (line 11) | bool is_sm100_supported() {
function check_sm100_device (line 20) | bool check_sm100_device() {
function sm100_gemm_f16 (line 34) | torch::Tensor sm100_gemm_f16(const torch::Tensor &A, const torch::Tensor...
function get_device_info (line 109) | torch::Tensor get_device_info() {
function get_aligned_shape (line 128) | std::vector<int64_t> get_aligned_shape(int64_t M, int64_t N, int64_t K) {
function create_aligned_tensor (line 137) | torch::Tensor create_aligned_tensor(const std::vector<int64_t> &shape,
function PYBIND11_MODULE (line 156) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: kernels/cuda/cutlass_gemm/broadcast_load_epilogue_c3x.hpp
type cutlass::epilogue::fusion (line 60) | namespace cutlass::epilogue::fusion {
type Sm90RowOrScalarBroadcast (line 75) | struct Sm90RowOrScalarBroadcast {
type SharedStorage (line 82) | struct SharedStorage {
type Arguments (line 89) | struct Arguments {
method Params (line 98) | static constexpr Params
method get_workspace_size (line 104) | static size_t
method initialize_workspace (line 110) | static cutlass::Status
method CUTLASS_HOST_DEVICE (line 116) | CUTLASS_HOST_DEVICE
method CUTLASS_DEVICE (line 127) | CUTLASS_DEVICE bool
method CUTLASS_DEVICE (line 132) | CUTLASS_DEVICE bool
method CUTLASS_DEVICE (line 137) | CUTLASS_DEVICE bool
type ProducerLoadCallbacks (line 143) | struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks {
method CUTLASS_DEVICE (line 154) | CUTLASS_DEVICE void
method get_producer_load_callbacks (line 174) | CUTLASS_DEVICE auto
type ConsumerStoreCallbacks (line 191) | struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
method CUTLASS_DEVICE (line 202) | CUTLASS_DEVICE void
method CUTLASS_DEVICE (line 218) | CUTLASS_DEVICE Array<Element, FragmentSize>
method get_consumer_store_callbacks (line 235) | CUTLASS_DEVICE auto
type Sm90ColOrScalarBroadcast (line 261) | struct Sm90ColOrScalarBroadcast {
type SharedStorage (line 269) | struct SharedStorage { }
type Arguments (line 274) | struct Arguments {
method Params (line 283) | static constexpr Params
method get_workspace_size (line 289) | static size_t
method initialize_workspace (line 295) | static cutlass::Status
method CUTLASS_DEVICE (line 301) | CUTLASS_DEVICE bool
method CUTLASS_DEVICE (line 306) | CUTLASS_DEVICE bool
method CUTLASS_DEVICE (line 311) | CUTLASS_DEVICE bool
method CUTLASS_HOST_DEVICE (line 316) | CUTLASS_HOST_DEVICE
method CUTLASS_HOST_DEVICE (line 319) | CUTLASS_HOST_DEVICE
method get_producer_load_callbacks (line 326) | CUTLASS_DEVICE auto
type ConsumerStoreCallbacks (line 332) | struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
method CUTLASS_DEVICE (line 343) | CUTLASS_DEVICE void
method CUTLASS_DEVICE (line 356) | CUTLASS_DEVICE Array<Element, FragmentSize>
method get_consumer_store_callbacks (line 375) | CUTLASS_DEVICE auto
FILE: kernels/cuda/cutlass_gemm/common.hpp
function next_pow_2 (line 15) | inline uint32_t next_pow_2(uint32_t const num) {
function get_cuda_max_shared_memory_per_block_opt_in (line 20) | inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
FILE: kernels/cuda/cutlass_gemm/cutlass.cpp
function cutlass_scaled_mm (line 7) | torch::Tensor cutlass_scaled_mm(torch::Tensor a, torch::Tensor b, torch:...
function PYBIND11_MODULE (line 17) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: kernels/cuda/inference/hadamard_transform/hadamard_transform.cpp
function is_power_of_two (line 11) | constexpr bool is_power_of_two(uint32_t x) {
function hadamard_transform (line 15) | torch::Tensor hadamard_transform(at::Tensor& in, bool inplace) {
function PYBIND11_MODULE (line 58) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: kernels/cuda/inference/hadamard_transform/test.py
function get_scale (line 24) | def get_scale(size):
function truth_hadamard_transform_inplace (line 34) | def truth_hadamard_transform_inplace(a: torch.Tensor, truth_hadamards):
function test_hadamard_transform_inplace_rowmajor (line 42) | def test_hadamard_transform_inplace_rowmajor(a: torch.Tensor):
function check_correctness (line 48) | def check_correctness(m, elem_c, a, result, truth, atol=1e-2, rtol=0):
FILE: kernels/needs_perf_help/fp8_gemm_bench.py
function bench (line 28) | def bench(cuda_graph: bool, rowwise_tma: bool=True) -> None:
function bf16_bench (line 107) | def bf16_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]:
function scale_row_bench (line 114) | def scale_row_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]:
function row_gemm_bench (line 136) | def row_gemm_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]:
function row_gemm_bench_tma (line 159) | def row_gemm_bench_tma(x: Tensor, w: Tensor) -> Callable[[], Tensor]:
FILE: kernels/needs_perf_help/fp8_rowwise_tma_persistent.py
function get_fp8_constants (line 27) | def get_fp8_constants() -> Tuple[torch.dtype, tl.dtype, float, float]:
function convert_fp8_type (line 46) | def convert_fp8_type(tensor, dtype) -> triton.TensorWrapper:
function init_to_zero (line 60) | def init_to_zero(name):
function get_configs_io_bound (line 64) | def get_configs_io_bound() -> List[Config]:
function _kernel_matmul_fp8_row_tma_persistent (line 108) | def _kernel_matmul_fp8_row_tma_persistent(
function matmul_fp8_row (line 235) | def matmul_fp8_row(
FILE: kernels/triton/inference/col_major_moe_gemm/perf_test_moe.py
function torch_moe (line 16) | def torch_moe(a, w1, w2, topk_weight, topk_ids):
function test_fused_moe (line 33) | def test_fused_moe(
function benchmark (line 99) | def benchmark(m, provider):
FILE: kernels/triton/inference/col_major_moe_gemm/profile_moe.py
function torch_moe (line 15) | def torch_moe(a, w1, w2, topk_weight, topk_ids):
function test_fused_moe (line 32) | def test_fused_moe(
FILE: kernels/triton/inference/col_major_moe_gemm/test_moe_gemm.py
function torch_moe (line 16) | def torch_moe(a, w1, w2, topk_weight, topk_ids):
function test_fused_moe (line 39) | def test_fused_moe(
FILE: kernels/triton/inference/col_major_moe_gemm/v0_moe_fused.py
function fused_moe_kernel (line 18) | def fused_moe_kernel(
function moe_align_block_size (line 138) | def moe_align_block_size(
function invoke_fused_moe_kernel (line 183) | def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.T...
function fused_moe (line 222) | def fused_moe(hidden_states: torch.Tensor,
FILE: kernels/triton/inference/col_major_moe_gemm/v1_moe_fused.py
function grouped_launch (line 21) | def grouped_launch(pid,
function fused_moe_kernel_splitk (line 39) | def fused_moe_kernel_splitk(
function moe_align_block_size (line 150) | def moe_align_block_size(
function invoke_fused_moe_kernel (line 195) | def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.T...
function fused_moe (line 249) | def fused_moe(hidden_states: torch.Tensor,
FILE: kernels/triton/inference/col_major_moe_gemm/v2_moe_fused.py
function col_major (line 18) | def col_major(pid,
function fused_moe_kernel (line 31) | def fused_moe_kernel(
function moe_align_block_size (line 136) | def moe_align_block_size(
function invoke_fused_moe_kernel (line 181) | def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.T...
function fused_moe (line 220) | def fused_moe(hidden_states: torch.Tensor,
FILE: kernels/triton/inference/flash_attention/stay_attention.py
function stay_attention (line 7) | def stay_attention(
function flash_fn (line 107) | def flash_fn(q, k, v):
FILE: kernels/triton/inference/fp8/float8_groupwise_quant.py
function _float8_groupwise_quant_kernel (line 21) | def _float8_groupwise_quant_kernel(
function float8_groupwise_quantize (line 53) | def float8_groupwise_quantize(x: torch.Tensor, block_size=128):
FILE: kernels/triton/inference/fp8/scaled_fp8_gemm.py
function grouped_launch (line 10) | def grouped_launch(pid,
function column_major (line 27) | def column_major(pid,
function scaled_gemm_splitk (line 39) | def scaled_gemm_splitk(a_ptr, b_ptr, c_ptr,
function scaled_mm_splitk (line 94) | def scaled_mm_splitk(a, b, scale_a: float=1.0, scale_b: float=1.0):
FILE: kernels/triton/inference/fp8/splitk_gemm_fp8.py
function grouped_launch (line 9) | def grouped_launch(pid,
function col_major (line 27) | def col_major(pid,
function gemm_split_k_kernel (line 40) | def gemm_split_k_kernel(a_ptr, b_ptr, c_ptr,
function gemm_split_k (line 90) | def gemm_split_k(a, b):
FILE: kernels/triton/inference/fp8/tma_gemm.py
function gemm_kernel_tma (line 7) | def gemm_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
function matmul (line 32) | def matmul(a, b, config=None):
FILE: kernels/triton/inference/gptq/a100_qlinear.py
function _a100_quantized_matmul (line 6) | def _a100_quantized_matmul(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr,
class a100_qlinear (line 77) | class a100_qlinear(torch.autograd.Function):
method forward (line 78) | def forward(ctx, a, b, scales, zeros):
FILE: kernels/triton/inference/gptq/benchmark.py
function benchmark_generation_speed (line 15) | def benchmark_generation_speed(model, tokenizer, prompt, batch_size, dev...
function main (line 59) | def main():
FILE: kernels/triton/inference/gptq/h100_qlinear.py
function _h100_quantized_matmul (line 7) | def _h100_quantized_matmul(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr,
class h100_qlinear (line 87) | class h100_qlinear(torch.autograd.Function):
method forward (line 88) | def forward(ctx, a, b, scales, zeros):
FILE: kernels/triton/inference/gptq/mixtral/test_dequant_moe_gemm.py
function torch_moe (line 9) | def torch_moe(a, w1, w2, topk_weight, topk_ids):
function test_dequant_moe (line 25) | def test_dequant_moe(
FILE: kernels/triton/inference/gptq/mixtral/w4a16_fused_dequant_gemm.py
function print_tensor_dim (line 9) | def print_tensor_dim(tensor, str_name):
function print_value (line 13) | def print_value(value):
function grouped_launch (line 18) | def grouped_launch(pid,
function col_major (line 36) | def col_major(pid,
function w4a16_fused_moe_kernel (line 50) | def w4a16_fused_moe_kernel(
function invoke_dequant_gemm_moe (line 153) | def invoke_dequant_gemm_moe(activations: torch.Tensor,
function moe_align_block_size (line 211) | def moe_align_block_size(
function dequant_gemm_moe (line 255) | def dequant_gemm_moe(hidden_states: torch.Tensor,
FILE: kernels/triton/inference/gptq/small_benchmark_cuda_graphs.py
function swizzle_tile (line 11) | def swizzle_tile(pid,
function matmul_data_parallel_kernel (line 29) | def matmul_data_parallel_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr,
class small_qlinear (line 109) | class small_qlinear(torch.autograd.Function):
method forward (line 110) | def forward(ctx, a, b, scales, zeros):
function matmul_split_k_kernel (line 161) | def matmul_split_k_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr,
function matmul_split_k (line 228) | def matmul_split_k(a, b, scales, zeros):
function make_tensor (line 281) | def make_tensor(M, N, dtype):
function gen_quant4 (line 292) | def gen_quant4(m, n, groupsize=-1):
FILE: kernels/triton/inference/gptq/splitk_dequant_gemm.py
function swizzle_tile (line 7) | def swizzle_tile(pid,
function matmul_split_k_kernel (line 24) | def matmul_split_k_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr,
function matmul_split_k (line 91) | def matmul_split_k(a, b, scales, zeros):
function make_tensor (line 143) | def make_tensor(M, N, dtype):
FILE: kernels/triton/inference/mamba/causal_1d_conv/causal_1d_conv/causal_1d_conv.py
function _causal_conv1d_fwd_kernel (line 27) | def _causal_conv1d_fwd_kernel(
function causal_conv1d_fwd (line 121) | def causal_conv1d_fwd(
class CausalConv1dFn (line 203) | class CausalConv1dFn(torch.autograd.Function):
method forward (line 205) | def forward(
method backward (line 259) | def backward(ctx, dout, *args):
function causal_conv1d_fn (line 295) | def causal_conv1d_fn(
FILE: kernels/triton/inference/mamba/causal_1d_conv/tests/test_causal_1d_conv.py
function _undecorated_test_causal_conv1d (line 24) | def _undecorated_test_causal_conv1d(
function causal_conv1d_ref (line 117) | def causal_conv1d_ref(
function test_causal_conv1d (line 183) | def test_causal_conv1d(
FILE: kernels/triton/inference/paged_attention/attention_triton.py
function print_tensor_dim (line 14) | def print_tensor_dim(tensor, str_name):
function print_value (line 20) | def print_value(value):
function print_line (line 27) | def print_line(str_line):
function paged_attention_v1 (line 33) | def paged_attention_v1(
function paged_attention_triton_v1 (line 158) | def paged_attention_triton_v1(
function paged_attention_v2 (line 206) | def paged_attention_v2(
function paged_attention_triton_v2 (line 358) | def paged_attention_triton_v2(
FILE: kernels/triton/inference/torch_compile/flash_backward.py
class MetaData (line 43) | class MetaData():
method __init__ (line 55) | def __init__(self, sm_scale=1.0):
method set_varlen_params (line 58) | def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
method need_bias (line 70) | def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k):
method need_alibi (line 77) | def need_alibi(self, alibi_slopes, batch, nheads):
method need_causal (line 84) | def need_causal(self):
method need_dropout (line 87) | def need_dropout(dropout_p, return_encoded_softmax):
method check_args (line 91) | def check_args(self, q, k, v, o):
function cdiv_fn (line 120) | def cdiv_fn(x,y):
function max_fn (line 124) | def max_fn(x, y):
function dropout_offsets (line 128) | def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
function dropout_rng (line 134) | def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
function dropout_mask (line 140) | def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
function load_fn (line 146) | def load_fn(block_ptr, first, second, pad):
function _attn_fwd_inner (line 158) | def _attn_fwd_inner(
function attn_fwd (line 270) | def attn_fwd(
function attention (line 553) | def attention(q, k, v, sm_scale):
function _attn_bwd_preprocess (line 637) | def _attn_bwd_preprocess(
function _bwd_kernel_dk_dv (line 691) | def _bwd_kernel_dk_dv(
function _bwd_kernel_dq (line 761) | def _bwd_kernel_dq(dq, q, K, V,
function _attn_bwd (line 821) | def _attn_bwd(Q, K, V, sm_scale, alibi_slopes,
function flash_bwd (line 1016) | def flash_bwd(q: torch.Tensor, k: torch.Tensor, v:torch.Tensor, o: torch...
function _ (line 1074) | def _(q, k, v, o, M, do):
function flash (line 1083) | def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, o: torch.Te...
function _ (line 1159) | def _(q, k, v, o, M):
function setup_context (line 1163) | def setup_context(ctx, inputs, output) -> torch.Tensor:
function backward (line 1167) | def backward(ctx, do):
function f (line 1187) | def f(q, k, v):
FILE: kernels/triton/training/fused_softmax/softmax.py
function _get_num_warps (line 18) | def _get_num_warps(block_size: int)-> int:
function _softmax_kernel_fwd (line 27) | def _softmax_kernel_fwd(
function _softmax_kernel_bwd (line 56) | def _softmax_kernel_bwd(
class triton_softmax (line 93) | class triton_softmax(autograd.Function):
method forward (line 95) | def forward(ctx, x):
method backward (line 122) | def backward(ctx, grad_probs):
FILE: kernels/triton/training/rms_norm/fused_rms_norm.py
function _rms_norm_fwd_kernel (line 33) | def _rms_norm_fwd_kernel(
function _rms_norm_bwd_kernel_sm (line 83) | def _rms_norm_bwd_kernel_sm(
class ttt_RMSNorm (line 181) | class ttt_RMSNorm(torch.autograd.Function):
method forward (line 183) | def forward(ctx, x, weight, eps):
method backward (line 225) | def backward(ctx, dy):
function fused_rms_norm_fn (line 313) | def fused_rms_norm_fn(
class FusedRMSNorm (line 325) | class FusedRMSNorm(torch.nn.Module):
method __init__ (line 326) | def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, ...
method reset_parameters (line 333) | def reset_parameters(self):
method forward (line 336) | def forward(
FILE: tutorials/triton/kernels/fused_softmax.py
function _get_num_warps (line 15) | def _get_num_warps(block_size: int)-> int:
function _softmax_kernel_fwd (line 24) | def _softmax_kernel_fwd(
function _softmax_kernel_bwd (line 53) | def _softmax_kernel_bwd(
class triton_softmax (line 90) | class triton_softmax(autograd.Function):
method forward (line 92) | def forward(ctx, x):
method backward (line 119) | def backward(ctx, grad_probs):
FILE: tutorials/triton/kernels/vector_add.py
function kernel_vector_addition (line 9) | def kernel_vector_addition(a_ptr, b_ptr, out_ptr,
function ceil_div (line 24) | def ceil_div(x: int, y: int)-> int:
function vector_addition (line 27) | def vector_addition(a: torch.tensor, b: torch.tensor)-> torch.tensor:
FILE: tutorials/triton/tests/test_softmax.py
function set_seed (line 13) | def set_seed():
class TestForwardSoftMax (line 18) | class TestForwardSoftMax:
method test_forward_2D_float32 (line 20) | def test_forward_2D_float32(self,):
method test_forward_2D_bfloat16 (line 36) | def test_forward_2D_bfloat16(self,):
method test_forward_3D_bfloat16 (line 51) | def test_forward_3D_bfloat16(self,):
class TestBackwardSoftMax (line 70) | class TestBackwardSoftMax:
method test_backward_2D (line 72) | def test_backward_2D(self,):
method test_bwd_3D (line 96) | def test_bwd_3D(self,):
FILE: tutorials/triton/tests/test_utils.py
function assert_expected (line 10) | def assert_expected(
function set_rng_seed (line 26) | def set_rng_seed(seed):
function gpu_test (line 31) | def gpu_test(gpu_count: int = 1):
Condensed preview — 119 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (587K chars).
[
{
"path": ".gitignore",
"chars": 28,
"preview": "*.pyc\n**/.ipynb_checkpoints\n"
},
{
"path": "CODE_OF_CONDUCT.md",
"chars": 3537,
"preview": "# Code of Conduct\n\n## Our Pledge\n\nIn the interest of fostering an open and welcoming environment, we as\ncontributors and"
},
{
"path": "CONTRIBUTING.md",
"chars": 1445,
"preview": "# Contributing to Applied AI\nWe want to make contributing to this project as easy and transparent as\npossible.\n\n## Our D"
},
{
"path": "LICENSE",
"chars": 1449,
"preview": "Copyright 2024 Meta\n\nRedistribution and use in source and binary forms, with or without modification, are permitted prov"
},
{
"path": "assets/images/dev-discuss-asynctp/readme.md",
"chars": 297,
"preview": "This folder is for hosting the images for the AsyncTP public post at: \n[https://discuss.pytorch.org/t/distributed-w-t"
},
{
"path": "assets/images/readme.md",
"chars": 43,
"preview": "Folder for housing images for the readmes.\n"
},
{
"path": "dev/sr/.gitignore",
"chars": 73,
"preview": "*.o\n*.ninja\n*.txt\n*.egg-info\n*.ninja-deps\n*.ninja-log/\n*.so\ndist/\nbuild/\n"
},
{
"path": "dev/sr/readme.md",
"chars": 98,
"preview": "Branch for stochastic rounding kernel\nCurrently processes 4 elements per thread to leverage rand4\n"
},
{
"path": "dev/sr/setup.py",
"chars": 929,
"preview": "\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\nsetup(\n name='stoc"
},
{
"path": "dev/sr/src/stochastic_rounding.cu",
"chars": 2183,
"preview": "\n#include <pybind11/pybind11.h>\n#include \"stochastic_rounding.hpp\"\n#include <random>\n\nnamespace py = pybind11;\n\n__host__"
},
{
"path": "dev/sr/src/stochastic_rounding.hpp",
"chars": 1217,
"preview": "\n#pragma once\n#include <cuda_bf16.h>\n#include <cuda_runtime.h>\n#include <vector_types.h>\n#include <torch/extension.h>\n#i"
},
{
"path": "dev/sr/src/stochastic_rounding_cuda.cu",
"chars": 3637,
"preview": " #include \"stochastic_rounding.hpp\"\n#include <cstdint>\n\n// Philox RNG implementation\n\n__device__ __forceinline__ PhiloxG"
},
{
"path": "dev/sr/test.md",
"chars": 545,
"preview": "(tkdev11) [less@devgpu115.cco2 ~/local/applied-ai/dev/sr (sr_kernel)]$ python usage.py\nLaunching kernel with blocks=1, t"
},
{
"path": "dev/sr/tests/benchmark.py",
"chars": 5028,
"preview": "import torch\nimport stochastic_rounding_cuda\nimport numpy as np\nimport time\nfrom tabulate import tabulate\nimport argpars"
},
{
"path": "dev/sr/tests/core_unit_tests.py",
"chars": 4443,
"preview": "import torch\nimport numpy as np\nfrom collections import Counter\nimport unittest\nimport stochastic_rounding_cuda\nimport t"
},
{
"path": "dev/sr/usage.py",
"chars": 772,
"preview": "import torch\nimport stochastic_rounding_cuda\n\n# Create input tensor\ninput_tensor = torch.randn(12, device='cuda', dtype="
},
{
"path": "dev/sr/usage2.py",
"chars": 359,
"preview": "import torch\nimport stochastic_rounding_cuda\n\n# Test tensor\nx = torch.tensor([9.8751e-01, -8.5288e-01, 1.6775e+00], devi"
},
{
"path": "dev/triton_groupGEMM/groupgemm.py",
"chars": 17760,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
},
{
"path": "dev/triton_groupGEMM/testing/base_testing.py",
"chars": 4906,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
},
{
"path": "dev/triton_groupGEMM/testing/unit_tests.py",
"chars": 10701,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
},
{
"path": "dev/triton_groupGEMM/tma_utils.py",
"chars": 4501,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
},
{
"path": "dev/triton_groupGEMM/triton_tutorial_groupgemm.py",
"chars": 21412,
"preview": "\"\"\"\nGroup GEMM\n============================\nThis group gemm kernel launches a fixed number of CTA to compute a group\nof "
},
{
"path": "kernels/MoE/group_GEMM/triton/readme.md",
"chars": 66,
"preview": "## Experimental\n\nTriton Group GEMM for supporting MoE training. \n"
},
{
"path": "kernels/MoE/group_GEMM/triton/testing/fast_verification.py",
"chars": 7319,
"preview": "import logging\n\nimport torch\n\n# Configure logging\nlogging.basicConfig(\n level=logging.INFO, format=\"%(asctime)s - %(l"
},
{
"path": "kernels/MoE/group_GEMM/triton/testing/pytorch_reference_backwards.py",
"chars": 6472,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
},
{
"path": "kernels/MoE/group_GEMM/triton/tgroup_gemm_backwards.py",
"chars": 21469,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
},
{
"path": "kernels/MoE/group_GEMM/triton/tgroup_gemm_forward.py",
"chars": 21277,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
},
{
"path": "kernels/MoE/group_GEMM/triton/utils/tma_utils.py",
"chars": 4501,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
},
{
"path": "kernels/blackwell/cute_gemm_01/Makefile",
"chars": 764,
"preview": "\n# Makefile for SM100 GEMM PyTorch Extension\n\n# Set these paths according to your installation\nCUTLASS_PATH ?= /path/to/"
},
{
"path": "kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/.ninja_log",
"chars": 473,
"preview": "# ninja log v5\n0\t15279\t1748131038212164071\t/data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_"
},
{
"path": "kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/build.ninja",
"chars": 3098,
"preview": "ninja_required_version = 1.3\ncxx = c++\nnvcc = /usr/local/cuda-12.8/bin/nvcc\n\ncflags = -pthread -B /home/less/.conda/envs"
},
{
"path": "kernels/blackwell/cute_gemm_01/driver.py",
"chars": 6895,
"preview": "# ==============================================================================\n# python_interface.py - High-level Pyth"
},
{
"path": "kernels/blackwell/cute_gemm_01/setup.py",
"chars": 2229,
"preview": "# setup.py\nimport os\n\nimport pybind11\nimport torch\nfrom pybind11 import get_cmake_dir\nfrom pybind11.setup_helpers import"
},
{
"path": "kernels/blackwell/cute_gemm_01/sm100_gemm.cu",
"chars": 8809,
"preview": "// sm100_gemm_kernel.cu - CUDA kernel implementation\n#include \"sm100_gemm.h\"\n\n#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORT"
},
{
"path": "kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/PKG-INFO",
"chars": 154,
"preview": "Metadata-Version: 2.4\nName: sm100_gemm\nVersion: 0.0.0\nRequires-Python: >=3.8\nRequires-Dist: torch>=1.12.0\nDynamic: requi"
},
{
"path": "kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/SOURCES.txt",
"chars": 247,
"preview": "setup.py\nsm100_gemm.cu\nsm100_gemm_pytorch.cpp\nsm100_gemm.egg-info/PKG-INFO\nsm100_gemm.egg-info/SOURCES.txt\nsm100_gemm.eg"
},
{
"path": "kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/dependency_links.txt",
"chars": 1,
"preview": "\n"
},
{
"path": "kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/not-zip-safe",
"chars": 1,
"preview": "\n"
},
{
"path": "kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/requires.txt",
"chars": 14,
"preview": "torch>=1.12.0\n"
},
{
"path": "kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/top_level.txt",
"chars": 11,
"preview": "sm100_gemm\n"
},
{
"path": "kernels/blackwell/cute_gemm_01/sm100_gemm.h",
"chars": 1305,
"preview": "// sm100_gemm_kernel.h - Header file for CUDA kernel\n#pragma once\n\n#include <cuda_runtime.h>\n\n#ifdef __cplusplus\nextern "
},
{
"path": "kernels/blackwell/cute_gemm_01/sm100_gemm_pytorch.cpp",
"chars": 6465,
"preview": "// sm100_gemm_pytorch.cpp - PyTorch C++ extension (no CUDA code)\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext."
},
{
"path": "kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/.ninja_log",
"chars": 664,
"preview": "# ninja log v5\n1\t15202\t1748185895110710199\t/data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.lin"
},
{
"path": "kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/build.ninja",
"chars": 3126,
"preview": "ninja_required_version = 1.3\ncxx = c++\nnvcc = /usr/local/cuda-12.8/bin/nvcc\n\ncflags = -pthread -B /home/less/.conda/envs"
},
{
"path": "kernels/blackwell/cute_gemm_02_tma/driver.py",
"chars": 30452,
"preview": "# python_interface.py - High-level Python interface with TMA support\n\nimport torch\n\ntry:\n import sm100_gemm # The co"
},
{
"path": "kernels/blackwell/cute_gemm_02_tma/setup.py",
"chars": 2229,
"preview": "# setup.py\nimport os\n\nimport pybind11\nimport torch\nfrom pybind11 import get_cmake_dir\nfrom pybind11.setup_helpers import"
},
{
"path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.cu",
"chars": 12042,
"preview": "// sm100_gemm_kernel.cu - CUDA kernel implementation with TMA\n#include \"sm100_gemm.h\"\n\n#if defined(CUTLASS_ARCH_MMA_SM10"
},
{
"path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/PKG-INFO",
"chars": 154,
"preview": "Metadata-Version: 2.4\nName: sm100_gemm\nVersion: 0.0.0\nRequires-Python: >=3.8\nRequires-Dist: torch>=1.12.0\nDynamic: requi"
},
{
"path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/SOURCES.txt",
"chars": 247,
"preview": "setup.py\nsm100_gemm.cu\nsm100_gemm_pytorch.cpp\nsm100_gemm.egg-info/PKG-INFO\nsm100_gemm.egg-info/SOURCES.txt\nsm100_gemm.eg"
},
{
"path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/dependency_links.txt",
"chars": 1,
"preview": "\n"
},
{
"path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/not-zip-safe",
"chars": 1,
"preview": "\n"
},
{
"path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/requires.txt",
"chars": 14,
"preview": "torch>=1.12.0\n"
},
{
"path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/top_level.txt",
"chars": 11,
"preview": "sm100_gemm\n"
},
{
"path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm.h",
"chars": 1355,
"preview": "// sm100_gemm_kernel.h - Header file for CUDA kernel\n#pragma once\n\n#include <cuda_runtime.h>\n\n#ifdef __cplusplus\nextern "
},
{
"path": "kernels/blackwell/cute_gemm_02_tma/sm100_gemm_pytorch.cpp",
"chars": 6469,
"preview": "// sm100_gemm_pytorch.cpp - PyTorch C++ extension (no CUDA code)\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext."
},
{
"path": "kernels/cuda/cutlass_gemm/broadcast_load_epilogue_c3x.hpp",
"chars": 15348,
"preview": "/***************************************************************************************************\n * Copyright (c) 20"
},
{
"path": "kernels/cuda/cutlass_gemm/common.hpp",
"chars": 807,
"preview": "#pragma once\n\n#include \"cutlass/cutlass.h\"\n#include <climits>\n\n/**\n * Helper function for checking CUTLASS errors\n */\n#d"
},
{
"path": "kernels/cuda/cutlass_gemm/cutlass.cpp",
"chars": 764,
"preview": "#include <torch/extension.h>\n#include<torch/all.h>\n\nvoid cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const&"
},
{
"path": "kernels/cuda/cutlass_gemm/cutlass_kernel.cu",
"chars": 18444,
"preview": "// clang-format will break include orders\n// clang-format off\n#include <cudaTypedefs.h>\n\n#if defined CUDA_VERSION && CUD"
},
{
"path": "kernels/cuda/cutlass_gemm/readme.md",
"chars": 160,
"preview": "Currently the CPP extension builds with Cutlass 3.5.1 (credit to @SamirMoustafa for the update). \n3.6 will fail atm due"
},
{
"path": "kernels/cuda/cutlass_gemm/setup.py",
"chars": 1139,
"preview": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\nsetup(\n name='cutla"
},
{
"path": "kernels/cuda/cutlass_gemm/test_cutlass_gemm.py",
"chars": 490,
"preview": "from pingpong_gemm import cutlass_scaled_mm\nimport torch\n\nm, k, n = 16, 4096, 4096\ndtype = torch.float8_e4m3fn\nout_dtype"
},
{
"path": "kernels/cuda/inference/README.md",
"chars": 13,
"preview": "cuda kernels\n"
},
{
"path": "kernels/cuda/inference/hadamard_transform/hadamard_transform.cpp",
"chars": 2100,
"preview": "#include <torch/extension.h>\n#include <pybind11/pybind11.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGu"
},
{
"path": "kernels/cuda/inference/hadamard_transform/hadamard_transform_cuda.cu",
"chars": 38515,
"preview": "#include <torch/extension.h>\n#include <stdint.h>\n#include <cuda_runtime.h>\n#include <mma.h>\n#include <cuda/annotated_ptr"
},
{
"path": "kernels/cuda/inference/hadamard_transform/setup.py",
"chars": 978,
"preview": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\nversions = [\n \"-gen"
},
{
"path": "kernels/cuda/inference/hadamard_transform/test.py",
"chars": 4743,
"preview": "import torch\nimport faster_hadamard_transform\nimport scipy.linalg\nimport math\n\n# set to false to check performance\ncorre"
},
{
"path": "kernels/cuda/training/README.md",
"chars": 35,
"preview": "kernels with backward pass support\n"
},
{
"path": "kernels/cuda/tutorials/README.md",
"chars": 15,
"preview": "CUDA tutorials\n"
},
{
"path": "kernels/cuda/tutorials/flash2.cu",
"chars": 1545,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
},
{
"path": "kernels/needs_perf_help/fp8_gemm_bench.py",
"chars": 5124,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
},
{
"path": "kernels/needs_perf_help/fp8_rowwise_tma_persistent.py",
"chars": 12923,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
},
{
"path": "kernels/triton/inference/README.md",
"chars": 25,
"preview": "Triton Inference kernels\n"
},
{
"path": "kernels/triton/inference/col_major_moe_gemm/README.md",
"chars": 497,
"preview": "\n**MoE (Mixture of Experts) GEMM Kernels**\n\n\nTriton kernel supporting and accelerating MoE inference (Mixtral).\nThis ker"
},
{
"path": "kernels/triton/inference/col_major_moe_gemm/perf_test_moe.py",
"chars": 4397,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
},
{
"path": "kernels/triton/inference/col_major_moe_gemm/profile_moe.py",
"chars": 1821,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
},
{
"path": "kernels/triton/inference/col_major_moe_gemm/results.html",
"chars": 52,
"preview": "<html><body>\n<image src=\"test.png\"/>\n</body></html>\n"
},
{
"path": "kernels/triton/inference/col_major_moe_gemm/test.csv",
"chars": 345,
"preview": "m,Fused MoE GEMM Kernel - Column Major,vLLM MoE GEMM Kernel\n1.000000,0.412454,0.259585\n2.000000,0.883064,0.269004\n4.0000"
},
{
"path": "kernels/triton/inference/col_major_moe_gemm/test_moe_gemm.py",
"chars": 2795,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
},
{
"path": "kernels/triton/inference/col_major_moe_gemm/v0_moe_fused.py",
"chars": 13030,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
},
{
"path": "kernels/triton/inference/col_major_moe_gemm/v1_moe_fused.py",
"chars": 13671,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
},
{
"path": "kernels/triton/inference/col_major_moe_gemm/v2_moe_fused.py",
"chars": 12229,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
},
{
"path": "kernels/triton/inference/flash_attention/stay_attention.py",
"chars": 4721,
"preview": "import triton.language as tl\nimport triton\nimport torch\n\n\n@triton.jit()\ndef stay_attention(\n q_ptr, k_ptr, v_ptr, o_p"
},
{
"path": "kernels/triton/inference/fp8/float8_groupwise_quant.py",
"chars": 2352,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
},
{
"path": "kernels/triton/inference/fp8/scaled_fp8_gemm.py",
"chars": 3840,
"preview": "import torch\nimport triton\nimport triton.language as tl\nimport time\nimport os\nos.environ['ENABLE_TMA'] = '1'\n\n\n@triton.j"
},
{
"path": "kernels/triton/inference/fp8/splitk_gemm_fp8.py",
"chars": 4537,
"preview": "import torch\nimport triton\nimport triton.language as tl\nimport time\nimport os\nos.environ['ENABLE_TMA'] = '1'\n\n@triton.ji"
},
{
"path": "kernels/triton/inference/fp8/tma_gemm.py",
"chars": 3199,
"preview": "import triton\nimport triton.language as tl\nimport numpy as np\nimport torch\n\n@triton.jit\ndef gemm_kernel_tma(a_desc_ptr, "
},
{
"path": "kernels/triton/inference/gptq/a100_qlinear.py",
"chars": 4633,
"preview": "import triton\nimport triton.language as tl\nimport torch \n\n@triton.jit()\ndef _a100_quantized_matmul(a_ptr, b_ptr, c_ptr, "
},
{
"path": "kernels/triton/inference/gptq/benchmark.py",
"chars": 3723,
"preview": "import argparse\nimport time\nimport logging\nfrom tqdm import tqdm\nimport torch\nfrom transformers import AutoTokenizer\nfro"
},
{
"path": "kernels/triton/inference/gptq/h100_qlinear.py",
"chars": 4341,
"preview": "import triton\nimport triton.language as tl\nimport torch \n\n\n@triton.jit()\ndef _h100_quantized_matmul(a_ptr, b_ptr, c_ptr,"
},
{
"path": "kernels/triton/inference/gptq/mixtral/test_dequant_moe_gemm.py",
"chars": 2876,
"preview": "import pytest\nimport torch\nfrom vllm.model_executor.layers.fused_moe import fused_moe\nfrom vllm.model_executor.layers.ac"
},
{
"path": "kernels/triton/inference/gptq/mixtral/w4a16_fused_dequant_gemm.py",
"chars": 11959,
"preview": "\"\"\"Fused MoE W4A16 Kernel.\"\"\"\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom vllm._C import ops\n\n@triton."
},
{
"path": "kernels/triton/inference/gptq/small_benchmark_cuda_graphs.py",
"chars": 14410,
"preview": "import torch\nimport triton\nfrom triton import language as tl\nimport sys\nimport marlin \nimport torch.nn as nn\nfrom auto_g"
},
{
"path": "kernels/triton/inference/gptq/splitk_dequant_gemm.py",
"chars": 6121,
"preview": "import torch\nimport triton\nfrom triton import language as tl\n# from actual_base_gptq_4 import triton_matmul4\n\n@triton.ji"
},
{
"path": "kernels/triton/inference/mamba/causal_1d_conv/causal_1d_conv/causal_1d_conv.py",
"chars": 12632,
"preview": "# Copyright (c) 2025, IBM Research\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange"
},
{
"path": "kernels/triton/inference/mamba/causal_1d_conv/tests/test_causal_1d_conv.py",
"chars": 7734,
"preview": "# Copyright (C) 2025, IBM Research.\n# python -m pytest tests/test_causal_conv1d.py\n\nimport sys\nfrom einops import rearra"
},
{
"path": "kernels/triton/inference/paged_attention/attention_triton.py",
"chars": 15747,
"preview": "#from einops import rearrange\nimport torch\nimport triton\nimport triton.language as tl\n\n# Credit:\n# vedantroy https://git"
},
{
"path": "kernels/triton/inference/torch_compile/flash_backward.py",
"chars": 44315,
"preview": "#!/usr/bin/env python\n\"\"\"\nCode copied from https://github.com/ROCm/triton/blob/triton-mlir/python/perf-kernels/flash-att"
},
{
"path": "kernels/triton/training/README.md",
"chars": 24,
"preview": "Triton training kernels\n"
},
{
"path": "kernels/triton/training/fused_softmax/README.md",
"chars": 253,
"preview": "Fused Softmax in Triton, supporting both inference (fwd) and training (fwd/backward). \n\nPerf testing on A100:\n\n<img widt"
},
{
"path": "kernels/triton/training/fused_softmax/softmax.py",
"chars": 4418,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
},
{
"path": "kernels/triton/training/rms_norm/fused_rms_norm.py",
"chars": 8996,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# This software may be used and distributed according to the terms "
},
{
"path": "kernels/triton/tutorials/README.md",
"chars": 17,
"preview": "Triton tutorials\n"
},
{
"path": "readme.md",
"chars": 2373,
"preview": "\n### Applied AI repo\nFor experiments and research on Applied AI.\n\n### Projects\n\n#### Kernels\n\nHousing a variety of Trito"
},
{
"path": "tutorials/triton/kernels/__init__.py",
"chars": 1,
"preview": "\n"
},
{
"path": "tutorials/triton/kernels/flash_attention_fwd.py",
"chars": 19,
"preview": "# flash forward v2\n"
},
{
"path": "tutorials/triton/kernels/fused_softmax.py",
"chars": 4292,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# ---- Fused Softmax written in Triton ----"
},
{
"path": "tutorials/triton/kernels/readme.md",
"chars": 165,
"preview": "Triton tutorials\n\n1 - Vector Add - Starting tutorial on simple first kernel \n2 - Fused Softmax - Full fused softmax wit"
},
{
"path": "tutorials/triton/kernels/vector_add.py",
"chars": 1291,
"preview": "# coding up a Triton vector addition kernel\n# links to\n\nimport triton\nimport triton.language as tl \nimport torch\n\n@trito"
},
{
"path": "tutorials/triton/tests/test_softmax.py",
"chars": 4467,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\nimport pytest\nimport torch\nimport sys\nsys.p"
},
{
"path": "tutorials/triton/tests/test_utils.py",
"chars": 951,
"preview": "from pathlib import Path\nfrom typing import Any, Dict, NamedTuple, Optional, Tuple, Union\n\nimport pytest\nimport torch\nim"
}
]
// ... and 8 more files (download for full content)
About this extraction
This page contains the full source code of the pytorch-labs/applied-ai GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 119 files (547.0 KB), approximately 155.6k tokens, and a symbol index with 317 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.