Repository: pytorch-labs/applied-ai Branch: main Commit: 2391954b1998 Files: 119 Total size: 547.0 KB Directory structure: gitextract_gcbwaf39/ ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── assets/ │ └── images/ │ ├── dev-discuss-asynctp/ │ │ └── readme.md │ └── readme.md ├── dev/ │ ├── sr/ │ │ ├── .gitignore │ │ ├── readme.md │ │ ├── setup.py │ │ ├── src/ │ │ │ ├── stochastic_rounding.cu │ │ │ ├── stochastic_rounding.hpp │ │ │ └── stochastic_rounding_cuda.cu │ │ ├── test.md │ │ ├── tests/ │ │ │ ├── benchmark.py │ │ │ └── core_unit_tests.py │ │ ├── usage.py │ │ └── usage2.py │ └── triton_groupGEMM/ │ ├── groupgemm.py │ ├── testing/ │ │ ├── base_testing.py │ │ └── unit_tests.py │ ├── tma_utils.py │ └── triton_tutorial_groupgemm.py ├── kernels/ │ ├── MoE/ │ │ └── group_GEMM/ │ │ └── triton/ │ │ ├── readme.md │ │ ├── testing/ │ │ │ ├── fast_verification.py │ │ │ └── pytorch_reference_backwards.py │ │ ├── tgroup_gemm_backwards.py │ │ ├── tgroup_gemm_forward.py │ │ └── utils/ │ │ └── tma_utils.py │ ├── blackwell/ │ │ ├── cute_gemm_01/ │ │ │ ├── Makefile │ │ │ ├── build/ │ │ │ │ └── temp.linux-x86_64-cpython-312/ │ │ │ │ ├── .ninja_deps │ │ │ │ ├── .ninja_log │ │ │ │ ├── build.ninja │ │ │ │ ├── sm100_gemm.o │ │ │ │ └── sm100_gemm_pytorch.o │ │ │ ├── dist/ │ │ │ │ └── sm100_gemm-0.0.0-py3.12-linux-x86_64.egg │ │ │ ├── driver.py │ │ │ ├── setup.py │ │ │ ├── sm100_gemm.cu │ │ │ ├── sm100_gemm.egg-info/ │ │ │ │ ├── PKG-INFO │ │ │ │ ├── SOURCES.txt │ │ │ │ ├── dependency_links.txt │ │ │ │ ├── not-zip-safe │ │ │ │ ├── requires.txt │ │ │ │ └── top_level.txt │ │ │ ├── sm100_gemm.h │ │ │ └── sm100_gemm_pytorch.cpp │ │ └── cute_gemm_02_tma/ │ │ ├── build/ │ │ │ └── temp.linux-x86_64-cpython-312/ │ │ │ ├── .ninja_deps │ │ │ ├── .ninja_log │ │ │ ├── build.ninja │ │ │ ├── sm100_gemm.o │ │ │ └── sm100_gemm_pytorch.o │ │ ├── dist/ │ │ │ └── sm100_gemm-0.0.0-py3.12-linux-x86_64.egg │ │ ├── driver.py │ │ ├── setup.py │ │ ├── sm100_gemm.cu │ │ ├── sm100_gemm.egg-info/ │ │ │ ├── PKG-INFO │ │ │ ├── SOURCES.txt │ │ │ ├── dependency_links.txt │ │ │ ├── not-zip-safe │ │ │ ├── requires.txt │ │ │ └── top_level.txt │ │ ├── sm100_gemm.h │ │ └── sm100_gemm_pytorch.cpp │ ├── cuda/ │ │ ├── cutlass_gemm/ │ │ │ ├── broadcast_load_epilogue_c3x.hpp │ │ │ ├── common.hpp │ │ │ ├── cutlass.cpp │ │ │ ├── cutlass_kernel.cu │ │ │ ├── readme.md │ │ │ ├── setup.py │ │ │ └── test_cutlass_gemm.py │ │ ├── inference/ │ │ │ ├── README.md │ │ │ └── hadamard_transform/ │ │ │ ├── hadamard_transform.cpp │ │ │ ├── hadamard_transform_cuda.cu │ │ │ ├── setup.py │ │ │ └── test.py │ │ ├── training/ │ │ │ └── README.md │ │ └── tutorials/ │ │ ├── README.md │ │ └── flash2.cu │ ├── needs_perf_help/ │ │ ├── fp8_gemm_bench.py │ │ └── fp8_rowwise_tma_persistent.py │ └── triton/ │ ├── inference/ │ │ ├── README.md │ │ ├── col_major_moe_gemm/ │ │ │ ├── README.md │ │ │ ├── perf_test_moe.py │ │ │ ├── profile_moe.py │ │ │ ├── results.html │ │ │ ├── test.csv │ │ │ ├── test_moe_gemm.py │ │ │ ├── v0_moe_fused.py │ │ │ ├── v1_moe_fused.py │ │ │ └── v2_moe_fused.py │ │ ├── flash_attention/ │ │ │ └── stay_attention.py │ │ ├── fp8/ │ │ │ ├── float8_groupwise_quant.py │ │ │ ├── scaled_fp8_gemm.py │ │ │ ├── splitk_gemm_fp8.py │ │ │ └── tma_gemm.py │ │ ├── gptq/ │ │ │ ├── a100_qlinear.py │ │ │ ├── benchmark.py │ │ │ ├── h100_qlinear.py │ │ │ ├── mixtral/ │ │ │ │ ├── test_dequant_moe_gemm.py │ │ │ │ └── w4a16_fused_dequant_gemm.py │ │ │ ├── small_benchmark_cuda_graphs.py │ │ │ └── splitk_dequant_gemm.py │ │ ├── mamba/ │ │ │ └── causal_1d_conv/ │ │ │ ├── causal_1d_conv/ │ │ │ │ └── causal_1d_conv.py │ │ │ └── tests/ │ │ │ └── test_causal_1d_conv.py │ │ ├── paged_attention/ │ │ │ └── attention_triton.py │ │ └── torch_compile/ │ │ └── flash_backward.py │ ├── training/ │ │ ├── README.md │ │ ├── fused_softmax/ │ │ │ ├── README.md │ │ │ └── softmax.py │ │ └── rms_norm/ │ │ └── fused_rms_norm.py │ └── tutorials/ │ └── README.md ├── readme.md └── tutorials/ └── triton/ ├── kernels/ │ ├── __init__.py │ ├── flash_attention_fwd.py │ ├── fused_softmax.py │ ├── readme.md │ └── vector_add.py └── tests/ ├── test_softmax.py └── test_utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.pyc **/.ipynb_checkpoints ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. This Code of Conduct also applies outside the project spaces when there is a reasonable belief that an individual's behavior may have a negative impact on the project or its community. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at . All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to Applied AI We want to make contributing to this project as easy and transparent as possible. ## Our Development Process ... (in particular how this is synced with internal changes to the project) ## Pull Requests We actively welcome your pull requests. 1. Fork the repo and create your branch from `main`. 2. If you've added code that should be tested, add tests. 3. If you've changed APIs, update the documentation. 4. Ensure the test suite passes. 5. Make sure your code lints. 6. If you haven't already, complete the Contributor License Agreement ("CLA"). ## Contributor License Agreement ("CLA") In order to accept your pull request, we need you to submit a CLA. You only need to do this once to work on any of Meta's open source projects. Complete your CLA here: ## Issues We use GitHub issues to track public bugs. Please ensure your description is clear and has sufficient instructions to be able to reproduce the issue. Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe disclosure of security bugs. In those cases, please go through the process outlined on that page and do not file a public issue. ## Coding Style * 2 spaces for indentation rather than tabs * 80 character line length * ... ## License By contributing to applied-ai, you agree that your contributions will be licensed under the LICENSE file in the root directory of this source tree. ================================================ FILE: LICENSE ================================================ Copyright 2024 Meta Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: assets/images/dev-discuss-asynctp/readme.md ================================================ This folder is for hosting the images for the AsyncTP public post at: [https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487) ================================================ FILE: assets/images/readme.md ================================================ Folder for housing images for the readmes. ================================================ FILE: dev/sr/.gitignore ================================================ *.o *.ninja *.txt *.egg-info *.ninja-deps *.ninja-log/ *.so dist/ build/ ================================================ FILE: dev/sr/readme.md ================================================ Branch for stochastic rounding kernel Currently processes 4 elements per thread to leverage rand4 ================================================ FILE: dev/sr/setup.py ================================================ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension setup( name='stochastic_rounding_cuda', version='0.1.021825', ext_modules=[ CUDAExtension('stochastic_rounding_cuda', [ 'src/stochastic_rounding.cu', 'src/stochastic_rounding_cuda.cu' ], extra_compile_args={ 'cxx': ['-O3'], 'nvcc': [ '-O3', '--expt-relaxed-constexpr', # better template support #'-gencode=arch=compute_70,code=sm_70', # Volta #'-gencode=arch=compute_75,code=sm_75', # Turing #'-gencode=arch=compute_80,code=sm_80' # Amper #'-gencode=arch=compute_86,code=sm_86' # Ampere '-gencode=arch=compute_90,code=sm_90', # Hopper ] }) ], cmdclass={ 'build_ext': BuildExtension } ) ================================================ FILE: dev/sr/src/stochastic_rounding.cu ================================================ #include #include "stochastic_rounding.hpp" #include namespace py = pybind11; __host__ int getOptimalBlockSize() { cudaDeviceProp prop; cudaGetDeviceProperties(&prop, 0); return std::min(prop.maxThreadsPerBlock, 256); } torch::Tensor stochastic_round_bf16_cuda(torch::Tensor input, bool requires_grad) { TORCH_CHECK(input.is_cuda(), "Input tensor must be on CUDA device"); TORCH_CHECK(input.is_contiguous(), "Input tensor must be contiguous"); TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input tensor must be float32"); const int threads_per_block = 256; const int num_elements = input.numel(); const int elements_per_thread = 4; const int min_blocks = (num_elements + elements_per_thread * threads_per_block - 1) / (elements_per_thread * threads_per_block); cudaDeviceProp prop; cudaGetDeviceProperties(&prop, 0); const int blocks_per_sm = 4; const int min_blocks_for_sms = prop.multiProcessorCount * blocks_per_sm; const int num_blocks = std::max(min_blocks, min_blocks_for_sms); auto options = torch::TensorOptions() .dtype(torch::kBFloat16) .device(input.device()) .requires_grad(requires_grad); auto output = torch::empty_like(input, options); std::random_device rd; std::mt19937_64 gen(rd()); std::uniform_int_distribution dis; const unsigned long long seed = dis(gen); stochastic_round_bf16<<>>( input.data_ptr(), reinterpret_cast<__nv_bfloat16*>(output.data_ptr()), num_elements, seed); cudaError_t err = cudaGetLastError(); TORCH_CHECK(err == cudaSuccess, "CUDA kernel execution failed: ", cudaGetErrorString(err)); return output; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("stochastic_round_bf16", static_cast(&stochastic_round_bf16_cuda), "Stochastic rounding to BFloat16", py::arg("input"), py::arg("requires_grad") = false); } ================================================ FILE: dev/sr/src/stochastic_rounding.hpp ================================================ #pragma once #include #include #include #include #include namespace philox { constexpr unsigned int W32_0 = 0x9E3779B9; constexpr unsigned int W32_1 = 0xBB67AE85; constexpr unsigned int M0 = 0xD2511F53; constexpr unsigned int M1 = 0xCD9E8D57; constexpr int ROUNDS = 7; } // Forward declarations class PhiloxGenerator { public: __device__ __forceinline__ PhiloxGenerator(); __device__ __forceinline__ void init(const unsigned long long seed, const unsigned int thread_id); __device__ __forceinline__ uint4 next(); private: uint2 key; uint4 counter; static __device__ __forceinline__ uint2 mulhilo(const unsigned int a, const unsigned int b); static __device__ __forceinline__ uint4 round(uint4 ctr, uint2 key); }; // CUDA kernel declaration __global__ void stochastic_round_bf16( float *__restrict__ input, __nv_bfloat16 *__restrict__ output, const int size, const unsigned long long seed); // Host functions __host__ int getOptimalBlockSize(); torch::Tensor stochastic_round_bf16_cuda(torch::Tensor input, bool requires_grad = false); ================================================ FILE: dev/sr/src/stochastic_rounding_cuda.cu ================================================ #include "stochastic_rounding.hpp" #include // Philox RNG implementation __device__ __forceinline__ PhiloxGenerator::PhiloxGenerator() : key(make_uint2(0, 0)), counter(make_uint4(0, 0, 0, 0)) {} __device__ __forceinline__ void PhiloxGenerator::init(const unsigned long long seed, const unsigned int thread_id) { key.x = static_cast(seed); key.y = static_cast(seed >> 32); counter = make_uint4(thread_id, 0, 0, 0); __threadfence_block(); } __device__ __forceinline__ uint2 PhiloxGenerator::mulhilo(const unsigned int a, const unsigned int b) { uint2 result; unsigned long long prod; asm("mul.wide.u32 %0, %1, %2;" : "=l"(prod) : "r"(a), "r"(b)); result.x = static_cast(prod); result.y = static_cast(prod >> 32); return result; } __device__ __forceinline__ uint4 PhiloxGenerator::round(uint4 ctr, uint2 key) { const uint2 mul0 = mulhilo(philox::M0, ctr.x); const uint2 mul1 = mulhilo(philox::M1, ctr.z); return make_uint4( mul1.y ^ ctr.y ^ key.x, mul1.x, mul0.y ^ ctr.w ^ key.y, mul0.x ); } __device__ __forceinline__ uint4 PhiloxGenerator::next() { uint4 ctr = counter; uint2 k = key; #pragma unroll for (int i = 0; i < philox::ROUNDS; ++i) { ctr = round(ctr, k); k.x += philox::W32_0; k.y += philox::W32_1; } counter.x += 4; return ctr; } __device__ __forceinline__ __nv_bfloat16 float_to_bf16_stochastic(const float value, const uint32_t rand) { const uint32_t val_bits = __float_as_uint(value); const uint32_t rounding_bits = val_bits & 0xFFFF; uint32_t result = val_bits & 0xFFFF0000u; result += (rand & 0xFFFF) < rounding_bits ? 0x10000u : 0; return __float2bfloat16(__uint_as_float(result)); } __device__ __forceinline__ void float4_to_bf16_stochastic( const float4& values, uint4& rand_vals, __nv_bfloat16* output) { float vals[4] = {values.x, values.y, values.z, values.w}; uint32_t rands[4] = {rand_vals.x, rand_vals.y, rand_vals.z, rand_vals.w}; #pragma unroll for (int i = 0; i < 4; i++) { output[i] = float_to_bf16_stochastic(vals[i], rands[i]); } } __global__ void stochastic_round_bf16( float *__restrict__ input, __nv_bfloat16 *__restrict__ output, const int size, const unsigned long long seed) { PhiloxGenerator rng; rng.init(seed, blockIdx.x * blockDim.x + threadIdx.x); int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; int stride = blockDim.x * gridDim.x * 4; float4 values; __nv_bfloat16 local_output[4]; // Process full vectors of 4 elements for (; idx <= size - 4; idx += stride) { values = *reinterpret_cast(&input[idx]); uint4 rand = rng.next(); float4_to_bf16_stochastic(values, rand, local_output); for (int j = 0; j < 4; j++) { output[idx + j] = local_output[j]; } } // Handle remaining elements if (idx < size) { float remaining_values[4] = {0.0f, 0.0f, 0.0f, 0.0f}; int remainder = size - idx; for (int j = 0; j < remainder; j++) { remaining_values[j] = input[idx + j]; } values.x = remaining_values[0]; values.y = remaining_values[1]; values.z = remaining_values[2]; values.w = remaining_values[3]; uint4 rand = rng.next(); float4_to_bf16_stochastic(values, rand, local_output); for (int j = 0; j < remainder; j++) { output[idx + j] = local_output[j]; } } } ================================================ FILE: dev/sr/test.md ================================================ (tkdev11) [less@devgpu115.cco2 ~/local/applied-ai/dev/sr (sr_kernel)]$ python usage.py Launching kernel with blocks=1, threads_per_block=256, num_elements=12 Input tensor: tensor([ 0.3282, -0.4513, -1.0612, 0.1446, -0.8440, -1.4669, -0.7135, -0.6183, -2.2411, 2.1464, 1.4772, -1.3564], device='cuda:0') Output tensor: tensor([ 0.3281, -0.4512, -1.0625, 0.1445, -0.8438, -1.4688, -0.7109, -0.6172, -2.2344, 2.1406, 1.4766, -1.3516], device='cuda:0', dtype=torch.bfloat16) Output tensor dtype: torch.bfloat16 Success! ================================================ FILE: dev/sr/tests/benchmark.py ================================================ import torch import stochastic_rounding_cuda import numpy as np import time from tabulate import tabulate import argparse def measure_performance(func, input_tensor, warmup=0, repeats=1): """Measure performance of a function with proper CUDA synchronization""" # Warmup for _ in range(warmup): output = func(input_tensor) torch.cuda.synchronize() start = time.perf_counter() for _ in range(repeats): output = func(input_tensor) torch.cuda.synchronize() end = time.perf_counter() avg_time = (end - start) / repeats elements_per_second = input_tensor.numel() / avg_time return avg_time, elements_per_second def benchmark_sizes(sizes= [1000, 10000, 100000, 1000000, 10000000, (10000000*10), (10000000*100)]): #[ 50,000,000]): # """Benchmark different input sizes""" results = [] for size in sizes: # Create input tensor x = torch.randn(size, device='cuda') # Measure stochastic rounding time_stoch, throughput_stoch = measure_performance( stochastic_rounding_cuda.stochastic_round_bf16, x) # Measure regular BF16 casting time_regular, throughput_regular = measure_performance( lambda t: t.to(torch.bfloat16), x) results.append([ size, time_stoch * 1000, # convert to ms throughput_stoch / 1e6, # convert to GElements/s time_regular * 1000, throughput_regular / 1e6, throughput_regular / throughput_stoch # speedup ]) print("\nSize Comparison:") print(tabulate(results, headers=['Size', 'Stoch Time (ms)', 'Stoch ME/s', 'Regular Time (ms)', 'Regular ME/s', 'Casting faster by'], floatfmt='.3f')) def benchmark_shapes(total_size=1000000): """Benchmark different tensor shapes with same total size""" shapes = [ (total_size,), # 1D (1000, total_size//1000), # 2D (100, 100, total_size//10000), # 3D ] results = [] for shape in shapes: x = torch.randn(*shape, device='cuda') time_stoch, throughput_stoch = measure_performance( stochastic_rounding_cuda.stochastic_round_bf16, x) results.append([ 'x'.join(str(d) for d in shape), time_stoch * 1000, throughput_stoch / 1e9 ]) print("\nShape Comparison (same total size):") print(tabulate(results, headers=['Shape', 'Time (ms)', 'GElements/s'], floatfmt='.3f')) def stress_test(duration=10): """Run a stress test for specified duration""" print(f"\nRunning stress test for {duration} seconds...") size = 1000000 x = torch.randn(size, device='cuda') start_time = time.time() iterations = 0 while time.time() - start_time < duration: stochastic_rounding_cuda.stochastic_round_bf16(x) iterations += 1 print(f"Completed {iterations} iterations without errors") print(f"Average throughput: {(iterations * size) / (duration * 1e9):.2f} GElements/s") def memory_test(max_size=1e9): """Test memory scaling""" sizes = np.logspace(3, min(9, np.log10(max_size)), num=7, dtype=int) results = [] for size in sizes: try: torch.cuda.empty_cache() x = torch.randn(size, device='cuda') torch.cuda.synchronize() # Measure peak memory during operation torch.cuda.reset_peak_memory_stats() _ = stochastic_rounding_cuda.stochastic_round_bf16(x) torch.cuda.synchronize() peak_memory = torch.cuda.max_memory_allocated() / 1e6 # MB results.append([size, peak_memory]) except RuntimeError as e: print(f"Out of memory at size {size}") break print("\nMemory Usage:") print(tabulate(results, headers=['Size', 'Peak Memory (MB)'], floatfmt='.2f')) def main(): parser = argparse.ArgumentParser(description='Benchmark stochastic rounding') parser.add_argument('--sizes', action='store_true', help='Run size benchmarks') parser.add_argument('--shapes', action='store_true', help='Run shape benchmarks') parser.add_argument('--stress', action='store_true', help='Run stress test') parser.add_argument('--memory', action='store_true', help='Run memory test') parser.add_argument('--all', action='store_true', help='Run all benchmarks') args = parser.parse_args() # Print device information device = torch.cuda.get_device_properties(0) print(f"\nRunning on: {device.name}") print(f"Compute Capability: {device.major}.{device.minor}") if args.all or args.sizes: benchmark_sizes() if args.all or args.shapes: benchmark_shapes() if args.all or args.stress: stress_test() if args.all or args.memory: memory_test() if __name__ == '__main__': main() ================================================ FILE: dev/sr/tests/core_unit_tests.py ================================================ import torch import numpy as np from collections import Counter import unittest import stochastic_rounding_cuda import time class TestStochasticRounding(unittest.TestCase): def setup(self): # Ensure deterministic behavior for some tests torch.manual_seed(42) np.random.seed(42) def _test_rounding_statistics_helper(self, value, lower_value, upper_value, tensor_size=10000, rounds=100): """Helper method for testing stochastic rounding statistics""" print(f"\nInput value: {value}") MAX_VARIANCE = 0.03 x = torch.full((tensor_size,), value, device='cuda') torch.cuda.manual_seed(42) # Single round test - isolate and show the round up and round down values single_result = stochastic_rounding_cuda.stochastic_round_bf16(x) print(f"Possible rounded values: {torch.unique(single_result)}") # Multiple rounds results = torch.empty((rounds, tensor_size), device='cuda', dtype=torch.bfloat16) for i in range(rounds): results[i] = stochastic_rounding_cuda.stochastic_round_bf16(x) prob_up = (results == upper_value).float().mean().item() print(f"Kernel's probability of rounding up: {prob_up:.4f}") distance_to_lower = abs(value - lower_value) total_distance = upper_value - lower_value expected_prob = distance_to_lower / total_distance print(f"Expected probability: {expected_prob:.4f}") self.assertTrue(abs(prob_up - expected_prob) < MAX_VARIANCE) def test_special_values(self): """Test handling of special values like inf, -inf, nan""" special_values = torch.tensor([float('inf'), float('-inf'), float('nan'), 0.0, -0.0], device='cuda') rounded = stochastic_rounding_cuda.stochastic_round_bf16(special_values) # Check inf and -inf are preserved self.assertTrue(torch.isinf(rounded[0])) self.assertTrue(torch.isinf(rounded[1])) self.assertTrue(rounded[0] > 0) self.assertTrue(rounded[1] < 0) # Check nan is preserved self.assertTrue(torch.isnan(rounded[2])) # Check zeros are preserved self.assertEqual(rounded[3].item(), 0.0) self.assertEqual(rounded[4].item(), 0.0) def test_small_values(self): """Test handling of small values near zero""" small_values = torch.tensor([1e-38, -1e-38, 1e-20, -1e-20], device='cuda') rounded = stochastic_rounding_cuda.stochastic_round_bf16(small_values) # Check that very small values are handled properly self.assertTrue(torch.all(torch.isfinite(rounded))) def test_vectorized_loading(self): """Test if vectorized loading works correctly for different tensor sizes""" sizes = [4, 8, 9, 16, 32, 100] # Test various sizes including non-aligned for size in sizes: x = torch.linspace(1, size, size, device='cuda') rounded = stochastic_rounding_cuda.stochastic_round_bf16(x) # Check output size matches input self.assertEqual(rounded.size(0), size) # Check dtype self.assertEqual(rounded.dtype, torch.bfloat16) def test_large_values(self): """Test handling of large values""" large_values = torch.tensor([1e38, -1e38, 1e20, -1e20], device='cuda') rounded = stochastic_rounding_cuda.stochastic_round_bf16(large_values) # Values should be preserved approximately in BF16 range self.assertTrue(torch.all(torch.isfinite(rounded))) def test_rounding_statistics(self): """Test if rounding probabilities match expected distribution""" self._test_rounding_statistics_helper(2.1999969482421875, 2.1875, 2.2031) def test_rounding_statistics_2(self): """Test stochastic rounding with different BF16 boundary values""" self._test_rounding_statistics_helper(1.7999992370605469, 1.7969, 1.8047) def test_rounding_statistics_small(self): """Test stochastic rounding for number between 0 and 1""" self._test_rounding_statistics_helper(0.7499847412109375, 0.7480, 0.7500) def test_rounding_statistics_large(self): """Test stochastic rounding for large number, over 100""" self._test_rounding_statistics_helper(128.99998474121094, 128.875, 129.000) if __name__ == '__main__': unittest.main(verbosity=2) ================================================ FILE: dev/sr/usage.py ================================================ import torch import stochastic_rounding_cuda # Create input tensor input_tensor = torch.randn(12, device='cuda', dtype=torch.float32) # Apply stochastic rounding output_tensor = stochastic_rounding_cuda.stochastic_round_bf16(input_tensor) print(f"Input tensor: {input_tensor}") print(f"Output tensor: {output_tensor}") print(f"Output tensor dtype: {output_tensor.dtype}") print(f"Success!") ''' # Test tensor x = torch.tensor([9.8751e-01, -8.5288e-01, 1.6775e+00, -1.3683e+00, 4.0467e-01, 1.0759e-03, 2.8418e-01, -4.9392e-01, 8.7239e-01, -9.0545e-01, 1.1134e+00, 0], # -2.6872e+00 device='cuda') # Convert to BF16 y = stochastic_rounding_cuda.stochastic_round_bf16(x) print(f"Input: {x}") print(f"Output: {y}") ''' ================================================ FILE: dev/sr/usage2.py ================================================ import torch import stochastic_rounding_cuda # Test tensor x = torch.tensor([9.8751e-01, -8.5288e-01, 1.6775e+00], device='cuda') # Compare with regular rounding y_normal = x.to(torch.bfloat16) y_stochastic = stochastic_rounding_cuda.stochastic_round_bf16(x) print(f"Input: {x}") print(f"Normal BF16: {y_normal}") print(f"Stochastic BF16: {y_stochastic}") ================================================ FILE: dev/triton_groupGEMM/groupgemm.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe import functools from typing import Optional import tma_utils as utils import torch import triton import triton.language as tl from triton.runtime import driver # @manual _NV_CONFIGS = [ triton.Config( { "BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k, }, num_stages=num_stages, num_warps=num_warps, num_ctas=num_ctas, ) for block_size_m in [64, 128] for block_size_n in [64, 128, 256] for block_size_k in [64, 128, 256] for num_stages in [3, 4] for num_warps in [4, 8] for num_ctas in [1] ] _AMD_CONFIGS = [ triton.Config( { "BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k, "waves_per_eu": waves_per_cu, "matrix_instr_nonkdim": matrix_instr_nonkdim, }, num_stages=num_stages, num_warps=num_warps, ) for block_size_m in [32, 64, 128] for block_size_n in [32, 64, 128, 256] for block_size_k in [128, 256] for num_stages in [1, 2] for num_warps, waves_per_cu in [(4, 1), (8, 2), (16, 4)] for matrix_instr_nonkdim in [16] ] def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs): device = torch.cuda.current_device() # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages if dtsize is None: dtsize = named_args["c_ptr"].element_size() if dtype is None: dtype = named_args["c_ptr"].dtype pruned_configs = [] for config in configs: kw = config.kwargs BLOCK_M, BLOCK_N, BLOCK_K, num_stages = ( kw["BLOCK_SIZE_M"], kw["BLOCK_SIZE_N"], kw["BLOCK_SIZE_K"], config.num_stages, ) G, M, N, K = ( named_args["G"], named_args["M_BUCKET"], named_args["N"], named_args["K"], ) # 1. make sure we have enough smem max_shared_memory = driver.active.utils.get_device_properties(device)[ "max_shared_mem" ] if torch.version.hip: required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize else: required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize if required_shared_memory > max_shared_memory: continue M_PER_GROUP = M // G MIN_M_TILES = 32 if torch.version.hip else 64 # 2. make sure we don't load M tiles that are too big if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2): continue # 3. make sure we don't load N tiles that are too small if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2): continue num_sm = driver.active.utils.get_device_properties(device)[ "multiprocessor_count" ] N_TILES = N // BLOCK_N MIN_N_TILES = 32 if torch.version.hip else 64 # 4. make sure we don't load N tiles that are too big if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm: continue # 5. make sure we don't load N tiles that are too small if BLOCK_N < 128 and M * N_TILES > 2 * num_sm: continue # 6. make sure K can be evenly divided if K % BLOCK_K != 0: continue pruned_configs.append(config) return pruned_configs @triton.autotune( configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS, key=["G", "M_BUCKET", "N", "K"], prune_configs_by={"early_config_prune": early_config_prune}, ) @triton.jit def _kernel_grouped_gemm( a_desc_ptr, b_desc_ptr, c_ptr, workspace, m_sizes, # problem sizes G: tl.constexpr, M_BUCKET: tl.constexpr, N: tl.constexpr, K: tl.constexpr, NUM_SMS: tl.constexpr, USE_TMA_LOAD: tl.constexpr, USE_TMA_STORE: tl.constexpr, # tile sizes BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ) -> None: tidx = tl.program_id(0) dtype: tl.dtype = c_ptr.dtype.element_ty TMA_SIZE: tl.constexpr = tl.constexpr(128) if USE_TMA_STORE: c_desc_ptr = workspace + tidx * TMA_SIZE else: c_desc_ptr = None M_end_offset = 0 iterated_tiles = 0 for g in tl.range(G): # Move across groups M_start_offset = M_end_offset m_size = tl.load(m_sizes + g) M_end_offset = M_start_offset + m_size if m_size > 0: N_start_offset = g * N n_size = N num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) num_tiles = num_m_tiles * num_n_tiles if USE_TMA_STORE: # pyre-ignore tl.extra.cuda.experimental_device_tensormap_create2d( desc_ptr=c_desc_ptr, global_address=c_ptr + M_start_offset * N, load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[m_size, n_size], element_ty=c_ptr.dtype.element_ty, ) # pyre-ignore tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) # Move across tiles while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles: gidx = tidx - iterated_tiles # Split M first and N second. tile_m_idx = gidx % num_m_tiles tile_n_idx = gidx // num_m_tiles accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) tl.static_assert(K % BLOCK_SIZE_K == 0) if USE_TMA_LOAD: m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32) n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32) for k_offset in range(0, K, BLOCK_SIZE_K): a = tl._experimental_descriptor_load( a_desc_ptr, [m_offset, k_offset], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype, ) b = tl._experimental_descriptor_load( b_desc_ptr, [n_offset, k_offset], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype, ) accumulator += tl.dot(a, b.T) else: offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = ( a_desc_ptr + (M_start_offset + offs_am[:, None]) * K + offs_k[None, :] ) b_ptrs = ( b_desc_ptr + (N_start_offset + offs_bn[:, None]) * K + offs_k[None, :] ) for k_offset in range(0, K, BLOCK_SIZE_K): a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) accumulator += tl.dot(a, b.T) a_ptrs += BLOCK_SIZE_K b_ptrs += BLOCK_SIZE_K if USE_TMA_STORE: m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) tl._experimental_descriptor_store( c_desc_ptr, accumulator.to(c_ptr.dtype.element_ty), [m_offset, n_offset], ) else: offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c = accumulator.to(c_ptr.dtype.element_ty) tl.store( c_ptr + (M_start_offset + offs_am[:, None]) * N + offs_bn[None, :], c, mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size, ) tidx += NUM_SMS iterated_tiles += num_tiles TT_FP8_DTYPE = tl.float8e4b8 if torch.version.hip else tl.float8e4nv @triton.autotune( configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS, key=["G", "M_BUCKET", "N", "K"], prune_configs_by={ "early_config_prune": functools.partial( early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1 ) }, ) @triton.jit def _kernel_grouped_gemm_fp8_rowwise( a_desc_ptr, a_scale_ptr, b_desc_ptr, b_scale_ptr, c_ptr, workspace, m_sizes, # problem sizes G: tl.constexpr, M_BUCKET: tl.constexpr, N: tl.constexpr, K: tl.constexpr, NUM_SMS: tl.constexpr, USE_TMA_LOAD: tl.constexpr, USE_TMA_STORE: tl.constexpr, # tile sizes BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ) -> None: tidx = tl.program_id(0) dtype = TT_FP8_DTYPE TMA_SIZE: tl.constexpr = tl.constexpr(128) if USE_TMA_STORE: c_desc_ptr = workspace + tidx * TMA_SIZE else: c_desc_ptr = None M_end_offset = 0 iterated_tiles = 0 for g in tl.range(G): # Move across groups M_start_offset = M_end_offset m_size = tl.load(m_sizes + g) M_end_offset = M_start_offset + m_size if m_size > 0: N_start_offset = g * N n_size = N num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) num_tiles = num_m_tiles * num_n_tiles if USE_TMA_STORE: # pyre-ignore tl.extra.cuda.experimental_device_tensormap_create2d( desc_ptr=c_desc_ptr, global_address=c_ptr + M_start_offset * N, load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[m_size, n_size], element_ty=c_ptr.dtype.element_ty, ) # pyre-ignore tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) # Move across tiles while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles: gidx = tidx - iterated_tiles # Split M first and N second. tile_m_idx = gidx % num_m_tiles tile_n_idx = gidx // num_m_tiles accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) tl.static_assert(K % BLOCK_SIZE_K == 0) if USE_TMA_LOAD: m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32) n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32) for k_offset in range(0, K, BLOCK_SIZE_K): a = tl._experimental_descriptor_load( a_desc_ptr, [m_offset, k_offset], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype, ) b = tl._experimental_descriptor_load( b_desc_ptr, [n_offset, k_offset], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype, ) accumulator += tl.dot(a, b.T) else: offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = ( a_desc_ptr + (M_start_offset + offs_am[:, None]) * K + offs_k[None, :] ) b_ptrs = ( b_desc_ptr + (N_start_offset + offs_bn[:, None]) * K + offs_k[None, :] ) for k_offset in range(0, K, BLOCK_SIZE_K): a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) accumulator += tl.dot(a, b.T) a_ptrs += BLOCK_SIZE_K b_ptrs += BLOCK_SIZE_K offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) a_scale = tl.load( a_scale_ptr + M_start_offset + offs_am[:, None], mask=offs_am[:, None] < m_size, ) b_scale = tl.load( b_scale_ptr + N_start_offset + offs_bn[None, :], mask=offs_bn[None, :] < n_size, ) c = accumulator.to(tl.float32) * a_scale * b_scale if USE_TMA_STORE: m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) tl._experimental_descriptor_store( c_desc_ptr, c.to(c_ptr.dtype.element_ty), [m_offset, n_offset], ) else: tl.store( c_ptr + (M_start_offset + offs_am[:, None]) * N + offs_bn[None, :], c, mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size, ) tidx += NUM_SMS iterated_tiles += num_tiles def _grouped_gemm( x: torch.Tensor, w: torch.Tensor, m_sizes: torch.Tensor, x_scale: Optional[torch.Tensor] = None, w_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: if not utils.HAS_TMA_DESC: raise NotImplementedError("Grouped GEMM without TMA is not supported yet") G = m_sizes.shape[0] assert x.is_contiguous() assert w.is_contiguous() assert m_sizes.is_contiguous() M, K = x.shape N = w.shape[0] // G assert K == w.shape[1] y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16) NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count USE_TMA_LOAD = not torch.version.hip USE_TMA_STORE = False desc_helper = None desc_x = x desc_w = w workspace = None if USE_TMA_LOAD: desc_helper = utils.TmaAutoTuneHelper() desc_helper.init_tma_descriptor("x") desc_helper.init_tma_descriptor("w") desc_x = desc_helper.get_tma_descriptor_kernel_param("x") desc_w = desc_helper.get_tma_descriptor_kernel_param("w") if USE_TMA_STORE: workspace = torch.empty( NUM_SMS * utils.TmaAutoTuneHelper.TMA_SIZE, device=x.device, dtype=torch.uint8, ) def grid(META): if USE_TMA_LOAD: nonlocal desc_helper desc_helper.fill_2d_tma_descriptor( "x", x.data_ptr(), M, K, META["BLOCK_SIZE_M"], META["BLOCK_SIZE_K"], x.element_size(), ) desc_helper.fill_2d_tma_descriptor( "w", w.data_ptr(), N * G, K, META["BLOCK_SIZE_N"], META["BLOCK_SIZE_K"], w.element_size(), ) return (NUM_SMS,) M_BUCKET = triton.next_power_of_2(M) if x_scale is not None and w_scale is not None: assert x_scale.is_contiguous() assert w_scale.is_contiguous() _kernel_grouped_gemm_fp8_rowwise[grid]( desc_x, x_scale, desc_w, w_scale, y, workspace, m_sizes, G, M_BUCKET, N, K, NUM_SMS, USE_TMA_LOAD, USE_TMA_STORE, ) else: assert x_scale is None assert w_scale is None _kernel_grouped_gemm[grid]( desc_x, desc_w, y, workspace, m_sizes, G, M_BUCKET, N, K, NUM_SMS, USE_TMA_LOAD, USE_TMA_STORE, ) return y def grouped_gemm( x: torch.Tensor, w: torch.Tensor, m_sizes: torch.Tensor ) -> torch.Tensor: return _grouped_gemm(x, w, m_sizes) def grouped_gemm_fp8_rowwise( x: torch.Tensor, w: torch.Tensor, m_sizes: torch.Tensor, x_scale: torch.Tensor, w_scale: torch.Tensor, ) -> torch.Tensor: return _grouped_gemm(x, w, m_sizes, x_scale, w_scale) ================================================ FILE: dev/triton_groupGEMM/testing/base_testing.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-strict import logging # Configure logging to print to stdout logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) import os import sys import unittest from typing import Tuple import torch # Add parent directory to path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) if torch.cuda.is_available(): # from fp8_gemm import quantize_fp8_row from groupgemm import grouped_gemm # , grouped_gemm_fp8_rowwise from tma_utils import HAS_TMA_DESC @unittest.skipIf( not torch.cuda.is_available() or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9 or not HAS_TMA_DESC, "Skip when H100 or TMA is not available", ) class TestGroupedGEMM(unittest.TestCase): def setUp(self) -> None: torch.manual_seed(0) """def test_grouped_gemm_fp8_rowwise(self) -> None: def _test_grouped_gemm_fp8_rowwise( shape: Tuple[int, int, int, int], device: torch.device, ) -> None: G, M, N, K = shape a = torch.randn(M, K, dtype=torch.bfloat16, device=device) b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) m_ends, _ = torch.sort( torch.randint( low=0, high=M, size=[G - 1], device=device, dtype=torch.int32 ) ) m_ends = m_ends.tolist() m_starts = [0] + m_ends m_ends = m_ends + [M] m_sizes = torch.tensor( [m_ends[i] - m_starts[i] for i in range(G)], device=device ).to(torch.int32) a_fp8, a_scale = quantize_fp8_row(a) b_fp8, b_scale = quantize_fp8_row(b) result = grouped_gemm_fp8_rowwise( a_fp8, b_fp8, m_sizes, a_scale, b_scale, ) self.assertTrue(result.shape == (M, N)) expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device) # Running baseline with quantization to exclude quantization error from the test as it has nothing to do with the correctness of the kernel implementation. for g in range(G): m_start = m_starts[g] m_end = m_ends[g] n_start = g * N n_end = (g + 1) * N expected_result[m_start:m_end, :] = ( a_fp8[m_start:m_end, :].to(torch.float32) @ b_fp8[n_start:n_end, :].to(torch.float32).T * a_scale[m_start:m_end][:, None] * b_scale[n_start:n_end][None, :] ).to(torch.bfloat16) torch.testing.assert_close(result, expected_result, atol=2e-2, rtol=1.6e-2) for G in (1, 4, 16): for M in (64, 512): logging.info(f"Testing FP8 GMM with G={G}, M={M}") _test_grouped_gemm_fp8_rowwise((G, M, 256, 256), torch.device("cuda")) """ def test_grouped_gemm_bf16(self) -> None: def _test_grouped_gemm_bf16( shape: Tuple[int, int, int, int], device: torch.device, ) -> None: G, M, N, K = shape a = torch.randn(M, K, dtype=torch.bfloat16, device=device) b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) m_ends, _ = torch.sort( torch.randint( low=0, high=M, size=[G - 1], device=device, dtype=torch.int32 ) ) m_ends = m_ends.tolist() m_starts = [0] + m_ends m_ends = m_ends + [M] m_sizes = torch.tensor( [m_ends[i] - m_starts[i] for i in range(G)], device=device ).to(torch.int32) result = grouped_gemm( a, b, m_sizes, ) self.assertTrue(result.shape == (M, N)) expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device) for g in range(G): m_start = m_starts[g] m_end = m_ends[g] expected_result[m_start:m_end, :] = ( a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T ) torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2) for G in (1, 4, 16): for M in (64, 512): logging.info(f"Testing BF16 GMM with G={G}, M={M}") _test_grouped_gemm_bf16((G, M, 256, 256), torch.device("cuda")) if __name__ == "__main__": unittest.main(exit=False) ================================================ FILE: dev/triton_groupGEMM/testing/unit_tests.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe # This code is derived from: https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gemm/triton_gemm import logging import unittest from typing import Tuple import torch # Add parent directory to path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from groupgemm import grouped_gemm class TestGroupedGEMM(unittest.TestCase): def test_grouped_gemm_bf16(self) -> None: def _test_grouped_gemm_bf16( shape: Tuple[int, int, int, int], device: torch.device, ) -> None: G, M, N, K = shape a = torch.randn(M, K, dtype=torch.bfloat16, device=device) b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) m_ends, _ = torch.sort( torch.randint( low=0, high=M, size=[G - 1], device=device, dtype=torch.int32 ) ) m_ends = m_ends.tolist() m_starts = [0] + m_ends m_ends = m_ends + [M] m_sizes = torch.tensor( [m_ends[i] - m_starts[i] for i in range(G)], device=device ).to(torch.int32) result = grouped_gemm( a, b, m_sizes, ) self.assertTrue(result.shape == (M, N)) expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device) for g in range(G): m_start = m_starts[g] m_end = m_ends[g] expected_result[m_start:m_end, :] = ( a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T ) torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2) for G in (1, 4, 16): for M in (64, 512): logging.info(f"Testing BF16 GMM with G={G}, M={M}") _test_grouped_gemm_bf16((G, M, 256, 256), torch.device("cuda")) def test_grouped_gemm_bf16_various_dimensions(self) -> None: """Test grouped_gemm with bf16 precision and various dimensions""" def _test_grouped_gemm_bf16( shape: Tuple[int, int, int, int], device: torch.device, ) -> None: G, M, N, K = shape a = torch.randn(M, K, dtype=torch.bfloat16, device=device) b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) m_ends, _ = torch.sort( torch.randint( low=0, high=M, size=[G - 1], device=device, dtype=torch.int32 ) ) m_ends = m_ends.tolist() m_starts = [0] + m_ends m_ends = m_ends + [M] m_sizes = torch.tensor( [m_ends[i] - m_starts[i] for i in range(G)], device=device ).to(torch.int32) result = grouped_gemm( a, b, m_sizes, ) self.assertTrue(result.shape == (M, N)) expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device) for g in range(G): m_start = m_starts[g] m_end = m_ends[g] expected_result[m_start:m_end, :] = ( a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T ) torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2) for G in (4, 8): for M in (128, 256): for N, K in [(128, 256), (256, 128), (64, 64)]: logging.info(f"Testing BF16 GMM with G={G}, M={M}, N={N}, K={K}") _test_grouped_gemm_bf16((G, M, N, K), torch.device("cuda")) def test_grouped_gemm_bf16_edge_cases(self) -> None: """Test grouped_gemm with bfloat16 for various edge cases""" device = torch.device("cuda") # Test with G=1 (single group case) G, M, N, K = 1, 32, 32, 32 a = torch.randn(M, K, dtype=torch.bfloat16, device=device) b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) m_sizes = torch.tensor([M], device=device).to(torch.int32) result = grouped_gemm(a, b, m_sizes) expected_result = a @ b.T torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2) # Test with uneven group sizes G, M, N, K = 3, 100, 32, 32 a = torch.randn(M, K, dtype=torch.bfloat16, device=device) b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) m_sizes = torch.tensor([25, 50, 25], device=device).to(torch.int32) result = grouped_gemm(a, b, m_sizes) self.assertTrue(result.shape == (M, N)) expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device) m_start = 0 for g in range(G): m_end = m_start + m_sizes[g].item() expected_result[m_start:m_end, :] = ( a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T ) m_start = m_end torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2) # Test with extremely small matrices G, M, N, K = 2, 8, 8, 8 a = torch.randn(M, K, dtype=torch.bfloat16, device=device) b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) m_sizes = torch.tensor([4, 4], device=device).to(torch.int32) result = grouped_gemm(a, b, m_sizes) expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device) m_start = 0 for g in range(G): m_end = m_start + m_sizes[g].item() expected_result[m_start:m_end, :] = ( a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T ) m_start = m_end torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2) # Test with large group count but small matrix sizes G, M, N, K = 32, 128, 16, 16 a = torch.randn(M, K, dtype=torch.bfloat16, device=device) b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) m_sizes = torch.ones(G, device=device).to(torch.int32) * (M // G) m_sizes[-1] += M % G # Adjust the last group size to account for remainder result = grouped_gemm(a, b, m_sizes) expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device) m_start = 0 for g in range(G): m_end = m_start + m_sizes[g].item() expected_result[m_start:m_end, :] = ( a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T ) m_start = m_end torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2) def test_grouped_gemm_bf16_invalid_inputs(self) -> None: """Test grouped_gemm with invalid inputs to ensure proper error handling""" device = torch.device("cuda") # Test with mismatched dimensions G, M, N, K = 2, 64, 32, 32 a = torch.randn( M, K + 1, dtype=torch.bfloat16, device=device ) # Wrong K dimension b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) m_sizes = torch.tensor([32, 32], device=device).to(torch.int32) with self.assertRaises(RuntimeError): grouped_gemm(a, b, m_sizes) # Test with mismatched G and m_sizes length G, M, N, K = 2, 64, 32, 32 a = torch.randn(M, K, dtype=torch.bfloat16, device=device) b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) m_sizes = torch.tensor([32, 32, 32], device=device).to( torch.int32 ) # Too many groups with self.assertRaises((RuntimeError, ValueError, IndexError)): grouped_gemm(a, b, m_sizes) # Test with incorrect sum of m_sizes G, M, N, K = 2, 64, 32, 32 a = torch.randn(M, K, dtype=torch.bfloat16, device=device) b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) m_sizes = torch.tensor([32, 40], device=device).to(torch.int32) # Sum > M with self.assertRaises((RuntimeError, ValueError, IndexError)): grouped_gemm(a, b, m_sizes) # Test with negative m_sizes values G, M, N, K = 2, 64, 32, 32 a = torch.randn(M, K, dtype=torch.bfloat16, device=device) b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) m_sizes = torch.tensor([40, -8], device=device).to( torch.int32 ) # Negative group size with self.assertRaises((RuntimeError, ValueError)): grouped_gemm(a, b, m_sizes) def test_grouped_gemm_bf16_deterministic(self) -> None: """Test that grouped_gemm produces deterministic results with the same inputs""" G, M, N, K = 4, 128, 64, 64 device = torch.device("cuda") # Fix the random seed for reproducibility torch.manual_seed(42) a = torch.randn(M, K, dtype=torch.bfloat16, device=device) b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) m_sizes = torch.tensor([32, 32, 32, 32], device=device).to(torch.int32) # First run result1 = grouped_gemm(a, b, m_sizes) # Second run with same inputs result2 = grouped_gemm(a, b, m_sizes) # Results should be exactly the same self.assertTrue(torch.all(result1 == result2)) def test_grouped_gemm_bf16_large_matrices(self) -> None: """Test grouped_gemm with larger matrices to stress test performance and stability""" device = torch.device("cuda") # Test with large matrices but fewer groups G, M, N, K = 2, 2048, 512, 1024 a = torch.randn(M, K, dtype=torch.bfloat16, device=device) b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) m_sizes = torch.tensor([1024, 1024], device=device).to(torch.int32) result = grouped_gemm(a, b, m_sizes) expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device) m_start = 0 for g in range(G): m_end = m_start + m_sizes[g].item() expected_result[m_start:m_end, :] = ( a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T ) m_start = m_end torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2) if __name__ == "__main__": unittest.main(argv=["first-arg-is-ignored"], exit=False) ================================================ FILE: dev/triton_groupGEMM/tma_utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe # This code is derived from: https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gemm/triton_gemm import sys import torch import triton # @manual import triton.language as tl # @manual def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype: """ Maps torch dtype to triton dtype. Args: dtype (torch.dtype): input dtype. Returns: tl.dtype: triton dtype. """ if dtype == torch.float16: return tl.float16 elif dtype == torch.bfloat16: return tl.bfloat16 elif dtype == torch.float32: return tl.float32 elif dtype == torch.int32: return tl.int32 elif dtype == torch.float8_e4m3fn and torch.version.hip is None: return tl.float8e4nv else: raise ValueError(f"Unsupported dtype {dtype}") # check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498). HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) if HAS_TMA_DESC: print( "TMA benchmarks will be running with experimental grid constant TMA descriptor.", file=sys.stderr, ) else: print( "Missing: This group gemm code will not run without TMA descriptor support....", file=sys.stderr, ) raise NotImplementedError("grouped Gemm without TMA is not supported") class TmaAutoTuneHelper: # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498 class KernelParamWrapper: def __init__(self, desc): self.desc = desc def tma_desc_cpu_ptr(self): return self.desc.data_ptr() TMA_SIZE = 128 def __init__(self): self.fill_1d_tma_descriptor_inner = ( triton.runtime.driver.active.utils.fill_1d_tma_descriptor ) self.fill_2d_tma_descriptor_inner = ( triton.runtime.driver.active.utils.fill_2d_tma_descriptor ) if HAS_TMA_DESC: self.descriptors = {} else: self.cuda_descriptors = {} # Call this method outside of the lambda function for grid size def init_tma_descriptor(self, name): if HAS_TMA_DESC: self.descriptors[name] = torch.empty( TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8 ) else: self.cuda_descriptors[name] = torch.empty( TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8 ) # Call this method inside the lambda function for grid size def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size): if HAS_TMA_DESC: desc_x = self.descriptors[name] assert desc_x.data_ptr() % 64 == 0 self.fill_1d_tma_descriptor_inner( ptr, dim, block_dim, element_size, desc_x.data_ptr() ) else: desc_x = self.cuda_descriptors[name] buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) self.fill_1d_tma_descriptor_inner( ptr, dim, block_dim, element_size, buf_x.data_ptr() ) desc_x.copy_(buf_x, non_blocking=True) # Call this method inside the lambda function for grid size def fill_2d_tma_descriptor( self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size ): if HAS_TMA_DESC: desc_x = self.descriptors[name] assert desc_x.data_ptr() % 64 == 0 self.fill_2d_tma_descriptor_inner( ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr() ) else: desc_x = self.cuda_descriptors[name] buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) self.fill_2d_tma_descriptor_inner( ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr() ) desc_x.copy_(buf_x, non_blocking=True) def get_tma_descriptor_kernel_param(self, name): if HAS_TMA_DESC: assert self.descriptors[name] is not None return self.KernelParamWrapper(self.descriptors[name]) else: assert self.cuda_descriptors[name] is not None return self.cuda_descriptors[name] ================================================ FILE: dev/triton_groupGEMM/triton_tutorial_groupgemm.py ================================================ """ Group GEMM ============================ This group gemm kernel launches a fixed number of CTA to compute a group of gemms. The scheduling is static and we do it on device. """ # Copyright (c) 2023 - 2025 NVIDIA Corporation & Affiliates. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files # (the "Software"), to deal in the Software without restriction, # including without limitation the rights to use, copy, modify, merge, # publish, distribute, sublicense, and/or sell copies of the Software, # and to permit persons to whom the Software is furnished to do so, # subject to the following conditions: # # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY # CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # from: https://github.com/triton-lang/triton/blob/main/python/tutorials/08-grouped-gemm.py from typing import Optional import torch import triton import triton.language as tl DEVICE = triton.runtime.driver.active.get_active_torch_device() def is_cuda(): return triton.runtime.driver.active.get_current_target().backend == "cuda" def supports_tma(): return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 def num_sms(): if is_cuda(): return torch.cuda.get_device_properties("cuda").multi_processor_count return 148 @triton.autotune( configs=[ triton.Config( { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "NUM_SM": 84, } ), triton.Config( { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "NUM_SM": 128, } ), triton.Config( { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "NUM_SM": 84, } ), triton.Config( { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "NUM_SM": 128, } ), triton.Config( { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "NUM_SM": num_sms(), } ), triton.Config( { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "NUM_SM": num_sms(), } ), ], key=["group_size"], ) @triton.jit def grouped_matmul_kernel( # device tensor of matrices pointers group_a_ptrs, group_b_ptrs, group_c_ptrs, # device tensor of gemm sizes. its shape is [group_size, 3] # dim 0 is group_size, dim 1 is the values of of each gemm group_gemm_sizes, # device tensor of leading dimension sizes. its shape is [group_size, 3] # dim 0 is group_size, dim 1 is the values of of each gemm g_lds, # number of gemms group_size, # number of virtual SM NUM_SM: tl.constexpr, # tile sizes BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): tile_idx = tl.program_id(0) last_problem_end = 0 for g in range(group_size): # get the gemm size of the current problem gm = tl.load(group_gemm_sizes + g * 3) gn = tl.load(group_gemm_sizes + g * 3 + 1) gk = tl.load(group_gemm_sizes + g * 3 + 2) num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M) num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) num_tiles = num_m_tiles * num_n_tiles # iterate through the tiles in the current gemm problem while tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles: # pick up a tile from the current gemm problem k = gk lda = tl.load(g_lds + g * 3) ldb = tl.load(g_lds + g * 3 + 1) ldc = tl.load(g_lds + g * 3 + 2) a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16)) b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16)) c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16)) # figure out tile coordinates tile_idx_in_gemm = tile_idx - last_problem_end tile_m_idx = tile_idx_in_gemm // num_n_tiles tile_n_idx = tile_idx_in_gemm % num_n_tiles # do regular gemm here offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :] b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :] accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)): # hint to Triton compiler to do proper loop pipelining tl.multiple_of(a_ptrs, [16, 16]) tl.multiple_of(b_ptrs, [16, 16]) # assume full tile for now a = tl.load(a_ptrs) b = tl.load(b_ptrs) accumulator += tl.dot(a, b) a_ptrs += BLOCK_SIZE_K b_ptrs += BLOCK_SIZE_K * ldb c = accumulator.to(tl.float16) offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :] # assumes full tile for now tl.store(c_ptrs, c) # go to the next tile by advancing NUM_SM tile_idx += NUM_SM # get ready to go to the next gemm problem last_problem_end = last_problem_end + num_tiles def group_gemm_fn(group_A, group_B): assert len(group_A) == len(group_B) group_size = len(group_A) A_addrs = [] B_addrs = [] C_addrs = [] g_sizes = [] g_lds = [] group_C = [] for i in range(group_size): A = group_A[i] B = group_B[i] assert A.shape[1] == B.shape[0] M, K = A.shape K, N = B.shape C = torch.empty((M, N), device=DEVICE, dtype=A.dtype) group_C.append(C) A_addrs.append(A.data_ptr()) B_addrs.append(B.data_ptr()) C_addrs.append(C.data_ptr()) g_sizes += [M, N, K] g_lds += [A.stride(0), B.stride(0), C.stride(0)] # note these are device tensors d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE) d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) # we use a fixed number of CTA, and it's auto-tunable grid = lambda META: (META["NUM_SM"],) grouped_matmul_kernel[grid]( d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size, ) return group_C tma_configs = [ triton.Config( {"BLOCK_SIZE_M": BM, "BLOCK_SIZE_N": BN, "BLOCK_SIZE_K": BK}, num_stages=s, num_warps=w, ) for BM in [128] for BN in [128, 256] for BK in [64, 128] for s in ([3, 4]) for w in [4, 8] ] @triton.autotune( tma_configs, key=["group_a_ptrs", "group_b_ptrs", "gropup_c_ptrs", "group_size"], ) @triton.jit def grouped_matmul_tma_kernel( # device tensor of matrices pointers group_a_ptrs, group_b_ptrs, group_c_ptrs, # device tensor of gemm sizes. its shape is [group_size, 3] # dim 0 is group_size, dim 1 is the values of of each gemm group_gemm_sizes, # device tensor of leading dimension sizes. its shape is [group_size, 3] # dim 0 is group_size, dim 1 is the values of of each gemm g_lds, # number of gemms group_size, # number of virtual SM NUM_SM: tl.constexpr, # tile sizes BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # is the output FP8 or FP16 FP8: tl.constexpr, ): dtype = tl.float8e4nv if FP8 else tl.float16 tile_idx = tl.program_id(0) last_problem_end = 0 for g in range(group_size): # get the gemm size of the current problem gm = tl.load(group_gemm_sizes + g * 3) gn = tl.load(group_gemm_sizes + g * 3 + 1) gk = tl.load(group_gemm_sizes + g * 3 + 2) num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M) num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) num_tiles = num_m_tiles * num_n_tiles if tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles: # pick up a tile from the current gemm problem lda = tl.load(g_lds + g * 3) ldb = tl.load(g_lds + g * 3 + 1) ldc = tl.load(g_lds + g * 3 + 2) a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(dtype)) b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(dtype)) c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(dtype)) a_desc = tl._experimental_make_tensor_descriptor( a_ptr, shape=[gm, gk], strides=[lda, 1], block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], ) b_desc = tl._experimental_make_tensor_descriptor( b_ptr, shape=[gn, gk], strides=[ldb, 1], block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], ) c_desc = tl._experimental_make_tensor_descriptor( c_ptr, shape=[gm, gn], strides=[ldc, 1], block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], ) # iterate through the tiles in the current gemm problem while ( tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles ): k = gk # figure out tile coordinates tile_idx_in_gemm = tile_idx - last_problem_end tile_m_idx = tile_idx_in_gemm // num_n_tiles tile_n_idx = tile_idx_in_gemm % num_n_tiles # do regular gemm here offs_am = tile_m_idx * BLOCK_SIZE_M offs_bn = tile_n_idx * BLOCK_SIZE_N accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)): a = a_desc.load([offs_am, kk * BLOCK_SIZE_K]) b = b_desc.load([offs_bn, kk * BLOCK_SIZE_K]) accumulator += tl.dot(a, b.T) offs_cm = tile_m_idx * BLOCK_SIZE_M offs_cn = tile_n_idx * BLOCK_SIZE_N c = accumulator.to(dtype) c_desc.store([offs_cm, offs_cn], c) # go to the next tile by advancing NUM_SM tile_idx += NUM_SM # get ready to go to the next gemm problem last_problem_end = last_problem_end + num_tiles def group_gemm_tma_fn(group_A, group_B): assert supports_tma() assert len(group_A) == len(group_B) group_size = len(group_A) A_addrs = [] B_addrs = [] C_addrs = [] g_sizes = [] g_lds = [] group_C = [] for i in range(group_size): A = group_A[i] B = group_B[i] assert A.shape[1] == B.shape[1] M, K = A.shape N, K = B.shape C = torch.empty((M, N), device=DEVICE, dtype=A.dtype) group_C.append(C) A_addrs.append(A.data_ptr()) B_addrs.append(B.data_ptr()) C_addrs.append(C.data_ptr()) g_sizes += [M, N, K] g_lds += [A.stride(0), B.stride(0), C.stride(0)] # note these are device tensors d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE) d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) # we use a fixed number of CTA, and it's auto-tunable # TMA descriptors require a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): return torch.empty(size, device="cuda", dtype=torch.int8) triton.set_allocator(alloc_fn) grid = lambda META: (META["NUM_SM"],) grouped_matmul_tma_kernel[grid]( d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size, FP8=torch.float8_e4m3fn == group_A[0].dtype, NUM_SM=num_sms(), ) return group_C group_m = [1024, 512, 256, 128] group_n = [1024, 512, 256, 128] group_k = [1024, 512, 256, 128] group_A = [] group_B = [] group_B_T = [] assert len(group_m) == len(group_n) assert len(group_n) == len(group_k) group_size = len(group_m) for i in range(group_size): M = group_m[i] N = group_n[i] K = group_k[i] A = torch.rand((M, K), device=DEVICE, dtype=torch.float16) B = torch.rand((K, N), device=DEVICE, dtype=torch.float16) B_T = B.T.contiguous() group_A.append(A) group_B.append(B) group_B_T.append(B_T) tri_out = group_gemm_fn(group_A, group_B) ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)] for i in range(group_size): assert torch.allclose(ref_out[i], tri_out[i], atol=1e-2, rtol=0) if supports_tma(): tri_tma_out = group_gemm_tma_fn(group_A, group_B_T) for i in range(group_size): assert torch.allclose(ref_out[i], tri_tma_out[i], atol=1e-2, rtol=0) # only launch the kernel, no tensor preparation here to remove all overhead def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size): grid = lambda META: (META["NUM_SM"],) grouped_matmul_kernel[grid]( a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, ) def triton_tma_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, dtype): grid = lambda META: (META["NUM_SM"],) grouped_matmul_tma_kernel[grid]( a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, FP8=torch.float8_e4m3fn == dtype, NUM_SM=num_sms(), ) def torch_perf_fn(group_A, group_B): for a, b in zip(group_A, group_B): torch.matmul(a, b) @triton.testing.perf_report( triton.testing.Benchmark( # argument names to use as an x-axis for the plot x_names=["N"], x_vals=[2**i for i in range(7, 11)], # different possible values for `x_name` line_arg="provider", # argument name whose value corresponds to a different line in the plot # possible values for `line_arg`` line_vals=["cublas", "triton"] + (["triton-tma"] if supports_tma() else []), # label name for the lines line_names=["cuBLAS", "Triton"] + (["Triton + TMA"] if supports_tma() else []), # line styles styles=[("green", "-"), ("blue", "-")] + ([("red", "-")] if supports_tma() else []), ylabel="runtime(ms)", # label name for the y-axis plot_name="group-gemm-performance", # name for the plot. Used also as a file name for saving the plot. args={}, ) ) def benchmark_square_matrices(N, provider): group_size = 4 group_A = [] group_B = [] group_B_T = [] A_addrs = [] B_addrs = [] B_T_addrs = [] C_addrs = [] g_sizes = [] g_lds = [] group_C = [] for i in range(group_size): A = torch.rand((N, N), device=DEVICE, dtype=torch.float16) B = torch.rand((N, N), device=DEVICE, dtype=torch.float16) C = torch.empty((N, N), device=DEVICE, dtype=torch.float16) B_T = B.T.contiguous() group_A.append(A) group_B.append(B) group_B_T.append(B_T) group_C.append(C) A_addrs.append(A.data_ptr()) B_addrs.append(B.data_ptr()) B_T_addrs.append(B_T.data_ptr()) C_addrs.append(C.data_ptr()) g_sizes += [N, N, N] g_lds += [N, N, N] d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) d_b_t_ptrs = torch.tensor(B_T_addrs, device=DEVICE) d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE) d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) quantiles = [0.5, 0.2, 0.8] if provider == "cublas": ms, min_ms, max_ms = triton.testing.do_bench( lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles ) if provider == "triton": ms, min_ms, max_ms = triton.testing.do_bench( lambda: triton_perf_fn( d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size ), quantiles=quantiles, ) if provider == "triton-tma": ms, min_ms, max_ms = triton.testing.do_bench( lambda: triton_tma_perf_fn( d_a_ptrs, d_b_t_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size, dtype=torch.float16, ), quantiles=quantiles, ) return ms, max_ms, min_ms @triton.testing.perf_report( triton.testing.Benchmark( # argument names to use as an x-axis for the plot x_names=["M"], x_vals=[2**i for i in range(7, 11)], # different possible values for `x_name` line_arg="provider", # argument name whose value corresponds to a different line in the plot # possible values for `line_arg`` line_vals=["cublas", "triton"] + (["triton-tma"] if supports_tma() else []), # label name for the lines line_names=["cuBLAS", "Triton"] + (["Triton + TMA"] if supports_tma() else []), # line styles styles=[("green", "-"), ("blue", "-")] + ([("red", "-")] if supports_tma() else []), ylabel="runtime(ms)", # label name for the y-axis plot_name="group-gemm-performance-m-8192-k-8192", # name for the plot. Used also as a file name for saving the plot. args={}, ) ) def benchmark_batches(M, provider): N = 8192 K = 8192 group_size = 4 group_A = [] group_B = [] group_B_T = [] A_addrs = [] B_addrs = [] B_T_addrs = [] C_addrs = [] g_sizes = [] g_lds = [] g_T_lds = [] group_C = [] for i in range(group_size): A = torch.rand((M, K), device=DEVICE, dtype=torch.float16) B = torch.rand((K, N), device=DEVICE, dtype=torch.float16) C = torch.empty((M, N), device=DEVICE, dtype=torch.float16) B_T = B.T.contiguous() group_A.append(A) group_B.append(B) group_B_T.append(B_T) group_C.append(C) A_addrs.append(A.data_ptr()) B_addrs.append(B.data_ptr()) B_T_addrs.append(B_T.data_ptr()) C_addrs.append(C.data_ptr()) g_sizes += [M, N, K] g_lds += [A.stride(0), B.stride(0), C.stride(0)] g_T_lds += [A.stride(0), B_T.stride(0), C.stride(0)] d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) d_b_t_ptrs = torch.tensor(B_T_addrs, device=DEVICE) d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE) d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) d_g_t_lds = torch.tensor(g_T_lds, dtype=torch.int32, device=DEVICE) quantiles = [0.5, 0.2, 0.8] if provider == "cublas": ms, min_ms, max_ms = triton.testing.do_bench( lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles ) if provider == "triton": ms, min_ms, max_ms = triton.testing.do_bench( lambda: triton_perf_fn( d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size ), quantiles=quantiles, ) if provider == "triton-tma": ms, min_ms, max_ms = triton.testing.do_bench( lambda: triton_tma_perf_fn( d_a_ptrs, d_b_t_ptrs, d_c_ptrs, d_g_sizes, d_g_t_lds, group_size, dtype=torch.float16, ), quantiles=quantiles, ) return ms, max_ms, min_ms benchmark_square_matrices.run(show_plots=True, print_data=True) benchmark_batches.run(show_plots=True, print_data=True) ================================================ FILE: kernels/MoE/group_GEMM/triton/readme.md ================================================ ## Experimental Triton Group GEMM for supporting MoE training. ================================================ FILE: kernels/MoE/group_GEMM/triton/testing/fast_verification.py ================================================ import logging import torch # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) # import the reference implementations from pytorch_reference_backwards import ( _compute_grad_w_pytorch, _compute_grad_x_pytorch, _pytorch_fallback_backward, _pytorch_reference_backward, ) # Import the grouped GEMM modules from tgrouped_gemm_backwards import grouped_gemm_backward from tgrouped_gemm_forward import grouped_gemm_forward as grouped_gemm def test_backward_pass(): """ A simple test for the grouped GEMM backward pass with detailed error handling. """ try: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Test parameters G = 20 # Number of groups M = 1024 # Input dimension N = 512 # Output dimension per group K = 256 # Hidden dimension # Create input and weight tensors x = torch.randn(M, K, dtype=torch.bfloat16, device=device, requires_grad=True) w = torch.randn( N * G, K, dtype=torch.bfloat16, device=device, requires_grad=True ) # Create group sizes m_sizes = torch.zeros(G, device=device, dtype=torch.int32) base_size = M // G remainder = M % G for i in range(G): m_sizes[i] = base_size + (1 if i < remainder else 0) # Log the setup print(f"Test setup - G: {G}, M: {M}, N: {N}, K: {K}") print(f"Input x shape: {x.shape}") logging.info(f"Weight w shape: {w.shape}") logging.info(f"Group sizes: {m_sizes}") # Step 1: Run forward pass logging.info("Running forward pass") result = grouped_gemm(x, w, m_sizes) logging.info(f"Forward result shape: {result.shape}") # Create a gradient for backpropagation grad_output = torch.randn_like(result) logging.info(f"Created gradient with shape: {grad_output.shape}") # Step 2: Run backward pass directly logging.info("Running backward pass directly") grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes) # Verify gradient shapes logging.info( f"Gradient shapes - grad_x: {grad_x.shape}, grad_w: {grad_w.shape}" ) # Step 3: Verify gradient computation using PyTorch's autograd # First create autograd-enabled tensors x_autograd = x.detach().clone().requires_grad_(True) w_autograd = w.detach().clone().requires_grad_(True) # Create a PyTorch reference implementation to compare against logging.info("Running PyTorch reference implementation") # Compute reference result reference_result = torch.zeros_like(result) m_start = 0 for g in range(G): m_size = m_sizes[g].item() m_end = m_start + m_size n_start = g * N n_end = (g + 1) * N if m_size > 0: reference_result[m_start:m_end, n_start:n_end] = ( x_autograd[m_start:m_end, :] @ w_autograd[n_start:n_end, :].T ) m_start = m_end # Backpropagate using PyTorch reference_result.backward(grad_output) # Compare gradients logging.info("Comparing gradients with PyTorch reference") grad_x_error = (grad_x - x_autograd.grad).abs().max().item() grad_w_error = (grad_w - w_autograd.grad).abs().max().item() logging.info( f"Maximum gradient error - grad_x: {grad_x_error}, grad_w: {grad_w_error}" ) # Check if gradients are close using allclose rtol = 1e-2 # Relative tolerance for bfloat16 atol = 1e-2 # Absolute tolerance for bfloat16 grad_x_close = torch.allclose(grad_x, x_autograd.grad, rtol=rtol, atol=atol) if not grad_x_close: logging.warning("FAILED: Gradient mismatch detected in grad_x") else: logging.info( "✓ SUCCESS! grad_X matches the PyTorch reference (allclose check passed)" ) grad_w_close = torch.allclose(grad_w, w_autograd.grad, rtol=rtol, atol=atol) if not grad_w_close: logging.warning("FAILED: Gradient mismatch detected in grad_w") else: logging.info( "✓ SUCCESS! grad_W matches the PyTorch reference (allclose check passed)" ) logging.info( f"Gradients allclose check - grad_x: {grad_x_close}, grad_w: {grad_w_close}" ) if grad_x_close and grad_w_close: logging.info( "✓ SUCCESS: Gradients match the PyTorch reference (allclose check passed)" ) else: logging.error("✗ FAILURE: Gradient mismatch detected in allclose check") # Additional diagnostics (for failed cases or in general) if True: # not grad_x_close: # Find where the largest differences are diff_x = (grad_x - x_autograd.grad).abs() max_idx_x = diff_x.argmax().item() flat_idx_x = max_idx_x idx_x = np.unravel_index(flat_idx_x, grad_x.shape) logging.error( f"Largest grad_x difference at {idx_x}: " f"{grad_x[idx_x].item()} vs {x_autograd.grad[idx_x].item()}" ) # Count zeros zeros_grad_x = (grad_x == 0).sum().item() zeros_autograd_x = (x_autograd.grad == 0).sum().item() logging.error( f"Zeros in grad_x: {zeros_grad_x}/{grad_x.numel()} ({zeros_grad_x/grad_x.numel()*100:.2f}%)" ) logging.error( f"Zeros in x_autograd.grad: {zeros_autograd_x}/{x_autograd.grad.numel()} ({zeros_autograd_x/x_autograd.grad.numel()*100:.2f}%)" ) if True: # not grad_w_close: # Find where the largest differences are diff_w = (grad_w - w_autograd.grad).abs() max_idx_w = diff_w.argmax().item() flat_idx_w = max_idx_w idx_w = np.unravel_index(flat_idx_w, grad_w.shape) logging.error( f"Largest grad_w difference at {idx_w}: " f"{grad_w[idx_w].item()} vs {w_autograd.grad[idx_w].item()}" ) # Count zeros zeros_grad_w = (grad_w == 0).sum().item() zeros_autograd_w = (w_autograd.grad == 0).sum().item() logging.error( f"Zeros in grad_w: {zeros_grad_w}/{grad_w.numel()} ({zeros_grad_w/grad_w.numel()*100:.2f}%)" ) logging.error( f"Zeros in w_autograd.grad: {zeros_autograd_w}/{w_autograd.grad.numel()} ({zeros_autograd_w/w_autograd.grad.numel()*100:.2f}%)" ) return grad_x_close and grad_w_close except Exception as e: logging.error(f"Test failed with error: {e}") import traceback logging.error(traceback.format_exc()) return False if __name__ == "__main__": print("Running test_backward_pass") logging.debug("Running test_backward_pass") # Add numpy import for unravel_index import numpy as np success = test_backward_pass() logging.info(f"Test {'succeeded' if success else 'failed'}") ================================================ FILE: kernels/MoE/group_GEMM/triton/testing/pytorch_reference_backwards.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch # This is a series of helper functions for grouped GEMM backward that compute the gradients # using eager PyTorch operations. They are used as a verification reference for the Triton kernels. # They can also used as a fallback when the Triton kernels cannot be used, though lets hope that is not needed. def _compute_grad_x_pytorch(grad_output, w, m_sizes, grad_x): """ Compute grad_x using pure PyTorch operations with FP32 precision """ G = m_sizes.shape[0] M, K = grad_x.shape N = w.shape[0] // G # Zero out the output tensor first grad_x.zero_() # Store original dtype and convert to float32 for computation orig_dtype = grad_x.dtype grad_output_fp32 = grad_output.float() w_fp32 = w.float() grad_x_fp32 = torch.zeros_like(grad_x, dtype=torch.float32) # Process each group separately m_start = 0 for g in range(G): m_size = m_sizes[g].item() if m_size > 0: m_end = m_start + m_size n_start = g * N n_end = (g + 1) * N # Get slices for this group grad_output_slice = grad_output_fp32[m_start:m_end, n_start:n_end] w_slice = w_fp32[n_start:n_end] # Process in chunks for better precision on large matrices CHUNK_SIZE = 256 for chunk_start in range(0, m_size, CHUNK_SIZE): chunk_end = min(chunk_start + CHUNK_SIZE, m_size) chunk_size = chunk_end - chunk_start # Compute matrix multiplication with higher precision grad_output_chunk = grad_output_slice[chunk_start:chunk_end] result_chunk = torch.matmul( grad_output_chunk.double(), w_slice.double() ) # Store the result grad_x_fp32[m_start + chunk_start : m_start + chunk_end].copy_( result_chunk.float() ) m_start = m_end # Convert back to original dtype grad_x.copy_(grad_x_fp32.to(orig_dtype)) def _compute_grad_w_pytorch(grad_output, x, m_sizes, grad_w): """ Compute grad_w using pure PyTorch operations with FP64 precision for better accuracy. """ G = m_sizes.shape[0] N_times_G, K = grad_w.shape N = N_times_G // G # Zero out the output tensor first grad_w.zero_() # Store original dtype and convert to float32 for computation orig_dtype = grad_w.dtype grad_output_fp32 = grad_output.float() x_fp32 = x.float() grad_w_fp32 = torch.zeros_like(grad_w, dtype=torch.float32) # Handle potential K dimension mismatches K_x = x.shape[1] min_K = min(K, K_x) # Process each group separately m_start = 0 for g in range(G): m_size = m_sizes[g].item() if m_size > 0: m_end = m_start + m_size n_start = g * N n_end = (g + 1) * N # Get slices for this group grad_output_slice = grad_output_fp32[m_start:m_end, n_start:n_end] x_slice = x_fp32[m_start:m_end, :min_K] # Process in chunks for better precision CHUNK_SIZE = 32 result = torch.zeros( (grad_output_slice.shape[1], min_K), dtype=torch.float64, device=grad_output_slice.device, ) for chunk_start in range(0, m_size, CHUNK_SIZE): chunk_end = min(chunk_start + CHUNK_SIZE, m_size) # Get chunks grad_output_chunk = grad_output_slice[chunk_start:chunk_end].double() x_chunk = x_slice[chunk_start:chunk_end].double() # Matrix multiplication in FP64 chunk_result = torch.matmul(grad_output_chunk.t(), x_chunk) result += chunk_result # Handle K dimension padding if needed if K > min_K: temp_result = torch.zeros( (grad_output_slice.shape[1], K), dtype=torch.float32, device=grad_output_slice.device, ) temp_result[:, :min_K] = result.float() grad_w_fp32[n_start:n_end].copy_(temp_result) else: grad_w_fp32[n_start:n_end].copy_(result.float()) m_start = m_end # Convert back to original dtype grad_w.copy_(grad_w_fp32.to(orig_dtype)) def _pytorch_fallback_backward(grad_output, x, w, m_sizes): """ Pure PyTorch implementation of grouped GEMM backward with high precision. Used as a fallback when the Triton kernels cannot be used. """ logging.info( "WARNING: Using PyTorch fallback for grouped GEMM backward with high precision" ) # Ensure inputs are contiguous x = x.contiguous() w = w.contiguous() grad_output = grad_output.contiguous() m_sizes = m_sizes.contiguous() # Allocate output tensors grad_x = torch.zeros_like(x) grad_w = torch.zeros_like(w) # Compute gradients using the helper functions _compute_grad_x_pytorch(grad_output, w, m_sizes, grad_x) _compute_grad_w_pytorch(grad_output, x, m_sizes, grad_w) return grad_x, grad_w def _pytorch_reference_backward(grad_output, x, w, m_sizes): """ Pure PyTorch implementation of grouped GEMM backward for validation. Simple version that's easy to verify but may be less numerically accurate for large matrices. """ # Create output gradients grad_x = torch.zeros_like(x) grad_w = torch.zeros_like(w) # Compute group-by-group G = m_sizes.shape[0] N = w.shape[0] // G m_start = 0 for g in range(G): m_size = m_sizes[g].item() if m_size > 0: m_end = m_start + m_size n_start = g * N n_end = (g + 1) * N # Compute gradients grad_x[m_start:m_end] = torch.matmul( grad_output[m_start:m_end, n_start:n_end], w[n_start:n_end] ) grad_w[n_start:n_end] = torch.matmul( grad_output[m_start:m_end, n_start:n_end].t(), x[m_start:m_end] ) m_start += m_size return grad_x, grad_w # ========== End helper functions ========== ================================================ FILE: kernels/MoE/group_GEMM/triton/tgroup_gemm_backwards.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import logging from typing import Tuple import torch import triton import triton.language as tl from tma_utils import TmaAutoTuneHelper # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) """ Backward pass for grouped GEMM with Triton, where grouping is N*G We are computing gradients with respect to both the input (`grad_x`) and the weights (`grad_w`). """ # =============== Start Triton Kernels =============== @triton.jit def _kernel_grouped_gemm_backward_x_scheduled( grad_y_ptr, # grad of dl/dY [M, N*G] w_t_ptr, # w transposed [K, N*G] grad_x_ptr, # output of kernel [M, K] group_offsets_ptr, # Pre-computed group offsets [G+1] workspace, # Workspace for TMA descriptors G, # Number of groups M, # Total M dimension size N, # N per group K, # K dimension size stride_go_m, stride_go_n, stride_w_n, stride_w_k, stride_gx_m, stride_gx_k, NUM_SMS, USE_TMA_LOAD: tl.constexpr = False, USE_TMA_STORE: tl.constexpr = False, BLOCK_SIZE_M: tl.constexpr = 64, BLOCK_SIZE_N: tl.constexpr = 64, BLOCK_SIZE_K: tl.constexpr = 64, GROUP_SIZE_M: tl.constexpr = 8, EVEN_K: tl.constexpr = True, ) -> None: """ Scheduled grouped GEMM backward for X with TMA support. For each group g, computes: grad_x[g] = grad_y[g] @ w_t[g].T Where: - grad_y is [M, N*G] - w_t is [K, N*G] (transposed from [N*G, K]) - grad_x is [M, K] """ # Get coordinates for the current program tidx = tl.program_id(axis=0) dtype = grad_x_ptr.dtype.element_ty TMA_SIZE: tl.constexpr = 128 # Initialize workspace pointer if using TMA store if USE_TMA_STORE: c_desc_ptr = workspace + tidx * TMA_SIZE else: c_desc_ptr = None # Calculate work distribution parameters num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) num_pid_in_group = num_pid_m * num_pid_k # Process all assigned work items pid = tidx while pid < G * num_pid_in_group: # Calculate work distribution for this pid group_id = pid // num_pid_in_group pid_in_group = pid % num_pid_in_group pid_m = pid_in_group % num_pid_m pid_k = pid_in_group // num_pid_m # Get group boundaries valid_group = group_id < G group_start = tl.where(valid_group, tl.load(group_offsets_ptr + group_id), 0) group_end = tl.where(valid_group, tl.load(group_offsets_ptr + group_id + 1), 0) group_size = group_end - group_start # Calculate a mask for valid processing (valid group and non-empty) valid_work = valid_group & (group_size > 0) # Only process if we have valid work if valid_work: # Compute offsets for this group n_start = group_id * N # Block dimensions m_block_offset = pid_m * BLOCK_SIZE_M k_block_offset = pid_k * BLOCK_SIZE_K # Setup TMA descriptor for output if using TMA if USE_TMA_STORE: m_size = tl.minimum( BLOCK_SIZE_M, group_end - (group_start + m_block_offset) ) if m_size > 0: tl.extra.cuda.experimental_device_tensormap_create2d( desc_ptr=c_desc_ptr, global_address=grad_x_ptr + (group_start + m_block_offset) * stride_gx_m + k_block_offset * stride_gx_k, load_size=[ m_size, tl.minimum(BLOCK_SIZE_K, K - k_block_offset), ], global_size=[m_size, K], element_ty=dtype, ) tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) # Initialize offsets for this block offs_m = group_start + m_block_offset + tl.arange(0, BLOCK_SIZE_M) # For K dimension, optimize memory access if EVEN_K is True offs_k = k_block_offset + tl.arange(0, BLOCK_SIZE_K) # Create masks m_mask = offs_m < group_end k_mask = offs_k < K # Initialize accumulator accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) # Loop over the reduction dimension (N) # Use smaller steps to improve precision and avoid numerical issues for n_offset in range(0, N, BLOCK_SIZE_N): # Handle boundary conditions for the reduction dimension n_size = tl.minimum(BLOCK_SIZE_N, N - n_offset) offs_n = n_start + n_offset + tl.arange(0, BLOCK_SIZE_N) n_mask = offs_n < (n_start + N) # Fixed stride formats to ensure consistent memory access grad_y_block = tl.load( grad_y_ptr + offs_m[:, None] * stride_go_m + offs_n[None, :] * stride_go_n, mask=m_mask[:, None] & n_mask[None, :], other=0.0, ) # Load w_t [K, N*G] block with correct strides w_t_block = tl.load( w_t_ptr + offs_k[:, None] * stride_w_k + offs_n[None, :] * stride_w_n, mask=k_mask[:, None] & n_mask[None, :], other=0.0, ) # grad_y @ w_t.T # Allow TF32 if K is even and divisible by 8 if EVEN_K: accumulator += tl.dot( grad_y_block.to(tl.float32), w_t_block.to(tl.float32).T, allow_tf32=True, ) else: accumulator += tl.dot( grad_y_block.to(tl.float32), w_t_block.to(tl.float32).T, allow_tf32=False, ) # Store result to grad_x with explicit strides if USE_TMA_STORE: # TMA store tl._experimental_descriptor_store( c_desc_ptr, accumulator.to(dtype), [0, 0], # Starting offset in the output block ) else: # Standard store tl.store( grad_x_ptr + offs_m[:, None] * stride_gx_m + offs_k[None, :] * stride_gx_k, accumulator.to(dtype), mask=m_mask[:, None] & k_mask[None, :], ) pid = pid + NUM_SMS @triton.jit def _kernel_grouped_gemm_backward_w_scheduled( x_t_ptr, # x transposed [K, M] grad_y_ptr, # grad of dl/dY [M, N*G] grad_w_ptr, # output of kernel (grad_w) [N*G, K] group_offsets_ptr, # Pre-computed group offsets [G+1] workspace, # Workspace for TMA descriptors G, # Number of groups M, # Total M dimension size N, # N per group K, # K dimension size stride_x_m, stride_x_k, stride_go_m, stride_go_n, stride_gw_n, stride_gw_k, NUM_SMS, USE_TMA_LOAD: tl.constexpr = False, USE_TMA_STORE: tl.constexpr = False, BLOCK_SIZE_N: tl.constexpr = 64, BLOCK_SIZE_K: tl.constexpr = 64, BLOCK_SIZE_M: tl.constexpr = 32, GROUP_SIZE_N: tl.constexpr = 8, EVEN_K: tl.constexpr = True, ) -> None: """ Scheduled implementation of grouped GEMM backward for W with TMA support. For each group g, computes: grad_w[g] = grad_y[g].T @ x[g] Where: - x_t is [K, M] (transposed from [M, K]) - grad_y is [M, N*G] - grad_w is [N*G, K] """ # Define coordinates for the current program tidx = tl.program_id(axis=0) dtype = grad_w_ptr.dtype.element_ty TMA_SIZE: tl.constexpr = 128 # Initialize workspace pointer if using TMA store if USE_TMA_STORE: c_desc_ptr = workspace + tidx * TMA_SIZE else: c_desc_ptr = None # Calculate work distribution parameters num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) num_pid_in_group = num_pid_n * num_pid_k # Process all assigned work items pid = tidx while pid < G * num_pid_in_group: # Calculate work distribution for this pid group_id = pid // num_pid_in_group pid_in_group = pid % num_pid_in_group pid_n = pid_in_group % num_pid_n pid_k = pid_in_group // num_pid_n # Get group boundaries valid_group = group_id < G group_start = tl.where(valid_group, tl.load(group_offsets_ptr + group_id), 0) group_end = tl.where(valid_group, tl.load(group_offsets_ptr + group_id + 1), 0) group_size = group_end - group_start # Calculate a mask for valid processing (valid group and non-empty) valid_work = valid_group & (group_size > 0) # Only process if we have valid work if valid_work: # Compute offsets for this group n_start = group_id * N # Block dimensions n_block_offset = pid_n * BLOCK_SIZE_N k_block_offset = pid_k * BLOCK_SIZE_K # Setup TMA descriptor for output if using TMA if USE_TMA_STORE: n_size = tl.minimum(BLOCK_SIZE_N, N - n_block_offset) if n_size > 0: tl.extra.cuda.experimental_device_tensormap_create2d( desc_ptr=c_desc_ptr, global_address=grad_w_ptr + (n_start + n_block_offset) * stride_gw_n + k_block_offset * stride_gw_k, load_size=[ n_size, tl.minimum(BLOCK_SIZE_K, K - k_block_offset), ], global_size=[n_size, K], element_ty=dtype, ) tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) # Initialize offsets for this block offs_n = n_start + n_block_offset + tl.arange(0, BLOCK_SIZE_N) offs_k = k_block_offset + tl.arange(0, BLOCK_SIZE_K) # Create masks n_mask = offs_n < (n_start + N) k_mask = offs_k < K # Initialize accumulator accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32) # Loop over the reduction dimension (M) with smaller steps to avoid overflow for m_offset in range(0, group_size, BLOCK_SIZE_M): # Handle boundary conditions for the reduction dimension m_size = tl.minimum(BLOCK_SIZE_M, group_size - m_offset) offs_m = group_start + m_offset + tl.arange(0, BLOCK_SIZE_M) m_mask = offs_m < group_end # Load grad_y [M, N*G] block with explicit strides grad_y_block = tl.load( grad_y_ptr + offs_m[:, None] * stride_go_m + offs_n[None, :] * stride_go_n, mask=m_mask[:, None] & n_mask[None, :], other=0.0, ) # Load x_t [K, M] block with explicit strides x_t_block = tl.load( x_t_ptr + offs_k[:, None] * stride_x_k + offs_m[None, :] * stride_x_m, mask=k_mask[:, None] & m_mask[None, :], other=0.0, ) # Matrix multiplication: (grad_y_block.T @ x_t_block.T) if EVEN_K: accumulator += tl.dot( grad_y_block.to( tl.float32 ).T, # Shape: [BLOCK_SIZE_N, BLOCK_SIZE_M] x_t_block.to( tl.float32 ).T, # Shape: [BLOCK_SIZE_M, BLOCK_SIZE_K] allow_tf32=True, ) else: accumulator += tl.dot( grad_y_block.to( tl.float32 ).T, # Shape: [BLOCK_SIZE_N, BLOCK_SIZE_M] x_t_block.to( tl.float32 ).T, # Shape: [BLOCK_SIZE_M, BLOCK_SIZE_K] allow_tf32=False, ) # Store result to grad_w with explicit strides if USE_TMA_STORE: # TMA store tl._experimental_descriptor_store( c_desc_ptr, accumulator.to(dtype), [0, 0], # Starting offset in the output block ) else: # Standard store with explicit strides tl.store( grad_w_ptr + offs_n[:, None] * stride_gw_n + offs_k[None, :] * stride_gw_k, accumulator.to(dtype), mask=n_mask[:, None] & k_mask[None, :], ) pid = pid + NUM_SMS # ========== End Triton kernels ========== # ========== Begin grouped_gemm_backward cover function ========== def grouped_gemm_backward( grad_output: torch.Tensor, x: torch.Tensor, w: torch.Tensor, m_sizes: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Backward pass for grouped matrix multiplication using scheduled kernels with TMA support. Args: grad_output: Gradient with respect to output, shape [M, N*G] x: Input tensor from forward pass, shape [M, K] w: Weight tensor from forward pass, shape [N*G, K] m_sizes: Group sizes tensor, shape [G] Returns: Tuple of gradients with respect to x and w: (grad_x, grad_w) """ logging.info("Starting grouped_gemm_backward with TMA-enabled scheduling") # Check CUDA availability if not torch.cuda.is_available(): logging.error("CUDA not available for backward pass") raise RuntimeError("CUDA not available for backward pass") # return _pytorch_fallback_backward(grad_output, x, w, m_sizes) # Get GPU parameters - TODO: this can use PyTorch cached info... device_props = torch.cuda.get_device_properties("cuda") NUM_SMS = device_props.multi_processor_count # Check TMA support has_tma = hasattr(tl.extra, "cuda") and device_props.major >= 9 if has_tma: logging.info(f"TMA support detected on GPU with {NUM_SMS} SMs") USE_TMA_LOAD = True # TODO - this does nothing atm..removed to focus on numerical correctness first. USE_TMA_STORE = False else: logging.warning("TMA support not detected, disabling TMA optimizations") USE_TMA_LOAD = False USE_TMA_STORE = False # Validate input dimensions G = m_sizes.shape[0] M, K_x = x.shape N_times_G, K_w = w.shape # Check that K dimensions match if K_x != K_w: logging.warning(f"K dimension mismatch: x has K={K_x}, w has K={K_w}") raise ValueError("K dimensions must match for grouped GEMM backward") # return _pytorch_fallback_backward(grad_output, x, w, m_sizes) try: # Ensure contiguous tensors grad_output = grad_output.contiguous() x = x.contiguous() w = w.contiguous() m_sizes = m_sizes.contiguous() # Allocate output tensors grad_x = torch.zeros_like(x) grad_w = torch.zeros_like(w) # Determine N per group # N*G is the second dimension size of grad_output N = N_times_G // G # Set stride values # Direct access pattern for grad_output tensor stride_go_m = grad_output.stride(0) # grad_output in M dimension stride_go_n = grad_output.stride(1) # grad_output in N dimension # Pattern match the transposed weight tensor stride_w_n = 1 # transposed weights in N dimension stride_w_k = N * G # transposed weights in K dimension # Pattern match the output grad_x tensor stride_gx_m = grad_x.stride(0) # grad_x in M dimension stride_gx_k = grad_x.stride(1) # Sgrad_x in K dimension # Pattern match the transposed x tensor stride_x_m = 1 # Stride for transposed x in M dimension stride_x_k = M # Stride for transposed x in K dimension # Pattern match the output grad_w tensor stride_gw_n = grad_w.stride(0) # grad_w in N dimension stride_gw_k = grad_w.stride(1) # grad_w in K dimension # Pre-compute group offsets for indexing group_offsets = torch.zeros(G + 1, device=m_sizes.device, dtype=torch.int32) m_offset = 0 for g in range(G): group_offsets[g] = m_offset m_offset += m_sizes[g].item() group_offsets[G] = m_offset # Total M # Check if K dimension is even (optimize memory access patterns) EVEN_K = (K_x % 8) == 0 logging.info(f"EVEN_K optimization enabled: {EVEN_K} (K={K_x})") # Transpose x and w for backward computation x_t = x.T.contiguous() # Shape: [K, M] w_t = w.T.contiguous() # Shape: [K, N*G] # Allocate workspace for TMA descriptors if needed if USE_TMA_LOAD or USE_TMA_STORE: workspace = torch.empty((NUM_SMS * 128), device=x.device, dtype=torch.uint8) else: # Empty tensor when TMA is not used workspace = torch.empty(0, device=x.device, dtype=torch.uint8) # Set block sizes based on K dimension # For larger K, use smaller blocks to reduce register pressure BLOCK_SIZE = 64 if K_x <= 64 else 32 BLOCK_SIZE_K = BLOCK_SIZE BLOCK_SIZE_M = BLOCK_SIZE BLOCK_SIZE_N = BLOCK_SIZE # Determine maximum size needed and set the grid size num_pid_m = triton.cdiv(M, BLOCK_SIZE_M) num_pid_k = triton.cdiv(K_x, BLOCK_SIZE_K) num_pid_n = triton.cdiv(N, BLOCK_SIZE_N) # Compute total number of blocks needed for each kernel total_blocks_x = G * num_pid_m * num_pid_k total_blocks_w = G * num_pid_n * num_pid_k try: logging.info("Computing grad_x with TMA-enabled kernel") # Fixed grid size based on SM count grid = (NUM_SMS,) _kernel_grouped_gemm_backward_x_scheduled[grid]( grad_output, w_t, # Using transposed weights grad_x, group_offsets, workspace, G, M, N, K_x, stride_go_m, stride_go_n, stride_w_n, stride_w_k, stride_gx_m, stride_gx_k, NUM_SMS, USE_TMA_LOAD, USE_TMA_STORE, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, EVEN_K=EVEN_K, ) logging.info( "Kernel run success: grad_X computation successful with TMA-enabled kernel" ) except Exception as e: logging.error(f"FAILED: Error in TMA-enabled backward_x kernel: {e}") logging.info("WARNING: Falling back to PyTorch for grad_x") _compute_grad_x_pytorch(grad_output, w, m_sizes, grad_x) try: logging.info("Computing grad_w with TMA-enabled kernel") # Fixed grid size based on SM count grid = (NUM_SMS,) _kernel_grouped_gemm_backward_w_scheduled[grid]( x_t, # Using transposed inputs grad_output, grad_w, group_offsets, workspace, G, M, N, K_w, stride_x_m, stride_x_k, stride_go_m, stride_go_n, stride_gw_n, stride_gw_k, NUM_SMS, USE_TMA_LOAD, USE_TMA_STORE, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_M=BLOCK_SIZE_M, EVEN_K=EVEN_K, ) logging.info( "Kernel run success - grad_W computation successful with TMA-enabled kernel" ) except Exception as e: logging.error(f"FAILED: Error in TMA-enabled backward_w kernel: {e}") logging.info("WARNING: Falling back to PyTorch for grad_w") # _compute_grad_w_pytorch(grad_output, x, m_sizes, grad_w) return grad_x, grad_w except Exception as e: logging.error(f"Error in grouped_gemm_backward: {e}") # return _pytorch_fallback_backward(grad_output, x, w, m_sizes) ================================================ FILE: kernels/MoE/group_GEMM/triton/tgroup_gemm_forward.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe # This is copied from FBGEMM, with some modifications. Not kept in sync, Original code: # https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py import functools from typing import Optional import tma_utils as utils import torch import triton import triton.language as tl from triton.runtime import driver # @manual """ _NV_CONFIGS = [ triton.Config( { "BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k, }, num_stages=num_stages, num_warps=num_warps, num_ctas=num_ctas, ) for block_size_m in [64, 128] for block_size_n in [64, 128, 256] for block_size_k in [64, 128, 256] for num_stages in [3, 4] for num_warps in [4, 8] for num_ctas in [1] ] _AMD_CONFIGS = [ triton.Config( { "BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k, "waves_per_eu": waves_per_cu, "matrix_instr_nonkdim": matrix_instr_nonkdim, }, num_stages=num_stages, num_warps=num_warps, ) for block_size_m in [32, 64, 128] for block_size_n in [32, 64, 128, 256] for block_size_k in [128, 256] for num_stages in [1, 2] for num_warps, waves_per_cu in [(4, 1), (8, 2), (16, 4)] for matrix_instr_nonkdim in [16] ] def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs): device = torch.cuda.current_device() # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages if dtsize is None: dtsize = named_args["c_ptr"].element_size() if dtype is None: dtype = named_args["c_ptr"].dtype pruned_configs = [] for config in configs: kw = config.kwargs BLOCK_M, BLOCK_N, BLOCK_K, num_stages = ( kw["BLOCK_SIZE_M"], kw["BLOCK_SIZE_N"], kw["BLOCK_SIZE_K"], config.num_stages, ) G, M, N, K = ( named_args["G"], named_args["M_BUCKET"], named_args["N"], named_args["K"], ) # 1. make sure we have enough smem max_shared_memory = driver.active.utils.get_device_properties(device)[ "max_shared_mem" ] if torch.version.hip: required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize else: required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize if required_shared_memory > max_shared_memory: continue M_PER_GROUP = M // G MIN_M_TILES = 32 if torch.version.hip else 64 # 2. make sure we don't load M tiles that are too big if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2): continue # 3. make sure we don't load N tiles that are too small if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2): continue num_sm = driver.active.utils.get_device_properties(device)[ "multiprocessor_count" ] N_TILES = N // BLOCK_N MIN_N_TILES = 32 if torch.version.hip else 64 # 4. make sure we don't load N tiles that are too big if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm: continue # 5. make sure we don't load N tiles that are too small if BLOCK_N < 128 and M * N_TILES > 2 * num_sm: continue # 6. make sure K can be evenly divided if K % BLOCK_K != 0: continue pruned_configs.append(config) return pruned_configs @triton.autotune( configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS, key=["G", "M_BUCKET", "N", "K"], prune_configs_by={"early_config_prune": early_config_prune}, ) """ @triton.jit def _kernel_grouped_gemm( a_desc_ptr, b_desc_ptr, c_ptr, workspace, m_sizes, # problem sizes G: tl.constexpr, M_BUCKET: tl.constexpr, N: tl.constexpr, # N is per group K: tl.constexpr, NUM_SMS: tl.constexpr, USE_TMA_LOAD: tl.constexpr, USE_TMA_STORE: tl.constexpr, # tile sizes BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ) -> None: tidx = tl.program_id(0) dtype: tl.dtype = c_ptr.dtype.element_ty TMA_SIZE: tl.constexpr = tl.constexpr(128) if USE_TMA_STORE: c_desc_ptr = workspace + tidx * TMA_SIZE else: c_desc_ptr = None M_end_offset = 0 iterated_tiles = 0 for g in tl.range(G): # Move across groups M_start_offset = M_end_offset m_size = tl.load(m_sizes + g) M_end_offset = M_start_offset + m_size if m_size > 0: # Compute for this group N_start_offset = g * N n_size = N # N is already per group # Calculate the number of tiles for this group num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) num_tiles = num_m_tiles * num_n_tiles if USE_TMA_STORE: # Set up TMA descriptor for output # pyre-ignore tl.extra.cuda.experimental_device_tensormap_create2d( desc_ptr=c_desc_ptr, global_address=c_ptr + M_start_offset * (N * G) + N_start_offset, # Offset to this group's output load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[m_size, n_size], element_ty=c_ptr.dtype.element_ty, ) # pyre-ignore tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) # Move across tiles while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles: gidx = tidx - iterated_tiles # Split M first and N second. tile_m_idx = gidx % num_m_tiles tile_n_idx = gidx // num_m_tiles accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) tl.static_assert(K % BLOCK_SIZE_K == 0) if USE_TMA_LOAD: # Use TMA to load input and weight blocks m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32) n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32) for k_offset in range(0, K, BLOCK_SIZE_K): # Load input block [M, K] a = tl._experimental_descriptor_load( a_desc_ptr, [m_offset, k_offset], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype, ) # Load weight block [N, K] b = tl._experimental_descriptor_load( b_desc_ptr, [n_offset, k_offset], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype, ) # Compute matrix multiplication accumulator += tl.dot(a, b.T) else: # Manual load without TMA offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = ( a_desc_ptr + (M_start_offset + offs_am[:, None]) * K + offs_k[None, :] ) b_ptrs = ( b_desc_ptr + (N_start_offset + offs_bn[:, None]) * K + offs_k[None, :] ) for k_offset in range(0, K, BLOCK_SIZE_K): # Load with bounds checking a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) # Compute matrix multiplication accumulator += tl.dot(a, b.T) # Update pointers for next block a_ptrs += BLOCK_SIZE_K b_ptrs += BLOCK_SIZE_K # Store result if USE_TMA_STORE: # Store using TMA m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) tl._experimental_descriptor_store( c_desc_ptr, accumulator.to(c_ptr.dtype.element_ty), [m_offset, n_offset], ) else: # Manual store offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c = accumulator.to(c_ptr.dtype.element_ty) tl.store( c_ptr + (M_start_offset + offs_am[:, None]) * (N * G) # Row stride is N*G + ( N_start_offset + offs_bn[None, :] ), # Column offset to this group's N c, mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size, ) tidx += NUM_SMS # Move to next tile iterated_tiles += num_tiles TT_FP8_DTYPE = tl.float8e4b8 if torch.version.hip else tl.float8e4nv """@triton.autotune( configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS, key=["G", "M_BUCKET", "N", "K"], prune_configs_by={ "early_config_prune": functools.partial( early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1 ) }, ) """ @triton.jit def _kernel_grouped_gemm_fp8_rowwise( a_desc_ptr, a_scale_ptr, b_desc_ptr, b_scale_ptr, c_ptr, workspace, m_sizes, # problem sizes G: tl.constexpr, M_BUCKET: tl.constexpr, N: tl.constexpr, # N is per group K: tl.constexpr, NUM_SMS: tl.constexpr, USE_TMA_LOAD: tl.constexpr, USE_TMA_STORE: tl.constexpr, # tile sizes BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ) -> None: tidx = tl.program_id(0) dtype = TT_FP8_DTYPE TMA_SIZE: tl.constexpr = tl.constexpr(128) if USE_TMA_STORE: c_desc_ptr = workspace + tidx * TMA_SIZE else: c_desc_ptr = None M_end_offset = 0 iterated_tiles = 0 for g in tl.range(G): # Move across groups M_start_offset = M_end_offset m_size = tl.load(m_sizes + g) M_end_offset = M_start_offset + m_size if m_size > 0: # Compute for this group N_start_offset = g * N n_size = N # N is already per group # Calculate the number of tiles for this group num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) num_tiles = num_m_tiles * num_n_tiles if USE_TMA_STORE: # Set up TMA descriptor for output # pyre-ignore tl.extra.cuda.experimental_device_tensormap_create2d( desc_ptr=c_desc_ptr, global_address=c_ptr + M_start_offset * (N * G) + N_start_offset, # Offset to this group's output load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[m_size, n_size], element_ty=c_ptr.dtype.element_ty, ) # pyre-ignore tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) # Move across tiles while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles: gidx = tidx - iterated_tiles # Split M first and N second. tile_m_idx = gidx % num_m_tiles tile_n_idx = gidx // num_m_tiles accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) tl.static_assert(K % BLOCK_SIZE_K == 0) if USE_TMA_LOAD: # Use TMA to load input and weight blocks with FP8 support m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32) n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32) for k_offset in range(0, K, BLOCK_SIZE_K): # Load input block [M, K] with FP8 a = tl._experimental_descriptor_load( a_desc_ptr, [m_offset, k_offset], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype, ) # Load weight block [N, K] with FP8 b = tl._experimental_descriptor_load( b_desc_ptr, [n_offset, k_offset], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype, ) # Compute matrix multiplication accumulator += tl.dot(a, b.T) else: # Manual load without TMA for FP8 offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = ( a_desc_ptr + (M_start_offset + offs_am[:, None]) * K + offs_k[None, :] ) b_ptrs = ( b_desc_ptr + (N_start_offset + offs_bn[:, None]) * K + offs_k[None, :] ) for k_offset in range(0, K, BLOCK_SIZE_K): # Load with bounds checking a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) # Compute matrix multiplication accumulator += tl.dot(a, b.T) # Update pointers for next block a_ptrs += BLOCK_SIZE_K b_ptrs += BLOCK_SIZE_K # Load FP8 scales offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) a_scale = tl.load( a_scale_ptr + M_start_offset + offs_am[:, None], mask=offs_am[:, None] < m_size, ) b_scale = tl.load( b_scale_ptr + N_start_offset + offs_bn[None, :], mask=offs_bn[None, :] < n_size, ) # Apply scales to result c = accumulator.to(tl.float32) * a_scale * b_scale # Store result if USE_TMA_STORE: # Store using TMA m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) tl._experimental_descriptor_store( c_desc_ptr, c.to(c_ptr.dtype.element_ty), [m_offset, n_offset], ) else: # Manual store tl.store( c_ptr + (M_start_offset + offs_am[:, None]) * (N * G) # Row stride is N*G + ( N_start_offset + offs_bn[None, :] ), # Column offset to this group's N c, mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size, ) tidx += NUM_SMS # Move to next tile iterated_tiles += num_tiles def _grouped_gemm( x: torch.Tensor, w: torch.Tensor, m_sizes: torch.Tensor, x_scale: Optional[torch.Tensor] = None, w_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: if not utils.HAS_TMA_DESC: raise NotImplementedError("Grouped GEMM without TMA is not supported yet") G = m_sizes.shape[0] assert x.is_contiguous() assert w.is_contiguous() assert m_sizes.is_contiguous() M, K = x.shape N_times_G = w.shape[0] # Ensure N is per group assert ( N_times_G % G == 0 ), f"Weight dimension ({N_times_G}) must be divisible by groups ({G})" N = N_times_G // G assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})" # Create output tensor with correct shape [M, N*G] y = torch.empty((M, N_times_G), device=x.device, dtype=torch.bfloat16) NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count USE_TMA_LOAD = True # not torch.version.hip USE_TMA_STORE = True desc_helper = None desc_x = x desc_w = w workspace = None if USE_TMA_LOAD: desc_helper = utils.TmaAutoTuneHelper() desc_helper.init_tma_descriptor("x") desc_helper.init_tma_descriptor("w") desc_x = desc_helper.get_tma_descriptor_kernel_param("x") desc_w = desc_helper.get_tma_descriptor_kernel_param("w") if USE_TMA_STORE: workspace = torch.empty( NUM_SMS * utils.TmaAutoTuneHelper.TMA_SIZE, device=x.device, dtype=torch.uint8, ) # Skip autotuning - use fixed grid size grid_size = (min(NUM_SMS, 4),) # Use smaller grid for small inputs M_BUCKET = triton.next_power_of_2(M) try: if USE_TMA_LOAD and desc_helper is not None: # Fixed block sizes that work well for most cases BLOCK_SIZE_M = 64 BLOCK_SIZE_N = 64 BLOCK_SIZE_K = 32 desc_helper.fill_2d_tma_descriptor( "x", x.data_ptr(), M, K, BLOCK_SIZE_M, BLOCK_SIZE_K, x.element_size(), ) desc_helper.fill_2d_tma_descriptor( "w", w.data_ptr(), N_times_G, K, BLOCK_SIZE_N, BLOCK_SIZE_K, w.element_size(), ) except Exception as e: print(f"Error in TMA descriptor setup: {e}") if x_scale is not None and w_scale is not None: assert x_scale.is_contiguous() assert w_scale.is_contiguous() # Call kernel directly without autotuning _kernel_grouped_gemm_fp8_rowwise[grid_size]( desc_x, x_scale, desc_w, w_scale, y, workspace, m_sizes, G, M_BUCKET, N, # N is per group K, NUM_SMS, USE_TMA_LOAD, USE_TMA_STORE, BLOCK_SIZE_M=64, # Fixed block sizes BLOCK_SIZE_N=64, BLOCK_SIZE_K=32, ) else: assert x_scale is None assert w_scale is None # Call kernel directly without autotuning _kernel_grouped_gemm[grid_size]( desc_x, desc_w, y, workspace, m_sizes, G, M_BUCKET, N, # N is per group K, NUM_SMS, USE_TMA_LOAD, USE_TMA_STORE, BLOCK_SIZE_M=64, # Fixed block sizes BLOCK_SIZE_N=64, BLOCK_SIZE_K=32, ) # Verify the output shape expected_output_shape = (M, N_times_G) assert y.shape == expected_output_shape, ( f"Output shape mismatch: got {y.shape}, " f"expected {expected_output_shape}" ) return y def grouped_gemm_forward( x: torch.Tensor, w: torch.Tensor, m_sizes: torch.Tensor ) -> torch.Tensor: return _grouped_gemm(x, w, m_sizes) def grouped_gemm_fp8_rowwise( x: torch.Tensor, w: torch.Tensor, m_sizes: torch.Tensor, x_scale: torch.Tensor, w_scale: torch.Tensor, ) -> torch.Tensor: return _grouped_gemm(x, w, m_sizes, x_scale, w_scale) ================================================ FILE: kernels/MoE/group_GEMM/triton/utils/tma_utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe # This code is derived from: https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gemm/triton_gemm import sys import torch import triton # @manual import triton.language as tl # @manual def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype: """ Maps torch dtype to triton dtype. Args: dtype (torch.dtype): input dtype. Returns: tl.dtype: triton dtype. """ if dtype == torch.float16: return tl.float16 elif dtype == torch.bfloat16: return tl.bfloat16 elif dtype == torch.float32: return tl.float32 elif dtype == torch.int32: return tl.int32 elif dtype == torch.float8_e4m3fn and torch.version.hip is None: return tl.float8e4nv else: raise ValueError(f"Unsupported dtype {dtype}") # check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498). HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) if HAS_TMA_DESC: print( "TMA benchmarks will be running with experimental grid constant TMA descriptor.", file=sys.stderr, ) else: print( "Missing: This group gemm code will not run without TMA descriptor support....", file=sys.stderr, ) raise NotImplementedError("grouped Gemm without TMA is not supported") class TmaAutoTuneHelper: # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498 class KernelParamWrapper: def __init__(self, desc): self.desc = desc def tma_desc_cpu_ptr(self): return self.desc.data_ptr() TMA_SIZE = 128 def __init__(self): self.fill_1d_tma_descriptor_inner = ( triton.runtime.driver.active.utils.fill_1d_tma_descriptor ) self.fill_2d_tma_descriptor_inner = ( triton.runtime.driver.active.utils.fill_2d_tma_descriptor ) if HAS_TMA_DESC: self.descriptors = {} else: self.cuda_descriptors = {} # Call this method outside of the lambda function for grid size def init_tma_descriptor(self, name): if HAS_TMA_DESC: self.descriptors[name] = torch.empty( TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8 ) else: self.cuda_descriptors[name] = torch.empty( TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8 ) # Call this method inside the lambda function for grid size def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size): if HAS_TMA_DESC: desc_x = self.descriptors[name] assert desc_x.data_ptr() % 64 == 0 self.fill_1d_tma_descriptor_inner( ptr, dim, block_dim, element_size, desc_x.data_ptr() ) else: desc_x = self.cuda_descriptors[name] buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) self.fill_1d_tma_descriptor_inner( ptr, dim, block_dim, element_size, buf_x.data_ptr() ) desc_x.copy_(buf_x, non_blocking=True) # Call this method inside the lambda function for grid size def fill_2d_tma_descriptor( self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size ): if HAS_TMA_DESC: desc_x = self.descriptors[name] assert desc_x.data_ptr() % 64 == 0 self.fill_2d_tma_descriptor_inner( ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr() ) else: desc_x = self.cuda_descriptors[name] buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) self.fill_2d_tma_descriptor_inner( ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr() ) desc_x.copy_(buf_x, non_blocking=True) def get_tma_descriptor_kernel_param(self, name): if HAS_TMA_DESC: assert self.descriptors[name] is not None return self.KernelParamWrapper(self.descriptors[name]) else: assert self.cuda_descriptors[name] is not None return self.cuda_descriptors[name] ================================================ FILE: kernels/blackwell/cute_gemm_01/Makefile ================================================ # Makefile for SM100 GEMM PyTorch Extension # Set these paths according to your installation CUTLASS_PATH ?= /path/to/cutlass CUDA_HOME ?= $(shell python -c "import torch; print(torch.utils.cpp_extension.CUDA_HOME)") # Build the extension build: CUTLASS_PATH=$(CUTLASS_PATH) python setup.py build_ext --inplace # Install the extension install: CUTLASS_PATH=$(CUTLASS_PATH) pip install . # Clean build artifacts clean: rm -rf build/ dist/ *.egg-info/ sm100_gemm*.so # Test the installation test: python python_interface.py # Check CUDA device capability check_device: python -c "import torch; print(f'CUDA device: {torch.cuda.get_device_name()}, Compute capability: {torch.cuda.get_device_capability()}')" .PHONY: build install clean test check_device ================================================ FILE: kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/.ninja_log ================================================ # ninja log v5 0 15279 1748131038212164071 /data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o 1163be77f63db063 6 13596 1748131241209889865 /data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm.o 79aa61597088743a 8 13684 1748132015451659084 /data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm.o 89ead7aaccf82852 ================================================ FILE: kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/build.ninja ================================================ ninja_required_version = 1.3 cxx = c++ nvcc = /usr/local/cuda-12.8/bin/nvcc cflags = -pthread -B /home/less/.conda/envs/pycutlass/compiler_compat -fno-strict-overflow -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c post_cflags = -O3 -std=c++17 -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1018"' -DTORCH_EXTENSION_NAME=sm100_gemm cuda_cflags = -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++17 --expt-relaxed-constexpr --expt-extended-lambda -gencode=arch=compute_100a,code=sm_100a -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED --use_fast_math -Xcompiler=-fPIC -DCUTE_ARCH_TCGEN05_TMEM_ENABLED=1 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1018"' -DTORCH_EXTENSION_NAME=sm100_gemm cuda_dlink_post_cflags = sycl_dlink_post_cflags = ldflags = rule compile command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags depfile = $out.d deps = gcc rule cuda_compile depfile = $out.d deps = gcc command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags build /data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm.o: cuda_compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm/sm100_gemm.cu build /data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o: compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm/sm100_gemm_pytorch.cpp ================================================ FILE: kernels/blackwell/cute_gemm_01/driver.py ================================================ # ============================================================================== # python_interface.py - High-level Python interface # ============================================================================== import torch try: import sm100_gemm # The compiled extension - this has to go after import torch...but auto-formatting is blocking except ImportError: print("❌ SM100 not ready!") raise ImportError( "SM100 not ready! Please build the extension using `python setup.py install`" ) def sm100_gemm_f16(A, B, C=None, alpha=1.0, beta=0.0): """ Perform GEMM using SM100 optimized kernel: D = alpha * A @ B^T + beta * C Args: A (torch.Tensor): Input tensor A of shape (M, K), dtype=torch.float16 B (torch.Tensor): Input tensor B of shape (N, K), dtype=torch.float16 C (torch.Tensor, optional): Input tensor C of shape (M, N), dtype=torch.float32 If None, creates zero tensor alpha (float): Scaling factor for A @ B^T beta (float): Scaling factor for C Returns: torch.Tensor: Output tensor D of shape (M, N), dtype=torch.float32 Note: - A and B are K-major (transposed in BLAS terms) - C and D are N-major (row-major) - All tensors must be on CUDA - M must be multiple of 128, N multiple of 256, K multiple of 64 """ # Input validation assert A.dtype == torch.float16, f"A must be float16, got {A.dtype}" assert B.dtype == torch.float16, f"B must be float16, got {B.dtype}" assert A.is_cuda and B.is_cuda, "A and B must be on CUDA" assert A.is_contiguous() and B.is_contiguous(), "A and B must be contiguous" M, K = A.shape N, K_B = B.shape assert K == K_B, f"Inner dimensions must match: A.shape[1]={K}, B.shape[1]={K_B}" # Check alignment requirements assert M % 128 == 0, f"M={M} must be multiple of 128" assert N % 256 == 0, f"N={N} must be multiple of 256" assert K % 64 == 0, f"K={K} must be multiple of 64" # Create C if not provided if C is None: C = torch.zeros(M, N, dtype=torch.float32, device=A.device) else: assert C.dtype == torch.float32, f"C must be float32, got {C.dtype}" assert C.is_cuda, "C must be on CUDA" assert C.is_contiguous(), "C must be contiguous" assert C.shape == ( M, N, ), f"C shape {C.shape} must match output shape ({M}, {N})" # Call the extension return sm100_gemm.sm100_gemm_f16(A, B, C, alpha, beta) def benchmark_sm100_vs_torch( M=1024, N=2048, K=256, num_warmup=1, num_trials=10 ): # M=512, N=1024, K=256, num_warmup=10, num_trials=100): """ Benchmark SM100 GEMM against PyTorch's native GEMM """ # Ensure dimensions are aligned M = ((M + 127) // 128) * 128 N = ((N + 255) // 256) * 256 K = ((K + 63) // 64) * 64 print(f"Benchmarking GEMM with shape: ({M}, {N}, {K})") # Create test tensors A = torch.randn(M, K, dtype=torch.float16, device="cuda") B = torch.randn(N, K, dtype=torch.float16, device="cuda") C = torch.randn(M, N, dtype=torch.float16, device="cuda") C32 = C.to(torch.float32).clone() # Keep A and B as FP16 for PyTorch A_fp16 = A B_fp16 = B # Warmup for _ in range(num_warmup): # PyTorch GEMM (using FP16) torch_result = torch.addmm(C, A_fp16, B_fp16.T) # SM100 GEMM sm100_result = sm100_gemm_f16(A, B, C32) torch.cuda.synchronize() # Benchmark PyTorch torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(num_trials): torch_result = torch.addmm(C, A_fp16, B_fp16.T) end.record() torch.cuda.synchronize() torch_time = start.elapsed_time(end) / num_trials # Benchmark SM100 start.record() for _ in range(num_trials): sm100_result = sm100_gemm_f16(A, B, C32) end.record() torch.cuda.synchronize() sm100_time = start.elapsed_time(end) / num_trials # Check correctness max_diff = torch.max(torch.abs(torch_result - sm100_result.to(torch.float16))) rel_error = max_diff / torch.max(torch.abs(torch_result)) # Calculate FLOPS flops = 2 * M * N * K # Multiply-add operations torch_tflops = flops / (torch_time * 1e-3) / 1e12 sm100_tflops = flops / (sm100_time * 1e-3) / 1e12 print(f"PyTorch time: {torch_time:.3f} ms ({torch_tflops:.2f} TFLOPS)") print(f"SM100 time: {sm100_time:.3f} ms ({sm100_tflops:.2f} TFLOPS)") print(f"Speedup: {torch_time/sm100_time:.2f}x") print(f"Max difference: {max_diff:.6f}") print(f"Relative error: {rel_error:.6f}") return { "torch_time": torch_time, "sm100_time": sm100_time, "speedup": torch_time / sm100_time, "torch_tflops": torch_tflops, "sm100_tflops": sm100_tflops, "max_diff": max_diff.item(), "rel_error": rel_error.item(), } # Example usage and test if __name__ == "__main__": # Test basic functionality print("Testing SM100 GEMM...") M, N, K = 512, 1024, 256 A = torch.randn(M, K, dtype=torch.float16, device="cuda") B = torch.randn(N, K, dtype=torch.float16, device="cuda") C = torch.randn(M, N, dtype=torch.float32, device="cuda") # Test the GEMM result = sm100_gemm_f16(A, B, C, alpha=1.0, beta=0.5) print(f"Result shape: {result.shape}, dtype: {result.dtype}") # Run benchmark print("\nRunning benchmark...") benchmark_results = benchmark_sm100_vs_torch(M, N, K) # ============================================================================== # Makefile for easy building # ============================================================================== ''' MAKEFILE_CONTENT = """ # Makefile for SM100 GEMM PyTorch Extension # Set these paths according to your installation CUTLASS_PATH ?= /path/to/cutlass CUDA_HOME ?= $(shell python -c "import torch; print(torch.utils.cpp_extension.CUDA_HOME)") # Build the extension build: CUTLASS_PATH=$(CUTLASS_PATH) python setup.py build_ext --inplace # Install the extension install: CUTLASS_PATH=$(CUTLASS_PATH) pip install . # Clean build artifacts clean: rm -rf build/ dist/ *.egg-info/ sm100_gemm*.so # Test the installation test: python python_interface.py # Check CUDA device capability check_device: python -c "import torch; print(f'CUDA device: {torch.cuda.get_device_name()}, Compute capability: {torch.cuda.get_device_capability()}')" .PHONY: build install clean test check_device """ # Write Makefile with open("Makefile", "w") as f: f.write(MAKEFILE_CONTENT) print("Setup files created!") print("To build:") print("1. Set CUTLASS_PATH environment variable to your CUTLASS installation") print("2. Run: make build") print("3. Test: make test") ''' ================================================ FILE: kernels/blackwell/cute_gemm_01/setup.py ================================================ # setup.py import os import pybind11 import torch from pybind11 import get_cmake_dir from pybind11.setup_helpers import build_ext, Pybind11Extension from setuptools import Extension, setup from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension # IMPORTANT: The following two lines are the only ones you need to change # Get CUTLASS path (you'll need to set this to your CUTLASS installation) CUTLASS_PATH = os.environ.get("CUTLASS_PATH", "/home/less/local/cutlas40") # CUDA and PyTorch paths cuda_home = torch.utils.cpp_extension.CUDA_HOME pytorch_includes = torch.utils.cpp_extension.include_paths() ext_modules = [ CUDAExtension( name="sm100_gemm", sources=[ "sm100_gemm_pytorch.cpp", # PyTorch bindings (C++) "sm100_gemm.cu", # CUDA kernel implementation ], include_dirs=[ # PyTorch includes *pytorch_includes, # CUTLASS includes f"{CUTLASS_PATH}/include", f"{CUTLASS_PATH}/tools/util/include", # CUDA includes f"{cuda_home}/include", ], library_dirs=[ f"{cuda_home}/lib64", ], libraries=["cuda", "cudart"], extra_compile_args={ "cxx": [ "-O3", "-std=c++17", "-DCUTLASS_ARCH_MMA_SM100_SUPPORTED", "-DCUTE_SM100_ENABLED", ], "nvcc": [ "-O3", "-std=c++17", "--expt-relaxed-constexpr", "--expt-extended-lambda", "-gencode=arch=compute_100a,code=sm_100a", # SM100 architecture "-DCUTLASS_ARCH_MMA_SM100_SUPPORTED", "-DCUTE_SM100_ENABLED", "--use_fast_math", "-Xcompiler=-fPIC", "-DCUTE_ARCH_TCGEN05_TMEM_ENABLED=1", # Enable TCGEN05_TMEM ], }, extra_link_args=["-lcuda", "-lcudart"], language="c++", ) ] setup( name="sm100_gemm", ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension}, zip_safe=False, python_requires=">=3.8", install_requires=["torch>=1.12.0"], ) ================================================ FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.cu ================================================ // sm100_gemm_kernel.cu - CUDA kernel implementation #include "sm100_gemm.h" #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) #include #include #include #include #include #include #include #include #include using namespace cute; // Shared storage structure template struct SharedStorage { alignas(128) cute::ArrayEngine> A; alignas(128) cute::ArrayEngine> B; alignas(16) cute::uint64_t mma_barrier; alignas(16) cute::uint32_t tmem_base_ptr; CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); } CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); } }; // Device kernel template __global__ static void gemm_device(ATensor mA, BTensor mB, CTensor mC, DTensor mD, MmaTiler_MNK mma_tiler, TiledMMA tiled_mma, ClusterShape_MNK cluster_shape, Alpha alpha, Beta beta) { // Step 1: The Prologue Layout cluster_layout_vmnk = tiled_divide( make_layout(cluster_shape), make_tile(typename TiledMMA::AtomThrID{})); auto mma_coord_vmnk = make_coord(blockIdx.x % size<0>(cluster_layout_vmnk), blockIdx.x / size<0>(cluster_layout_vmnk), blockIdx.y, _); auto mma_coord = select<1, 2, 3>(mma_coord_vmnk); Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X, _1>{}); Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step{}); Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1, _1, X>{}); Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1, _1, X>{}); // SMEM allocation extern __shared__ char shared_memory[]; SharedStorage &shared_storage = *reinterpret_cast(shared_memory); Tensor tCsA = shared_storage.tensor_sA(); Tensor tCsB = shared_storage.tensor_sB(); // MMA partitioning auto mma_v = get<0>(mma_coord_vmnk); ThrMMA cta_mma = tiled_mma.get_slice(mma_v); Tensor tCgA = cta_mma.partition_A(gA); Tensor tCgB = cta_mma.partition_B(gB); Tensor tCgC = cta_mma.partition_C(gC); Tensor tCgD = cta_mma.partition_C(gD); // Fragment allocation Tensor tCrA = cta_mma.make_fragment_A(tCsA); Tensor tCrB = cta_mma.make_fragment_B(tCsB); Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); uint32_t elect_one_thr = cute::elect_one_sync(); uint32_t elect_one_warp = (threadIdx.x / 32 == 0); using TmemAllocator = cute::TMEM::Allocator1Sm; TmemAllocator tmem_allocator{}; // TMEM allocation if (elect_one_warp) { tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); } __syncthreads(); tCtAcc.data() = shared_storage.tmem_base_ptr; // Barrier initialization if (elect_one_warp && elect_one_thr) { cute::initialize_barrier(shared_storage.mma_barrier, 1); } int mma_barrier_phase_bit = 0; __syncthreads(); // Step 2: The Mainloop tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile) { // Load A and B tiles cooperative_copy<128>(threadIdx.x, tCgA(_, _, _, k_tile), tCsA); cooperative_copy<128>(threadIdx.x, tCgB(_, _, _, k_tile), tCsB); __syncthreads(); // Execute MMAs if (elect_one_warp) { for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCtAcc); tiled_mma.accumulate_ = UMMA::ScaleOut::One; } cutlass::arch::umma_arrive(&shared_storage.mma_barrier); } cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit); mma_barrier_phase_bit ^= 1; } // Step 3: The Epilogue TiledCopy tiled_t2r_copy = make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc); ThrCopy thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x); Tensor tDgC = thr_t2r_copy.partition_D(tCgC); Tensor tDrC = make_fragment_like(tDgC); copy(tDgC, tDrC); Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc); Tensor tDgD = thr_t2r_copy.partition_D(tCgD); using AccType = typename decltype(tCtAcc)::value_type; Tensor tDrAcc = make_tensor(shape(tDgD)); copy(tiled_t2r_copy, tDtAcc, tDrAcc); // AXPBY and store result axpby(alpha, tDrAcc, beta, tDrC); copy(tDrC, tDgD); __syncthreads(); // Cleanup TMEM if (elect_one_warp) { tmem_allocator.release_allocation_lock(); tmem_allocator.free(shared_storage.tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); } } // Host function that launches the kernel cudaError_t launch_sm100_gemm_f16(void *d_A, void *d_B, void *d_C, void *d_D, int M, int N, int K, float alpha, float beta, cudaStream_t stream) { // Define types using TypeA = cutlass::half_t; using TypeB = cutlass::half_t; using TypeC = float; using TypeD = float; // Create layouts (K-major for A and B, N-major for C and D) auto layout_A = make_layout(make_shape(M, K), make_stride(K, Int<1>{})); auto layout_B = make_layout(make_shape(N, K), make_stride(K, Int<1>{})); auto layout_C = make_layout(make_shape(M, N), make_stride(N, Int<1>{})); auto layout_D = layout_C; // Create CuTe tensors auto mA = make_tensor(make_gmem_ptr(reinterpret_cast(d_A)), layout_A); auto mB = make_tensor(make_gmem_ptr(reinterpret_cast(d_B)), layout_B); auto mC = make_tensor(make_gmem_ptr(reinterpret_cast(d_C)), layout_C); auto mD = make_tensor(make_gmem_ptr(reinterpret_cast(d_D)), layout_D); // Create TiledMMA TiledMMA tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_SS{}); // Define MMA tiler sizes auto bM = tile_size<0>(tiled_mma); // 128 auto bN = tile_size<1>(tiled_mma); // 256 auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // 64 auto mma_tiler = make_shape(bM, bN, bK); // Check alignment if (M % int(bM) != 0 || N % int(bN) != 0 || K % int(bK) != 0) { return cudaErrorInvalidValue; } // Create SMEM layouts auto mma_shape_A = partition_shape_A( tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler))); auto mma_shape_B = partition_shape_B( tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler))); auto sA_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom{}, mma_shape_A); auto sB_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom{}, mma_shape_B); using SMEMStorage = SharedStorage; // Cluster configuration auto cluster_shape = make_shape(Int<1>{}, Int<1>{}, Int<1>{}); // Launch parameters dim3 dimBlock(128); dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); dim3 dimGrid(ceil_div(M, int(bM)), ceil_div(N, int(bN))); int smemBytes = sizeof(SMEMStorage); // Get kernel pointer auto *kernel_ptr = &gemm_device; // Set kernel attributes cudaError_t error = cudaFuncSetAttribute( kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smemBytes); if (error != cudaSuccess) { return error; } // Launch kernel cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes}; cutlass::Status status = cutlass::launch_kernel_on_cluster( params, (void const *)kernel_ptr, mA, mB, mC, mD, mma_tiler, tiled_mma, cluster_shape, alpha, beta); return (status == cutlass::Status::kSuccess) ? cudaSuccess : cudaErrorLaunchFailure; } #else cudaError_t launch_sm100_gemm_f16(void *d_A, void *d_B, void *d_C, void *d_D, int M, int N, int K, float alpha, float beta, cudaStream_t stream) { return cudaErrorNotSupported; } #endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED ================================================ FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/PKG-INFO ================================================ Metadata-Version: 2.4 Name: sm100_gemm Version: 0.0.0 Requires-Python: >=3.8 Requires-Dist: torch>=1.12.0 Dynamic: requires-dist Dynamic: requires-python ================================================ FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/SOURCES.txt ================================================ setup.py sm100_gemm.cu sm100_gemm_pytorch.cpp sm100_gemm.egg-info/PKG-INFO sm100_gemm.egg-info/SOURCES.txt sm100_gemm.egg-info/dependency_links.txt sm100_gemm.egg-info/not-zip-safe sm100_gemm.egg-info/requires.txt sm100_gemm.egg-info/top_level.txt ================================================ FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/dependency_links.txt ================================================ ================================================ FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/not-zip-safe ================================================ ================================================ FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/requires.txt ================================================ torch>=1.12.0 ================================================ FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/top_level.txt ================================================ sm100_gemm ================================================ FILE: kernels/blackwell/cute_gemm_01/sm100_gemm.h ================================================ // sm100_gemm_kernel.h - Header file for CUDA kernel #pragma once #include #ifdef __cplusplus extern "C" { #endif /** * Launch SM100 GEMM kernel: D = alpha * A @ B^T + beta * C * * @param d_A Pointer to matrix A in device memory (M x K, FP16, K-major) * @param d_B Pointer to matrix B in device memory (N x K, FP16, K-major) * @param d_C Pointer to matrix C in device memory (M x N, FP32, N-major) * @param d_D Pointer to matrix D in device memory (M x N, FP32, N-major) * @param M Number of rows in A and C/D * @param N Number of rows in B and columns in C/D * @param K Number of columns in A and B * @param alpha Scaling factor for A @ B^T * @param beta Scaling factor for C * @param stream CUDA stream (currently unused, for future async support) * * @return cudaSuccess on success, error code otherwise * * Requirements: * - M must be multiple of 128 * - N must be multiple of 256 * - K must be multiple of 64 * - All pointers must be valid device memory * - Tensors must be contiguous with specified layouts */ cudaError_t launch_sm100_gemm_f16(void *d_A, void *d_B, void *d_C, void *d_D, int M, int N, int K, float alpha, float beta, cudaStream_t stream = 0); #ifdef __cplusplus } #endif ================================================ FILE: kernels/blackwell/cute_gemm_01/sm100_gemm_pytorch.cpp ================================================ // sm100_gemm_pytorch.cpp - PyTorch C++ extension (no CUDA code) #include #include #include #include #include #include "sm100_gemm.h" // Check if SM100 support is available at compile time bool is_sm100_supported() { #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) return true; #else return false; #endif } // Check if current GPU supports SM100 at runtime bool check_sm100_device() { int device; cudaGetDevice(&device); cudaDeviceProp props; cudaError_t error = cudaGetDeviceProperties(&props, device); if (error != cudaSuccess) { return false; } // Check for SM100 architecture (compute capability 10.0a) return (props.major == 10 && props.minor == 0); } torch::Tensor sm100_gemm_f16(const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C, float alpha = 1.0f, float beta = 0.0f) { // Check compile-time support TORCH_CHECK( is_sm100_supported(), "SM100 support not compiled. Requires CUTLASS_ARCH_MMA_SM100_SUPPORTED"); // Check runtime device support TORCH_CHECK(check_sm100_device(), "Current GPU does not support SM100 architecture (requires " "compute capability 10.0a)"); // Input validation TORCH_CHECK(A.device().is_cuda(), "A must be a CUDA tensor"); TORCH_CHECK(B.device().is_cuda(), "B must be a CUDA tensor"); TORCH_CHECK(C.device().is_cuda(), "C must be a CUDA tensor"); TORCH_CHECK(A.dtype() == torch::kFloat16, "A must be float16"); TORCH_CHECK(B.dtype() == torch::kFloat16, "B must be float16"); TORCH_CHECK(C.dtype() == torch::kFloat32, "C must be float32"); TORCH_CHECK(A.is_contiguous(), "A must be contiguous"); TORCH_CHECK(B.is_contiguous(), "B must be contiguous"); TORCH_CHECK(C.is_contiguous(), "C must be contiguous"); TORCH_CHECK(A.dim() == 2, "A must be 2D"); TORCH_CHECK(B.dim() == 2, "B must be 2D"); TORCH_CHECK(C.dim() == 2, "C must be 2D"); // Get dimensions int64_t M = A.size(0); int64_t K = A.size(1); int64_t N = B.size(0); int64_t K_B = B.size(1); TORCH_CHECK(K == K_B, "Inner dimensions must match: A.shape[1]=", K, ", B.shape[1]=", K_B); TORCH_CHECK(C.size(0) == M && C.size(1) == N, "C dimensions (", C.size(0), ", ", C.size(1), ") must match output shape (", M, ", ", N, ")"); // Check alignment requirements for SM100 TORCH_CHECK(M % 128 == 0, "M=", M, " must be multiple of 128"); TORCH_CHECK(N % 256 == 0, "N=", N, " must be multiple of 256"); TORCH_CHECK(K % 64 == 0, "K=", K, " must be multiple of 64"); // Check size limits (avoid overflow in int conversion) TORCH_CHECK(M <= INT_MAX && N <= INT_MAX && K <= INT_MAX, "Dimensions too large for int conversion"); // Create output tensor auto D = torch::empty_like(C); // Set CUDA device guard const auto device = A.device(); c10::cuda::CUDAGuard device_guard(device); // Get current CUDA stream cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index()).stream(); // Launch the kernel cudaError_t error = launch_sm100_gemm_f16( A.data_ptr(), B.data_ptr(), C.data_ptr(), D.data_ptr(), static_cast(M), static_cast(N), static_cast(K), alpha, beta, stream); // Check for launch errors TORCH_CHECK(error == cudaSuccess, "SM100 GEMM kernel launch failed: ", cudaGetErrorString(error)); // Check for kernel execution errors C10_CUDA_CHECK(cudaGetLastError()); return D; } // Utility functions for debugging and information torch::Tensor get_device_info() { int device; cudaGetDevice(&device); cudaDeviceProp props; cudaGetDeviceProperties(&props, device); // Return device info as a tensor (for easy Python access) auto info = torch::zeros({4}, torch::kInt32); auto accessor = info.accessor(); accessor[0] = props.major; // Compute capability major accessor[1] = props.minor; // Compute capability minor accessor[2] = is_sm100_supported(); // Compile-time support accessor[3] = check_sm100_device(); // Runtime device support return info; } std::vector get_aligned_shape(int64_t M, int64_t N, int64_t K) { // Return properly aligned dimensions for SM100 int64_t aligned_M = ((M + 127) / 128) * 128; int64_t aligned_N = ((N + 255) / 256) * 256; int64_t aligned_K = ((K + 63) / 64) * 64; return {aligned_M, aligned_N, aligned_K}; } torch::Tensor create_aligned_tensor(const std::vector &shape, torch::ScalarType dtype, torch::Device device) { // Create a tensor with SM100-aligned dimensions TORCH_CHECK(shape.size() == 2, "Shape must be 2D"); auto aligned_shape = get_aligned_shape(shape[0], shape[1], shape.size() > 2 ? shape[2] : 64); if (shape.size() == 2) { return torch::zeros({aligned_shape[0], aligned_shape[1]}, torch::TensorOptions().dtype(dtype).device(device)); } else { return torch::zeros({aligned_shape[0], aligned_shape[2]}, torch::TensorOptions().dtype(dtype).device(device)); } } // Python bindings PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "SM100 GEMM PyTorch Extension"; // Main GEMM function m.def("sm100_gemm_f16", &sm100_gemm_f16, "SM100 GEMM with FP16 inputs and FP32 output: D = alpha * A @ B^T + " "beta * C", py::arg("A"), py::arg("B"), py::arg("C"), py::arg("alpha") = 1.0f, py::arg("beta") = 0.0f); // Utility functions m.def("is_sm100_supported", &is_sm100_supported, "Check if SM100 support was compiled in"); m.def("check_sm100_device", &check_sm100_device, "Check if current GPU supports SM100 architecture"); m.def("get_device_info", &get_device_info, "Get device compute capability and SM100 support info"); m.def("get_aligned_shape", &get_aligned_shape, "Get SM100-aligned dimensions for given shape", py::arg("M"), py::arg("N"), py::arg("K")); m.def("create_aligned_tensor", &create_aligned_tensor, "Create tensor with SM100-aligned dimensions", py::arg("shape"), py::arg("dtype"), py::arg("device")); // Constants for alignment requirements m.attr("MMA_TILE_M") = 128; m.attr("MMA_TILE_N") = 256; m.attr("MMA_TILE_K") = 64; } ================================================ FILE: kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/.ninja_log ================================================ # ninja log v5 1 15202 1748185895110710199 /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o 342153d32d365f0b 7 78 1748186494782816813 /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o 342153d32d365f0b 6 15086 1748186805607894090 /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o 342153d32d365f0b 6 14058 1748187024415643408 /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm.o 6c5f77cfca7cfb81 ================================================ FILE: kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/build.ninja ================================================ ninja_required_version = 1.3 cxx = c++ nvcc = /usr/local/cuda-12.8/bin/nvcc cflags = -pthread -B /home/less/.conda/envs/pycutlass/compiler_compat -fno-strict-overflow -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c post_cflags = -O3 -std=c++17 -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1018"' -DTORCH_EXTENSION_NAME=sm100_gemm cuda_cflags = -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++17 --expt-relaxed-constexpr --expt-extended-lambda -gencode=arch=compute_100a,code=sm_100a -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED --use_fast_math -Xcompiler=-fPIC -DCUTE_ARCH_TCGEN05_TMEM_ENABLED=1 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1018"' -DTORCH_EXTENSION_NAME=sm100_gemm cuda_dlink_post_cflags = sycl_dlink_post_cflags = ldflags = rule compile command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags depfile = $out.d deps = gcc rule cuda_compile depfile = $out.d deps = gcc command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags build /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm.o: cuda_compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/sm100_gemm.cu build /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o: compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/sm100_gemm_pytorch.cpp ================================================ FILE: kernels/blackwell/cute_gemm_02_tma/driver.py ================================================ # python_interface.py - High-level Python interface with TMA support import torch try: import sm100_gemm # The compiled extension except ImportError: print("❌ SM100 not ready!") raise ImportError( "SM100 not ready! Please build the extension using `python setup.py install`" ) def check_sm100_compatibility(): """Check if SM100 is supported and available""" compile_support = sm100_gemm.is_sm100_supported() device_support = sm100_gemm.check_sm100_device() info = sm100_gemm.get_device_info() major, minor, compile_flag, device_flag = info.tolist() print(f"Device compute capability: {major}.{minor}") print(f"Compile-time SM100 support: {bool(compile_flag)}") print(f"Runtime SM100 device support: {bool(device_flag)}") if not compile_support: print( "❌ SM100 support not compiled in. Rebuild with CUTLASS_ARCH_MMA_SM100_SUPPORTED" ) elif not device_support: print("❌ Current GPU does not support SM100 (need compute capability 10.0a)") else: print("✅ SM100 with TMA ready!") return compile_support and device_support def sm100_gemm_f16_tma(A, B, C=None, alpha=1.0, beta=0.0, check_alignment=True): """ Perform GEMM using SM100 optimized kernel with TMA: D = alpha * A @ B^T + beta * C Args: A (torch.Tensor): Input tensor A of shape (M, K), dtype=torch.float16 B (torch.Tensor): Input tensor B of shape (N, K), dtype=torch.float16 C (torch.Tensor, optional): Input tensor C of shape (M, N), dtype=torch.float32 If None, creates zero tensor alpha (float): Scaling factor for A @ B^T beta (float): Scaling factor for C check_alignment (bool): Whether to check and suggest aligned dimensions Returns: torch.Tensor: Output tensor D of shape (M, N), dtype=torch.float32 Note: - Uses TMA (Tensor Memory Accelerator) for efficient memory transfers - A and B are K-major (transposed in BLAS terms) - C and D are N-major (row-major) - All tensors must be on CUDA - M must be multiple of 128, N multiple of 256, K multiple of 64 """ # Input validation assert A.dtype == torch.float16, f"A must be float16, got {A.dtype}" assert B.dtype == torch.float16, f"B must be float16, got {B.dtype}" assert A.is_cuda and B.is_cuda, "A and B must be on CUDA" assert A.is_contiguous() and B.is_contiguous(), "A and B must be contiguous" M, K = A.shape N, K_B = B.shape assert K == K_B, f"Inner dimensions must match: A.shape[1]={K}, B.shape[1]={K_B}" # Check or fix alignment requirements if check_alignment: aligned_M, aligned_N, aligned_K = sm100_gemm.get_aligned_shape(M, N, K) if M != aligned_M or N != aligned_N or K != aligned_K: print(f"Warning: Dimensions ({M}, {N}, {K}) not aligned for SM100") print( f"Suggested aligned dimensions: ({aligned_M}, {aligned_N}, {aligned_K})" ) print("Consider padding tensors or use create_aligned_tensors()") # Strict alignment check assert ( M % sm100_gemm.MMA_TILE_M == 0 ), f"M={M} must be multiple of {sm100_gemm.MMA_TILE_M}" assert ( N % sm100_gemm.MMA_TILE_N == 0 ), f"N={N} must be multiple of {sm100_gemm.MMA_TILE_N}" assert ( K % sm100_gemm.MMA_TILE_K == 0 ), f"K={K} must be multiple of {sm100_gemm.MMA_TILE_K}" # Create C if not provided if C is None: C = torch.zeros(M, N, dtype=torch.float32, device=A.device) else: assert C.dtype == torch.float32, f"C must be float32, got {C.dtype}" assert C.is_cuda, "C must be on CUDA" assert C.is_contiguous(), "C must be contiguous" assert C.shape == ( M, N, ), f"C shape {C.shape} must match output shape ({M}, {N})" # Call the extension (now uses TMA internally) return sm100_gemm.sm100_gemm_f16(A, B, C, alpha, beta) # Keep the old name for compatibility sm100_gemm_f16 = sm100_gemm_f16_tma def create_aligned_tensors( M, N, K, device="cuda", dtype_AB=torch.float16, dtype_C=torch.float32 ): """ Create properly aligned tensors for SM100 GEMM with TMA Returns: tuple: (A, B, C) tensors with aligned dimensions """ aligned_M, aligned_N, aligned_K = sm100_gemm.get_aligned_shape(M, N, K) A = torch.zeros(aligned_M, aligned_K, dtype=dtype_AB, device=device) B = torch.zeros(aligned_N, aligned_K, dtype=dtype_AB, device=device) C = torch.zeros(aligned_M, aligned_N, dtype=dtype_C, device=device) return A, B, C def pad_to_aligned(tensor, target_shape=None, dim_requirements=None): """ Pad tensor to meet SM100 alignment requirements Args: tensor: Input tensor to pad target_shape: Specific target shape (optional) dim_requirements: Tuple of (M_align, N_align, K_align) requirements Returns: Padded tensor and padding info for later unpadding """ if dim_requirements is None: dim_requirements = ( sm100_gemm.MMA_TILE_M, sm100_gemm.MMA_TILE_N, sm100_gemm.MMA_TILE_K, ) if tensor.dim() == 2: M, N = tensor.shape if target_shape: target_M, target_N = target_shape else: target_M = ( (M + dim_requirements[0] - 1) // dim_requirements[0] ) * dim_requirements[0] target_N = ( (N + dim_requirements[1] - 1) // dim_requirements[1] ) * dim_requirements[1] pad_M = target_M - M pad_N = target_N - N # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom) padded = torch.nn.functional.pad(tensor, (0, pad_N, 0, pad_M)) return padded, (M, N, pad_M, pad_N) else: raise ValueError("Only 2D tensors supported") def unpad_result(tensor, padding_info): """Remove padding from result tensor""" orig_M, orig_N, pad_M, pad_N = padding_info return tensor[:orig_M, :orig_N] def benchmark_sm100_vs_torch( M=512, N=1024, K=256, num_warmup=1, num_trials=10, auto_align=True, compare_tma=True, ): """ Benchmark SM100 GEMM with TMA against PyTorch's native GEMM """ # Ensure dimensions are aligned if auto_align: M = ( (M + sm100_gemm.MMA_TILE_M - 1) // sm100_gemm.MMA_TILE_M ) * sm100_gemm.MMA_TILE_M N = ( (N + sm100_gemm.MMA_TILE_N - 1) // sm100_gemm.MMA_TILE_N ) * sm100_gemm.MMA_TILE_N K = ( (K + sm100_gemm.MMA_TILE_K - 1) // sm100_gemm.MMA_TILE_K ) * sm100_gemm.MMA_TILE_K print(f"Benchmarking GEMM with TMA for shape: ({M}, {N}, {K})") # Check SM100 availability if not check_sm100_compatibility(): print("SM100 not available, skipping benchmark") return None # Create test tensors A = torch.randn(M, K, dtype=torch.float16, device="cuda") B = torch.randn(N, K, dtype=torch.float16, device="cuda") C = torch.randn(M, N, dtype=torch.float32, device="cuda") # PyTorch baseline (using mixed precision) A_fp32 = A.float() B_fp32 = B.float() # Warmup for _ in range(num_warmup): # PyTorch GEMM torch_result = torch.addmm(C, A_fp32, B_fp32.T) # SM100 GEMM with TMA sm100_result = sm100_gemm_f16_tma(A, B, C.clone(), check_alignment=False) torch.cuda.synchronize() # Benchmark PyTorch torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) # warmup torch_result = torch.addmm(C, A_fp32, B_fp32.T) start.record() for _ in range(num_trials): torch_result = torch.addmm(C, A_fp32, B_fp32.T) end.record() torch.cuda.synchronize() torch_time = start.elapsed_time(end) / num_trials # Benchmark SM100 with TMA # warmup sm100_result = sm100_gemm_f16_tma(A, B, C.clone(), check_alignment=False) start.record() for _ in range(num_trials): sm100_result = sm100_gemm_f16_tma(A, B, C.clone(), check_alignment=False) end.record() torch.cuda.synchronize() sm100_time = start.elapsed_time(end) / num_trials # Check correctness max_diff = torch.max(torch.abs(torch_result - sm100_result)) rel_error = max_diff / torch.max(torch.abs(torch_result)) # Calculate FLOPS flops = 2 * M * N * K # Multiply-add operations torch_tflops = flops / (torch_time * 1e-3) / 1e12 sm100_tflops = flops / (sm100_time * 1e-3) / 1e12 print(f"PyTorch time: {torch_time:.3f} ms ({torch_tflops:.2f} TFLOPS)") print(f"SM100+TMA time: {sm100_time:.3f} ms ({sm100_tflops:.2f} TFLOPS)") print(f"Speedup: {torch_time/sm100_time:.2f}x") print(f"Max difference: {max_diff:.6f}") print(f"Relative error: {rel_error:.6f}") print(f"🚀 TMA provides efficient memory transfers for large matrices!") return { "torch_time": torch_time, "sm100_time": sm100_time, "speedup": torch_time / sm100_time, "torch_tflops": torch_tflops, "sm100_tflops": sm100_tflops, "max_diff": max_diff.item(), "rel_error": rel_error.item(), } # Neural network layer implementations with TMA class SM100LinearTMA(torch.nn.Module): """ Linear layer using SM100 GEMM with TMA for forward pass """ def __init__(self, in_features, out_features, bias=True, device="cuda"): super().__init__() # Align dimensions self.orig_in_features = in_features self.orig_out_features = out_features aligned_in = ( (in_features + sm100_gemm.MMA_TILE_K - 1) // sm100_gemm.MMA_TILE_K ) * sm100_gemm.MMA_TILE_K aligned_out = ( (out_features + sm100_gemm.MMA_TILE_N - 1) // sm100_gemm.MMA_TILE_N ) * sm100_gemm.MMA_TILE_N self.in_features = aligned_in self.out_features = aligned_out # Parameters (with padding) self.weight = torch.nn.Parameter( torch.randn(aligned_out, aligned_in, dtype=torch.float16, device=device) * 0.1 ) if bias: self.bias = torch.nn.Parameter( torch.zeros(aligned_out, dtype=torch.float32, device=device) ) else: self.register_parameter("bias", None) print( f"SM100LinearTMA: {in_features} -> {out_features} (aligned: {aligned_in} -> {aligned_out})" ) print("🚀 Using TMA for efficient memory transfers") def forward(self, x): # Pad input if necessary batch_size = x.size(0) # Align batch size aligned_batch = ( (batch_size + sm100_gemm.MMA_TILE_M - 1) // sm100_gemm.MMA_TILE_M ) * sm100_gemm.MMA_TILE_M if x.size(1) != self.in_features or batch_size != aligned_batch: x_padded = torch.zeros( aligned_batch, self.in_features, dtype=torch.float16, device=x.device ) x_padded[:batch_size, : self.orig_in_features] = x x = x_padded # Prepare bias if self.bias is not None: C = ( self.bias.unsqueeze(0) .expand(aligned_batch, self.out_features) .contiguous() ) beta = 1.0 else: C = torch.zeros( aligned_batch, self.out_features, dtype=torch.float32, device=x.device ) beta = 0.0 # SM100 GEMM with TMA: output = x @ weight^T + bias output = sm100_gemm_f16_tma( x, self.weight, C, alpha=1.0, beta=beta, check_alignment=False ) # Remove padding return output[:batch_size, : self.orig_out_features] def benchmark_tma_vs_cooperative_copy(M=512, N=1024, K=256, num_trials=50): """ TMA addition """ results = benchmark_sm100_vs_torch(M, N, K, num_trials=num_trials) if results: print(f"\nTMA-accelerated SM100 GEMM achieved:") print(f" Performance: {results['sm100_tflops']:.2f} TFLOPS") print(f" Speedup: {results['speedup']:.2f}x over PyTorch") print(f" Memory efficiency: Hardware-optimized transfers") def stress_test_large_matrices(): """ Test TMA performance with large matrices that benefit most from TMA """ print("\n=== Large Matrix Stress Test with TMA ===") # Test progressively larger matrices test_sizes = [ (1024, 2048, 512), # 1GB+ tensors (2048, 4096, 1024), # 4GB+ tensors (4096, 8192, 2048), # 16GB+ tensors (if memory allows) ] for M, N, K in test_sizes: try: print(f"\nTesting size: ({M}, {N}, {K})") # Check memory requirements memory_A = M * K * 2 # FP16 memory_B = N * K * 2 # FP16 memory_C = M * N * 4 # FP32 total_memory = (memory_A + memory_B + memory_C * 2) / (1024**3) # GB print(f"Memory requirement: {total_memory:.2f} GB") if total_memory > 20: # Skip if > 20GB print("⚠️ Skipping due to memory constraints") continue # Create tensors A, B, C = create_aligned_tensors(M, N, K) A[:M, :K].normal_(0, 0.1) B[:N, :K].normal_(0, 0.1) C[:M, :N].normal_(0, 0.1) # Warmup for _ in range(3): result = sm100_gemm_f16_tma(A, B, C.clone(), check_alignment=False) torch.cuda.synchronize() # Benchmark start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) num_trials = 10 start.record() for _ in range(num_trials): result = sm100_gemm_f16_tma(A, B, C.clone(), check_alignment=False) end.record() torch.cuda.synchronize() avg_time = start.elapsed_time(end) / num_trials # Calculate performance flops = 2 * M * N * K tflops = flops / (avg_time * 1e-3) / 1e12 bandwidth = total_memory / (avg_time * 1e-3) # GB/s print(f"✅ Time: {avg_time:.2f} ms") print(f"✅ Performance: {tflops:.2f} TFLOPS") print(f"✅ Bandwidth: {bandwidth:.1f} GB/s") print(f"🚀 TMA enables efficient handling of large matrices!") except torch.cuda.OutOfMemoryError: print(f"❌ Out of memory for size ({M}, {N}, {K})") break except Exception as e: print(f"❌ Error: {e}") break # Example usage and test if __name__ == "__main__": print("=== SM100 GEMM with TMA Extension Test ===") # Check compatibility first if not check_sm100_compatibility(): print("Exiting due to compatibility issues") exit(1) print("\n=== Testing basic TMA functionality ===") # Test with properly aligned dimensions M, N, K = 512, 1024, 256 A, B, C = create_aligned_tensors(M, N, K) # Fill with random data (only the actual needed portion) A[:M, :K].normal_() B[:N, :K].normal_() C[:M, :N].normal_() # Test the TMA GEMM result = sm100_gemm_f16_tma(A, B, C, alpha=1.0, beta=0.5, check_alignment=False) print( f"✅ TMA GEMM test passed. Result shape: {result.shape}, dtype: {result.dtype}" ) print("\n=== Testing SM100LinearTMA layer ===") # Test linear layer with TMA layer = SM100LinearTMA(256, 512, bias=True) x = torch.randn(128, 256, dtype=torch.float16, device="cuda") output = layer(x) print(f"✅ TMA Linear layer test passed. Output shape: {output.shape}") print("\n=== Testing padding utilities ===") # Test padding for misaligned tensors misaligned_A = torch.randn(300, 200, dtype=torch.float16, device="cuda") padded_A, pad_info = pad_to_aligned(misaligned_A) print(f"Original shape: {misaligned_A.shape}, Padded shape: {padded_A.shape}") unpadded = unpad_result(padded_A, pad_info) print(f"✅ Padding test passed. Unpadded shape: {unpadded.shape}") print("\n=== Running TMA performance benchmark ===") # Run benchmark benchmark_results = benchmark_sm100_vs_torch(M=512, N=1024, K=256, num_trials=50) if benchmark_results: print(f"\n✅ All TMA tests passed!") print( f"🚀 SM100+TMA achieved {benchmark_results['speedup']:.2f}x speedup over PyTorch" ) print(f"🚀 TMA provides hardware-accelerated memory transfers!") # Run additional TMA-specific tests benchmark_tma_vs_cooperative_copy(M=1024, N=2048, K=512) # Test with larger matrices if memory allows print("\n=== Testing TMA with larger matrices ===") stress_test_large_matrices() else: print("❌ Benchmark failed") print("\n=== TMA Summary ===") print("🚀 TMA (Tensor Memory Accelerator) provides:") print(" • Hardware-accelerated global->shared memory transfers") print(" • Reduced CPU overhead and better bandwidth utilization") print(" • Automatic memory layout optimization") print(" • Essential for peak performance on large matrices") print(" • Enables scaling to multi-GB tensor operations") import sm100_gemm # The compiled extension # python_interface.py - High-level Python interface (updated for split files) import torch def check_sm100_compatibility(): """Check if SM100 is supported and available""" compile_support = sm100_gemm.is_sm100_supported() device_support = sm100_gemm.check_sm100_device() info = sm100_gemm.get_device_info() major, minor, compile_flag, device_flag = info.tolist() print(f"Device compute capability: {major}.{minor}") print(f"Compile-time SM100 support: {bool(compile_flag)}") print(f"Runtime SM100 device support: {bool(device_flag)}") if not compile_support: print( "❌ SM100 support not compiled in. Rebuild with CUTLASS_ARCH_MMA_SM100_SUPPORTED" ) elif not device_support: print("❌ Current GPU does not support SM100 (need compute capability 10.0a)") else: print("SM100 ready!") # ✅ return compile_support and device_support def sm100_gemm_f16(A, B, C=None, alpha=1.0, beta=0.0, check_alignment=True): """ Perform GEMM using SM100 optimized kernel: D = alpha * A @ B^T + beta * C Args: A (torch.Tensor): Input tensor A of shape (M, K), dtype=torch.float16 B (torch.Tensor): Input tensor B of shape (N, K), dtype=torch.float16 C (torch.Tensor, optional): Input tensor C of shape (M, N), dtype=torch.float32 If None, creates zero tensor alpha (float): Scaling factor for A @ B^T beta (float): Scaling factor for C check_alignment (bool): Whether to check and suggest aligned dimensions Returns: torch.Tensor: Output tensor D of shape (M, N), dtype=torch.float32 Note: - A and B are K-major (transposed in BLAS terms) - C and D are N-major (row-major) - All tensors must be on CUDA - M must be multiple of 128, N multiple of 256, K multiple of 64 """ # Input validation assert A.dtype == torch.float16, f"A must be float16, got {A.dtype}" assert B.dtype == torch.float16, f"B must be float16, got {B.dtype}" assert A.is_cuda and B.is_cuda, "A and B must be on CUDA" assert A.is_contiguous() and B.is_contiguous(), "A and B must be contiguous" M, K = A.shape N, K_B = B.shape assert K == K_B, f"Inner dimensions must match: A.shape[1]={K}, B.shape[1]={K_B}" # Check or fix alignment requirements if check_alignment: aligned_M, aligned_N, aligned_K = sm100_gemm.get_aligned_shape(M, N, K) if M != aligned_M or N != aligned_N or K != aligned_K: print(f"Warning: Dimensions ({M}, {N}, {K}) not aligned for SM100") print( f"Suggested aligned dimensions: ({aligned_M}, {aligned_N}, {aligned_K})" ) print("Consider padding tensors or use create_aligned_tensors()") # Strict alignment check assert ( M % sm100_gemm.MMA_TILE_M == 0 ), f"M={M} must be multiple of {sm100_gemm.MMA_TILE_M}" assert ( N % sm100_gemm.MMA_TILE_N == 0 ), f"N={N} must be multiple of {sm100_gemm.MMA_TILE_N}" assert ( K % sm100_gemm.MMA_TILE_K == 0 ), f"K={K} must be multiple of {sm100_gemm.MMA_TILE_K}" # Create C if not provided if C is None: C = torch.zeros(M, N, dtype=torch.float32, device=A.device) else: assert C.dtype == torch.float32, f"C must be float32, got {C.dtype}" assert C.is_cuda, "C must be on CUDA" assert C.is_contiguous(), "C must be contiguous" assert C.shape == ( M, N, ), f"C shape {C.shape} must match output shape ({M}, {N})" # Call the extension return sm100_gemm.sm100_gemm_f16(A, B, C, alpha, beta) def create_aligned_tensors( M, N, K, device="cuda", dtype_AB=torch.float16, dtype_C=torch.float32 ): """ Create properly aligned tensors for SM100 GEMM Returns: tuple: (A, B, C) tensors with aligned dimensions """ aligned_M, aligned_N, aligned_K = sm100_gemm.get_aligned_shape(M, N, K) A = torch.zeros(aligned_M, aligned_K, dtype=dtype_AB, device=device) B = torch.zeros(aligned_N, aligned_K, dtype=dtype_AB, device=device) C = torch.zeros(aligned_M, aligned_N, dtype=dtype_C, device=device) return A, B, C def pad_to_aligned(tensor, target_shape=None, dim_requirements=None): """ Pad tensor to meet SM100 alignment requirements Args: tensor: Input tensor to pad target_shape: Specific target shape (optional) dim_requirements: Tuple of (M_align, N_align, K_align) requirements Returns: Padded tensor and padding info for later unpadding """ if dim_requirements is None: dim_requirements = ( sm100_gemm.MMA_TILE_M, sm100_gemm.MMA_TILE_N, sm100_gemm.MMA_TILE_K, ) if tensor.dim() == 2: M, N = tensor.shape if target_shape: target_M, target_N = target_shape else: target_M = ( (M + dim_requirements[0] - 1) // dim_requirements[0] ) * dim_requirements[0] target_N = ( (N + dim_requirements[1] - 1) // dim_requirements[1] ) * dim_requirements[1] pad_M = target_M - M pad_N = target_N - N # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom) padded = torch.nn.functional.pad(tensor, (0, pad_N, 0, pad_M)) return padded, (M, N, pad_M, pad_N) else: raise ValueError("Only 2D tensors supported") def unpad_result(tensor, padding_info): """Remove padding from result tensor""" orig_M, orig_N, pad_M, pad_N = padding_info return tensor[:orig_M, :orig_N] def benchmark_sm100_vs_torch( M=512, N=1024, K=256, num_warmup=10, num_trials=100, auto_align=True ): """ Benchmark SM100 GEMM against PyTorch's native GEMM """ # Ensure dimensions are aligned if auto_align: M = ( (M + sm100_gemm.MMA_TILE_M - 1) // sm100_gemm.MMA_TILE_M ) * sm100_gemm.MMA_TILE_M N = ( (N + sm100_gemm.MMA_TILE_N - 1) // sm100_gemm.MMA_TILE_N ) * sm100_gemm.MMA_TILE_N K = ( (K + sm100_gemm.MMA_TILE_K - 1) // sm100_gemm.MMA_TILE_K ) * sm100_gemm.MMA_TILE_K print(f"Benchmarking GEMM with shape: ({M}, {N}, {K})") # Check SM100 availability if not check_sm100_compatibility(): print("SM100 not available, skipping benchmark") return None # Create test tensors A = torch.randn(M, K, dtype=torch.float16, device="cuda") B = torch.randn(N, K, dtype=torch.float16, device="cuda") C = torch.randn(M, N, dtype=torch.float32, device="cuda") # PyTorch baseline (using mixed precision) A_fp32 = A.float() B_fp32 = B.float() # Warmup for _ in range(num_warmup): # PyTorch GEMM torch_result = torch.addmm(C, A_fp32, B_fp32.T) # SM100 GEMM sm100_result = sm100_gemm_f16(A, B, C.clone(), check_alignment=False) torch.cuda.synchronize() # Benchmark PyTorch torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(num_trials): torch_result = torch.addmm(C, A_fp32, B_fp32.T) end.record() torch.cuda.synchronize() torch_time = start.elapsed_time(end) / num_trials # Benchmark SM100 start.record() for _ in range(num_trials): sm100_result = sm100_gemm_f16(A, B, C.clone(), check_alignment=False) end.record() torch.cuda.synchronize() sm100_time = start.elapsed_time(end) / num_trials # Check correctness max_diff = torch.max(torch.abs(torch_result - sm100_result)) rel_error = max_diff / torch.max(torch.abs(torch_result)) # Calculate FLOPS flops = 2 * M * N * K # Multiply-add operations torch_tflops = flops / (torch_time * 1e-3) / 1e12 sm100_tflops = flops / (sm100_time * 1e-3) / 1e12 print(f"PyTorch time: {torch_time:.3f} ms ({torch_tflops:.2f} TFLOPS)") print(f"SM100 time: {sm100_time:.3f} ms ({sm100_tflops:.2f} TFLOPS)") print(f"Speedup: {torch_time/sm100_time:.2f}x") # print(f"Max difference: {max_diff:.6f}") print(f"Relative error: {rel_error:.6f}") return { "torch_time": torch_time, "sm100_time": sm100_time, "speedup": torch_time / sm100_time, "torch_tflops": torch_tflops, "sm100_tflops": sm100_tflops, "max_diff": max_diff.item(), "rel_error": rel_error.item(), } # Neural network layer implementations class SM100Linear(torch.nn.Module): """ Linear layer using SM100 GEMM for forward pass """ def __init__(self, in_features, out_features, bias=True, device="cuda"): super().__init__() # Align dimensions self.orig_in_features = in_features self.orig_out_features = out_features aligned_in = ( (in_features + sm100_gemm.MMA_TILE_K - 1) // sm100_gemm.MMA_TILE_K ) * sm100_gemm.MMA_TILE_K aligned_out = ( (out_features + sm100_gemm.MMA_TILE_N - 1) // sm100_gemm.MMA_TILE_N ) * sm100_gemm.MMA_TILE_N self.in_features = aligned_in self.out_features = aligned_out # Parameters (with padding) self.weight = torch.nn.Parameter( torch.randn(aligned_out, aligned_in, dtype=torch.float16, device=device) * 0.1 ) if bias: self.bias = torch.nn.Parameter( torch.zeros(aligned_out, dtype=torch.float32, device=device) ) else: self.register_parameter("bias", None) print( f"SM100Linear: {in_features} -> {out_features} (aligned: {aligned_in} -> {aligned_out})" ) def forward(self, x): # Pad input if necessary batch_size = x.size(0) # Align batch size aligned_batch = ( (batch_size + sm100_gemm.MMA_TILE_M - 1) // sm100_gemm.MMA_TILE_M ) * sm100_gemm.MMA_TILE_M if x.size(1) != self.in_features or batch_size != aligned_batch: x_padded = torch.zeros( aligned_batch, self.in_features, dtype=torch.float16, device=x.device ) x_padded[:batch_size, : self.orig_in_features] = x x = x_padded # Prepare bias if self.bias is not None: C = ( self.bias.unsqueeze(0) .expand(aligned_batch, self.out_features) .contiguous() ) beta = 1.0 else: C = torch.zeros( aligned_batch, self.out_features, dtype=torch.float32, device=x.device ) beta = 0.0 # SM100 GEMM: output = x @ weight^T + bias output = sm100_gemm_f16( x, self.weight, C, alpha=1.0, beta=beta, check_alignment=False ) # Remove padding return output[:batch_size, : self.orig_out_features] # Example usage and test if __name__ == "__main__": print("=== SM100 GEMM Extension Test ===") # Check compatibility first if not check_sm100_compatibility(): print("Exiting due to compatibility issues") exit(1) print("\n=== Testing basic functionality ===") # Test with properly aligned dimensions M, N, K = 512, 1024, 256 A, B, C = create_aligned_tensors(M, N, K) # Fill with random data (only the actual needed portion) A[:M, :K].normal_() B[:N, :K].normal_() C[:M, :N].normal_() # Test the GEMM result = sm100_gemm_f16(A, B, C, alpha=1.0, beta=0.5, check_alignment=False) print( f"✅ Basic GEMM test passed. Result shape: {result.shape}, dtype: {result.dtype}" ) print("\n=== Testing SM100Linear layer ===") # Test linear layer layer = SM100Linear(256, 512, bias=True) x = torch.randn(128, 256, dtype=torch.float16, device="cuda") output = layer(x) print(f"✅ Linear layer test passed. Output shape: {output.shape}") print("\n=== Testing padding utilities ===") # Test padding for misaligned tensors misaligned_A = torch.randn(300, 200, dtype=torch.float16, device="cuda") padded_A, pad_info = pad_to_aligned(misaligned_A) print(f"Original shape: {misaligned_A.shape}, Padded shape: {padded_A.shape}") unpadded = unpad_result(padded_A, pad_info) print(f"✅ Padding test passed. Unpadded shape: {unpadded.shape}") print("\n=== Running performance benchmark ===") # Run benchmark benchmark_results = benchmark_sm100_vs_torch( M=8192, N=8192 * 2, K=2048, num_trials=50 ) if benchmark_results: print(f"\n✅ All tests passed!") print( f"SM100 achieved {benchmark_results['speedup']:.2f}x speedup over PyTorch" ) else: print("❌ Benchmark failed") ================================================ FILE: kernels/blackwell/cute_gemm_02_tma/setup.py ================================================ # setup.py import os import pybind11 import torch from pybind11 import get_cmake_dir from pybind11.setup_helpers import build_ext, Pybind11Extension from setuptools import Extension, setup from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension # IMPORTANT: The following two lines are the only ones you need to change # Get CUTLASS path (you'll need to set this to your CUTLASS installation) CUTLASS_PATH = os.environ.get("CUTLASS_PATH", "/home/less/local/cutlas40") # CUDA and PyTorch paths cuda_home = torch.utils.cpp_extension.CUDA_HOME pytorch_includes = torch.utils.cpp_extension.include_paths() ext_modules = [ CUDAExtension( name="sm100_gemm", sources=[ "sm100_gemm_pytorch.cpp", # PyTorch bindings (C++) "sm100_gemm.cu", # CUDA kernel implementation ], include_dirs=[ # PyTorch includes *pytorch_includes, # CUTLASS includes f"{CUTLASS_PATH}/include", f"{CUTLASS_PATH}/tools/util/include", # CUDA includes f"{cuda_home}/include", ], library_dirs=[ f"{cuda_home}/lib64", ], libraries=["cuda", "cudart"], extra_compile_args={ "cxx": [ "-O3", "-std=c++17", "-DCUTLASS_ARCH_MMA_SM100_SUPPORTED", "-DCUTE_SM100_ENABLED", ], "nvcc": [ "-O3", "-std=c++17", "--expt-relaxed-constexpr", "--expt-extended-lambda", "-gencode=arch=compute_100a,code=sm_100a", # SM100 architecture "-DCUTLASS_ARCH_MMA_SM100_SUPPORTED", "-DCUTE_SM100_ENABLED", "--use_fast_math", "-Xcompiler=-fPIC", "-DCUTE_ARCH_TCGEN05_TMEM_ENABLED=1", # Enable TCGEN05_TMEM ], }, extra_link_args=["-lcuda", "-lcudart"], language="c++", ) ] setup( name="sm100_gemm", ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension}, zip_safe=False, python_requires=">=3.8", install_requires=["torch>=1.12.0"], ) ================================================ FILE: kernels/blackwell/cute_gemm_02_tma/sm100_gemm.cu ================================================ // sm100_gemm_kernel.cu - CUDA kernel implementation with TMA #include "sm100_gemm.h" #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) #include #include #include #include #include #include #include #include #include using namespace cute; // Shared storage structure with TMA barriers template struct SharedStorage { alignas(128) cute::ArrayEngine> A; alignas(128) cute::ArrayEngine> B; alignas(16) cute::uint64_t mma_barrier; // Barrier to track MMA computation on SMEM alignas(16) cute::uint64_t tma_barrier; // Barrier to track TMA data transfers to SMEM alignas(16) cute::uint32_t tmem_base_ptr; // Base pointer for TMEM allocation CUTE_DEVICE constexpr auto tensor_sA() { return make_tensor(make_smem_ptr(A.begin()), ASmemLayout{}); } CUTE_DEVICE constexpr auto tensor_sB() { return make_tensor(make_smem_ptr(B.begin()), BSmemLayout{}); } }; // Device kernel with TMA template __global__ static void gemm_device_tma( ATensor mA, BTensor mB, CTensor mC, DTensor mD, MmaTiler_MNK mma_tiler, TiledMMA tiled_mma, ClusterShape_MNK cluster_shape, CUTE_GRID_CONSTANT TmaAtomA const tma_atom_A, CUTE_GRID_CONSTANT TmaAtomB const tma_atom_B, Alpha alpha, Beta beta) { // Step 1: The Prologue Layout cluster_layout_vmnk = tiled_divide( make_layout(cluster_shape), make_tile(typename TiledMMA::AtomThrID{})); auto mma_coord_vmnk = make_coord(blockIdx.x % size<0>(cluster_layout_vmnk), blockIdx.x / size<0>(cluster_layout_vmnk), blockIdx.y, _); auto mma_coord = select<1, 2, 3>(mma_coord_vmnk); Tensor gA = local_tile(mA, mma_tiler, mma_coord, Step<_1, X, _1>{}); Tensor gB = local_tile(mB, mma_tiler, mma_coord, Step{}); Tensor gC = local_tile(mC, mma_tiler, mma_coord, Step<_1, _1, X>{}); Tensor gD = local_tile(mD, mma_tiler, mma_coord, Step<_1, _1, X>{}); // SMEM allocation extern __shared__ char shared_memory[]; SharedStorage &shared_storage = *reinterpret_cast(shared_memory); Tensor tCsA = shared_storage.tensor_sA(); Tensor tCsB = shared_storage.tensor_sB(); // MMA partitioning auto mma_v = get<0>(mma_coord_vmnk); ThrMMA cta_mma = tiled_mma.get_slice(mma_v); Tensor tCgA = cta_mma.partition_A(gA); Tensor tCgB = cta_mma.partition_B(gB); Tensor tCgC = cta_mma.partition_C(gC); Tensor tCgD = cta_mma.partition_C(gD); // Fragment allocation Tensor tCrA = cta_mma.make_fragment_A(tCsA); Tensor tCrB = cta_mma.make_fragment_B(tCsB); Tensor tCtAcc = cta_mma.make_fragment_C(tCgC); uint32_t elect_one_thr = cute::elect_one_sync(); uint32_t elect_one_warp = (threadIdx.x / 32 == 0); using TmemAllocator = cute::TMEM::Allocator1Sm; TmemAllocator tmem_allocator{}; // TMEM allocation if (elect_one_warp) { tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); } __syncthreads(); tCtAcc.data() = shared_storage.tmem_base_ptr; // TMA Setup // TMA partitioning with dedicated custom partitioner // The Int<0>, Layout<_1> indicates that the TMAs are not multicasted // group_modes<0,3> transforms the tensor shape for TMA operation auto [tAgA, tAsA] = tma_partition(tma_atom_A, Int<0>{}, Layout<_1>{}, group_modes<0, 3>(tCsA), group_modes<0, 3>(tCgA)); auto [tBgB, tBsB] = tma_partition(tma_atom_B, Int<0>{}, Layout<_1>{}, group_modes<0, 3>(tCsB), group_modes<0, 3>(tCgB)); // Calculate total bytes that TMA will transfer each tile to track completion int tma_transaction_bytes = sizeof(make_tensor_like(tAsA)) + sizeof(make_tensor_like(tBsB)); // Barrier initialization if (elect_one_warp && elect_one_thr) { cute::initialize_barrier(shared_storage.mma_barrier, 1); cute::initialize_barrier(shared_storage.tma_barrier, 1); } int mma_barrier_phase_bit = 0; int tma_barrier_phase_bit = 0; __syncthreads(); // Step 2: The Mainloop with TMA tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; for (int k_tile = 0; k_tile < size<3>(tCgA); ++k_tile) { // Step 2a: TMA Load Operations // Execute asynchronous TMA loads with single thread if (elect_one_warp && elect_one_thr) { cute::set_barrier_transaction_bytes(shared_storage.tma_barrier, tma_transaction_bytes); copy(tma_atom_A.with(shared_storage.tma_barrier), tAgA(_, k_tile), tAsA); // Load MmaTile_M x MmaTile_K A tile copy(tma_atom_B.with(shared_storage.tma_barrier), tBgB(_, k_tile), tBsB); // Load MmaTile_N x MmaTile_K B tile } // Step 2b: Wait for TMA loads and execute MMAs // Wait for TMA loads to SMEM to complete cute::wait_barrier(shared_storage.tma_barrier, tma_barrier_phase_bit); tma_barrier_phase_bit ^= 1; // Execute MMAs if (elect_one_warp) { for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCtAcc); tiled_mma.accumulate_ = UMMA::ScaleOut::One; } cutlass::arch::umma_arrive(&shared_storage.mma_barrier); } // Wait MMAs to complete to avoid overwriting the A and B SMEM cute::wait_barrier(shared_storage.mma_barrier, mma_barrier_phase_bit); mma_barrier_phase_bit ^= 1; } // Step 3: The Epilogue TiledCopy tiled_t2r_copy = make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc); ThrCopy thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x); Tensor tDgC = thr_t2r_copy.partition_D(tCgC); Tensor tDrC = make_fragment_like(tDgC); copy(tDgC, tDrC); Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc); Tensor tDgD = thr_t2r_copy.partition_D(tCgD); using AccType = typename decltype(tCtAcc)::value_type; Tensor tDrAcc = make_tensor(shape(tDgD)); copy(tiled_t2r_copy, tDtAcc, tDrAcc); // AXPBY and store result axpby(alpha, tDrAcc, beta, tDrC); copy(tDrC, tDgD); __syncthreads(); // Cleanup TMEM if (elect_one_warp) { tmem_allocator.release_allocation_lock(); tmem_allocator.free(shared_storage.tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); } } // Host setup // Host function that creates TMA descriptors and launches the kernel cudaError_t launch_sm100_gemm_f16_tma(void *d_A, void *d_B, void *d_C, void *d_D, int M, int N, int K, float alpha, float beta, cudaStream_t stream) { // Define types using TypeA = cutlass::half_t; using TypeB = cutlass::half_t; using TypeC = float; using TypeD = float; // Create layouts (K-major for A and B, N-major for C and D) auto layout_A = make_layout(make_shape(M, K), make_stride(K, Int<1>{})); auto layout_B = make_layout(make_shape(N, K), make_stride(K, Int<1>{})); auto layout_C = make_layout(make_shape(M, N), make_stride(N, Int<1>{})); auto layout_D = layout_C; // Create CuTe tensors auto mA = make_tensor(make_gmem_ptr(reinterpret_cast(d_A)), layout_A); auto mB = make_tensor(make_gmem_ptr(reinterpret_cast(d_B)), layout_B); auto mC = make_tensor(make_gmem_ptr(reinterpret_cast(d_C)), layout_C); auto mD = make_tensor(make_gmem_ptr(reinterpret_cast(d_D)), layout_D); // Create TiledMMA TiledMMA tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_SS{}); // Define MMA tiler sizes auto bM = tile_size<0>(tiled_mma); // 128 auto bN = tile_size<1>(tiled_mma); // 256 auto bK = tile_size<2>(tiled_mma) * Int<4>{}; // 64 auto mma_tiler = make_shape(bM, bN, bK); // Check alignment if (M % int(bM) != 0 || N % int(bN) != 0 || K % int(bK) != 0) { return cudaErrorInvalidValue; } // Create SMEM layouts auto mma_shape_A = partition_shape_A( tiled_mma, make_shape(size<0>(mma_tiler), size<2>(mma_tiler))); auto mma_shape_B = partition_shape_B( tiled_mma, make_shape(size<1>(mma_tiler), size<2>(mma_tiler))); auto sA_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom{}, mma_shape_A); auto sB_layout = UMMA::tile_to_mma_shape(UMMA::Layout_K_SW128_Atom{}, mma_shape_B); using SMEMStorage = SharedStorage; // Cluster configuration auto cluster_shape = make_shape(Int<1>{}, Int<1>{}, Int<1>{}); // Create TMA descriptors for A and B matrices Copy_Atom tma_atom_A = make_tma_atom(SM90_TMA_LOAD{}, // TMA Load Op mA, // Source GMEM tensor sA_layout, // Destination SMEM layout select<0, 2>(mma_tiler) // MK Tiler for TMA operation ); Tensor mA_tma = tma_atom_A.get_tma_tensor(shape(mA)); Copy_Atom tma_atom_B = make_tma_atom(SM90_TMA_LOAD{}, // TMA Load Op mB, // Source GMEM tensor sB_layout, // Destination SMEM layout select<1, 2>(mma_tiler) // NK Tiler for TMA operation ); Tensor mB_tma = tma_atom_B.get_tma_tensor(shape(mB)); // Launch parameters dim3 dimBlock(128); dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); dim3 dimGrid(ceil_div(M, int(bM)), ceil_div(N, int(bN))); int smemBytes = sizeof(SMEMStorage); // Get kernel pointer auto *kernel_ptr = &gemm_device_tma; // Set kernel attributes cudaError_t error = cudaFuncSetAttribute( kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smemBytes); if (error != cudaSuccess) { return error; } // Launch kernel cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smemBytes}; cutlass::Status status = cutlass::launch_kernel_on_cluster( params, (void const *)kernel_ptr, mA_tma, mB_tma, mC, mD, mma_tiler, tiled_mma, cluster_shape, tma_atom_A, tma_atom_B, alpha, beta); return (status == cutlass::Status::kSuccess) ? cudaSuccess : cudaErrorLaunchFailure; } // Wrapper function to choose between TMA and non-TMA versions cudaError_t launch_sm100_gemm_f16(void *d_A, void *d_B, void *d_C, void *d_D, int M, int N, int K, float alpha, float beta, cudaStream_t stream) { // For now, always use TMA version for better performance return launch_sm100_gemm_f16_tma(d_A, d_B, d_C, d_D, M, N, K, alpha, beta, stream); } #else cudaError_t launch_sm100_gemm_f16(void *d_A, void *d_B, void *d_C, void *d_D, int M, int N, int K, float alpha, float beta, cudaStream_t stream) { return cudaErrorNotSupported; } #endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED ================================================ FILE: kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/PKG-INFO ================================================ Metadata-Version: 2.4 Name: sm100_gemm Version: 0.0.0 Requires-Python: >=3.8 Requires-Dist: torch>=1.12.0 Dynamic: requires-dist Dynamic: requires-python ================================================ FILE: kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/SOURCES.txt ================================================ setup.py sm100_gemm.cu sm100_gemm_pytorch.cpp sm100_gemm.egg-info/PKG-INFO sm100_gemm.egg-info/SOURCES.txt sm100_gemm.egg-info/dependency_links.txt sm100_gemm.egg-info/not-zip-safe sm100_gemm.egg-info/requires.txt sm100_gemm.egg-info/top_level.txt ================================================ FILE: kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/dependency_links.txt ================================================ ================================================ FILE: kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/not-zip-safe ================================================ ================================================ FILE: kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/requires.txt ================================================ torch>=1.12.0 ================================================ FILE: kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/top_level.txt ================================================ sm100_gemm ================================================ FILE: kernels/blackwell/cute_gemm_02_tma/sm100_gemm.h ================================================ // sm100_gemm_kernel.h - Header file for CUDA kernel #pragma once #include #ifdef __cplusplus extern "C" { #endif /** * Launch SM100 GEMM kernel: D = alpha * A @ B^T + beta * C * * @param d_A Pointer to matrix A in device memory (M x K, FP16, K-major) * @param d_B Pointer to matrix B in device memory (N x K, FP16, K-major) * @param d_C Pointer to matrix C in device memory (M x N, FP32, N-major) * @param d_D Pointer to matrix D in device memory (M x N, FP32, N-major) * @param M Number of rows in A and C/D * @param N Number of rows in B and columns in C/D * @param K Number of columns in A and B * @param alpha Scaling factor for A @ B^T * @param beta Scaling factor for C * @param stream CUDA stream (currently unused, for future async support) * * @return cudaSuccess on success, error code otherwise * * Requirements: * - M must be multiple of 128 * - N must be multiple of 256 * - K must be multiple of 64 * - All pointers must be valid device memory * - Tensors must be contiguous with specified layouts */ cudaError_t launch_sm100_gemm_f16_tma(void *d_A, void *d_B, void *d_C, void *d_D, int M, int N, int K, float alpha, float beta, cudaStream_t stream = 0); #ifdef __cplusplus } #endif ================================================ FILE: kernels/blackwell/cute_gemm_02_tma/sm100_gemm_pytorch.cpp ================================================ // sm100_gemm_pytorch.cpp - PyTorch C++ extension (no CUDA code) #include #include #include #include #include #include "sm100_gemm.h" // Check if SM100 support is available at compile time bool is_sm100_supported() { #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) return true; #else return false; #endif } // Check if current GPU supports SM100 at runtime bool check_sm100_device() { int device; cudaGetDevice(&device); cudaDeviceProp props; cudaError_t error = cudaGetDeviceProperties(&props, device); if (error != cudaSuccess) { return false; } // Check for SM100 architecture (compute capability 10.0a) return (props.major == 10 && props.minor == 0); } torch::Tensor sm100_gemm_f16(const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C, float alpha = 1.0f, float beta = 0.0f) { // Check compile-time support TORCH_CHECK( is_sm100_supported(), "SM100 support not compiled. Requires CUTLASS_ARCH_MMA_SM100_SUPPORTED"); // Check runtime device support TORCH_CHECK(check_sm100_device(), "Current GPU does not support SM100 architecture (requires " "compute capability 10.0a)"); // Input validation TORCH_CHECK(A.device().is_cuda(), "A must be a CUDA tensor"); TORCH_CHECK(B.device().is_cuda(), "B must be a CUDA tensor"); TORCH_CHECK(C.device().is_cuda(), "C must be a CUDA tensor"); TORCH_CHECK(A.dtype() == torch::kFloat16, "A must be float16"); TORCH_CHECK(B.dtype() == torch::kFloat16, "B must be float16"); TORCH_CHECK(C.dtype() == torch::kFloat32, "C must be float32"); TORCH_CHECK(A.is_contiguous(), "A must be contiguous"); TORCH_CHECK(B.is_contiguous(), "B must be contiguous"); TORCH_CHECK(C.is_contiguous(), "C must be contiguous"); TORCH_CHECK(A.dim() == 2, "A must be 2D"); TORCH_CHECK(B.dim() == 2, "B must be 2D"); TORCH_CHECK(C.dim() == 2, "C must be 2D"); // Get dimensions int64_t M = A.size(0); int64_t K = A.size(1); int64_t N = B.size(0); int64_t K_B = B.size(1); TORCH_CHECK(K == K_B, "Inner dimensions must match: A.shape[1]=", K, ", B.shape[1]=", K_B); TORCH_CHECK(C.size(0) == M && C.size(1) == N, "C dimensions (", C.size(0), ", ", C.size(1), ") must match output shape (", M, ", ", N, ")"); // Check alignment requirements for SM100 TORCH_CHECK(M % 128 == 0, "M=", M, " must be multiple of 128"); TORCH_CHECK(N % 256 == 0, "N=", N, " must be multiple of 256"); TORCH_CHECK(K % 64 == 0, "K=", K, " must be multiple of 64"); // Check size limits (avoid overflow in int conversion) TORCH_CHECK(M <= INT_MAX && N <= INT_MAX && K <= INT_MAX, "Dimensions too large for int conversion"); // Create output tensor auto D = torch::empty_like(C); // Set CUDA device guard const auto device = A.device(); c10::cuda::CUDAGuard device_guard(device); // Get current CUDA stream cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index()).stream(); // Launch the kernel cudaError_t error = launch_sm100_gemm_f16_tma( A.data_ptr(), B.data_ptr(), C.data_ptr(), D.data_ptr(), static_cast(M), static_cast(N), static_cast(K), alpha, beta, stream); // Check for launch errors TORCH_CHECK(error == cudaSuccess, "SM100 GEMM kernel launch failed: ", cudaGetErrorString(error)); // Check for kernel execution errors C10_CUDA_CHECK(cudaGetLastError()); return D; } // Utility functions for debugging and information torch::Tensor get_device_info() { int device; cudaGetDevice(&device); cudaDeviceProp props; cudaGetDeviceProperties(&props, device); // Return device info as a tensor (for easy Python access) auto info = torch::zeros({4}, torch::kInt32); auto accessor = info.accessor(); accessor[0] = props.major; // Compute capability major accessor[1] = props.minor; // Compute capability minor accessor[2] = is_sm100_supported(); // Compile-time support accessor[3] = check_sm100_device(); // Runtime device support return info; } std::vector get_aligned_shape(int64_t M, int64_t N, int64_t K) { // Return properly aligned dimensions for SM100 int64_t aligned_M = ((M + 127) / 128) * 128; int64_t aligned_N = ((N + 255) / 256) * 256; int64_t aligned_K = ((K + 63) / 64) * 64; return {aligned_M, aligned_N, aligned_K}; } torch::Tensor create_aligned_tensor(const std::vector &shape, torch::ScalarType dtype, torch::Device device) { // Create a tensor with SM100-aligned dimensions TORCH_CHECK(shape.size() == 2, "Shape must be 2D"); auto aligned_shape = get_aligned_shape(shape[0], shape[1], shape.size() > 2 ? shape[2] : 64); if (shape.size() == 2) { return torch::zeros({aligned_shape[0], aligned_shape[1]}, torch::TensorOptions().dtype(dtype).device(device)); } else { return torch::zeros({aligned_shape[0], aligned_shape[2]}, torch::TensorOptions().dtype(dtype).device(device)); } } // Python bindings PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "SM100 GEMM PyTorch Extension"; // Main GEMM function m.def("sm100_gemm_f16", &sm100_gemm_f16, "SM100 GEMM with FP16 inputs and FP32 output: D = alpha * A @ B^T + " "beta * C", py::arg("A"), py::arg("B"), py::arg("C"), py::arg("alpha") = 1.0f, py::arg("beta") = 0.0f); // Utility functions m.def("is_sm100_supported", &is_sm100_supported, "Check if SM100 support was compiled in"); m.def("check_sm100_device", &check_sm100_device, "Check if current GPU supports SM100 architecture"); m.def("get_device_info", &get_device_info, "Get device compute capability and SM100 support info"); m.def("get_aligned_shape", &get_aligned_shape, "Get SM100-aligned dimensions for given shape", py::arg("M"), py::arg("N"), py::arg("K")); m.def("create_aligned_tensor", &create_aligned_tensor, "Create tensor with SM100-aligned dimensions", py::arg("shape"), py::arg("dtype"), py::arg("device")); // Constants for alignment requirements m.attr("MMA_TILE_M") = 128; m.attr("MMA_TILE_N") = 256; m.attr("MMA_TILE_K") = 64; } ================================================ FILE: kernels/cuda/cutlass_gemm/broadcast_load_epilogue_c3x.hpp ================================================ /*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ // // This file is a modified excerpt of // include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp // from https://github.com/NVIDIA/cutlass v3.5.0 // It has been modified to support either row/column or scalar broadcasting // where the tensor being loaded from is always passed in via a device pointer. // This lets one compiled kernel handle all cases of per-tensor or // per-channel/per-token quantization. // // This interface also allows the scales to be passed in as tensors that // consistently reside on the device, which avoids an issue with a previous // implementation where scalars needed to be on the CPU since they // were passed in via float values. This created a potential performance hazard // if scales were initially on the device, and caused torch.compile graphs // breaks when moving scales to the CPU. // #pragma once // Turn off clang-format for the entire file to keep it close to upstream // clang-format off #include "cutlass/cutlass.h" #include "cutlass/arch/barrier.h" #include "cute/tensor.hpp" #include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" namespace cutlass::epilogue::fusion { using namespace cute; using namespace detail; // Row vector broadcast template< // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races int Stages, class CtaTileShapeMNK, class Element, class StrideMNL = Stride<_0,_1,_0>, int Alignment = 128 / sizeof_bits_v > struct Sm90RowOrScalarBroadcast { static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); static_assert( (cute::is_same_v>) || // row vector broadcast, e.g. per-col alpha/bias (cute::is_same_v>)); // batched row vector broadcast // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem struct SharedStorage { alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_row; }; // This struct has been modified to have a bool indicating that ptr_row is a // scalar that must be broadcast, instead of containing a scalar that is // valid if ptr_row is null. struct Arguments { Element const* ptr_row = nullptr; bool row_broadcast = true; StrideMNL dRow = {}; }; using Params = Arguments; template static constexpr Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { return args; } template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { return 0; } template static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { return cutlass::Status::kSuccess; } CUTLASS_HOST_DEVICE Sm90RowOrScalarBroadcast() { } CUTLASS_HOST_DEVICE Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) : params(params), smem_row(const_cast(shared_storage.smem_row.data())) { } Params params; Element* smem_row; CUTLASS_DEVICE bool is_producer_load_needed() const { return true; } CUTLASS_DEVICE bool is_C_load_needed() const { return false; } CUTLASS_DEVICE bool is_zero() const { return (!params.row_broadcast && *(params.ptr_row) == Element(0)); } template struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { CUTLASS_DEVICE ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params) : gRow(cute::forward(gRow)), sRow(cute::forward(sRow)), params(params) {} GTensor gRow; // (CTA_M,CTA_N) STensor sRow; // (CTA_M,CTA_N,PIPE) Params const& params; CUTLASS_DEVICE void begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) { if (!params.row_broadcast) { return; } if (issue_tma_load) { // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v / 8; cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); // Issue the TMA bulk copy auto bulk_copy = Copy_Atom{}.with(*full_mbarrier_ptr); // Filter so we don't issue redundant copies over stride-0 modes int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index))); } } }; template CUTLASS_DEVICE auto get_producer_load_callbacks(ProducerLoadArgs const& args) { auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; return ProducerLoadCallbacks( cute::move(gRow), cute::move(sRow), params); } template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params) : tCrRow(cute::forward(tCrRow)), tCsRow(cute::forward(tCsRow)), params(params) {} RTensor tCrRow; // (CPY,CPY_M,CPY_N) STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) Params const& params; CUTLASS_DEVICE void previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { if (!params.row_broadcast) { fill(tCrRow, *(params.ptr_row)); return; } if (epi_m == 0) { // Assumes M-major subtile loop // Filter so we don't issue redundant copies over stride-0 modes // (only works if 0-strides are in same location, which is by construction) int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow)); } } template CUTLASS_DEVICE Array visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { Array frg_row; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { frg_row[i] = tCrRow(epi_v * FragmentSize + i); } return frg_row; } }; template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy class... Args > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); Tensor tCsRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) sRow, args.epi_tile, args.tiled_copy, args.thread_idx); Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N) constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; return ConsumerStoreCallbacks( cute::move(tCrRow), cute::move(tCsRow), params); } }; ///////////////////////////////////////////////////////////////////////////////////////////////// // Column vector broadcast template< int Stages, class CtaTileShapeMNK, class Element, class StrideMNL = Stride<_1,_0,_0>, int Alignment = 128 / sizeof_bits_v > struct Sm90ColOrScalarBroadcast { static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); static_assert( (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem struct SharedStorage { }; // This struct has been modified to have a bool indicating that ptr_col is a // scalar that must be broadcast, instead of containing a scalar that is // valid if ptr_col is null. struct Arguments { Element const* ptr_col = nullptr; bool col_broadcast = true; StrideMNL dCol = {}; }; using Params = Arguments; template static constexpr Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { return args; } template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { return 0; } template static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { return cutlass::Status::kSuccess; } CUTLASS_DEVICE bool is_producer_load_needed() const { return false; } CUTLASS_DEVICE bool is_C_load_needed() const { return false; } CUTLASS_DEVICE bool is_zero() const { return (!params.col_broadcast && *(params.ptr_col) == Element(0)); } CUTLASS_HOST_DEVICE Sm90ColOrScalarBroadcast() { } CUTLASS_HOST_DEVICE Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) : params(params) { } Params params; template CUTLASS_DEVICE auto get_producer_load_callbacks(ProducerLoadArgs const& args) { return EmptyProducerLoadCallbacks{}; } template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params) : tCgCol(cute::forward(tCgCol)), tCrCol(cute::forward(tCrCol)), params(params) {} GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) Params const& params; CUTLASS_DEVICE void begin() { if (!params.col_broadcast) { fill(tCrCol, *(params.ptr_col)); return; } // Filter so we don't issue redundant copies over stride-0 modes // (only works if 0-strides are in same location, which is by construction) copy_aligned(filter(tCgCol), filter(tCrCol)); } template CUTLASS_DEVICE Array visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { Array frg_col; Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); } return frg_col; } }; template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy class... Args > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { auto [M, N, K, L] = args.problem_shape_mnkl; Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) return ConsumerStoreCallbacks( cute::move(tCgCol), cute::move(tCrCol), params); } }; } ================================================ FILE: kernels/cuda/cutlass_gemm/common.hpp ================================================ #pragma once #include "cutlass/cutlass.h" #include /** * Helper function for checking CUTLASS errors */ #define CUTLASS_CHECK(status) \ { \ TORCH_CHECK(status == cutlass::Status::kSuccess, \ cutlassGetStatusString(status)) \ } inline uint32_t next_pow_2(uint32_t const num) { if (num <= 1) return num; return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { int max_shared_mem_per_block_opt_in = 0; cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); return max_shared_mem_per_block_opt_in; } ================================================ FILE: kernels/cuda/cutlass_gemm/cutlass.cpp ================================================ #include #include void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales); torch::Tensor cutlass_scaled_mm(torch::Tensor a, torch::Tensor b, torch::Tensor a_scales, torch::Tensor b_scales) { auto acc_dtype = torch::kFloat16; auto options = torch::TensorOptions().dtype(acc_dtype).device(a.device()); torch::Tensor out = torch::empty({a.size(0), b.size(1)}, options); cutlass_scaled_mm_sm90(out, a, b, a_scales, b_scales); return out; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("cutlass_scaled_mm", &cutlass_scaled_mm, "CUTLASS Scaled Matrix Multiplication"); } ================================================ FILE: kernels/cuda/cutlass_gemm/cutlass_kernel.cu ================================================ // clang-format will break include orders // clang-format off #include #if defined CUDA_VERSION && CUDA_VERSION >= 12000 #include #include #include #include #include #include "cutlass/cutlass.h" #include "cute/tensor.hpp" #include "cute/atom/mma_atom.hpp" #include "cutlass/numeric_types.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "broadcast_load_epilogue_c3x.hpp" #include "common.hpp" // clang-format on using namespace cute; /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later. Epilogue functions can be defined to post-process the output before it is written to GPU memory. Epilogues must contain a public type named EVTCompute of type Sm90EVT, as well as a static prepare_args function that constructs an EVTCompute::Arguments struct. */ namespace { // A wrapper for the GEMM kernel that is used to guard against compilation on // architectures that will never use the kernel. The purpose of this is to // reduce the size of the compiled binary. // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef // into code that will be executed on the device where it is defined. template struct enable_sm90_or_later : Kernel { template CUTLASS_DEVICE void operator()(Args&&... args) { #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 Kernel::operator()(std::forward(args)...); #endif } }; /* This epilogue function defines a quantized GEMM operation similar to torch.scaled_mm_. A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or per-row. B can be quantized per-tensor or per-column. Any combination of per-tensor and per-row or column is supported. A and B must have symmetric quantization (zero point == 0). So the GEMM operation is D = (a_scales * A) (b_scales * B), where the scales are applied elementwise with numpy-style broadcasting. ScaleA and ScaleB define the epilogue functions that apply the scales for the A and B operands respectively. These scales may be either per-tensor or per row or column. */ template struct ScaledEpilogue { private: using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, Stride, Int<0>, Int<0>>>; using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, Stride, Int<1>, Int<0>>>; using Compute0 = cutlass::epilogue::fusion::Sm90Compute< cutlass::multiplies, float, float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; using Compute1 = cutlass::epilogue::fusion::Sm90Compute< cutlass::multiplies, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; public: using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; static ArgumentType prepare_args(torch::Tensor const& a_scales, torch::Tensor const& b_scales) { using ScaleA_Args = typename ScaleA::Arguments; using ScaleB_Args = typename ScaleB::Arguments; ScaleA_Args a_args{a_scales.data_ptr(), a_scales.numel() != 1, {}}; ScaleB_Args b_args{b_scales.data_ptr(), b_scales.numel() != 1, {}}; return ArgumentType{a_args, {b_args}}; } }; template typename Epilogue_, typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule> struct cutlass_3x_gemm { using ElementAB = ElementAB_; using ElementD = ElementD_; using ElementAcc = typename std::conditional, int32_t, float>::type; using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, ElementD, EpilogueSchedule>; using Epilogue = Epilogue_; using StrideD = Stride, Int<0>>; using ElementC = void; using StrideC = StrideD; using EVTCompute = typename Epilogue::EVTCompute; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, EpilogueSchedule, EVTCompute>::CollectiveOp; static constexpr size_t CEStorageSize = sizeof(typename CollectiveEpilogue::SharedStorage); using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< static_cast(CEStorageSize)>; // clang-format off using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, ElementAB, cutlass::layout::RowMajor, 16, ElementAB, cutlass::layout::ColumnMajor, 16, ElementAcc, TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp; // clang-format on using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::PersistentScheduler>>; struct GemmKernel : public KernelType {}; }; template void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, EpilogueArgs&&... epilogue_params) { using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; int32_t m = a.size(0); int32_t n = b.size(1); int32_t k = a.size(1); int64_t lda = a.stride(0); int64_t ldb = b.stride(1); int64_t ldc = out.stride(0); using StrideA = Stride, int64_t>; using StrideB = Stride, int64_t>; using StrideC = typename Gemm::StrideC; StrideA a_stride{lda, Int<1>{}, 0}; StrideB b_stride{ldb, Int<1>{}, 0}; StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; using GemmKernel = typename Gemm::GemmKernel; typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; auto a_ptr = static_cast(a.data_ptr()); auto b_ptr = static_cast(b.data_ptr()); typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, b_stride}; auto c_ptr = static_cast(out.data_ptr()); typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( std::forward(epilogue_params)...), c_ptr, c_stride, c_ptr, c_stride}; typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, prob_shape, mainloop_args, epilogue_args}; // Launch the CUTLASS GEMM kernel. using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; GemmOp gemm_op; // CUTLASS_CHECK(gemm_op.can_implement(args)); size_t workspace_size = gemm_op.get_workspace_size(args); auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); auto workspace = torch::empty(workspace_size, workspace_options); auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); } template typename Epilogue> struct sm90_fp8_config_default { // M in (128, inf) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = cutlass_3x_gemm; }; template typename Epilogue> struct sm90_fp8_config_M128 { // M in (64, 128] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = cutlass_3x_gemm; }; template typename Epilogue> struct sm90_fp8_config_M64 { // M in [1, 64] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _64, _128>; using ClusterShape = Shape<_1, _8, _1>; using Cutlass3xGemm = cutlass_3x_gemm; }; template typename Epilogue> struct sm90_int8_config_default { // For M > 128 and any N static_assert(std::is_same()); using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = cutlass_3x_gemm; }; template typename Epilogue> struct sm90_int8_config_M128 { // For M in (64, 128] and any N static_assert(std::is_same()); using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = cutlass_3x_gemm; }; template typename Epilogue> struct sm90_int8_config_M64 { // For M in (32, 64] and any N static_assert(std::is_same()); using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _64, _256>; using ClusterShape = Shape<_1, _1, _1>; using Cutlass3xGemm = cutlass_3x_gemm; }; template typename Epilogue> struct sm90_int8_config_M32_NBig { // For M in [1, 32] and N >= 8192 static_assert(std::is_same()); using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _128, _256>; using ClusterShape = Shape<_1, _4, _1>; using Cutlass3xGemm = cutlass_3x_gemm; }; template typename Epilogue> struct sm90_int8_config_M32_NSmall { // For M in [1, 32] and N < 8192 static_assert(std::is_same()); using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _64, _256>; using ClusterShape = Shape<_1, _8, _1>; using Cutlass3xGemm = cutlass_3x_gemm; }; } // namespace template typename Epilogue, typename... EpilogueArgs> void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); using Cutlass3xGemmDefault = typename sm90_fp8_config_default::Cutlass3xGemm; using Cutlass3xGemmM64 = typename sm90_fp8_config_M64::Cutlass3xGemm; using Cutlass3xGemmM128 = typename sm90_fp8_config_M128::Cutlass3xGemm; uint32_t const m = a.size(0); uint32_t const mp2 = std::max(static_cast(64), next_pow_2(m)); // next power of 2 if (mp2 <= 64) { // m in [1, 64] return cutlass_gemm_caller( out, a, b, std::forward(args)...); } else if (mp2 <= 128) { // m in (64, 128] return cutlass_gemm_caller( out, a, b, std::forward(args)...); } else { // m in (128, inf) return cutlass_gemm_caller( out, a, b, std::forward(args)...); } } template typename Epilogue, typename... EpilogueArgs> void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8); using Cutlass3xGemmDefault = typename sm90_int8_config_default::Cutlass3xGemm; using Cutlass3xGemmM128 = typename sm90_int8_config_M128::Cutlass3xGemm; using Cutlass3xGemmM64 = typename sm90_int8_config_M64::Cutlass3xGemm; using Cutlass3xGemmM32NBig = typename sm90_int8_config_M32_NBig::Cutlass3xGemm; using Cutlass3xGemmM32NSmall = typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; uint32_t const n = out.size(1); bool const is_small_n = n < 8192; uint32_t const m = a.size(0); uint32_t const mp2 = std::max(static_cast(32), next_pow_2(m)); // next power of 2 if (mp2 <= 32) { // m in [1, 32] if (is_small_n) { return cutlass_gemm_caller( out, a, b, std::forward(args)...); } else { return cutlass_gemm_caller( out, a, b, std::forward(args)...); } } else if (mp2 <= 64) { // m in (32, 64] return cutlass_gemm_caller( out, a, b, std::forward(args)...); } else if (mp2 <= 128) { // m in (64, 128] return cutlass_gemm_caller( out, a, b, std::forward(args)...); } else { // m in (128, inf) return cutlass_gemm_caller( out, a, b, std::forward(args)...); } } void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (a.dtype() == torch::kInt8) { TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { return cutlass_gemm_sm90_int8_dispatch( out, a, b, a_scales, b_scales); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); return cutlass_gemm_sm90_int8_dispatch( out, a, b, a_scales, b_scales); } } else { TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); if (out.dtype() == torch::kBFloat16) { return cutlass_gemm_sm90_fp8_dispatch< cutlass::float_e4m3_t, cutlass::bfloat16_t, ScaledEpilogue>( out, a, b, a_scales, b_scales); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); return cutlass_gemm_sm90_fp8_dispatch( out, a, b, a_scales, b_scales); } } } #endif ================================================ FILE: kernels/cuda/cutlass_gemm/readme.md ================================================ Currently the CPP extension builds with Cutlass 3.5.1 (credit to @SamirMoustafa for the update). 3.6 will fail atm due to a refactor in the TMA descriptor. ================================================ FILE: kernels/cuda/cutlass_gemm/setup.py ================================================ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension setup( name='cutlass_gemm', ext_modules=[ CUDAExtension( name='pingpong_gemm', sources=['cutlass.cpp', 'cutlass_kernel.cu'], extra_compile_args={ 'nvcc': [ '-DNDEBUG', '-O3', '-g', '-lineinfo', '--keep', '--ptxas-options=--warn-on-local-memory-usage', '--ptxas-options=--warn-on-spills', '--resource-usage', '--source-in-ptx', '-DCUTLASS_DEBUG_TRACE_LEVEL=1', '-gencode=arch=compute_90a, code=sm_90a', ] }, include_dirs=[ '/home/adhoq26/cutlass/include', '/home/adhoq26/cutlass/tools/util/include', ], libraries=['cuda'], library_dirs=['/usr/local/cuda-12.4/lib64'], ) ], cmdclass={ 'build_ext': BuildExtension } ) ================================================ FILE: kernels/cuda/cutlass_gemm/test_cutlass_gemm.py ================================================ from pingpong_gemm import cutlass_scaled_mm import torch m, k, n = 16, 4096, 4096 dtype = torch.float8_e4m3fn out_dtype = torch.float16 a = torch.empty(m, k).normal_(mean=0.0, std=0.5).to(dtype=dtype, device='cuda') bt = torch.empty(n, k).normal_(mean=0.0, std=0.5).to(dtype=dtype, device='cuda').t() scale_a = torch.ones((1,)).to(dtype=torch.float32, device='cuda') scale_b = torch.ones((1,)).to(dtype=torch.float32, device='cuda') y = cutlass_scaled_mm(a, bt, scale_a, scale_b) print(y) ================================================ FILE: kernels/cuda/inference/README.md ================================================ cuda kernels ================================================ FILE: kernels/cuda/inference/hadamard_transform/hadamard_transform.cpp ================================================ #include #include #include #include using namespace torch::indexing; template void run_fht(void* a, void* out, uint32_t numel, uint32_t had_size, cudaStream_t stream); constexpr bool is_power_of_two(uint32_t x) { return x && !(x & (x - 1)); } torch::Tensor hadamard_transform(at::Tensor& in, bool inplace) { auto dtype = in.scalar_type(); TORCH_CHECK(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16, "Only fp16 and bf16 supported currently"); TORCH_CHECK(in.is_cuda()); const int had_size = in.size(-1); TORCH_CHECK(is_power_of_two(had_size) && (had_size <= (1U << 15)), "Only power of two Hadamard sizes up to 2^15 are supported, got ", had_size); const auto res_shape = in.sizes(); torch::Tensor x = in.reshape({-1, had_size}); auto numel = in.numel(); if (numel % 256 != 0) { x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, 0, 0, (256 - numel % 256) / had_size})); } if (x.stride(-1) != 1) { x = x.contiguous(); } torch::Tensor out = inplace ? x : torch::empty_like(x); at::cuda::CUDAGuard device_guard{(char)x.get_device()}; auto stream = at::cuda::getCurrentCUDAStream().stream(); if (dtype == torch::ScalarType::Half) { run_fht(x.data_ptr(), out.data_ptr(), x.numel(), had_size, stream); } else { run_fht(x.data_ptr(), out.data_ptr(), x.numel(), had_size, stream); } if (numel % 256 != 0) { out = out.index({Slice(0, numel / had_size)}); } if (inplace && out.data_ptr() != in.data_ptr()) { in.copy_(out.view(res_shape)); return in; } return out.reshape(res_shape); } namespace py = pybind11; PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("hadamard_transform", &hadamard_transform, "A function to perform a fast Hadamard transform", py::arg("x"), py::arg("inplace")=false); } ================================================ FILE: kernels/cuda/inference/hadamard_transform/hadamard_transform_cuda.cu ================================================ #include #include #include #include #include #include #ifndef __CUDACC__ #define __launch_bounds__(x,y) #endif #define MAX_WARPS_PER_SM 48 #define MIN(a, b) ((a) < (b) ? (a) : (b)) typedef uint32_t b32; typedef uint16_t b16; constexpr int launch_configs_big[7][3] = { // default {2, 1, 24}, {2, 2, 16}, {2, 4, 8}, {2, 8, 4}, {2, 16, 3}, {4, 16, 2}, {8, 16, 1} // // extra coalescing // {2, 1, 24}, // {2, 2, 16}, // {2, 4, 8}, // {2, 8, 4}, // {4, 8, 3}, // {8, 8, 2}, // {16, 8, 1} // // less coalescing // {2, 1, 24}, // {2, 2, 16}, // {2, 4, 8}, // {2, 8, 4}, // {1, 32, 1}, // {2, 32, 1}, // {4, 32, 1} }; // a 4x2, b 2x2, c 2x2 template __device__ __forceinline__ void mma_m16_n8_k16_b16_b16_b16_noacc(b32 a0, b32 a1, b32 a2, b32 a3, b32 b0, b32 b1, b32& c0, b32& c1){ static_assert(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16); // d, a, b, c b32 zero = 0; if constexpr(dtype == torch::ScalarType::Half) { asm ( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " "{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n\t" : "=r"(c0), "=r"(c1) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(zero), "r"(zero) ); } else { b32 temp0, temp1, temp2, temp3; asm ( "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n\t" : "=r"(temp0), "=r"(temp1), "=r"(temp2), "=r"(temp3) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(zero), "r"(zero), "r"(zero), "r"(zero) ); asm ("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c0) : "r"(temp1), "r"(temp0)); asm ("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c1) : "r"(temp3), "r"(temp2)); } } // a 4x2, b 4x2, c 4x2 template __device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc(b32 a0, b32 a1, b32 a2, b32 a3, b32 b0, b32 b1, b32 b2, b32 b3, b32& c0, b32& c1, b32& c2, b32& c3){ mma_m16_n8_k16_b16_b16_b16_noacc(a0, a1, a2, a3, b0, b1, c0, c1); mma_m16_n8_k16_b16_b16_b16_noacc(a0, a1, a2, a3, b2, b3, c2, c3); } __device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(b32& a0) { asm ( "movmatrix.sync.aligned.m8n8.trans.b16 " "%0, %1;\n\t" : "=r"(a0) : "r"(a0) ); } #define p_p(i) ((val_1p[i] & 0x0000FFFF) | val_1p[i] << 16) #define p_n(i) ((val_1p[i] & 0x0000FFFF) | val_1n[i] << 16) #define n_p(i) ((val_1n[i] & 0x0000FFFF) | val_1p[i] << 16) #define n_n(i) ((val_1n[i] & 0x0000FFFF) | val_1n[i] << 16) template __global__ void __launch_bounds__(32 * warps_per_block, blocks_per_sm) // a is column major, b is row major hadamard_transform_kernel(b16* a, b16* out, int total_num_chunks) { static_assert(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16, "Only fp16 and bf16 supported currently"); b32 b_frag_all[num_chunks][4]; // for all chunks, holds matrix fragment (which takes 4 regs of b16x2 * 32 threads) uint blockid = blockIdx.x * warps_per_block + threadIdx.x / 32; uint threadid = threadIdx.x % 32; extern __shared__ b32 bfrag_arr[]; // num_chunks * warps_per_block * 128 int real_num_chunks = ((blockid + 1) * num_chunks) > total_num_chunks ? (total_num_chunks - (blockid * num_chunks)) : num_chunks; int diff_num_chunks = real_num_chunks - num_chunks; b32* a_start_ptr = (b32*) (a + blockid * num_chunks * 256); // offset a to where this warp starts b32* out_start_ptr = (b32*) (out + blockid * num_chunks * 256); b32* a_ptr = a_start_ptr + threadid * 4; b32* b_frag_ptr = bfrag_arr + (blockid % warps_per_block) * num_chunks * 128 + threadid * 4; #if (__CUDA_ARCH__ < 900) // SM80, SM89 uint64_t cache_policy; asm volatile( "createpolicy.fractional.L2::evict_first.b64 %0, 1.0;\n" : "=l"(cache_policy) ); #endif #pragma unroll for (int k = 0; k < num_chunks; k++) { size_t shared_ptr = __cvta_generic_to_shared(b_frag_ptr); #if (__CUDA_ARCH__ >= 900) // SM90 asm volatile( "cp.async.cg.shared.global [%0], [%1], 16;\n" "cp.async.commit_group;\n" :: "l"(shared_ptr), "l"(a_ptr) ); #else // SM80, SM89 asm volatile( "cp.async.cg.shared.global.L2::cache_hint.L2::256B [%0], [%1], 16, %2;\n" "cp.async.commit_group;\n" :: "l"(shared_ptr), "l"(a_ptr), "l"(cache_policy) ); #endif a_ptr += 128; b_frag_ptr += 128; } // generate hadamard 16x16 (up to 2 of them) constexpr b16 fp16_1p[4] = {0b0011100110101000, 0b0011100000000000, 0b0011010110101000, 0b0011010000000000}; constexpr b16 fp16_1n[4] = {0b1011100110101000, 0b1011100000000000, 0b1011010110101000, 0b1011010000000000}; constexpr b16 bf16_1p[4] = {0b0011111100110101, 0b0011111100000000, 0b0011111010110101, 0b0011111010000000}; constexpr b16 bf16_1n[4] = {0b1011111100110101, 0b1011111100000000, 0b1011111010110101, 0b1011111010000000}; #define val_type_1p(i) (((dtype) == torch::ScalarType::Half) ? (fp16_1p[i]) : (bf16_1p[i])) #define val_type_1n(i) (((dtype) == torch::ScalarType::Half) ? (fp16_1n[i]) : (bf16_1n[i])) constexpr b16 val_1p[4] = {val_type_1p(0), val_type_1p(1), val_type_1p(2), val_type_1p(3)}; constexpr b16 val_1n[4] = {val_type_1n(0), val_type_1n(1), val_type_1n(2), val_type_1n(3)}; constexpr b32 p_p[4] = {p_p(0), p_p(1), p_p(2), p_p(3)}; constexpr b32 p_n[4] = {p_n(0), p_n(1), p_n(2), p_n(3)}; constexpr b32 n_p[4] = {n_p(0), n_p(1), n_p(2), n_p(3)}; constexpr b32 n_n[4] = {n_n(0), n_n(1), n_n(2), n_n(3)}; const b32 had_16_p1[4][4] = { { 0b10001000010001000010001000010001, 0b00000000000000000000000000000000, 0b00000000000000000000000000000000, 0b10001000010001000010001000010001 }, { 0b11001100100010000011001100100010, 0b00000000000000000000000000000000, 0b00000000000000000000000000000000, 0b11001100100010000011001100100010 }, { 0b11111111101010101100110010011001, 0b00000000000000000000000000000000, 0b00000000000000000000000000000000, 0b11111111101010101100110010011001 }, { 0b11111111101010101100110010011001, 0b11111111101010101100110010011001, 0b11111111101010101100110010011001, 0b00000000010101010011001101100110 } }; const b32 had_16_p2[4][4] = { { 0b10000000010000000010000000010000, 0b00000000000000000000000000000000, 0b00000000000000000000000000000000, 0b10000000010000000010000000010000 }, { 0b11000000100001000011000000100001, 0b00000000000000000000000000000000, 0b00000000000000000000000000000000, 0b11000000100001000011000000100001 }, { 0b11110000101001011100001110010110, 0b00000000000000000000000000000000, 0b00000000000000000000000000000000, 0b11110000101001011100001110010110 }, { 0b11110000101001011100001110010110, 0b11110000101001011100001110010110, 0b11110000101001011100001110010110, 0b00001111010110100011110001101001 } }; const b32 had_16_mask[3][4] = { { 0b10001000010001000010001000010001, 0b00000000000000000000000000000000, 0b00000000000000000000000000000000, 0b10001000010001000010001000010001 }, { 0b11001100110011000011001100110011, 0b00000000000000000000000000000000, 0b00000000000000000000000000000000, 0b11001100110011000011001100110011 }, { 0b11111111111111111111111111111111, 0b00000000000000000000000000000000, 0b00000000000000000000000000000000, 0b11111111111111111111111111111111 } }; b32 had_frag[8]; #pragma unroll for (int i = 0; i < 2; i++) { int c_log_h = (i == 0) ? MIN(4, log_had_size) : log_had_size % 4; #pragma unroll for (int j = 0; j < 4; j++) { if (c_log_h < 4) { bool mask = had_16_mask[c_log_h - 1][j] & (1 << (31 - threadid)); if (!mask) { had_frag[i * 4 + j] = 0; continue; } } bool pred1 = had_16_p1[c_log_h - 1][j] & (1 << (31 - threadid)); bool pred2 = had_16_p2[c_log_h - 1][j] & (1 << (31 - threadid)); b32 val = pred1 ? (pred2 ? p_p[c_log_h - 1] : p_n[c_log_h - 1]) : (pred2 ? n_p[c_log_h - 1] : n_n[c_log_h - 1]); had_frag[i * 4 + j] = val; } if constexpr(log_had_size <= 4 || log_had_size % 4 == 0) break; } // log had size above 8, only used for above 2^8 = 256 size constexpr int part8_log_had_size = log_had_size - 8; b32* a_chunk_ptr = a_start_ptr; // first chunk starts at this warp's data starts b32* out_chunk_ptr = out_start_ptr; #pragma unroll for (int l = 0; l < 2; l++) { if constexpr(log_had_size <= 8) { // l == 0 guaranteed, redundant simplified version of else body, to help compiler warnings b_frag_ptr = bfrag_arr + (blockid % warps_per_block) * num_chunks * 128; } else { b_frag_ptr = bfrag_arr + (blockid % warps_per_block) * num_chunks * (l == 0 ? 128 : (128 >> part8_log_had_size)); } if (l == 1) { if constexpr(log_had_size > 8) { __syncthreads(); // sync between first and second iterations if above size 256 if constexpr(log_had_size >= 12) { // sizes 4k and above // a + threadblock offset + warp offset // can then index into all chunks owned by this warp b32* store = bfrag_arr + (128 >> part8_log_had_size) * (num_chunks * (blockid % warps_per_block)); #pragma unroll for (int j = 0; j < 4; j++) { #pragma unroll for (int k = 0; k < num_chunks; k++) { // here, j represents register, and k represents 8-offset/chunk int real_chunk_num = (num_chunks - (threadid % num_chunks) + k) % num_chunks; // chunk at which you have target thread #'s data int real_thread_id = (threadid / num_chunks) * num_chunks + k; // target thread # int chunk_idx = 128 * real_chunk_num; // index due to fetching from another chunk (chunk in which this thread has the target thread's original data) int thread_group_idx = (real_thread_id / 4) * 16; // index due to fetching from another group of num_chunk threads (since shuffle is between num_chunk threads) int thread_idx = (real_thread_id % 4) * 2; // index due to original thread's position within the group of num_chunk threads int reg_idx = (j / 2) * 8 + (j % 2); // index due to target register int idx = chunk_idx + thread_group_idx + thread_idx + reg_idx; // final index // fix idx for majorness int rowidx = idx % (1 << part8_log_had_size); int colidx = idx >> part8_log_had_size; // store[rowidx * 128 + colidx] = data; b32 data = store[rowidx * 128 + colidx]; // compiler generates excessive instructions, so we manually do the if statement #pragma unroll for (int i = 0; i < num_chunks; i++) { asm volatile ( "{\n\t" " .reg .pred p0;\n\t" " setp.eq.u32 p0, %1, %2;\n\t" " @p0 mov.b32 %0, %3;\n\t" "}\n\t" : "+r"(b_frag_all[i][j]) // Output operand %0 : "r"(real_chunk_num), "r"(i), "r"(data) // Input operands %1, %2, %3 ); } } } #pragma unroll for (int j = 0; j < 4; j++) { #pragma unroll for (int k = 1; k < num_chunks; k++) { int threadid_contig = threadid % num_chunks; int threadid_mul = threadid / num_chunks; int threadid2 = (threadid_contig + num_chunks - k) % num_chunks + threadid_mul * num_chunks; // thread to give your data to b_frag_all[k][j] = __shfl_sync(0xFFFFFFFF, b_frag_all[k][j], threadid2); } } } } } #pragma unroll for (int k = 0; k < num_chunks; k++) { if constexpr(enable_mask) { if (k >= real_num_chunks) break; } if (l == 0) { // bad fix for k not being recognized as a constexpr by compiler // asm("cp.async.wait_group %0;\n" :: "n"(num_chunks - k - 1)); #define SWITCH_WAIT_ASYNC_LOAD_GROUP(i) case i: asm volatile("cp.async.wait_group %0;\n" :: "n"(num_chunks - i - 1)); break; if constexpr(enable_mask) { switch(k + diff_num_chunks) { SWITCH_WAIT_ASYNC_LOAD_GROUP(0) SWITCH_WAIT_ASYNC_LOAD_GROUP(1) SWITCH_WAIT_ASYNC_LOAD_GROUP(2) SWITCH_WAIT_ASYNC_LOAD_GROUP(3) SWITCH_WAIT_ASYNC_LOAD_GROUP(4) SWITCH_WAIT_ASYNC_LOAD_GROUP(5) SWITCH_WAIT_ASYNC_LOAD_GROUP(6) SWITCH_WAIT_ASYNC_LOAD_GROUP(7) SWITCH_WAIT_ASYNC_LOAD_GROUP(8) SWITCH_WAIT_ASYNC_LOAD_GROUP(9) SWITCH_WAIT_ASYNC_LOAD_GROUP(10) SWITCH_WAIT_ASYNC_LOAD_GROUP(11) SWITCH_WAIT_ASYNC_LOAD_GROUP(12) SWITCH_WAIT_ASYNC_LOAD_GROUP(13) SWITCH_WAIT_ASYNC_LOAD_GROUP(14) SWITCH_WAIT_ASYNC_LOAD_GROUP(15) SWITCH_WAIT_ASYNC_LOAD_GROUP(16) SWITCH_WAIT_ASYNC_LOAD_GROUP(17) SWITCH_WAIT_ASYNC_LOAD_GROUP(18) SWITCH_WAIT_ASYNC_LOAD_GROUP(19) SWITCH_WAIT_ASYNC_LOAD_GROUP(20) SWITCH_WAIT_ASYNC_LOAD_GROUP(21) SWITCH_WAIT_ASYNC_LOAD_GROUP(22) SWITCH_WAIT_ASYNC_LOAD_GROUP(23) SWITCH_WAIT_ASYNC_LOAD_GROUP(24) SWITCH_WAIT_ASYNC_LOAD_GROUP(25) SWITCH_WAIT_ASYNC_LOAD_GROUP(26) SWITCH_WAIT_ASYNC_LOAD_GROUP(27) SWITCH_WAIT_ASYNC_LOAD_GROUP(28) SWITCH_WAIT_ASYNC_LOAD_GROUP(29) SWITCH_WAIT_ASYNC_LOAD_GROUP(30) SWITCH_WAIT_ASYNC_LOAD_GROUP(31) } } else { switch(k) { SWITCH_WAIT_ASYNC_LOAD_GROUP(0) SWITCH_WAIT_ASYNC_LOAD_GROUP(1) SWITCH_WAIT_ASYNC_LOAD_GROUP(2) SWITCH_WAIT_ASYNC_LOAD_GROUP(3) SWITCH_WAIT_ASYNC_LOAD_GROUP(4) SWITCH_WAIT_ASYNC_LOAD_GROUP(5) SWITCH_WAIT_ASYNC_LOAD_GROUP(6) SWITCH_WAIT_ASYNC_LOAD_GROUP(7) SWITCH_WAIT_ASYNC_LOAD_GROUP(8) SWITCH_WAIT_ASYNC_LOAD_GROUP(9) SWITCH_WAIT_ASYNC_LOAD_GROUP(10) SWITCH_WAIT_ASYNC_LOAD_GROUP(11) SWITCH_WAIT_ASYNC_LOAD_GROUP(12) SWITCH_WAIT_ASYNC_LOAD_GROUP(13) SWITCH_WAIT_ASYNC_LOAD_GROUP(14) SWITCH_WAIT_ASYNC_LOAD_GROUP(15) SWITCH_WAIT_ASYNC_LOAD_GROUP(16) SWITCH_WAIT_ASYNC_LOAD_GROUP(17) SWITCH_WAIT_ASYNC_LOAD_GROUP(18) SWITCH_WAIT_ASYNC_LOAD_GROUP(19) SWITCH_WAIT_ASYNC_LOAD_GROUP(20) SWITCH_WAIT_ASYNC_LOAD_GROUP(21) SWITCH_WAIT_ASYNC_LOAD_GROUP(22) SWITCH_WAIT_ASYNC_LOAD_GROUP(23) SWITCH_WAIT_ASYNC_LOAD_GROUP(24) SWITCH_WAIT_ASYNC_LOAD_GROUP(25) SWITCH_WAIT_ASYNC_LOAD_GROUP(26) SWITCH_WAIT_ASYNC_LOAD_GROUP(27) SWITCH_WAIT_ASYNC_LOAD_GROUP(28) SWITCH_WAIT_ASYNC_LOAD_GROUP(29) SWITCH_WAIT_ASYNC_LOAD_GROUP(30) SWITCH_WAIT_ASYNC_LOAD_GROUP(31) } } } if (l == 0) { // loading for the first iteration // thread 0 loads [t0r0, t16r1, t0r2, t16r3] // thread 16 loads [t0r1, t16r0, t0r3, t16r2] // allows full coalescing, same for t1/t17, t2/t18, etc. #pragma unroll for (int j = 0; j < 4; j++) { int reg = ((threadid & 16) == 0) ? j : (j / 2 * 2 + (1 - j % 2)); int real_thread_id = (reg == 0 || reg == 2) ? threadid : (threadid ^ 16); int real_row = real_thread_id % 4; int real_col = real_thread_id / 4; b_frag_all[k][j] = b_frag_ptr[(real_row + (reg % 2) * 4) + (real_col + (j / 2) * 8) * 8]; } // for t16 swap r0/r1 and r2/r3 to have [t16r0, t0r1, t16r2, t0r3] // so registers are in right order, same for t17, t18, etc. if ((threadid & 16) != 0) { b32 temp = b_frag_all[k][0]; b_frag_all[k][0] = b_frag_all[k][1]; b_frag_all[k][1] = temp; temp = b_frag_all[k][2]; b_frag_all[k][2] = b_frag_all[k][3]; b_frag_all[k][3] = temp; } // t0 and t16 swap r1 and r3 to have their own data, // same for t1/t17, t2/18, etc. #pragma unroll for (int j = 1; j < 4; j += 2) { b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], 16); } } else if constexpr(log_had_size > 8) { // condition is redundant to help compiler warnings if constexpr(log_had_size < 12) { // sizes 512, 1k, and 2k // for 512: // thread 0 loads [t0r0, t0r1, t16r2, t16r3] // thread 16 loads [t0r2, t0r3, t16r0, t16r1] // same for t1/t17, t2/t18, etc. // for 1k and 2k: // thread 0 loads [t0r0, t0r1, t1r2, t1r3] // thread 1 loads [t0r2, t0r3, t1r0, t1r1] // same for t2/t3, t4/t5, etc. // allows full coalescing for 512 and 1k, 16x coalescing for 2k constexpr int xor_val = log_had_size == 9 ? 16 : 1; #pragma unroll for (int j = 0; j < 4; j++) { int reg = ((threadid & xor_val) == 0) ? j : (j + 2) % 4; int real_thread_id = reg < 2 ? threadid : (threadid ^ xor_val); int idx = (real_thread_id / 4 * 16) + (real_thread_id % 4 * 2) + (reg / 2 * 8) + (reg % 2); int rowidx = idx % (1 << part8_log_had_size); int colidx = idx >> part8_log_had_size; b_frag_all[k][j] = b_frag_ptr[rowidx * 128 + colidx]; } if ((threadid & xor_val) != 0) { b32 temp = b_frag_all[k][0]; b_frag_all[k][0] = b_frag_all[k][2]; b_frag_all[k][2] = temp; temp = b_frag_all[k][1]; b_frag_all[k][1] = b_frag_all[k][3]; b_frag_all[k][3] = temp; } #pragma unroll for (int j = 2; j < 4; j++) { b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], xor_val); } } } if (l == 1) { // for second iteration, we load 2 consecutive b16s (1 b32) per register, // but tensor core register layout requires 2 b16s that are in the // same column/consecutive rows to be in the same register, so do the swap b32 f0 = ((b_frag_all[k][1] & 0xFFFF) << 16) | (b_frag_all[k][0] & 0xFFFF); b32 f1 = ((b_frag_all[k][3] & 0xFFFF) << 16) | (b_frag_all[k][2] & 0xFFFF); b32 f2 = (b_frag_all[k][1] & 0xFFFF0000) | (b_frag_all[k][0] >> 16); b32 f3 = (b_frag_all[k][3] & 0xFFFF0000) | (b_frag_all[k][2] >> 16); b_frag_all[k][0] = f0; b_frag_all[k][1] = f1; b_frag_all[k][2] = f2; b_frag_all[k][3] = f3; } #pragma unroll for(int i = 0, remaining_log_had_size = log_had_size - l * 8; i < 2 && remaining_log_had_size > 0; i++) { int had_off = ((remaining_log_had_size < 4) && !(log_had_size <= 4 || log_had_size % 4 == 0)) ? 4 : 0; mma_m16_n16_k16_b16_b16_b16_noacc(had_frag[had_off + 0], had_frag[had_off + 1], had_frag[had_off + 2], had_frag[had_off + 3], b_frag_all[k][0], b_frag_all[k][1], b_frag_all[k][2], b_frag_all[k][3], b_frag_all[k][0], b_frag_all[k][1], b_frag_all[k][2], b_frag_all[k][3]); remaining_log_had_size -= 4; if (remaining_log_had_size <= 0 && i == 0) { // TODO: consider different storing so no need for transpose matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][0]); matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][1]); matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][2]); matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][3]); } else { // swap and use output directly as b_frag for next iteration as an actually free transpose b32 temp = b_frag_all[k][1]; b_frag_all[k][1] = b_frag_all[k][2]; b_frag_all[k][2] = temp; } } if (l == 1) { // invert swap from above for second iteration b32 f0 = ((b_frag_all[k][2] & 0xFFFF) << 16) | (b_frag_all[k][0] & 0xFFFF); b32 f1 = (b_frag_all[k][2] & 0xFFFF0000) | (b_frag_all[k][0] >> 16); b32 f2 = ((b_frag_all[k][3] & 0xFFFF) << 16) | (b_frag_all[k][1] & 0xFFFF); b32 f3 = (b_frag_all[k][3] & 0xFFFF0000) | (b_frag_all[k][1] >> 16); b_frag_all[k][0] = f0; b_frag_all[k][1] = f1; b_frag_all[k][2] = f2; b_frag_all[k][3] = f3; } if (l == 0) { // inverse of coalesced load for first iteration to store result #pragma unroll for (int j = 1; j < 4; j += 2) { b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], 16); } if ((threadid & 16) != 0) { b32 temp = b_frag_all[k][0]; b_frag_all[k][0] = b_frag_all[k][1]; b_frag_all[k][1] = temp; temp = b_frag_all[k][2]; b_frag_all[k][2] = b_frag_all[k][3]; b_frag_all[k][3] = temp; } // if only going up to 256 size, store directly back to global memory, // otherwise store back to shared memory for next iteration b32* store = (log_had_size <= 8) ? out_chunk_ptr : b_frag_ptr; #pragma unroll for (int j = 0; j < 4; j++) { int reg = ((threadid & 16) == 0) ? j : (j / 2 * 2 + (1 - j % 2)); int real_thread_id = (reg == 0 || reg == 2) ? threadid : (threadid ^ 16); int real_row = real_thread_id % 4; int real_col = real_thread_id / 4; store[(real_row + (reg % 2) * 4) + (real_col + (reg / 2) * 8) * 8] = b_frag_all[k][j]; } } else if constexpr(log_had_size > 8) { // condition is redundant to help compiler warnings if (log_had_size < 12) { // inverse of coalesced load for sizes 512, 1k and 2k to store result constexpr int xor_val = log_had_size == 9 ? 16 : 1; #pragma unroll for (int j = 2; j < 4; j++) { b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], xor_val); } if ((threadid & xor_val) != 0) { b32 temp = b_frag_all[k][0]; b_frag_all[k][0] = b_frag_all[k][2]; b_frag_all[k][2] = temp; temp = b_frag_all[k][1]; b_frag_all[k][1] = b_frag_all[k][3]; b_frag_all[k][3] = temp; } b32* store = (b32*)(out + (blockid / warps_per_block) * (num_chunks * warps_per_block) * 256 + (256 >> part8_log_had_size) * (num_chunks * (blockid % warps_per_block) + k)); #pragma unroll for (int j = 0; j < 4; j++) { int reg = ((threadid & xor_val) == 0) ? j : (j + 2) % 4; b32 data = b_frag_all[k][j]; int real_thread_id = reg < 2 ? threadid : (threadid ^ xor_val); int idx = (real_thread_id / 4 * 16) + (real_thread_id % 4 * 2) + (reg / 2 * 8) + (reg % 2); int rowidx = idx % (1 << part8_log_had_size); int colidx = idx >> part8_log_had_size; store[rowidx * 128 + colidx] = data; } } // for size 4k and above, wait to process all chunks so a final store can be performed coalesced } a_chunk_ptr += 128; // (only affects first 256 size) move on to next chunk by skipping 256 elements in b16 (= 128 in b32) out_chunk_ptr += 128; if constexpr(log_had_size > 8) { b_frag_ptr += (l == 0 ? 128 : (128 >> part8_log_had_size)); } else { // else is redundant, simplified version of if body, to help compiler warnings b_frag_ptr += 128; } } if (log_had_size <= 8) break; } if constexpr(log_had_size >= 12) { // for sizes 4k and above, perform final coalesced store after processing all chunks #pragma unroll for (int j = 0; j < 4; j++) { #pragma unroll for (int k = 1; k < num_chunks; k++) { int threadid_contig = threadid % num_chunks; int threadid_mul = threadid / num_chunks; int threadid2 = (threadid_contig + k) % num_chunks + threadid_mul * num_chunks; // thread to give your data to b_frag_all[k][j] = __shfl_sync(0xFFFFFFFF, b_frag_all[k][j], threadid2); } } // a + threadblock offset + warp offset // can then index into all chunks owned by this warp b32* store = bfrag_arr + (128 >> part8_log_had_size) * (num_chunks * (blockid % warps_per_block)); #pragma unroll for (int j = 0; j < 4; j++) { #pragma unroll for (int k = 0; k < num_chunks; k++) { // here, j represents register, and k represents 8-offset/chunk int real_chunk_num = (num_chunks - (threadid % num_chunks) + k) % num_chunks; // chunk at which you have target thread #'s data // b32 data = b_frag_all[real_chunk_num][j]; // target thread data b32 data; #pragma unroll for (int i = 0; i < num_chunks; i++) { if (real_chunk_num == i) data = b_frag_all[i][j]; } int real_thread_id = (threadid / num_chunks) * num_chunks + k; // target thread # int chunk_idx = 128 * real_chunk_num; // index due to fetching from another chunk (chunk in which this thread has the target thread's original data) int thread_group_idx = (real_thread_id / 4) * 16; // index due to fetching from another group of num_chunk threads (since shuffle is between num_chunk threads) int thread_idx = (real_thread_id % 4) * 2; // index due to original thread's position within the group of num_chunk threads int reg_idx = (j / 2) * 8 + (j % 2); // index due to target register int idx = chunk_idx + thread_group_idx + thread_idx + reg_idx; // final index // fix idx for majorness int rowidx = idx % (1 << part8_log_had_size); int colidx = idx >> part8_log_had_size; store[rowidx * 128 + colidx] = data; } } __syncthreads(); store = ((b32*) out) + (blockid / warps_per_block) * (num_chunks * warps_per_block) * 128; int4* store4 = (int4*) store; int4* bfrag_arr4 = (int4*) bfrag_arr; // flush smem, simply linearly write to store // always divisible by 128*32b, so (32*4)*32b is ok #pragma unroll for (int warp_off = 0; warp_off < (num_chunks * warps_per_block * 128 / 4); warp_off += 32 * warps_per_block) { int total_off = warp_off + threadid + (blockid % warps_per_block) * 32; store4[total_off] = bfrag_arr4[total_off]; } } } constexpr int ceil_div(int a, int b) { return (a + b - 1) / b; } template void __forceinline__ run_kernel(b16* a_mat, b16* out, int num_chunks, cudaStream_t stream) { int shared_size = chunks_per_warp * warps_per_block * 128 * 4; dim3 block_size = 32 * warps_per_block; #define CHECK_SHARED_LIM() { \ if (shared_size > 48 * 1024) { \ C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536)); \ } \ } \ if constexpr(check_masking) { if (num_chunks % (chunks_per_warp * warps_per_block) != 0) { dim3 grid_size = ceil_div(ceil_div(num_chunks, chunks_per_warp), warps_per_block); auto kernel = hadamard_transform_kernel; CHECK_SHARED_LIM(); kernel<<>>(a_mat, out, num_chunks); } else { dim3 grid_size = num_chunks / chunks_per_warp / warps_per_block; auto kernel = hadamard_transform_kernel; CHECK_SHARED_LIM(); kernel<<>>(a_mat, out, num_chunks); } } else { dim3 grid_size = num_chunks / chunks_per_warp / warps_per_block; auto kernel = hadamard_transform_kernel; CHECK_SHARED_LIM(); kernel<<>>(a_mat, out, num_chunks); } C10_CUDA_KERNEL_LAUNCH_CHECK(); } template void run_fht(void* a_mat_ptr, void* out_ptr, uint32_t numel, uint32_t had_size, cudaStream_t stream) { uint32_t num_chunks = numel / 256; // caller required to ensure divisible by 256 // for size 256, use (2, 1) // for size 32k use (8, 16) constexpr int chunks_per_warp_small = 1;// 8; constexpr int warps_per_block_small = 1;//2;//16; constexpr int blocks_per_sm_small = 24; constexpr int chunks_per_warp_large = 2; constexpr int warps_per_block_large = 1; constexpr int blocks_per_sm_large = 24; // constexpr torch::ScalarType dtype = torch::ScalarType::Half; b16* a_mat = (b16*) a_mat_ptr; b16* out = (b16*) out_ptr; if (numel <= 256) { switch (had_size) { case (1<<1): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<2): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<3): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<4): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<5): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<6): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<7): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<8): run_kernel(a_mat, out, num_chunks, stream); break; } } else { switch (had_size) { case (1<<1): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<2): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<3): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<4): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<5): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<6): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<7): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<8): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<9): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<10): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<11): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<12): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<13): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<14): run_kernel(a_mat, out, num_chunks, stream); break; case (1<<15): run_kernel(a_mat, out, num_chunks, stream); break; } } } template void run_fht(void* a_mat_ptr, void* out_ptr, uint32_t numel, uint32_t had_size, cudaStream_t stream); template void run_fht(void* a_mat_ptr, void* out_ptr, uint32_t numel, uint32_t had_size, cudaStream_t stream); ================================================ FILE: kernels/cuda/inference/hadamard_transform/setup.py ================================================ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension versions = [ "-gencode", "arch=compute_80,code=sm_80", "-gencode", "arch=compute_89,code=sm_89", "-gencode", "arch=compute_90,code=sm_90", ] # TODO: assumes installed CUDA toolkit supports sm_80 to sm_90 setup( name='faster_hadamard_transform', ext_modules=[ CUDAExtension( name="faster_hadamard_transform", sources=[ "hadamard_transform.cpp", "hadamard_transform_cuda.cu", ], extra_compile_args={ "cxx": ["-O3"], "nvcc": [ "-O3", "-lineinfo", '--ptxas-options=--warn-on-local-memory-usage', '--ptxas-options=--warn-on-spills', ] + versions } ), ], cmdclass={ 'build_ext': BuildExtension } ) ================================================ FILE: kernels/cuda/inference/hadamard_transform/test.py ================================================ import torch import faster_hadamard_transform import scipy.linalg import math # set to false to check performance correctness_check = True # set to warmup count + 1 to check performance # for quick testing, 2 is good. runs_per_size = 2 # hadamard sizes test_sizes_m = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768] test_elem_counts = [1 << i for i in range(9, 26, 1)] # 32MB # 64MB # 2**28 = 256M print("test_sizes_m: ", test_sizes_m) print("test_elem_counts: ", test_elem_counts) test_count = len(test_sizes_m) * len(test_elem_counts) tests_done = 0 failed_tests = 0 def get_scale(size): return math.sqrt(1 / size) truth_hadamards = [torch.tensor(scipy.linalg.hadamard(size), device='cuda', dtype=torch.float32) * get_scale(size) for size in test_sizes_m] truth_hadamards = [(x.to(torch.float16), x.to(torch.bfloat16)) for x in truth_hadamards] truth_hadamards_fp16, truth_hadamards_bf16 = zip(*truth_hadamards) truth_hadamards_fp16 = list(truth_hadamards_fp16) truth_hadamards_bf16 = list(truth_hadamards_bf16) del truth_hadamards def truth_hadamard_transform_inplace(a: torch.Tensor, truth_hadamards): target_index = -1 for i in range(len(test_sizes_m)): if test_sizes_m[i] == a.shape[1]: target_index = i break return a @ truth_hadamards[int(target_index)] def test_hadamard_transform_inplace_rowmajor(a: torch.Tensor): faster_hadamard_transform.hadamard_transform(a, inplace=True) return a torch.manual_seed(0) def check_correctness(m, elem_c, a, result, truth, atol=1e-2, rtol=0): success = torch.allclose(truth, result, atol=atol, rtol=rtol) if not success: torch.set_printoptions(threshold=100) print(f'Failed test: {m}x{elem_c // m}') print(f'Input:') print(a) print(f'Expected:') print(truth) print(f'Got:') print(result) # worst element diff = torch.abs(truth - result) max_diff = torch.max(diff) print(f'Max diff: {max_diff}') print(f'Max diff index: {torch.argmax(diff)}') diff_input = torch.abs(a - result) max_diff_input = torch.max(diff_input) print(f'Max diff input: {max_diff_input}') print('') exit(1) for m in test_sizes_m: for elem_c in test_elem_counts: if elem_c < m: tests_done += runs_per_size if tests_done % 100 == 0 or tests_done == test_count: print(f'{tests_done}/{test_count} tests done') continue print(f'Testing size {m}x{elem_c // m}') a = torch.randn((elem_c // m, m), device='cuda', dtype=torch.float32) # a = torch.zeros((m, elem_c // m), device='cuda', dtype=torch.float16) # for i in range(min(a.shape[0], a.shape[1])): # a[i, i] = 1.0 if correctness_check: for i in range(runs_per_size): # run test here a_result_fp16 = a.clone().to(torch.float16) a_truth_fp16 = a.clone().to(torch.float16) result_fp16 = test_hadamard_transform_inplace_rowmajor(a_result_fp16) truth_fp16 = truth_hadamard_transform_inplace(a_truth_fp16, truth_hadamards_fp16) check_correctness(m, elem_c, a, result_fp16, truth_fp16, atol=1e-2) # TODO: NOTE: we are not accurate down to 3 decimal places (atol) a_result_bf16 = a.clone().to(torch.bfloat16) a_truth_bf16 = a.clone().to(torch.bfloat16) result_bf16 = test_hadamard_transform_inplace_rowmajor(a_result_bf16) truth_bf16 = truth_hadamard_transform_inplace(a_truth_bf16, truth_hadamards_bf16) check_correctness(m, elem_c, a, result_bf16, truth_bf16, atol=5e-2) # TODO: NOTE: need 5x atol to pass for bf16 else: # run in a row so that warmup is valid a_result = a # we can clobber the result cause we are only interested in timing for i in range(runs_per_size): a_result = test_hadamard_transform_inplace_rowmajor(a_result) a_truth = a for i in range(runs_per_size): a_truth = truth_hadamard_transform_inplace(a_truth) a_memcpy = a # also can compare timing to memcpy temp = torch.empty_like(a) for i in range(runs_per_size): temp.copy_(a_memcpy) # do nothing with results since we are only interested in timing # NOTE: make sure to disable clearing cache in Nsight Compute tests_done += 1 if tests_done % 100 == 0 or tests_done == test_count: print(f'{tests_done}/{test_count} size tests done') ================================================ FILE: kernels/cuda/training/README.md ================================================ kernels with backward pass support ================================================ FILE: kernels/cuda/tutorials/README.md ================================================ CUDA tutorials ================================================ FILE: kernels/cuda/tutorials/flash2.cu ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. // flash2 __global__ void forward_kernel(const float* Q, const float* K, const float* V, const int N, const int d, const int Tc, const int Tr, const int Bc, const int Br, const float sm_scale, float* l, float* m, float* O) { int tidx = threadIdx.x; int bidx = blockIdx.x; // batch index int bidy = blockIdx.y; // head index int qkv_offset = (bidx * gridDim.y * N * d) + (bidy*N*d); int lm_offset = (bidx * gridDim.y *N) + (bidy *N); //l and m offset extern __shared__ float sram[]; int tile_size = Bc * d; size of Qi, Kj, Vj float* Qi = sram; float * Kj = &sram[tile_size]; float* Vj = &sram[tile_size *2]; float* S = &sram[tile_size *3]; for (int j=0; j < Tc; j++) { // load Kj, Vj to sram for (int x=0; x < d; x++) { Kj[(tx*d)+x] = K[qkv_offset + (tile_size *j) + (tx*d) +x]; Vj[(tx*d) + x] = V[qkv_offset +(tile_size *j) + (tx*d) +x]; } __synchthreads(); } } for (int j = 0; j < Tc; j++) { // Load Kj, Vj to SRAM for (int x = 0; x < d; x++) { Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x]; Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x]; } __syncthreads(); // such that the inner loop can use the correct Kj, Vj ================================================ FILE: kernels/needs_perf_help/fp8_gemm_bench.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-strict from typing import Callable, Tuple #import click import torch import triton # @manual from fp8_gemm_rowwise import ( matmul_fp8_block, matmul_fp8_row, quantize_fp8_block, quantize_fp8_row, ) from torch._tensor import Tensor #@click.command() #@click.option("--cuda-graph", type=bool, default=True) #@click.option("--rowwise-tma", is_flag=True, default=False) def bench(cuda_graph: bool, rowwise_tma: bool=True) -> None: """Benchmark bf16 vs scale/cast + fp8.""" def _run_benchmark( bench_factory: Callable[ [torch.Tensor, torch.Tensor], Callable[[], torch.Tensor] ], shape: Tuple[int, int, int] = (1024, 1024, 1024), tag: str = "", ) -> None: # Benchmarks the function returned by bench_factory. # Any pre-processing that should not be benchmarked can occur inside bench_factory. m, n, k = shape input_shape = (m, k) weight_shape = (n, k) base_dtype = torch.bfloat16 input_ = torch.randn(input_shape, device="cuda", dtype=base_dtype) weight_ = torch.randn(weight_shape, device="cuda", dtype=base_dtype) gemm_fn = bench_factory(input_, weight_) if cuda_graph: bench_stream = torch.cuda.Stream() with torch.cuda.stream(bench_stream): ms = triton.testing.do_bench_cudagraph( lambda: gemm_fn(), rep=100, ) else: ms = triton.testing.do_bench( lambda: gemm_fn(), warmup=25, rep=100, ) tflops = (2 * m * n * k) / 1e12 sec = ms / 1e3 perf_str = f"{tflops / sec:.2f}" print( f"{(tag + ':').ljust(40)}\tshape {str(shape):<25} tflops {perf_str:<8} ms {ms:.3f}" ) shapes = [ (8192, 8192, 512), (8192, 8192, 8192), (65536, 8192, 7168), (65536, 3584, 8192), (8192, 14336, 4096), ] for shape in shapes: _run_benchmark(bf16_bench, shape=shape, tag="bf16") _run_benchmark(scale_row_bench, shape=shape, tag="fp8 scale + row gemm") _run_benchmark(scale_block_bench, shape=shape, tag="fp8 scale + block gemm") _run_benchmark( row_gemm_bench, shape=shape, tag="fp8 row gemm only | fp8_fast_accum=True", ) _run_benchmark( row_gemm_bench_no_fast_acc, shape=shape, tag="fp8 row gemm only | fp8_fast_accum=False", ) _run_benchmark( row_gemm_bench_imprecise_acc, shape=shape, tag="fp8 row gemm only | max_num_imprecise_acc=32", ) _run_benchmark(block_gemm_bench, shape=shape, tag="fp8 block gemm only") if rowwise_tma: _run_benchmark( row_gemm_bench_tma, shape=shape, tag="fp8 row gemm only | fp8_fast_accum=True | tma_persistent=True", ) def bf16_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]: def gemm_fn() -> Tensor: return torch.matmul(x, w.T) return gemm_fn def scale_row_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]: # Benchmark quantize(x) + gemm for inference. def run_gemm() -> Tensor: x_fp8: Tensor w_fp8: Tensor x_scale: Tensor w_scale: Tensor x_fp8, x_scale = quantize_fp8_row(x) w_fp8, w_scale = quantize_fp8_row(w) return matmul_fp8_row( x_fp8, w_fp8, x_scale, w_scale, dot_out_dtype=torch.float32, allow_tf32=True, fp8_fast_accum=True, ) return run_gemm def row_gemm_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]: # Benchmark only row-wise gemm, caching scaling. x_fp8: Tensor w_fp8: Tensor x_scale: Tensor w_scale: Tensor x_fp8, x_scale = quantize_fp8_row(x) w_fp8, w_scale = quantize_fp8_row(w) def run_gemm() -> Tensor: return matmul_fp8_row( x_fp8, w_fp8, x_scale, w_scale, dot_out_dtype=torch.float32, allow_tf32=True, fp8_fast_accum=True, ) return run_gemm def row_gemm_bench_tma(x: Tensor, w: Tensor) -> Callable[[], Tensor]: # Benchmark only row-wise gemm with TMA persistent x_fp8: Tensor w_fp8: Tensor x_scale: Tensor w_scale: Tensor x_fp8, x_scale = quantize_fp8_row(x) w_fp8, w_scale = quantize_fp8_row(w) def run_gemm() -> Tensor: return matmul_fp8_row( x_fp8, w_fp8, x_scale, w_scale, dot_out_dtype=torch.float32, allow_tf32=True, fp8_fast_accum=True, tma_persistent=True, ) return run_gemm ================================================ FILE: kernels/needs_perf_help/fp8_rowwise_tma_persistent.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe import logging from typing import List, Optional, Tuple import torch import triton # @manual import triton.language as tl # @manual from torch._tensor import Tensor from triton import Config # @manual from triton.ops.matmul_perf_model import ( # @manual early_config_prune, estimate_matmul_time, ) from triton.runtime.jit import reinterpret as tl_reinterpret, TensorWrapper # @manual logger: logging.Logger = logging.getLogger(__name__) def get_fp8_constants() -> Tuple[torch.dtype, tl.dtype, float, float]: """ Helper function to get constant values for the current platform. Returns: pt_dtype (torch.dtype): The correct torch fp8 datatype. tl_dtype (tl.dtype): The correct triton fp8 datatype. max_fp8 (float): The maximum reprsentable value for the fp8 datatype. eps (float): Minimum clip value to prevent divide by zero. """ if torch.version.hip is not None: pt_fp8_dtype = torch.float8_e4m3fnuz tl_fp8_dtype = tl.float8e4b8 else: pt_fp8_dtype = torch.float8_e4m3fn tl_fp8_dtype = tl.float8e4nv return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12 def convert_fp8_type(tensor, dtype) -> triton.TensorWrapper: """ Converts tensor to triton fp8 type. Args: tensor (torch.Tensor): input tensor. dtype (tl.dtype): target triton dtype. Returns: triton.TensorWrapper: fp8 tensor. """ return tl_reinterpret(tensor, dtype=dtype) def init_to_zero(name): return lambda nargs: nargs[name].zero_() def get_configs_io_bound() -> List[Config]: """ Returns a list of configs for matmul that are IO bound. Returns: List[Config]: list of configs. """ configs = [] for num_stages in [2, 3, 4, 5, 6]: for block_m in [16, 32]: for block_k in [32, 64]: for block_n in [32, 64, 128, 256]: num_warps = 2 if block_n <= 64 else 4 configs.append( Config( { "BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1, }, num_stages=num_stages, num_warps=num_warps, ) ) # split_k for split_k in []: # Disabled [2, 4, 8, 16]: configs.append( Config( { "BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k, }, num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero("C"), ) ) return configs @triton.jit def _kernel_matmul_fp8_row_tma_persistent( A_ptr, B_ptr, C_ptr, M, N, K, A_scale, B_scale, stride_am, stride_ak, stride_bn, stride_bk, stride_cm, stride_cn, dot_out_dtype: tl.constexpr, allow_tf32: tl.constexpr, fp8_fast_accum: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, AB_DTYPE: tl.constexpr, NUM_SMS: tl.constexpr, ) -> None: """Matmul kernel of [M, K] @ [N, K] with row-wise scales performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles. Args: A (TensorWrapper): [M, K] input tensor. B (TensorWrapper): [N, K] input tensor. C (TensorWrapper): [M, N] output tensor. M (int): M dimension of input tensor. N (int): N dimension of input tensor. K (int): K dimension of input tensor. A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B stride_am (int): Stride of M dimension of A. stride_ak (int): Stride of K dimension of A. stride_bn (int): Stride of N dimension of B. stride_bk (int): Stride of K dimension of B. stride_cm (int): Stride of M dimension of C. stride_cn (int): Stride of N dimension of C. dot_out_dtype (torch.dtype): Output type of tensor core. allow_tf32 (bool): Whether to use TF32 for tensor core. fp8_fast_accum (bool): Whether to use fast accumulation for tensor core. BLOCK_M (int): Block size for M dimension. BLOCK_N (int): Block size for N dimension. BLOCK_K (int): Block size for K dimension. GROUP_M (int): Number of groups for M dimension swizzle. SPLIT_K (int): Number of SM's to launch per row. EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K. AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core. """ # Matrix multiplication. start_pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_M) num_pid_n = tl.cdiv(N, BLOCK_N) k_tiles = tl.cdiv(K, BLOCK_K) num_tiles = num_pid_m * num_pid_n tiles_per_SM = num_tiles // NUM_SMS if start_pid < num_tiles % NUM_SMS: tiles_per_SM += 1 tile_id = start_pid - NUM_SMS ki = -1 pid_m = 0 pid_n = 0 offs_am = 0 offs_bn = 0 num_pid_in_group = GROUP_M * num_pid_n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) dtype_fp8 = tl.float8e4nv scale_dtype = tl.float32 for _ in range(0, k_tiles * tiles_per_SM): ki = tl.where(ki == k_tiles - 1, 0, ki + 1) if ki == 0: tile_id += NUM_SMS group_id = tile_id // num_pid_in_group first_pid_m = group_id * GROUP_M group_size_m = min(num_pid_m - first_pid_m, GROUP_M) pid_m = first_pid_m + (tile_id % group_size_m) pid_n = (tile_id % num_pid_in_group) // group_size_m offs_am = pid_m * BLOCK_M offs_bn = pid_n * BLOCK_N offs_am = tl.multiple_of(offs_am, BLOCK_M) offs_bn = tl.multiple_of(offs_bn, BLOCK_N) offs_k = ki * BLOCK_K a = tl._experimental_descriptor_load( A_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], dtype_fp8 ) b = tl._experimental_descriptor_load( B_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], dtype_fp8 ) acc = tl.dot(a, b.T, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) if ki == k_tiles - 1: # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M rn = pid_n * BLOCK_N # # Invert scaling. a_scale = tl._experimental_descriptor_load( A_scale, [rm], [BLOCK_M], scale_dtype ) b_scale = tl._experimental_descriptor_load( B_scale, [rn], [BLOCK_N], scale_dtype ) # pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`. scale = a_scale[:, None] * b_scale[None, :] acc *= scale acc = acc.to(C_ptr.dtype.element_ty) tl._experimental_descriptor_store(C_ptr, acc, [rm, rn]) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) def matmul_fp8_row( a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, dot_out_dtype: Optional[torch.dtype] = None, allow_tf32: bool = True, fp8_fast_accum: bool = True, imprecise_acc: bool = False, tma_persistent: bool = False, ) -> torch.Tensor: """ Performs matmul on [M, K] and [N, K] fp8 matrices with row-wise scalings [M], [N]. Args: a (torch.Tensor): [M, K] input tensor. b (torch.Tensor): [N, K] input tensor. a_scale (torch.Tensor): [M] reciprocal scale tensor per row. A * a_scale = original A b_scale (torch.Tensor): [N] reciprocal scale tensor per row. B * b_scale = original B dot_out_dtype (torch.dtype): Output type of tensor core. allow_tf32 (bool): Whether to use TF32 for tensor core. fp8_fast_accum (bool): Whether to use fast accumulation for tensor core. tma_persistent (bool): Whether to use TMA persistent kernel impl. Returns: torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :]) """ # Get datatypes and constants to use. _, tl_dtype, _, _ = get_fp8_constants() # Reinterpret inputs into proper triton fp8 dtype. a_tl = convert_fp8_type(a, tl_dtype) b_tl = convert_fp8_type(b, tl_dtype) M, N, K, m_key, n_key, k_key, c, dot_out_dtype_triton, device = prep_matmul( a_tl, b_tl, dot_out_dtype ) # launch kernel if a.device == torch.device("cpu"): logger.info( "FP8 Row-wise Triton kernel not supported on cpu, fallback to torch" ) return ( torch.matmul(a.to(torch.bfloat16), b.to(torch.bfloat16).T) * (a_scale[:, None] * b_scale[None, :]) ).to(dtype=c.dtype) def grid(META): return ( triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"], ) NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count def persistent_grid(META): return ( min( NUM_SMS, triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), ), ) if tma_persistent: # used by TMA persistent kernel TMA_SIZE = 128 import numpy as np # autotune doesn't work with TMA # https://github.com/triton-lang/triton/blob/main/python/tutorials/09-persistent-matmul.py#L312 BLOCK_M = 128 BLOCK_N = 256 BLOCK_K = 128 GROUP_M = 8 num_stages = 3 num_warps = 8 desc_a = np.empty(TMA_SIZE, dtype=np.int8) desc_b = np.empty(TMA_SIZE, dtype=np.int8) desc_c = np.empty(TMA_SIZE, dtype=np.int8) desc_a_scale = np.empty(TMA_SIZE, dtype=np.int8) desc_b_scale = np.empty(TMA_SIZE, dtype=np.int8) triton.runtime.driver.active.utils.fill_2d_tma_descriptor( a_tl.data_ptr(), M, K, BLOCK_M, BLOCK_K, a_tl.element_size(), desc_a, ) triton.runtime.driver.active.utils.fill_2d_tma_descriptor( b_tl.data_ptr(), N, K, BLOCK_N, BLOCK_K, b_tl.element_size(), desc_b, ) triton.runtime.driver.active.utils.fill_2d_tma_descriptor( c.data_ptr(), M, N, BLOCK_M, BLOCK_N, c.element_size(), desc_c, ) triton.runtime.driver.active.utils.fill_1d_tma_descriptor( a_scale.data_ptr(), M, BLOCK_M, a_scale.element_size(), desc_a_scale, ) triton.runtime.driver.active.utils.fill_1d_tma_descriptor( b_scale.data_ptr(), N, BLOCK_N, b_scale.element_size(), desc_b_scale, ) desc_a = torch.tensor(desc_a, device="cuda") desc_b = torch.tensor(desc_b, device="cuda") desc_c = torch.tensor(desc_c, device="cuda") desc_a_scale = torch.tensor(desc_a_scale, device="cuda") desc_b_scale = torch.tensor(desc_b_scale, device="cuda") # pyre-ignore[28]: _kernel_matmul_fp8_row_tma_persistent[persistent_grid]( desc_a, desc_b, desc_c, M, N, K, desc_a_scale, desc_b_scale, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), dot_out_dtype=dot_out_dtype_triton, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_M=GROUP_M, AB_DTYPE=False, NUM_SMS=NUM_SMS, num_stages=num_stages, num_warps=num_warps, ) return c ================================================ FILE: kernels/triton/inference/README.md ================================================ Triton Inference kernels ================================================ FILE: kernels/triton/inference/col_major_moe_gemm/README.md ================================================ **MoE (Mixture of Experts) GEMM Kernels** Triton kernel supporting and accelerating MoE inference (Mixtral). This kernel was contributed by IBM Research. This kernel showcases the following optimizations: * Column-Major Launch Schedule (L2 Cache Optimization) * SplitK Work Decomposition (Parallel Work Strategy Optimization) See blog post: https://pytorch.org/blog/accelerating-moe-model/ * v0 = grouped MM * v1 = SplitK MM * v2 = Col Major MM This requires vLLM to be installed to run. ================================================ FILE: kernels/triton/inference/col_major_moe_gemm/perf_test_moe.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import pytest import torch import triton from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.activation import SiluAndMul from v0_moe_fused import fused_moe as fused_moe_grouped from v2_moe_fused import fused_moe as fused_moe_col import time def torch_moe(a, w1, w2, topk_weight, topk_ids): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) out = torch.zeros(B * topk_ids.shape[1], w2.shape[1], dtype=a.dtype, device=a.device) topk_ids = topk_ids.view(-1) topk_weight = topk_weight.view(-1) for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1)).sum(dim=1) def test_fused_moe( m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, ): torch.cuda.manual_seed(3227) a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 score = torch.randn((m, e), device='cuda', dtype=dtype) score = torch.softmax(score, dim=-1) topk_weight, topk_ids = torch.topk(score, topk) start = time.time() triton_output_gl = fused_moe_grouped(a, w1, w2, topk_weight, topk_ids, False) end = time.time() gl_time = end - start gl_time = gl_time * 1000 print("Grouped Launch Time (us): ", gl_time) start = time.time() triton_output_cm = fused_moe_col(a, w1, w2, topk_weight, topk_ids, False) end = time.time() cm_major_time = end - start cm_major_time = cm_major_time * 1000 print("Columm Major Time (us): ", cm_major_time) torch_base = torch_moe(a, w1, w2, topk_weight, topk_ids) torch.testing.assert_close(triton_output_cm, torch_base, atol=1e-2, rtol=0) # print(f"{triton_output_cm=}\n") # print(f"{triton_output_gl=}\n") print(f"Col Major Speedup {((gl_time - cm_major_time)/(gl_time))*100}") if __name__ == '__main__': # test_fused_moe(512, 14336//2, 4096, 8, 2, torch.float16) @triton.testing.perf_report( triton.testing.Benchmark( x_names=['m'], # Argument names to use as an x-axis for the plot x_vals=[ 2**i for i in range(0, 10) ], # Different possible values for `x_name` line_arg='provider', # Argument name whose value corresponds to a different line in the plot # Possible values for `line_arg` line_vals=['cm', 'gl'], # Label name for the lines line_names=["Fused MoE GEMM Kernel - Column Major", "vLLM MoE GEMM Kernel"], # Line styles styles=[('blue', '-'), ('green', '-')], ylabel="TFLOPS", # Label name for the y-axis plot_name="test", # Name for the plot, used also as a file name for saving the plot. args={}, ) ) def benchmark(m, provider): m = m n = 14336//2 k = 4096 e = 8 topk = 2 torch.cuda.manual_seed(3227) dtype = torch.float16 a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 score = torch.randn((m, e), device='cuda', dtype=dtype) score = torch.softmax(score, dim=-1) topk_weight, topk_ids = torch.topk(score, topk) quantiles = [0.5, 0.2, 0.8] if provider == 'cm': ms, min_ms, max_ms = triton.testing.do_bench(lambda: fused_moe_col(a, w1, w2, topk_weight, topk_ids, False), quantiles=quantiles) if provider == 'gl': ms, min_ms, max_ms = triton.testing.do_bench(lambda: fused_moe_grouped(a, w1, w2, topk_weight, topk_ids, False), quantiles=quantiles) perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) benchmark.run(show_plots=True, print_data=True, save_path='./') ================================================ FILE: kernels/triton/inference/col_major_moe_gemm/profile_moe.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import pytest import torch from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.activation import SiluAndMul from v0_moe_fused import fused_moe as fused_moe_base from triton.kernels.mixtral.v1_moe_fused import fused_moe import time def torch_moe(a, w1, w2, topk_weight, topk_ids): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) out = torch.zeros(B * topk_ids.shape[1], w2.shape[1], dtype=a.dtype, device=a.device) topk_ids = topk_ids.view(-1) topk_weight = topk_weight.view(-1) for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1)).sum(dim=1) def test_fused_moe( m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, ): a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 score = torch.randn((m, e), device='cuda', dtype=dtype) score = torch.softmax(score, dim=-1) topk_weight, topk_ids = torch.topk(score, topk) triton_output_splitk = fused_moe(a, w1, w2, topk_weight, topk_ids, False) triton_output_base = fused_moe_base(a, w1, w2, topk_weight, topk_ids, False) if __name__ == '__main__': test_fused_moe(2, 14336//2, 4096, 8, 2, torch.float16) ================================================ FILE: kernels/triton/inference/col_major_moe_gemm/results.html ================================================ ================================================ FILE: kernels/triton/inference/col_major_moe_gemm/test.csv ================================================ m,Fused MoE GEMM Kernel - Column Major,vLLM MoE GEMM Kernel 1.000000,0.412454,0.259585 2.000000,0.883064,0.269004 4.000000,1.751380,0.447645 8.000000,2.106783,0.571765 16.000000,4.121877,1.002326 32.000000,8.259988,1.991226 64.000000,16.105391,3.879061 128.000000,29.356460,7.191373 256.000000,50.550095,12.524316 512.000000,72.862390,19.934314 ================================================ FILE: kernels/triton/inference/col_major_moe_gemm/test_moe_gemm.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import pytest import torch from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.activation import SiluAndMul from v0_moe_fused import fused_moe as fused_moe_v0 from v1_moe_fused import fused_moe as fused_moe_v1 from splitk_moe_fused import fused_moe import time def torch_moe(a, w1, w2, topk_weight, topk_ids): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) out = torch.zeros(B * topk_ids.shape[1], w2.shape[1], dtype=a.dtype, device=a.device) topk_ids = topk_ids.view(-1) topk_weight = topk_weight.view(-1) for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1)).sum(dim=1) @pytest.mark.parametrize("m", [2, 4, 8, 16, 32, 64, 128, 512, 1024, 2048]) @pytest.mark.parametrize("n", [14336//2]) @pytest.mark.parametrize("k", [4096]) @pytest.mark.parametrize("e", [8]) @pytest.mark.parametrize("topk", [2]) @pytest.mark.parametrize("dtype", [torch.float16]) def test_fused_moe( m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, ): torch.cuda.manual_seed(3227) a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 score = torch.randn((m, e), device='cuda', dtype=dtype) score = torch.softmax(score, dim=-1) topk_weight, topk_ids = torch.topk(score, topk) start = time.time() triton_output_gl = fused_moe_v0(a, w1, w2, topk_weight, topk_ids, False) end = time.time() gl_time = end - start gl_time = gl_time * 1000 print("Grouped Launch Time (us): \n", gl_time) start = time.time() triton_output_cm = fused_moe_v1(a, w1, w2, topk_weight, topk_ids, False) end = time.time() cm_major_time = end - start cm_major_time = cm_major_time * 1000 print("Columm Major Time (us): \n", cm_major_time) torch_base = torch_moe(a, w1, w2, topk_weight, topk_ids) assert torch.allclose(triton_output_cm, torch_base, atol=1e-2, rtol=0) assert torch.allclose(triton_output_cm, triton_output_gl, atol=1e-2, rtol=0) # print(f"{triton_output_cm=}\n") # print(f"{triton_output_gl=}\n") # print(f"{torch_base=}\n") print(f"Col Major Speedup: {((gl_time/cm_major_time))} x\n") ================================================ FILE: kernels/triton/inference/col_major_moe_gemm/v0_moe_fused.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # Credit: # Woosuk vLLM: https://github.com/vllm-project/vllm/blob/3d925165f2b18379640a63fbb42de95440d63b64/vllm/model_executor/layers/fused_moe/fused_moe.py """Fused MoE kernel.""" import torch import triton import triton.language as tl from vllm._C import ops @triton.jit def fused_moe_kernel( # Pointers to matrices a_ptr, b_ptr, c_ptr, topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr, # Matrix dimensions N, K, EM, num_valid_tokens, # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` # by to get the element one row down (A has M rows). stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, stride_weight, stride_token_id, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, compute_type: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. Key Parameters: - A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token. - B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension. - C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated, and N is the output feature dimension. - sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to. - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A. This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids` by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert. """ # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # This is done in a grouped ordering to promote L2 data reuse. pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m # ---------------------------------------------------------- # Create pointers for the first blocks of A and B. # We will advance this pointer as we move in the K direction # and accumulate # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) off_experts = tl.load(expert_ids_ptr + pid_m) b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # of fp32 values for higher accuracy. # `accumulator` will be converted back to fp16 after the loop. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the K dimension. a = tl.load(a_ptrs, mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk if MUL_ROUTED_WEIGHT: moe_weight = tl.load(topk_weights_ptr + offs_token * stride_weight, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor): """ Aligns the token distribution across experts to be compatible with block size for matrix multiplication. Parameters: - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token. - block_size: The block size used in block matrix multiplication. - num_experts: The total number of experts. Returns: - sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert. - expert_ids: A tensor indicating the assigned expert index for each block. - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size. This function pads the number of tokens that each expert needs to process so that it is divisible by block_size. Padding ensures that during block matrix multiplication, the dimensions align correctly. Example: Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4: - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens. - As block_size is 4, we pad 1 token for each expert. - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. - Then append padding tokens [12, 12, 12, 12] for each block. - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication. - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. """ sorted_ids = torch.empty( (topk_ids.numel() + num_experts * (block_size - 1), ), dtype=torch.int32, device=topk_ids.device) expert_ids = torch.empty((topk_ids.numel() + num_experts, ), dtype=torch.int32, device=topk_ids.device) sorted_ids.fill_(topk_ids.numel()) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad) return sorted_ids, expert_ids, num_tokens_post_pad def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, mul_routed_weight: bool, top_k: int, config: dict): grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) print(f"Base {config}\n") fused_moe_kernel[grid]( A, B, C, topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, B.shape[1], B.shape[2], sorted_token_ids.shape[0], topk_ids.numel(), A.stride(0), A.stride(1), B.stride(0), B.stride(2), B.stride(1), C.stride(1), C.stride(2), topk_weights.stride(1), sorted_token_ids.stride(0), MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16, **config, ) def fused_moe(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace=False): """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. Parameters: - hidden_states (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - topk_weights (torch.Tensor): The weights for the top-k selected experts. - topk_ids (torch.Tensor): The indices of the top-k selected experts. - inplace (bool): If True, perform the operation in-place. Defaults to False. Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Incompatible dimensions" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [torch.float16, torch.bfloat16] M, _ = hidden_states.shape E, N, _ = w1.shape config = { 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 } if topk_ids.numel() <= w1.shape[0]: config = { 'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1 } intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), device=hidden_states.device, dtype=hidden_states.dtype) intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), device=hidden_states.device, dtype=hidden_states.dtype) intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype) sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, config['BLOCK_SIZE_M'], E) invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, False, topk_ids.shape[1], config) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, True, 1, config) if inplace: return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1, out=hidden_states) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) ================================================ FILE: kernels/triton/inference/col_major_moe_gemm/v1_moe_fused.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # Credit: # Woosuk vLLM: https://github.com/vllm-project/vllm/blob/3d925165f2b18379640a63fbb42de95440d63b64/vllm/model_executor/layers/fused_moe/fused_moe.py """Fused MoE kernel.""" import torch import triton import triton.language as tl from vllm._C import ops from typing import Any, Dict, Optional import functools import json import os @triton.jit() def grouped_launch(pid, m, n, block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr): grid_m = tl.cdiv(m, block_m) grid_n = tl.cdiv(n, block_n) width = group_m * grid_n group_id = pid // width group_size = tl.minimum(grid_m - group_id * group_m, group_m) pid_m = group_id * group_m + (pid % group_size) pid_n = (pid % width) // group_size return pid_m, pid_n @triton.jit() def fused_moe_kernel_splitk( # Pointers to matrices a_ptr, b_ptr, c_ptr, topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr, # Matrix dimensions N, K, EM, num_valid_tokens, # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` # by to get the element one row down (A has M rows). stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, stride_weight, stride_token_id, # Meta-parameters block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr, group_m: tl.constexpr, split_k: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, compute_type: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. Key Parameters: - A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token. - B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension. - C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated, and N is the output feature dimension. - sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to. - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A. This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids` by expert index and padding ensures divisibility by block_m, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert. """ # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # This is done in a grouped ordering to promote L2 data reuse. # Scheduling Problem pid = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) # print("num_tokens_post_padded: ", num_tokens_post_padded) pid_m, pid_n = grouped_launch(pid, EM, N, block_m, block_n, group_m) total_blocks_k = tl.cdiv(K, block_k*split_k) if pid_m * block_m >= num_tokens_post_padded: return offs_token_id = pid_m * block_m + tl.arange(0, block_m) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens offs_bn = (pid_n * block_n + tl.arange(0, block_n)) % N offs_k = pid_k*block_k + tl.arange(0, block_k) a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) off_experts = tl.load(expert_ids_ptr + pid_m) b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) accumulator = tl.zeros((block_m, block_n), dtype=tl.float32) for k in range(0, total_blocks_k): a = tl.load(a_ptrs, mask=token_mask[:, None] & (offs_k[None, :] < K - k * (block_k * split_k)), other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * (block_k * split_k), other=0.0) # We accumulate along the K dimension. accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. a_ptrs += block_k * stride_ak * split_k b_ptrs += block_k * stride_bk * split_k if MUL_ROUTED_WEIGHT: moe_weight = tl.load(topk_weights_ptr + offs_token * stride_weight, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * block_n + tl.arange(0, block_n) c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.atomic_add(c_ptrs, accumulator, mask=c_mask) def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor): """ Aligns the token distribution across experts to be compatible with block size for matrix multiplication. Parameters: - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token. - block_size: The block size used in block matrix multiplication. - num_experts: The total number of experts. Returns: - sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert. - expert_ids: A tensor indicating the assigned expert index for each block. - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size. This function pads the number of tokens that each expert needs to process so that it is divisible by block_size. Padding ensures that during block matrix multiplication, the dimensions align correctly. Example: Given topk_ids = [[2, 3, ], [1, 2], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4: - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens. - As block_size is 4, we pad 1 token for each expert. - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. - Then append padding tokens [12, 12, 12, 12] for each block. - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication. - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. """ sorted_ids = torch.empty( (topk_ids.numel() + num_experts * (block_size - 1), ), dtype=torch.int32, device=topk_ids.device) expert_ids = torch.empty((topk_ids.numel() + num_experts, ), dtype=torch.int32, device=topk_ids.device) sorted_ids.fill_(topk_ids.numel()) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad) return sorted_ids, expert_ids, num_tokens_post_pad def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, mul_routed_weight: bool, top_k: int, config: dict): N = B.shape[1] # 14336 K = B.shape[2] # 4096 EM = sorted_token_ids.shape[0] # 124 grid = lambda META: (triton.cdiv(EM, META['block_m']) * triton.cdiv(N, META['block_n']), META['split_k']) # print(f"SplitK {config}\n") k = fused_moe_kernel_splitk[grid]( A, B, C, topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, # 64 N, K, EM, topk_ids.numel(), A.stride(0), A.stride(1), B.stride(0), B.stride(2), B.stride(1), C.stride(1), C.stride(2), topk_weights.stride(1), sorted_token_ids.stride(0), MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16, **config, num_warps=8, ) # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n") # with open('split_k_moe_ttir.txt', 'w') as f: # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n", file=f) # print("IR", k.asm['ttir'], file=f) # print("TTGIR", k.asm['ttgir'], file=f) # print("PTX", k.asm['ptx'], file=f) # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n", file=f) def fused_moe(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace=False): """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. Parameters: - hidden_states (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - topk_weights (torch.Tensor): The weights for the top-k selected experts. - topk_ids (torch.Tensor): The indices of the top-k selected experts. - inplace (bool): If True, perform the operation in-place. Defaults to False. Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Incompatible dimensions" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [torch.float16, torch.bfloat16] M, _ = hidden_states.shape E, N, _ = w1.shape # Prefill config_w1 = { 'block_m': 32, 'block_n': 64, 'block_k': 64, 'group_m': 8, 'split_k': 2, } config_w2 = { 'block_m': 32, 'block_n': 64, 'block_k': 64, 'group_m': 8, 'split_k': 2, } # Decoding if topk_ids.numel() <= w1.shape[0]: config_w1 = { 'block_m': 16, 'block_n': 64, 'block_k': 128, 'group_m': 8, 'split_k' : 2, } config_w2 = { 'block_m': 16, 'block_n': 128, 'block_k': 64, 'group_m': 8, 'split_k': 4, } intermediate_cache1 = torch.zeros((M, topk_ids.shape[1], N), device=hidden_states.device, dtype=hidden_states.dtype) intermediate_cache2 = torch.zeros((M * topk_ids.shape[1], N // 2), device=hidden_states.device, dtype=hidden_states.dtype) intermediate_cache3 = torch.zeros((M, topk_ids.shape[1], w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype) sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, config_w1['block_m'], E) invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, False, topk_ids.shape[1], config_w1) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, True, 1, config_w2) if inplace: return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1, out=hidden_states) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) ================================================ FILE: kernels/triton/inference/col_major_moe_gemm/v2_moe_fused.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # Credit: # Woosuk vLLM: https://github.com/vllm-project/vllm/blob/3d925165f2b18379640a63fbb42de95440d63b64/vllm/model_executor/layers/fused_moe/fused_moe.py """Fused MoE kernel.""" import torch import triton import triton.language as tl from vllm._C import ops @triton.jit() def col_major(pid, m, n, block_m: tl.constexpr, block_n: tl.constexpr): grid_m = tl.cdiv(m, block_m) grid_n = tl.cdiv(n, block_n) pid_m = (pid % grid_n) pid_n = pid // grid_m return pid_m, pid_n @triton.jit def fused_moe_kernel( # Pointers to matrices a_ptr, b_ptr, c_ptr, topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr, # Matrix dimensions N, K, EM, num_valid_tokens, # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` # by to get the element one row down (A has M rows). stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, stride_weight, stride_token_id, # Meta-parameters block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, compute_type: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. Key Parameters: - A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token. - B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension. - C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated, and N is the output feature dimension. - sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to. - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A. This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids` by expert index and padding ensures divisibility by block_m, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert. """ pid = tl.program_id(axis=0) pid_m, pid_n = col_major(pid, EM, N, block_m, block_n,) num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * block_m >= num_tokens_post_padded: return offs_token_id = pid_m * block_m + tl.arange(0, block_m) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens offs_bn = (pid_n * block_n + tl.arange(0, block_n)) % N offs_k = tl.arange(0, block_k) a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) off_experts = tl.load(expert_ids_ptr + pid_m) b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. # We accumulate into a `[block_m, block_n]` block # of fp32 values for higher accuracy. # `accumulator` will be converted back to fp16 after the loop. accumulator = tl.zeros((block_m, block_n), dtype=tl.float32) for k in range(0, tl.cdiv(K, block_k)): # Load the next block of A and B, generate a mask by checking the K dimension. a = tl.load(a_ptrs, mask=token_mask[:, None] & (offs_k[None, :] < K - k * block_k), other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * block_k, other=0.0) # We accumulate along the K dimension. accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. a_ptrs += block_k * stride_ak b_ptrs += block_k * stride_bk if MUL_ROUTED_WEIGHT: moe_weight = tl.load(topk_weights_ptr + offs_token * stride_weight, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * block_n + tl.arange(0, block_n) c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor): """ Aligns the token distribution across experts to be compatible with block size for matrix multiplication. Parameters: - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token. - block_size: The block size used in block matrix multiplication. - num_experts: The total number of experts. Returns: - sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert. - expert_ids: A tensor indicating the assigned expert index for each block. - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size. This function pads the number of tokens that each expert needs to process so that it is divisible by block_size. Padding ensures that during block matrix multiplication, the dimensions align correctly. Example: Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4: - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens. - As block_size is 4, we pad 1 token for each expert. - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. - Then append padding tokens [12, 12, 12, 12] for each block. - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication. - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. """ sorted_ids = torch.empty( (topk_ids.numel() + num_experts * (block_size - 1), ), dtype=torch.int32, device=topk_ids.device) expert_ids = torch.empty((topk_ids.numel() + num_experts, ), dtype=torch.int32, device=topk_ids.device) sorted_ids.fill_(topk_ids.numel()) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad) return sorted_ids, expert_ids, num_tokens_post_pad def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, mul_routed_weight: bool, top_k: int, config: dict): EM = sorted_token_ids.shape[0] N = B.shape[1] grid = lambda META: (triton.cdiv(EM, META['block_m']) * triton.cdiv(N, META['block_n']), ) fused_moe_kernel[grid]( A, B, C, topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, B.shape[1], B.shape[2], sorted_token_ids.shape[0], topk_ids.numel(), A.stride(0), A.stride(1), B.stride(0), B.stride(2), B.stride(1), C.stride(1), C.stride(2), topk_weights.stride(1), sorted_token_ids.stride(0), MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16, **config, ) def fused_moe(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace=False): """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. Parameters: - hidden_states (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - topk_weights (torch.Tensor): The weights for the top-k selected experts. - topk_ids (torch.Tensor): The indices of the top-k selected experts. - inplace (bool): If True, perform the operation in-place. Defaults to False. Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Incompatible dimensions" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [torch.float16, torch.bfloat16] M, _ = hidden_states.shape E, N, _ = w1.shape config = { 'block_m': 64, 'block_n': 64, 'block_k': 32, } if topk_ids.numel() <= w1.shape[0]: config = { 'block_m': 16, 'block_n': 32, 'block_k': 64, } intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), device=hidden_states.device, dtype=hidden_states.dtype) intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), device=hidden_states.device, dtype=hidden_states.dtype) intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype) sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, config['block_m'], E) invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, False, topk_ids.shape[1], config) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, True, 1, config) if inplace: return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1, out=hidden_states) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) ================================================ FILE: kernels/triton/inference/flash_attention/stay_attention.py ================================================ import triton.language as tl import triton import torch @triton.jit() def stay_attention( q_ptr, k_ptr, v_ptr, o_ptr, stride_b, stride_nh, stride_qs, stride_qh, stride_ks, stride_kh, stride_vs, stride_vh, stride_os, stride_oh, seq_len, head_dim, sm_scale, BLOCK_SEQ: tl.constexpr, BLOCK_HD: tl.constexpr, NUM_SM: tl.constexpr, ): pid_b = tl.program_id(0) pid_h = tl.program_id(1) pid = tl.program_id(2) qkv_offset = pid_b*stride_b + pid_h*stride_nh num_tiles_seq_len = tl.cdiv(seq_len, BLOCK_SEQ) tiles_per_SM = num_tiles_seq_len // NUM_SM if pid < num_tiles_seq_len % NUM_SM: tiles_per_SM += 1 tile_id = pid - NUM_SM si = -1 pid_seq_m = 0 pid_seq_n = 0 offs_seq_m = tl.arange(0, BLOCK_SEQ) offs_seq_n = tl.arange(0, BLOCK_SEQ) offs_head = tl.arange(0, BLOCK_HD) q_ptrs = q_ptr + qkv_offset + offs_seq_n[:, None]*stride_qs + offs_head[None, :]*stride_qh # initialize pointer to m and l m_i = tl.zeros([BLOCK_SEQ], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_SEQ], dtype=tl.float32) qk_scale = sm_scale * 1.44269504 q = tl.load(q_ptrs) q = (q * qk_scale) pv = tl.zeros([BLOCK_SEQ, BLOCK_HD], dtype=tl.float32) for _ in range(0, num_tiles_seq_len * tiles_per_SM): si = tl.where(si == num_tiles_seq_len - 1, 0, si + 1) if si == 0: tile_id += NUM_SM pid_seq_m = pid // num_tiles_seq_len pid_seq_n = pid % num_tiles_seq_len offs_seq_m = pid_seq_m*BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) offs_seq_n = pid_seq_n*BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) offs_head = tl.arange(0, BLOCK_HD) q_ptrs = q_ptr + qkv_offset + offs_seq_n[:, None]*stride_qs + offs_head[None, :]*stride_qh qk_scale = sm_scale * 1.44269504 q = tl.load(q_ptrs) q = (q * qk_scale) offs_seq_m = si*BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) offs_head = tl.arange(0, BLOCK_HD) k_ptrs = k_ptr + qkv_offset + offs_seq_m[:, None]*stride_ks + offs_head[None, :]*stride_kh v_ptrs = v_ptr + qkv_offset + offs_seq_m[:, None]*stride_vs + offs_head[None, :]*stride_vh k = tl.load(k_ptrs) v = tl.load(v_ptrs) qk = tl.dot(q.to(tl.float16), k.T, out_dtype=tl.float32) # -- compute scaling constant --- m_i_new = tl.maximum(m_i, tl.max(qk, 1)) alpha = tl.math.exp2(m_i - m_i_new) p = tl.math.exp2(qk - m_i_new[:, None]) # -- scale and update acc -- pv *= alpha[:, None] pv += tl.dot(p.to(tl.float16), v, out_dtype=tl.float32) # -- update m_i and l_i -- l_i = l_i * alpha + tl.sum(p, 1) m_i = m_i_new if si == num_tiles_seq_len - 1: offs_seq_n = pid_seq_n*BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) pv = pv / l_i[:, None] o_ptrs = o_ptr + qkv_offset + offs_seq_n[:, None]*stride_os + offs_head[None, :]*stride_oh tl.store(o_ptrs, pv) pv = tl.zeros([BLOCK_SEQ, BLOCK_HD], dtype=tl.float32) def flash_fn(q, k, v): batch, num_heads, seq_len, head_dim = q.shape sm_scale = 0.5 BLOCK_SEQ = 64 BLOCK_HD = 128 NUM_SM = torch.cuda.get_device_properties("cuda").multi_processor_count grid = (batch, num_heads, NUM_SM) o = torch.zeros(batch, num_heads, seq_len, head_dim, dtype=torch.float16, device='cuda') stay_attention[grid](q, k, v, o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(2), k.stride(3), v.stride(2), v.stride(3), o.stride(2), o.stride(3), seq_len, head_dim, sm_scale, BLOCK_SEQ, BLOCK_HD, NUM_SM) return o if __name__ == '__main__': torch.manual_seed(0) batch, num_heads, seq_len, head_dim = 1, 32, 4096, 128 q = torch.randn(batch, num_heads, seq_len, head_dim, dtype=torch.float16, device='cuda') // 10 k = torch.randn(batch, num_heads, seq_len, head_dim, dtype=torch.float16, device='cuda') // 10 v = torch.randn(batch, num_heads, seq_len, head_dim, dtype=torch.float16, device='cuda') // 10 sm_scale = 0.5 p = (q @ k.transpose(2, 3)) * sm_scale p = torch.softmax(p.float(), dim=-1) o_torch = torch.matmul(p.to(torch.float16), v) o_triton = flash_fn(q, k, v) print(f"{o_triton=}") print(f"{o_torch=}") torch.testing.assert_close(o_triton, o_torch, atol=1e-2, rtol=0) ================================================ FILE: kernels/triton/inference/fp8/float8_groupwise_quant.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from typing import Tuple import torch import triton import triton.language as tl from triton import Config # global constants FP8_MAX: tl.constexpr = 448.0 EPSILON: tl.constexpr = 1e-12 @triton.jit def _float8_groupwise_quant_kernel( in_ptr, out_ptr, scale_ptr, BLOCK_SIZE: tl.constexpr ): """ Quantizes the input tensor via BLOCK_SIZE groupwise scaling (i.e. 1x 128). Results: Stores 1 - float8_e4m3fn result in `out_ptr` 2 - scaling factor in `scale_ptr` """ pid = tl.program_id(axis=0) # load inputs offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x_vec = tl.load(in_ptr + offsets).to(tl.float32) # calc max and scale max_val = tl.max(tl.abs(x_vec)) safe_scale = tl.maximum(max_val, EPSILON) / FP8_MAX y_vec = x_vec / safe_scale # quantize y_clamped = tl.minimum(tl.maximum(y_vec, -FP8_MAX), FP8_MAX) y_fp8 = y_clamped.to(out_ptr.dtype.element_ty) # store quantized values and scale tl.store(out_ptr + offsets, y_fp8) tl.store(scale_ptr + pid, safe_scale) def float8_groupwise_quantize(x: torch.Tensor, block_size=128): """ Quantizes the input tensor via block_size groupwise scaling (i.e. 1x 128) to torch.float8_e4m3fn format. Results: Stores 1 - float8_e4m3fn result in `out_ptr` 2 - scaling factor in `scale_ptr` """ # verify input tensor x_last_dim_size = x.size(-1) # evenly divisible? if x_last_dim_size % block_size != 0: raise ValueError( f"Input tensor must have a last dimension that is a multiple of {block_size}" ) # contiguous? if x.stride(-1) != 1: x = x.contiguous() # allocate output tensors output = torch.empty_like(x, dtype=torch.float8_e4m3fn) scales = x.new_empty( *x.size()[:-1], x_last_dim_size // block_size, dtype=torch.float32 ) print(f"{scales.size()=}") grid = lambda meta: (x.numel() // block_size,) _float8_groupwise_quant_kernel[grid]( in_ptr=x, out_ptr=output, scale_ptr=scales, BLOCK_SIZE=block_size, ) return output, scales ================================================ FILE: kernels/triton/inference/fp8/scaled_fp8_gemm.py ================================================ import torch import triton import triton.language as tl import time import os os.environ['ENABLE_TMA'] = '1' @triton.jit def grouped_launch(pid, m, n, block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr): grid_m = tl.cdiv(m, block_m) grid_n = tl.cdiv(n, block_n) width = group_m * grid_n group_id = pid // width group_size = tl.minimum(grid_m - group_id * group_m, group_m) pid_m = group_id * group_m + (pid % group_size) pid_n = (pid % width) // group_size return pid_m, pid_n @triton.jit() def column_major(pid, m, n, block_m: tl.constexpr, block_n: tl.constexpr): grid_m = tl.cdiv(m, block_m) pid_m = pid % grid_m pid_n = pid // grid_m return pid_m, pid_n @triton.jit def scaled_gemm_splitk(a_ptr, b_ptr, c_ptr, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, scale_a, scale_b, m, n, k, block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr, split_k: tl.constexpr, group_m: tl.constexpr): pid = tl.program_id(0) pid_k = tl.program_id(1) grid_k = tl.cdiv(k, block_k*split_k) # Column Major produces speedup over Grouped Launch for small-to-medium M pid_m, pid_n = column_major(pid, m, n, block_m, block_n) offs_m = pid_m*block_m + tl.arange(0, block_m) offs_n = pid_n*block_n + tl.arange(0, block_n) offs_k = pid_k*block_k + tl.arange(0, block_k) offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m) offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) acc = tl.zeros((block_m, block_n), dtype=tl.float32) for k_ in range(0, grid_k): k_remaining = k - k_ * (block_k * split_k) a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) acc = tl.dot(a, b, acc, out_dtype=tl.float32) a_ptrs += block_k * split_k * stride_ak b_ptrs += block_k * split_k * stride_bk # Scaled in SRAM before write back to DRAM acc = scale_a * scale_b * acc acc.to(tl.float16) offs_m = pid_m*block_m + tl.arange(0, block_m) offs_n = pid_n*block_n + tl.arange(0, block_n) c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn) mask = (offs_m < m)[:, None] & (offs_n < n)[None, :] tl.atomic_add(c_ptrs, acc, mask=mask) def scaled_mm_splitk(a, b, scale_a: float=1.0, scale_b: float=1.0): assert a.shape[1] == b.shape[0] m, k = a.shape _, n = b.shape block_m = 64 block_n = 64 block_k = 256 num_stages = 3 num_warps = 8 split_k = 4 group_m = 8 total_blocks_m = triton.cdiv(m, block_m) total_blocks_n = triton.cdiv(n, block_n) total_programs_mn = total_blocks_m * total_blocks_n total_programs_k = split_k grid = (total_programs_mn, total_programs_k) c = torch.zeros((m, n), device=a.device, dtype=torch.float16) k = scaled_gemm_splitk[grid](a, b, c, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), scale_a, scale_b, m, n, k, block_m, block_n, block_k, split_k, group_m, num_stages=num_stages, num_warps=num_warps) return c ================================================ FILE: kernels/triton/inference/fp8/splitk_gemm_fp8.py ================================================ import torch import triton import triton.language as tl import time import os os.environ['ENABLE_TMA'] = '1' @triton.jit def grouped_launch(pid, m, n, block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr): grid_m = tl.cdiv(m, block_m) grid_n = tl.cdiv(n, block_n) width = group_m * grid_n group_id = pid // width group_size = tl.minimum(grid_m - group_id * group_m, group_m) pid_m = group_id * group_m + (pid % group_size) pid_n = (pid % width) // group_size return pid_m, pid_n @triton.jit() def col_major(pid, m, n, block_m: tl.constexpr, block_n: tl.constexpr): grid_m = tl.cdiv(m, block_m) pid_m = pid % grid_m pid_n = pid // grid_m return pid_m, pid_n @triton.jit def gemm_split_k_kernel(a_ptr, b_ptr, c_ptr, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, m, n, k, block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr, split_k: tl.constexpr, group_m: tl.constexpr): pid = tl.program_id(0) pid_k = tl.program_id(1) grid_k = tl.cdiv(k, block_k*split_k) pid_m, pid_n = grouped_launch(pid, m, n, block_m, block_n, group_m) offs_m = pid_m*block_m + tl.arange(0, block_m) offs_n = pid_n*block_n + tl.arange(0, block_n) offs_k = pid_k*block_k + tl.arange(0, block_k) offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m) offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) acc = tl.zeros((block_m, block_n), dtype=tl.float32) for k_ in range(0, grid_k): k_remaining = k - k_ * (block_k * split_k) a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) acc = tl.dot(a, b, acc, out_dtype=tl.float32) a_ptrs += block_k * split_k * stride_ak b_ptrs += block_k * split_k * stride_bk acc.to(tl.float16) offs_m = pid_m*block_m + tl.arange(0, block_m) offs_n = pid_n*block_n + tl.arange(0, block_n) c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn) mask = (offs_m < m)[:, None] & (offs_n < n)[None, :] tl.atomic_add(c_ptrs, acc, mask=mask) def gemm_split_k(a, b): m, k = a.shape _, n = b.shape block_m = 64 block_n = 64 block_k = 512 num_stages = 3 num_warps = 8 split_k = 4 group_m = 8 total_blocks_m = triton.cdiv(m, block_m) total_blocks_n = triton.cdiv(n, block_n) total_programs_mn = total_blocks_m * total_blocks_n total_programs_k = split_k grid = (total_programs_mn, total_programs_k) # print(f"problem m size: {m}, tile size m: {block_m}, total blocks m: {total_blocks_m}") # print(f"problem n size: {n}, tile size n: {block_n}, total blocks n: {total_blocks_n}") # print(f"problem k size: {k}, tile size k: {block_k}, total thread blocks k: {split_k}") # print(f"total thread blocks k: {k}, total thread blocks m and total thread blocks n = {total_blocks_m=} x {total_blocks_n} = {total_programs_mn}") # print(f"{total_programs_mn=}, {total_programs_k=}") c = torch.zeros((m, n), device=a.device, dtype=torch.float16) k = gemm_split_k_kernel[grid](a, b, c, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), m, n, k, block_m, block_n, block_k, split_k, group_m, num_stages=num_stages, num_warps=num_warps) # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n") # with open('matmul_split_k.txt', 'w') as f: # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n", file=f) # print("IR", k.asm['ttir'], file=f) # print("TTGIR", k.asm['ttgir'], file=f) # print("PTX", k.asm['ptx'], file=f) # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n", file=f) return c ================================================ FILE: kernels/triton/inference/fp8/tma_gemm.py ================================================ import triton import triton.language as tl import numpy as np import torch @triton.jit def gemm_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, # prob_m, prob_n, prob_k, block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr): pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(prob_m, block_m) num_pid_k = tl.cdiv(prob_k, block_k) pid_m = pid % num_pid_m pid_n = pid // num_pid_m offs_am = pid_m * block_m offs_bn = pid_n * block_n offs_k = 0 accumulator = tl.zeros((block_m, block_n), dtype=tl.float32) for kk in range(0, num_pid_k): a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [block_m, block_k], tl.float8e4nv) b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [block_n, block_k], tl.float8e4nv) accumulator = tl.dot(a, b.T, acc=accumulator, out_dtype=tl.float32) offs_k += block_k accumulator = accumulator.to(tl.float16) tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn]) def matmul(a, b, config=None): m, _ = a.shape n, k = b.shape if config: block_m = config["block_m"] block_n = config["block_n"] block_k = config["block_k"] num_warps = config["num_warps"] num_stages = config["num_stages"] block_m = 64 block_n = 64 block_k = 256 num_warps = 4 num_stages = 4 TMA_SIZE = 512 desc_a = np.empty(TMA_SIZE, dtype=np.int8) desc_b = np.empty(TMA_SIZE, dtype=np.int8) desc_c = np.empty(TMA_SIZE, dtype=np.int8) c = torch.empty((m, n), dtype=torch.float16, device='cuda') triton.runtime.driver.active.utils.fill_2d_tma_descriptor(a.data_ptr(), m, k, block_m, block_k, a.element_size(), desc_a) triton.runtime.driver.active.utils.fill_2d_tma_descriptor(b.data_ptr(), n, k, block_n, block_k, b.element_size(), desc_b) triton.runtime.driver.active.utils.fill_2d_tma_descriptor(c.data_ptr(), m, n, block_m, block_n, c.element_size(), desc_c) desc_a = torch.tensor(desc_a, device='cuda') desc_b = torch.tensor(desc_b, device='cuda') desc_c = torch.tensor(desc_c, device='cuda') total_blocks_m = triton.cdiv(m, block_m) total_blocks_n = triton.cdiv(n, block_n) grid = (total_blocks_m * total_blocks_n, 1, 1) k = gemm_kernel_tma[grid]( desc_a, desc_b, desc_c, m, n, k, block_m, block_n, block_k, num_warps=num_warps, num_stages=num_stages, ) # with open('tma_fp8.ttgir', 'w') as f: # print(k.asm['ttgir'], file=f) # with open('tma_fp8.ptx', 'w') as f: # print(k.asm['ptx'], file=f) return c if __name__ == '__main__': M = 128 N = 4096 K = 4096 a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) b = b.T.contiguous() c = matmul(a, b) ================================================ FILE: kernels/triton/inference/gptq/a100_qlinear.py ================================================ import triton import triton.language as tl import torch @triton.jit() def _a100_quantized_matmul(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales_g, stride_scales_n, stride_zeros_g, stride_zeros_n, groupsize, m, n, k, block_size_m: tl.constexpr, block_size_n: tl.constexpr, block_size_k: tl.constexpr, group_size_m: tl.constexpr, ): pid = tl.program_id(0) total_blocks_m = tl.cdiv(m, block_size_m) total_blocks_n = tl.cdiv(n, block_size_n) total_blocks_k = tl.cdiv(k, block_size_k) num_blocks_in_group = group_size_m * total_blocks_n group_id = pid // num_blocks_in_group group_size = min(total_blocks_m - group_id * group_size_m, group_size_m) pid_m = group_id * group_size_m + (pid % group_size) pid_n = (pid % num_blocks_in_group) // (group_size) offs_m = (pid_m * block_size_m + tl.arange(0, block_size_m)) % m offs_n = (pid_n * block_size_n + tl.arange(0, block_size_n)) % n offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_size_m), block_size_m) offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_size_n), block_size_n) offs_k = tl.arange(0, block_size_k) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) scales_ptrs = scales_ptr + offs_bn * stride_scales_n zeros_ptrs = zeros_ptr + ((offs_bn // 8) * stride_zeros_n) shifter = (offs_k % 8) * 4 zeros_shifter = (offs_bn % 8) * 4 output = tl.zeros((block_size_m, block_size_n), dtype=tl.float32) for k in range(0, total_blocks_k): a = tl.load(a_ptrs) b = tl.load(b_ptrs) g_id = k // (groupsize // block_size_k) ptr = scales_ptrs + g_id * stride_scales_g scales = tl.load(ptr) ptr = zeros_ptrs + g_id * stride_zeros_g zeros = tl.load(ptr) zeros = (zeros >> zeros_shifter) & 0xF zeros = (zeros + 1) * scales b = (b >> shifter[:, None]) & 0xF # b -> int32 b = b * scales[None, :] - zeros[None, :] # b -> fp16 output += tl.dot(a, b) a_ptrs += stride_ak * block_size_k b_ptrs += (block_size_k//8) * stride_bk output.to(tl.float16) offs_cm = pid_m * block_size_m + tl.arange(0, block_size_m) offs_cn = pid_n * block_size_n + tl.arange(0, block_size_n) c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) tl.store(c_ptrs, output) class a100_qlinear(torch.autograd.Function): def forward(ctx, a, b, scales, zeros): m, k = a.shape _, n = b.shape quant_groupsize = 128 block_size_m = 16 block_size_n = 32 # [N = 4096 // 32] = 128 blocks block_size_k = 256 group_size_m = 8 num_warps = 4 num_stages = 8 total_blocks_m = triton.cdiv(m, block_size_m) total_blocks_n = triton.cdiv(n, block_size_n) total_programs = total_blocks_m * total_blocks_n grid = (total_programs, 1) c = torch.zeros((m, n), device=b.device, dtype=torch.float16) k = _a100_quantized_matmul[grid]( a, b, c, scales, zeros, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), scales.stride(0), scales.stride(1), zeros.stride(0), zeros.stride(1), quant_groupsize, m, n, k, block_size_m, block_size_n, block_size_k, group_size_m, num_warps = num_warps, num_stages = num_stages, ) print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n") with open('dequant_simple.txt', 'w') as f: print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n", file=f) print("IR", k.asm['ttir'], file=f) print("TTGIR", k.asm['ttgir'], file=f) print("PTX", k.asm['ptx'], file=f) print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n", file=f) print(f"{total_blocks_m=} x {total_blocks_n=} = {total_programs=}") return c a100_qlinear = a100_qlinear.apply ================================================ FILE: kernels/triton/inference/gptq/benchmark.py ================================================ import argparse import time import logging from tqdm import tqdm import torch from transformers import AutoTokenizer from auto_gptq import AutoGPTQForCausalLM # Configure logging logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) def benchmark_generation_speed(model, tokenizer, prompt, batch_size, device, num_passes=5): token_dict = tokenizer([prompt] * batch_size, return_tensors="pt", padding="longest").to(device) total_generation_time = 0 total_num_generated_tokens = 0 # Warmup logger.info("Starting warmup...") for _ in tqdm(range(4), desc="Warmup", leave=False): with torch.inference_mode(): _ = model.generate(**token_dict, min_length=30, max_length=30) logger.info("Starting benchmark...") with tqdm(range(num_passes), desc="Benchmark Passes") as pbar: for pass_num in pbar: token_dict = tokenizer([prompt] * batch_size, return_tensors="pt", padding="longest").to(device) start = time.time() with torch.inference_mode(): outputs_ids = model.generate(**token_dict, min_length=30, max_length=30) end = time.time() generation_time = end - start num_generated_tokens = sum(len(output_ids) for output_ids in outputs_ids) - batch_size * len(token_dict['input_ids'][0]) tokens_per_second = num_generated_tokens / generation_time total_generation_time += generation_time total_num_generated_tokens += num_generated_tokens # Update tqdm post-fix with current iteration results pbar.set_postfix({"Time (s)": f"{generation_time:.2f}", "Tokens/s": f"{tokens_per_second:.2f}"}) # Calculate average statistics avg_generation_time = total_generation_time / num_passes avg_tokens_per_second = total_num_generated_tokens / total_generation_time avg_num_generated_tokens = total_num_generated_tokens / num_passes # Log average statistics logger.info(f"Batch size: {batch_size}, Avg Time: {avg_generation_time:.2f}s, Avg Tokens/s: {avg_tokens_per_second:.2f}, Avg Total tokens: {avg_num_generated_tokens}") return avg_generation_time, avg_tokens_per_second, avg_num_generated_tokens def main(): parser = argparse.ArgumentParser(description='Benchmark Llama-70B') parser.add_argument('--use_triton', type=lambda x: (str(x).lower() == 'true'), help='use Triton Kernel') parser.add_argument('--batch_size', type=int, required=True, help='Batch size for the benchmark') args = parser.parse_args() device = "cuda:5" quantized_model_dir = '/net/storage149/autofs/css22/ccyang/fm-models/llama-gptq/gptq_output_act0_grp128_bluewiki' tokenizer = AutoTokenizer.from_pretrained(quantized_model_dir, use_fast=True) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" if args.use_triton: torch.cuda.empty_cache() model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device=device, inject_fused_attention=False, inject_fused_mlp=False, use_triton=args.use_triton, disable_exllamaV2=True, low_cpu_mem_usage=True, warmup_triton=False) else: model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device=device, inject_fused_attention=False, inject_fused_mlp=False, use_triton=False, disable_exllamaV2=False, low_cpu_mem_usage=True, warmup_triton=False) model = torch.compile(model, mode="reduce-overhead") benchmark_generation_speed(model, tokenizer, "auto-gptq is a", args.batch_size, device) if __name__ == "__main__": main() ================================================ FILE: kernels/triton/inference/gptq/h100_qlinear.py ================================================ import triton import triton.language as tl import torch @triton.jit() def _h100_quantized_matmul(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales_g, stride_scales_n, stride_zeros_g, stride_zeros_n, groupsize, m, n, k, block_size_m: tl.constexpr, block_size_n: tl.constexpr, block_size_k: tl.constexpr, group_size_m: tl.constexpr, fp8_fast_accum: tl.constexpr,): pid = tl.program_id(0) total_blocks_m = tl.cdiv(m, block_size_m) total_blocks_n = tl.cdiv(n, block_size_n) total_blocks_k = tl.cdiv(k, block_size_k) num_blocks_in_group = group_size_m * total_blocks_n group_id = pid // num_blocks_in_group group_size = min(total_blocks_m - group_id * group_size_m, group_size_m) pid_m = group_id * group_size_m + (pid % group_size) pid_n = (pid % num_blocks_in_group) // (group_size) offs_n = pid_n * block_size_n + tl.arange(0, block_size_n) offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_size_n), block_size_n) offs_k = tl.arange(0, block_size_k) a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(m,k), strides=(stride_am, stride_ak), offsets=(pid_m*block_size_m, 0), block_shape=(block_size_m, block_size_k), order =(1,0)) b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) scales_ptrs = scales_ptr + offs_bn * stride_scales_n zeros_ptrs = zeros_ptr + ((offs_bn // 8) * stride_zeros_n) shifter = (offs_k % 8) * 4 zeros_shifter = (offs_bn % 8) * 4 acc = tl.zeros((block_size_m, block_size_n), dtype=tl.float32) for k in range(0, total_blocks_k): a = tl.load(a_block_ptr, boundary_check=(0,1)) b = tl.load(b_ptrs) g_id = k // (groupsize // block_size_k) ptr = scales_ptrs + g_id * stride_scales_g scales = tl.load(ptr) ptr = zeros_ptrs + g_id * stride_zeros_g zeros = tl.load(ptr) zeros = (zeros >> zeros_shifter) & 0xF zeros = (zeros + 1) * scales b = (b >> shifter[:, None]) & 0xF b = b * scales[None, :] - zeros[None, :] if fp8_fast_accum: acc = tl.dot(a.to(tl.float), b.to(tl.float8e4nv), acc) else: acc += tl.dot(a,b) a_block_ptr = tl.advance(a_block_ptr, (0, block_size_k)) b_ptrs += (block_size_k//8) * stride_bk acc.to(tl.float16) offs_cm = pid_m * block_size_m + tl.arange(0, block_size_m) offs_cn = pid_n * block_size_n + tl.arange(0, block_size_n) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < n) & (offs_cn[None, :] < n) tl.store(c_ptrs, acc, mask=c_mask) class h100_qlinear(torch.autograd.Function): def forward(ctx, a, b, scales, zeros): m, k = a.shape _, n = b.shape quant_groupsize = 128 block_size_m = 16 block_size_n = 32 block_size_k = 256 group_size_m = 8 num_warps = 4 num_stages = 4 total_blocks_m = triton.cdiv(m, block_size_m) total_blocks_n = triton.cdiv(n, block_size_n) total_programs = total_blocks_m * total_blocks_n grid = (total_programs, 1) fp8_fast_accum = False c = torch.zeros((m, n), device=a.device, dtype=a.dtype) k = _h100_quantized_matmul[grid]( a, b, c, scales, zeros, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), scales.stride(0), scales.stride(1), zeros.stride(0), zeros.stride(1), quant_groupsize, m, n, k, block_size_m, block_size_n, block_size_k, group_size_m, fp8_fast_accum = fp8_fast_accum, num_warps = num_warps, num_stages = num_stages, ) print(f"{total_blocks_m=} x {total_blocks_n=} = {total_programs=}") return c h100_qlinear = h100_qlinear.apply ================================================ FILE: kernels/triton/inference/gptq/mixtral/test_dequant_moe_gemm.py ================================================ import pytest import torch from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.activation import SiluAndMul from triton.kernels.gptq.mixtral.w4a16_fused_dequant_gemm import dequant_gemm_moe from v0_moe_fused import fused_moe as fused_moe_base import time def torch_moe(a, w1, w2, topk_weight, topk_ids): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) out = torch.zeros(B * topk_ids.shape[1], w2.shape[1], dtype=a.dtype, device=a.device) topk_ids = topk_ids.view(-1) topk_weight = topk_weight.view(-1) for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1)).sum(dim=1) def test_dequant_moe( m: int, n: int, k: int, e: int, topk: int, ): m = m n = n k = k e = e topk = topk groupsize = 128 packed_k_dim = k // 8 packed_n_dim = n // 8 g = k // groupsize topk = 2 a = torch.randn((m, k), dtype=torch.float16, device='cuda') qw1 = torch.randint(0, 5, (e, packed_k_dim, n), device='cuda', dtype=torch.int32) qw2 = torch.randint(0, 5, (e, 2*n, packed_k_dim), device='cuda', dtype=torch.int32) qw1_zeros = torch.randint(0, 5, (e, g, packed_n_dim), device='cuda', dtype=torch.int32) qw2_zeros = torch.randint(0, 5, (e, g, packed_n_dim), device='cuda', dtype=torch.int32) qw1_scales = torch.randn((e, g, n), dtype=torch.float16, device='cuda') qw2_scales = torch.randn((e, g, n), dtype=torch.float16, device='cuda') score = torch.randn((m, e), device='cuda', dtype=torch.float16) score = torch.softmax(score, dim=-1) _, topk_ids = torch.topk(score, topk) # dtype = torch.float16 # a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 # w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 # w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 # score = torch.randn((m, e), device='cuda', dtype=dtype) # score = torch.softmax(score, dim=-1) # topk_weight, topk_ids = torch.topk(score, topk) # triton_output_base = fused_moe_base(a, w1, w2, topk_weight, topk_ids, False) # print(triton_output_base) # breakpoint() c = dequant_gemm_moe(a, qw1, qw2, qw1_scales, qw2_scales, qw1_zeros, qw2_zeros, topk_ids, ) # print(c) # assert torch.allclose(triton_output_splitk, torch_output, atol=1e-1, rtol=0) if __name__ == '__main__': test_dequant_moe(2, 14336//2, 4096, 8, 2) ================================================ FILE: kernels/triton/inference/gptq/mixtral/w4a16_fused_dequant_gemm.py ================================================ """Fused MoE W4A16 Kernel.""" import torch import triton import triton.language as tl from vllm._C import ops @triton.jit def print_tensor_dim(tensor, str_name): if tl.program_id(0) == 0 and tl.program_id(1) == 0: tl.static_print(str_name," ",tensor.shape," ",tensor.dtype) @triton.jit def print_value(value): if tl.program_id(0) == 0 and tl.program_id(1) == 0: tl.device_print(str(value)) @triton.jit() def grouped_launch(pid, m, n, block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr): grid_m = tl.cdiv(m, block_m) grid_n = tl.cdiv(n, block_n) width = group_m * grid_n group_id = pid // width group_size = min(grid_m - group_id * group_m, group_m) pid_m = group_id * group_m + (pid % group_size) pid_n = (pid % width) // group_size return pid_m, pid_n @triton.jit() def col_major(pid, m, n, num_tokens_post_padded, block_m: tl.constexpr, block_n: tl.constexpr): grid_m = tl.cdiv(m, block_m) grid_n = tl.cdiv(n, block_n) pid_m = (pid % grid_n) pid_n = pid // grid_m return pid_m, pid_n @triton.jit() def w4a16_fused_moe_kernel( # Pointers to matrices a_ptr, b_ptr, c_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr, # Quantization Scales and Zeros Ptr scales_ptr, zeros_ptr, # Matrix dimensions N, K, EM, num_valid_tokens, # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` # by to get the element one row down (A has M rows). stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, # Quantization Scales and Zeros Strides stride_scales_e, stride_scales_g, stride_scales_n, stride_zeros_e, stride_zeros_g, stride_zeros_n, # Meta-parameters groupsize: tl.constexpr, top_k: tl.constexpr, block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr, group_m: tl.constexpr, ): pid = tl.program_id(0) # GEMM Schedule pid_m, pid_n = grouped_launch(pid, EM, N, block_m, block_n, group_m) grid_k = tl.cdiv(K, block_k) num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * block_m >= num_tokens_post_padded: return # Offset Calculations offs_token_id = pid_m*block_m + tl.arange(0, block_m) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) offs_bn = (pid_n * block_n + tl.arange(0, block_n)) % N # NOTE: No change needed here since weights are packed along K dim offs_k = tl.arange(0, block_k) off_experts = tl.load(expert_ids_ptr + pid_m) # Mask for Activations token_mask = offs_token < num_valid_tokens # Pointer Calculations a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) #NOTE: offs_token[:, None] // top_k -> since each row of activations repeats top_k times b_ptrs = b_ptr + off_experts * stride_be + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) #NOTE: offs_k[:, None] // 8 -> since B is packed along k dim is packed # We need to handle the e dim of the scales and zeros pointers # We can do this in the same fashion that the stacked expert weight matrix is handled # off_experts = tl.load(expert_ids_ptr + pid_m) # b_ptr + off_experts * stride_be + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) scales_ptrs = scales_ptr + off_experts * stride_scales_e + offs_bn * stride_scales_n zeros_ptrs = zeros_ptr + off_experts * stride_zeros_e + ((offs_bn//8) * stride_zeros_n) shifter = (offs_k % 8) * 4 zeros_shifter = (offs_bn % 8) * 4 acc = tl.zeros([block_m, block_n], dtype=tl.float32) for k in range(0, grid_k): a = tl.load(a_ptrs, mask=token_mask[:, None] & (offs_k[None, :] < K - k * block_k), other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * block_k, other=0.0) g_id = k // (groupsize // block_k) ptr = scales_ptrs + g_id * stride_scales_g scales = tl.load(ptr) ptr = zeros_ptrs + g_id * stride_zeros_g zeros = tl.load(ptr) zeros = (zeros >> zeros_shifter) & 0xF zeros = (zeros + 1) * scales b = (b >> shifter[:, None]) & 0xF b = b * scales[None, :] - zeros[None, :] acc += tl.dot(a, b) a_ptrs += block_k * stride_ak b_ptrs += (block_k // 8) * stride_bk acc.to(tl.float16) offs_m = pid_m*block_m + tl.arange(0, block_m) offs_n = pid_n*block_n + tl.arange(0, block_n) c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn) tl.store(c_ptrs, acc) def invoke_dequant_gemm_moe(activations: torch.Tensor, qweight: torch.Tensor, c: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, topk: torch.Tensor, ): EM = sorted_token_ids.shape[0] N = qweight.shape[1] K = qweight.shape[2] block_m = 32 block_n = 32 block_k = 32 group_m = 8 groupsize = 128 topk = 2 if topk_ids.numel() <= qweight.shape[0]: block_m = 16 block_n = 128 block_k = 128 group_m = 8 total_blocks_m = triton.cdiv(EM, block_m) total_blocks_n = triton.cdiv(N, block_n) grid = (total_blocks_m * total_blocks_n,) w4a16_fused_moe_kernel[grid]( activations, qweight, c, sorted_token_ids, expert_ids, num_tokens_post_padded, scales, qzeros, N, K, EM, topk_ids.numel(), activations.stride(0), activations.stride(1), qweight.stride(0), qweight.stride(2), qweight.stride(1), c.stride(1), c.stride(2), scales.stride(0), scales.stride(1), scales.stride(2), qzeros.stride(0), qzeros.stride(1), qzeros.stride(2), groupsize=groupsize, top_k=topk, block_m=block_m, block_n=block_n, block_k=block_k, group_m=group_m, ) def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor): """ Aligns the token distribution across experts to be compatible with block size for matrix multiplication. Parameters: - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token. - block_size: The block size used in block matrix multiplication. - num_experts: The total number of experts. Returns: - sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert. - expert_ids: A tensor indicating the assigned expert index for each block. - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size. This function pads the number of tokens that each expert needs to process so that it is divisible by block_size. Padding ensures that during block matrix multiplication, the dimensions align correctly. Example: Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4: - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens. - As block_size is 4, we pad 1 token for each expert. - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. - Then append padding tokens [12, 12, 12, 12] for each block. - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication. - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. """ sorted_ids = torch.empty( (topk_ids.numel() + num_experts * (block_size - 1), ), dtype=torch.int32, device=topk_ids.device) expert_ids = torch.empty((topk_ids.numel() + num_experts, ), dtype=torch.int32, device=topk_ids.device) sorted_ids.fill_(topk_ids.numel()) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad) return sorted_ids, expert_ids, num_tokens_post_pad def dequant_gemm_moe(hidden_states: torch.Tensor, qw1: torch.Tensor, qw2: torch.Tensor, scales_qw1: torch.Tensor, scales_qw2: torch.Tensor, zeros_qw1: torch.Tensor, zeros_qw2: torch.Tensor, topk_ids: torch.Tensor, ): # Check constraints. # assert hidden_states.shape[1] == qw1.shape[2], "Incompatible dimensions" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert qw1.is_contiguous(), "Expert weights1 must be contiguous" assert qw2.is_contiguous(), "Expert weights2 must be contiguous" # assert hidden_states.dtype in [torch.float16, torch.bfloat16] M, _ = hidden_states.shape E, N, _ = qw1.shape block_m = 32 if topk_ids.numel() <= qw1.shape[0]: block_m = 16 intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), device=hidden_states.device, dtype=hidden_states.dtype) intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), device=hidden_states.device, dtype=hidden_states.dtype) intermediate_cache3 = torch.empty((M, topk_ids.shape[1], qw2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype) sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, block_m, E) invoke_dequant_gemm_moe(hidden_states, qw1, intermediate_cache1, scales_qw1, zeros_qw1, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, topk_ids.shape[1],) # return torch.sum(intermediate_cache1.view(*intermediate_cache1.shape), dim=1) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) invoke_dequant_gemm_moe(intermediate_cache2, qw2, intermediate_cache3, scales_qw2, zeros_qw2, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, 1,) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) ================================================ FILE: kernels/triton/inference/gptq/small_benchmark_cuda_graphs.py ================================================ import torch import triton from triton import language as tl import sys import marlin import torch.nn as nn from auto_gptq.utils.import_utils import dynamically_import_QuantLinear from auto_gptq.modeling._utils import autogptq_post_init @triton.jit() def swizzle_tile(pid, m, n, block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr): grid_m = tl.cdiv(m, block_m) grid_n = tl.cdiv(n, block_n) width = group_m * grid_n group_id = pid // width group_size = tl.minimum(grid_m - group_id * group_m, group_m) pid_m = group_id * group_m + (pid % group_size) pid_n = (pid % width) // group_size return pid_m, pid_n @triton.jit() def matmul_data_parallel_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales_g, stride_scales_n, stride_zeros_g, stride_zeros_n, groupsize, m, n, k, block_size_m: tl.constexpr, block_size_n: tl.constexpr, block_size_k: tl.constexpr, group_size_m: tl.constexpr, fp8_fast_accum: tl.constexpr,): pid = tl.program_id(0) total_blocks_m = tl.cdiv(m, block_size_m) total_blocks_n = tl.cdiv(n, block_size_n) total_blocks_k = tl.cdiv(k, block_size_k) num_blocks_in_group = group_size_m * total_blocks_n group_id = pid // num_blocks_in_group group_size = min(total_blocks_m - group_id * group_size_m, group_size_m) pid_m = group_id * group_size_m + (pid % group_size) pid_n = (pid % num_blocks_in_group) // (group_size) offs_m = (pid_m * block_size_m + tl.arange(0, block_size_m)) % m offs_n = (pid_n * block_size_n + tl.arange(0, block_size_n)) % n offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_size_m), block_size_m) offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_size_n), block_size_n) offs_k = tl.arange(0, block_size_k) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (16, 64) b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) scales_ptrs = scales_ptr + offs_bn * stride_scales_n zeros_ptrs = zeros_ptr + ((offs_bn // 8) * stride_zeros_n) shifter = (offs_k % 8) * 4 zeros_shifter = (offs_bn % 8) * 4 output = tl.zeros((block_size_m, block_size_n), dtype=tl.float32) for k in range(0, total_blocks_k): a = tl.load(a_ptrs) b = tl.load(b_ptrs) # tl.device_print("data parallel b: ", b) g_id = k // (groupsize // block_size_k) ptr = scales_ptrs + g_id * stride_scales_g scales = tl.load(ptr) ptr = zeros_ptrs + g_id * stride_zeros_g zeros = tl.load(ptr) zeros = (zeros >> zeros_shifter) & 0xF zeros = (zeros + 1) * scales b = (b >> shifter[:, None]) & 0xF # b is int32 b = b * scales[None, :] - zeros[None, :] # b is fp16 # output += tl.dot(a, b) # output += tl.sum(a, b, axis=0) # print(b.type) # result = a[:, None] * b # (1 x 64 x 64 x 32) x illegal # (NEED A SQUARE MATRIX for B) # b -> 64 x 64 instead 64 x 32 output += tl.dot(a, b) # a_block_ptr = tl.advance(a_block_ptr, (0, block_size_k)) a_ptrs += stride_ak * block_size_k b_ptrs += (block_size_k//8) * stride_bk output.to(tl.float16) offs_cm = pid_m * block_size_m + tl.arange(0, block_size_m) offs_cn = pid_n * block_size_n + tl.arange(0, block_size_n) c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) tl.store(c_ptrs, output) class small_qlinear(torch.autograd.Function): def forward(ctx, a, b, scales, zeros): m, k = a.shape _, n = b.shape quant_groupsize = 128 block_size_m = 64 block_size_n = 64 # [N = 4096 // 32] = 128 blocks block_size_k = 64 group_size_m = 8 num_warps = 4 num_stages = 8 total_blocks_m = triton.cdiv(m, block_size_m) total_blocks_n = triton.cdiv(n, block_size_n) total_programs = total_blocks_m * total_blocks_n grid = (total_programs, 1) fp8_fast_accum = False c = torch.zeros((m, n), device=b.device, dtype=torch.float16) # output = torch.em k = matmul_data_parallel_kernel[grid]( a, b, c, scales, zeros, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), scales.stride(0), scales.stride(1), zeros.stride(0), zeros.stride(1), quant_groupsize, m, n, k, block_size_m, block_size_n, block_size_k, group_size_m, fp8_fast_accum = fp8_fast_accum, num_warps = num_warps, num_stages = num_stages, ) print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n") with open('dequant_simple.txt', 'w') as f: print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n", file=f) print("IR", k.asm['ttir'], file=f) print("TTGIR", k.asm['ttgir'], file=f) print("PTX", k.asm['ptx'], file=f) print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n", file=f) print(f"{total_blocks_m=} x {total_blocks_n=} = {total_programs=}") return c matmul_data_parallel = small_qlinear.apply @triton.jit() def matmul_split_k_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales_g, stride_scales_n, stride_zeros_g, stride_zeros_n, groupsize, m, n, k, block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr, group_m: tl.constexpr, split_k: tl.constexpr): pid = tl.program_id(0) pid_k = tl.program_id(1) num_pid_k = tl.cdiv(k, block_k*split_k) pid_m, pid_n = swizzle_tile(pid, m, n, block_m, block_n, group_m) offs_m = pid_m*block_m + tl.arange(0, block_m) offs_n = pid_n*block_n + tl.arange(0, block_n) offs_k = pid_k*block_k + tl.arange(0, block_k) offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m) offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) scales_ptrs = scales_ptr + offs_bn * stride_scales_n zeros_ptrs = zeros_ptr + ((offs_bn // 8) * stride_zeros_n) shifter = (offs_k % 8) * 4 zeros_shifter = (offs_bn % 8) * 4 acc = tl.zeros((block_m, block_n), dtype=tl.float32) for k in range(0, num_pid_k): a = tl.load(a_ptrs) b = tl.load(b_ptrs) g_id = k // (groupsize // (block_k*split_k)) ptr = scales_ptrs + g_id * stride_scales_g scales = tl.load(ptr) # -> 1D naive assumes no reordering ptr = zeros_ptrs + g_id * stride_zeros_g zeros = tl.load(ptr) # -> 1D naive assumes no reordering zeros = (zeros >> zeros_shifter) & 0xF zeros = (zeros + 1) * scales b = (b >> shifter[:, None]) & 0xF # b is int32 b = b * scales[None, :] - zeros[None, :] acc += tl.dot(a, b) a_ptrs += block_k * split_k * stride_ak b_ptrs += (block_k//8) * split_k * stride_bk acc.to(tl.float16) offs_cm = pid_m*block_m + tl.arange(0, block_m) offs_cn = pid_n*block_n + tl.arange(0, block_n) c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) tl.atomic_add(c_ptrs, acc) def matmul_split_k(a, b, scales, zeros): m, k = a.shape _, n = b.shape quant_groupsize = 128 block_m = 16 block_n = 32 block_k = 128 group_m = 8 num_stages = 3 num_warps = 4 split_k = 4 total_blocks_m = triton.cdiv(m, block_m) total_blocks_n = triton.cdiv(n, block_n) total_programs_mn = total_blocks_m * total_blocks_n total_programs_k = split_k grid = (total_programs_mn, total_programs_k) # print(f"problem m size: {m}, tile size m: {block_m}, total blocks m: {total_blocks_m}") # print(f"problem n size: {n}, tile size n: {block_n}, total blocks n: {total_blocks_n}") # print(f"problem k size: {k}, tile size k: {block_k}, total thread blocks k: {split_k}") # print(f"total thread blocks k: {k}, total thread blocks m and total thread blocks n = {total_blocks_m=} x {total_blocks_n} = {total_programs_mn}") # print(f"{total_programs_mn=}, {total_programs_k=}") c = torch.zeros((m, n), device=a.device, dtype=torch.float16) k = matmul_split_k_kernel[grid](a, b, c, scales, zeros, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), scales.stride(0), scales.stride(1), zeros.stride(0), zeros.stride(1), quant_groupsize, m, n, k, block_m, block_n, block_k, group_m, split_k, num_stages=num_stages, num_warps=num_warps) # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n") # with open('matmul_split_k.txt', 'w') as f: # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n", file=f) # print("IR", k.asm['ttir'], file=f) # print("TTGIR", k.asm['ttgir'], file=f) # print("PTX", k.asm['ptx'], file=f) # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n", file=f) return c def make_tensor(M, N, dtype): if dtype == torch.int32: # Fill with random integers for int32 type res = torch.randint(low=-2**31, high=2**31, size=(M, N), dtype=dtype, device="cuda") else: # Fill with normally distributed random values for other types res = torch.empty((M, N), dtype=dtype, device="cuda") res.normal_(mean=0.0, std=0.5) return res def gen_quant4(m, n, groupsize=-1): tile = 16 maxq = 2 ** 4 - 1 w = torch.randn((m, n), dtype=torch.half, device="cuda") if groupsize != -1: w = w.reshape((-1, groupsize, n)) w = w.permute(1, 0, 2) w = w.reshape((groupsize, -1)) s = torch.max(torch.abs(w), 0, keepdim=True)[0] s *= 2 / maxq w = torch.round(w / s).int() w += (maxq + 1) // 2 w = torch.clamp(w, 0, maxq) ref = (w - (maxq + 1) // 2).half() * s if groupsize != -1: def reshape(w): w = w.reshape((groupsize, -1, n)) w = w.permute(1, 0, 2) w = w.reshape((m, n)).contiguous() return w ref = reshape(ref) w = reshape(w) s = s.reshape((-1, n)).contiguous() linear = nn.Linear(m, n) linear.weight.data = ref.t() # Workaround to test some special cases that are forbidden by the API layer = marlin.Layer(256, 256, groupsize=groupsize) if groupsize == -1: groupsize = m layer.k = m layer.n = n layer.groupsize = groupsize layer.B = torch.empty((m // 16, n * 16 // 8), dtype=torch.int, device="cuda") layer.s = torch.empty((m // groupsize, n), dtype=torch.half, device="cuda") layer.pack(linear, s.t()) q = layer.B s = layer.s return ref, q, s if __name__ == '__main__': m = 16 k = 4096 n = 4096 groupsize = 128 g = k // groupsize a = make_tensor(m, k, dtype=torch.float16) b = make_tensor(k//8, n, dtype=torch.int32) c = make_tensor(m, n, dtype=torch.float16) workspace = torch.zeros(n//128*16, device="cuda") zeros = make_tensor(g, n//8, torch.int32) scales = make_tensor(g, n, torch.float16) # Marlin # m, n, k = 16, 4096, 4096 # A = torch.randn((m, k), dtype=torch.half, device="cuda") # B_ref, B, s = gen_quant4(k, n) # C = torch.zeros((m, n), dtype=torch.half, device="cuda") # workspace = torch.zeros(n // 128*16, device="cuda") output_marlin = marlin.mul(a, b, c, scales, workspace, sms=108) output_split_k = matmul_split_k(a, b, scales, zeros) nbits = 4 group_size=128 disable_exllama=True disable_exllamav2=False use_triton = False linear_class = dynamically_import_QuantLinear( disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, use_triton=use_triton, desc_act=False, group_size=group_size, bits=nbits) linear = linear_class( bits=nbits, group_size=group_size, infeatures=k, outfeatures=n, bias=0, ) device = torch.device('cuda') linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32) linear.scales = linear.scales + 0.002 linear = linear.eval().to(device) linear = autogptq_post_init(linear, use_act_order=False) b_fake = torch.randn((k, n), dtype=torch.float16, device="cuda") # Warmup for i in range(3): linear(a) matmul_split_k(a, b, scales, zeros) torch.matmul(a, b_fake) s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): matmul_split_k(a, b, scales, zeros) torch.cuda.current_stream().wait_stream(s) # capture g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): matmul_split_k(a, b, scales, zeros) for i in range(7): torch.matmul(a, b_fake) for i in range(7): linear(a) for i in range(7): g.replay() # This replays the captured operations in the graph for i in range(7): matmul_data_parallel(a, b, scales, zeros) for i in range(7): matmul_split_k(a, b, scales, zeros) ================================================ FILE: kernels/triton/inference/gptq/splitk_dequant_gemm.py ================================================ import torch import triton from triton import language as tl # from actual_base_gptq_4 import triton_matmul4 @triton.jit() def swizzle_tile(pid, m, n, block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr): grid_m = tl.cdiv(m, block_m) grid_n = tl.cdiv(n, block_n) width = group_m * grid_n group_id = pid // width group_size = tl.minimum(grid_m - group_id * group_m, group_m) pid_m = group_id * group_m + (pid % group_size) pid_n = (pid % width) // group_size return pid_m, pid_n @triton.jit() def matmul_split_k_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales_g, stride_scales_n, stride_zeros_g, stride_zeros_n, groupsize, m, n, k, block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr, group_m: tl.constexpr, split_k: tl.constexpr): pid = tl.program_id(0) pid_k = tl.program_id(1) total_blocks_k = tl.cdiv(k, block_k*split_k) pid_m, pid_n = swizzle_tile(pid, m, n, block_m, block_n, group_m) offs_m = pid_m*block_m + tl.arange(0, block_m) offs_n = pid_n*block_n + tl.arange(0, block_n) offs_k = pid_k*block_k + tl.arange(0, block_k) offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m) offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) scales_ptrs = scales_ptr + offs_bn * stride_scales_n zeros_ptrs = zeros_ptr + ((offs_bn // 8) * stride_zeros_n) shifter = (offs_k % 8) * 4 zeros_shifter = (offs_bn % 8) * 4 acc = tl.zeros((block_m, block_n), dtype=tl.float32) for k in range(0, total_blocks_k): a = tl.load(a_ptrs) b = tl.load(b_ptrs) g_id = (k * split_k + pid_k) // (groupsize // block_k) ptr = scales_ptrs + g_id * stride_scales_g scales = tl.load(ptr) ptr = zeros_ptrs + g_id * stride_zeros_g zeros = tl.load(ptr) zeros = (zeros >> zeros_shifter) & 0xF zeros = (zeros + 1) * scales b = (b >> shifter[:, None]) & 0xF b = b * scales[None, :] - zeros[None, :] acc += tl.dot(a, b) a_ptrs += block_k * split_k * stride_ak b_ptrs += (block_k // 8) * split_k * stride_bk acc.to(tl.float16) offs_m = pid_m*block_m + tl.arange(0, block_m) offs_n = pid_n*block_n + tl.arange(0, block_n) c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn) tl.atomic_add(c_ptrs, acc, sem='release') def matmul_split_k(a, b, scales, zeros): m, k = a.shape _, n = b.shape quant_groupsize = 128 block_m = 16 block_n = 32 block_k = 128 group_m = 8 num_stages = 3 num_warps = 4 split_k = 4 total_blocks_m = triton.cdiv(m, block_m) total_blocks_n = triton.cdiv(n, block_n) total_programs_mn = total_blocks_m * total_blocks_n total_programs_k = split_k grid = (total_programs_mn, total_programs_k) print(f"problem m size: {m}, tile size m: {block_m}, total blocks m: {total_blocks_m}") print(f"problem n size: {n}, tile size n: {block_n}, total blocks n: {total_blocks_n}") print(f"problem k size: {k}, tile size k: {block_k}, total thread blocks k: {split_k}") print(f"total thread blocks k: {k}, total thread blocks m and total thread blocks n = {total_blocks_m=} x {total_blocks_n} = {total_programs_mn}") print(f"{total_programs_mn=}, {total_programs_k=}") c = torch.zeros((m, n), device=a.device, dtype=torch.float16) k = matmul_split_k_kernel[grid](a, b, c, scales, zeros, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), scales.stride(0), scales.stride(1), zeros.stride(0), zeros.stride(1), quant_groupsize, m, n, k, block_m, block_n, block_k, group_m, split_k, num_stages=num_stages, num_warps=num_warps) print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n") with open('matmul_split_k.txt', 'w') as f: print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n", file=f) print("IR", k.asm['ttir'], file=f) print("TTGIR", k.asm['ttgir'], file=f) print("PTX", k.asm['ptx'], file=f) print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n", file=f) return c def make_tensor(M, N, dtype): if dtype == torch.int32: # Fill with random integers for int32 type res = torch.randint(low=-2147483648, high=2147483647, size=(M, N), dtype=dtype, device="cuda") else: # Fill with normally distributed random values for other types res = torch.empty((M, N), dtype=dtype, device="cuda") res.normal_(mean=0.0, std=0.5) return res if __name__ == '__main__': m = 16 k = 4096 n = 4096 groupsize = 128 g = k // groupsize a = make_tensor(m, k, dtype=torch.float16) b = make_tensor(k//8, n, dtype=torch.int32) c = make_tensor(m, n, dtype=torch.float16) zeros = make_tensor(g, n//8, torch.int32) scales = make_tensor(g, n, torch.float16) # base = no_autotune(groupsize, a, b, scales, zeros) # print(f"{base.shape=}, {base[0][0:4]}") # c = custom_qlinear(a, b, scales, zeros) # print(f"{c.shape=}, {c[0][0:4]}") split_k_output = matmul_split_k(a, b, scales, zeros) print(f"{split_k_output.shape=}, {split_k_output[0][0:4]}") ================================================ FILE: kernels/triton/inference/mamba/causal_1d_conv/causal_1d_conv/causal_1d_conv.py ================================================ # Copyright (c) 2025, IBM Research import torch import triton import triton.language as tl from einops import rearrange from typing import Literal, Optional @triton.autotune( configs=[ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_stages=3, num_warps=8), triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_stages=3, num_warps=8), ], key=["seqlen", "dim", "batch"], ) @triton.jit() def _causal_conv1d_fwd_kernel( # Pointers to matrices x_ptr, # (batch, dim, seqlen) w_ptr, # (dim, width) bias_ptr, initial_states_ptr, o_ptr, # (batch, dim, seqlen) # Matrix dimensions batch, dim, seqlen, # Strides stride_x_seq, # stride to get to next sequence, stride_x_dim, # stride to get to next feature-value, stride_x_token, # stride to get to next token (same feature-index, same sequence-index) stride_weight_dim, # stride to get to next dim-axis value stride_weight_width, # stride to get to next width-axis value stride_istate_seq, stride_istate_dim, stride_istate_token, stride_o_seq, stride_o_dim, stride_o_token, # Meta-parameters HAS_BIAS: tl.constexpr, KERNEL_WIDTH: tl.constexpr, # maybe using this we don't need 'width' SILU_ACTIVATION: tl.constexpr, HAS_INITIAL_STATES: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): indices_0 = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) idx_seqs = indices_0 // seqlen idx_tokens = indices_0 % seqlen x_base = x_ptr + (idx_seqs * stride_x_seq)[:, None] # the beginning features at all tokens at all sequences processed by this Triton program idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) w_base = w_ptr + (idx_feats * stride_weight_dim) # first kernel column, configured for weights to handle BLOCK_N features in range load_init_state = False if HAS_INITIAL_STATES: load_init_state = tl.min(idx_tokens) < KERNEL_WIDTH - 1 initial_states_base = initial_states_ptr + (idx_seqs * stride_istate_seq)[:, None] + (idx_feats * stride_istate_dim)[None, :] # store output data at the corresponding tokens (BLOCK_M of them) and feature-indices (BLOCK_N of them) in these tokens if HAS_BIAS: bias = bias_ptr + idx_feats mask_bias = idx_feats < dim acc = tl.load(bias, mask=mask_bias, other=0.0).to(tl.float32)[None, :] # [BLOCK_N] acc = tl.broadcast_to(acc, (BLOCK_M, BLOCK_N)) else: acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) PADDING_W = KERNEL_WIDTH - 1 for j in range(KERNEL_WIDTH): idx_x_w = j - PADDING_W + idx_tokens # the token index to multiply with kernel[:, 0], given kernel with width-columns, i.e. kernel[:, 0..(width-1)] x_ptrs = x_base + ((idx_x_w * stride_x_token)[:, None] + (idx_feats * stride_x_dim)[None, :]) # [BLOCK_M, BLOCK_N] mask_x = ((idx_seqs < batch)[:, None] # sequence-index & (idx_x_w >= 0)[:, None] # token-index & (idx_x_w < seqlen)[:, None] # token-index & (idx_feats < dim)[None, :] # feature-index ) if HAS_INITIAL_STATES: if load_init_state: initial_states_ptrs = initial_states_base + ((idx_x_w + KERNEL_WIDTH - 1) * stride_istate_token)[:, None] # [BLOCK_M, BLOCK_N] mask_w = (idx_seqs < batch)[:, None] & (idx_x_w < 0)[:, None] & (idx_feats < dim)[None, :] # sequence-index # token-index # feature-index initial_states = tl.load(initial_states_ptrs, mask_w, 0.0) else: initial_states = tl.zeros((BLOCK_M, BLOCK_N), dtype=x_ptr.dtype.element_ty) matrix_x = tl.load(x_ptrs, mask=mask_x, other=initial_states) else: matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) w_ptrs = w_base[None, :] + \ (j * stride_weight_width) # [1, BLOCK_N] tensor mask_w = (idx_feats < dim)[None, :] matrix_w = tl.load(w_ptrs, mask_w, other=0.0) acc += matrix_x * matrix_w if SILU_ACTIVATION: acc = acc / (1 + tl.exp(-acc)) mask = ( (idx_seqs < batch)[:, None] # sequence-index & (idx_tokens < seqlen)[:, None] # token-index & (idx_feats < dim)[None, :] # feature-index ) o_ptrs = ( o_ptr + (idx_seqs * stride_o_seq)[:, None] + (idx_tokens * stride_o_token)[:, None] + (idx_feats * stride_o_dim)[None, :] ) tl.store(o_ptrs, acc, mask=mask) def causal_conv1d_fwd( x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, seq_idx: Optional[torch.Tensor] = None, initial_states: Optional[torch.Tensor] = None, return_final_states: Optional[torch.Tensor] = False, final_states_out: Optional[torch.Tensor] = None, activation: Optional[Literal["silu", "swish"]] = None, ): batch, dim, seqlen = x.shape _, width = weight.shape assert (dim, width) == weight.shape assert x.stride(2) == 1 or x.stride(1) == 1 # TODO: we may want to use weight such that weight.stride(dim)==1 assert weight.stride(1) == 1 # Tensor layout as NHWC is called channel last with 'C' is time-dimension is_channel_last = (x.stride(1) == 1) & (x.stride(2) > 1) stride_w_dim = weight.stride(0) stride_w_width = weight.stride(1) # effort to make data contiguous along dim-axis: weight = weight.transpose(0, 1).contiguous() stride_w_dim = weight.stride(1) stride_w_width = weight.stride(0) # assert initial_states is None # only this for now assert return_final_states is False stride_istate_seq = 0 stride_istate_dim = 0 stride_istate_token = 0 if initial_states is not None: assert (batch, dim, width - 1) == initial_states.shape stride_istate_seq = initial_states.stride(0) stride_istate_dim = initial_states.stride(1) stride_istate_token = initial_states.stride(2) assert stride_istate_dim == 1 out = torch.empty_like(x) if not is_channel_last: assert 0, "Need to run in channel-last layout" else: def grid(META): return ( triton.cdiv(batch * seqlen, META["BLOCK_M"]), triton.cdiv(dim, META["BLOCK_N"]), ) with torch.cuda.device(x.device.index): _causal_conv1d_fwd_kernel[grid]( # Pointers to matrices x, weight, bias, initial_states, out, # Matrix dimensions batch, dim, seqlen, # stride x.stride(0), x.stride(1), x.stride(2), stride_w_dim, stride_w_width, stride_istate_seq, stride_istate_dim, stride_istate_token, out.stride(0), out.stride(1), out.stride(2), # META HAS_BIAS=bias is not None, KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], HAS_INITIAL_STATES=initial_states is not None, ) return out class CausalConv1dFn(torch.autograd.Function): @staticmethod def forward( ctx, x, weight, bias=None, seq_idx=None, initial_states=None, return_final_states: bool = False, final_states_out=None, activation: Optional[Literal["silu", "swish"]] = None, ): # NOTE: in fact, 'beta=1' would turn swish into silu - and only silu form is used if x.stride(2) != 1 and x.stride(1) != 1: x = x.contiguous() bias = bias.contiguous() if bias is not None else None if seq_idx is not None: assert initial_states is None, "initial_states must be None if seq_idx is not None" assert not return_final_states, "If seq_idx is not None, we don't return final_states_out" seq_idx = seq_idx.contiguous() if seq_idx is not None else None if initial_states is not None and ((initial_states.stride(2) != 1) and (initial_states.stride(1) != 1)): initial_states = initial_states.contiguous() if return_final_states: assert ( x.stride(1) == 1 ), "Only channel-last layout support returning final_states_out" if final_states_out is not None: assert ( (final_states_out.stride(2) == 1) or ( final_states_out.stride(1) == 1) ) else: batch, dim, seqlen = x.shape width = weight.shape[1] final_states_out = torch.empty( batch, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) else: final_states_out = None ctx.activation = activation out = causal_conv1d_fwd( x, weight, bias=bias, seq_idx=seq_idx, initial_states=initial_states, return_final_states=return_final_states, final_states_out=final_states_out, activation=ctx.activation, ) ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) ctx.return_final_states = return_final_states ctx.return_dinitial_states = initial_states is not None and initial_states.requires_grad return out if not return_final_states else (out, final_states_out) @staticmethod def backward(ctx, dout, *args): """dout = dL/dy RETURN: dL/dx, dL/dweight, dL/dbias, ... GIVEN THAT: def forward(ctx, x, weight, bias=None...) """ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors dfinal_states = args[0] if ctx.return_final_states else None if dout.stride(2) != 1 and dout.stride(1) != 1: dout = dout.contiguous() # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the # backward of conv1d with the backward of chunk). # Here we just pass in None and dx will be allocated in the C++ code. dx, dweight, dbias, dinitial_states = causal_conv1d_bwd( x, weight, bias, dout, seq_idx, initial_states, dfinal_states, None, ctx.return_dinitial_states, ctx.activation, ) return ( dx, dweight, dbias if bias is not None else None, None, dinitial_states if initial_states is not None else None, None, None, None, ) def causal_conv1d_fn( x, # channel last, i.e. (batch, dim, seqlen) weight, # (dim, w) bias=None, # (dim, )scalar seq_idx=None, initial_states=None, return_final_states=False, final_states_out=None, activation: Optional[Literal["silu", "swish"]] = None, ): """causal_conv1d_fn. :param x: (batch, dim, seqlen) tensor :param weight: (dim, w) tensor :param bias: (dim,) tensor :param activation: ["silu", "swish"] :param seq_idx=None :param initial_states=None :param return_final_states=False :param final_states_out=None Return: (batch, dim, seqlen) tensor """ if weight.dim() == 3: assert weight.shape[1] == 1 weight = rearrange(weight, "d 1 w -> d w") return CausalConv1dFn.apply( x, weight, bias, seq_idx, initial_states, return_final_states, final_states_out, activation, ) ================================================ FILE: kernels/triton/inference/mamba/causal_1d_conv/tests/test_causal_1d_conv.py ================================================ # Copyright (C) 2025, IBM Research. # python -m pytest tests/test_causal_conv1d.py import sys from einops import rearrange import pytest import torch.nn.functional as F import torch import math import os from pathlib import Path base_path = Path(os.path.abspath(os.path.dirname(os.path.realpath(__file__)))) sys.path.insert(0, str(base_path / "../causal_1d_conv")) try: from causal_1d_conv import causal_conv1d_fn except ImportError: raise def _undecorated_test_causal_conv1d( batch, dim, seqlen, width, has_bias, silu_activation, itype, channel_last, has_initial_states, return_final_states, check_backward, ): if not channel_last and (has_initial_states or return_final_states): pytest.skip("Only channel_last support initial_states or return_final_states") device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 rtolw, atolw = (1e-3, 1e-3) # set seed torch.random.manual_seed(0) if not channel_last: x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[ :, 4096: 4096 + dim, : ].requires_grad_() else: x = rearrange( torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096: 4096 + dim], "b s d -> b d s", ).requires_grad_() weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) if has_bias: bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) else: bias = None if has_initial_states: initial_states = torch.randn(batch, width - 1, dim, device=device, dtype=itype).transpose(1, 2).requires_grad_() else: initial_states = None x_ref = x.detach().clone().requires_grad_() weight_ref = weight.detach().clone().requires_grad_() bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None initial_states_ref = initial_states.detach().clone().requires_grad_() if initial_states is not None else None activation = None if not silu_activation else "silu" out = causal_conv1d_fn( x, weight, bias, initial_states=initial_states, return_final_states=return_final_states, activation=activation ) out_ref = causal_conv1d_ref( x_ref, weight_ref, bias_ref, initial_states=initial_states_ref, return_final_states=return_final_states, activation=activation, ) if return_final_states: out, final_states = out out_ref, final_states_ref = out_ref print(f"Final states max diff: {(final_states - final_states_ref).abs().max().item()}") print(f"Final states mean diff: {(final_states - final_states_ref).abs().mean().item()}") assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) if return_final_states: out += F.sigmoid(final_states).sum(dim=-1, keepdim=True) out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True) if check_backward: g = torch.randn_like(out) out.backward(g) out_ref.backward(g) print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}") print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}") if has_bias: print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}") if has_initial_states: print(f"dinitial_states max diff: {(initial_states.grad - initial_states_ref.grad).abs().max().item()}") assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol) assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw) if has_bias: assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw) if has_initial_states: assert torch.allclose(initial_states.grad, initial_states_ref.grad.to(dtype=itype), rtol=rtol, atol=atol) torch.cuda.empty_cache() del x_ref, x, weight, weight_ref, bias, bias_ref, out, out_ref def causal_conv1d_ref( x, weight, bias=None, initial_states=None, return_final_states=False, final_states_out=None, activation=None, ): """[copied from causal_conv1d/causal_conv1d_interface.py] x: (batch, dim, seqlen) weight: (dim, width) bias: (dim,) initial_states: (batch, dim, width - 1) final_states_out: (batch, dim, width - 1) out: (batch, dim, seqlen) """ if activation not in [None, "silu", "swish"]: raise NotImplementedError("activation must be None, silu, or swish") dtype_in = x.dtype x = x.to(weight.dtype) seqlen = x.shape[-1] dim, width = weight.shape if initial_states is None: out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) else: x = torch.cat([initial_states, x], dim=-1) out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) out = out[..., :seqlen] if return_final_states: final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(dtype_in) # (batch, dim, width - 1) if final_states_out is not None: final_states_out.copy_(final_states) else: final_states_out = final_states out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) return out if not return_final_states else (out, final_states_out) @pytest.mark.parametrize("batch", [1, 2, 3, 8, 16, 32, 64]) # END-GOAL # @pytest.mark.parametrize("batch", [2]) @pytest.mark.parametrize("dim", [64, 4096 + 32]) # END-GOAL # @pytest.mark.parametrize('dim', [64]) # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) # @pytest.mark.parametrize('seqlen', [128]) @pytest.mark.parametrize( "seqlen", [1, 2, 8, 16, 32, 64, 128, 129, 130, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] ) # END-GOAL @pytest.mark.parametrize("width", [2, 3, 4, 5]) # END-GOAL # @pytest.mark.parametrize('width', [3]) @pytest.mark.parametrize("has_bias", [False, True]) # END-GOAL # @pytest.mark.parametrize('has_bias', [True]) # @pytest.mark.parametrize('has_bias', [False]) @pytest.mark.parametrize("silu_activation", [False, True]) # END-GOAL # @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) # @pytest.mark.parametrize('itype', [torch.float16]) # @pytest.mark.parametrize("channel_last", [False, True]) @pytest.mark.parametrize("channel_last", [True]) # END-GOAL @pytest.mark.parametrize("has_initial_states", [False, True]) # END-GOAL # @pytest.mark.parametrize("has_initial_states", [False]) # @pytest.mark.parametrize("return_final_states", [False, True]) # END-GOAL @pytest.mark.parametrize("return_final_states", [False]) # @pytest.mark.parametrize('check_backward', [True]) # END-GOAL @pytest.mark.parametrize("check_backward", [False]) def test_causal_conv1d( batch, dim, seqlen, width, has_bias, silu_activation, itype, channel_last, has_initial_states, return_final_states, check_backward, ): return _undecorated_test_causal_conv1d( batch, dim, seqlen, width, has_bias, silu_activation, itype, channel_last, has_initial_states, return_final_states, check_backward, ) ================================================ FILE: kernels/triton/inference/paged_attention/attention_triton.py ================================================ #from einops import rearrange import torch import triton import triton.language as tl # Credit: # vedantroy https://github.com/openai/triton/issues/2200#issuecomment-1815471999 # Expect block table to map # logical bid (block id) -> (physical bid, # filled) # In tests, it maps: logical pid -> physical bid @triton.jit def print_tensor_dim(tensor, str_name): if tl.program_id(0) == 0 and tl.program_id(1) == 0: tl.static_print(str_name," ",tensor.shape," ",tensor.dtype) #tl.static_print('*************** program id: ', tl.program_id(0), tl.program_id(1)) @triton.jit def print_value(value): if tl.program_id(0) == 0 and tl.program_id(1) == 0: tl.device_print(str(value)) #tl.static_print('*************** program id: ', tl.program_id(0), tl.program_id(1)) #tl.static_print(str_name+" ") @triton.jit def print_line(str_line): if tl.program_id(0) == 0 and tl.program_id(1) == 0: print(str_line) #Paged Attention V1: basic version, has a memory limitation error @triton.jit def paged_attention_v1( # need these b/c we can't use view/reshape scratchpad_key_ptr, # [num_seqs, max_seq_len, num_heads, head_size] scratchpad_value_ptr, # [num_seqs, max_seq_len, num_heads, head_size] output_ptr, # [num_seqs, num_query_heads, head_size] query_ptr, # [num_seqs, num_query_heads, head_size] key_cache_ptr, # [num_blocks, num_kv_heads, head_size, block_size] value_cache_ptr, # [num_blocks, num_kv_heads, head_size, block_size] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] context_lens_ptr, # [num_seqs] scale, # float32 num_seqs, # int num_heads, # int cache_block_stride, # int MAX_SEQ_LEN: tl.constexpr, # int (same as max_seq_len) BLOCK_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int, must be power of 2 MAX_NUM_BLOCKS_PER_SEQ: tl.constexpr, # int, must be power of 2 ): seq_idx = tl.program_id(0).to(tl.int64) head_idx = tl.program_id(1).to(tl.int64) #Compute the offsets of the query using the strides #TODO(amorari) use the strides as returned from tensor.stride() instead query_offset = seq_idx * num_seqs + head_idx * HEAD_SIZE query_head = tl.load(query_ptr + query_offset + tl.arange(0, HEAD_SIZE)) #print_tensor_dim(query_head, "query_head") block_table_offset = seq_idx * MAX_NUM_BLOCKS_PER_SEQ #load the context len for this q vector context_len = tl.load(context_lens_ptr + seq_idx) #print_tensor_dim(block_tables_ptr, "block_tables_ptr") #iterate on the tokens for tok_idx in range(0, context_len): logical_block_idx = tok_idx // BLOCK_SIZE #physical block starting pointer for token physical_block_idx = tl.load( block_tables_ptr + block_table_offset + logical_block_idx ) start_of_block_offset = ( physical_block_idx.to(tl.int64) * cache_block_stride + head_idx * HEAD_SIZE * BLOCK_SIZE ) tok_idx_within_block = tok_idx % BLOCK_SIZE tok_offsets = ( start_of_block_offset + BLOCK_SIZE * tl.arange(0, HEAD_SIZE) + tok_idx_within_block ) #Get all blocks for this token tok_key = tl.load(key_cache_ptr + tok_offsets) tok_value = tl.load(value_cache_ptr + tok_offsets) #print_tensor_dim(tok_key, "tok_key") #print_tensor_dim(tok_value, "tok_value") #Compute offsets to write in the scratchpad scratchpad_offset = ( seq_idx.to(tl.int64) * (MAX_SEQ_LEN * num_heads.to(tl.int64) * HEAD_SIZE) + tok_idx.to(tl.int64) * (num_heads * HEAD_SIZE) + head_idx * HEAD_SIZE ) tl.store( scratchpad_key_ptr + scratchpad_offset + tl.arange(0, HEAD_SIZE), tok_key ) tl.store( scratchpad_value_ptr + scratchpad_offset + tl.arange(0, HEAD_SIZE), tok_value, ) # TODO: Not sure if this is necessary tl.debug_barrier() # scratchpad_key_ptr, # [num_seqs, max_seq_len, num_heads, head_size] start_seq_offset = (MAX_SEQ_LEN * num_heads * HEAD_SIZE) * seq_idx start_tok_offset = start_seq_offset + tl.arange(0, MAX_SEQ_LEN) \ * (num_heads * HEAD_SIZE) + head_idx * HEAD_SIZE # [seq_len, head_size] # zero out keys that aren't part of the sequence mask = tl.arange(0, MAX_SEQ_LEN)[:, None] < context_len kv_offs = start_tok_offset[:, None] + tl.arange(0, HEAD_SIZE)[None, :] print_tensor_dim(kv_offs, "kv_offs_v1") keys = tl.load(scratchpad_key_ptr + kv_offs, mask=mask, other=0.0) print_tensor_dim(keys, "keys_v1") values = tl.load(scratchpad_value_ptr + kv_offs, mask=mask, other=0.0) print_tensor_dim(values, "values_v1") # keys shape [seq_len x head_size], query shape = [head_size] # Can't do below b/c minimum size on all dimensions is 16 # scores = tl.dot(query_head[None, :], keys.T) scores = tl.sum(scale * keys * query_head[None, :], axis=1) # This mask is necessary b/c even though we mask out the keys on load # that just results in 0s in the attention dot product, # which then get softmaxed and result in non-zero values # in the softmax output (which is wrong) # -inf guarantees that the softmax output will be 0 for masked values mask = tl.full([MAX_SEQ_LEN], -float('inf'), dtype=tl.float32) cond = tl.arange(0, MAX_SEQ_LEN) < context_len scores_masked = tl.where(cond, scores, mask) # do a numerically stable softmax on the scores scores_minus_max = scores_masked - tl.max(scores_masked, axis=0) numerator = tl.exp(scores_minus_max) denominator = tl.sum(numerator, axis=0) + float(1e-6) logits = numerator / denominator print_tensor_dim(logits, "logits_v1") weighted_values = tl.sum(values * logits[:, None], axis=0) print_tensor_dim(weighted_values, "weighted_values_v1") output_offset = seq_idx * (num_heads * HEAD_SIZE) + head_idx * HEAD_SIZE tl.store(output_ptr + output_offset + tl.arange(0, HEAD_SIZE), weighted_values) def paged_attention_triton_v1( output, query, key_cache, value_cache, #head_mapping, scale, block_tables, context_lens, block_size, #max_seq_len, #alibi_slopes, num_seqs, num_query_heads, max_seq_len, max_num_blocks_per_seq, head_size ): scratchpad_key = torch.zeros( (num_seqs, max_seq_len, num_query_heads, head_size), dtype=torch.float32, device="cuda", ) scratchpad_value = torch.zeros_like(scratchpad_key) paged_attention_v1[(num_seqs, num_query_heads)]( scratchpad_key_ptr=scratchpad_key, scratchpad_value_ptr=scratchpad_value, output_ptr=output, query_ptr=query, key_cache_ptr=key_cache, value_cache_ptr=value_cache, block_tables_ptr=block_tables, context_lens_ptr=context_lens, scale=scale, num_seqs=num_seqs, num_heads=num_query_heads, cache_block_stride=key_cache.stride(0), MAX_SEQ_LEN=max_seq_len, BLOCK_SIZE=block_size, HEAD_SIZE=head_size, MAX_NUM_BLOCKS_PER_SEQ=max_num_blocks_per_seq, ) #Paged Attention V2: Iterate on kv vectors to avoid memory limitation error (sram) @triton.jit def paged_attention_v2( # need these b/c we can't use view/reshape scratchpad_key_ptr, # [num_seqs, max_seq_len, num_heads, head_size] scratchpad_value_ptr, # [num_seqs, max_seq_len, num_heads, head_size] partition_buf_ptr, output_ptr, # [num_seqs, num_query_heads, head_size] query_ptr, # [num_seqs, num_query_heads, head_size] key_cache_ptr, # [num_blocks, num_kv_heads, head_size, block_size] value_cache_ptr, # [num_blocks, num_kv_heads, head_size, block_size] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] context_lens_ptr, # [num_seqs] scale, # float32 num_seqs, # int num_heads, # int cache_block_stride, # int num_partitions, #int PARTITION_SIZE: tl.constexpr, #int MAX_SEQ_LEN: tl.constexpr, # int BLOCK_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int, must be power of 2 MAX_NUM_BLOCKS_PER_SEQ: tl.constexpr, # int, must be power of 2 ): seq_idx = tl.program_id(0).to(tl.int64) head_idx = tl.program_id(1).to(tl.int64) partition_idx = tl.program_id(2).to(tl.int64) #Compute the offsets of the query using the strides #TODO(amorari) use the strides as returned from tensor.stride() instead query_offset = seq_idx * num_seqs + head_idx * HEAD_SIZE #load one q vector query_head = tl.load(query_ptr + query_offset + tl.arange(0, HEAD_SIZE)) print_tensor_dim(query_head, "query_head") block_table_offset = seq_idx * MAX_NUM_BLOCKS_PER_SEQ #load the context len for this q vector context_len = tl.load(context_lens_ptr + seq_idx) assert(context_len <= MAX_SEQ_LEN) #iterate on the tokens in this partition token_start_idx = partition_idx * PARTITION_SIZE token_end_idx = min((partition_idx + 1) * PARTITION_SIZE, context_len) #NOTE: For some sequence, it is possible that context_len < token_start_idx for tok_idx in range(token_start_idx, token_end_idx): logical_block_offset = tok_idx // BLOCK_SIZE #physical block starting pointer for token physical_block_idx = tl.load( block_tables_ptr + block_table_offset + logical_block_offset ) start_of_block_offset = ( physical_block_idx * cache_block_stride + head_idx * HEAD_SIZE * BLOCK_SIZE ) tok_idx_within_block = tok_idx % BLOCK_SIZE tok_offsets = ( start_of_block_offset + BLOCK_SIZE * tl.arange(0, HEAD_SIZE) + tok_idx_within_block ) tok_key = tl.load(key_cache_ptr + tok_offsets) #print_tensor_dim(tok_key, "tok_key") tok_value = tl.load(value_cache_ptr + tok_offsets) #print_tensor_dim(tok_key, "tok_value") scratchpad_offset = ( seq_idx.to(tl.int64) * (MAX_SEQ_LEN * num_heads.to(tl.int64) * HEAD_SIZE) + tok_idx.to(tl.int64) * (num_heads.to(tl.int64) * HEAD_SIZE) + head_idx * HEAD_SIZE ) print_tensor_dim(scratchpad_key_ptr, "scratchpad_key_ptr") mask=tl.full([HEAD_SIZE], 1, dtype=tl.float32) > 0 #store the keys in line tl.store( scratchpad_key_ptr + scratchpad_offset + tl.arange(0, HEAD_SIZE), tok_key, mask ) #store the values in line tl.store( scratchpad_value_ptr + scratchpad_offset + tl.arange(0, HEAD_SIZE), tok_value, mask ) # TODO: Not sure if this is necessary tl.debug_barrier() #start of the sequence start_seq_offset = (MAX_SEQ_LEN * num_heads.to(tl.int64) * HEAD_SIZE) * seq_idx.to(tl.int64) #offsets with the start of the token start_tok_offsets = start_seq_offset.to(tl.int64) \ + tl.arange(0, PARTITION_SIZE) * (num_heads.to(tl.int64) * HEAD_SIZE) \ + head_idx.to(tl.int64) * HEAD_SIZE # [seq_len, head_size] # zero out keys that aren't part of the sequence mask = tl.arange(0, PARTITION_SIZE)[:, None] < context_len kv_offs = start_tok_offsets[:, None] + tl.arange(0, HEAD_SIZE)[None, :] print_tensor_dim(kv_offs, "kv_offs_v2") keys = tl.load(scratchpad_key_ptr + kv_offs, mask=mask, other=0.0) print_tensor_dim(keys, "keys_v2") # Can't do below b/c minimum size on all dimensions is 16 # scores = tl.dot(query_head[None, :], keys.T) scores = tl.sum(scale * keys * query_head[None, :], axis=1) print_tensor_dim(keys, "scores_v2") partition_buf_offset = start_seq_offset \ + head_idx.to(tl.int64) * HEAD_SIZE + partition_idx.to(tl.int64) * PARTITION_SIZE print_tensor_dim(partition_buf_offset, "partition_buf_offset_v2") tl.store(partition_buf_ptr + partition_buf_offset + tl.arange(0, PARTITION_SIZE), scores) #weighted_values = tl.zeros(HEAD_SIZE, dtype=tl.float32) # This mask is necessary b/c even though we mask out the keys on load # that just results in 0s in the attention dot product, # which then get softmaxed and result in non-zero values # in the softmax output (which is wrong) # -inf guarantees that the softmax output will be 0 for masked values mask = tl.full([PARTITION_SIZE], -float('inf'), dtype=tl.float32) cond = tl.arange(0, PARTITION_SIZE) < context_len scores_masked = tl.where(cond, scores, mask) # do a numerically stable softmax on the scores scores_minus_max = scores_masked - tl.max(scores_masked, axis=0) numerator = tl.exp(scores_minus_max) denominator = tl.sum(numerator, axis=0) + float(1e-6) logits = numerator / denominator print_tensor_dim(logits, "logits_v2") values = tl.load(scratchpad_value_ptr + kv_offs, mask=mask, other=0.0) print_tensor_dim(values, "values_v2") weighted_values += tl.sum(values * logits[:, None], axis=0) print_tensor_dim(weighted_values, "weighed_values_v2") #output_offset = seq_idx.to(tl.int64) * (num_heads.to(tl.int64) * HEAD_SIZE) \ # + head_idx.to(tl.int64) * HEAD_SIZE + seq_partition_idx.to(tl.int64) * PARTITION_SIZE #to_store_values=weighted_values.to(tl.float32) #mask = tl.full([HEAD_SIZE], 1, dtype=tl.float32) > 0 #tl.store(output_ptr + output_offset + tl.arange(0, HEAD_SIZE), to_store_values, mask) output_offset = seq_idx * (num_heads * HEAD_SIZE) + head_idx * HEAD_SIZE tl.store(output_ptr + output_offset + tl.arange(0, HEAD_SIZE), weighted_values) def paged_attention_triton_v2( output, query, key_cache, value_cache, #head_mapping, scale, block_tables, context_lens, block_size, partition_size, #alibi_slopes, num_seqs, num_query_heads, max_seq_len, max_num_blocks_per_seq, head_size ): scratchpad_key = torch.zeros( (num_seqs, max_seq_len, num_query_heads, head_size), dtype=torch.float32, device="cuda", ) scratchpad_value = torch.zeros_like(scratchpad_key) num_partitions = max_seq_len//partition_size assert(max_seq_len % partition_size == 0) partition_buf_ptr = torch.zeros((num_seqs,max_seq_len,num_query_heads,head_size), dtype=torch.float32, device="cuda") #print(f"started_v2 num_seqs: {num_seqs} num_query_heads: {num_query_heads}") paged_attention_v2[(num_seqs, num_query_heads, num_partitions)]( scratchpad_key_ptr=scratchpad_key, scratchpad_value_ptr=scratchpad_value, partition_buf_ptr=partition_buf_ptr, output_ptr=output, query_ptr=query, key_cache_ptr=key_cache, value_cache_ptr=value_cache, block_tables_ptr=block_tables, context_lens_ptr=context_lens, scale=scale, num_seqs=num_seqs, num_heads=num_query_heads, cache_block_stride=key_cache.stride(0), num_partitions=num_partitions, PARTITION_SIZE=partition_size, MAX_SEQ_LEN=max_seq_len, BLOCK_SIZE=block_size, HEAD_SIZE=head_size, MAX_NUM_BLOCKS_PER_SEQ=max_num_blocks_per_seq, ) #print("finished_v2") ================================================ FILE: kernels/triton/inference/torch_compile/flash_backward.py ================================================ #!/usr/bin/env python """ Code copied from https://github.com/ROCm/triton/blob/triton-mlir/python/perf-kernels/flash-attention.py """ """ Fused Attention =============== This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) Credits: OpenAI kernel team, AMD ML Frameworks Triton team Features supported: 1) Fwd with causal masking 2) Any sequence lengths without padding (currently fwd kernel only) 3) Support for different sequence lengths for q and k 4) Nested tensor API currently does not support dropout or bias. Not currently supported: 1) Non power of two head dims """ import argparse import random import sys import torch import triton import triton.language as tl torch_dtype:tl.constexpr = torch.float16 #TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz') #if TORCH_HAS_FP8E5: # torch_dtype:tl.constexpr = torch.float8_e5m2fnuz TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2') if TORCH_HAS_FP8E5: torch_dtype:tl.constexpr = torch.float8_e5m2 class MetaData(): cu_seqlens_q = None cu_seqlens_k = None max_seqlens_q = 0 max_seqlens_k = 0 bias = None alibi_slopes = None causal = False num_contexts = 0 varlen = False dropout_p, return_encoded_softmax = 0.0, False def __init__(self, sm_scale=1.0): self.sm_scale = sm_scale def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): self.varlen = True self.cu_seqlens_q = cu_seqlens_q self.cu_seqlens_k = cu_seqlens_k # Without "varlen", there should still be one sequence. assert len(cu_seqlens_q) >= 2 assert len(cu_seqlens_q) == len(cu_seqlens_k) self.num_contexts = len(cu_seqlens_q) - 1 for i in range (0, self.num_contexts): self.max_seqlens_q = max(cu_seqlens_q[i+1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q) self.max_seqlens_k = max(cu_seqlens_k[i+1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k) def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k): assert bias.is_cuda assert bias.dim() == 4 assert bias.shape[0] == 1 assert bias.shape[2:] == (seqlen_q, seqlen_k) self.bias = bias def need_alibi(self, alibi_slopes, batch, nheads): assert alibi_slopes.is_cuda assert alibi_slopes.dim() == 2 assert alibi_slopes.shape[0] == batch assert alibi_slopes.shape[1] == nheads self.alibi_slopes = alibi_slopes def need_causal(self): self.causal = True def need_dropout(dropout_p, return_encoded_softmax): self.dropout_p = dropout_p self.return_encoded_softmax = return_encoded_softmax def check_args(self, q, k, v, o): assert q.dim() == k.dim() and q.dim() == v.dim() if self.varlen: assert q.dim() == 3 total_q, nheads_q, head_size = q.shape total_k, nheads_k, _ = k.shape assert self.cu_seqlens_q is not None assert self.cu_seqlens_k is not None assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) # TODO: Remove once bias is supported with varlen assert self.bias == None # TODO:Remove once dropout is supported with varlen assert self.dropout_p == 0.0 assert not self.return_encoded_softmax else: assert q.dim() == 4 batch, nheads_q, seqlen_q, head_size = q.shape _, nheads_k, seqlen_k, _ = k.shape assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 assert self.cu_seqlens_q is None and self.cu_seqlens_k is None assert k.shape == v.shape assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] # TODO: Change assert if we support qkl f8 and v f16 assert q.dtype == k.dtype and q.dtype == v.dtype assert head_size <= 256 assert o.shape == q.shape assert (nheads_q % nheads_k) == 0 @triton.jit def cdiv_fn(x,y): return (x + y - 1) // y @triton.jit def max_fn(x, y): return tl.math.max(x, y) @triton.jit def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): ms = tl.arange(0, m) ns = tl.arange(0, n) return philox_offset + ms[:, None] * stride + ns[None, :] @triton.jit def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32) # TODO: use tl.randint for better performance return tl.rand(philox_seed, rng_offsets) @triton.jit def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) rng_keep = rng_output > dropout_p return rng_keep @triton.jit def load_fn(block_ptr, first, second, pad): if first and second: tensor = tl.load(block_ptr, boundary_check=(0,1), padding_option=pad) elif first: tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad) elif second: tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad) else: tensor = tl.load(block_ptr) return tensor @triton.jit def _attn_fwd_inner( acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, bias_ptr, alibi_slope, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr ): # loop over k, v, and update accumulator for start_n in range (block_min, block_max, BLOCK_N): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. k = load_fn(K_block_ptr, PADDED_HEAD, MASK_STEPS and (n_extra_tokens != 0), "zero") if PRE_LOAD_V: v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n # TODO: This can be optimized to only be true for the padded block. if MASK_STEPS: # If this is the last block / iteration, we want to # mask if the sequence length is not a multiple of block size # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. # last step might get wasted but that is okay. check if this masking works For # that case. if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) size_n = start_n + OFFS_N[None,:] mask = size_n < boundary_m[:,None] qk = tl.where(mask, qk, float("-inf")) if IS_CAUSAL: causal_boundary = start_n + offs_n_causal causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] qk = tl.where(causal_mask, qk, float("-inf")) # -- compute qk ---- qk += tl.dot(q, k) if bias_ptr is not None: bias = load_fn(bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero") # While bias is added after multiplying qk with sm_scale, # our optimization to use 2^x instead of e^x results in an additional # scale factor of log2(e) which we must also multiply the bias with. qk += (bias * 1.44269504089) if alibi_slope is not None: # Compute the global position of each token within the sequence global_m_positions = start_m*BLOCK_M + tl.arange(0, BLOCK_M) global_n_positions = start_n + tl.arange(0, BLOCK_N) # Compute the relative position using the global positions relative_pos_block = global_m_positions[:,None] + actual_seqlen_k - global_n_positions[None,:] - actual_seqlen_q relative_pos_block = tl.abs(relative_pos_block) alibi_block = -1 * alibi_slope * relative_pos_block qk += (alibi_block * 1.44269504089) # scale factor of log2(e) # softmax m_ij = tl.maximum(m_i, tl.max(qk,1)) qk = qk - m_ij[:, None] p = tl.math.exp2(qk) # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) if RETURN_ENCODED_SOFTMAX: tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty)) p = tl.where(keep, p, 0.0) elif RETURN_ENCODED_SOFTMAX: tl.store(encoded_softmax_block_ptr, p.to(encoded_softmax_block_ptr.type.element_ty)) # -- update output accumulator -- alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] if not PRE_LOAD_V: v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) if bias_ptr is not None: bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) if RETURN_ENCODED_SOFTMAX: encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N)) return acc, l_i, m_i @triton.jit def attn_fwd( Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, cu_seqlens_k, dropout_p, philox_seed, philox_offset_base, encoded_softmax, hq, hk, alibi_slopes, ACTUAL_BLOCK_DMODEL:tl.constexpr, MAX_SEQLENS_Q:tl.constexpr, MAX_SEQLENS_K:tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, BIAS_TYPE: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr, BATCH_SIZE: tl.constexpr, ): start_m = tl.program_id(0) off_h_q = tl.program_id(1) off_z = tl.program_id(2) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) if VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start # We have a one-size-fits-all grid in id(0). Some seqlens might be too # small for all start_m so for those we return early. if start_m * BLOCK_M > seqlen_q: return cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start else: cu_seqlens_q_start = 0 cu_seqlens_k_start = 0 seqlen_q = MAX_SEQLENS_Q seqlen_k = MAX_SEQLENS_K # Now we compute whether we need to exit early due to causal masking. # This is because for seqlen_q > seqlen_k, M rows of the attn scores # are completely masked, resulting in 0s written to the output, and # inf written to LSE. We don't need to do any GEMMs in this case. # This block of code determines what N is, and if this WG is operating # on those M rows. n_blocks = cdiv_fn(seqlen_k, BLOCK_N) if (IS_CAUSAL): # If seqlen_q == seqlen_k, the attn scores are a square matrix. # If seqlen_q != seqlen_k, attn scores are rectangular which means # the causal mask boundary is bottom right aligned, and ends at either # the top edge (seqlen_q < seqlen_k) or left edge. # This captures the decrease in n_blocks if we have a rectangular attn matrix n_blocks_seqlen = cdiv_fn( (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N ) # This is what adjusts the block_max for the current WG, only # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks n_blocks = min(n_blocks, n_blocks_seqlen) # If we have no blocks after adjusting for seqlen deltas, this WG is part of # the blocks that are all 0. We exit early. if n_blocks <= 0: o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh O_block_ptr = tl.make_block_ptr( base=Out + o_offset, shape=(seqlen_q, BLOCK_DMODEL), strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) # We still need to write 0s to the result tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1)) l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m # We store inf to LSE, not -inf because in the bwd pass, we subtract this # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) tl.store(l_ptrs, l) # TODO: Should dropout and return encoded softmax be handled here too? return is_mqa = hq != hk off_h_k = off_h_q % hk if is_mqa else off_h_q need_padding = False n_extra_tokens = 0 if seqlen_k < BLOCK_N: need_padding = True n_extra_tokens = BLOCK_N - seqlen_k elif seqlen_k % BLOCK_N: need_padding = True n_extra_tokens = seqlen_k % BLOCK_N padded_head = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) # Compute pointers for all the tensors used in this kernel. q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm Q_block_ptr = tl.make_block_ptr( base=Q + q_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn K_block_ptr = tl.make_block_ptr( base=K + k_offset, shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1) ) v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk V_block_ptr = tl.make_block_ptr( base=V + v_offset, shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), order=(1, 0) ) if BIAS_TYPE != 0: b_offset = off_h_q * stride_bh # Note: this might get large enough to overflow on some configs bias_ptr = tl.make_block_ptr( base=bias + b_offset, shape=(seqlen_q, seqlen_k), strides=(stride_bm, stride_bn), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0), ) else: bias_ptr = None if USE_ALIBI != 0: a_offset = off_z * stride_az + off_h_q * stride_ah alibi_slope = tl.load(alibi_slopes + a_offset) else: alibi_slope = None if ENABLE_DROPOUT: batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k else: batch_philox_offset = 0 # We can ask to return the dropout mask without actually doing any dropout. In # this case, we return an invalid pointer so indicate the mask is not valid. # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. if RETURN_ENCODED_SOFTMAX: encoded_softmax_block_ptr = tl.make_block_ptr( base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, shape=(seqlen_q, seqlen_k), strides=(seqlen_k, 1), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0) ) else: encoded_softmax_block_ptr = 0 # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # scale sm_scale by log_2(e) and use 2^x in the loop as we do not # have native e^x support in HW. qk_scale = sm_scale * 1.44269504089 # Q is loaded once at the beginning and shared by all N blocks. q = load_fn(Q_block_ptr, True, padded_head, "zero") q = (q * qk_scale).to(Q_block_ptr.type.element_ty) # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) if IS_CAUSAL: # There are always at least BLOCK_M // BLOCK_N masked blocks. # Additionally there might be one more due to dissimilar seqlens. masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) else: # Padding on Q does not need to be masked in the FA loop. masked_blocks = padded_block_k # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. # In this case we might exceed n_blocks so pick the min. masked_blocks = min(masked_blocks, n_blocks) n_full_blocks = n_blocks - masked_blocks block_min = 0 block_max = n_blocks * BLOCK_N # Compute for full blocks. Here we set causal to false regardless of its actual # value because there is no masking. Similarly we do not need padding. if n_full_blocks > 0: block_max = (n_blocks - masked_blocks) * BLOCK_N acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ block_min, block_max, 0, 0, 0, bias_ptr, alibi_slope, # IS_CAUSAL, .... False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, # _, MASK_STEPS, ... PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, padded_head ) block_min = block_max block_max = n_blocks * BLOCK_N tl.debug_barrier() # Remaining blocks, if any, are full / not masked. if (masked_blocks > 0): if IS_CAUSAL: offs_n_causal = offs_n + (seqlen_q - seqlen_k) else: offs_n_causal = 0 K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks*BLOCK_N)) V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks*BLOCK_N, 0)) if bias_ptr is not None: bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks*BLOCK_N)) if RETURN_ENCODED_SOFTMAX: encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, n_full_blocks)) acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, bias_ptr, alibi_slope, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, # _, MASK_STEPS, ... PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, padded_head ) # epilogue acc = acc / l_i[:, None] if ENABLE_DROPOUT: acc = acc / (1 - dropout_p) # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, # then we have one block with a row of all NaNs which come from computing # softmax over a row of all -infs (-inf - inf = NaN). We check for that here # and store 0s where there are NaNs as these rows should've been zeroed out. end_m_idx = (start_m + 1) * BLOCK_M start_m_idx = start_m * BLOCK_M causal_start_idx = seqlen_q - seqlen_k acc = acc.to(Out.type.element_ty) if IS_CAUSAL: if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: out_mask_boundary = tl.full((BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32) mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. # This is only true for the last M block. For others, overflow_size will be -ve overflow_size = end_m_idx - seqlen_q if overflow_size > 0: boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) # This is a > check because mask being 0 blocks the store. l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) else: tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh O_block_ptr = tl.make_block_ptr( base=Out + o_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) # Need boundary check on this to make sure the padding from the # Q and KV tensors in both dims are not part of what we store back. # TODO: Do the boundary check optionally. tl.store(O_block_ptr, acc, boundary_check=(0,1)) def attention(q, k, v, sm_scale): o = torch.empty_like(q, dtype=v.dtype) batch, nheads_q, seqlen_q, head_size = q.shape _, nheads_k, seqlen_k, _ = k.shape max_seqlens_q = seqlen_q q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) # Get closest power of 2 over or equal to 32. unpadded_head_dims = {32, 64, 128, 256} if head_size not in unpadded_head_dims: padded_d_model = None for i in unpadded_head_dims: if i > head_size: padded_d_model = i break assert padded_d_model is not None else: padded_d_model = head_size # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4) BLOCK_M = 128 BLOCK_N = 128 PRE_LOAD_V = False num_stages = 1 num_warps = 4 grid = (triton.cdiv(max_seqlens_q, BLOCK_M), nheads_q, batch) # encoded_softmax is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing # only. This return holds no useful output aside from debugging. encoded_softmax = None M = torch.empty((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32) # Seed the RNG so we get reproducible results for testing. philox_seed = 0x1BF52 philox_offset = 0x1D4B42 bias_strides = (0,0,0,0) alibi_strides = (0, 0) attn_fwd[grid]( q, k, v, None, sm_scale, M, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, *alibi_strides, None, None, BLOCK_M=BLOCK_M, PRE_LOAD_V=PRE_LOAD_V, BLOCK_N=BLOCK_N, dropout_p=0.0, philox_seed=philox_seed, philox_offset_base=philox_offset, encoded_softmax=encoded_softmax, hq=nheads_q, hk=nheads_k, alibi_slopes = None, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=seqlen_q, MAX_SEQLENS_K=seqlen_k, IS_CAUSAL=False, ######################## VARLEN=False, BLOCK_DMODEL=padded_d_model, BIAS_TYPE=0, USE_ALIBI=0, ENABLE_DROPOUT=False, RETURN_ENCODED_SOFTMAX=False, BATCH_SIZE= q.shape[0], ) return o @triton.jit def _attn_bwd_preprocess( Out, DO, Delta, stride_oz, stride_oh, stride_om, stride_on, stride_doz, stride_doh, stride_dom, stride_don, seqlen_q, head_dim, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, ): # off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) # off_n = tl.arange(0, D_HEAD) off_m = tl.program_id(0) * BLOCK_M off_h = tl.program_id(1) # head index off_z = tl.program_id(2) # batch index num_h = tl.num_programs(1) o_offset = off_h * stride_oh + off_z * stride_oz O_block_ptr = tl.make_block_ptr( base=Out + o_offset, shape=(seqlen_q, head_dim), strides=(stride_om, stride_on), offsets=(off_m, 0), block_shape=(BLOCK_M, D_HEAD), order=(1, 0) ) do_offset = off_h * stride_doh + off_z * stride_doz DO_block_ptr = tl.make_block_ptr( base=DO + do_offset, shape=(seqlen_q, head_dim), strides=(stride_dom, stride_don), offsets=(off_m, 0), block_shape=(BLOCK_M, D_HEAD), order=(1, 0) ) # load # o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) # do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) o = tl.load(O_block_ptr, boundary_check=(0,1), padding_option="zero").to(tl.float32) do = tl.load(DO_block_ptr, boundary_check=(0,1), padding_option="zero").to(tl.float32) # compute delta = tl.sum(o * do, axis=1) # write-back, shape (q.shape[0] * q.shape[1], q.shape[2]) off_zh = off_z * num_h + off_h * 1 # Check for OOB accesses delta_ptrs = Delta + off_zh * seqlen_q + off_m + tl.arange(0, BLOCK_M) overflow = off_m + BLOCK_M - seqlen_q if overflow > 0: boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow, dtype=tl.int32) mask = boundary > tl.arange(0, BLOCK_M) tl.store(delta_ptrs, delta, mask=mask) else: tl.store(delta_ptrs, delta) @triton.jit def _bwd_kernel_dk_dv( dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, # shared by Q/K/V/DO. stride_tok, stride_d, H, N_CTX, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # Filled in by the wrapper. start_n, start_m, num_steps, MASK: tl.constexpr): offs_m = start_m + tl.arange(0, BLOCK_M1) offs_n = start_n + tl.arange(0, BLOCK_N1) offs_k = tl.arange(0, BLOCK_DMODEL) QT_block_ptr = tl.make_block_ptr( base=Q, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), offsets=(0, start_m), block_shape=(BLOCK_DMODEL, BLOCK_M1), order=(0,1) ) DO_block_ptr = tl.make_block_ptr( base=DO, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), offsets=(start_m, 0), block_shape=(BLOCK_M1, BLOCK_DMODEL), order=(1,0) ) # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) curr_m = start_m step_m = BLOCK_M1 for blk_idx in range(num_steps): qT = tl.load(QT_block_ptr) # Load m before computing qk to reduce pipeline stall. offs_m = curr_m + tl.arange(0, BLOCK_M1) m = tl.load(M + offs_m) kqT = tl.dot(k, qT) if alibi_slope is not None: alibi_block = compute_alibi_block(alibi_slope, N_CTX, N_CTX, offs_m, offs_n, True) kqT += alibi_block * 1.44269504089 pT = tl.math.exp2(kqT - m[None, :]) # Autoregressive masking. if MASK: mask = (offs_m[None, :] >= offs_n[:, None]) pT = tl.where(mask, pT, 0.0) do = tl.load(DO_block_ptr) # Compute dV. ppT = pT ppT = ppT.to(tl.bfloat16) dv += tl.dot(ppT, do) # D (= delta) is pre-divided by ds_scale. Di = tl.load(D + offs_m) # Compute dP and dS. dpT = tl.dot(v, tl.trans(do)) dsT = pT * (dpT - Di[None, :]) dsT = dsT.to(tl.bfloat16) dk += tl.dot(dsT, tl.trans(qT)) # Increment pointers. curr_m += step_m QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m)) DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0)) return dk, dv @triton.jit def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, # shared by Q/K/V/DO. stride_tok, stride_d, H, N_CTX, BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # Filled in by the wrapper. start_m, start_n, num_steps, MASK: tl.constexpr): offs_m = start_m + tl.arange(0, BLOCK_M2) offs_n = start_n + tl.arange(0, BLOCK_N2) offs_k = tl.arange(0, BLOCK_DMODEL) KT_block_ptr = tl.make_block_ptr( base=K, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), offsets=(0, start_n), block_shape=(BLOCK_DMODEL, BLOCK_N2), order=(0, 1) ) VT_block_ptr = tl.make_block_ptr( base=V, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), offsets=(0, start_n), block_shape=(BLOCK_DMODEL, BLOCK_N2), order=(0, 1) ) # D (= delta) is pre-divided by ds_scale. Di = tl.load(D + offs_m) # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) curr_n = start_n step_n = BLOCK_N2 for blk_idx in range(num_steps): kT = tl.load(KT_block_ptr) qk = tl.dot(q, kT) p = tl.math.exp2(qk - m) # Autoregressive masking. if MASK: offs_n = curr_n + tl.arange(0, BLOCK_N2) mask = (offs_m[:, None] >= offs_n[None, :]) p = tl.where(mask, p, 0.0) # Compute dP and dS. vT = tl.load(VT_block_ptr) dp = tl.dot(do, vT).to(tl.float32) ds = p * (dp - Di[:, None]) ds = ds.to(tl.bfloat16) # Compute dQ.0. # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. dq += tl.dot(ds, tl.trans(kT)) # Increment pointers. curr_n += step_n KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n)) VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n)) return dq @triton.jit def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, # shared by Q/K/V/DO. stride_z, stride_h, stride_tok, stride_d, # H = 16, N_CTX = 1024 H, N_CTX, BLOCK_DMODEL: tl.constexpr, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, BLK_SLICE_FACTOR: tl.constexpr, USE_ALIBI: tl.constexpr): LN2: tl.constexpr = 0.6931471824645996 # = ln(2) bhid = tl.program_id(2) off_chz = (bhid * N_CTX).to(tl.int64) adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) pid = tl.program_id(0) # offset pointers for batch/head Q += adj K += adj V += adj DO += adj DQ += adj DK += adj DV += adj M += off_chz D += off_chz offs_k = tl.arange(0, BLOCK_DMODEL) start_n = pid * BLOCK_N1 # This assignment is important. It is what allows us to pick the diagonal # blocks. Later, when we want to do the lower triangular, we update start_m # after the first dkdv call. start_m = start_n MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR offs_n = start_n + tl.arange(0, BLOCK_N1) dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) K_block_ptr = tl.make_block_ptr( base=K, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1, 0), ) V_block_ptr = tl.make_block_ptr( base=V, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1, 0), ) # load K and V: they stay in SRAM throughout the inner loop for dkdv. k = tl.load(K_block_ptr) v = tl.load(V_block_ptr) if USE_ALIBI: a_offset = bhid alibi_slope = tl.load(alibi_slopes + a_offset) else: alibi_slope = None # compute dK and dV for blocks close to the diagonal that need to be masked num_steps = BLOCK_N1 // MASK_BLOCK_M1 dk, dv = _bwd_kernel_dk_dv( dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=True ) # compute dK and dV for blocks that don't need masking further from the diagonal start_m += num_steps * MASK_BLOCK_M1 num_steps = (N_CTX - start_m) // BLOCK_M1 dk, dv = _bwd_kernel_dk_dv( dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=False ) DV_block_ptrs = tl.make_block_ptr( base=DV, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1,0) ) tl.store(DV_block_ptrs, dv.to(v.dtype)) # Write back dK. dk *= sm_scale DK_block_ptrs = tl.make_block_ptr( base=DK, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1,0) ) tl.store(DK_block_ptrs, dk.to(k.dtype)) # THIS BLOCK DOES DQ: start_m = pid * BLOCK_M2 end_n = start_m + BLOCK_M2 MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR offs_m = start_m + tl.arange(0, BLOCK_M2) Q_block_ptr = tl.make_block_ptr( base=Q, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0) ) DO_block_ptr = tl.make_block_ptr( base=DO, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0) ) q = tl.load(Q_block_ptr) do = tl.load(DO_block_ptr) dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) m = tl.load(M + offs_m) m = m[:, None] # Compute dQ for masked (diagonal) blocks. # NOTE: This code scans each row of QK^T backward (from right to left, # but inside each call to _attn_bwd_dq, from left to right), but that's # not due to anything important. I just wanted to reuse the loop # structure for dK & dV above as much as possible. num_steps = BLOCK_M2 // MASK_BLOCK_N2 dq = _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, stride_tok, stride_d, H, N_CTX, BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL, start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, MASK=True ) end_n -= num_steps * MASK_BLOCK_N2 # stage 2 num_steps = end_n // BLOCK_N2 dq = _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, stride_tok, stride_d, H, N_CTX, BLOCK_M2, BLOCK_N2, BLOCK_DMODEL, start_m, end_n - num_steps * BLOCK_N2, num_steps, MASK=False ) # Write back dQ. DQ_block_ptr = tl.make_block_ptr( base=DQ, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0) ) dq *= LN2 tl.store(DQ_block_ptr, dq.to(q.dtype)) @torch.library.custom_op("triton::flash_bwd", mutates_args=()) def flash_bwd(q: torch.Tensor, k: torch.Tensor, v:torch.Tensor, o: torch.Tensor, M:torch.Tensor, do: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: BLOCK = 128 sm_scale = q.shape[-1] ** -0.5 batch, nheads_q, seqlen_q, head_size = q.shape _, nheads_k, seqlen_k, _ = k.shape Lk = k.shape[-1] max_seqlens_q = seqlen_q padded_d_model = head_size # assert do.is_contiguous() assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() seqlen_q = q.shape[2] dq = torch.empty_like(q) dk = torch.empty_like(k) dv = torch.empty_like(v) BATCH, N_HEAD, N_CTX = q.shape[:3] PRE_BLOCK = 128 NUM_WARPS, NUM_STAGES = 4, 1 BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32 BLK_SLICE_FACTOR = 2 RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) arg_k = k arg_k = arg_k * (sm_scale * RCP_LN2) assert N_CTX % PRE_BLOCK == 0 delta = torch.empty_like(M) grid_preprocess = (triton.cdiv(do.shape[2], BLOCK), do.shape[1], do.shape[0]) _attn_bwd_preprocess[grid_preprocess]( o, do, delta, o.stride(0), o.stride(1), o.stride(2), o.stride(3), do.stride(0), do.stride(1), do.stride(2), do.stride(3), seqlen_q, head_dim=Lk, BLOCK_M=BLOCK, D_HEAD=padded_d_model, ) grid = (triton.cdiv(N_CTX, BLOCK_N1), 1, BATCH * N_HEAD) _attn_bwd[grid]( q, arg_k, v, sm_scale, None, do, dq, dk, dv, M, delta, q.stride(0), q.stride(1), q.stride(2), q.stride(3), N_HEAD, N_CTX, BLOCK_DMODEL=padded_d_model, BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, USE_ALIBI=False, num_warps=NUM_WARPS, num_stages=NUM_STAGES, ) return dq, dk, dv @flash_bwd.register_fake def _(q, k, v, o, M, do): dq = torch.empty_like(q) dk = torch.empty_like(k) dv = torch.empty_like(v) return dq, dk, dv @torch.library.custom_op("triton::flash", mutates_args=()) def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, o: torch.Tensor, M: torch.Tensor) -> torch.Tensor: sm_scale = q.shape[-1] ** -0.5 batch, nheads_q, seqlen_q, head_size = q.shape _, nheads_k, seqlen_k, _ = k.shape max_seqlens_q = seqlen_q q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) # Get closest power of 2 over or equal to 32. unpadded_head_dims = {32, 64, 128, 256} if head_size not in unpadded_head_dims: padded_d_model = None for i in unpadded_head_dims: if i > head_size: padded_d_model = i break assert padded_d_model is not None else: padded_d_model = head_size # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4) BLOCK_M = 128 BLOCK_N = 128 PRE_LOAD_V = False grid = (triton.cdiv(max_seqlens_q, BLOCK_M), nheads_q, batch) # encoded_softmax is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing # only. This return holds no useful output aside from debugging. encoded_softmax = None # Seed the RNG so we get reproducible results for testing. philox_seed = 0x1BF52 philox_offset = 0x1D4B42 bias_strides = (0, 0, 0, 0) alibi_strides = (0, 0) attn_fwd[grid]( q, k, v, None, sm_scale, M, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, *alibi_strides, None, None, BLOCK_M=BLOCK_M, PRE_LOAD_V=PRE_LOAD_V, BLOCK_N=BLOCK_N, dropout_p=0.0, philox_seed=philox_seed, philox_offset_base=philox_offset, encoded_softmax=encoded_softmax, hq=nheads_q, hk=nheads_k, alibi_slopes = None, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=seqlen_q, MAX_SEQLENS_K=seqlen_k, IS_CAUSAL=True, ######################## VARLEN=False, BLOCK_DMODEL=padded_d_model, BIAS_TYPE=0, USE_ALIBI=0, ENABLE_DROPOUT=False, RETURN_ENCODED_SOFTMAX=False, BATCH_SIZE= q.shape[0], ) out = o.clone() return out @flash.register_fake def _(q, k, v, o, M): return torch.empty_like(q, dtype=v.dtype) def setup_context(ctx, inputs, output) -> torch.Tensor: q, k, v, o, M = inputs ctx.save_for_backward(q, k, v, o, M) def backward(ctx, do): q, k, v, o, M = ctx.saved_tensors dq, dk, dv = flash_bwd(q, k, v, o, M, do) return dq, dk, dv, None, None flash.register_autograd(backward, setup_context=setup_context) if __name__ == "__main__": b, nh, s, hd = 1, 32, 128, 128 q = torch.randn(b, nh, s, hd, dtype=torch.float16, device='cuda').requires_grad_() k = torch.randn(b, nh, s, hd, dtype=torch.float16, device='cuda').requires_grad_() v = torch.randn(b, nh, s, hd, dtype=torch.float16, device='cuda').requires_grad_() sm_scale = q.shape[-1] ** -0.5 @torch.compile(fullgraph=True) def f(q, k, v): return flash(q, k, v) o = f(q, k, v) print(f"{o=}") dout = torch.randn_like(q) o.backward(dout) tri_dq = q.grad.clone() tri_dk = k.grad.clone() tri_dv = v.grad.clone() ================================================ FILE: kernels/triton/training/README.md ================================================ Triton training kernels ================================================ FILE: kernels/triton/training/fused_softmax/README.md ================================================ Fused Softmax in Triton, supporting both inference (fwd) and training (fwd/backward). Perf testing on A100: fused_softmax_a100 ================================================ FILE: kernels/triton/training/fused_softmax/softmax.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # ---- Fused Softmax written in Triton ------ # Extra Credits: # Triton Softmax Tutorial # LucidRains Triton_Transformers import torch import triton import triton.language as tl from torch import autograd def _get_num_warps(block_size: int)-> int: num_warps = 4 if block_size > 2047: num_warps = 8 if block_size > 4095: num_warps=16 return num_warps @triton.jit def _softmax_kernel_fwd( output_ptr, output_row_stride, input_ptr, input_row_stride, n_cols, block_size: tl.constexpr, ): # setup input location row_index = tl.program_id(0) input_row_ptr = input_ptr + (row_index * input_row_stride) col_offsets = tl.arange(0, block_size) input_ptrs = input_row_ptr + col_offsets rw_mask = col_offsets < n_cols row = tl.load(input_ptrs, mask = rw_mask, other=float("-inf")) # safe softmax proper safe_row = row - tl.max(row, axis=0) numerator = tl.exp(safe_row) denom = tl.sum(numerator, axis=0) sm_out = numerator / denom # write results to HBM out_row_ptr = output_ptr + (row_index * output_row_stride) out_row_ptrs = out_row_ptr + col_offsets tl.store(out_row_ptrs, sm_out, mask = rw_mask) @triton.jit def _softmax_kernel_bwd( output_ptr, stride_output_row, grad_ptr, stride_grad_row, input_ptr, stride_input_row, n_cols, block_size: tl.constexpr, ): # setup input locations - need both grad and input access row_index = tl.program_id(0) input_row_ptr = input_ptr + (row_index * stride_input_row) grad_row_ptr = grad_ptr + (row_index * stride_grad_row) col_offsets = tl.arange(0,block_size) rw_mask = col_offsets < n_cols input_row_ptrs = input_row_ptr + col_offsets grad_row_ptrs = grad_row_ptr + col_offsets probs_row =tl.load(input_row_ptrs, mask=rw_mask, other = 0) grads_row = tl.load(grad_row_ptrs, mask = rw_mask, other=0) # compute derivatives dx = probs_row * grads_row dsm_out = dx - probs_row * (tl.sum(dx, axis=0)) # write to HBM output_row_ptr = output_ptr + (row_index * stride_output_row) output_ptrs = output_row_ptr + col_offsets tl.store(output_ptrs, dsm_out, mask=rw_mask) class triton_softmax(autograd.Function): @staticmethod def forward(ctx, x): orig_shape = x.shape x = x.view(-1, orig_shape[-1]) nrows, ncols = x.shape block_size = triton.next_power_of_2(ncols) num_warps = _get_num_warps(block_size) res = torch.empty_like(x) grid = (nrows,) _softmax_kernel_fwd[grid]( res, res.stride(0), x, x.stride(0), ncols, block_size=block_size, num_warps=num_warps, ) if x.requires_grad: ctx.save_for_backward(res) return res.view(*orig_shape) @staticmethod def backward(ctx, grad_probs): orig_shape = grad_probs.shape probs, = ctx.saved_tensors grad_probs = grad_probs.view(-1, orig_shape[-1]) nrows, ncols = grad_probs.shape block_size = triton.next_power_of_2(ncols) num_warps = _get_num_warps(block_size) dx = torch.empty_like(probs) grid = (nrows,) _softmax_kernel_bwd[grid]( dx, dx.stride(0), probs, probs.stride(0), grad_probs, grad_probs.stride(0), ncols, block_size=block_size, num_warps=num_warps, ) return dx.view(*orig_shape), None fused_softmax = triton_softmax.apply if __name__ == '__main__': sample = torch.tensor([[1,2,3,4,5], [5,4,3,2,1]], dtype = torch.float32, device="cuda", requires_grad=True) from torch.nn.functional import softmax as torch_softmax res_torch = torch_softmax(sample, dim=1) res_triton = fused_softmax(sample) torch.testing.assert_close(res_torch, res_triton, rtol=0, atol=1e-4) # backward dout = torch.randn_like(sample) bwd_torch = res_torch.backward(dout) bwd_triton = res_triton.backward(dout) torch.testing.assert_close(bwd_triton, bwd_torch, rtol=0, atol=1e-4) ================================================ FILE: kernels/triton/training/rms_norm/fused_rms_norm.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # Credit # Tri Dao's Triton LayerNorm: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py # Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html # pylint: skip-file # flake8: noqa import math import torch import triton import triton.language as tl @triton.autotune( configs=[ triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), triton.Config({}, num_warps=4), triton.Config({}, num_warps=8), triton.Config({}, num_warps=16), triton.Config({}, num_warps=32), ], key=["N"], ) @triton.jit def _rms_norm_fwd_kernel( X, stride_x, Y, stride_y, W, Rstd, eps, M, # num rows N, # num cols block_N: tl.constexpr, ): row = tl.program_id(0) cols = tl.arange(0, block_N) # Load input data and weights mask = cols < N x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) # Compute mean and variance # xbar = tl.sum(x, axis=0) / tl.max(tl.sum(mask, axis=0), 1) xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) # Store the reciprocal standard deviation tl.store(Rstd + row, rstd) # Normalize and apply linear transformation x_hat = x * rstd y = x_hat * w # Write output tl.store(Y + row * stride_y + cols, y, mask=mask) @triton.autotune( configs=[ triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), triton.Config({}, num_warps=4), triton.Config({}, num_warps=8), triton.Config({}, num_warps=16), triton.Config({}, num_warps=32), ], key=["N"], ) @triton.jit def _rms_norm_bwd_kernel_sm( X, stride_x, W, DY, stride_dy, DX, stride_dx, Rstd, DW, eps, M, # num rows N, # num cols rows_per_program, block_N: tl.constexpr, ): row_block_id = tl.program_id(0) row_start = row_block_id * rows_per_program cols = tl.arange(0, block_N) mask = cols < N # Load weights w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) # Accumulate gradients for weights dw = tl.zeros((block_N,), dtype=tl.float32) row_end = min(row_start + rows_per_program, M) for row in range(row_start, row_end): # Load input, output gradient, and reciprocal standard deviation x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32) rstd = tl.load(Rstd + row) # Compute normalized input and gradients x_hat = x * rstd wdy = w * dy dw += dy * x_hat c1 = tl.sum(x_hat * wdy, axis=0) / N dx = (wdy - x_hat * c1) * rstd # Store input gradient tl.store(DX + row * stride_dx + cols, dx, mask=mask) # Store weight gradients tl.store(DW + row_block_id * N + cols, dw, mask=mask) """ # using the sm count to determine the number of rows per program # appears to be slightly faster than this bwd kernel below. @triton.jit def _rms_norm_bwd_kernel( X, stride_x, W, DY, stride_dy, DX, stride_dx, Rstd, DW, eps, M, # num rows N, # num cols rows_per_program, block_N: tl.constexpr, ): row_block_id = tl.program_id(0) row_start = row_block_id * rows_per_program cols = tl.arange(0, block_N) mask = cols < N # Load weights w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) # Accumulate gradients for weights dw = tl.zeros((block_N,), dtype=tl.float32) row_end = min(row_start + rows_per_program, M) for row in range(row_start, row_end): # Load input, output gradient, and reciprocal standard deviation x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32) rstd = tl.load(Rstd + row) # Compute normalized input and gradients x_hat = x * rstd wdy = w * dy dw += dy * x_hat c1 = tl.sum(x_hat * wdy, axis=0) / N dx = (wdy - x_hat * c1) * rstd # Store input gradient tl.store(DX + row * stride_dx + cols, dx, mask=mask) # Store weight gradients tl.store(DW + cols, dw, mask=mask) """ class ttt_RMSNorm(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, eps): x_shape_start = x.shape # Flatten input x = x.reshape(-1, x.shape[-1]) if x.stride(-1) != 1: x = x.contiguous() if weight.stride(-1) != 1: weight = weight.contiguous() M, N = x.shape y = torch.empty_like(x) rstd = torch.empty((M,), dtype=torch.float32, device=x.device) max_size = 65536 // x.element_size() block_N = min(max_size, triton.next_power_of_2(N)) if N > block_N: raise ValueError(f"N {N} must be <= {block_N=}") grid = lambda meta: (M,) _rms_norm_fwd_kernel[grid]( x, x.stride(0), y, y.stride(0), weight, rstd, eps, M, N, block_N, ) ctx.eps = eps ctx.save_for_backward(x, weight, rstd) ctx.x_shape_start = x_shape_start y = y.reshape(x_shape_start) return y @staticmethod def backward(ctx, dy): x, weight, rstd = ctx.saved_tensors eps = ctx.eps x_shape_start = ctx.x_shape_start # Flatten input and output gradients dy = dy.reshape(-1, dy.shape[-1]) if dy.stride(-1) != 1: dy = dy.contiguous() M, N = dy.shape dx = torch.empty_like(x) dw = torch.empty_like(weight) sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) max_size = 65536 // x.element_size() block_N = min(max_size, triton.next_power_of_2(N)) rows_per_sm = math.ceil(M / sm_count) if N > block_N: raise ValueError(f"N {N} must be <= {block_N=}") grid = lambda meta: (sm_count,) _rms_norm_bwd_kernel_sm[grid]( x, x.stride(0), weight, dy, dy.stride(0), dx, dx.stride(0), rstd, _dw, eps, M, N, rows_per_sm, block_N, ) dw = _dw.sum(0).to(weight.dtype) dx = dx.reshape(x_shape_start) return dx, dw, None """ # this is an alternative approach - but it seems to be just slightly slower than sm approach. @staticmethod def backward(ctx, dy): x, weight, rstd = ctx.saved_tensors eps = ctx.eps x_shape_start = ctx.x_shape_start # Flatten input and output gradients dy = dy.reshape(-1, dy.shape[-1]) if dy.stride(-1) != 1: dy = dy.contiguous() M, N = dy.shape dx = torch.empty_like(x) dw = torch.empty_like(weight) max_size = 65536 // x.element_size() block_N = min(max_size, triton.next_power_of_2(N)) rows_per_program = 1024 if N > block_N: raise ValueError(f"N {N} must be <= {block_N=}") grid = lambda meta: (triton.cdiv(M, rows_per_program),) _rms_norm_bwd_kernel[grid]( x, x.stride(0), weight, dy, dy.stride(0), dx, dx.stride(0), rstd, dw, eps, M, N, rows_per_program, block_N, ) dx = dx.reshape(x_shape_start) return dx, dw, None """ def fused_rms_norm_fn( x, weight, eps=1e-6, ): return ttt_RMSNorm.apply( x, weight, eps, ) class FusedRMSNorm(torch.nn.Module): def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.reset_parameters() def reset_parameters(self): torch.nn.init.ones_(self.weight) def forward( self, x, ): return fused_rms_norm_fn( x, self.weight, eps=self.eps, ) ================================================ FILE: kernels/triton/tutorials/README.md ================================================ Triton tutorials ================================================ FILE: readme.md ================================================ ### Applied AI repo For experiments and research on Applied AI. ### Projects #### Kernels Housing a variety of Triton and CUDA kernels for training and inference. Inference kernels = no backward pass support. ##### Triton Kernels #### 1 - Triton - MoE (Mixtral) GEMM for accelerating inference. Uses col major access pattern to increase locality. moe_gemm_a100 #### 2 - Triton - Fused Softmax for both training and inference. softmax_fused #### 3 - Triton - Fused RMSNorm for both training and inference. [Fused RMSNorm Kernel](https://github.com/meta-pytorch/applied-ai/blob/main/kernels/triton/training/rms_norm/fused_rms_norm.py) #### Other projects from Applied AI 1. [CUDA Mode](https://github.com/cuda-mode) - Reading group for learning CUDA programming - ([Discord](https://discord.gg/cudamode), [Lecture Materials](https://github.com/cuda-mode/lectures), [Lecture recordings](https://www.youtube.com/@CUDAMODE)) 2. [llama-recipes](https://github.com/meta-llama/llama-recipes) - Recipes for fine-tuning and inference for Llama model series 3. NeurIPS'23 [LLM Efficiency Challenge](https://llm-efficiency-challenge.github.io/) - 1LLM + 1GPU + 1Day competition - ([website](https://llm-efficiency-challenge.github.io/), [code](https://github.com/llm-efficiency-challenge), [NeurIPS Workshop recordings](https://neurips.cc/virtual/2023/competition/66594)) ### Papers and Publications 1. PyTorch 2: Faster Machine Learning Through Dynamic Python Bytecode Transformation and Graph Compilation [paper](https://pytorch.org/assets/pytorch2-2.pdf) 2. Accelerating a Triton Fused Kernel for W4A16 Quantized Inference with SplitK Work Decomposition [paper](https://ai.meta.com/research/publications/accelerating-a-triton-fused-kernel-for-w4a16-quantized-inference-with-splitk-work-decomposition/) 3. PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel [paper](https://arxiv.org/abs/2304.11277) 4. Sustainable AI: Environmental Implications, Challenges and Opportunities [paper](https://arxiv.org/abs/2111.00364) ### License The applied-ai repo is released under the [BSD 3](LICENSE) license. ================================================ FILE: tutorials/triton/kernels/__init__.py ================================================ ================================================ FILE: tutorials/triton/kernels/flash_attention_fwd.py ================================================ # flash forward v2 ================================================ FILE: tutorials/triton/kernels/fused_softmax.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # ---- Fused Softmax written in Triton ------ # Extra Credits: # Triton Softmax Tutorial # LucidRains Triton_Transformers import torch import triton import triton.language as tl from torch import autograd def _get_num_warps(block_size: int)-> int: num_warps = 4 if block_size > 2047: num_warps = 8 if block_size > 4095: num_warps=16 return num_warps @triton.jit def _softmax_kernel_fwd( output_ptr, output_row_stride, input_ptr, input_row_stride, n_cols, block_size: tl.constexpr, ): # setup input location row_index = tl.program_id(0) input_row_ptr = input_ptr + (row_index * input_row_stride) col_offsets = tl.arange(0, block_size) input_ptrs = input_row_ptr + col_offsets rw_mask = col_offsets < n_cols row = tl.load(input_ptrs, mask = rw_mask, other=float("-inf")) # safe softmax proper safe_row = row - tl.max(row, axis=0) numerator = tl.exp(safe_row) denom = tl.sum(numerator, axis=0) sm_out = numerator / denom # write results to HBM out_row_ptr = output_ptr + (row_index * output_row_stride) out_row_ptrs = out_row_ptr + col_offsets tl.store(out_row_ptrs, sm_out, mask = rw_mask) @triton.jit def _softmax_kernel_bwd( output_ptr, stride_output_row, grad_ptr, stride_grad_row, input_ptr, stride_input_row, n_cols, block_size: tl.constexpr, ): # setup input locations - need both grad and input access row_index = tl.program_id(0) input_row_ptr = input_ptr + (row_index * stride_input_row) grad_row_ptr = grad_ptr + (row_index * stride_grad_row) col_offsets = tl.arange(0,block_size) rw_mask = col_offsets < n_cols input_row_ptrs = input_row_ptr + col_offsets grad_row_ptrs = grad_row_ptr + col_offsets probs_row =tl.load(input_row_ptrs, mask=rw_mask, other = 0) grads_row = tl.load(grad_row_ptrs, mask = rw_mask, other=0) # compute derivatives dx = probs_row * grads_row dsm_out = dx - probs_row * (tl.sum(dx, axis=0)) # write to HBM output_row_ptr = output_ptr + (row_index * stride_output_row) output_ptrs = output_row_ptr + col_offsets tl.store(output_ptrs, dsm_out, mask=rw_mask) class triton_softmax(autograd.Function): @staticmethod def forward(ctx, x): orig_shape = x.shape x = x.view(-1, orig_shape[-1]) nrows, ncols = x.shape block_size = triton.next_power_of_2(ncols) num_warps = _get_num_warps(block_size) res = torch.empty_like(x) grid = (nrows,) _softmax_kernel_fwd[grid]( res, res.stride(0), x, x.stride(0), ncols, block_size=block_size, num_warps=num_warps, ) if x.requires_grad: ctx.save_for_backward(res) return res.view(*orig_shape) @staticmethod def backward(ctx, grad_probs): orig_shape = grad_probs.shape probs, = ctx.saved_tensors grad_probs = grad_probs.view(-1, orig_shape[-1]) nrows, ncols = grad_probs.shape block_size = triton.next_power_of_2(ncols) num_warps = _get_num_warps(block_size) dx = torch.empty_like(probs) grid = (nrows,) _softmax_kernel_bwd[grid]( dx, dx.stride(0), probs, probs.stride(0), grad_probs, grad_probs.stride(0), ncols, block_size=block_size, num_warps=num_warps, ) return dx.view(*orig_shape), None fused_softmax = triton_softmax.apply if __name__ == '__main__': sample = torch.tensor([[1,2,3,4,5], [5,4,3,2,1]], dtype = torch.float32, device="cuda", requires_grad=True) from torch.nn.functional import softmax as torch_softmax res_torch = torch_softmax(sample, dim=1) res_triton = fused_softmax(sample) torch.testing.assert_close(res_torch, res_triton, rtol=0, atol=1e-4) # backward dout = torch.randn_like(sample) bwd_torch = res_torch.backward(dout) bwd_triton = res_triton.backward(dout) torch.testing.assert_close(bwd_triton, bwd_torch, rtol=0, atol=1e-4) ================================================ FILE: tutorials/triton/kernels/readme.md ================================================ Triton tutorials 1 - Vector Add - Starting tutorial on simple first kernel 2 - Fused Softmax - Full fused softmax with both forward and backward (training ready) ================================================ FILE: tutorials/triton/kernels/vector_add.py ================================================ # coding up a Triton vector addition kernel # links to import triton import triton.language as tl import torch @triton.jit def kernel_vector_addition(a_ptr, b_ptr, out_ptr, num_elems: tl.constexpr, block_size: tl.constexpr): pid = tl.program_id(axis = 0) block_start = pid * block_size # 0 * 2 = 0, 1 * 2 = 2, thread_offsets = block_start + tl.arange(0, block_size) mask = thread_offsets < num_elems a_pointers = tl.load(a_ptr + thread_offsets, mask = mask) b_pointers = tl.load(b_ptr + thread_offsets, mask = mask) res = a_pointers + b_pointers tl.store(out_ptr + thread_offsets, res, mask=mask) def ceil_div(x: int, y: int)-> int: return ((x+y-1)// y) def vector_addition(a: torch.tensor, b: torch.tensor)-> torch.tensor: output_buffer = torch.empty_like(a) assert a.is_cuda() and b.is_cuda() num_elems = a.numel() assert num_elems == b.numel() # todo - handle mismatched sizes block_size = 128 grid_size = ceil_div(num_elems, block_size) grid = (grid_size,) k2 = kernel_vector_addition[grid](a, b, output_buffer, num_elems, block_size) return output_buffer ================================================ FILE: tutorials/triton/tests/test_softmax.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. import pytest import torch import sys sys.path.append('..') from triton_kernels.softmax import fused_softmax from test_utils import assert_expected, set_rng_seed, gpu_test @pytest.fixture(autouse=True) def set_seed(): set_rng_seed(2020) @gpu_test() class TestForwardSoftMax: def test_forward_2D_float32(self,): # float32 seq_len = 768 sample_constant_float32 = torch.ones((seq_len, seq_len), dtype=torch.float32, device='cuda') sample_random_float32 = torch.randn_like(sample_constant_float32) expected_out_constant32 = torch.softmax(sample_constant_float32, dim=1) expected_out_random32 = torch.softmax(sample_random_float32, dim=1) triton_out_c32 = fused_softmax(sample_constant_float32) triton_out_random32 = fused_softmax(sample_random_float32) assert_expected(triton_out_c32, expected_out_constant32 ) assert_expected(triton_out_random32, expected_out_random32) def test_forward_2D_bfloat16(self,): # bfloat16 seq_len = 2048 sample_constant_bf16 = torch.ones((seq_len, seq_len), dtype=torch.bfloat16, device='cuda') sample_random_bf16 = torch.randn_like(sample_constant_bf16) expected_out_c_bf16 = torch.softmax(sample_constant_bf16, dim=1) expected_out_rand_bf16 = torch.softmax(sample_random_bf16, dim=1) triton_out_c_bf16 = fused_softmax(sample_constant_bf16) triton_out_rand_bf16 = fused_softmax(sample_random_bf16) assert_expected(triton_out_c_bf16, expected_out_c_bf16 ) assert_expected(triton_out_rand_bf16, expected_out_rand_bf16) def test_forward_3D_bfloat16(self,): # bfloat16 seq_len = 2048 batch = 12 sample_constant_bf16 = torch.ones((batch, seq_len, seq_len), dtype=torch.bfloat16, device='cuda') sample_random_bf16 = torch.randn_like(sample_constant_bf16) expected_out_c_bf16 = torch.softmax(sample_constant_bf16, dim=1) expected_out_rand_bf16 = torch.softmax(sample_random_bf16, dim=1) triton_out_c_bf16 = fused_softmax(sample_constant_bf16) triton_out_rand_bf16 = fused_softmax(sample_random_bf16) assert_expected(triton_out_c_bf16, expected_out_c_bf16, atol=1e-2 ) assert_expected(triton_out_rand_bf16, expected_out_rand_bf16, atol=1e-2) @gpu_test() class TestBackwardSoftMax: def test_backward_2D(self,): seq_len = 1024 sample_constant_float32 = torch.ones((seq_len, seq_len), dtype=torch.float32, device='cuda', requires_grad=True) sample_random_float32 = torch.randn_like(sample_constant_float32, requires_grad=True) expected_fwd_constant32 = torch.softmax(sample_constant_float32, dim=1) expected_fwd_random32 = torch.softmax(sample_random_float32, dim=1) triton_fwd_c32 = fused_softmax(sample_constant_float32) triton_fwd_random32 = fused_softmax(sample_random_float32) dout = torch.randn_like(sample_constant_float32) expected_bwd_c32 = expected_fwd_constant32.backward(dout) expected_bwd_r32 = expected_fwd_random32.backward(dout) triton_bwd_c32 = triton_fwd_c32.backward(dout) triton_bwd_r32 = triton_fwd_random32.backward(dout) assert_expected(triton_bwd_c32, expected_bwd_c32 ) assert_expected(triton_bwd_r32, expected_bwd_r32) def test_bwd_3D(self,): seq_len = 2048 batch = 4 sample_constant_float32 = torch.ones((batch, seq_len, seq_len), dtype=torch.float32, device='cuda', requires_grad=True) sample_random_float32 = torch.randn_like(sample_constant_float32, requires_grad=True) expected_fwd_constant32 = torch.softmax(sample_constant_float32, dim=1) expected_fwd_random32 = torch.softmax(sample_random_float32, dim=1) triton_fwd_c32 = fused_softmax(sample_constant_float32) triton_fwd_random32 = fused_softmax(sample_random_float32) dout = torch.randn_like(sample_constant_float32) expected_bwd_c32 = expected_fwd_constant32.backward(dout) expected_bwd_r32 = expected_fwd_random32.backward(dout) triton_bwd_c32 = triton_fwd_c32.backward(dout) triton_bwd_r32 = triton_fwd_random32.backward(dout) assert_expected(triton_bwd_c32, expected_bwd_c32 ) assert_expected(triton_bwd_r32, expected_bwd_r32) ================================================ FILE: tutorials/triton/tests/test_utils.py ================================================ from pathlib import Path from typing import Any, Dict, NamedTuple, Optional, Tuple, Union import pytest import torch import torch.distributed as dist from torch import Tensor, nn def assert_expected( actual: Any, expected: Any, rtol: Optional[float] = 0, atol: Optional[float] = 1e-4, check_device=True, ): torch.testing.assert_close( actual, expected, rtol=rtol, atol=atol, check_device=check_device, msg=f"actual: {actual}, expected: {expected}", ) def set_rng_seed(seed): """Sets the seed for pytorch random number generators""" torch.manual_seed(seed) def gpu_test(gpu_count: int = 1): """ Annotation for GPU tests, skipping the test if the required amount of GPU is not available """ message = f"Not enough GPUs to run the test: required {gpu_count}" return pytest.mark.skipif(torch.cuda.device_count() < gpu_count, reason=message)