Repository: usyd-fsalab/fp6_llm Branch: main Commit: 12e83379f16a Files: 30 Total size: 145.7 KB Directory structure: gitextract_xbh0tnpv/ ├── .gitignore ├── LICENSE ├── README.md ├── examples/ │ └── README.md ├── fp6_llm/ │ ├── Makefile │ ├── __init__.py │ └── csrc/ │ ├── fp6_linear.cu │ ├── fp6_linear.cuh │ ├── include/ │ │ ├── configs.h │ │ ├── kernel_matmul.cuh │ │ ├── kernel_reduction.cuh │ │ ├── ptx_cp.async.cuh │ │ ├── ptx_mma.cuh │ │ ├── utils_core.cuh │ │ ├── utils_gmem.cuh │ │ └── utils_parallel_dequant.cuh │ ├── pybind.cpp │ └── utils/ │ ├── common.h │ ├── weight_dequant.h │ ├── weight_prepacking.h │ └── weight_quant.h ├── setup.py └── tests/ ├── cpp/ │ ├── Makefile │ ├── kernel_test.h │ ├── kernel_test_fp6.cu │ ├── kernel_test_fpx.cu │ └── run.sh └── python/ ├── kernel_test_fp6.py ├── kernel_test_fpx.py └── run.sh ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *ncu *.so *.o *.ncu-rep *.nsys-rep *.sqlite build kernel_test *.egg-info .vscode ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # Quant-LLM (FP6, FP5, FPx...) Six-bit quantization (FP6) can achieve **better trade-offs** between [*model quality*](#1-model-accuracy) and [*inference cost*](#2-speedups-on-linear-layers) compard to 4-bit and 8-bit quantization counterparts, reducing the size of large language models (LLMs) effectively and preserving the model quality consistently across varied applications. To support **6-bit inference of LLMs effective on modern GPUs**, we provide the official implementation of [**FP6-LLM**](https://arxiv.org/pdf/2401.14112.pdf), achieving significant *speedups of linear layers* and *GPU memory reduction* over the fp16/int8 baselines. Our long-term goal is to support **various quantization methods** by providing **extensible & high-performance** GPU kernels for mixed-input matrix multiplication, using the **unified design scheme** presented in our [paper](https://arxiv.org/pdf/2401.14112.pdf), which is recently accepted by [USENIX ATC24](https://www.usenix.org/conference/atc24/presentation/xia). Currently, we have tested **FP6_e3m2** and **FP5_e2m2**. However, our code is templated and easy to support different combination of eXmY if necessary. ![Overview of FP6-LLM.](./docs/figures/banner.png) ## Roadmap The current release contains: - We support model weights in **FP6_e3m2** or **FP5_e2m2** and the activations in FP16 format. - Efficient CUDA implementation for mixed-input matrix multiplication of linear layers (weights in FP6 and activations in FP16 format) with Tensor Core enabled. - C++ APIs and PyTorch APIs to use the CUDA kernels. - Test codes to demonstrate the usage of FP6-LLM and verify its correctness. - **End-to-end inference** support of [LLaMA2](https://arxiv.org/abs/2307.09288) models is released at [Deepspeed](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024). Our future plan includes but not limited to : - [ ] Currently, FP6-LLM only supports **FP6** quantization due to its accuracy/performance tradeoff benefits. However, the technology of FP6-LLM can be easily applied to other quanzation methods, e.g., **FP4, INT5**. - [ ] Currently, FP6-LLM is only tested and verified on **A100 GPUs**, but the core design methods can also be applied to other Tensor Core GPUs like **NVIDIA H100 and GH200**. Furtheremore, W6A8 quantization can be supported on H100 GPUs by exploiting the FP8 Tensor Cores. ## Installation 1. Clone this repository. ```sh git clone https://github.com/usyd-fsalab/fp6_llm.git cd fp6_llm ``` 2. Install the python package. [Enabling PyTorch APIs] ```sh pip install . ``` 3. Compiling the .so file. [Enabling C++ APIs] ```sh cd fp6_llm && make ``` ## Tests We provide scripts to verify the correctness of FP6-LLM. The outputs of FP6-LLM are compared to the outputs of the FP16 baseliness (the official implementation of linear layer in PyTorch and NVIDIA cuBLAS). #### 1. Run tests for PyTorch APIs ``` cd ../tests/python ./run.sh ``` #### 2. Run tests for C++ APIs ``` cd ../cpp make ./run.sh ``` ## How to use our FP6 CUDA kernels We implemented the CUDA kernel supporting matrix multiply C = A × B, where A is the weight matrix of shape [OC, IC] in FP6 and B is the activation matrix of shape [IC, BS] in FP16. C and B are column-major matrices. The CUDA kernels can be launched via **PyTorch APIs** or **C++ APIs**. Currently: - OC(Output Channel) must be a multiple of 256, and IC(Input Channel) must be a multiple of 64. - BS(Batch Size) can be arbitrary values, smaller BS is prefered for better performance than FP16 baseline. For more details of using FP6-LLM APIs, please see the source code of the [C++/Python Test](#tests). #### 1. PyTorch APIs To use the PyTorch APIs of FP6-LLM, you only need to import the python module: ``` import fp6_llm ``` * Ahead of time weight prepacking: ``` fp6_pakced_weight = fp6_llm.weight_prepacking_cpu(fp6_weight) Arguments: fp6_weight: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. Return: fp6_pakced_weight: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. ``` * Execute FP6-FP16 mixed input GEMM on GPUs: ``` fp16_output=fp6_llm.linear_forward_cuda(fp6_packed_weight, fp16_scale, fp16_input, splitK) Arguments: fp6_packed_weight: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. fp16_scale: tensor of shape [OC]; // half tensor fp16_input: tensor of shape [B, IC]; // half tensor splitK: int value, spliting the MatMul problem along K dimension for higher GPU utilization, default 1. Return: fp16_output: tensor of shape [B, OC]; // half tensor ``` * To help users determine the SplitK for other MatMul shapes, we provide a heuristic function in this repo to calculate a SplitK. This function can provide you a pretty good candidature for SplitK. To achieve best performance, please try different SplitK executing the MatMul of the given shape and choose the best SplitK according to the measured kernel latency. ``` import fp6_llm SplitK = fp6_llm.HeuristicFuntion_SplitK(M, N, Number_GPU_SMs) Arguments: M is the number of rows of the weight matrix, N is the inference batch size. More specifically, the shape of the MulMal: (M, K) * (K, N) -> (M, N), Number_GPU_SMs is the number of Stream Multiprocessors of the GPU you are using. For A100 GPUs, Number_GPU_SMs=108. ``` * Dequantize an FP6 matrix back to FP16 matrix with CPUs (a useful tool to construct input matrices for the FP16 GEMM baseline): ``` fp16_tensor=fp6_llm.weight_dequant_cpu(fp6_tensor, fp16_scale) Arguments: fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. fp16_scale: half tensor of shape [OC]; // for row-wise quantization. Return: fp16_tensor: half tensor of shape [OC, IC]. ``` #### 2. C++ APIs To use the C++ APIs of FP6-LLM, you need to include [this head file](fp6_llm/csrc/fp6_linear.cuh) of FP6-LLM in your C++ codes: ``` #include "fp6_linear.cuh" ``` Besides, you need to link the dynamic linkable library (*fp6.so*) under [this directory](fp6_llm/) during compilation. * Ahead of time weight prepacking: ``` void weight_matrix_prepacking(int* packed_weights, // [Output] prepacked FP6 weight matrix int *FP6Weights, // [Input] original FP6 weight matrix size_t M, // OC size_t K); // IC ``` * Execute FP6-FP16 mixed input GEMM on GPUs: ``` cudaError_t fp6_linear_kernel(cudaStream_t stream, // CUDA stream to execute the GPU kernel. const uint4 *Weight, // [Input] Pointer to the FP6 weight matrix. const half *Scales, // [Input] Pointer to the FP16 quantization scales. const half *B, // [Input] Pointer to the FP16 input activation. half *C, // [Output] Pointer to the FP16 output activation. const size_t M_Global, // OC const size_t N_Global, // BS const size_t K_Global, // IC float *Reduction_Workspace, // Pointer to the temporary workspace for splitK reduction. Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) int Split_K); // splitK ``` * Dequantize an FP6 matrix back to FP16 matrix with CPUs ``` void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, // [Output] Pointer to the dequantized FP16 weight. matrix. unsigned char* A_6bit_h, // [Input] Pointer to the quantized FP6 weight matrix. size_t M, // OC size_t K, // IC half* scale); // [Input] Pointer to the FP16 quantization scales. ``` ## How to use our FP5 CUDA kernels Similar to FP6 APIs, except that two more parameters (EXPONENT, MANTISSA) are required. More detailed will be provided soon. ## Performance #### 1. Model Accuracy > **FP6 quantization is a practical alternative to further democratize the deployment of LLMs without significantly sacrificing model quality on complex tasks and various model sizes.** - While 4-bit quantization unavoidably causes degradation in model quality, near-lossless model compression can be achieved with 6-bit quantization. As shown in Table 1 and Table 2, FP6 displays **strong and consistent performance across various tasks** including code generation and zero-shot perplexity performance. - It also shows **high robustness across various model sizes**, e.g., 1B, 13B, and 65B LLaMA models. - FP6 quantization already works well on coarse-grained quantization, while INT4 quantization heavily relies on Fine-Grained Quantization (FGQ) methods to maintain high model quality. - *More details can be found in these two papers ([FP6-LLM](https://arxiv.org/pdf/2401.14112.pdf) & [ZeroQuant(4+2)](https://arxiv.org/abs/2312.08583)).* ![ModelAccuracy](docs/figures/Accuracy.png) #### 2. Speedups on linear layers > **Compared to the FP16 baselines (cuBLAS), INT8 quantization (TensorRT_LLM), and FP4 quantization (bitsandbytes):** - FP6-LLM outperforms bitsandbytes, cuBLAS, and TensorRT_LLM by **up to** 8.9×, 2.6×, and 1.9×. - FP6-LLM outperforms bitsandbytes, cuBLAS and TensorRT_LLM by 7.2×, 2.1×, and 1.3× **on average**. ![Speedup_To_8bit](docs/figures/Speedup_to_8bit.png) > **Compared to baselines for INT4 quantization (TensorRT_LLM):** - FP6-LLM is **1.06×/1.04×/0.94× faster** than **Fine-grained_W4A16** at batch size 8/16/32, espectively. - FP6-LLM is **only 16%/17%/24% slower** than **Coarse-grained_W4A16** at batch size 8/16/32. - It is **a worthwhile trade-off** between model quality and inference speed, since FP6 quantization can provide [higher model quality](#1-model-accuracy) than INT4 quantization. ![Speedup_To_4bit](docs/figures/Speedup_to_4bit.png) #### 3. Speedups on end-to-end inference > **Experimental settings and baselines:** - We integrate our FP6-LLM kernel into DeepSpeed for end-to-end evaluation. The baseline for comparison is the FP16 execution of the original DeepSpeed inference system. - We set the prefill/prompt length of each request to 0.5K, and generate 1.5K tokens for each request ignoring the "EOS" (end of sequence) token. ![End-to-end Inference](./docs/figures/e2e_inference.png) > **Results of LLaMA-70b:** - Both FP6-LLM and FP16 baseline can at most set the inference **batch size to 32** before running out of GPU memory, whereas **FP6-LLM** only requires **a single GPU** and the **baseline** uses **two GPUs**. - FP6-LLM achieves **1.69×-2.65× higher normalized inference throughput** than the FP16 baseline. > **Results of OPT-30b:** - FP6-LLM can set the inference batch size at most to 16 before running out of GPU memory while the FP16 baseline can at most serve 4 requests in a batch. - Using a 80GB A100 GPU, FP6-LLM/FP16-baseline can at most achieve 319.1/78.8 tokens per GPU-second with batch size 16/4. - FP6-LLM achieves 1.91×/1.84×/1.72× higher generation throughput compared to the FP16 baseline when their batch sizes are set to 1/2/4. > **Inference latency breakdown of LLaMA-70b:** - The execution of linear layers (MatMul) implemented with FP6-LLM is 1.20× faster than the FP16 baseline on average, even with half number of GPUs. - The FP16 baseline is faster running multi-head attention (MHA) with 2-way tensor parallelism. - Cross-GPU communications (NCCL) is avoided using FP6-LLM since only a single GPU is required. > **Inference latency breakdown of OPT-30b:** - End-to-end performance improvements mainly comes from time reduction in executing linear layers. - The linear layers implemented with FP6-LLM are 2.39× faster than the FP16 baselines on average. ## Key Innovations > We propose **TC-FPx**, **the first full-stack GPU system design scheme with unified Tensor Core support of float-point weights for various quantization bit-width (6-bit, 5-bit, 3-bit, etc.), mitigating the "memory wall" issues during LLM inference.** TC-FPx breaks the limitations of the underlying GPU hardware, allowing the GPU to support linear layer calculations involving model weights of arbitrary bit width. In TC-FPx, Tensor Cores are utilized for intensive computation of matrix multiplications, while SIMT cores are effectively leveraged for weight dequantization, transforming the x-bit model weights to FP16 type during runtime before feeding them to Tensor Cores. - We propose **Ahead-of-time Bit-level Pre-packing** to resolve the challenge of unfriendly memory access for weights with irregular bit-width, enabling optimal GPU memory access. - Besides, we propose **SIMT-Efficient GPU Runtime** to minimize the runtime overhead of weight de-quantization. - Last but not least, we present the software pipeline of TC-FPx kernel, where SIMT cores, Tensor Cores, and the GPU memory hierarchy cooperate efficiently with high performance. ![](./docs/figures/designs.png) ## FP6-LLM Community FP6-LLM is already integrated in [DeepSpeed](https://github.com/microsoft/DeepSpeed) and this new feature will be available soon. Given that easy-to-use PyTorch and C++ APIs are provided, FP6-LLM can be easily integrated to any inference frameworks as an useful component. We welcome collaborations to integrate FP6-LLM into other inference frameworks. We also welcome all AI developers/practitioners/researchers to join this on-going project, fully exploring the potential of different quantization methods. ## Citation If you find FP6-LLM useful or relevant to your research, please kindly cite [our paper](https://arxiv.org/pdf/2401.14112.pdf): ``` @misc{xia2024fp6llm, title={FP6-LLM: Efficiently Serving Large Language Models Through FP6-Centric Algorithm-System Co-Design}, author={Haojun Xia and Zhen Zheng and Xiaoxia Wu and Shiyang Chen and Zhewei Yao and Stephen Youn and Arash Bakhtiari and Michael Wyatt and Donglin Zhuang and Zhongzhu Zhou and Olatunji Ruwase and Yuxiong He and Shuaiwen Leon Song}, year={2024}, eprint={2401.14112}, archivePrefix={arXiv}, primaryClass={cs.LG} } ``` ## Change Logs - **[28th May, 2024]**: Release of FP6-LLM v0.2. Change the project name back to QuantLLM. - **[31th April, 2024]**: Our paper is accepted by [USENIX ATC24](https://www.usenix.org/conference/atc24/presentation/xia). - **[4th March, 2024]**: Release of FP6-LLM v0.1. ## Related Projects - [ZeroQuant(4+2): Redefining LLMs Quantization with a New FP6-Centric Strategy for Diverse Generative Tasks](https://arxiv.org/abs/2312.08583) - [DeepSpeed: Extreme Speed and Scale for DL Training and Inference](https://github.com/microsoft/DeepSpeed) - [cuBLAS: Basic Linear Algebra on NVIDIA GPUs](https://developer.nvidia.com/cublas) - [TensorRT-LLM: A TensorRT Toolbox for Optimized Large Language Model Inference](https://github.com/NVIDIA/TensorRT-LLM) - [bitsandbytes: the library including quantization primitives for 8-bit & 4-bit operations](https://github.com/TimDettmers/bitsandbytes) - [Mixed-input matrix multiplication performance optimizations](https://blog.research.google/2024/01/mixed-input-matrix-multiplication.html) - [Flash-LLM: Enabling Cost-Effective and Highly-Efficient Large Generative Model Inference with Unstructured Sparsity](https://www.vldb.org/pvldb/vol17/p211-xia.pdf) ================================================ FILE: examples/README.md ================================================ # Example of LLM inference using FP6-LLM Example scripts of using FP6-LLM for end-to-end inference are coming soon. ================================================ FILE: fp6_llm/Makefile ================================================ # host compiler HOST_COMPILER ?= g++ NVCC := nvcc -ccbin $(HOST_COMPILER) # internal flags NVCCFLAGS := -m$(shell getconf LONG_BIT) CCFLAGS := -fPIC LDFLAGS := ALL_CCFLAGS := ALL_CCFLAGS += $(NVCCFLAGS) ALL_CCFLAGS += $(addprefix -Xcompiler ,$(CCFLAGS)) ################################################################################ ALL_CCFLAGS += -DNO_PYTORCH ALL_CCFLAGS += --std=c++17 ALL_CCFLAGS += -maxrregcount=255 ALL_CCFLAGS += --use_fast_math ALL_CCFLAGS += --ptxas-options=-v,-warn-lmem-usage,--warn-on-spills ################################################################################ ALL_LDFLAGS := ALL_LDFLAGS += $(ALL_CCFLAGS) ALL_LDFLAGS += $(addprefix -Xlinker ,$(LDFLAGS)) # Common includes and paths for CUDA INCLUDES := -I/usr/local/cuda/include/ # Gencode arguments SMS ?= 80 # Generate SASS code for each SM architecture listed in $(SMS) $(foreach sm,$(SMS),$(eval GENCODE_FLAGS += -gencode arch=compute_$(sm),code=sm_$(sm))) HEAD_FILES = csrc/fp6_linear.cuh \ csrc/include/configs.h \ csrc/include/kernel_matmul.cuh \ csrc/include/kernel_reduction.cuh \ csrc/include/ptx_cp.async.cuh \ csrc/include/ptx_mma.cuh \ csrc/include/utils_core.cuh \ csrc/include/utils_gmem.cuh \ csrc/include/utils_parallel_dequant.cuh \ csrc/utils/common.h \ csrc/utils/weight_prepacking.h \ csrc/utils/weight_quant.h \ csrc/utils/weight_dequant.h # Target rules all: libfp6.so libfp6.so: fp6.o $(EXEC) $(NVCC) --shared $(ALL_LDFLAGS) $(GENCODE_FLAGS) -o $@ $+ $(LIBRARIES) fp6.o: csrc/fp6_linear.cu $(HEAD_FILES) $(EXEC) $(NVCC) $(INCLUDES) $(ALL_CCFLAGS) $(GENCODE_FLAGS) -o $@ -c $< clean: rm -f libfp6.so fp6.o ================================================ FILE: fp6_llm/__init__.py ================================================ from fp6_llm_cuda import linear_forward_cuda, weight_prepacking_cpu, weight_dequant_cpu, linear_forward_eXmY_cuda, weight_prepacking_eXmY_cpu, weight_dequant_eXmY_cpu import math def Num_Wave(M, N, SplitK, Num_GPU_SMs): Num_Wave = math.ceil(M/512) * math.ceil(N/64) * SplitK / Num_GPU_SMs return Num_Wave # The shape of the MatMul: (M, K)*(K, N)->(M, N). # Typically, M is the number of rows of weight matrix, and N is the inference batch size. # Num_GPU_SMs is the number of SMs (Streaming Multiprocessors) within the GPUs, e,g. each A100 GPU has 108 SMs. def HeuristicFuntion_SplitK(M, N, Num_GPU_SMs): SplitK=1 Efficiency=0.0 for i in range(1, 100): numWave = Num_Wave(M, N, i, Num_GPU_SMs) eff = numWave / math.ceil(numWave) if eff >= 0.8: SplitK = i Efficiency = eff break return SplitK ================================================ FILE: fp6_llm/csrc/fp6_linear.cu ================================================ #include "include/kernel_matmul.cuh" #include "include/kernel_reduction.cuh" #include "utils/weight_prepacking.h" #include "utils/weight_dequant.h" #include "utils/weight_quant.h" #include // For CUDA stream management #include #include #include template static void Kernel_Ex(cudaStream_t stream, const uint4 *Weight, const half *Scales, const half *B, OutputDataType *C, const size_t M_Global, const size_t N_Global, const size_t K_Global, int Split_K) { #ifdef DEBUG_MODE printf("\n"); printf("Launcher.cu->Kernel_Ex():\n"); printf("M: %d, N: %d, K: %d, SplitK: %d\n", M_Global, N_Global, K_Global, Split_K); printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", TilingConfig::TILE_M, TilingConfig::TILE_K, TilingConfig::TILE_N); #endif static size_t SHMEM_SZ = max(TilingConfig::SMEM_SIZE_B_TILE+SMEM_SIZE_PER_TB_A_TILE, TilingConfig::SMEM_SIZE_C_TILE); cudaFuncSetAttribute(QUANT_GEMM_Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ); size_t dimN = (N_Global-1) / TilingConfig::TILE_N + 1; size_t dimM = M_Global * Split_K / TilingConfig::TILE_M; dim3 GridDim(dimN, dimM, 1); dim3 BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1); // #ifdef DEBUG_MODE printf("GridDim.x: %d, GridDim.y: %d, GridDim.z: %d, BlockDim.x: %d, BlockDim.y: %d, BlockDim.z: %d SHMEM_SZ: %d\n", GridDim.x, GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z, SHMEM_SZ); printf("\n"); #endif QUANT_GEMM_Kernel<<>> (Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); } template cudaError_t fpx_linear_kernel(cudaStream_t stream, const uint4 *Weight, const half *Scales, const half *B, half *C, const size_t M_Global, const size_t N_Global, const size_t K_Global, float *Reduction_Workspace, // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) int Split_K) { assert(M_Global % 256 == 0); assert(K_Global % 64 == 0); assert(N_Global>0); // Work around to support more N shapes: size_t N_PowerOf2; if(N_Global>0 && N_Global<=8) N_PowerOf2 = 8; if(N_Global>8 && N_Global<=16) N_PowerOf2 = 16; if(N_Global>16 && N_Global<=32) N_PowerOf2 = 32; if(N_Global>32 && N_Global<=64) N_PowerOf2 = 64; if(N_Global>64 && N_Global<=128) N_PowerOf2 = 128; if(N_Global>128) N_PowerOf2 = ((N_Global-1)/128+1) * 128; if (Split_K == 1) { switch (N_PowerOf2) { case 8: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; case 16: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; case 32: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; case 64: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; case 128: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); return cudaErrorUnknown; } Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; } } else { switch (N_PowerOf2) { case 8: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; case 16: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; case 32: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; case 64: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; case 128: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); return cudaErrorUnknown; } Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; } // Reduction for SplitK dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1); dim3 BlockDim(WARP_SIZE, 1, 1); SplitK_Reduction<<>>(C, Reduction_Workspace, M_Global, N_Global, Split_K); } return cudaGetLastError(); } cudaError_t fp6_linear_kernel( cudaStream_t stream, const uint4 *Weight, const half *Scales, const half *B, half *C, const size_t M_Global, const size_t N_Global, const size_t K_Global, float *Reduction_Workspace, int Split_K) { // return fpx_linear_kernel<3,2>( stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Reduction_Workspace, Split_K); } cudaError_t fp_eXmY_linear_kernel( const int EXPONENT, const int MANTISSA, cudaStream_t stream, const uint4 *Weight, const half *Scales, const half *B, half *C, const size_t M_Global, const size_t N_Global, const size_t K_Global, float *Reduction_Workspace, int Split_K) { // if(EXPONENT==2 && MANTISSA==2) return fpx_linear_kernel<2,2>( stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Reduction_Workspace, Split_K); if(EXPONENT==3 && MANTISSA==2) return fpx_linear_kernel<3,2>( stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Reduction_Workspace, Split_K); printf("QuantLLM_API Error: Unsupported EXPONENT=%d, MANTISSA=%d!\n", EXPONENT, MANTISSA); exit(-1); } #ifndef NO_PYTORCH #include #include /////////////////////////////////////////////////// Old Interface only Supporting FP6 ///////////////////////////////////////////////////////////////////// /* Computes FP6-FP16 GEMM (PyTorch interface). [Mathmatical Formula] Standard definition of linear layer: Out = In * trans(W), where In, Out, and W are stored in row-major. After Equivalent transformation : trans(Out) = W * trans(In). Note that we do not perform "transpose" during runtime, we instead interpret the In/Out as column-major matrices when calling our CUDA kernel. [Inputs] _in_feats: tensor of shape [B, IC]; // half _weights: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. _scales: tensor of shape [OC]; // half splitK: spliting the MatMul problem along K dimension for higher GPU utilization, default 1. [Outputs] _out_feats: tensor of shape [B, OC]; // half */ torch::Tensor fp6_linear_forward_cuda( torch::Tensor _in_feats, torch::Tensor _weights, torch::Tensor _scales, int splitK=1) { int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); int num_out_channels = _weights.size(0); assert( num_in_channels%64 == 0 ); assert( (num_in_channels/16*3) == _weights.size(1) ); // Making sure the K dimension is matched. // int M = num_out_channels; int K = num_in_channels; int N = num_in_feats; // Input Tensors auto weight = reinterpret_cast(_weights.data_ptr()); // weights is [OC, IC] but in FP6. auto in_feats = reinterpret_cast(_in_feats.data_ptr()); auto scales = reinterpret_cast(_scales.data_ptr()); // Output Tensors auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); at::Tensor _out_feats = torch::empty({num_in_feats, num_out_channels}, options); auto out_feats = reinterpret_cast(_out_feats.data_ptr()); options = torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device()); at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options); auto Reduction_Workspace = reinterpret_cast(_workspace.data_ptr()); // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) // Get the current device and stream int device_id = _in_feats.device().index(); auto stream = at::cuda::getCurrentCUDAStream(device_id).stream(); fp6_linear_kernel(stream, // Using current stream. weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); return _out_feats; } /* * Weight prepacking (Pytorch interface). * [Input & Output] * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. * [Output] * packed_tensor: int tensor of shape [OC, IC // 16 * 3]; */ torch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor) { size_t OC = fp6_tensor.size(0); size_t IC = fp6_tensor.size(1); assert (IC%3==0); IC = IC*16/3; assert( (OC%256==0) && (IC%64==0) ); auto packed_tensor = torch::empty_like(fp6_tensor); auto packed_tensor_ptr = reinterpret_cast(packed_tensor.data_ptr()); auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); weight_matrix_prepacking(packed_tensor_ptr, fp6_tensor_ptr, OC, IC); return packed_tensor; } /* * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs. * A useful tool to construct input matrices for the FP16 GEMM baseline. * [Input] * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. * fp16_scale: half tensor of shape [OC]; // for row-wise quantization. * [Output] * fp16_tensor: half tensor of shape [OC, IC]. */ torch::Tensor weight_matrix_dequant_cpu(torch::Tensor fp6_tensor, torch::Tensor fp16_scale) { int OC = fp6_tensor.size(0); assert(fp6_tensor.size(1) % 3 == 0); int IC = fp6_tensor.size(1) / 3 * 16; assert(fp16_scale.size(0)==OC); // auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); auto fp16_scale_ptr = reinterpret_cast(fp16_scale.data_ptr()); // auto options = torch::TensorOptions().dtype(fp16_scale.dtype()).device(fp16_scale.device()); at::Tensor fp16_tensor = torch::empty({OC, IC}, options); auto fp16_tensor_ptr = reinterpret_cast(fp16_tensor.data_ptr()); // DeQuantMatrix_FP6_To_FP16(fp16_tensor_ptr, (unsigned char*)fp6_tensor_ptr, OC, IC, fp16_scale_ptr); // return fp16_tensor; } /////////////////////////////////////////////////// New Interface Supporting FPx ///////////////////////////////////////////////////////////////////// /* Computes FPx-FP16 GEMM (PyTorch interface). [Mathmatical Formula] Standard definition of linear layer: Out = In * trans(W), where In, Out, and W are stored in row-major. After Equivalent transformation : trans(Out) = W * trans(In). Note that we do not perform "transpose" during runtime, we instead interpret the In/Out as column-major matrices when calling our CUDA kernel. [Inputs] _in_feats: tensor of shape [B, IC]; // half _weights: int tensor of shape [OC, IC // 32 * x]; // x INT32 words contains 32 FPx weights. _scales: tensor of shape [OC]; // half splitK: spliting the MatMul problem along K dimension for higher GPU utilization, default 1. [Outputs] _out_feats: tensor of shape [B, OC]; // half */ torch::Tensor fp_eXmY_linear_forward_cuda( int EXPONENT, int MANTISSA, torch::Tensor _in_feats, torch::Tensor _weights, torch::Tensor _scales, int splitK=1) { int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); int num_out_channels = _weights.size(0); assert( num_in_channels%64 == 0 ); assert( (num_in_channels/32*(1+EXPONENT+MANTISSA)) == _weights.size(1) ); // Making sure the K dimension is matched. // int M = num_out_channels; int K = num_in_channels; int N = num_in_feats; // Input Tensors auto weight = reinterpret_cast(_weights.data_ptr()); // weights is [OC, IC] but in FP6. auto in_feats = reinterpret_cast(_in_feats.data_ptr()); auto scales = reinterpret_cast(_scales.data_ptr()); // Output Tensors auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); at::Tensor _out_feats = torch::empty({num_in_feats, num_out_channels}, options); auto out_feats = reinterpret_cast(_out_feats.data_ptr()); options = torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device()); at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options); auto Reduction_Workspace = reinterpret_cast(_workspace.data_ptr()); // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) // Get the current device and stream int device_id = _in_feats.device().index(); auto stream = at::cuda::getCurrentCUDAStream(device_id).stream(); // fp_eXmY_linear_kernel( EXPONENT, MANTISSA, stream, // Using torch's current stream here. weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); return _out_feats; } /* * Weight prepacking (Pytorch interface). * [Input & Output] * fpx_tensor: int tensor of shape [OC, IC // 32 * x]; * [Output] * packed_tensor: int tensor of shape [OC, IC // 32 * x]; */ torch::Tensor weight_matrix_prepacking_fp_eXmY_cpu( int EXPONENT, int MANTISSA, torch::Tensor fpx_tensor) { int BIT_WIDTH = 1 + EXPONENT + MANTISSA; // size_t OC = fpx_tensor.size(0); size_t IC = fpx_tensor.size(1); assert (IC%BIT_WIDTH==0); IC = IC*32/BIT_WIDTH; assert( (OC%256==0) && (IC%64==0) ); auto packed_tensor = torch::empty_like(fpx_tensor); auto packed_tensor_ptr = reinterpret_cast(packed_tensor.data_ptr()); auto fpx_tensor_ptr = reinterpret_cast(fpx_tensor.data_ptr()); // weight_matrix_prepacking_fp_eXmY(EXPONENT, MANTISSA, packed_tensor_ptr, fpx_tensor_ptr, OC, IC); return packed_tensor; } /* * Dequant a FPx matrix to a equivalent FP16 matrix using CPUs. * A useful tool to construct input matrices for the FP16 GEMM baseline. * [Input] * fpx_tensor: int tensor of shape [OC, IC // 32 * x]; // * fp16_scale: half tensor of shape [OC]; // for row-wise quantization. * [Output] * fp16_tensor: half tensor of shape [OC, IC]. */ torch::Tensor weight_matrix_dequant_fp_eXmY_cpu( int EXPONENT, int MANTISSA, torch::Tensor fpx_tensor, torch::Tensor fp16_scale) { int BIT_WIDTH = 1 + EXPONENT + MANTISSA; // int OC = fpx_tensor.size(0); assert(fpx_tensor.size(1) % BIT_WIDTH == 0); int IC = fpx_tensor.size(1) / BIT_WIDTH * 32; assert(fp16_scale.size(0)==OC); // auto fpx_tensor_ptr = reinterpret_cast(fpx_tensor.data_ptr()); auto fp16_scale_ptr = reinterpret_cast(fp16_scale.data_ptr()); // auto options = torch::TensorOptions().dtype(fp16_scale.dtype()).device(fp16_scale.device()); at::Tensor fp16_tensor = torch::empty({OC, IC}, options); auto fp16_tensor_ptr = reinterpret_cast(fp16_tensor.data_ptr()); // dequant_matrix_fp_eXmY_to_fp16(EXPONENT, MANTISSA, fp16_tensor_ptr, (unsigned char*)fpx_tensor_ptr, OC, IC, fp16_scale_ptr); // return fp16_tensor; } #endif ================================================ FILE: fp6_llm/csrc/fp6_linear.cuh ================================================ #include #include #include /* * Computes FP6-FP16 GEMM (C++ interface). */ cudaError_t fp6_linear_kernel( cudaStream_t stream, const uint4 *Weight, const half *Scales, const half *B, half *C, const size_t M_Global, const size_t N_Global, const size_t K_Global, float *Reduction_Workspace, // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) int Split_K); cudaError_t fp_eXmY_linear_kernel( const int EXPONENT, const int MANTISSA, cudaStream_t stream, const uint4 *Weight, const half *Scales, const half *B, half *C, const size_t M_Global, const size_t N_Global, const size_t K_Global, float *Reduction_Workspace, int Split_K); /* * In-place weight prepacking (C++ interface). */ void weight_matrix_prepacking(int* packed_weights, int *FP6Weights, size_t M, size_t K); void weight_matrix_prepacking_fp_eXmY(const int EXPONENT, const int MANTISSA, int* packed_weights, int *FPxWeights, size_t M, size_t K); /* * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs. */ void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale); void dequant_matrix_fp_eXmY_to_fp16(const int EXPONENT, const int MANTISSA, half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale); #ifndef NO_PYTORCH #include /* * Computes FP6-FP16 GEMM (PyTorch interface). */ torch::Tensor fp6_linear_forward_cuda( torch::Tensor _in_feats, torch::Tensor _weights, torch::Tensor _scales, int splitK=1); torch::Tensor fp_eXmY_linear_forward_cuda( int EXPONENT, int MANTISSA, torch::Tensor _in_feats, torch::Tensor _weights, torch::Tensor _scales, int splitK=1); /* * Weight prepacking (Pytorch interface). */ torch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor); torch::Tensor weight_matrix_prepacking_fp_eXmY_cpu( int EXPONENT, int MANTISSA, torch::Tensor fpx_tensor); /* * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs. * A useful tool to construct input matrices for the FP16 GEMM baseline. * [Input] * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. * fp16_scale: half tensor of shape [OC]; // for row-wise quantization. * [Output] * fp16_tensor: half tensor of shape [OC, IC]. */ torch::Tensor weight_matrix_dequant_cpu( torch::Tensor fp6_tensor, torch::Tensor fp16_scale); torch::Tensor weight_matrix_dequant_fp_eXmY_cpu( int EXPONENT, int MANTISSA, torch::Tensor fpx_tensor, torch::Tensor fp16_scale); #endif ================================================ FILE: fp6_llm/csrc/include/configs.h ================================================ #ifndef CONFIGS_H #define CONFIGS_H //#define DEBUG_MODE #define PIPELINE_LEVEL_GMEM 2 #define PIPELINE_LEVEL_SMEM 2 // only support 2 /************************ Hardware Parameters ************************/ #define WARP_SIZE 32 #define REG_BIT_WIDTH 32 // mma: M=16 K=16 N=8 #define MMA_8 8 #define MMA_16 16 // for memory access #define THREAD_OPT_ACCESS_BIT_WIDTH_128 128 // LDS.128, cp_async.128, ... #define BIT_WIDTH_PER_HALF 16 // Half precision: FP16 /******************** Register Allocation For GEMM ********************/ #define REG_PER_THREAD_C_TENSOR_16_16 8 // 8 for FP32 Accumulation /********************** Memory Padding Parameters **********************/ // Eliminating bank-conflict #define PADDING_BYTES_16 16 // Padding 16 bytes each column #define PADDING_SHARED_MEM_FOR_B_8 8 // Padding 8 half each column, during CopyFromGlobalToShared() for B #define PADDING_SHARED_MEM_FOR_C_4 4 // Padding 4 float each column, during StoreToSharedMemoryFromRegister() for C /************************* WARP Tiling part-1 *************************/ #define WARP_ROW_MMA_TENSORS 4 #define WARP_M (WARP_ROW_MMA_TENSORS * MMA_16) // 64 #define WARP_K_MMA_TENSORS 4 #define WARP_K (WARP_K_MMA_TENSORS * MMA_16) // 64 template struct TilingConfig { // Depending on "n" dimension of the GEMM static constexpr int BLOCK_ROW_WARPS = BLOCK_ROW_WARPS_; static constexpr int BLOCK_COL_WARPS = BLOCK_COL_WARPS_; static constexpr int WARP_COL_MMA_TENSORS = WARP_COL_MMA_TENSORS_; /************************* WARP Tiling part-2 *************************/ static constexpr int WARP_N = WARP_COL_MMA_TENSORS * MMA_8; /*************************Thread Block Tiling *************************/ static constexpr int TILE_M = WARP_M * BLOCK_ROW_WARPS; static constexpr int TILE_N = MMA_8 * WARP_COL_MMA_TENSORS * BLOCK_COL_WARPS; static constexpr int TILE_K = WARP_K; /********************** #Thread per Thread Block **********************/ static constexpr int BLOCK_WARPS = BLOCK_ROW_WARPS * BLOCK_COL_WARPS; static constexpr int BLOCK_THREADS = BLOCK_WARPS * WARP_SIZE; /******************************* Others *******************************/ static constexpr int SMEM_SIZE_B_TILE = TILE_N * (TILE_K + PADDING_BYTES_16) * 2 * PIPELINE_LEVEL_GMEM; // sizeof(half)=2, doubleBuffer=2 static constexpr int SMEM_SIZE_C_TILE = TILE_N * (TILE_M + PADDING_BYTES_16) * 4; // sizeof(float)=4 }; #endif // CONFIGS_H ================================================ FILE: fp6_llm/csrc/include/kernel_matmul.cuh ================================================ #include "configs.h" #include "utils_gmem.cuh" #include "utils_core.cuh" /************************** Bitwidth of Weight Segments ************************/ #define BIT_WIDTH_1 1 #define BIT_WIDTH_2 2 #define BIT_WIDTH_4 4 /*************************** 64*64 Weghts of Weight Matrix *********************/ #define WEIGHT_PER_WARP (WARP_M*WARP_K) // 64*64 = 4096 #define SMEM_SIZE_PER_WARP_1BIT (WEIGHT_PER_WARP*BIT_WIDTH_1/8) // 512 Bytes, doubleBuffer not taken into consideration #define SMEM_SIZE_PER_WARP_2BIT (WEIGHT_PER_WARP*BIT_WIDTH_2/8) // 1024 Bytes, doubleBuffer not taken into consideration #define SMEM_SIZE_PER_WARP_4BIT (WEIGHT_PER_WARP*BIT_WIDTH_4/8) // 2048 Bytes, doubleBuffer not taken into consideration #define SMEM_SIZE_PER_TB_1BIT (SMEM_SIZE_PER_WARP_1BIT*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM) // #WARP=4; Trible-Buffer for 3-level pipeline for A = 6 KB; double buffer for 2-level pipeline A= 4 KB. #define SMEM_SIZE_PER_TB_2BIT (SMEM_SIZE_PER_WARP_2BIT*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM) // #WARP=4; Trible-Buffer for 3-level pipeline for A = 12 KB; double buffer for 2-level pipeline A= 8 KB. #define SMEM_SIZE_PER_TB_4BIT (SMEM_SIZE_PER_WARP_4BIT*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM) // #WARP=4; Trible-Buffer for 3-level pipeline for A = 24 KB; double buffer for 2-level pipeline A= 16 KB. #define SMEM_SIZE_PER_TB_A_TILE (SMEM_SIZE_PER_TB_1BIT+SMEM_SIZE_PER_TB_2BIT+SMEM_SIZE_PER_TB_4BIT) // used in fp6_linear.cu, Kernel_Ex(). /******************** Gloabl Memory Layout For QUANTIZED DATA *******************/ #define NUM_INT4_PER_WARP_1BIT (WEIGHT_PER_WARP*BIT_WIDTH_1/128) // 32 #define NUM_INT4_PER_WARP_2BIT (WEIGHT_PER_WARP*BIT_WIDTH_2/128) // 64 #define NUM_INT4_PER_WARP_4BIT (WEIGHT_PER_WARP*BIT_WIDTH_4/128) // 128 /* * C = A*B * A: row major with ahead-of-time layout transformation, FP6 * B: col major, FP16 * C: col major, FP16 */ template __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, const half *B, OutputDataType* C, const size_t M_Global, const size_t N_Global, const size_t K_Global, int Split_K) { #ifdef DEBUG_MODE assert(K_Global%TilingConfig::TILE_K==0); assert(M_Global%TilingConfig::TILE_M==0); assert( gridDim.y == Split_K * (M_Global/TilingConfig::TILE_M)); #endif // 1+2+4 weight split constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; const uint4* Weight_1bit = Weight; const uint4* Weight_2bit = Weight_1bit + (USE_SEG_1BIT ? M_Global*K_Global*BIT_WIDTH_1/128 : 0); const uint4* Weight_4bit = Weight_2bit + (USE_SEG_2BIT ? M_Global*K_Global*BIT_WIDTH_2/128 : 0); // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned extern __shared__ __align__(128) half smem[]; half (*smem_array)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = reinterpret_cast ( smem + SMEM_SIZE_PER_TB_A_TILE/2 ); // Dynamic shared memory for FP16 B tiles __shared__ half QuantScales[64*TilingConfig::BLOCK_WARPS]; // static shared memory for quantization scales, 64 row per warp * 4 warps = 512 Bytes // Thread Block Mapping, considering SplitK const size_t BatchID = blockIdx.y / (M_Global/TilingConfig::TILE_M); const size_t x = blockIdx.x; // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) const size_t y = blockIdx.y % (M_Global/TilingConfig::TILE_M); // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) const size_t Tile_Start_M = y * TilingConfig::TILE_M; const size_t Tile_Start_N = x * TilingConfig::TILE_N; const size_t NumColumnToCopy = (N_Global-Tile_Start_N) < TilingConfig::TILE_N ? (N_Global-Tile_Start_N) : TilingConfig::TILE_N; const size_t NumBlock_K = K_Global/TilingConfig::TILE_K; const size_t AverageNumBlock_K = NumBlock_K/Split_K; const size_t ExtraNumBlock_K = NumBlock_K - AverageNumBlock_K * Split_K; size_t NumIter = AverageNumBlock_K; size_t StartBlockID_K = AverageNumBlock_K*BatchID; if(BatchID(smem); uint32_t* AFrag_2BIT_SPTR = AFrag_1BIT_SPTR + SMEM_SIZE_PER_TB_1BIT/4; uint32_t* AFrag_4BIT_SPTR = AFrag_2BIT_SPTR + SMEM_SIZE_PER_TB_2BIT/4; // 8 buffers including double buffers, 12 for trible buffers // StartSPTR for each WARP AFrag_1BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_1BIT/4; AFrag_2BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_2BIT/4; AFrag_4BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_4BIT/4; // Pre-fetch of A tile for(int i=0; i(AFrag_1BIT_SPTR+i*SMEM_SIZE_PER_WARP_1BIT/4*4, WARP_StartGPTR_A_1BIT); if(USE_SEG_2BIT) CopyFromGlobalToShared_A(AFrag_2BIT_SPTR+i*SMEM_SIZE_PER_WARP_2BIT/4*4, WARP_StartGPTR_A_2BIT); if(USE_SEG_4BIT) CopyFromGlobalToShared_A(AFrag_4BIT_SPTR+i*SMEM_SIZE_PER_WARP_4BIT/4*4, WARP_StartGPTR_A_4BIT); WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT/16; WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT/16; WARP_StartGPTR_A_4BIT += SMEM_SIZE_PER_WARP_4BIT/16; } // Global Memory Address for Matrix A (QuantScale) ///////////////////////////////////////////////////////////////////// const half* TB_StartGPTR_A_Scale = Scales + (y*TilingConfig::BLOCK_ROW_WARPS) * 64; const half* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64; CopyFromGlobalToShared_Scales(QuantScales+WARP_i*64, WARP_StartGPTR_A_Scales); // Copying B tile from Global to Shared, considering SplitK ///////////////////////////////////////////////////////////// const half *BTile_GPTR = B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K; for(int i=0; i (smem_array+i*TilingConfig::TILE_N, BTile_GPTR, K_Global, NumColumnToCopy); BTile_GPTR += TilingConfig::TILE_K; } // Register Allocation for A,B, and C, Initilazed to Zeros ///////////////////////////////////////////////////////////////////// constexpr int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block constexpr int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block uint32_t a [NumRegSets_a * PIPELINE_LEVEL_SMEM][4]; // double/Trible buffer is used // Registers to store decompressed FP6 uint32_t b [NumRegSets_b * PIPELINE_LEVEL_SMEM][4]; // double/Triple buffer is used // Register to store FP16 B matrix (a slice) float c[NumRegSets_a * NumRegSets_b][REG_PER_THREAD_C_TENSOR_16_16]; for(int i=0; i(a, b, AFrag_1BIT_SPTR, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); // The outer loop. ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// #pragma unroll(1) for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++) { // Trible-Buffer for A Tile uint32_t* __restrict__ read_SPTR_Frag_1bit = AFrag_1BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT/4*4; // 512 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 uint32_t* __restrict__ read_SPTR_Frag_2bit = AFrag_2BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 uint32_t* __restrict__ read_SPTR_Frag_4bit = AFrag_4BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 uint32_t* __restrict__ read2_SPTR_Frag_1bit = AFrag_1BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT/4*4; uint32_t* __restrict__ read2_SPTR_Frag_2bit = AFrag_2BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT/4*4; uint32_t* __restrict__ read2_SPTR_Frag_4bit = AFrag_4BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4; uint32_t* __restrict__ write_SPTR_Frag_1bit = AFrag_1BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT/4*4; // 512 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 uint32_t* __restrict__ write_SPTR_Frag_2bit = AFrag_2BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 uint32_t* __restrict__ write_SPTR_Frag_4bit = AFrag_4BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 // Trible-Buffer for B Tile half __restrict__ (*read_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; half __restrict__ (*read2_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; half __restrict__ (*write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; // bool GlobalCopy = (tile_id_k+PIPELINE_LEVEL_GMEM-1) < NumIter; // Copying A tile from Global to Register, Bypassing L1, using double-buffer if(USE_SEG_1BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_1bit, WARP_StartGPTR_A_1BIT, GlobalCopy); if(USE_SEG_2BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_2bit, WARP_StartGPTR_A_2BIT, GlobalCopy); if(USE_SEG_4BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_4bit, WARP_StartGPTR_A_4BIT, GlobalCopy); // copying B tile from GlobalMemory to SharedMemory CopyFromGlobalToShared (write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy); cp_async_group_commit(); core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each WARP; read_SPTR is shared among WARPs core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 2); core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 3); // Barriers and Synchronizations cp_async_wait_group(); __syncthreads(); core_mma_slice(c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit, read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0); // Updating global PTRs WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT/16; // 2KB/16=128 (1)/16: int4*+1 = char*+16 WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT/16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 WARP_StartGPTR_A_4BIT += SMEM_SIZE_PER_WARP_4BIT/16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 BTile_GPTR += TilingConfig::TILE_K; } ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // Store the C fragments to shared memory. float (*smem_CFrag) [TilingConfig::TILE_M+PADDING_SHARED_MEM_FOR_C_4] = reinterpret_cast (smem); StoreToSharedMemoryFromRegister(smem_CFrag, c); __syncthreads(); // Now that shared memory contains all the D tiles, stream them to global memory. OutputDataType* BlockGlobalPTR = C + BatchID*(M_Global*N_Global) + Tile_Start_M + Tile_Start_N*M_Global; for(size_t i=warpId; i::value) BlockGlobalPTR[j+i*M_Global] = __float2half_rn(smem_CFrag[i][j]); else BlockGlobalPTR[j+i*M_Global] = smem_CFrag[i][j]; } } ================================================ FILE: fp6_llm/csrc/include/kernel_reduction.cuh ================================================ /*************************************************************************** * Copyright 2023 The FLash-LLM Authors. All rights reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * http://www.apache.org/licenses/LICENSE-2.0 * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ***************************************************************************/ // Used for the reduction of result matrix if Split-K is used // Reduction_Workspace: (Split_K, M_Global, N_Global), column major // C: (M_Global, N_Global), column major // Each thread deals with 8 output elements, each elements is the sum of Split_K elements // Read Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 float_per_thread (256bit) -> 256 float per warp // Write Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 half_per_thread (128bit) -> 256 half per warp // GridSize = (M_Global*N_Global) / 256 #include #include #include #define REDUCTION_ELEMENT_PER_THREADBLOCK 256 #define HALF_PER_128BIT 8 __global__ void SplitK_Reduction(half* C, float* Reduction_Workspace, size_t M_Global, size_t N_Global, int Split_K) { half* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; float* WARP_GPTR_R = Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; half* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT; float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT; // Initializing Thread-Local Results float Results[HALF_PER_128BIT]; #pragma unroll for (int i = 0; i < HALF_PER_128BIT; i++) Results[i] = 0.0f; // Reduction for (int i = 0; i < Split_K; i++) { #pragma unroll for (int j = 0; j < HALF_PER_128BIT; j++) Results[j] += THREAD_GPTR_R[j]; THREAD_GPTR_R += M_Global * N_Global; } // Writing to global memory #pragma unroll for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]); } ================================================ FILE: fp6_llm/csrc/include/ptx_cp.async.cuh ================================================ /*************************************************************************** * Copyright 2023 The FLash-LLM Authors. All rights reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * http://www.apache.org/licenses/LICENSE-2.0 * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ***************************************************************************/ // Extended from CUTLASS's source code #ifndef PTX_CP_ASYNC_CUH #define PTX_CP_ASYNC_CUH #include #include #include template __device__ __forceinline__ void cp_async(half* smem_ptr, const half* global_ptr, bool pred_guard = true) { static_assert(SizeInBytes == 16, "Size is not supported"); unsigned smem_int_ptr = __cvta_generic_to_shared(smem_ptr); asm volatile("{ \n" " .reg .pred p;\n" " setp.ne.b32 p, %0, 0;\n" " @p cp.async.cg.shared.global [%1], [%2], %3;\n" "}\n" ::"r"((int)pred_guard), "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes)); } /// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. __device__ __forceinline__ void cp_async_group_commit() { asm volatile("cp.async.commit_group;\n" ::); } /// Blocks until all but previous cp.async.commit_group operations have committed. template __device__ __forceinline__ void cp_async_wait_group() { asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); } /// Blocks until all previous cp.async.commit_group operations have committed. // cp.async.wait_all is equivalent to : // cp.async.commit_group; // cp.async.wait_group 0; __device__ __forceinline__ void cp_async_wait_all() { asm volatile("cp.async.wait_all;\n" ::); } #endif ================================================ FILE: fp6_llm/csrc/include/ptx_mma.cuh ================================================ /*************************************************************************** * Copyright 2023 The FLash-LLM Authors. All rights reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * http://www.apache.org/licenses/LICENSE-2.0 * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ***************************************************************************/ #ifndef PTX_MMA_CUH #define PTX_MMA_CUH #include #include #include #include #include "configs.h" template __device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[][4], half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], int slice_id) { #ifdef DEBUG_MODE static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) ); #endif const int warpId = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * WARP_j; // each warp may start from reading warp_start_col'th column of the B tile in shared memory #ifdef DEBUG_MODE assert( warp_start_col==0 ); #endif int col = (lane_id%8) + (lane_id/16)*8; int row = (lane_id%16) / 8 * 8; uint32_t smem_local_ptr = static_cast(__cvta_generic_to_shared(&read_SPTR[warp_start_col+col][slice_id*MMA_16 + row])); if(TilingConfig::WARP_COL_MMA_TENSORS==1) { asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(Reg[0][0]), "=r"(Reg[0][1]) : "r"(smem_local_ptr)); } else { #pragma unroll for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS/2; i++) { asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3]) : "r"(smem_local_ptr)); smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); } } } __device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b) { asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "{ %0, %1, %2, %3}," "{ %4, %5, %6, %7 }," "{ %8, %9 }," "{ %10, %11, %12, %13 };" : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); } #endif ================================================ FILE: fp6_llm/csrc/include/utils_core.cuh ================================================ #ifndef UTILS_CORE_CUH #define UTILS_CORE_CUH #include #include "configs.h" #include "ptx_mma.cuh" #include "utils_parallel_dequant.cuh" template __device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR, int slice_id) { SPTR += slice_id * (NUM_INT_PER_THREAD*WARP_SIZE); int lane_id = threadIdx.x % WARP_SIZE; #pragma unroll for(int i=0; i __device__ __forceinline__ void initialize_mma_slice(uint32_t (*a)[4], uint32_t (*b)[4], uint32_t* __restrict__ A_1BIT_SPTR_read, uint32_t* __restrict__ A_2BIT_SPTR_read, uint32_t* __restrict__ A_4BIT_SPTR_read, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], uint32_t* RPTR_Scales) { // 1+2+4 weight split constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; // Writing registers // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; uint32_t a_1bit[1]; // NO double buffer uint32_t a_2bit[2]; // NO double buffer uint32_t a_4bit[4]; // NO double buffer if(USE_SEG_1BIT) CopyFromSharedToRegister_AFrag<1> (a_1bit, A_1BIT_SPTR_read, 0); if(USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2> (a_2bit, A_2BIT_SPTR_read, 0); if(USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4> (a_4bit, A_4BIT_SPTR_read, 0); Dequant_32FP6_4Way(a, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FPx to FP16 at register level, dequantizing a slice each time B_FromSharedToReg(b, B_SPTR_read, 0); // Loading B from shared to registers } template __device__ __forceinline__ void core_mma_slice(float c[][REG_PER_THREAD_C_TENSOR_16_16], uint32_t (*a)[4], uint32_t (*b)[4], uint32_t* __restrict__ A_1bit_SPTR_read, uint32_t* __restrict__ A_2bit_SPTR_read, uint32_t* __restrict__ A_4bit_SPTR_read, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], uint32_t* RPTR_Scales, int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching { // 1+2+4 weight split constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; #ifdef DEBUG_MODE assert((TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block #endif const int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block uint32_t (*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = reinterpret_cast(c); // Reigsters for accumulated FP32 results // Setting RPTRs for double buffers uint32_t (*a_read )[4] = a; uint32_t (*a_write)[4] = a; uint32_t (*b_read )[4] = b; uint32_t (*b_write)[4] = b; if(slice_id%2==1) { b_write += NumRegSets_b; a_write += NumRegSets_a;} else { b_read += NumRegSets_b; a_read += NumRegSets_a;} // Reading registers and issuing core tensor core computations (a slice of A and B tile in shared memory) #pragma unroll for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { if(TilingConfig::WARP_COL_MMA_TENSORS==1) { MMA_FP16_M16N8K16( c_uint_ptr[i], a_read[i], b_read[0] ); } else { #pragma unroll for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS/2; j++) { MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a_read[i], b_read[j] ); MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, a_read[i], b_read[j] + 2 ); // c+4; b+2 } } } // Writing registers // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; uint32_t a_1bit[1]; // NO double buffer uint32_t a_2bit[2]; // NO double buffer uint32_t a_4bit[4]; // NO double buffer if(USE_SEG_1BIT) CopyFromSharedToRegister_AFrag<1> (a_1bit, A_1bit_SPTR_read, slice_id); if(USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2> (a_2bit, A_2bit_SPTR_read, slice_id); if(USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4> (a_4bit, A_4bit_SPTR_read, slice_id); Dequant_32FP6_4Way(a_write, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time B_FromSharedToReg (b_write, B_SPTR_read, slice_id); // Loading B from shared to registers } template __device__ __forceinline__ void StoreToSharedMemoryFromRegister(float (*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4], float c[][REG_PER_THREAD_C_TENSOR_16_16]) { const int lane_id = threadIdx.x % WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE; int warp_row_offset = warpId * (MMA_16 * WARP_ROW_MMA_TENSORS); #pragma unroll for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { #pragma unroll for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS; j++) { // Dealing with one 16*8 Tensor int RegSetID = i + (j/2)*WARP_ROW_MMA_TENSORS; int RegOffset = (j%2)*(REG_PER_THREAD_C_TENSOR_16_16/2); int Tensor_row_offset = warp_row_offset + i * MMA_16; int Tensor_col_offset = j * MMA_8; #pragma unroll for (int r = 0; r < REG_PER_THREAD_C_TENSOR_16_16/2; r++) { int row_offset = lane_id / 4; if (r >= 2) row_offset += 8; int col_offset = (lane_id % 4) * 2; if (r%2==1) col_offset += 1; smem_CFrag[Tensor_col_offset + col_offset][Tensor_row_offset + row_offset] = c[RegSetID][r + RegOffset]; } } } } #endif ================================================ FILE: fp6_llm/csrc/include/utils_gmem.cuh ================================================ #ifndef UTILS_GMEM_CUH #define UTILS_GMEM_CUH #include #include "configs.h" #include "ptx_cp.async.cuh" /* * Copying A1/A2 from global memory to shared memory. * Usually 1024 or 2048 Bytes */ template __device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, const uint4* GPTR, bool pred_guard = true) { #ifdef DEBUG_MODE static_assert(SMEM_SIZE_IN_BYTES_PER_WARP/WARP_SIZE % 16 == 0); #endif int lane_id = threadIdx.x % WARP_SIZE; half* SPTR_HALF = reinterpret_cast(SPTR); const half* GPTR_HALF = reinterpret_cast(GPTR); SPTR_HALF += lane_id*8; GPTR_HALF += lane_id*8; #pragma unroll for(int i=0; i( SPTR_HALF, GPTR_HALF, pred_guard); SPTR_HALF += 256; // Forward 512 Bytes GPTR_HALF += 256; // Forward 512 Bytes } } /* * Copying 64 Quant Scales (FP16) from global memory to shared memory. */ __device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantScales, const half* GPTR_A_Scales) { int lane_id = threadIdx.x % WARP_SIZE; int Offset_Shared = lane_id*2; int Offset_Global = lane_id/4 + (lane_id%4)*16; for(int i=0; i<2; i++) SPTR_QuantScales[Offset_Shared+i] = GPTR_A_Scales[Offset_Global+i*8]; } /* * (1) Copying X rows * 64 columns of FP16 values, originally in row major * (2) Copying 64 rows * X columns of FP16 values, originally in column major * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 Threads */ template __device__ __forceinline__ void CopyFromGlobalToShared(half __restrict__ (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], const half* GlobalPTR, const int GlobalStride, const int NumOfLinesLeft, // To support arbitrary N dimensions. bool Pred = true) { // static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time const int NumOfThreads = BLOCK_WARPS * WARP_SIZE; const int NumOfGroups = NumOfThreads / 8; const int MaxIteration = (MaxNumOfLinesToCopy-1) / NumOfGroups + 1; // runtime variables const int line_id = threadIdx.x / 8; const int line_offset = (threadIdx.x%8) * 8; // PTR for source global memory and target shared memory GlobalPTR += line_id * GlobalStride + line_offset; SharedPTR += line_id; #pragma unroll for (int i = 0; i < MaxIteration; i++) { bool AsyncCopyPred = (line_id+i*NumOfGroups) < NumOfLinesLeft && Pred; cp_async<16>( &(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred); // GlobalPTR += NumOfGroups * GlobalStride; SharedPTR += NumOfGroups; } } #endif ================================================ FILE: fp6_llm/csrc/include/utils_parallel_dequant.cuh ================================================ #ifndef UTILS_PARALLELDEQUANT_CUH #define UTILS_PARALLELDEQUANT_CUH #include #include #include /* * Input: R1 * Outputs: R1, R2 * Note: Simplified Exponent calculation is applied. */ template __device__ __forceinline__ void FPx_FP16_Cast_4Way(u_int32_t *In, u_int32_t *Out1, u_int32_t *Out2) { // constexpr int RIGHT_SHIFT = 5 - EXPONENT; constexpr int MASK1 = 0x80000000; constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA; constexpr int MASK3 = MASK2 & 0x7fffffff; constexpr int MASK = MASK3 | MASK3 >> 16; // *Out1 = *In & 0x80008000; *Out1 |= ( (*In) & MASK ) >> RIGHT_SHIFT; // *In = (*In) << 8; *Out2 = *In & 0x80008000; *Out2 |= ( (*In) & MASK ) >> RIGHT_SHIFT; } template __device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale) { constexpr int BIAS_OFFSET = (int(1) << (5-1)) - (int(1) << (EXPONENT-1)); constexpr int BIAS = int(1) << BIAS_OFFSET; // half* FP16_1 = reinterpret_cast(&PackedFP16Pair); half* FP16_2 = FP16_1 + 1; uint32_t output; half* output_half_ptr = reinterpret_cast(&output); output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(1.0f*BIAS)), Scale); output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2half(1.0f*BIAS)), Scale); return output; } template __device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__ Reg[][4], u_int32_t __restrict__ *read_RPTR_1bit, u_int32_t __restrict__ *read_RPTR_2bit, u_int32_t __restrict__ *read_RPTR_4bit, u_int32_t *Scales) { // 1+2+4 weight split constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; // u_int32_t *OutputRegs = reinterpret_cast (Reg); u_int32_t *Frag_PTR_1bit = read_RPTR_1bit; u_int32_t *Frag_PTR_2bit = read_RPTR_2bit; u_int32_t *Frag_PTR_4bit = read_RPTR_4bit; half *Scale_RPTR = reinterpret_cast(Scales); // Dequantizing 32 FP6, each Loop dequantizing 4 FP6 #pragma unroll(8) for(int i=0; i<8; i++) { u_int32_t Packed_FP6 = 0; u_int32_t tmp = 0; // 1bit Frag if(USE_SEG_1BIT) { tmp = (*Frag_PTR_1bit) & 0x80808080; Packed_FP6 |= tmp >> (BIT_WIDTH & 0); if(i%8==7) Frag_PTR_1bit++; else (*Frag_PTR_1bit) = (*Frag_PTR_1bit) << 1; } // 2bit Frag if(USE_SEG_2BIT) { tmp = (*Frag_PTR_2bit) & 0xc0c0c0c0; Packed_FP6 |= tmp >> (BIT_WIDTH & 1); if(i%4==3) Frag_PTR_2bit++; else (*Frag_PTR_2bit) = (*Frag_PTR_2bit) << 2; } // 4bit Frag2 if(USE_SEG_4BIT) { tmp = (*Frag_PTR_4bit) & 0xf0f0f0f0; Packed_FP6 |= tmp >> (BIT_WIDTH & 3); if(i%2==1) Frag_PTR_4bit++; else (*Frag_PTR_4bit) = (*Frag_PTR_4bit) << 4; } // u_int32_t out1, out2; FPx_FP16_Cast_4Way(&Packed_FP6, &out1, &out2); // *OutputRegs = MultScale(out1, Scale_RPTR[0] ); // Muliply FP16 scales OutputRegs += 1; *OutputRegs = MultScale(out2, Scale_RPTR[1]); // Muliply FP16 scales OutputRegs += 1; // Updating offset for FP16 scales for every two iterations if(i%2==1) Scale_RPTR += 2; } } /* * */ __device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, half* WARP_SPTR_Scales) { int lane_id = threadIdx.x % WARP_SIZE; uint32_t* SPTR_uint = reinterpret_cast(WARP_SPTR_Scales); uint32_t tmpReg = SPTR_uint[lane_id]; #pragma unroll for(int i=0; i<4; i++) { // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize); Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4); } } #endif ================================================ FILE: fp6_llm/csrc/pybind.cpp ================================================ #include #include #include "fp6_linear.cuh" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Old Interfaces. m.def("linear_forward_cuda", &fp6_linear_forward_cuda, "Computes FP6-FP16 GEMM."); m.def("weight_prepacking_cpu", &weight_matrix_prepacking_cpu, "Weight prepacking."); m.def("weight_dequant_cpu", &weight_matrix_dequant_cpu, "Dequantize weight from fp6 to fp16."); // New Interfaces. m.def("linear_forward_eXmY_cuda", &fp_eXmY_linear_forward_cuda, "Computes FPx-FP16 GEMM."); m.def("weight_prepacking_eXmY_cpu", &weight_matrix_prepacking_fp_eXmY_cpu, "FPx Weight prepacking."); m.def("weight_dequant_eXmY_cpu", &weight_matrix_dequant_fp_eXmY_cpu, "Dequantize weight from fpx to fp16."); } ================================================ FILE: fp6_llm/csrc/utils/common.h ================================================ #ifndef UTILS_COMMON_H #define UTILS_COMMON_H template unsigned char Extract_X_Bits_To_A_Byte(unsigned char* Bytes, int ByteOffset, int BitOffset){ assert (sizeof(unsigned int)==4); unsigned int tmp_int32_word=0; unsigned char* uchar_ptr = reinterpret_cast(&tmp_int32_word); uchar_ptr[3] = Bytes[ByteOffset+0]; uchar_ptr[2] = Bytes[ByteOffset+1]; tmp_int32_word = tmp_int32_word << BitOffset; // signed int mask = 0x80000000; mask = mask >> (EXPONENT+MANTISSA); tmp_int32_word &= mask; // unsigned char out = uchar_ptr[3]; return out; } #endif ================================================ FILE: fp6_llm/csrc/utils/weight_dequant.h ================================================ #include #include #include #include #include #include #include "common.h" template void DeQuantMatrix_FPx_To_FP16(half* A_16bit_h, unsigned char* A_x_bit_h, size_t M, size_t K, half* scale) { // assert(M%64==0); // Currently, M must be a multiple of 64. assert(K%64==0); // Currently, K must be a multiple of 64. constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; assert(BIT_WIDTH<=8); size_t TotalSizeInByte = M * K * BIT_WIDTH / 8; // half* OutPTR = A_16bit_h; for(size_t i=0; i(Bytes, ByteOffset, BitOffset); } // Dequant constexpr int MASK1 = 0x80000000; constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA; constexpr int MASK = MASK2 & 0x7fffffff; constexpr int RIGHT_SHIFT = 5 - EXPONENT; constexpr int BIAS_OFFSET = (int(1) << (5-1)) - (int(1) << (EXPONENT-1)); constexpr int BIAS = int(1) << BIAS_OFFSET; for(int x=0; x<8; x++) { unsigned int OUT_fp16; // Storing fp16 in the high 16 bits. OUT_fp16 = int(OUT[x]) << 24; OUT_fp16 = (OUT_fp16 & 0x80000000) | ( (OUT_fp16 & MASK) >> RIGHT_SHIFT ); OUT_fp16 = OUT_fp16 >> 16; // half* OUT_FP16_PTR = reinterpret_cast(&OUT_fp16); OutPTR[x] = __float2half_rn ( __half2float(*OUT_FP16_PTR) * (1.0f*BIAS) * __half2float(scale[(8*i)/K]) ); } // OutPTR +=8; } } void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale) { DeQuantMatrix_FPx_To_FP16<3, 2>(A_16bit_h, A_6bit_h, M, K, scale); } void dequant_matrix_fp_eXmY_to_fp16(const int EXPONENT, const int MANTISSA, half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale){ if(EXPONENT==2 && MANTISSA==2) return DeQuantMatrix_FPx_To_FP16<2, 2>(A_16bit_h, A_6bit_h, M, K, scale); if(EXPONENT==3 && MANTISSA==2) return DeQuantMatrix_FPx_To_FP16<3, 2>(A_16bit_h, A_6bit_h, M, K, scale); printf("DeQuantMatrix Error: Unsupported EXPONENT=%d, MANTISSA=%d!\n", EXPONENT, MANTISSA); exit(-1); } ================================================ FILE: fp6_llm/csrc/utils/weight_prepacking.h ================================================ #include #include #include #include #include "common.h" /* * Inputs: * (1) unsigned char Weight_6bit [M*K*6/8] * Outputs: * (1) unsigned char Weight_2bit [M*K*2/8] * (2) unsigned char Weight_4bit [M*K*4/8] * * Assumption: Weight_6bit, Weight_2bit, Weight_4bit all stored continuously in row-major. * 8 FP6 = 6 Bytes * 8 FP4 = 4 Bytes * 8 FP2 = 2 Bytes */ using namespace std; void Extract_segments_from_8_padded_fpx(unsigned char Seg_xbit[], unsigned char Padded_8_FPx[], int bit_width, int bit_offset){ for(int i=0; i< bit_width; i++) Seg_xbit[i] = 0; for(int i=0; i<8; i++){ unsigned int seg = (Padded_8_FPx[i] << bit_offset) & 0x000000ff; int mask = 0xffffff00; seg &= mask >> bit_width; // int Seg_idx = (i * bit_width) / 8; int Seg_off = (i * bit_width) % 8; Seg_xbit[Seg_idx] |= seg >> Seg_off; } } // dealing with 4 1*8 blocks of FPx template void Assign_32_FPx_To_4_Thread(vector Vec_Seg_1bit[], vector Vec_Seg_2bit[], vector Vec_Seg_4bit[], unsigned char* PTR[]) { constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; assert(BIT_WIDTH<8); constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; // constexpr int nTHREADS = 4; constexpr int FPx_PER_THREAD = 8; unsigned char Padded_8_FPx[nTHREADS][FPx_PER_THREAD]; for(int i=0; i(PTR[j/2], ByteOffset, BitOffset); } } // unsigned char Seg_1bit[nTHREADS][1]; unsigned char Seg_2bit[nTHREADS][2]; unsigned char Seg_4bit[nTHREADS][4]; for(int t=0; t void BitInterleaving_x_bit(unsigned char* PTR_4Bytes) { unsigned int *PTR_UINT = reinterpret_cast(PTR_4Bytes); unsigned int input = *PTR_UINT; // int* order = NULL; int order_1bit[32] = {2,6,10,14,18,22,26,30, 4,8,12,16,20,24,28,32, 1,5,9, 13,17,21,25,29, 3,7,11,15,19,23,27,31}; // pre-defined order for bit-interleaving in FP6-LLM int order_2bit[16] = {2,6,10,14,4,8,12,16,1,5,9,13,3,7,11,15}; // pre-defined order for bit-interleaving in FP6-LLM int order_4bit[8] = {2,6,4,8,1,5,3,7}; // pre-defined order for bit-interleaving in FP6-LLM if(BIT_WIDTH==1) order = order_1bit; if(BIT_WIDTH==2) order = order_2bit; if(BIT_WIDTH==4) order = order_4bit; assert(order); // int mask = 0x80000000; assert(BIT_WIDTH>=1); mask = mask >> (BIT_WIDTH-1); // unsigned int output = 0x00000000; for(int i=0; i<32/BIT_WIDTH; i++){ unsigned int Frag_xbit = ( input << BIT_WIDTH*(order[i]-1) ) & mask; // The highest x bits are used to store the extracted fragments. output |= Frag_xbit >> (i*BIT_WIDTH); } // *PTR_UINT = output; } template void weight_matrix_prepacking_x_bit(int* packed_weights, int *FPxWeights, size_t M, size_t K) { assert(M % 64 == 0); assert(K % 64 == 0); // constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; assert(BIT_WIDTH<8); constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; // unsigned char* Weight_xbit = reinterpret_cast(FPxWeights); unsigned char* Weight_1bit = reinterpret_cast(packed_weights); unsigned char* Weight_2bit = Weight_1bit + (USE_SEG_1BIT ? M*K*1/8 : 0); unsigned char* Weight_4bit = Weight_2bit + (USE_SEG_2BIT ? M*K*2/8 : 0); // vector A_Segment_1bit[32]; vector A_Segment_2bit[32]; vector A_Segment_4bit[32]; // size_t BytesPerRow = K*BIT_WIDTH/8; // Pass-1: (1) 1+2+4 split; (2) assign weights to 32 threads. for (size_t i = 0; i < M / 64; i++){ for (size_t j = 0; j < K / 16; j++){ for(size_t k=0; k<64/16; k++){ size_t row = i*64 + k*16; size_t col = j*16; unsigned char* StartPTR_1 = Weight_xbit + row*BytesPerRow + col*(BIT_WIDTH)/8; unsigned char* StartPTR_2 = StartPTR_1 + 8*BytesPerRow; unsigned char* StartPTR_3 = StartPTR_1 + 8*(BIT_WIDTH)/8; unsigned char* StartPTR_4 = StartPTR_2 + 8*(BIT_WIDTH)/8; // Dealing with each 16*16 blocks then... for(int l=0; l<8; l++) { unsigned char* PTR[4]={StartPTR_1+l*BytesPerRow, StartPTR_2+l*BytesPerRow, StartPTR_3+l*BytesPerRow, StartPTR_4+l*BytesPerRow}; Assign_32_FPx_To_4_Thread(&A_Segment_1bit[l*4], &A_Segment_2bit[l*4], &A_Segment_4bit[l*4], PTR); } } } } // Verifying the length of 1/2/4_bit segments. size_t BytesPerThread_1bit = M*K*1/8/32; size_t BytesPerThread_2bit = M*K*2/8/32; size_t BytesPerThread_4bit = M*K*4/8/32; for(int i=0; i<32; i++){ if(USE_SEG_1BIT) assert(A_Segment_1bit[i].size()==BytesPerThread_1bit); else assert(A_Segment_1bit[i].size()==0); if(USE_SEG_2BIT) assert(A_Segment_2bit[i].size()==BytesPerThread_2bit); else assert(A_Segment_2bit[i].size()==0); if(USE_SEG_4BIT) assert(A_Segment_4bit[i].size()==BytesPerThread_4bit); else assert(A_Segment_4bit[i].size()==0); } // Pass-2: Optimizing coleasced global memory access if(USE_SEG_1BIT) for(size_t i=0; i(Weight_1bit+4*i); if(USE_SEG_2BIT) for(size_t i=0; i(Weight_2bit+4*i); if(USE_SEG_4BIT) for(size_t i=0; i(Weight_4bit+4*i); } void weight_matrix_prepacking(int* packed_weights, int *FP6Weights, size_t M, size_t K){ weight_matrix_prepacking_x_bit<3, 2>(packed_weights, FP6Weights, M, K); } // void weight_matrix_prepacking_fp_eXmY(const int EXPONENT, const int MANTISSA, int* packed_weights, int *FPxWeights, size_t M, size_t K){ if(EXPONENT==2 && MANTISSA==2) return weight_matrix_prepacking_x_bit<2, 2>(packed_weights, FPxWeights, M, K); if(EXPONENT==3 && MANTISSA==2) return weight_matrix_prepacking_x_bit<3, 2>(packed_weights, FPxWeights, M, K); printf("Weight_prepacking Error: Unsupported EXPONENT=%d, MANTISSA=%d!\n", EXPONENT, MANTISSA); exit(-1); } ================================================ FILE: fp6_llm/csrc/utils/weight_quant.h ================================================ // Author: Zhen Zheng // To be used in the future as a tool to generating the FP6 matrix from the FP16 matrix. #include /* * Function to pack 4 fake quantized FP16 value into continuously stored 4 FP6 values. */ void cast_fp16_fp6(uint16_t* FP16x4, uint8_t* FP6x4) { // Constants for FP6 constexpr int exponent_nbits_fp6 = 3; constexpr int mantissa_nbits_fp6 = 2; constexpr int exp_bias_fp6 = (1 << (exponent_nbits_fp6 - 1)) - 1; // Constants for FP16 constexpr int exponent_nbits_fp16 = 5; constexpr int mantissa_nbits_fp16 = 10; constexpr int exp_bias_fp16 = (1 << (exponent_nbits_fp16 - 1)) - 1; int fp6_temp[4]; float absmin_nonzero_fp6 = 0.0625; // Note that we regard the exponent of '111' as a regular value rather than NaN or inf. This is // the same with that in qtorch. float absmax_fp6 = 28; for (int i = 0; i < 4; ++i) { uint16_t source = FP16x4[i]; float fp6_value_abs = std::abs(__half2float(*((half*)(&source)))); if ((fp6_value_abs != 0 && fp6_value_abs < absmin_nonzero_fp6) || fp6_value_abs > absmax_fp6) { // TODO(zhen): a better way may be rounding it to the nearest FP6 value. throw std::invalid_argument("Input value out of range for FP6."); } // It is not safe to do shift operation on uint16_t. So we promote it to int. int source_promote = int(source); int sign_bit = (source_promote >> 15); // Extracting exponent represented in FP16. The sign mask 0x7FFF is '0111 1111 1111 1111' int exp_bit = (source_promote & 0x7FFF) >> mantissa_nbits_fp16; // Extracting mantissa represented in FP16 int mant_bit = source_promote & ((1 << mantissa_nbits_fp16) - 1); int new_exp_bit; int new_mant_bit; if (exp_bit == 0) { // Subnormal FP16 number. Too small for FP6. new_exp_bit = 0; new_mant_bit = 0; } else { new_mant_bit = mant_bit >> (mantissa_nbits_fp16 - mantissa_nbits_fp6); new_exp_bit = exp_bit - exp_bias_fp16 + exp_bias_fp6; // Deal with subnormal FP6 values. int target_exp_val = exp_bit - exp_bias_fp16; int min_fp6_exp_val = -exp_bias_fp6 + 1; bool subnormal_fp6 = target_exp_val < min_fp6_exp_val; if (subnormal_fp6) { // TODO(zhen): add the rounding logic. new_exp_bit = 0; // The implicit 1 in the mantissa of FP16 is not present in subnormal FP6. Thus we // need to add it new_mant_bit = (new_mant_bit | (1 << mantissa_nbits_fp6)) >> (min_fp6_exp_val - target_exp_val); } } fp6_temp[i] = (sign_bit << (exponent_nbits_fp6 + mantissa_nbits_fp6)) | (new_exp_bit << mantissa_nbits_fp6) | new_mant_bit; } // Pack the values FP6x4[0] = fp6_temp[0] << 2 | (fp6_temp[1] >> 4); FP6x4[1] = (fp6_temp[1] & 0x0F) << 4 | (fp6_temp[2] >> 2); FP6x4[2] = (fp6_temp[2] & 0x03) << 6 | fp6_temp[3]; } /* * Function to prepack FP16 weights into continuous FP6 values. * * Parameters: * weight_16bit: input weight in FP16, size M*K * weight_6bit: output weight in packed FP6, continuously stored, size M*K*6/8 * M, K: the shape of the weight */ void weight_prepacking_fp16_to_fp6(uint16_t* weight_16bit, uint8_t* weight_6bit_packed, size_t M, size_t K) { // Every four 16-bit elements are packed into three 6-bit values (4*6bit == 3*8bit). if (K * 6 % 8 != 0) { throw std::invalid_argument("(K * 6 % 8) should be 0"); } size_t K_fp6_packed = K * 6 / 8; // #pragma omp parallel for for (auto m = 0; m < M; m++) { uint8_t* ptr_6bit = weight_6bit_packed + m * K_fp6_packed; uint16_t* ptr_16bit = weight_16bit + m * K; for (auto k = 0; k < K; k += 4) { cast_fp16_fp6(ptr_16bit, ptr_6bit); ptr_16bit += 4; ptr_6bit += 3; } } } ================================================ FILE: setup.py ================================================ from setuptools import find_packages, setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension extra_compile_args = { "cxx": [ "-O3", "-std=c++17" ], "nvcc": [ "-O3", "--use_fast_math", "-std=c++17", "-maxrregcount=255", "--ptxas-options=-v,-warn-lmem-usage,--warn-on-spills", "-gencode=arch=compute_80,code=sm_80" ], } setup( name="fp6_llm", author="Haojun Xia, Zhen Zheng, Xiaoxia Wu, Shiyang Chen, Zhewei Yao, Stephen Youn, Arash Bakhtiari, Michael Wyatt, Donglin Zhuang, Zhongzhu Zhou, Olatunji Ruwase, Yuxiong He, Shuaiwen Leon Song", version="0.2", author_email="xhjustc@gmail.com", description ="An efficient GPU support for LLM inference with x-bit quantization (e.g., FP6 and FP5).", python_requires=">=3.8", install_requires=[ "torch", "transformers" ], packages=find_packages(), ext_modules=[ CUDAExtension( name="fp6_llm_cuda", sources=[ "fp6_llm/csrc/pybind.cpp", "fp6_llm/csrc/fp6_linear.cu" ], extra_compile_args=extra_compile_args, ), ], cmdclass={"build_ext": BuildExtension} ) ================================================ FILE: tests/cpp/Makefile ================================================ # host compiler HOST_COMPILER ?= g++ NVCC := nvcc -ccbin $(HOST_COMPILER) # internal flags NVCCFLAGS := -m$(shell getconf LONG_BIT) CCFLAGS := -DNO_PYTORCH LDFLAGS := -rpath=../../fp6_llm ALL_CCFLAGS := ALL_CCFLAGS += $(NVCCFLAGS) ALL_CCFLAGS += $(addprefix -Xcompiler ,$(CCFLAGS)) ALL_LDFLAGS := ALL_LDFLAGS += $(ALL_CCFLAGS) ALL_LDFLAGS += $(addprefix -Xlinker ,$(LDFLAGS)) # Common includes and paths for CUDA INCLUDES := -I/usr/local/cuda/include/ LIBRARIES := -lcublas # INCLUDES += -I../../fp6_llm/csrc LIBRARIES += -L../../fp6_llm -lfp6 ################################################################################ # Gencode arguments SMS ?= 80 # Generate SASS code for each SM architecture listed in $(SMS) $(foreach sm,$(SMS),$(eval GENCODE_FLAGS += -gencode arch=compute_$(sm),code=sm_$(sm))) ################################################################################ # Target rules all: kernel_test_fp6 kernel_test_fpx kernel_test_fp6.o: kernel_test_fp6.cu kernel_test.h $(EXEC) $(NVCC) $(INCLUDES) $(ALL_CCFLAGS) $(GENCODE_FLAGS) -o $@ -c $< kernel_test_fpx.o: kernel_test_fpx.cu kernel_test.h $(EXEC) $(NVCC) $(INCLUDES) $(ALL_CCFLAGS) $(GENCODE_FLAGS) -o $@ -c $< kernel_test_fp6: kernel_test_fp6.o $(EXEC) $(NVCC) $(ALL_LDFLAGS) $(GENCODE_FLAGS) -o $@ $+ $(LIBRARIES) kernel_test_fpx: kernel_test_fpx.o $(EXEC) $(NVCC) $(ALL_LDFLAGS) $(GENCODE_FLAGS) -o $@ $+ $(LIBRARIES) clean: rm -f kernel_test_fp6 kernel_test_fp6.o kernel_test_fpx kernel_test_fpx.o ================================================ FILE: tests/cpp/kernel_test.h ================================================ #include #include #include #include #include #include #include // Performance Benchmark #define WARM_UP_ITERATION 100 #define BENCHMARK_ITERATION 10000 void __forceinline__ CheckMallocCPU(void* PTR, int line = -1) { if (PTR == NULL) { printf("Error in CPU Malloc, line %d!\n", line); exit(-1); } } void __forceinline__ CheckMallocCUDA(void* PTR, int line = -1) { if (PTR == NULL) { printf("Error in cudaMalloc, line %d!\n", line); exit(-1); } } void checkCublasError(cublasStatus_t status, int line) { if (status != CUBLAS_STATUS_SUCCESS) { printf("Cublas Error at line %d, Error Code: %d\n", line, status); exit(EXIT_FAILURE); } } void checkLastCudaError(int line) { cudaError_t error = cudaGetLastError(); if (error != cudaSuccess) { printf("Last Cuda Error Detected at line: %d, Error: %s.\n", line, cudaGetErrorString(error)); exit(EXIT_FAILURE); } } // Note: totalAbsSum might overflow if (1)the shape of output matrix are large and (2)the value of each element within the output matrix is large. // The overflow might result in NaN for the "TotalAbsError/TotalAbsSum", while the fp6_llm is working correctly. // This problem can be fixed by setting the quantizaiton scales to a smaller value. double ComputeTotalError(half* CuBlas, half* Other, size_t m, size_t n) { long double totalError = 0.0; for (size_t i = 0; i < m * n; i++) totalError += fabs(__half2float(CuBlas[i]) - __half2float(Other[i])); long double totalAbsSum = 0.0; for (size_t i = 0; i < m * n; i++) totalAbsSum += fabs(__half2float(CuBlas[i])); return totalError/totalAbsSum; } void PrintPerformance(const char* KernelName, float milliseconds, float tflops, double error) { printf("%-10s \t -> \t\t Time/ms: %5.3f \t Performance/TFLOPs: %4.2f \t TotalAbsError/TotalAbsSum: %.8lf\n", KernelName, milliseconds, tflops, error); } void PrintMismatch(const char* KernelName, size_t MaxNumMismatch, float RelativeErrorThreshold, half* CuBlas, half* Other, size_t M_GLOBAL, size_t N_GLOBAL) { //printf("First %d Mismatches between Cublas and %s:\n", MaxNumMismatch, KernelName); size_t count = 0; for (size_t i = 0; i < M_GLOBAL; i++) { for (size_t j = 0; j < N_GLOBAL; j++) { if (fabs(__half2float(CuBlas[i + j * M_GLOBAL]) - __half2float(Other[i + j * M_GLOBAL]))/fabs(__half2float(CuBlas[i + j * M_GLOBAL])) > RelativeErrorThreshold) { count++; printf("(%d,%d) CuBlas=%f %s=%f\n", i, j, __half2float(CuBlas[i + j * M_GLOBAL]), KernelName, __half2float(Other[i + j * M_GLOBAL])); } if (count == MaxNumMismatch) break; } if (count == MaxNumMismatch) break; } } ================================================ FILE: tests/cpp/kernel_test_fp6.cu ================================================ #include "kernel_test.h" #include "fp6_linear.cuh" int main(int argc, char** argv) { // Parsing the inputs from CLI. if (argc != 5) { printf("Wrong Inputs! Correct input format: ./kernel_test #Row_Weight #Column_Weight BatchSize SplitK\n"); return -1; } size_t M_GLOBAL = atoi(argv[1]); size_t K_GLOBAL = atoi(argv[2]); size_t N_GLOBAL = atoi(argv[3]); int SPLIT_K = atoi(argv[4]); assert(M_GLOBAL%256==0); // Currently, M_GLOBAL must be a multiple of 256. assert(K_GLOBAL%64==0); // Currently, K_GLOBAL must be a multiple of 64. //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // Matrices in quantized FP6 models with faked values. unsigned char* A_6bit_h = (unsigned char*)malloc(M_GLOBAL*K_GLOBAL*6/8); CheckMallocCPU(A_6bit_h, __LINE__); // Weight matrix with FP6 values, stored in row-major. for(size_t i=0; i(&A_6bit), M_GLOBAL*K_GLOBAL*6/8); CheckMallocCUDA(A_6bit, __LINE__); cudaMalloc(reinterpret_cast(&A_Scale), M_GLOBAL*sizeof(half)); CheckMallocCUDA(A_Scale, __LINE__); cudaMalloc(reinterpret_cast(&A_16bit), M_GLOBAL*K_GLOBAL*sizeof(half)); CheckMallocCUDA(A_16bit, __LINE__); // Memory Copy from CPU to GPU cudaMemcpy(A_6bit, A_6bit_h, M_GLOBAL*K_GLOBAL*6/8, cudaMemcpyHostToDevice); cudaMemcpy(A_Scale, A_Scale_h, M_GLOBAL*sizeof(half), cudaMemcpyHostToDevice); cudaMemcpy(A_16bit, A_16bit_h, M_GLOBAL*K_GLOBAL*sizeof(half), cudaMemcpyHostToDevice); checkLastCudaError(__LINE__); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // B Matrix: Activations half* B_h = (half*)malloc(sizeof(half) * K_GLOBAL * N_GLOBAL); CheckMallocCPU(B_h); // col major for (size_t i = 0; i < N_GLOBAL * K_GLOBAL; i++) B_h[i] = __float2half_rn(static_cast((rand() % 5)) / 5 - 0.5f); // Device memory half* B = NULL; cudaMalloc(reinterpret_cast(&B), sizeof(half) * N_GLOBAL * K_GLOBAL); CheckMallocCUDA(B, __LINE__); // Memory Copy from CPU to GPU cudaMemcpy(B, B_h, sizeof(half) * N_GLOBAL * K_GLOBAL, cudaMemcpyHostToDevice); checkLastCudaError(__LINE__); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// cublasStatus_t cublas_status; cudaEvent_t start, stop; cudaEventCreate(&start); cudaEventCreate(&stop); checkLastCudaError(__LINE__); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// half* D_cublas = NULL; cudaMalloc(reinterpret_cast(&D_cublas), sizeof(half) * M_GLOBAL * N_GLOBAL); CheckMallocCUDA(D_cublas, __LINE__); cudaMemset(D_cublas, 0, sizeof(half) * M_GLOBAL * N_GLOBAL); cublasHandle_t handle; cublasCreate(&handle); cublasSetStream(handle, 0); //cublasSetMathMode(handle, CUBLAS_PEDANTIC_MATH); // Tensor core NOT enabled cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH); // Tensor core enabled cudaDeviceSynchronize(); int m = M_GLOBAL, n = N_GLOBAL, k = K_GLOBAL; const float alpha = 1.0; const float beta = 0.0; cublasGemmAlgo_t CuBlasALG = static_cast(0); for (int i = 0; i < WARM_UP_ITERATION; i++) { cublas_status = cublasGemmEx(handle, CUBLAS_OP_T, CUBLAS_OP_N, m, n, k, &alpha, A_16bit, CUDA_R_16F, k, B, CUDA_R_16F, k, &beta, D_cublas, CUDA_R_16F, m, CUDA_R_32F, CuBlasALG); checkCublasError(cublas_status, __LINE__); } cudaEventRecord(start); for (int i = 0; i < BENCHMARK_ITERATION; i++) cublas_status = cublasGemmEx(handle, CUBLAS_OP_T, CUBLAS_OP_N, m, n, k, &alpha, A_16bit, CUDA_R_16F, k, B, CUDA_R_16F, k, &beta, D_cublas, CUDA_R_16F, m, CUDA_R_32F, CuBlasALG); cudaEventRecord(stop); cudaEventSynchronize(stop); // float milliseconds_cublas = 0; cudaEventElapsedTime(&milliseconds_cublas, start, stop); milliseconds_cublas = milliseconds_cublas / BENCHMARK_ITERATION; float tflops_cublas = static_cast((static_cast(M_GLOBAL) * N_GLOBAL * K_GLOBAL * 2) / (milliseconds_cublas / 1000.)) / 1e12; // half* D_cublas_h = NULL; // col major D_cublas_h = (half*)malloc(sizeof(half) * M_GLOBAL * N_GLOBAL); CheckMallocCPU(D_cublas_h); cudaMemcpy(D_cublas_h, D_cublas, sizeof(half) * M_GLOBAL * N_GLOBAL, cudaMemcpyDeviceToHost); // Col Major cudaFree(D_cublas); checkLastCudaError(__LINE__); ///////////////////////////////////////////////////////////////////////////////////////////////// half* D_fp6 = NULL; cudaMalloc(reinterpret_cast(&D_fp6), sizeof(half) * M_GLOBAL * N_GLOBAL); CheckMallocCUDA(D_fp6); cudaMemset(D_fp6, 0, sizeof(half) * M_GLOBAL * N_GLOBAL); // int Split_K = SPLIT_K; float* Reduction_Workspace = NULL; cudaMalloc(reinterpret_cast(&Reduction_Workspace), sizeof(float) * M_GLOBAL * N_GLOBAL * Split_K); CheckMallocCUDA(Reduction_Workspace, __LINE__); // for (int i = 0; i < WARM_UP_ITERATION; i++) fp6_linear_kernel( 0, (uint4*)A_6bit, A_Scale, B, D_fp6, M_GLOBAL, N_GLOBAL, K_GLOBAL, Reduction_Workspace, Split_K); cudaEventRecord(start); for (int i = 0; i < BENCHMARK_ITERATION; i++) fp6_linear_kernel( 0, (uint4*)A_6bit, A_Scale, B, D_fp6, M_GLOBAL, N_GLOBAL, K_GLOBAL, Reduction_Workspace, Split_K); cudaEventRecord(stop); cudaEventSynchronize(stop); checkLastCudaError(__LINE__); // float milliseconds_fp6 = 0.0f; cudaEventElapsedTime(&milliseconds_fp6, start, stop); milliseconds_fp6 = milliseconds_fp6 / BENCHMARK_ITERATION; float tflops_fp6 = static_cast((static_cast(M_GLOBAL) * N_GLOBAL * K_GLOBAL * 2) / (milliseconds_fp6 / 1000.)) / 1e12; half* D_fp6_h = NULL; // col major D_fp6_h = (half*)malloc(sizeof(half) * M_GLOBAL * N_GLOBAL); cudaMemcpy(D_fp6_h, D_fp6, sizeof(half) * M_GLOBAL * N_GLOBAL, cudaMemcpyDeviceToHost); // Col Major cudaFree(D_fp6); cudaFree(Reduction_Workspace); ///////////////////////////////////////////////////////////////////////////////////////////////// double totalRelativeError_fp6 = ComputeTotalError(D_cublas_h, D_fp6_h, M_GLOBAL, N_GLOBAL); printf("************************************* "); printf("[%d-bit Weights, e%dm%d] M: %d N: %d K: %d SplitK: %d", 6, 3, 2, M_GLOBAL, N_GLOBAL, K_GLOBAL, SPLIT_K); printf(" ************************************\n"); PrintPerformance("cuBLAS", milliseconds_cublas, tflops_cublas, 0.0); PrintPerformance("fp6_llm", milliseconds_fp6, tflops_fp6, totalRelativeError_fp6); //PrintMismatch("fp6", 100, 0.002, D_cublas_h, D_fp6_h, M_GLOBAL, N_GLOBAL); free(D_cublas_h); free(D_fp6_h); cudaFree(B); return 0; } ================================================ FILE: tests/cpp/kernel_test_fpx.cu ================================================ #include "kernel_test.h" #include "fp6_linear.cuh" int main(int argc, char** argv) { // Parsing the inputs from CLI. if (argc != 7) { printf("Wrong Inputs! Correct input format: ./kernel_test EXPONENT MANTISSA #Row_Weight #Column_Weight BatchSize SplitK\n"); return -1; } int EXPONENT = atoi(argv[1]); int MANTISSA = atoi(argv[2]); size_t M_GLOBAL = atoi(argv[3]); size_t K_GLOBAL = atoi(argv[4]); size_t N_GLOBAL = atoi(argv[5]); int SPLIT_K = atoi(argv[6]); int BIT_WIDTH = 1 + EXPONENT + MANTISSA; assert(EXPONENT==2 || EXPONENT==3); assert(MANTISSA==2); assert(M_GLOBAL%256==0); // Currently, M_GLOBAL must be a multiple of 256. assert(K_GLOBAL%64==0); // Currently, K_GLOBAL must be a multiple of 64. //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // Matrices in quantized FPx models with faked values. unsigned char* A_xbit_h = (unsigned char*)malloc(M_GLOBAL*K_GLOBAL*BIT_WIDTH/8); CheckMallocCPU(A_xbit_h, __LINE__); // Weight matrix with FP6 values, stored in row-major. for(size_t i=0; i(&A_xbit), M_GLOBAL*K_GLOBAL*BIT_WIDTH/8); CheckMallocCUDA(A_xbit, __LINE__); cudaMalloc(reinterpret_cast(&A_Scale), M_GLOBAL*sizeof(half)); CheckMallocCUDA(A_Scale, __LINE__); cudaMalloc(reinterpret_cast(&A_16bit), M_GLOBAL*K_GLOBAL*sizeof(half)); CheckMallocCUDA(A_16bit, __LINE__); // Memory Copy from CPU to GPU cudaMemcpy(A_xbit, A_xbit_h, M_GLOBAL*K_GLOBAL*BIT_WIDTH/8, cudaMemcpyHostToDevice); cudaMemcpy(A_Scale, A_Scale_h, M_GLOBAL*sizeof(half), cudaMemcpyHostToDevice); cudaMemcpy(A_16bit, A_16bit_h, M_GLOBAL*K_GLOBAL*sizeof(half), cudaMemcpyHostToDevice); checkLastCudaError(__LINE__); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // B Matrix: Activations half* B_h = (half*)malloc(sizeof(half) * K_GLOBAL * N_GLOBAL); CheckMallocCPU(B_h); // col major for (size_t i = 0; i < N_GLOBAL * K_GLOBAL; i++) B_h[i] = __float2half_rn(static_cast((rand() % 5)) / 5 - 0.5f); // Device memory half* B = NULL; cudaMalloc(reinterpret_cast(&B), sizeof(half) * N_GLOBAL * K_GLOBAL); CheckMallocCUDA(B, __LINE__); // Memory Copy from CPU to GPU cudaMemcpy(B, B_h, sizeof(half) * N_GLOBAL * K_GLOBAL, cudaMemcpyHostToDevice); checkLastCudaError(__LINE__); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// cublasStatus_t cublas_status; cudaEvent_t start, stop; cudaEventCreate(&start); cudaEventCreate(&stop); checkLastCudaError(__LINE__); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// //printf("Launching CuBlas...\n"); half* D_cublas = NULL; cudaMalloc(reinterpret_cast(&D_cublas), sizeof(half) * M_GLOBAL * N_GLOBAL); CheckMallocCUDA(D_cublas, __LINE__); cudaMemset(D_cublas, 0, sizeof(half) * M_GLOBAL * N_GLOBAL); cublasHandle_t handle; cublasCreate(&handle); cublasSetStream(handle, 0); //cublasSetMathMode(handle, CUBLAS_PEDANTIC_MATH); // Tensor core NOT enabled cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH); // Tensor core enabled cudaDeviceSynchronize(); int m = M_GLOBAL, n = N_GLOBAL, k = K_GLOBAL; const float alpha = 1.0; const float beta = 0.0; cublasGemmAlgo_t CuBlasALG = static_cast(0); for (int i = 0; i < WARM_UP_ITERATION; i++) { cublas_status = cublasGemmEx(handle, CUBLAS_OP_T, CUBLAS_OP_N, m, n, k, &alpha, A_16bit, CUDA_R_16F, k, B, CUDA_R_16F, k, &beta, D_cublas, CUDA_R_16F, m, CUDA_R_32F, CuBlasALG); checkCublasError(cublas_status, __LINE__); } cudaEventRecord(start); for (int i = 0; i < BENCHMARK_ITERATION; i++) cublas_status = cublasGemmEx(handle, CUBLAS_OP_T, CUBLAS_OP_N, m, n, k, &alpha, A_16bit, CUDA_R_16F, k, B, CUDA_R_16F, k, &beta, D_cublas, CUDA_R_16F, m, CUDA_R_32F, CuBlasALG); cudaEventRecord(stop); cudaEventSynchronize(stop); // float milliseconds_cublas = 0; cudaEventElapsedTime(&milliseconds_cublas, start, stop); milliseconds_cublas = milliseconds_cublas / BENCHMARK_ITERATION; float tflops_cublas = static_cast((static_cast(M_GLOBAL) * N_GLOBAL * K_GLOBAL * 2) / (milliseconds_cublas / 1000.)) / 1e12; // half* D_cublas_h = NULL; // col major D_cublas_h = (half*)malloc(sizeof(half) * M_GLOBAL * N_GLOBAL); CheckMallocCPU(D_cublas_h); cudaMemcpy(D_cublas_h, D_cublas, sizeof(half) * M_GLOBAL * N_GLOBAL, cudaMemcpyDeviceToHost); // Col Major cudaFree(D_cublas); checkLastCudaError(__LINE__); ///////////////////////////////////////////////////////////////////////////////////////////////// //printf("Launching FP6-LLM...\n"); half* D_fp6 = NULL; cudaMalloc(reinterpret_cast(&D_fp6), sizeof(half) * M_GLOBAL * N_GLOBAL); CheckMallocCUDA(D_fp6); cudaMemset(D_fp6, 0, sizeof(half) * M_GLOBAL * N_GLOBAL); // int Split_K = SPLIT_K; float* Reduction_Workspace = NULL; cudaMalloc(reinterpret_cast(&Reduction_Workspace), sizeof(float) * M_GLOBAL * N_GLOBAL * Split_K); CheckMallocCUDA(Reduction_Workspace, __LINE__); // for (int i = 0; i < WARM_UP_ITERATION; i++) fp_eXmY_linear_kernel( EXPONENT, MANTISSA, 0, (uint4*)A_xbit, A_Scale, B, D_fp6, M_GLOBAL, N_GLOBAL, K_GLOBAL, Reduction_Workspace, Split_K); cudaEventRecord(start); for (int i = 0; i < BENCHMARK_ITERATION; i++) fp_eXmY_linear_kernel( EXPONENT, MANTISSA, 0, (uint4*)A_xbit, A_Scale, B, D_fp6, M_GLOBAL, N_GLOBAL, K_GLOBAL, Reduction_Workspace, Split_K); cudaEventRecord(stop); cudaEventSynchronize(stop); checkLastCudaError(__LINE__); // float milliseconds_fp6 = 0.0f; cudaEventElapsedTime(&milliseconds_fp6, start, stop); milliseconds_fp6 = milliseconds_fp6 / BENCHMARK_ITERATION; float tflops_fp6 = static_cast((static_cast(M_GLOBAL) * N_GLOBAL * K_GLOBAL * 2) / (milliseconds_fp6 / 1000.)) / 1e12; half* D_fp6_h = NULL; // col major D_fp6_h = (half*)malloc(sizeof(half) * M_GLOBAL * N_GLOBAL); cudaMemcpy(D_fp6_h, D_fp6, sizeof(half) * M_GLOBAL * N_GLOBAL, cudaMemcpyDeviceToHost); // Col Major cudaFree(D_fp6); cudaFree(Reduction_Workspace); ///////////////////////////////////////////////////////////////////////////////////////////////// double totalRelativeError_fp6 = ComputeTotalError(D_cublas_h, D_fp6_h, M_GLOBAL, N_GLOBAL); printf("************************************* "); printf("[%d-bit Weights, e%dm%d] M: %d N: %d K: %d SplitK: %d", BIT_WIDTH, EXPONENT, MANTISSA, M_GLOBAL, N_GLOBAL, K_GLOBAL, SPLIT_K); printf(" ************************************\n"); PrintPerformance("cuBLAS", milliseconds_cublas, tflops_cublas, 0.0); PrintPerformance("quant_llm", milliseconds_fp6, tflops_fp6, totalRelativeError_fp6); //PrintMismatch("fp6", 100, 0.002, D_cublas_h, D_fp6_h, M_GLOBAL, N_GLOBAL); free(D_cublas_h); free(D_fp6_h); cudaFree(B); return 0; } ================================================ FILE: tests/cpp/run.sh ================================================ #! /bin/bash # Batch sizes to test N=(1 2 3 4 5 6 7 8) # Benchmarking the specific Matrix Shape from llama-1 65b M=(13824 5120 22016 8192) K=(5120 13824 8192 22016) SplitK=(4 10 5 6) # SplitK for smaller Batch Sizes #SplitK=(3) # SplitK for Batch Sizes 128 #SplitK=(3) # SplitK for Batch Sizes 512 #SplitK=(1) # SplitK for Batch Sizes 2048 # Benchmarking Matrix Shapes from OPT models #M=(21504 7168 28672 7168 27648 9216 36864 9216 36864 12288 49152 12288) #K=(7168 7168 7168 28672 9216 9216 9216 36864 12288 12288 12288 49152) #SplitK=(2 7 7 7 2 6 3 6 3 4 1 4) # SplitK for smaller Batch Sizes #SplitK=(5 7 7 7 1 3 3 3 3 2 1 2) # SplitK for Batch Sizes 128 #SplitK=(3 2 3 3 1 3 1 3 1 1 1 1) # SplitK for Batch Sizes 512 #SplitK=(1 2 1 2 1 1 1 1 1 1 1 1) # SplitK for Batch Sizes 2048 # Benchmarking Matrix Shapes from llama-1 models #M=(12288 4096 11008 4096 15360 5120 13824 5120 19968 6656 17920 6656 24576 8192 22016 8192) #K=(4096 4096 4096 11008 5120 5120 5120 13824 6656 6656 6656 17920 8192 8192 8192 22016) #SplitK=(9 13 5 13 3 10 4 10 5 8 3 8 2 6 5 6) # SplitK for smaller Batch Sizes #SplitK=(2 6 5 6 3 5 2 5 4 4 3 4 1 3 5 3) # SplitK for Batch Sizes 128 #SplitK=(1 3 3 3 2 4 1 4 1 1 3 1 1 3 2 3) # SplitK for Batch Sizes 512 #SplitK=(1 2 1 2 1 1 1 1 1 1 1 1 1 1 1 1) # SplitK for Batch Sizes 2048 #mkdir -p Profiling for ((i=0;i<${#M[@]};i++)) do echo "Processing Shape ${i}..." for BS in ${N[@]} do echo "BS=${BS}" #ncu -f -o Profiling/M${M[i]}K${K[i]}N${BS} --set full \ #./kernel_test_fp6 ${M[i]} ${K[i]} ${BS} ${SplitK[i]} ./kernel_test_fpx 2 2 ${M[i]} ${K[i]} ${BS} ${SplitK[i]} ./kernel_test_fpx 3 2 ${M[i]} ${K[i]} ${BS} ${SplitK[i]} done done ================================================ FILE: tests/python/kernel_test_fp6.py ================================================ import argparse import torch import fp6_llm WARMUP = 10 REPEAT = 1000 parser = argparse.ArgumentParser(description='The shape of the MatMul: (M, K)*(K, N)->(M, N).') parser.add_argument('--OC', type=int, required=False, default=4096, help='number of rows of the weight matrix.') parser.add_argument('--IC', type=int, required=False, default=4096, help='number of columns of the weight matrix.') parser.add_argument('--BS', type=int, required=False, default=32, help='inference batch size.') parser.add_argument('--splitK', type=int, required=False, default=1, help='Split-K parameters allow users to split the GEMM computation along the K dimension so that more CTAs will be created with a better SM utilization.') args = parser.parse_args() assert(args.OC%256==0) assert(args.IC%64==0) print("#"*64) print(args) fp6_weight = torch.randint(4294967295, (args.OC,args.IC//16*3)).to(torch.int) # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. fp16_scale = torch.rand(args.OC).to(torch.half)+0.5 fp16_activation = torch.rand(args.BS, args.IC).to(torch.half)+0.5 start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) # fp6-fp16 GEMM (fp6-llm) #################################################################################################################################### torch.cuda.synchronize() fp6_weight_packed = fp6_llm.weight_prepacking_cpu(fp6_weight) act_cuda = fp16_activation.cuda() weight_cuda = fp6_weight_packed.cuda() scale_cuda = fp16_scale.cuda() for i in range(WARMUP): results_fp6_llm = fp6_llm.linear_forward_cuda(act_cuda, weight_cuda, scale_cuda, args.splitK); start_event.record() for i in range(REPEAT): results_fp6_llm = fp6_llm.linear_forward_cuda(act_cuda, weight_cuda, scale_cuda, args.splitK); end_event.record() torch.cuda.synchronize() fp6_llm_time_ms = start_event.elapsed_time(end_event)/REPEAT fp6_llm_tflops = args.OC*args.IC*args.BS*2/fp6_llm_time_ms/1e9 #################################################################################################################################### # baseline fp16 GEMM (cuBLAS) #################################################################################################################################### torch.cuda.synchronize() fp16_weight = fp6_llm.weight_dequant_cpu(fp6_weight, fp16_scale) cuBLAS_MatMul = torch.nn.Linear(args.IC, args.OC, False) results_cublas = None with torch.no_grad(): cuBLAS_MatMul.weight = torch.nn.Parameter(fp16_weight.clone().cuda()) act_cuda = fp16_activation.cuda() for i in range(WARMUP): results_cublas = cuBLAS_MatMul(act_cuda) start_event.record() for i in range(REPEAT): results_cublas = cuBLAS_MatMul(act_cuda) end_event.record() torch.cuda.synchronize() cublas_time_ms = start_event.elapsed_time(end_event)/REPEAT cublas_tflops = args.OC*args.IC*args.BS*2/cublas_time_ms/1e9 #################################################################################################################################### # Performance print( 'cuBLAS time: {:.3f} ms \t\t cuBLAS TFLOPs: {:.1f}'.format(cublas_time_ms, cublas_tflops) ) print( 'fp6-llm time: {:.3f} ms \t\t fp6-llm TFLOPs: {:.1f}'.format(fp6_llm_time_ms, fp6_llm_tflops) ) print( 'speedup: {:.2f}'.format(cublas_time_ms/fp6_llm_time_ms) ) # Correctness error = results_cublas.cpu() - results_fp6_llm.cpu() ground_truth = results_cublas.cpu() mean_error = torch.mean(abs(error)) mean_ground_truth = torch.mean(abs(ground_truth)) relative_error = mean_error.item()/mean_ground_truth.item() print( "relative error: {:.6f}".format(relative_error) ) ================================================ FILE: tests/python/kernel_test_fpx.py ================================================ import argparse import torch import fp6_llm WARMUP = 10 REPEAT = 1000 parser = argparse.ArgumentParser(description='The shape of the MatMul: (M, K)*(K, N)->(M, N).') parser.add_argument('--OC', type=int, required=False, default=4096, help='number of rows of the weight matrix.') parser.add_argument('--IC', type=int, required=False, default=4096, help='number of columns of the weight matrix.') parser.add_argument('--BS', type=int, required=False, default=32, help='inference batch size.') parser.add_argument('--splitK', type=int, required=False, default=1, help='Split-K parameters allow users to split the GEMM computation along the K dimension so that more CTAs will be created with a better SM utilization.') parser.add_argument('--EXP', type=int, required=False, default=2, help='number of bits of fpx Exponent, can be set to 2 or 3.') parser.add_argument('--MAN', type=int, required=False, default=2, help='number of bits of fpx Mantissa, can only be set to 2.') args = parser.parse_args() EXPONENT = args.EXP MANTISSA = args.MAN BIT_WIDTH = 1 + EXPONENT + MANTISSA assert EXPONENT in [2,3] assert MANTISSA in [2] assert(args.OC%256==0) assert(args.IC%64==0) print("#"*64) print(args) fpx_weight = torch.randint(4294967295, (args.OC,args.IC//32*BIT_WIDTH)).to(torch.int) # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. fp16_scale = torch.rand(args.OC).to(torch.half)+0.5 fp16_activation = torch.rand(args.BS, args.IC).to(torch.half)+0.5 start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) # fpx-fp16 GEMM (fp6-llm) #################################################################################################################################### torch.cuda.synchronize() fpx_weight_packed = fp6_llm.weight_prepacking_eXmY_cpu(EXPONENT, MANTISSA, fpx_weight) act_cuda = fp16_activation.cuda() weight_cuda = fpx_weight_packed.cuda() scale_cuda = fp16_scale.cuda() for i in range(WARMUP): results_fp6_llm = fp6_llm.linear_forward_eXmY_cuda(EXPONENT, MANTISSA, act_cuda, weight_cuda, scale_cuda, args.splitK) start_event.record() for i in range(REPEAT): results_fp6_llm = fp6_llm.linear_forward_eXmY_cuda(EXPONENT, MANTISSA, act_cuda, weight_cuda, scale_cuda, args.splitK) end_event.record() torch.cuda.synchronize() fp6_llm_time_ms = start_event.elapsed_time(end_event)/REPEAT fp6_llm_tflops = args.OC*args.IC*args.BS*2/fp6_llm_time_ms/1e9 #################################################################################################################################### # baseline fp16 GEMM (cuBLAS) #################################################################################################################################### torch.cuda.synchronize() fp16_weight = fp6_llm.weight_dequant_eXmY_cpu(EXPONENT, MANTISSA, fpx_weight, fp16_scale) cuBLAS_MatMul = torch.nn.Linear(args.IC, args.OC, False) results_cublas = None with torch.no_grad(): cuBLAS_MatMul.weight = torch.nn.Parameter(fp16_weight.clone().cuda()) act_cuda = fp16_activation.cuda() for i in range(WARMUP): results_cublas = cuBLAS_MatMul(act_cuda) start_event.record() for i in range(REPEAT): results_cublas = cuBLAS_MatMul(act_cuda) end_event.record() torch.cuda.synchronize() cublas_time_ms = start_event.elapsed_time(end_event)/REPEAT cublas_tflops = args.OC*args.IC*args.BS*2/cublas_time_ms/1e9 #################################################################################################################################### # Performance print( 'cuBLAS time: {:.3f} ms \t\t cuBLAS TFLOPs: {:.1f}'.format(cublas_time_ms, cublas_tflops ) ) print( 'quant-llm time: {:.3f} ms \t\t quant-llm TFLOPs: {:.1f}'.format(fp6_llm_time_ms, fp6_llm_tflops) ) print( 'speedup: {:.2f}'.format(cublas_time_ms/fp6_llm_time_ms) ) # Correctness error = results_cublas.cpu() - results_fp6_llm.cpu() ground_truth = results_cublas.cpu() mean_error = torch.mean(abs(error)) mean_ground_truth = torch.mean(abs(ground_truth)) relative_error = mean_error.item()/mean_ground_truth.item() print( "relative error: {:.6f}".format(relative_error) ) # [FP5_e2m2] Setting each element of the input matrices to 1.0f instead of a random value. #fpx_weight = torch.zeros(args.OC, args.IC//32*BIT_WIDTH).to(torch.int64) #for i in range(args.OC): # for j in range(args.IC//32): # fpx_weight[i][j*BIT_WIDTH+0] = 272762913 # fpx_weight[i][j*BIT_WIDTH+1] = 1107829124 # fpx_weight[i][j*BIT_WIDTH+2] = 136414224 # fpx_weight[i][j*BIT_WIDTH+3] = 562303042 # fpx_weight[i][j*BIT_WIDTH+4] = 2215657992 #fpx_weight = fpx_weight.to(torch.int) #fp16_scale = torch.zeros(args.OC).to(torch.half) +1.0 #fp16_activation = torch.zeros(args.BS, args.IC).to(torch.half)+1.0 # 1.0f values in fp5_e2m2 # 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 00100 # “00100001 00001000 01000010 00010000” “10000100 00100001 00001000 01000010” “00010000 10000100 00100001 00001000” “01000010 00010000 10000100 00100001” “00001000 01000010 00010000 10000100” # Considering Byte order within a INT32 # 00010000 01000010 00001000 00100001 01000010 00001000 00100001 10000100 00001000 00100001 10000100 00010000 00100001 10000100 00010000 01000010 10000100 00010000 01000010 00001000 # 00010000010000100000100000100001 01000010000010000010000110000100 00001000001000011000010000010000 00100001100001000001000001000010 10000100000100000100001000001000 # 272762913 1107829124 136414224 562303042 2215657992 # [FP6_e3m2] Setting each element of the input matrices to 1.0f instead of a random value. #fp6_weight = torch.zeros(args.OC, args.IC//32*BIT_WIDTH).to(torch.int64) #for i in range(args.OC): # for j in range(args.IC//16): # fp6_weight[i][j*3+0] = 806142768 # fp6_weight[i][j*3+1] = 3274706115 # fp6_weight[i][j*3+2] = 214118412 # fp6_weight[i][j*3+3] = 806142768 # fp6_weight[i][j*3+4] = 3274706115 # fp6_weight[i][j*3+5] = 214118412 #fp6_weight = fp6_weight.to(torch.int32) #fp16_scale = torch.zeros(args.OC).to(torch.half)+1.0 #fp16_activation = torch.zeros(args.BS, args.IC).to(torch.half)+1.0 # 1.0f values in fp6_e3m2 # 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 # "00110000110000110000110000110000" "11000011000011000011000011000011" "00001100001100001100001100001100" # 00110000 11000011 00001100 00110000 11000011 00001100 00110000 11000011 00001100 00110000 11000011 00001100 # Considering Byte order within a INT32 # 00110000 00001100 11000011 00110000 11000011 00110000 00001100 11000011 00001100 11000011 00110000 00001100 # 00110000000011001100001100110000 11000011001100000000110011000011 00001100110000110011000000001100 # 806142768 3274706115 214118412 ================================================ FILE: tests/python/run.sh ================================================ #! /bin/bash # [Batch sizes to test] # If you want to test the performance of FP6-LLM for larger inference batch sizes, # which typically happens during prompt processing, # please revise this file by simply "commenting" and "uncommenting". # BS <=64 N=(1 2 4 8 16 32 64) SplitK=(5 6 7 6) # BS = 128 #N=(128) #SplitK=(5 3 3 3) # BS = 256 #N=(256) #SplitK=(4 3 2 3) # BS = 512 #N=(512) #SplitK=(2 5 2 4) # BS = 1024 #N=(1024) #SplitK=(1 2 1 2) # BS >= 2048 #N=(2048, 4096, 8192, 16384) #SplitK=(1 1 1 1) # Benchmarking the specific Matrix Shape from llama2-70b M=(10240 8192 57344 8192) K=(8192 8192 8192 28672) #mkdir -p Profiling for ((i=0;i<${#M[@]};i++)) do for BS in ${N[@]} do #ncu -f -o Profiling/M${M[i]}K${K[i]}N${BS} --set full \ #python kernel_test_fp6.py --OC=${M[i]} --IC=${K[i]} --BS=${BS} --splitK=${SplitK[i]} python kernel_test_fpx.py --OC=${M[i]} --IC=${K[i]} --BS=${BS} --splitK=${SplitK[i]} python kernel_test_fpx.py --OC=${M[i]} --IC=${K[i]} --BS=${BS} --splitK=${SplitK[i]} --EXP=3 --MAN=2 done done