[
  {
    "path": ".gitignore",
    "content": "*ncu\n*.so\n*.o\n*.ncu-rep\n*.nsys-rep\n*.sqlite\nbuild\nkernel_test\n*.egg-info\n.vscode"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# Quant-LLM (FP6, FP5, FPx...)\nSix-bit quantization (FP6) can achieve **better trade-offs** between [*model quality*](#1-model-accuracy) and [*inference cost*](#2-speedups-on-linear-layers) compard to 4-bit and 8-bit quantization counterparts, reducing the size of large language models (LLMs) effectively and preserving the model quality consistently across varied applications.\nTo support **6-bit inference of LLMs effective on modern GPUs**, we provide the official implementation of [**FP6-LLM**](https://arxiv.org/pdf/2401.14112.pdf), achieving significant *speedups of linear layers* and *GPU memory reduction* over the fp16/int8 baselines.\n\nOur long-term goal is to support **various quantization methods** by providing **extensible & high-performance** GPU kernels for mixed-input matrix multiplication, using the **unified design scheme** presented in our [paper](https://arxiv.org/pdf/2401.14112.pdf), which is recently accepted by [USENIX ATC24](https://www.usenix.org/conference/atc24/presentation/xia).\n\nCurrently, we have tested **FP6_e3m2** and **FP5_e2m2**.\nHowever, our code is templated and easy to support different combination of eXmY if necessary.\n\n![Overview of FP6-LLM.](./docs/figures/banner.png)\n\n## Roadmap\n\nThe current release contains:\n- We support model weights in **FP6_e3m2** or **FP5_e2m2** and the activations in FP16 format.\n- Efficient CUDA implementation for mixed-input matrix multiplication of linear layers (weights in FP6 and activations in FP16 format) with Tensor Core enabled.\n- C++ APIs and PyTorch APIs to use the CUDA kernels.\n- Test codes to demonstrate the usage of FP6-LLM and verify its correctness.\n- **End-to-end inference** support of [LLaMA2](https://arxiv.org/abs/2307.09288) models is released at [Deepspeed](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024).\n\nOur future plan includes but not limited to :\n- [ ] Currently, FP6-LLM only supports **FP6** quantization due to its accuracy/performance tradeoff benefits. However, the technology of FP6-LLM can be easily applied to other quanzation methods, e.g., **FP4, INT5**.\n- [ ] Currently, FP6-LLM is only tested and verified on **A100 GPUs**, but the core design methods can also be applied to other Tensor Core GPUs like **NVIDIA H100 and GH200**. Furtheremore, W6A8 quantization can be supported on H100 GPUs by exploiting the FP8 Tensor Cores.\n\n## Installation\n1. Clone this repository.\n```sh\ngit clone https://github.com/usyd-fsalab/fp6_llm.git\ncd fp6_llm\n```\n2. Install the python package. [Enabling PyTorch APIs]\n```sh\npip install .\n```\n\n3. Compiling the .so file. [Enabling C++ APIs]\n```sh\ncd fp6_llm && make\n```\n\n## Tests\nWe provide scripts to verify the correctness of FP6-LLM.\nThe outputs of FP6-LLM are compared to the outputs of the FP16 baseliness (the official implementation of linear layer in PyTorch and NVIDIA cuBLAS).\n\n#### 1. Run tests for PyTorch APIs\n```\ncd ../tests/python\n./run.sh\n```\n\n#### 2. Run tests for C++ APIs\n```\ncd ../cpp\nmake\n./run.sh\n```\n\n## How to use our FP6 CUDA kernels\nWe implemented the CUDA kernel supporting matrix multiply C = A × B, where A is the weight matrix of shape [OC, IC] in FP6 and B is the activation matrix of shape [IC, BS] in FP16.\nC and B are column-major matrices.\nThe CUDA kernels can be launched via **PyTorch APIs** or **C++ APIs**.\nCurrently:\n- OC(Output Channel) must be a multiple of 256, and IC(Input Channel) must be a multiple of 64.\n- BS(Batch Size) can be arbitrary values, smaller BS is prefered for better performance than FP16 baseline.\n\nFor more details of using FP6-LLM APIs, please see the source code of the [C++/Python Test](#tests).\n\n#### 1. PyTorch APIs\nTo use the PyTorch APIs of FP6-LLM, you only need to import the python module:\n```\nimport fp6_llm\n```\n\n* Ahead of time weight prepacking:\n```\nfp6_pakced_weight = fp6_llm.weight_prepacking_cpu(fp6_weight)\n\nArguments:\n    fp6_weight:         int tensor of shape [OC, IC // 16 * 3];     // 3 INT32 words contains 16 FP6 weights.\nReturn:\n    fp6_pakced_weight:  int tensor of shape [OC, IC // 16 * 3];     // 3 INT32 words contains 16 FP6 weights.\n```\n\n* Execute FP6-FP16 mixed input GEMM on GPUs:\n```\nfp16_output=fp6_llm.linear_forward_cuda(fp6_packed_weight, fp16_scale, fp16_input, splitK)\n\nArguments:\n    fp6_packed_weight:  int tensor of shape [OC, IC // 16 * 3];     // 3 INT32 words contains 16 FP6 weights.\n    fp16_scale:         tensor of shape [OC];                       // half tensor\n    fp16_input:         tensor of shape [B, IC];                    // half tensor\n    splitK:             int value, spliting the MatMul problem along K dimension for higher GPU utilization, default 1.\nReturn:\n    fp16_output:        tensor of shape [B, OC];                    // half tensor\n```\n\n* To help users determine the SplitK for other MatMul shapes, we provide a heuristic function in this repo to calculate a SplitK.\nThis function can provide you a pretty good candidature for SplitK.\nTo achieve best performance, please try different SplitK executing the MatMul of the given shape and choose the best SplitK according to the measured kernel latency.\n```\nimport fp6_llm\nSplitK = fp6_llm.HeuristicFuntion_SplitK(M, N, Number_GPU_SMs)\n\nArguments:\n    M is the number of rows of the weight matrix, N is the inference batch size. More specifically, the shape of the MulMal: (M, K) * (K, N) -> (M, N), \n    Number_GPU_SMs is the number of Stream Multiprocessors of the GPU you are using. For A100 GPUs, Number_GPU_SMs=108.\n```\n\n* Dequantize an FP6 matrix back to FP16 matrix with CPUs (a useful tool to construct input matrices for the FP16 GEMM baseline):\n```\nfp16_tensor=fp6_llm.weight_dequant_cpu(fp6_tensor, fp16_scale)\n\nArguments:\n    fp6_tensor:         int  tensor of shape [OC, IC // 16 * 3];   // 3 INT32 words contains 16 FP6  weights.\n    fp16_scale:         half tensor of shape [OC];                 // for row-wise quantization.\nReturn:\n    fp16_tensor:        half tensor of shape [OC, IC].\n```\n\n#### 2. C++ APIs\nTo use the C++ APIs of FP6-LLM, you need to include [this head file](fp6_llm/csrc/fp6_linear.cuh) of FP6-LLM in your C++ codes:\n```\n#include \"fp6_linear.cuh\"\n```\nBesides, you need to link the dynamic linkable library (*fp6.so*) under [this directory](fp6_llm/) during compilation.\n\n* Ahead of time weight prepacking:\n```\nvoid weight_matrix_prepacking(int* packed_weights,                  // [Output] prepacked FP6 weight matrix\n                              int *FP6Weights,                      // [Input]  original FP6 weight matrix\n                              size_t M,                             // OC\n                              size_t K);                            // IC\n```\n\n* Execute FP6-FP16 mixed input GEMM on GPUs:\n```\ncudaError_t fp6_linear_kernel(cudaStream_t    stream,               // CUDA stream to execute the GPU kernel.\n                              const uint4     *Weight,              // [Input]  Pointer to the FP6 weight matrix.\n                              const half      *Scales,              // [Input]  Pointer to the FP16 quantization scales.\n                              const half      *B,                   // [Input]  Pointer to the FP16 input activation.\n                              half            *C,                   // [Output] Pointer to the FP16 output activation.\n                              const size_t    M_Global,             // OC\n                              const size_t    N_Global,             // BS\n                              const size_t    K_Global,             // IC\n                              float           *Reduction_Workspace, // Pointer to the temporary workspace for splitK reduction. Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32)\n                              int             Split_K);             // splitK\n```\n\n* Dequantize an FP6 matrix back to FP16 matrix with CPUs\n```\nvoid DeQuantMatrix_FP6_To_FP16(half* A_16bit_h,                     // [Output] Pointer to the dequantized FP16 weight. matrix.\n                               unsigned char* A_6bit_h,             // [Input]  Pointer to the quantized FP6 weight matrix. \n                               size_t M,                            // OC\n                               size_t K,                            // IC\n                               half* scale);                        // [Input]  Pointer to the FP16 quantization scales.\n```\n\n## How to use our FP5 CUDA kernels\nSimilar to FP6 APIs, except that two more parameters (EXPONENT, MANTISSA) are required. More detailed will be provided soon.\n\n## Performance\n\n#### 1. Model Accuracy\n> **FP6 quantization is a practical alternative to further democratize the deployment of LLMs without significantly sacrificing model quality on complex tasks and various model sizes.**\n\n- While 4-bit quantization unavoidably causes degradation in model quality, near-lossless model compression can be achieved with 6-bit quantization. As shown in Table 1 and Table 2, FP6 displays **strong and consistent performance across various tasks** including code generation and zero-shot perplexity performance.\n- It also shows **high robustness across various model sizes**, e.g., 1B, 13B, and 65B LLaMA models. \n- FP6 quantization already works well on coarse-grained quantization, while INT4 quantization heavily relies on Fine-Grained Quantization (FGQ) methods to maintain high model quality.\n- *More details can be found in these two papers ([FP6-LLM](https://arxiv.org/pdf/2401.14112.pdf) & [ZeroQuant(4+2)](https://arxiv.org/abs/2312.08583)).*\n\n![ModelAccuracy](docs/figures/Accuracy.png)\n\n\n#### 2. Speedups on linear layers\n\n> **Compared to the FP16 baselines (cuBLAS), INT8 quantization (TensorRT_LLM),\nand FP4 quantization (bitsandbytes):**\n\n- FP6-LLM outperforms bitsandbytes, cuBLAS, and TensorRT_LLM by **up to** 8.9×, 2.6×, and 1.9×. \n- FP6-LLM outperforms bitsandbytes, cuBLAS and TensorRT_LLM by 7.2×, 2.1×, and 1.3× **on average**.\n\n![Speedup_To_8bit](docs/figures/Speedup_to_8bit.png)\n\n> **Compared to baselines for INT4 quantization (TensorRT_LLM):**\n\n- FP6-LLM is **1.06×/1.04×/0.94× faster** than **Fine-grained_W4A16** at batch size 8/16/32, espectively. \n- FP6-LLM is **only 16%/17%/24% slower** than **Coarse-grained_W4A16** at batch size 8/16/32. \n- It is **a worthwhile trade-off** between model quality and inference speed, since FP6 quantization can provide [higher model quality](#1-model-accuracy) than INT4 quantization. \n\n![Speedup_To_4bit](docs/figures/Speedup_to_4bit.png)\n\n#### 3. Speedups on end-to-end inference\n\n> **Experimental settings and baselines:**\n\n- We integrate our FP6-LLM kernel into DeepSpeed for end-to-end evaluation. The baseline for comparison is the FP16 execution of the original DeepSpeed inference system.\n- We set the prefill/prompt length of each request to 0.5K, and generate 1.5K tokens for each request ignoring the \"EOS\" (end of sequence) token. \n\n![End-to-end Inference](./docs/figures/e2e_inference.png)\n\n\n> **Results of LLaMA-70b:**\n- Both FP6-LLM and FP16 baseline can at most set the inference **batch size to 32** before running out of GPU memory, whereas **FP6-LLM** only requires **a single GPU** and the **baseline** uses **two GPUs**. \n- FP6-LLM achieves **1.69×-2.65× higher normalized inference throughput** than the FP16 baseline.\n\n> **Results of OPT-30b:**\n- FP6-LLM can set the inference batch size at most to 16 before running out of GPU memory while the FP16 baseline can at most serve 4 requests in a batch. \n- Using a 80GB A100 GPU, FP6-LLM/FP16-baseline can at most achieve 319.1/78.8 tokens per GPU-second with batch size 16/4.\n- FP6-LLM achieves 1.91×/1.84×/1.72× higher generation throughput compared to the FP16 baseline when their batch sizes are set to 1/2/4.\n\n> **Inference latency breakdown of LLaMA-70b:**\n- The execution of linear layers (MatMul) implemented with FP6-LLM is 1.20× faster than the FP16 baseline on average, even with half number of GPUs.\n- The FP16 baseline is faster running multi-head attention (MHA) with 2-way tensor parallelism.\n- Cross-GPU communications (NCCL) is avoided using FP6-LLM since only a\nsingle GPU is required. \n\n> **Inference latency breakdown of OPT-30b:**\n- End-to-end performance improvements mainly comes from time reduction in executing linear layers.\n- The linear layers implemented with FP6-LLM are 2.39× faster than the FP16 baselines on average.\n\n\n## Key Innovations\n> We propose **TC-FPx**, **the first full-stack GPU system design scheme with unified Tensor Core support of float-point weights for various quantization bit-width (6-bit, 5-bit, 3-bit, etc.), mitigating the \"memory wall\" issues during LLM inference.**\nTC-FPx breaks the limitations of the underlying GPU hardware, allowing the GPU to support linear layer calculations involving model weights of arbitrary bit width. \nIn TC-FPx, Tensor Cores are utilized for intensive computation of matrix multiplications,\nwhile SIMT cores are effectively leveraged for weight dequantization, transforming the x-bit model weights to FP16 type during runtime before feeding them to Tensor Cores. \n\n- We propose **Ahead-of-time Bit-level Pre-packing** to resolve the challenge of unfriendly memory access for weights with irregular bit-width, enabling optimal GPU memory access.\n- Besides, we propose **SIMT-Efficient GPU Runtime** to minimize the runtime overhead of weight de-quantization. \n- Last but not least, we present the software pipeline of TC-FPx kernel, where SIMT cores, Tensor Cores, and the GPU memory hierarchy cooperate efficiently with high performance.\n\n![](./docs/figures/designs.png)\n\n## FP6-LLM Community\nFP6-LLM is already integrated in [DeepSpeed](https://github.com/microsoft/DeepSpeed) and this new feature will be available soon.\nGiven that easy-to-use PyTorch and C++ APIs are provided, FP6-LLM can be easily integrated to any inference frameworks as an useful component.\nWe welcome collaborations to integrate FP6-LLM into other inference frameworks.\nWe also welcome all AI developers/practitioners/researchers to join this on-going project, fully exploring the potential of different quantization methods.\n\n## Citation\n\nIf you find FP6-LLM useful or relevant to your research, please kindly cite [our paper](https://arxiv.org/pdf/2401.14112.pdf):\n\n```\n@misc{xia2024fp6llm,\n      title={FP6-LLM: Efficiently Serving Large Language Models Through FP6-Centric Algorithm-System Co-Design}, \n      author={Haojun Xia and Zhen Zheng and Xiaoxia Wu and Shiyang Chen and Zhewei Yao and Stephen Youn and Arash Bakhtiari and Michael Wyatt and Donglin Zhuang and Zhongzhu Zhou and Olatunji Ruwase and Yuxiong He and Shuaiwen Leon Song},\n      year={2024},\n      eprint={2401.14112},\n      archivePrefix={arXiv},\n      primaryClass={cs.LG}\n}\n```\n\n## Change Logs\n- **[28th May, 2024]**: Release of FP6-LLM v0.2. Change the project name back to QuantLLM.\n- **[31th April, 2024]**: Our paper is accepted by [USENIX ATC24](https://www.usenix.org/conference/atc24/presentation/xia).\n- **[4th March, 2024]**: Release of FP6-LLM v0.1.\n\n## Related Projects\n- [ZeroQuant(4+2): Redefining LLMs Quantization with a New FP6-Centric Strategy for Diverse Generative Tasks](https://arxiv.org/abs/2312.08583)\n- [DeepSpeed: Extreme Speed and Scale for DL Training and Inference](https://github.com/microsoft/DeepSpeed)\n- [cuBLAS: Basic Linear Algebra on NVIDIA GPUs](https://developer.nvidia.com/cublas)\n- [TensorRT-LLM: A TensorRT Toolbox for Optimized Large Language Model Inference](https://github.com/NVIDIA/TensorRT-LLM)\n- [bitsandbytes: the library including quantization primitives for 8-bit & 4-bit operations](https://github.com/TimDettmers/bitsandbytes)\n- [Mixed-input matrix multiplication performance optimizations](https://blog.research.google/2024/01/mixed-input-matrix-multiplication.html)\n- [Flash-LLM: Enabling Cost-Effective and Highly-Efficient Large Generative Model Inference with Unstructured Sparsity](https://www.vldb.org/pvldb/vol17/p211-xia.pdf)\n"
  },
  {
    "path": "examples/README.md",
    "content": "# Example of LLM inference using FP6-LLM\n\nExample scripts of using FP6-LLM for end-to-end inference are coming soon."
  },
  {
    "path": "fp6_llm/Makefile",
    "content": "# host compiler\nHOST_COMPILER ?= g++\nNVCC          := nvcc -ccbin $(HOST_COMPILER)\n\n# internal flags\nNVCCFLAGS   := -m$(shell getconf LONG_BIT)\nCCFLAGS     := -fPIC\nLDFLAGS     :=\n\nALL_CCFLAGS :=\nALL_CCFLAGS += $(NVCCFLAGS)\nALL_CCFLAGS += $(addprefix -Xcompiler ,$(CCFLAGS))\n\n################################################################################\nALL_CCFLAGS += -DNO_PYTORCH\nALL_CCFLAGS += --std=c++17\nALL_CCFLAGS += -maxrregcount=255\nALL_CCFLAGS += --use_fast_math\nALL_CCFLAGS += --ptxas-options=-v,-warn-lmem-usage,--warn-on-spills\n################################################################################\n\nALL_LDFLAGS :=\nALL_LDFLAGS += $(ALL_CCFLAGS)\nALL_LDFLAGS += $(addprefix -Xlinker ,$(LDFLAGS))\n\n# Common includes and paths for CUDA\nINCLUDES  := -I/usr/local/cuda/include/\n\n# Gencode arguments\nSMS ?= 80\n# Generate SASS code for each SM architecture listed in $(SMS)\n$(foreach sm,$(SMS),$(eval GENCODE_FLAGS += -gencode arch=compute_$(sm),code=sm_$(sm)))\n\n\nHEAD_FILES = csrc/fp6_linear.cuh \\\n\t\t\t csrc/include/configs.h \\\n\t\t\t csrc/include/kernel_matmul.cuh \\\n\t\t\t csrc/include/kernel_reduction.cuh \\\n\t\t\t csrc/include/ptx_cp.async.cuh \\\n\t\t\t csrc/include/ptx_mma.cuh \\\n\t\t\t csrc/include/utils_core.cuh \\\n\t\t\t csrc/include/utils_gmem.cuh \\\n\t\t\t csrc/include/utils_parallel_dequant.cuh \\\n\t\t\t csrc/utils/common.h \\\n\t\t\t csrc/utils/weight_prepacking.h \\\n\t\t\t csrc/utils/weight_quant.h \\\n\t\t\t csrc/utils/weight_dequant.h\n\n# Target rules\nall: libfp6.so\n\nlibfp6.so: fp6.o\n\t$(EXEC) $(NVCC) --shared $(ALL_LDFLAGS) $(GENCODE_FLAGS) -o $@ $+ $(LIBRARIES)\n\nfp6.o:  csrc/fp6_linear.cu $(HEAD_FILES)  \n\t$(EXEC) $(NVCC) $(INCLUDES) $(ALL_CCFLAGS) $(GENCODE_FLAGS) -o $@ -c $<\n\nclean:\n\trm -f libfp6.so fp6.o"
  },
  {
    "path": "fp6_llm/__init__.py",
    "content": "from fp6_llm_cuda import linear_forward_cuda, weight_prepacking_cpu, weight_dequant_cpu, linear_forward_eXmY_cuda, weight_prepacking_eXmY_cpu, weight_dequant_eXmY_cpu\n\n\n\nimport math\ndef Num_Wave(M, N, SplitK, Num_GPU_SMs):\n    Num_Wave = math.ceil(M/512) * math.ceil(N/64) * SplitK / Num_GPU_SMs\n    return Num_Wave\n\n# The shape of the MatMul: (M, K)*(K, N)->(M, N).\n# Typically, M is the number of rows of weight matrix, and N is the inference batch size.\n# Num_GPU_SMs is the number of SMs (Streaming Multiprocessors) within the GPUs, e,g. each A100 GPU has 108 SMs.\ndef HeuristicFuntion_SplitK(M, N, Num_GPU_SMs):\n    SplitK=1\n    Efficiency=0.0\n    for i in range(1, 100):\n        numWave = Num_Wave(M, N, i, Num_GPU_SMs)\n        eff = numWave / math.ceil(numWave)\n        if eff >= 0.8:\n            SplitK = i\n            Efficiency = eff\n            break\n    return SplitK\n    "
  },
  {
    "path": "fp6_llm/csrc/fp6_linear.cu",
    "content": "#include \"include/kernel_matmul.cuh\"\n#include \"include/kernel_reduction.cuh\"\n#include \"utils/weight_prepacking.h\"\n#include \"utils/weight_dequant.h\"\n#include \"utils/weight_quant.h\"\n#include <ATen/cuda/CUDAContext.h> // For CUDA stream management\n\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <assert.h>\n\ntemplate<typename TilingConfig, typename OutputDataType, int EXPONENT, int MANTISSA>\nstatic void Kernel_Ex(cudaStream_t    stream,\n                      const uint4     *Weight,\n                      const half      *Scales,\n                      const half      *B,\n                      OutputDataType  *C,\n                      const size_t    M_Global,\n                      const size_t    N_Global,\n                      const size_t    K_Global, \n                      int             Split_K) \n{\n    #ifdef DEBUG_MODE\n        printf(\"\\n\");\n        printf(\"Launcher.cu->Kernel_Ex():\\n\");\n        printf(\"M: %d, N: %d, K: %d, SplitK: %d\\n\", M_Global, N_Global, K_Global, Split_K);\n        printf(\"TILE_M: %d, TILE_K: %d, TILE_N: %d\\n\", TilingConfig::TILE_M, TilingConfig::TILE_K, TilingConfig::TILE_N);\n    #endif\n    static size_t SHMEM_SZ = max(TilingConfig::SMEM_SIZE_B_TILE+SMEM_SIZE_PER_TB_A_TILE, TilingConfig::SMEM_SIZE_C_TILE);\n    cudaFuncSetAttribute(QUANT_GEMM_Kernel<TilingConfig, OutputDataType, EXPONENT, MANTISSA>, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ);\n    size_t  dimN = (N_Global-1) / TilingConfig::TILE_N + 1;\n    size_t  dimM = M_Global * Split_K / TilingConfig::TILE_M;\n    dim3    GridDim(dimN, dimM, 1);\n    dim3    BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1);\n    //\n    #ifdef DEBUG_MODE\n        printf(\"GridDim.x: %d, GridDim.y: %d, GridDim.z: %d, BlockDim.x: %d, BlockDim.y: %d, BlockDim.z: %d SHMEM_SZ: %d\\n\",\n                GridDim.x, GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z, SHMEM_SZ);\n        printf(\"\\n\");\n    #endif\n    QUANT_GEMM_Kernel<TilingConfig, OutputDataType, EXPONENT, MANTISSA><<<GridDim, BlockDim, SHMEM_SZ, stream>>>\n                    (Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K);\n}\n\ntemplate<int EXPONENT, int MANTISSA>\ncudaError_t fpx_linear_kernel(cudaStream_t    stream,\n                              const uint4     *Weight,\n                              const half      *Scales,\n                              const half      *B,\n                              half            *C,\n                              const size_t    M_Global,\n                              const size_t    N_Global,\n                              const size_t    K_Global, \n                              float           *Reduction_Workspace,  // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32)\n                              int             Split_K)\n{\n    assert(M_Global % 256 == 0);\n    assert(K_Global % 64 == 0);\n    assert(N_Global>0);\n\n    // Work around to support more N shapes:\n    size_t N_PowerOf2;\n    if(N_Global>0 &&  N_Global<=8)      N_PowerOf2 = 8;\n    if(N_Global>8 &&  N_Global<=16)     N_PowerOf2 = 16;\n    if(N_Global>16 && N_Global<=32)     N_PowerOf2 = 32;\n    if(N_Global>32 && N_Global<=64)     N_PowerOf2 = 64;\n    if(N_Global>64 && N_Global<=128)    N_PowerOf2 = 128;\n    if(N_Global>128)                    N_PowerOf2 = ((N_Global-1)/128+1) * 128;\n\n    if (Split_K == 1) {\n        switch (N_PowerOf2) {\n            case 8:     Kernel_Ex<TilingConfig<4, 1, 1>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K);  break;\n            case 16:    Kernel_Ex<TilingConfig<4, 1, 2>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K);  break;\n            case 32:    Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K);  break;\n            case 64:    Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K);  break;\n            case 128:   Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K);  break;\n            default:    if (N_PowerOf2 % 128 != 0) {\n                            printf(\"FP6LLM_API Error: Unsupported N dimension %d!\\n\", N_PowerOf2);\n                            return cudaErrorUnknown;\n                        }\n                        Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K);  break;\n        }\n    }\n    else {\n        switch (N_PowerOf2) {\n            case 8:     Kernel_Ex<TilingConfig<4, 1, 1>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K);  break;\n            case 16:    Kernel_Ex<TilingConfig<4, 1, 2>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K);  break;\n            case 32:    Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K);  break;\n            case 64:    Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K);  break;\n            case 128:   Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K);  break;\n            default:    if (N_PowerOf2 % 128 != 0) {\n                            printf(\"FP6LLM_API Error: Unsupported N dimension %d!\\n\", N_PowerOf2);\n                            return cudaErrorUnknown;\n                        }\n                        Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K);  break;\n        }\n        // Reduction for SplitK\n        dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1);\n        dim3 BlockDim(WARP_SIZE, 1, 1);\n        SplitK_Reduction<<<GridDim, BlockDim, 0, stream>>>(C, Reduction_Workspace, M_Global, N_Global, Split_K);\n    }\n    return cudaGetLastError();\n}\n\ncudaError_t fp6_linear_kernel(\n    cudaStream_t    stream,\n    const uint4     *Weight,\n    const half      *Scales,\n    const half      *B,\n    half            *C,\n    const size_t    M_Global,\n    const size_t    N_Global,\n    const size_t    K_Global, \n    float           *Reduction_Workspace,\n    int             Split_K) {\n    //\n    return fpx_linear_kernel<3,2>( stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,  Reduction_Workspace, Split_K);\n}\n\ncudaError_t fp_eXmY_linear_kernel(\n    const int       EXPONENT,\n    const int       MANTISSA,\n    cudaStream_t    stream,\n    const uint4     *Weight,\n    const half      *Scales,\n    const half      *B,\n    half            *C,\n    const size_t    M_Global,\n    const size_t    N_Global,\n    const size_t    K_Global, \n    float           *Reduction_Workspace,\n    int             Split_K) {\n    //\n    if(EXPONENT==2 && MANTISSA==2)\n        return fpx_linear_kernel<2,2>( stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,  Reduction_Workspace, Split_K);\n    if(EXPONENT==3 && MANTISSA==2)\n        return fpx_linear_kernel<3,2>( stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,  Reduction_Workspace, Split_K);\n    printf(\"QuantLLM_API Error: Unsupported EXPONENT=%d, MANTISSA=%d!\\n\", EXPONENT, MANTISSA);\n    exit(-1);\n}\n\n#ifndef NO_PYTORCH\n#include <torch/extension.h>\n#include <ATen/ATen.h>\n\n/////////////////////////////////////////////////// Old Interface only Supporting FP6 /////////////////////////////////////////////////////////////////////\n/*\nComputes FP6-FP16 GEMM (PyTorch interface).\n\n[Mathmatical Formula]\nStandard definition of linear layer:    Out = In * trans(W), where In, Out, and W are stored in row-major.\nAfter Equivalent transformation    :    trans(Out) = W * trans(In). Note that we do not perform \"transpose\" during runtime, we instead interpret the In/Out as column-major matrices when calling our CUDA kernel.\n\n[Inputs]\n  _in_feats:  tensor of shape [B, IC];                  // half \n  _weights:   int tensor of shape [OC, IC // 16 * 3];   // 3 INT32 words contains 16 FP6 weights.\n  _scales:    tensor of shape [OC];                     // half\n  splitK:     spliting the MatMul problem along K dimension for higher GPU utilization, default 1.\n[Outputs]\n  _out_feats: tensor of shape [B, OC];                  // half\n*/\ntorch::Tensor fp6_linear_forward_cuda(\n    torch::Tensor _in_feats,\n    torch::Tensor _weights,\n    torch::Tensor _scales,\n    int           splitK=1)\n{\n    int num_in_feats      = _in_feats.size(0);\n    int num_in_channels   = _in_feats.size(1);\n    int num_out_channels  = _weights.size(0);\n    assert( num_in_channels%64 == 0 );\n    assert( (num_in_channels/16*3) == _weights.size(1) );    // Making sure the K dimension is matched.\n    //\n    int M = num_out_channels;\n    int K = num_in_channels;\n    int N = num_in_feats;\n    // Input Tensors\n    auto weight = reinterpret_cast<const uint4*>(_weights.data_ptr<int>());  // weights is [OC, IC] but in FP6.\n    auto in_feats = reinterpret_cast<const half*>(_in_feats.data_ptr<at::Half>());\n    auto scales   = reinterpret_cast<const half*>(_scales.data_ptr<at::Half>());\n    // Output Tensors\n    auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());\n    at::Tensor _out_feats = torch::empty({num_in_feats, num_out_channels}, options);\n    auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());\n\n    options = torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device());\n    at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options);\n    auto Reduction_Workspace = reinterpret_cast<float*>(_workspace.data_ptr<float>());  // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32)\n\n    // Get the current device and stream\n    int device_id = _in_feats.device().index();\n    auto stream = at::cuda::getCurrentCUDAStream(device_id).stream();\n\n    fp6_linear_kernel(stream, // Using current stream.\n                      weight,\n                      scales,\n                      in_feats,\n                      out_feats,\n                      M,\n                      N,\n                      K, \n                      Reduction_Workspace,  \n                      splitK);\n\n    return _out_feats;\n}\n\n\n/*\n * Weight prepacking (Pytorch interface).\n * [Input & Output]\n *  fp6_tensor: int tensor of shape [OC, IC // 16 * 3];   // 3 INT32 words contains 16 FP6 weights.\n * [Output]\n *  packed_tensor: int tensor of shape [OC, IC // 16 * 3];\n */\ntorch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor)\n{\n    size_t OC = fp6_tensor.size(0);\n    size_t IC = fp6_tensor.size(1);\n    assert (IC%3==0);   \n    IC = IC*16/3;\n    assert( (OC%256==0) && (IC%64==0) );\n    auto packed_tensor = torch::empty_like(fp6_tensor);\n    auto packed_tensor_ptr = reinterpret_cast<int*>(packed_tensor.data_ptr<int>());\n    auto fp6_tensor_ptr = reinterpret_cast<int*>(fp6_tensor.data_ptr<int>());\n    weight_matrix_prepacking(packed_tensor_ptr, fp6_tensor_ptr, OC, IC);\n    return packed_tensor;\n}\n\n/*\n * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs.\n * A useful tool to construct input matrices for the FP16 GEMM baseline.\n * [Input]\n *  fp6_tensor:  int  tensor of shape [OC, IC // 16 * 3];   // 3 INT32 words contains 16 FP6  weights.\n *  fp16_scale:  half tensor of shape [OC];                 // for row-wise quantization.\n * [Output]\n *  fp16_tensor: half tensor of shape [OC, IC].     \n */\ntorch::Tensor weight_matrix_dequant_cpu(torch::Tensor fp6_tensor, torch::Tensor fp16_scale) \n{\n    int OC = fp6_tensor.size(0);\n    assert(fp6_tensor.size(1) % 3 == 0);\n    int IC = fp6_tensor.size(1) / 3 * 16;\n    assert(fp16_scale.size(0)==OC);\n    //\n    auto fp6_tensor_ptr = reinterpret_cast<int*>(fp6_tensor.data_ptr<int>());\n    auto fp16_scale_ptr = reinterpret_cast<half*>(fp16_scale.data_ptr<at::Half>());\n    //\n    auto options = torch::TensorOptions().dtype(fp16_scale.dtype()).device(fp16_scale.device());\n    at::Tensor fp16_tensor = torch::empty({OC, IC}, options);\n    auto fp16_tensor_ptr = reinterpret_cast<half*>(fp16_tensor.data_ptr<at::Half>());\n    //\n    DeQuantMatrix_FP6_To_FP16(fp16_tensor_ptr, (unsigned char*)fp6_tensor_ptr, OC, IC, fp16_scale_ptr);\n    //\n    return fp16_tensor;\n}\n\n/////////////////////////////////////////////////// New Interface Supporting FPx /////////////////////////////////////////////////////////////////////\n/*\nComputes FPx-FP16 GEMM (PyTorch interface).\n\n[Mathmatical Formula]\nStandard definition of linear layer:    Out = In * trans(W), where In, Out, and W are stored in row-major.\nAfter Equivalent transformation    :    trans(Out) = W * trans(In). Note that we do not perform \"transpose\" during runtime, we instead interpret the In/Out as column-major matrices when calling our CUDA kernel.\n\n[Inputs]\n  _in_feats:  tensor of shape [B, IC];                  // half \n  _weights:   int tensor of shape [OC, IC // 32 * x];   // x INT32 words contains 32 FPx weights.\n  _scales:    tensor of shape [OC];                     // half\n  splitK:     spliting the MatMul problem along K dimension for higher GPU utilization, default 1.\n[Outputs]\n  _out_feats: tensor of shape [B, OC];                  // half\n*/\ntorch::Tensor fp_eXmY_linear_forward_cuda(\n    int             EXPONENT,\n    int             MANTISSA,\n    torch::Tensor   _in_feats,\n    torch::Tensor   _weights,\n    torch::Tensor   _scales,\n    int             splitK=1)\n{\n    int num_in_feats      = _in_feats.size(0);\n    int num_in_channels   = _in_feats.size(1);\n    int num_out_channels  = _weights.size(0);\n    assert( num_in_channels%64 == 0 );\n    assert( (num_in_channels/32*(1+EXPONENT+MANTISSA)) == _weights.size(1) );    // Making sure the K dimension is matched.\n    //\n    int M = num_out_channels;\n    int K = num_in_channels;\n    int N = num_in_feats;\n    // Input Tensors\n    auto weight = reinterpret_cast<const uint4*>(_weights.data_ptr<int>());  // weights is [OC, IC] but in FP6.\n    auto in_feats = reinterpret_cast<const half*>(_in_feats.data_ptr<at::Half>());\n    auto scales   = reinterpret_cast<const half*>(_scales.data_ptr<at::Half>());\n    // Output Tensors\n    auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());\n    at::Tensor _out_feats = torch::empty({num_in_feats, num_out_channels}, options);\n    auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());\n\n    options = torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device());\n    at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options);\n    auto Reduction_Workspace = reinterpret_cast<float*>(_workspace.data_ptr<float>());  // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32)\n\n    // Get the current device and stream\n    int device_id = _in_feats.device().index();\n    auto stream = at::cuda::getCurrentCUDAStream(device_id).stream();\n\n    //\n    fp_eXmY_linear_kernel(\n        EXPONENT,\n        MANTISSA,\n        stream, // Using torch's current stream here.\n        weight,\n        scales,\n        in_feats,\n        out_feats,\n        M,\n        N,\n        K, \n        Reduction_Workspace,  \n        splitK);\n    return _out_feats;\n}\n\n\n/*\n * Weight prepacking (Pytorch interface).\n * [Input & Output]\n *  fpx_tensor: int tensor of shape [OC, IC // 32 * x];\n * [Output]\n *  packed_tensor: int tensor of shape [OC, IC // 32 * x];\n */\ntorch::Tensor weight_matrix_prepacking_fp_eXmY_cpu(\n    int EXPONENT,\n    int MANTISSA,\n    torch::Tensor fpx_tensor)\n{\n    int BIT_WIDTH = 1 + EXPONENT + MANTISSA;\n    //\n    size_t OC = fpx_tensor.size(0);\n    size_t IC = fpx_tensor.size(1);\n    assert (IC%BIT_WIDTH==0);   \n    IC = IC*32/BIT_WIDTH;\n    assert( (OC%256==0) && (IC%64==0) );\n    auto packed_tensor = torch::empty_like(fpx_tensor);\n    auto packed_tensor_ptr = reinterpret_cast<int*>(packed_tensor.data_ptr<int>());\n    auto fpx_tensor_ptr = reinterpret_cast<int*>(fpx_tensor.data_ptr<int>());\n    //\n    weight_matrix_prepacking_fp_eXmY(EXPONENT, MANTISSA, packed_tensor_ptr, fpx_tensor_ptr, OC, IC);\n    return packed_tensor;\n}\n\n/*\n * Dequant a FPx matrix to a equivalent FP16 matrix using CPUs.\n * A useful tool to construct input matrices for the FP16 GEMM baseline.\n * [Input]\n *  fpx_tensor:  int  tensor of shape [OC, IC // 32 * x];   //\n *  fp16_scale:  half tensor of shape [OC];                 // for row-wise quantization.\n * [Output]\n *  fp16_tensor: half tensor of shape [OC, IC].     \n */\ntorch::Tensor weight_matrix_dequant_fp_eXmY_cpu(\n    int EXPONENT,\n    int MANTISSA,\n    torch::Tensor fpx_tensor,\n    torch::Tensor fp16_scale) \n{\n    int BIT_WIDTH = 1 + EXPONENT + MANTISSA;\n    //\n    int OC = fpx_tensor.size(0);\n    assert(fpx_tensor.size(1) % BIT_WIDTH == 0);\n    int IC = fpx_tensor.size(1) / BIT_WIDTH * 32;\n    assert(fp16_scale.size(0)==OC);\n    //\n    auto fpx_tensor_ptr = reinterpret_cast<int*>(fpx_tensor.data_ptr<int>());\n    auto fp16_scale_ptr = reinterpret_cast<half*>(fp16_scale.data_ptr<at::Half>());\n    //\n    auto options = torch::TensorOptions().dtype(fp16_scale.dtype()).device(fp16_scale.device());\n    at::Tensor fp16_tensor = torch::empty({OC, IC}, options);\n    auto fp16_tensor_ptr = reinterpret_cast<half*>(fp16_tensor.data_ptr<at::Half>());\n    //\n    dequant_matrix_fp_eXmY_to_fp16(EXPONENT, MANTISSA, fp16_tensor_ptr, (unsigned char*)fpx_tensor_ptr, OC, IC, fp16_scale_ptr);\n    //\n    return fp16_tensor;\n}\n#endif"
  },
  {
    "path": "fp6_llm/csrc/fp6_linear.cuh",
    "content": "#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n/*\n* Computes FP6-FP16 GEMM (C++ interface).\n*/\ncudaError_t fp6_linear_kernel(\n    cudaStream_t    stream,\n    const uint4     *Weight,\n    const half      *Scales,\n    const half      *B,\n    half            *C,\n    const size_t    M_Global,\n    const size_t    N_Global,\n    const size_t    K_Global, \n    float           *Reduction_Workspace,  // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32)\n    int             Split_K);\n\ncudaError_t fp_eXmY_linear_kernel(\n    const int       EXPONENT,\n    const int       MANTISSA,\n    cudaStream_t    stream,\n    const uint4     *Weight,\n    const half      *Scales,\n    const half      *B,\n    half            *C,\n    const size_t    M_Global,\n    const size_t    N_Global,\n    const size_t    K_Global, \n    float           *Reduction_Workspace,\n    int             Split_K);\n/*\n * In-place weight prepacking (C++ interface).\n */\nvoid weight_matrix_prepacking(int* packed_weights, int *FP6Weights, size_t M, size_t K);\nvoid weight_matrix_prepacking_fp_eXmY(const int EXPONENT, const int MANTISSA, int* packed_weights, int *FPxWeights, size_t M, size_t K);\n\n/*\n * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs.\n */\nvoid DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale);\nvoid dequant_matrix_fp_eXmY_to_fp16(const int EXPONENT, const int MANTISSA, half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale);\n\n#ifndef NO_PYTORCH\n#include <torch/extension.h>\n/*\n* Computes FP6-FP16 GEMM (PyTorch interface).\n*/\ntorch::Tensor fp6_linear_forward_cuda(\n    torch::Tensor _in_feats,\n    torch::Tensor _weights,\n    torch::Tensor _scales,\n    int           splitK=1);\ntorch::Tensor fp_eXmY_linear_forward_cuda(\n    int             EXPONENT,\n    int             MANTISSA,\n    torch::Tensor   _in_feats,\n    torch::Tensor   _weights,\n    torch::Tensor   _scales,\n    int             splitK=1);\n\n\n/*\n * Weight prepacking (Pytorch interface).\n */\ntorch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor);\ntorch::Tensor weight_matrix_prepacking_fp_eXmY_cpu(\n    int EXPONENT,\n    int MANTISSA,\n    torch::Tensor fpx_tensor);\n\n\n/*\n * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs.\n * A useful tool to construct input matrices for the FP16 GEMM baseline.\n * [Input]\n *  fp6_tensor:  int  tensor of shape [OC, IC // 16 * 3];   // 3 INT32 words contains 16 FP6  weights.\n *  fp16_scale:  half tensor of shape [OC];                 // for row-wise quantization.\n * [Output]\n *  fp16_tensor: half tensor of shape [OC, IC].     \n */\ntorch::Tensor weight_matrix_dequant_cpu(\n    torch::Tensor fp6_tensor, \n    torch::Tensor fp16_scale);\ntorch::Tensor weight_matrix_dequant_fp_eXmY_cpu(\n    int EXPONENT,\n    int MANTISSA,\n    torch::Tensor fpx_tensor,\n    torch::Tensor fp16_scale);\n#endif"
  },
  {
    "path": "fp6_llm/csrc/include/configs.h",
    "content": "#ifndef CONFIGS_H\n#define CONFIGS_H\n\n//#define DEBUG_MODE\n#define PIPELINE_LEVEL_GMEM 2\n#define PIPELINE_LEVEL_SMEM 2       // only support 2\n\n/************************ Hardware Parameters ************************/\n#define WARP_SIZE                           32\n#define REG_BIT_WIDTH                       32\n// mma: M=16 K=16 N=8\n#define MMA_8                               8\n#define MMA_16                              16\n// for memory access\n#define THREAD_OPT_ACCESS_BIT_WIDTH_128     128 // LDS.128, cp_async.128, ...\n#define BIT_WIDTH_PER_HALF                  16  // Half precision: FP16\n\n/******************** Register Allocation For GEMM ********************/\n#define REG_PER_THREAD_C_TENSOR_16_16       8   // 8 for FP32 Accumulation\n/********************** Memory Padding Parameters **********************/   \n// Eliminating bank-conflict\n#define PADDING_BYTES_16                    16 // Padding 16 bytes each column\n#define PADDING_SHARED_MEM_FOR_B_8          8  // Padding 8 half  each column, during CopyFromGlobalToShared() for B\n#define PADDING_SHARED_MEM_FOR_C_4          4  // Padding 4 float each column, during StoreToSharedMemoryFromRegister() for C\n/************************* WARP Tiling part-1 *************************/\n#define WARP_ROW_MMA_TENSORS                4\n#define WARP_M                              (WARP_ROW_MMA_TENSORS * MMA_16)       // 64\n#define WARP_K_MMA_TENSORS                  4\n#define WARP_K                              (WARP_K_MMA_TENSORS   * MMA_16)       // 64\ntemplate<int BLOCK_ROW_WARPS_, int BLOCK_COL_WARPS_, int WARP_COL_MMA_TENSORS_>\nstruct TilingConfig {\n    // Depending on \"n\" dimension of the GEMM\n    static constexpr int BLOCK_ROW_WARPS        = BLOCK_ROW_WARPS_;\n    static constexpr int BLOCK_COL_WARPS        = BLOCK_COL_WARPS_;\n    static constexpr int WARP_COL_MMA_TENSORS   = WARP_COL_MMA_TENSORS_;    \n    /************************* WARP Tiling part-2 *************************/\n    static constexpr int WARP_N                 = WARP_COL_MMA_TENSORS * MMA_8;\n    /*************************Thread Block Tiling *************************/\n    static constexpr int TILE_M                 = WARP_M * BLOCK_ROW_WARPS;\n    static constexpr int TILE_N                 = MMA_8  * WARP_COL_MMA_TENSORS * BLOCK_COL_WARPS;\n    static constexpr int TILE_K                 = WARP_K;\n    /********************** #Thread per Thread Block **********************/\n    static constexpr int BLOCK_WARPS   = BLOCK_ROW_WARPS * BLOCK_COL_WARPS;\n    static constexpr int BLOCK_THREADS = BLOCK_WARPS * WARP_SIZE;\n    /******************************* Others *******************************/\n    static constexpr int SMEM_SIZE_B_TILE   = TILE_N * (TILE_K + PADDING_BYTES_16) * 2 * PIPELINE_LEVEL_GMEM;          // sizeof(half)=2, doubleBuffer=2\n    static constexpr int SMEM_SIZE_C_TILE   = TILE_N * (TILE_M + PADDING_BYTES_16) * 4;                             // sizeof(float)=4\n};\n\n\n\n#endif  // CONFIGS_H"
  },
  {
    "path": "fp6_llm/csrc/include/kernel_matmul.cuh",
    "content": "#include \"configs.h\"\n#include \"utils_gmem.cuh\"\n#include \"utils_core.cuh\"\n\n/************************** Bitwidth of Weight Segments ************************/\n#define BIT_WIDTH_1                 1\n#define BIT_WIDTH_2                 2\n#define BIT_WIDTH_4                 4\n/*************************** 64*64 Weghts of Weight Matrix *********************/\n#define WEIGHT_PER_WARP            (WARP_M*WARP_K)                                                            // 64*64 = 4096\n#define SMEM_SIZE_PER_WARP_1BIT    (WEIGHT_PER_WARP*BIT_WIDTH_1/8)                                            // 512 Bytes,  doubleBuffer not taken into consideration\n#define SMEM_SIZE_PER_WARP_2BIT    (WEIGHT_PER_WARP*BIT_WIDTH_2/8)                                            // 1024 Bytes, doubleBuffer not taken into consideration\n#define SMEM_SIZE_PER_WARP_4BIT    (WEIGHT_PER_WARP*BIT_WIDTH_4/8)                                            // 2048 Bytes, doubleBuffer not taken into consideration\n#define SMEM_SIZE_PER_TB_1BIT      (SMEM_SIZE_PER_WARP_1BIT*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM)    // #WARP=4; Trible-Buffer for 3-level pipeline for A = 6 KB;  double buffer for 2-level pipeline A= 4  KB. \n#define SMEM_SIZE_PER_TB_2BIT      (SMEM_SIZE_PER_WARP_2BIT*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM)    // #WARP=4; Trible-Buffer for 3-level pipeline for A = 12 KB; double buffer for 2-level pipeline A= 8  KB.   \n#define SMEM_SIZE_PER_TB_4BIT      (SMEM_SIZE_PER_WARP_4BIT*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM)    // #WARP=4; Trible-Buffer for 3-level pipeline for A = 24 KB; double buffer for 2-level pipeline A= 16 KB.\n#define SMEM_SIZE_PER_TB_A_TILE    (SMEM_SIZE_PER_TB_1BIT+SMEM_SIZE_PER_TB_2BIT+SMEM_SIZE_PER_TB_4BIT)        // used in fp6_linear.cu, Kernel_Ex().\n/******************** Gloabl Memory Layout For QUANTIZED DATA *******************/\n#define NUM_INT4_PER_WARP_1BIT     (WEIGHT_PER_WARP*BIT_WIDTH_1/128)                                          // 32\n#define NUM_INT4_PER_WARP_2BIT     (WEIGHT_PER_WARP*BIT_WIDTH_2/128)                                          // 64\n#define NUM_INT4_PER_WARP_4BIT     (WEIGHT_PER_WARP*BIT_WIDTH_4/128)                                          // 128\n\n/*\n * C = A*B\n * A: row major with ahead-of-time layout transformation, FP6\n * B: col major, FP16\n * C: col major, FP16\n */ \n template<typename TilingConfig, typename OutputDataType, int EXPONENT, int MANTISSA>\n__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,\n                                  const half *B,\n                                  OutputDataType* C,\n                                  const size_t M_Global, const size_t N_Global, const size_t K_Global,\n                                  int Split_K)\n{\n  #ifdef DEBUG_MODE\n    assert(K_Global%TilingConfig::TILE_K==0);\n    assert(M_Global%TilingConfig::TILE_M==0);\n    assert( gridDim.y == Split_K * (M_Global/TilingConfig::TILE_M));\n  #endif\n  // 1+2+4 weight split\n  constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;\n  constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;\n  constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;\n  constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;\n  const uint4* Weight_1bit = Weight;\n  const uint4* Weight_2bit = Weight_1bit + (USE_SEG_1BIT ? M_Global*K_Global*BIT_WIDTH_1/128 : 0);\n  const uint4* Weight_4bit = Weight_2bit + (USE_SEG_2BIT ? M_Global*K_Global*BIT_WIDTH_2/128 : 0);\n  // Dynamic shared memory for FP16 A tiles， 128 Bytes aligned\n  extern __shared__ __align__(128) half smem[];   \n  half (*smem_array)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = reinterpret_cast<half (*)[WARP_K+PADDING_SHARED_MEM_FOR_B_8]> ( smem + SMEM_SIZE_PER_TB_A_TILE/2 ); // Dynamic shared memory for FP16 B tiles\n  __shared__ half QuantScales[64*TilingConfig::BLOCK_WARPS];  // static shared memory for quantization scales, 64 row per warp * 4 warps = 512 Bytes\n  // Thread Block Mapping, considering SplitK\n  const size_t BatchID = blockIdx.y / (M_Global/TilingConfig::TILE_M);\n  const size_t x = blockIdx.x;                                     // Output Block ID: (BlockID_Row = y; BlockID_Col = x )\n  const size_t y = blockIdx.y % (M_Global/TilingConfig::TILE_M);   // Output Block ID: (BlockID_Row = y; BlockID_Col = x )\n  const size_t Tile_Start_M = y * TilingConfig::TILE_M;\n  const size_t Tile_Start_N = x * TilingConfig::TILE_N;\n  const size_t NumColumnToCopy = (N_Global-Tile_Start_N) < TilingConfig::TILE_N ? (N_Global-Tile_Start_N) : TilingConfig::TILE_N;\n  const size_t NumBlock_K = K_Global/TilingConfig::TILE_K;\n  const size_t AverageNumBlock_K = NumBlock_K/Split_K;\n  const size_t ExtraNumBlock_K   = NumBlock_K - AverageNumBlock_K * Split_K;\n  size_t NumIter = AverageNumBlock_K;\n  size_t StartBlockID_K = AverageNumBlock_K*BatchID;  \n  if(BatchID<ExtraNumBlock_K){\n    NumIter ++; \n    StartBlockID_K += BatchID;\n  }\n  else\n    StartBlockID_K += ExtraNumBlock_K;\n  // Warp ID.\n  const int warpId = threadIdx.x / WARP_SIZE;\n  int WARP_i = warpId / TilingConfig::BLOCK_COL_WARPS;  // WARP_i: row number;  WARP_j: column number\n  //int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS;\n  // Global Memory Address for Matrix A (Weight) /////////////////////////////////////////////////////////////////////////\n  // StartPTR for each ThreadBlock(TB)\n  const uint4* TB_StartGPTR_A_1BIT = Weight_1bit + (y*TilingConfig::BLOCK_ROW_WARPS)*NumBlock_K * NUM_INT4_PER_WARP_1BIT;\n  const uint4* TB_StartGPTR_A_2BIT = Weight_2bit + (y*TilingConfig::BLOCK_ROW_WARPS)*NumBlock_K * NUM_INT4_PER_WARP_2BIT;\n  const uint4* TB_StartGPTR_A_4BIT = Weight_4bit + (y*TilingConfig::BLOCK_ROW_WARPS)*NumBlock_K * NUM_INT4_PER_WARP_4BIT;\n  // StartPTR for each WARP.\n  const uint4* WARP_StartGPTR_A_1BIT  = TB_StartGPTR_A_1BIT + WARP_i * NumBlock_K * NUM_INT4_PER_WARP_1BIT;\n  const uint4* WARP_StartGPTR_A_2BIT  = TB_StartGPTR_A_2BIT + WARP_i * NumBlock_K * NUM_INT4_PER_WARP_2BIT;\n  const uint4* WARP_StartGPTR_A_4BIT  = TB_StartGPTR_A_4BIT + WARP_i * NumBlock_K * NUM_INT4_PER_WARP_4BIT;\n  // StartPTR for each WARP, considering SplitK\n  const size_t WARP_Start_UnitID_K = StartBlockID_K;\n  WARP_StartGPTR_A_1BIT  += WARP_Start_UnitID_K * NUM_INT4_PER_WARP_1BIT;\n  WARP_StartGPTR_A_2BIT  += WARP_Start_UnitID_K * NUM_INT4_PER_WARP_2BIT;\n  WARP_StartGPTR_A_4BIT  += WARP_Start_UnitID_K * NUM_INT4_PER_WARP_4BIT;\n  // Copying A tile from Global to Shared, using double-buffer //////////////////////////////////////////////////////////\n  // StartSPTR for each ThreadBlock\n  uint32_t* AFrag_1BIT_SPTR = reinterpret_cast<uint32_t*>(smem);\n  uint32_t* AFrag_2BIT_SPTR = AFrag_1BIT_SPTR + SMEM_SIZE_PER_TB_1BIT/4;\n  uint32_t* AFrag_4BIT_SPTR = AFrag_2BIT_SPTR + SMEM_SIZE_PER_TB_2BIT/4;   // 8 buffers including double buffers, 12 for trible buffers\n  // StartSPTR for each WARP\n  AFrag_1BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_1BIT/4;\n  AFrag_2BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_2BIT/4;\n  AFrag_4BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_4BIT/4;\n  // Pre-fetch of A tile\n  for(int i=0; i<PIPELINE_LEVEL_GMEM-1; i++) {\n    if(USE_SEG_1BIT) CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_1BIT>(AFrag_1BIT_SPTR+i*SMEM_SIZE_PER_WARP_1BIT/4*4, WARP_StartGPTR_A_1BIT);\n    if(USE_SEG_2BIT) CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_2BIT>(AFrag_2BIT_SPTR+i*SMEM_SIZE_PER_WARP_2BIT/4*4, WARP_StartGPTR_A_2BIT);\n    if(USE_SEG_4BIT) CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_4BIT>(AFrag_4BIT_SPTR+i*SMEM_SIZE_PER_WARP_4BIT/4*4, WARP_StartGPTR_A_4BIT);\n    WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT/16;\n    WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT/16;\n    WARP_StartGPTR_A_4BIT += SMEM_SIZE_PER_WARP_4BIT/16;\n  }\n  // Global Memory Address for Matrix A (QuantScale) /////////////////////////////////////////////////////////////////////\n  const half* TB_StartGPTR_A_Scale    = Scales + (y*TilingConfig::BLOCK_ROW_WARPS) * 64;\n  const half* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64;\n  CopyFromGlobalToShared_Scales(QuantScales+WARP_i*64, WARP_StartGPTR_A_Scales);\n  // Copying B tile from Global to Shared, considering SplitK /////////////////////////////////////////////////////////////\n  const half *BTile_GPTR = B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K;\n  for(int i=0; i<PIPELINE_LEVEL_GMEM-1; i++) {\n    CopyFromGlobalToShared<TilingConfig::TILE_N, TilingConfig::BLOCK_WARPS> (smem_array+i*TilingConfig::TILE_N, BTile_GPTR, K_Global, NumColumnToCopy);\n    BTile_GPTR += TilingConfig::TILE_K;    \n  }\n  // Register Allocation for A,B, and C, Initilazed to Zeros /////////////////////////////////////////////////////////////////////\n  constexpr int NumRegSets_a = WARP_ROW_MMA_TENSORS;                                                                  // 1 set = 4 registers, containing a 16*16 MMA block\n  constexpr int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2;    // 1 set = 4 registers, containing a 16*16 MMA block\n  uint32_t a  [NumRegSets_a * PIPELINE_LEVEL_SMEM][4];      // double/Trible buffer is used // Registers to store decompressed FP6\n  uint32_t b  [NumRegSets_b * PIPELINE_LEVEL_SMEM][4];      // double/Triple buffer is used // Register to store FP16 B matrix (a slice)\n  float c[NumRegSets_a * NumRegSets_b][REG_PER_THREAD_C_TENSOR_16_16];\n  for(int i=0; i<NumRegSets_a * NumRegSets_b; i++)\n    for(int j=0; j<REG_PER_THREAD_C_TENSOR_16_16; j++)\n      c[i][j] = 0.0f;\n  //\n  cp_async_wait_all();\n  __syncthreads();\n\n  /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n  uint32_t Scales_RPTR[4]; // 4 Registers per thread for Quantization Scales\n  ExtractFromSharedToReg_Scales(Scales_RPTR, QuantScales + WARP_i*64);\n  // Initializing the Software Pipeline: writing registers. ////////////////////////////////////////////////////////////////////////////////////////////////\n  initialize_mma_slice<TilingConfig, EXPONENT, MANTISSA>(a, b, AFrag_1BIT_SPTR, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR);\n  // The outer loop. /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n  #pragma unroll(1)\n  for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++)\n  {\n    // Trible-Buffer for A Tile\n    uint32_t* __restrict__ read_SPTR_Frag_1bit  = AFrag_1BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT/4*4; // 512  (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16\n    uint32_t* __restrict__ read_SPTR_Frag_2bit  = AFrag_2BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16\n    uint32_t* __restrict__ read_SPTR_Frag_4bit  = AFrag_4BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16\n    uint32_t* __restrict__ read2_SPTR_Frag_1bit = AFrag_1BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT/4*4;\n    uint32_t* __restrict__ read2_SPTR_Frag_2bit = AFrag_2BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT/4*4;\n    uint32_t* __restrict__ read2_SPTR_Frag_4bit = AFrag_4BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4;\n    uint32_t* __restrict__ write_SPTR_Frag_1bit = AFrag_1BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1))  % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT/4*4; // 512  (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16\n    uint32_t* __restrict__ write_SPTR_Frag_2bit = AFrag_2BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1))  % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16\n    uint32_t* __restrict__ write_SPTR_Frag_4bit = AFrag_4BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1))  % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16\n    // Trible-Buffer for B Tile\n    half __restrict__ (*read_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0)  % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;\n    half __restrict__ (*read2_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;\n    half __restrict__ (*write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1))  % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;\n    //\n    bool GlobalCopy = (tile_id_k+PIPELINE_LEVEL_GMEM-1) < NumIter;\n    // Copying A tile from Global to Register, Bypassing L1, using double-buffer\n    if(USE_SEG_1BIT) CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_1BIT>(write_SPTR_Frag_1bit, WARP_StartGPTR_A_1BIT, GlobalCopy);\n    if(USE_SEG_2BIT) CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_2BIT>(write_SPTR_Frag_2bit, WARP_StartGPTR_A_2BIT, GlobalCopy);\n    if(USE_SEG_4BIT) CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_4BIT>(write_SPTR_Frag_4bit, WARP_StartGPTR_A_4BIT, GlobalCopy);\n    // copying B tile from GlobalMemory to SharedMemory\n    CopyFromGlobalToShared<TilingConfig::TILE_N, TilingConfig::BLOCK_WARPS> (write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy);\n    cp_async_group_commit();\n    core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each WARP; read_SPTR is shared among WARPs\n    core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 2);\n    core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 3);\n    // Barriers and Synchronizations\n    cp_async_wait_group<PIPELINE_LEVEL_GMEM-2>();\n    __syncthreads();\n    core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit, read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0);\n    // Updating global PTRs\n    WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT/16;  // 2KB/16=128 (1)/16: int4*+1 = char*+16\n    WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT/16;  // 4KB/16=256 (1)/16: int4*+1 = char*+16\n    WARP_StartGPTR_A_4BIT += SMEM_SIZE_PER_WARP_4BIT/16;  // 8KB/16=512 (1)/16: int4*+1 = char*+16\n    BTile_GPTR += TilingConfig::TILE_K;\n  }\n  /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n  /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n  // Store the C fragments to shared memory.\n  float (*smem_CFrag) [TilingConfig::TILE_M+PADDING_SHARED_MEM_FOR_C_4] =\n        reinterpret_cast <float (*)[TilingConfig::TILE_M+PADDING_SHARED_MEM_FOR_C_4]> (smem);\n  StoreToSharedMemoryFromRegister<TilingConfig>(smem_CFrag, c);\n  __syncthreads();\n  // Now that shared memory contains all the D tiles, stream them to global memory.\n  OutputDataType* BlockGlobalPTR = C + BatchID*(M_Global*N_Global) + Tile_Start_M + Tile_Start_N*M_Global;\n  for(size_t i=warpId; i<NumColumnToCopy; i+=TilingConfig::BLOCK_WARPS)    // i-th column\n    #pragma unroll\n    for(size_t j=threadIdx.x%WARP_SIZE; j<TilingConfig::TILE_M; j+=WARP_SIZE) // j-th row\n    {\n      if constexpr (std::is_same<OutputDataType, half>::value)   BlockGlobalPTR[j+i*M_Global] = __float2half_rn(smem_CFrag[i][j]);\n      else                                            BlockGlobalPTR[j+i*M_Global] = smem_CFrag[i][j];\n    }\n}"
  },
  {
    "path": "fp6_llm/csrc/include/kernel_reduction.cuh",
    "content": "/***************************************************************************\n * Copyright 2023 The FLash-LLM Authors. All rights reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n * http://www.apache.org/licenses/LICENSE-2.0\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n ***************************************************************************/\n// Used for the reduction of result matrix if Split-K is used\n// Reduction_Workspace:     (Split_K, M_Global, N_Global),  column major\n// C:                       (M_Global, N_Global),           column major\n// Each thread deals with 8 output elements, each elements is the sum of Split_K elements\n//      Read  Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 float_per_thread (256bit) -> 256 float per warp\n//      Write Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 half_per_thread  (128bit) -> 256 half  per warp\n// GridSize = (M_Global*N_Global) / 256\n\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#define REDUCTION_ELEMENT_PER_THREADBLOCK   256\n#define HALF_PER_128BIT                     8\n\n__global__ void SplitK_Reduction(half* C, float* Reduction_Workspace, size_t M_Global, size_t N_Global, int Split_K)\n{\n    half*  WARP_GPTR_C      = C                    + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x;\n    float* WARP_GPTR_R      = Reduction_Workspace  + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x;\n    half*  THREAD_GPTR_C    = WARP_GPTR_C          + threadIdx.x * HALF_PER_128BIT;    \n    float* THREAD_GPTR_R    = WARP_GPTR_R          + threadIdx.x * HALF_PER_128BIT;\n    // Initializing Thread-Local Results\n    float Results[HALF_PER_128BIT];\n    #pragma unroll\n    for (int i = 0; i < HALF_PER_128BIT; i++)       Results[i] = 0.0f;\n    // Reduction\n    for (int i = 0; i < Split_K; i++) {\n        #pragma unroll\n        for (int j = 0; j < HALF_PER_128BIT; j++)   Results[j] += THREAD_GPTR_R[j];\n        THREAD_GPTR_R += M_Global * N_Global;\n    }\n    // Writing to global memory\n    #pragma unroll\n    for (int i = 0; i < HALF_PER_128BIT; i++)       THREAD_GPTR_C[i] = __float2half_rn(Results[i]);\n}\n"
  },
  {
    "path": "fp6_llm/csrc/include/ptx_cp.async.cuh",
    "content": "/***************************************************************************\n * Copyright 2023 The FLash-LLM Authors. All rights reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n * http://www.apache.org/licenses/LICENSE-2.0\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n ***************************************************************************/\n// Extended from CUTLASS's source code\n\n#ifndef PTX_CP_ASYNC_CUH\n#define PTX_CP_ASYNC_CUH\n\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\ntemplate<int SizeInBytes>\n__device__ __forceinline__ void cp_async(half* smem_ptr, const half* global_ptr, bool pred_guard = true)\n{\n    static_assert(SizeInBytes == 16, \"Size is not supported\");\n    unsigned smem_int_ptr = __cvta_generic_to_shared(smem_ptr);\n    asm volatile(\"{ \\n\"\n                 \"  .reg .pred p;\\n\"\n                 \"  setp.ne.b32 p, %0, 0;\\n\"\n                 \"  @p cp.async.cg.shared.global [%1], [%2], %3;\\n\"\n                 \"}\\n\" ::\"r\"((int)pred_guard),\n                 \"r\"(smem_int_ptr),\n                 \"l\"(global_ptr),\n                 \"n\"(SizeInBytes));\n}\n\n/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block.\n__device__ __forceinline__ void cp_async_group_commit()\n{\n    asm volatile(\"cp.async.commit_group;\\n\" ::);\n}\n\n/// Blocks until all but <N> previous cp.async.commit_group operations have committed.\ntemplate<int N>\n__device__ __forceinline__ void cp_async_wait_group()\n{\n    asm volatile(\"cp.async.wait_group %0;\\n\" ::\"n\"(N));\n}\n\n/// Blocks until all previous cp.async.commit_group operations have committed.\n// cp.async.wait_all is equivalent to :\n// cp.async.commit_group;\n// cp.async.wait_group 0;\n__device__ __forceinline__ void cp_async_wait_all()\n{\n    asm volatile(\"cp.async.wait_all;\\n\" ::);\n}\n\n#endif"
  },
  {
    "path": "fp6_llm/csrc/include/ptx_mma.cuh",
    "content": "/***************************************************************************\n * Copyright 2023 The FLash-LLM Authors. All rights reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n * http://www.apache.org/licenses/LICENSE-2.0\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n ***************************************************************************/\n#ifndef PTX_MMA_CUH\n#define PTX_MMA_CUH\n\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <assert.h>\n#include \"configs.h\"\n\ntemplate <typename TilingConfig>\n__device__ __forceinline__ void B_FromSharedToReg(uint32_t  __restrict__    Reg[][4],\n                                                  half      __restrict__    (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],\n                                                  int                       slice_id) {\n    #ifdef DEBUG_MODE\n        static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) );\n    #endif\n    \n    const int   warpId  = threadIdx.x / WARP_SIZE;\n    int         lane_id = threadIdx.x % WARP_SIZE;\n    int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS;\n    int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * WARP_j;   // each warp may start from reading warp_start_col'th column of the B tile in shared memory\n    #ifdef DEBUG_MODE\n        assert( warp_start_col==0 );\n    #endif    \n\n    int col = (lane_id%8) + (lane_id/16)*8;\n    int row = (lane_id%16) / 8 * 8;\n    uint32_t smem_local_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&read_SPTR[warp_start_col+col][slice_id*MMA_16 + row]));\n    if(TilingConfig::WARP_COL_MMA_TENSORS==1) {\n        asm volatile(\"ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\\n\"\n                     : \"=r\"(Reg[0][0]), \"=r\"(Reg[0][1])\n                     : \"r\"(smem_local_ptr));\n    }\n    else {\n        #pragma unroll\n        for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS/2; i++)\n        {\n            asm volatile(\"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\\n\"\n                         : \"=r\"(Reg[i][0]), \"=r\"(Reg[i][1]), \"=r\"(Reg[i][2]), \"=r\"(Reg[i][3])\n                         : \"r\"(smem_local_ptr));\n            smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(half);\n        }\n    }\n}\n\n__device__ __forceinline__ void\nMMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b)\n{\n    asm volatile(\"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\"\n                 \"{ %0, %1, %2, %3},\"\n                 \"{ %4, %5, %6, %7 },\"\n                 \"{ %8, %9 },\"\n                 \"{ %10, %11, %12, %13 };\"\n                 : \"=r\"(c[0]), \"=r\"(c[1]), \"=r\"(c[2]), \"=r\"(c[3])\n                 : \"r\"(a[0]), \"r\"(a[1]), \"r\"(a[2]), \"r\"(a[3]),\n                   \"r\"(b[0]), \"r\"(b[1]),\n                   \"r\"(c[0]), \"r\"(c[1]), \"r\"(c[2]), \"r\"(c[3]));\n}\n\n#endif"
  },
  {
    "path": "fp6_llm/csrc/include/utils_core.cuh",
    "content": "#ifndef UTILS_CORE_CUH\n#define UTILS_CORE_CUH\n\n#include <assert.h>\n\n#include \"configs.h\"\n#include \"ptx_mma.cuh\"\n#include \"utils_parallel_dequant.cuh\"\n\n\ntemplate<int NUM_INT_PER_THREAD>\n__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR, int slice_id) {\n    SPTR += slice_id * (NUM_INT_PER_THREAD*WARP_SIZE);\n    int     lane_id = threadIdx.x % WARP_SIZE;\n    #pragma unroll\n    for(int i=0; i<NUM_INT_PER_THREAD; i++) {\n        Reg[i] = SPTR[lane_id+i*WARP_SIZE];\n    }\n}\n\ntemplate <typename TilingConfig, int EXPONENT, int MANTISSA>\n__device__ __forceinline__ void initialize_mma_slice(uint32_t                  (*a)[4],\n                                                     uint32_t                  (*b)[4],\n                                                     uint32_t* __restrict__    A_1BIT_SPTR_read,\n                                                     uint32_t* __restrict__    A_2BIT_SPTR_read,\n                                                     uint32_t* __restrict__    A_4BIT_SPTR_read,\n                                                     half      __restrict__    (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],\n                                                     uint32_t*                 RPTR_Scales)\n{\n    // 1+2+4 weight split\n    constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;\n    constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;\n    constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;\n    constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;\n    // Writing registers\n    // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread;\n    uint32_t a_1bit[1];                      // NO double buffer\n    uint32_t a_2bit[2];                      // NO double buffer\n    uint32_t a_4bit[4];                      // NO double buffer\n    if(USE_SEG_1BIT) CopyFromSharedToRegister_AFrag<1>   (a_1bit, A_1BIT_SPTR_read, 0);\n    if(USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2>   (a_2bit, A_2BIT_SPTR_read, 0);\n    if(USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4>   (a_4bit, A_4BIT_SPTR_read, 0);\n    Dequant_32FP6_4Way<EXPONENT, MANTISSA>(a, a_1bit, a_2bit, a_4bit, RPTR_Scales);   // SIMT Dequant: dequantizing FPx to FP16 at register level, dequantizing a slice each time \n    B_FromSharedToReg<TilingConfig>(b, B_SPTR_read, 0); // Loading B from shared to registers\n}\n\ntemplate <typename TilingConfig, int EXPONENT, int MANTISSA>\n__device__ __forceinline__ void core_mma_slice(float                     c[][REG_PER_THREAD_C_TENSOR_16_16],\n                                               uint32_t                  (*a)[4],\n                                               uint32_t                  (*b)[4],\n                                               uint32_t* __restrict__    A_1bit_SPTR_read,\n                                               uint32_t* __restrict__    A_2bit_SPTR_read,\n                                               uint32_t* __restrict__    A_4bit_SPTR_read,\n                                               half      __restrict__    (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],\n                                               uint32_t*                 RPTR_Scales,\n                                               int                       slice_id)      // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching\n{\n    // 1+2+4 weight split\n    constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;\n    constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;\n    constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;\n    constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;\n\n    #ifdef DEBUG_MODE\n        assert((TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0));   // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block\n    #endif\n    const int NumRegSets_a = WARP_ROW_MMA_TENSORS;                                                                              // 1 set = 4 registers, containing a 16*16 MMA block\n    const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2;                // 1 set = 4 registers, containing a 16*16 MMA block\n    uint32_t (*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = reinterpret_cast<uint32_t(*)[REG_PER_THREAD_C_TENSOR_16_16]>(c);    // Reigsters for accumulated FP32 results\n\n    // Setting RPTRs for double buffers\n    uint32_t (*a_read )[4] = a;\n    uint32_t (*a_write)[4] = a;\n    uint32_t (*b_read )[4] = b;\n    uint32_t (*b_write)[4] = b;\n    if(slice_id%2==1)   { b_write += NumRegSets_b; a_write += NumRegSets_a;} \n    else                { b_read  += NumRegSets_b; a_read  += NumRegSets_a;}\n\n    // Reading registers and issuing core tensor core computations (a slice of A and B tile in shared memory)\n    #pragma unroll\n    for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) {\n        if(TilingConfig::WARP_COL_MMA_TENSORS==1) {\n            MMA_FP16_M16N8K16( c_uint_ptr[i], a_read[i], b_read[0] );\n        }\n        else {            \n            #pragma unroll\n            for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS/2; j++) {\n                MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS],     a_read[i], b_read[j]     );\n                MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, a_read[i], b_read[j] + 2 ); // c+4; b+2\n            }\n        }\n    }\n    // Writing registers\n    // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread;\n    uint32_t a_1bit[1];                      // NO double buffer\n    uint32_t a_2bit[2];                      // NO double buffer\n    uint32_t a_4bit[4];                      // NO double buffer\n    if(USE_SEG_1BIT) CopyFromSharedToRegister_AFrag<1>   (a_1bit, A_1bit_SPTR_read, slice_id);\n    if(USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2>   (a_2bit, A_2bit_SPTR_read, slice_id);\n    if(USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4>   (a_4bit, A_4bit_SPTR_read, slice_id);\n    Dequant_32FP6_4Way<EXPONENT, MANTISSA>(a_write, a_1bit, a_2bit, a_4bit, RPTR_Scales);   // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time \n    B_FromSharedToReg<TilingConfig>     (b_write, B_SPTR_read, slice_id); // Loading B from shared to registers\n}\n\ntemplate <typename TilingConfig>\n__device__ __forceinline__ void StoreToSharedMemoryFromRegister(float (*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4],\n                                                                float c[][REG_PER_THREAD_C_TENSOR_16_16])\n{\n    const int   lane_id             = threadIdx.x % WARP_SIZE;\n    const int   warpId              = threadIdx.x / WARP_SIZE;\n    int         warp_row_offset     = warpId * (MMA_16 * WARP_ROW_MMA_TENSORS);\n    #pragma unroll\n    for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) {\n        #pragma unroll\n        for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS; j++) {    // Dealing with one 16*8 Tensor\n            int RegSetID            = i + (j/2)*WARP_ROW_MMA_TENSORS;\n            int RegOffset           = (j%2)*(REG_PER_THREAD_C_TENSOR_16_16/2);\n            int Tensor_row_offset   = warp_row_offset + i * MMA_16;\n            int Tensor_col_offset   = j * MMA_8;\n            #pragma unroll\n            for (int r = 0; r < REG_PER_THREAD_C_TENSOR_16_16/2; r++) {\n                int row_offset = lane_id / 4;\n                if (r >= 2) row_offset += 8;\n                int col_offset = (lane_id % 4) * 2;\n                if (r%2==1) col_offset += 1;\n                smem_CFrag[Tensor_col_offset + col_offset][Tensor_row_offset + row_offset] = c[RegSetID][r + RegOffset];\n            }\n        }\n    }\n}\n\n#endif"
  },
  {
    "path": "fp6_llm/csrc/include/utils_gmem.cuh",
    "content": "#ifndef UTILS_GMEM_CUH\n#define UTILS_GMEM_CUH\n\n#include <assert.h>\n#include \"configs.h\"\n#include \"ptx_cp.async.cuh\"\n\n/* \n * Copying A1/A2 from global memory to shared memory.\n * Usually 1024 or 2048 Bytes\n */\ntemplate<int SMEM_SIZE_IN_BYTES_PER_WARP>\n__device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, \n                                                        const uint4* GPTR,\n                                                        bool pred_guard = true) {\n    #ifdef DEBUG_MODE\n        static_assert(SMEM_SIZE_IN_BYTES_PER_WARP/WARP_SIZE % 16 == 0);\n    #endif\n    int lane_id      = threadIdx.x % WARP_SIZE;\n    half* SPTR_HALF = reinterpret_cast<half*>(SPTR);\n    const half* GPTR_HALF = reinterpret_cast<const half*>(GPTR);\n    SPTR_HALF += lane_id*8;\n    GPTR_HALF += lane_id*8;\n    #pragma unroll\n    for(int i=0; i<SMEM_SIZE_IN_BYTES_PER_WARP/WARP_SIZE/16; i++) {\n        cp_async<16>( SPTR_HALF, GPTR_HALF, pred_guard);\n        SPTR_HALF += 256;   // Forward 512 Bytes\n        GPTR_HALF += 256;   // Forward 512 Bytes\n    }\n\n}\n\n/* \n * Copying 64 Quant Scales (FP16) from global memory to shared memory.\n */\n__device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantScales,\n                                                              const half* GPTR_A_Scales) {\n    int lane_id         = threadIdx.x % WARP_SIZE;\n    int Offset_Shared   = lane_id*2; \n    int Offset_Global   = lane_id/4 + (lane_id%4)*16;\n    for(int i=0; i<2; i++)  SPTR_QuantScales[Offset_Shared+i] = GPTR_A_Scales[Offset_Global+i*8];\n}\n\n/* \n * (1) Copying X  rows * 64 columns of FP16 values, originally in row    major\n * (2) Copying 64 rows * X  columns of FP16 values, originally in column major\n * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 Threads\n */\ntemplate<int MaxNumOfLinesToCopy, int BLOCK_WARPS>\n__device__ __forceinline__ void CopyFromGlobalToShared(half __restrict__ (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],\n                                                       const half*       GlobalPTR,\n                                                       const int         GlobalStride,\n                                                       const int         NumOfLinesLeft,        // To support arbitrary N dimensions.\n                                                       bool              Pred = true) {\n    // static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time\n    const int NumOfThreads  = BLOCK_WARPS * WARP_SIZE;\n    const int NumOfGroups   = NumOfThreads / 8; \n    const int MaxIteration  = (MaxNumOfLinesToCopy-1) / NumOfGroups + 1;\n    // runtime variables   \n    const int line_id       = threadIdx.x / 8;\n    const int line_offset   = (threadIdx.x%8) * 8;\n    // PTR for source global memory and target shared memory\n    GlobalPTR += line_id * GlobalStride + line_offset;\n    SharedPTR += line_id;\n    #pragma unroll\n    for (int i = 0; i < MaxIteration; i++) {\n        bool AsyncCopyPred = (line_id+i*NumOfGroups) < NumOfLinesLeft && Pred;\n        cp_async<16>( &(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred);\n        //\n        GlobalPTR += NumOfGroups * GlobalStride;\n        SharedPTR += NumOfGroups;\n    }\n}\n\n#endif"
  },
  {
    "path": "fp6_llm/csrc/include/utils_parallel_dequant.cuh",
    "content": "#ifndef UTILS_PARALLELDEQUANT_CUH\n#define UTILS_PARALLELDEQUANT_CUH\n\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n/*\n * Input:   R1\n * Outputs: R1, R2\n * Note:    Simplified Exponent calculation is applied.\n */\ntemplate<int EXPONENT, int MANTISSA>\n__device__ __forceinline__ void FPx_FP16_Cast_4Way(u_int32_t *In, u_int32_t *Out1, u_int32_t *Out2) {\n    //\n    constexpr int RIGHT_SHIFT = 5 - EXPONENT;\n    constexpr int MASK1 = 0x80000000;\n    constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA;\n    constexpr int MASK3 = MASK2 & 0x7fffffff;\n    constexpr int MASK  = MASK3 | MASK3 >> 16;\n    //\n    *Out1  = *In & 0x80008000;\n    *Out1 |= ( (*In) & MASK ) >> RIGHT_SHIFT;\n    //\n    *In    = (*In) << 8;\n    *Out2  = *In & 0x80008000;\n    *Out2 |= ( (*In) & MASK ) >> RIGHT_SHIFT;\n}\n\ntemplate<int EXPONENT, int MANTISSA>\n__device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale) {\n    constexpr int BIAS_OFFSET = (int(1) << (5-1)) - (int(1) << (EXPONENT-1));\n    constexpr int BIAS        = int(1) << BIAS_OFFSET;\n    //\n    half* FP16_1 = reinterpret_cast<half*>(&PackedFP16Pair);\n    half* FP16_2 = FP16_1 + 1;\n    uint32_t output;\n    half* output_half_ptr = reinterpret_cast<half*>(&output);\n    output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(1.0f*BIAS)), Scale);\n    output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2half(1.0f*BIAS)), Scale);   \n    return output;\n}\n\ntemplate<int EXPONENT, int MANTISSA>\n__device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__   Reg[][4], \n                                                   u_int32_t __restrict__   *read_RPTR_1bit,\n                                                   u_int32_t __restrict__   *read_RPTR_2bit, \n                                                   u_int32_t __restrict__   *read_RPTR_4bit,\n                                                   u_int32_t                *Scales) {\n    // 1+2+4 weight split\n    constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;\n    constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;\n    constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;\n    constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;\n    //\n    u_int32_t *OutputRegs    = reinterpret_cast<u_int32_t*> (Reg);\n    u_int32_t *Frag_PTR_1bit = read_RPTR_1bit;\n    u_int32_t *Frag_PTR_2bit = read_RPTR_2bit;\n    u_int32_t *Frag_PTR_4bit = read_RPTR_4bit;\n    half      *Scale_RPTR    = reinterpret_cast<half*>(Scales);\n    // Dequantizing 32 FP6, each Loop dequantizing 4 FP6\n    #pragma unroll(8)\n    for(int i=0; i<8; i++) { \n        u_int32_t Packed_FP6 = 0;\n        u_int32_t tmp        = 0;\n        // 1bit Frag\n        if(USE_SEG_1BIT) {\n            tmp = (*Frag_PTR_1bit) & 0x80808080;\n            Packed_FP6 |= tmp >> (BIT_WIDTH & 0);\n            if(i%8==7)  Frag_PTR_1bit++;\n            else        (*Frag_PTR_1bit) = (*Frag_PTR_1bit) << 1;\n        }\n        // 2bit Frag\n        if(USE_SEG_2BIT) {\n            tmp = (*Frag_PTR_2bit) & 0xc0c0c0c0;\n            Packed_FP6 |= tmp >> (BIT_WIDTH & 1);\n            if(i%4==3)  Frag_PTR_2bit++;\n            else        (*Frag_PTR_2bit) = (*Frag_PTR_2bit) << 2;\n        }\n        // 4bit Frag2\n        if(USE_SEG_4BIT) {\n            tmp = (*Frag_PTR_4bit) & 0xf0f0f0f0;\n            Packed_FP6 |= tmp >> (BIT_WIDTH & 3);\n            if(i%2==1)  Frag_PTR_4bit++;\n            else        (*Frag_PTR_4bit) = (*Frag_PTR_4bit) << 4;\n        }\n        //\n        u_int32_t out1, out2;\n        FPx_FP16_Cast_4Way<EXPONENT, MANTISSA>(&Packed_FP6, &out1, &out2);\n        //\n        *OutputRegs = MultScale<EXPONENT, MANTISSA>(out1, Scale_RPTR[0]  );       // Muliply FP16 scales\n        OutputRegs += 1;\n        *OutputRegs = MultScale<EXPONENT, MANTISSA>(out2, Scale_RPTR[1]);         // Muliply FP16 scales\n        OutputRegs += 1;\n        // Updating offset for FP16 scales for every two iterations\n        if(i%2==1)  Scale_RPTR += 2;\n    }\n    \n}\n\n/*\n * \n */\n__device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, half* WARP_SPTR_Scales) {\n    int lane_id = threadIdx.x % WARP_SIZE;\n    uint32_t* SPTR_uint = reinterpret_cast<uint32_t*>(WARP_SPTR_Scales);\n    uint32_t tmpReg = SPTR_uint[lane_id];\n    #pragma unroll\n    for(int i=0; i<4; i++) {\n        // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize); \n        Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4); \n    }\n}\n\n#endif"
  },
  {
    "path": "fp6_llm/csrc/pybind.cpp",
    "content": "#include <pybind11/pybind11.h>\n#include <torch/extension.h>\n\n#include \"fp6_linear.cuh\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m)\n{\n    // Old Interfaces.\n    m.def(\"linear_forward_cuda\", &fp6_linear_forward_cuda, \"Computes FP6-FP16 GEMM.\");\n    m.def(\"weight_prepacking_cpu\", &weight_matrix_prepacking_cpu, \"Weight prepacking.\");\n    m.def(\"weight_dequant_cpu\", &weight_matrix_dequant_cpu, \"Dequantize weight from fp6 to fp16.\");\n    // New Interfaces.\n    m.def(\"linear_forward_eXmY_cuda\", &fp_eXmY_linear_forward_cuda, \"Computes FPx-FP16 GEMM.\");\n    m.def(\"weight_prepacking_eXmY_cpu\", &weight_matrix_prepacking_fp_eXmY_cpu, \"FPx Weight prepacking.\");\n    m.def(\"weight_dequant_eXmY_cpu\", &weight_matrix_dequant_fp_eXmY_cpu, \"Dequantize weight from fpx to fp16.\");    \n}"
  },
  {
    "path": "fp6_llm/csrc/utils/common.h",
    "content": "#ifndef UTILS_COMMON_H\n#define UTILS_COMMON_H\n\ntemplate<int EXPONENT, int MANTISSA>\nunsigned char Extract_X_Bits_To_A_Byte(unsigned char* Bytes, int ByteOffset, int BitOffset){\n    assert (sizeof(unsigned int)==4);\n    unsigned int tmp_int32_word=0;\n    unsigned char* uchar_ptr = reinterpret_cast<unsigned char*>(&tmp_int32_word);\n    uchar_ptr[3] = Bytes[ByteOffset+0];\n    uchar_ptr[2] = Bytes[ByteOffset+1];\n    tmp_int32_word = tmp_int32_word << BitOffset;\n    //\n    signed int mask = 0x80000000;\n    mask = mask >> (EXPONENT+MANTISSA);\n    tmp_int32_word &= mask;\n    //\n    unsigned char out = uchar_ptr[3];\n    return out;\n}\n\n#endif"
  },
  {
    "path": "fp6_llm/csrc/utils/weight_dequant.h",
    "content": "#include <stdio.h>\n#include <stdlib.h>\n#include <assert.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include \"common.h\"\n\ntemplate<int EXPONENT, int MANTISSA>\nvoid DeQuantMatrix_FPx_To_FP16(half* A_16bit_h, unsigned char* A_x_bit_h, size_t M, size_t K, half* scale) {\n    //\n    assert(M%64==0);                 // Currently, M must be a multiple of 64.\n    assert(K%64==0);                 // Currently, K must be a multiple of 64.\n    constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;\n    assert(BIT_WIDTH<=8);\n    size_t TotalSizeInByte = M * K * BIT_WIDTH / 8;\n    //\n    half* OutPTR = A_16bit_h;\n    for(size_t i=0; i<TotalSizeInByte/BIT_WIDTH; i++) {    // Processing BIT_WIDTH Bytes for each Loop, generating 8 FP16.\n        unsigned char Bytes[BIT_WIDTH];\n        for(int x=0; x<BIT_WIDTH; x++)  Bytes[x] = A_x_bit_h[i*BIT_WIDTH+x];\n        unsigned char OUT[8];\n        for(int x=0; x<8; x++) {                        // Prepare Initial memory layout for Dequant\n            int ByteOffset  = BIT_WIDTH * x / 8;\n            int BitOffset   = BIT_WIDTH * x % 8;\n            OUT[x] = Extract_X_Bits_To_A_Byte<EXPONENT, MANTISSA>(Bytes, ByteOffset, BitOffset);\n        }\n        // Dequant\n        constexpr int MASK1 = 0x80000000;\n        constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA;\n        constexpr int MASK  = MASK2 & 0x7fffffff;\n        constexpr int RIGHT_SHIFT = 5 - EXPONENT;\n        constexpr int BIAS_OFFSET = (int(1) << (5-1)) - (int(1) << (EXPONENT-1));\n        constexpr int BIAS        = int(1) << BIAS_OFFSET;\n        for(int x=0; x<8; x++) {\n            unsigned int OUT_fp16;        // Storing fp16 in the high 16 bits.\n            OUT_fp16 = int(OUT[x]) << 24;\n            OUT_fp16 = (OUT_fp16 & 0x80000000) | ( (OUT_fp16 & MASK) >> RIGHT_SHIFT );\n            OUT_fp16 = OUT_fp16 >> 16;\n            //\n            half* OUT_FP16_PTR = reinterpret_cast<half*>(&OUT_fp16);\n            OutPTR[x] = __float2half_rn ( __half2float(*OUT_FP16_PTR) * (1.0f*BIAS) * __half2float(scale[(8*i)/K]) );\n        }   \n        //\n        OutPTR +=8;\n    }\n}\n\n\nvoid DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale) {\n    DeQuantMatrix_FPx_To_FP16<3, 2>(A_16bit_h, A_6bit_h, M, K, scale);\n}\nvoid dequant_matrix_fp_eXmY_to_fp16(const int EXPONENT, const int MANTISSA, half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale){\n    if(EXPONENT==2 && MANTISSA==2)\n        return DeQuantMatrix_FPx_To_FP16<2, 2>(A_16bit_h, A_6bit_h, M, K, scale);\n    if(EXPONENT==3 && MANTISSA==2)\n        return DeQuantMatrix_FPx_To_FP16<3, 2>(A_16bit_h, A_6bit_h, M, K, scale);\n    printf(\"DeQuantMatrix Error: Unsupported EXPONENT=%d, MANTISSA=%d!\\n\", EXPONENT, MANTISSA);\n    exit(-1);\n}"
  },
  {
    "path": "fp6_llm/csrc/utils/weight_prepacking.h",
    "content": "#include <stdio.h>\n#include <stdlib.h>\n#include <assert.h>\n#include <vector>\n#include \"common.h\"\n\n/*\n * Inputs:\n * (1) unsigned char Weight_6bit [M*K*6/8]\n * Outputs:\n * (1) unsigned char Weight_2bit [M*K*2/8]\n * (2) unsigned char Weight_4bit [M*K*4/8]\n * \n * Assumption: Weight_6bit, Weight_2bit, Weight_4bit all stored continuously in row-major.\n * 8 FP6 = 6 Bytes\n * 8 FP4 = 4 Bytes\n * 8 FP2 = 2 Bytes\n */\n\nusing namespace std;\n\n\nvoid Extract_segments_from_8_padded_fpx(unsigned char Seg_xbit[], unsigned char Padded_8_FPx[], int bit_width, int bit_offset){\n    for(int i=0; i< bit_width; i++)\n        Seg_xbit[i] = 0;\n    for(int i=0; i<8; i++){\n        unsigned int seg = (Padded_8_FPx[i] << bit_offset) & 0x000000ff;\n        int mask = 0xffffff00;\n        seg &= mask >> bit_width;\n        //\n        int Seg_idx = (i * bit_width) / 8;\n        int Seg_off = (i * bit_width) % 8;\n        Seg_xbit[Seg_idx] |= seg >> Seg_off;\n    }\n}\n\n\n// dealing with 4 1*8 blocks of FPx\ntemplate<int EXPONENT, int MANTISSA>\nvoid Assign_32_FPx_To_4_Thread(vector<unsigned char> Vec_Seg_1bit[], vector<unsigned char> Vec_Seg_2bit[], vector<unsigned char> Vec_Seg_4bit[], unsigned char* PTR[]) \n{\n    constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;\n    assert(BIT_WIDTH<8);\n    constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;\n    constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;\n    constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;\n    //\n    constexpr int nTHREADS = 4;\n    constexpr int FPx_PER_THREAD = 8;\n    unsigned char Padded_8_FPx[nTHREADS][FPx_PER_THREAD];\n    for(int i=0; i<nTHREADS; i++){                             // 4 threads\n        for(int j=0; j<FPx_PER_THREAD; j++){                   // 8 FPx per thread\n            int offset = (i*2 + j%2) * BIT_WIDTH;\n            int ByteOffset = offset / 8;\n            int BitOffset  = offset % 8;\n            Padded_8_FPx[i][j] = Extract_X_Bits_To_A_Byte<EXPONENT, MANTISSA>(PTR[j/2], ByteOffset, BitOffset);\n        }\n    }\n    //\n    unsigned char Seg_1bit[nTHREADS][1];\n    unsigned char Seg_2bit[nTHREADS][2];\n    unsigned char Seg_4bit[nTHREADS][4];\n    for(int t=0; t<nTHREADS; t++){\n        Extract_segments_from_8_padded_fpx(Seg_1bit[t], Padded_8_FPx[t], 1, int(BIT_WIDTH & 0));\n        Extract_segments_from_8_padded_fpx(Seg_2bit[t], Padded_8_FPx[t], 2, int(BIT_WIDTH & 1));\n        Extract_segments_from_8_padded_fpx(Seg_4bit[t], Padded_8_FPx[t], 4, int(BIT_WIDTH & 3));\n    }\n    //\n    for(int t=0; t<4; t++)\n    {\n        if (USE_SEG_1BIT) {\n            Vec_Seg_1bit[t].push_back(Seg_1bit[t][0]);\n        }\n        if (USE_SEG_2BIT) {\n            Vec_Seg_2bit[t].push_back(Seg_2bit[t][0]);\n            Vec_Seg_2bit[t].push_back(Seg_2bit[t][1]);\n        }\n        if (USE_SEG_4BIT) {\n            Vec_Seg_4bit[t].push_back(Seg_4bit[t][0]);\n            Vec_Seg_4bit[t].push_back(Seg_4bit[t][1]);\n            Vec_Seg_4bit[t].push_back(Seg_4bit[t][2]);\n            Vec_Seg_4bit[t].push_back(Seg_4bit[t][3]);\n        }\n    }\n}\n\ntemplate<int BIT_WIDTH>\nvoid BitInterleaving_x_bit(unsigned char* PTR_4Bytes)\n{\n    unsigned int *PTR_UINT = reinterpret_cast<unsigned int*>(PTR_4Bytes);\n    unsigned int input  = *PTR_UINT;\n    //\n    int* order = NULL;\n    int order_1bit[32] = {2,6,10,14,18,22,26,30,\n                          4,8,12,16,20,24,28,32,\n                          1,5,9, 13,17,21,25,29,\n                          3,7,11,15,19,23,27,31};  // pre-defined order for bit-interleaving in FP6-LLM\n    int order_2bit[16] = {2,6,10,14,4,8,12,16,1,5,9,13,3,7,11,15};  // pre-defined order for bit-interleaving in FP6-LLM\n    int order_4bit[8] = {2,6,4,8,1,5,3,7};  // pre-defined order for bit-interleaving in FP6-LLM\n    if(BIT_WIDTH==1) order = order_1bit;\n    if(BIT_WIDTH==2) order = order_2bit;\n    if(BIT_WIDTH==4) order = order_4bit;\n    assert(order);\n    //\n    int mask = 0x80000000;\n    assert(BIT_WIDTH>=1);\n    mask = mask >> (BIT_WIDTH-1);\n    //\n    unsigned int output = 0x00000000;\n    for(int i=0; i<32/BIT_WIDTH; i++){\n        unsigned int Frag_xbit = ( input << BIT_WIDTH*(order[i]-1) ) & mask;    // The highest x bits are used to store the extracted fragments.\n        output |= Frag_xbit >> (i*BIT_WIDTH);\n    }\n    //\n    *PTR_UINT = output;\n}\n\ntemplate<int EXPONENT, int MANTISSA>\nvoid weight_matrix_prepacking_x_bit(int* packed_weights, int *FPxWeights, size_t M, size_t K)\n{\n    assert(M % 64 == 0);\n    assert(K % 64 == 0);\n    //\n    constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;\n    assert(BIT_WIDTH<8);\n    constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;\n    constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;\n    constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;\n    //\n    unsigned char* Weight_xbit = reinterpret_cast<unsigned char*>(FPxWeights);\n    unsigned char* Weight_1bit = reinterpret_cast<unsigned char*>(packed_weights);\n    unsigned char* Weight_2bit = Weight_1bit + (USE_SEG_1BIT ? M*K*1/8 : 0);\n    unsigned char* Weight_4bit = Weight_2bit + (USE_SEG_2BIT ? M*K*2/8 : 0);\n    //\n    vector<unsigned char> A_Segment_1bit[32];\n    vector<unsigned char> A_Segment_2bit[32];\n    vector<unsigned char> A_Segment_4bit[32];\n    //\n    size_t BytesPerRow = K*BIT_WIDTH/8;\n    // Pass-1: (1) 1+2+4 split; (2) assign weights to 32 threads.\n    for (size_t i = 0; i < M / 64; i++){\n        for (size_t j = 0; j < K / 16; j++){\n            for(size_t k=0; k<64/16; k++){\n                size_t row = i*64 + k*16;\n                size_t col = j*16;\n                unsigned char* StartPTR_1 = Weight_xbit + row*BytesPerRow + col*(BIT_WIDTH)/8;\n                unsigned char* StartPTR_2 = StartPTR_1 + 8*BytesPerRow;\n                unsigned char* StartPTR_3 = StartPTR_1 + 8*(BIT_WIDTH)/8;\n                unsigned char* StartPTR_4 = StartPTR_2 + 8*(BIT_WIDTH)/8;\n                // Dealing with each 16*16 blocks then...\n                for(int l=0; l<8; l++) {\n                    unsigned char* PTR[4]={StartPTR_1+l*BytesPerRow, StartPTR_2+l*BytesPerRow, StartPTR_3+l*BytesPerRow, StartPTR_4+l*BytesPerRow};\n                    Assign_32_FPx_To_4_Thread<EXPONENT,MANTISSA>(&A_Segment_1bit[l*4], &A_Segment_2bit[l*4], &A_Segment_4bit[l*4], PTR);\n                }\n            }\n        }\n    }\n    // Verifying the length of 1/2/4_bit segments.\n    size_t BytesPerThread_1bit = M*K*1/8/32;\n    size_t BytesPerThread_2bit = M*K*2/8/32;\n    size_t BytesPerThread_4bit = M*K*4/8/32;\n    for(int i=0; i<32; i++){\n        if(USE_SEG_1BIT)    assert(A_Segment_1bit[i].size()==BytesPerThread_1bit);\n        else                assert(A_Segment_1bit[i].size()==0);\n        if(USE_SEG_2BIT)    assert(A_Segment_2bit[i].size()==BytesPerThread_2bit);\n        else                assert(A_Segment_2bit[i].size()==0);\n        if(USE_SEG_4BIT)    assert(A_Segment_4bit[i].size()==BytesPerThread_4bit);\n        else                assert(A_Segment_4bit[i].size()==0);\n    }\n    // Pass-2: Optimizing coleasced global memory access\n    if(USE_SEG_1BIT)\n    for(size_t i=0; i<BytesPerThread_1bit/4; i++) \n        for(int t=0; t<32; t++)\n            for(int b=0; b<4; b++)              // why (3-b): special byte order within a register\n                Weight_1bit[i*128+t*4+(3-b)] = A_Segment_1bit[t][i*4+b];    \n    if(USE_SEG_2BIT)\n    for(size_t i=0; i<BytesPerThread_2bit/4; i++) \n        for(int t=0; t<32; t++)\n            for(int b=0; b<4; b++)              // why (3-b): special byte order within a register\n                Weight_2bit[i*128+t*4+(3-b)] = A_Segment_2bit[t][i*4+b];    \n    if(USE_SEG_4BIT)\n    for(size_t i=0; i<BytesPerThread_4bit/4; i++) \n        for(int t=0; t<32; t++)\n            for(int b=0; b<4; b++)              // why (3-b):special byte order within a register\n                Weight_4bit[i*128+t*4+(3-b)] = A_Segment_4bit[t][i*4+b];\n    // Pass-3: Bit-level interleaving\n    if(USE_SEG_1BIT)\n    for(size_t i=0; i<BytesPerThread_1bit*32/4; i++)\n        BitInterleaving_x_bit<1>(Weight_1bit+4*i);\n    if(USE_SEG_2BIT)\n    for(size_t i=0; i<BytesPerThread_2bit*32/4; i++)\n        BitInterleaving_x_bit<2>(Weight_2bit+4*i);\n    if(USE_SEG_4BIT)\n    for(size_t i=0; i<BytesPerThread_4bit*32/4; i++)\n        BitInterleaving_x_bit<4>(Weight_4bit+4*i);\n}\n\n\nvoid weight_matrix_prepacking(int* packed_weights, int *FP6Weights, size_t M, size_t K){\n    weight_matrix_prepacking_x_bit<3, 2>(packed_weights, FP6Weights, M, K);\n}\n\n//\nvoid weight_matrix_prepacking_fp_eXmY(const int EXPONENT, const int MANTISSA, int* packed_weights, int *FPxWeights, size_t M, size_t K){\n    if(EXPONENT==2 && MANTISSA==2)\n        return weight_matrix_prepacking_x_bit<2, 2>(packed_weights, FPxWeights, M, K);\n    if(EXPONENT==3 && MANTISSA==2)\n        return weight_matrix_prepacking_x_bit<3, 2>(packed_weights, FPxWeights, M, K);\n    printf(\"Weight_prepacking Error: Unsupported EXPONENT=%d, MANTISSA=%d!\\n\", EXPONENT, MANTISSA);\n    exit(-1);\n}"
  },
  {
    "path": "fp6_llm/csrc/utils/weight_quant.h",
    "content": "// Author: Zhen Zheng\n// To be used in the future as a tool to generating the FP6 matrix from the FP16 matrix.\n\n#include<iostream>\n\n/*\n * Function to pack 4 fake quantized FP16 value into continuously stored 4 FP6 values.\n */\nvoid cast_fp16_fp6(uint16_t* FP16x4, uint8_t* FP6x4)\n{\n    // Constants for FP6\n    constexpr int exponent_nbits_fp6 = 3;\n    constexpr int mantissa_nbits_fp6 = 2;\n    constexpr int exp_bias_fp6 = (1 << (exponent_nbits_fp6 - 1)) - 1;\n    // Constants for FP16\n    constexpr int exponent_nbits_fp16 = 5;\n    constexpr int mantissa_nbits_fp16 = 10;\n    constexpr int exp_bias_fp16 = (1 << (exponent_nbits_fp16 - 1)) - 1;\n\n    int fp6_temp[4];\n\n    float absmin_nonzero_fp6 = 0.0625;\n    // Note that we regard the exponent of '111' as a regular value rather than NaN or inf. This is\n    // the same with that in qtorch.\n    float absmax_fp6 = 28;\n\n    for (int i = 0; i < 4; ++i) {\n        uint16_t source = FP16x4[i];\n        float fp6_value_abs = std::abs(__half2float(*((half*)(&source))));\n        if ((fp6_value_abs != 0 && fp6_value_abs < absmin_nonzero_fp6) ||\n            fp6_value_abs > absmax_fp6) {\n            // TODO(zhen): a better way may be rounding it to the nearest FP6 value.\n            throw std::invalid_argument(\"Input value out of range for FP6.\");\n        }\n\n        // It is not safe to do shift operation on uint16_t. So we promote it to int.\n        int source_promote = int(source);\n\n        int sign_bit = (source_promote >> 15);\n        // Extracting exponent represented in FP16. The sign mask 0x7FFF is '0111 1111 1111 1111'\n        int exp_bit = (source_promote & 0x7FFF) >> mantissa_nbits_fp16;\n        // Extracting mantissa represented in FP16\n        int mant_bit = source_promote & ((1 << mantissa_nbits_fp16) - 1);\n\n        int new_exp_bit;\n        int new_mant_bit;\n\n        if (exp_bit == 0) {\n            // Subnormal FP16 number. Too small for FP6.\n            new_exp_bit = 0;\n            new_mant_bit = 0;\n        } else {\n            new_mant_bit = mant_bit >> (mantissa_nbits_fp16 - mantissa_nbits_fp6);\n            new_exp_bit = exp_bit - exp_bias_fp16 + exp_bias_fp6;\n\n            // Deal with subnormal FP6 values.\n            int target_exp_val = exp_bit - exp_bias_fp16;\n            int min_fp6_exp_val = -exp_bias_fp6 + 1;\n            bool subnormal_fp6 = target_exp_val < min_fp6_exp_val;\n            if (subnormal_fp6) {\n                // TODO(zhen): add the rounding logic.\n                new_exp_bit = 0;\n                // The implicit 1 in the mantissa of FP16 is not present in subnormal FP6. Thus we\n                // need to add it\n                new_mant_bit = (new_mant_bit | (1 << mantissa_nbits_fp6)) >>\n                               (min_fp6_exp_val - target_exp_val);\n            }\n        }\n\n        fp6_temp[i] = (sign_bit << (exponent_nbits_fp6 + mantissa_nbits_fp6)) |\n                      (new_exp_bit << mantissa_nbits_fp6) | new_mant_bit;\n    }\n    // Pack the values\n    FP6x4[0] = fp6_temp[0] << 2 | (fp6_temp[1] >> 4);\n    FP6x4[1] = (fp6_temp[1] & 0x0F) << 4 | (fp6_temp[2] >> 2);\n    FP6x4[2] = (fp6_temp[2] & 0x03) << 6 | fp6_temp[3];\n}\n\n/*\n *  Function to prepack FP16 weights into continuous FP6 values.\n *\n *  Parameters:\n *     weight_16bit: input weight in FP16, size M*K\n *     weight_6bit: output weight in packed FP6, continuously stored, size M*K*6/8\n *     M, K: the shape of the weight\n */\nvoid weight_prepacking_fp16_to_fp6(uint16_t* weight_16bit,\n                                   uint8_t* weight_6bit_packed,\n                                   size_t M,\n                                   size_t K)\n{\n    // Every four 16-bit elements are packed into three 6-bit values (4*6bit == 3*8bit).\n    if (K * 6 % 8 != 0) { throw std::invalid_argument(\"(K * 6 % 8) should be 0\"); }\n    size_t K_fp6_packed = K * 6 / 8;\n    // #pragma omp parallel for\n    for (auto m = 0; m < M; m++) {\n        uint8_t* ptr_6bit = weight_6bit_packed + m * K_fp6_packed;\n        uint16_t* ptr_16bit = weight_16bit + m * K;\n        for (auto k = 0; k < K; k += 4) {\n            cast_fp16_fp6(ptr_16bit, ptr_6bit);\n            ptr_16bit += 4;\n            ptr_6bit += 3;\n        }\n    }\n}"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import find_packages, setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension\n\n\nextra_compile_args = {\n    \"cxx\": [\n        \"-O3\",\n        \"-std=c++17\"\n    ],\n    \"nvcc\": [\n        \"-O3\", \n        \"--use_fast_math\",\n        \"-std=c++17\",\n        \"-maxrregcount=255\",\n        \"--ptxas-options=-v,-warn-lmem-usage,--warn-on-spills\",\n        \"-gencode=arch=compute_80,code=sm_80\"\n    ],\n}\n\nsetup(\n    name=\"fp6_llm\",\n    author=\"Haojun Xia, Zhen Zheng, Xiaoxia Wu, Shiyang Chen, Zhewei Yao, Stephen Youn, Arash Bakhtiari, Michael Wyatt, Donglin Zhuang, Zhongzhu Zhou, Olatunji Ruwase, Yuxiong He, Shuaiwen Leon Song\",\n    version=\"0.2\",\n    author_email=\"xhjustc@gmail.com\",\n    description =\"An efficient GPU support for LLM inference with x-bit quantization (e.g., FP6 and FP5).\",\n    python_requires=\">=3.8\",\n    install_requires=[\n        \"torch\",\n        \"transformers\"\n    ],\n    packages=find_packages(),\n    ext_modules=[\n        CUDAExtension(\n            name=\"fp6_llm_cuda\",\n            sources=[\n                \"fp6_llm/csrc/pybind.cpp\", \n                \"fp6_llm/csrc/fp6_linear.cu\"\n            ],\n            extra_compile_args=extra_compile_args,\n        ),\n    ],\n    cmdclass={\"build_ext\": BuildExtension}\n)"
  },
  {
    "path": "tests/cpp/Makefile",
    "content": "# host compiler\nHOST_COMPILER ?= g++\nNVCC          := nvcc -ccbin $(HOST_COMPILER)\n\n# internal flags\nNVCCFLAGS   := -m$(shell getconf LONG_BIT)\nCCFLAGS     := -DNO_PYTORCH\nLDFLAGS     := -rpath=../../fp6_llm\n\nALL_CCFLAGS :=\nALL_CCFLAGS += $(NVCCFLAGS)\nALL_CCFLAGS += $(addprefix -Xcompiler ,$(CCFLAGS))\n\nALL_LDFLAGS :=\nALL_LDFLAGS += $(ALL_CCFLAGS)\nALL_LDFLAGS += $(addprefix -Xlinker ,$(LDFLAGS))\n\n# Common includes and paths for CUDA\nINCLUDES  := -I/usr/local/cuda/include/ \nLIBRARIES := -lcublas \n\n#\nINCLUDES  += -I../../fp6_llm/csrc\nLIBRARIES += -L../../fp6_llm -lfp6\n\n################################################################################\n# Gencode arguments\nSMS ?= 80\n# Generate SASS code for each SM architecture listed in $(SMS)\n$(foreach sm,$(SMS),$(eval GENCODE_FLAGS += -gencode arch=compute_$(sm),code=sm_$(sm)))\n################################################################################\n# Target rules\nall: kernel_test_fp6 kernel_test_fpx\n\nkernel_test_fp6.o:  kernel_test_fp6.cu kernel_test.h\n\t$(EXEC) $(NVCC) $(INCLUDES) $(ALL_CCFLAGS) $(GENCODE_FLAGS) -o $@ -c $<\n\nkernel_test_fpx.o:  kernel_test_fpx.cu kernel_test.h\n\t$(EXEC) $(NVCC) $(INCLUDES) $(ALL_CCFLAGS) $(GENCODE_FLAGS) -o $@ -c $<\n\nkernel_test_fp6: kernel_test_fp6.o \n\t$(EXEC) $(NVCC) $(ALL_LDFLAGS) $(GENCODE_FLAGS) -o $@ $+ $(LIBRARIES) \n\nkernel_test_fpx: kernel_test_fpx.o \n\t$(EXEC) $(NVCC) $(ALL_LDFLAGS) $(GENCODE_FLAGS) -o $@ $+ $(LIBRARIES) \n\nclean:\n\trm -f kernel_test_fp6 kernel_test_fp6.o kernel_test_fpx kernel_test_fpx.o"
  },
  {
    "path": "tests/cpp/kernel_test.h",
    "content": "#include <stdio.h>\n#include <stdlib.h>\n#include <assert.h>\n\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n#include <cublas_v2.h>\n\n// Performance Benchmark\n#define WARM_UP_ITERATION 100\n#define BENCHMARK_ITERATION 10000\n\nvoid __forceinline__ CheckMallocCPU(void* PTR, int line = -1) {\n    if (PTR == NULL) {\n        printf(\"Error in CPU Malloc, line %d!\\n\", line);\n        exit(-1);\n    }\n}\n\nvoid __forceinline__ CheckMallocCUDA(void* PTR, int line = -1) {\n    if (PTR == NULL) {\n        printf(\"Error in cudaMalloc, line %d!\\n\", line);\n        exit(-1);\n    }\n}\n\nvoid checkCublasError(cublasStatus_t status, int line)\n{\n    if (status != CUBLAS_STATUS_SUCCESS) {\n        printf(\"Cublas Error at line %d, Error Code: %d\\n\", line, status);\n        exit(EXIT_FAILURE);\n    }\n}\n\nvoid checkLastCudaError(int line)\n{\n    cudaError_t error = cudaGetLastError();\n    if (error != cudaSuccess) {\n        printf(\"Last Cuda Error Detected at line: %d, Error: %s.\\n\", line, cudaGetErrorString(error));\n        exit(EXIT_FAILURE);\n    }\n}\n\n// Note: totalAbsSum might overflow if (1)the shape of output matrix are large and (2)the value of each element within the output matrix is large.\n// The overflow might result in NaN for the \"TotalAbsError/TotalAbsSum\", while the fp6_llm is working correctly.\n// This problem can be fixed by setting the quantizaiton scales to a smaller value.\ndouble ComputeTotalError(half* CuBlas, half* Other, size_t m, size_t n)\n{\n    long double totalError = 0.0;\n    for (size_t i = 0; i < m * n; i++)\n        totalError += fabs(__half2float(CuBlas[i]) - __half2float(Other[i]));\n    long double totalAbsSum = 0.0;\n    for (size_t i = 0; i < m * n; i++)\n        totalAbsSum += fabs(__half2float(CuBlas[i]));    \n    return totalError/totalAbsSum;\n}\n\nvoid PrintPerformance(const char* KernelName, float milliseconds, float tflops, double error)\n{\n    printf(\"%-10s \\t -> \\t\\t Time/ms: %5.3f \\t Performance/TFLOPs: %4.2f \\t TotalAbsError/TotalAbsSum: %.8lf\\n\",\n           KernelName,\n           milliseconds,\n           tflops,\n           error);\n}\n\n\nvoid PrintMismatch(const char* KernelName,\n                   size_t      MaxNumMismatch,\n                   float       RelativeErrorThreshold,\n                   half*       CuBlas,\n                   half*       Other,\n                   size_t         M_GLOBAL,\n                   size_t         N_GLOBAL)\n{\n    //printf(\"First %d Mismatches between Cublas and %s:\\n\", MaxNumMismatch, KernelName);\n    size_t count = 0;\n    for (size_t i = 0; i < M_GLOBAL; i++) {\n        for (size_t j = 0; j < N_GLOBAL; j++) {\n            if (fabs(__half2float(CuBlas[i + j * M_GLOBAL]) - __half2float(Other[i + j * M_GLOBAL]))/fabs(__half2float(CuBlas[i + j * M_GLOBAL])) > RelativeErrorThreshold) {\n                count++;\n                printf(\"(%d,%d) CuBlas=%f %s=%f\\n\",\n                       i,\n                       j,\n                       __half2float(CuBlas[i + j * M_GLOBAL]),\n                       KernelName,\n                       __half2float(Other[i + j * M_GLOBAL]));\n            }\n            if (count == MaxNumMismatch)\n                break;\n        }\n        if (count == MaxNumMismatch)\n            break;\n    }\n}\n\n"
  },
  {
    "path": "tests/cpp/kernel_test_fp6.cu",
    "content": "#include \"kernel_test.h\"\n#include \"fp6_linear.cuh\"\n\nint main(int argc, char** argv)\n{\n    // Parsing the inputs from CLI.\n    if (argc != 5) {\n        printf(\"Wrong Inputs! Correct input format: ./kernel_test #Row_Weight #Column_Weight BatchSize SplitK\\n\");\n        return -1;\n    }\n    size_t M_GLOBAL = atoi(argv[1]);\n    size_t K_GLOBAL = atoi(argv[2]);\n    size_t N_GLOBAL = atoi(argv[3]);\n    int    SPLIT_K  = atoi(argv[4]);\n    assert(M_GLOBAL%256==0);                 // Currently, M_GLOBAL must be a multiple of 256.\n    assert(K_GLOBAL%64==0);                  // Currently, K_GLOBAL must be a multiple of 64.\n    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n    // Matrices in quantized FP6 models with faked values.\n    unsigned char* A_6bit_h  = (unsigned char*)malloc(M_GLOBAL*K_GLOBAL*6/8);       CheckMallocCPU(A_6bit_h, __LINE__);     // Weight matrix with FP6 values, stored in row-major.\n    for(size_t i=0; i<M_GLOBAL*K_GLOBAL*6/8; i++)   A_6bit_h[i] = rand() % 256;                                             // Random initialization.\n    half*          A_Scale_h = (half*)malloc(M_GLOBAL*sizeof(half));                CheckMallocCPU(A_Scale_h, __LINE__);    // Quantization Scales with FP16 values.\n    for(size_t i=0; i<M_GLOBAL; i++)                A_Scale_h[i] = float(rand()%256)/64.0f;                                 // Scale\n    // Generaing FP16 format of the Weight Matrix\n    half* A_16bit_h = (half*) malloc(M_GLOBAL*K_GLOBAL*sizeof(half));                           CheckMallocCPU(A_16bit_h, __LINE__);\n    DeQuantMatrix_FP6_To_FP16(A_16bit_h, A_6bit_h, M_GLOBAL, K_GLOBAL, A_Scale_h);\n    // In-place weight pre-packing\n    weight_matrix_prepacking((int*)A_6bit_h, (int*)A_6bit_h, M_GLOBAL, K_GLOBAL);\n\n    // Devices Memory\n    unsigned char*  A_6bit;\n    half*           A_Scale;\n    half*           A_16bit;\n    cudaMalloc(reinterpret_cast<void**>(&A_6bit),  M_GLOBAL*K_GLOBAL*6/8);             CheckMallocCUDA(A_6bit, __LINE__);\n    cudaMalloc(reinterpret_cast<void**>(&A_Scale), M_GLOBAL*sizeof(half));             CheckMallocCUDA(A_Scale, __LINE__);\n    cudaMalloc(reinterpret_cast<void**>(&A_16bit),          M_GLOBAL*K_GLOBAL*sizeof(half));    CheckMallocCUDA(A_16bit, __LINE__);\n    // Memory Copy from CPU to GPU\n    cudaMemcpy(A_6bit,     A_6bit_h,  M_GLOBAL*K_GLOBAL*6/8,          cudaMemcpyHostToDevice);\n    cudaMemcpy(A_Scale,    A_Scale_h,          M_GLOBAL*sizeof(half),          cudaMemcpyHostToDevice);\n    cudaMemcpy(A_16bit,             A_16bit_h,          M_GLOBAL*K_GLOBAL*sizeof(half), cudaMemcpyHostToDevice);\n    checkLastCudaError(__LINE__);\n    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n    // B Matrix: Activations\n    half* B_h = (half*)malloc(sizeof(half) * K_GLOBAL * N_GLOBAL); CheckMallocCPU(B_h);       // col major \n    for (size_t i = 0; i < N_GLOBAL * K_GLOBAL; i++)\n        B_h[i] = __float2half_rn(static_cast<float>((rand() % 5)) / 5 - 0.5f);\n    // Device memory\n    half* B            = NULL;\n    cudaMalloc(reinterpret_cast<void**>(&B), sizeof(half) * N_GLOBAL * K_GLOBAL);               CheckMallocCUDA(B, __LINE__);\n    // Memory Copy from CPU to GPU\n    cudaMemcpy(B, B_h, sizeof(half) * N_GLOBAL * K_GLOBAL, cudaMemcpyHostToDevice);\n    checkLastCudaError(__LINE__);\n    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n    cublasStatus_t cublas_status;\n    cudaEvent_t start, stop;\n    cudaEventCreate(&start);\n    cudaEventCreate(&stop);\n    checkLastCudaError(__LINE__);\n    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n    half* D_cublas = NULL;\n    cudaMalloc(reinterpret_cast<void**>(&D_cublas), sizeof(half) * M_GLOBAL * N_GLOBAL);        CheckMallocCUDA(D_cublas, __LINE__);\n    cudaMemset(D_cublas, 0, sizeof(half) * M_GLOBAL * N_GLOBAL);\n    cublasHandle_t handle;\n    cublasCreate(&handle);\n    cublasSetStream(handle, 0);\n    //cublasSetMathMode(handle, CUBLAS_PEDANTIC_MATH);          // Tensor core NOT enabled\n    cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH);             // Tensor core enabled\n    cudaDeviceSynchronize();\n    int              m = M_GLOBAL, n = N_GLOBAL, k = K_GLOBAL;\n    const float      alpha     = 1.0;\n    const float      beta      = 0.0;\n    cublasGemmAlgo_t CuBlasALG = static_cast<cublasGemmAlgo_t>(0);\n    for (int i = 0; i < WARM_UP_ITERATION; i++) {\n        cublas_status = cublasGemmEx(handle,\n                                     CUBLAS_OP_T,   CUBLAS_OP_N,\n                                     m, n, k,\n                                     &alpha,\n                                     A_16bit,   CUDA_R_16F, k,\n                                     B,         CUDA_R_16F, k,\n                                     &beta,\n                                     D_cublas,  CUDA_R_16F, m,\n                                     CUDA_R_32F,\n                                     CuBlasALG);\n        checkCublasError(cublas_status, __LINE__);\n    }\n    cudaEventRecord(start);\n    for (int i = 0; i < BENCHMARK_ITERATION; i++)\n        cublas_status = cublasGemmEx(handle,\n                                     CUBLAS_OP_T,   CUBLAS_OP_N,\n                                     m, n, k,\n                                     &alpha,\n                                     A_16bit,   CUDA_R_16F, k,\n                                     B,         CUDA_R_16F, k,\n                                     &beta,\n                                     D_cublas,  CUDA_R_16F, m,\n                                     CUDA_R_32F,\n                                     CuBlasALG);\n    cudaEventRecord(stop);\n    cudaEventSynchronize(stop);\n    //\n    float milliseconds_cublas = 0;\n    cudaEventElapsedTime(&milliseconds_cublas, start, stop);\n    milliseconds_cublas = milliseconds_cublas / BENCHMARK_ITERATION;\n    float tflops_cublas = static_cast<double>((static_cast<double>(M_GLOBAL) * N_GLOBAL * K_GLOBAL * 2) / (milliseconds_cublas / 1000.)) / 1e12;\n    //\n    half* D_cublas_h = NULL;  // col major\n    D_cublas_h       = (half*)malloc(sizeof(half) * M_GLOBAL * N_GLOBAL);   CheckMallocCPU(D_cublas_h);\n    cudaMemcpy(D_cublas_h, D_cublas, sizeof(half) * M_GLOBAL * N_GLOBAL, cudaMemcpyDeviceToHost);  // Col Major\n    cudaFree(D_cublas);\n    checkLastCudaError(__LINE__);\n    /////////////////////////////////////////////////////////////////////////////////////////////////\n    half* D_fp6 = NULL;\n    cudaMalloc(reinterpret_cast<void**>(&D_fp6), sizeof(half) * M_GLOBAL * N_GLOBAL); CheckMallocCUDA(D_fp6);\n    cudaMemset(D_fp6, 0, sizeof(half) * M_GLOBAL * N_GLOBAL);\n    //\n    int Split_K = SPLIT_K;\n    float* Reduction_Workspace = NULL;\n    cudaMalloc(reinterpret_cast<void**>(&Reduction_Workspace), sizeof(float) * M_GLOBAL * N_GLOBAL * Split_K);   CheckMallocCUDA(Reduction_Workspace, __LINE__);\n    //\n    for (int i = 0; i < WARM_UP_ITERATION; i++)\n        fp6_linear_kernel(  0,\n                        (uint4*)A_6bit, A_Scale,\n                        B,\n                        D_fp6,\n                        M_GLOBAL, N_GLOBAL, K_GLOBAL,\n                        Reduction_Workspace,  \n                        Split_K);\n    cudaEventRecord(start);\n    for (int i = 0; i < BENCHMARK_ITERATION; i++)\n        fp6_linear_kernel(  0,\n                        (uint4*)A_6bit, A_Scale,\n                        B,\n                        D_fp6,\n                        M_GLOBAL, N_GLOBAL, K_GLOBAL,\n                        Reduction_Workspace,  \n                        Split_K);\n    cudaEventRecord(stop);\n    cudaEventSynchronize(stop);\n    checkLastCudaError(__LINE__);\n    //\n    float milliseconds_fp6 = 0.0f;\n    cudaEventElapsedTime(&milliseconds_fp6, start, stop);\n    milliseconds_fp6 = milliseconds_fp6 / BENCHMARK_ITERATION;\n    float tflops_fp6 = static_cast<double>((static_cast<double>(M_GLOBAL) * N_GLOBAL * K_GLOBAL * 2) / (milliseconds_fp6 / 1000.)) / 1e12;\n    half* D_fp6_h = NULL;  // col major\n    D_fp6_h       = (half*)malloc(sizeof(half) * M_GLOBAL * N_GLOBAL);\n    cudaMemcpy(D_fp6_h, D_fp6, sizeof(half) * M_GLOBAL * N_GLOBAL, cudaMemcpyDeviceToHost);  // Col Major\n    cudaFree(D_fp6);\n    cudaFree(Reduction_Workspace);\n    /////////////////////////////////////////////////////////////////////////////////////////////////\n    double totalRelativeError_fp6  = ComputeTotalError(D_cublas_h, D_fp6_h, M_GLOBAL, N_GLOBAL);\n    printf(\"************************************* \");\n    printf(\"[%d-bit Weights, e%dm%d] M: %d N: %d K: %d SplitK: %d\", 6, 3, 2, M_GLOBAL, N_GLOBAL, K_GLOBAL, SPLIT_K);\n    printf(\" ************************************\\n\");\n    PrintPerformance(\"cuBLAS\", milliseconds_cublas, tflops_cublas, 0.0);\n    PrintPerformance(\"fp6_llm\", milliseconds_fp6, tflops_fp6, totalRelativeError_fp6);\n\n    //PrintMismatch(\"fp6\", 100, 0.002, D_cublas_h, D_fp6_h, M_GLOBAL, N_GLOBAL);\n\n    free(D_cublas_h);\n    free(D_fp6_h);\n    cudaFree(B);\n    return 0;\n}\n"
  },
  {
    "path": "tests/cpp/kernel_test_fpx.cu",
    "content": "#include \"kernel_test.h\"\n#include \"fp6_linear.cuh\"\n\n\nint main(int argc, char** argv)\n{\n    // Parsing the inputs from CLI.\n    if (argc != 7) {\n        printf(\"Wrong Inputs! Correct input format: ./kernel_test EXPONENT MANTISSA #Row_Weight #Column_Weight BatchSize SplitK\\n\");\n        return -1;\n    }\n    int EXPONENT    = atoi(argv[1]);\n    int MANTISSA    = atoi(argv[2]);\n    size_t M_GLOBAL = atoi(argv[3]);\n    size_t K_GLOBAL = atoi(argv[4]);\n    size_t N_GLOBAL = atoi(argv[5]);\n    int    SPLIT_K  = atoi(argv[6]);\n    int BIT_WIDTH = 1 + EXPONENT + MANTISSA;\n    assert(EXPONENT==2 || EXPONENT==3);\n    assert(MANTISSA==2);\n    assert(M_GLOBAL%256==0);                 // Currently, M_GLOBAL must be a multiple of 256.\n    assert(K_GLOBAL%64==0);                  // Currently, K_GLOBAL must be a multiple of 64.\n    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n    // Matrices in quantized FPx models with faked values.\n    unsigned char* A_xbit_h  = (unsigned char*)malloc(M_GLOBAL*K_GLOBAL*BIT_WIDTH/8);       CheckMallocCPU(A_xbit_h, __LINE__);     // Weight matrix with FP6 values, stored in row-major.\n    for(size_t i=0; i<M_GLOBAL*K_GLOBAL*BIT_WIDTH/8; i++)   A_xbit_h[i] = rand() % 256;                                             // Random initialization.\n    half*          A_Scale_h = (half*)malloc(M_GLOBAL*sizeof(half));                CheckMallocCPU(A_Scale_h, __LINE__);    // Quantization Scales with FP16 values.\n    for(size_t i=0; i<M_GLOBAL; i++)                A_Scale_h[i] = float(rand()%256)/64.0f;                                 // Scale\n    // Generaing FP16 format of the Weight Matrix\n    half* A_16bit_h = (half*) malloc(M_GLOBAL*K_GLOBAL*sizeof(half));                           CheckMallocCPU(A_16bit_h, __LINE__);\n    dequant_matrix_fp_eXmY_to_fp16(EXPONENT, MANTISSA, A_16bit_h, A_xbit_h, M_GLOBAL, K_GLOBAL, A_Scale_h);\n    // In-place weight pre-packing\n    weight_matrix_prepacking_fp_eXmY(EXPONENT, MANTISSA, (int*)A_xbit_h, (int*)A_xbit_h, M_GLOBAL, K_GLOBAL);\n\n    // Devices Memory\n    unsigned char*  A_xbit;\n    half*           A_Scale;\n    half*           A_16bit;\n    cudaMalloc(reinterpret_cast<void**>(&A_xbit),  M_GLOBAL*K_GLOBAL*BIT_WIDTH/8);             CheckMallocCUDA(A_xbit, __LINE__);\n    cudaMalloc(reinterpret_cast<void**>(&A_Scale), M_GLOBAL*sizeof(half));             CheckMallocCUDA(A_Scale, __LINE__);\n    cudaMalloc(reinterpret_cast<void**>(&A_16bit), M_GLOBAL*K_GLOBAL*sizeof(half));    CheckMallocCUDA(A_16bit, __LINE__);\n    // Memory Copy from CPU to GPU\n    cudaMemcpy(A_xbit,     A_xbit_h,  M_GLOBAL*K_GLOBAL*BIT_WIDTH/8,          cudaMemcpyHostToDevice);\n    cudaMemcpy(A_Scale,    A_Scale_h,          M_GLOBAL*sizeof(half),          cudaMemcpyHostToDevice);\n    cudaMemcpy(A_16bit,             A_16bit_h,          M_GLOBAL*K_GLOBAL*sizeof(half), cudaMemcpyHostToDevice);\n    checkLastCudaError(__LINE__);\n    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n    // B Matrix: Activations\n    half* B_h = (half*)malloc(sizeof(half) * K_GLOBAL * N_GLOBAL); CheckMallocCPU(B_h);       // col major \n    for (size_t i = 0; i < N_GLOBAL * K_GLOBAL; i++)\n        B_h[i] = __float2half_rn(static_cast<float>((rand() % 5)) / 5 - 0.5f);\n    // Device memory\n    half* B            = NULL;\n    cudaMalloc(reinterpret_cast<void**>(&B), sizeof(half) * N_GLOBAL * K_GLOBAL);               CheckMallocCUDA(B, __LINE__);\n    // Memory Copy from CPU to GPU\n    cudaMemcpy(B, B_h, sizeof(half) * N_GLOBAL * K_GLOBAL, cudaMemcpyHostToDevice);\n    checkLastCudaError(__LINE__);\n    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n    cublasStatus_t cublas_status;\n    cudaEvent_t start, stop;\n    cudaEventCreate(&start);\n    cudaEventCreate(&stop);\n    checkLastCudaError(__LINE__);\n    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n    //printf(\"Launching CuBlas...\\n\");\n    half* D_cublas = NULL;\n    cudaMalloc(reinterpret_cast<void**>(&D_cublas), sizeof(half) * M_GLOBAL * N_GLOBAL);        CheckMallocCUDA(D_cublas, __LINE__);\n    cudaMemset(D_cublas, 0, sizeof(half) * M_GLOBAL * N_GLOBAL);\n    cublasHandle_t handle;\n    cublasCreate(&handle);\n    cublasSetStream(handle, 0);\n    //cublasSetMathMode(handle, CUBLAS_PEDANTIC_MATH);          // Tensor core NOT enabled\n    cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH);             // Tensor core enabled\n    cudaDeviceSynchronize();\n    int              m = M_GLOBAL, n = N_GLOBAL, k = K_GLOBAL;\n    const float      alpha     = 1.0;\n    const float      beta      = 0.0;\n    cublasGemmAlgo_t CuBlasALG = static_cast<cublasGemmAlgo_t>(0);\n    for (int i = 0; i < WARM_UP_ITERATION; i++) {\n        cublas_status = cublasGemmEx(handle,\n                                     CUBLAS_OP_T,   CUBLAS_OP_N,\n                                     m, n, k,\n                                     &alpha,\n                                     A_16bit,   CUDA_R_16F, k,\n                                     B,         CUDA_R_16F, k,\n                                     &beta,\n                                     D_cublas,  CUDA_R_16F, m,\n                                     CUDA_R_32F,\n                                     CuBlasALG);\n        checkCublasError(cublas_status, __LINE__);\n    }\n    cudaEventRecord(start);\n    for (int i = 0; i < BENCHMARK_ITERATION; i++)\n        cublas_status = cublasGemmEx(handle,\n                                     CUBLAS_OP_T,   CUBLAS_OP_N,\n                                     m, n, k,\n                                     &alpha,\n                                     A_16bit,   CUDA_R_16F, k,\n                                     B,         CUDA_R_16F, k,\n                                     &beta,\n                                     D_cublas,  CUDA_R_16F, m,\n                                     CUDA_R_32F,\n                                     CuBlasALG);\n    cudaEventRecord(stop);\n    cudaEventSynchronize(stop);\n    //\n    float milliseconds_cublas = 0;\n    cudaEventElapsedTime(&milliseconds_cublas, start, stop);\n    milliseconds_cublas = milliseconds_cublas / BENCHMARK_ITERATION;\n    float tflops_cublas = static_cast<double>((static_cast<double>(M_GLOBAL) * N_GLOBAL * K_GLOBAL * 2) / (milliseconds_cublas / 1000.)) / 1e12;\n    //\n    half* D_cublas_h = NULL;  // col major\n    D_cublas_h       = (half*)malloc(sizeof(half) * M_GLOBAL * N_GLOBAL);   CheckMallocCPU(D_cublas_h);\n    cudaMemcpy(D_cublas_h, D_cublas, sizeof(half) * M_GLOBAL * N_GLOBAL, cudaMemcpyDeviceToHost);  // Col Major\n    cudaFree(D_cublas);\n    checkLastCudaError(__LINE__);\n    /////////////////////////////////////////////////////////////////////////////////////////////////\n    //printf(\"Launching FP6-LLM...\\n\");\n    half* D_fp6 = NULL;\n    cudaMalloc(reinterpret_cast<void**>(&D_fp6), sizeof(half) * M_GLOBAL * N_GLOBAL); CheckMallocCUDA(D_fp6);\n    cudaMemset(D_fp6, 0, sizeof(half) * M_GLOBAL * N_GLOBAL);\n    //\n    int Split_K = SPLIT_K;\n    float* Reduction_Workspace = NULL;\n    cudaMalloc(reinterpret_cast<void**>(&Reduction_Workspace), sizeof(float) * M_GLOBAL * N_GLOBAL * Split_K);   CheckMallocCUDA(Reduction_Workspace, __LINE__);\n    //\n    for (int i = 0; i < WARM_UP_ITERATION; i++)\n        fp_eXmY_linear_kernel(  \n            EXPONENT,\n            MANTISSA,\n            0,\n            (uint4*)A_xbit, A_Scale,\n            B,\n            D_fp6,\n            M_GLOBAL, N_GLOBAL, K_GLOBAL,\n            Reduction_Workspace,  \n            Split_K);\n    cudaEventRecord(start);\n    for (int i = 0; i < BENCHMARK_ITERATION; i++)\n        fp_eXmY_linear_kernel(  \n            EXPONENT,\n            MANTISSA,\n            0,\n            (uint4*)A_xbit, A_Scale,\n            B,\n            D_fp6,\n            M_GLOBAL, N_GLOBAL, K_GLOBAL,\n            Reduction_Workspace,  \n            Split_K);\n    cudaEventRecord(stop);\n    cudaEventSynchronize(stop);\n    checkLastCudaError(__LINE__);\n    //\n    float milliseconds_fp6 = 0.0f;\n    cudaEventElapsedTime(&milliseconds_fp6, start, stop);\n    milliseconds_fp6 = milliseconds_fp6 / BENCHMARK_ITERATION;\n    float tflops_fp6 = static_cast<double>((static_cast<double>(M_GLOBAL) * N_GLOBAL * K_GLOBAL * 2) / (milliseconds_fp6 / 1000.)) / 1e12;\n    half* D_fp6_h = NULL;  // col major\n    D_fp6_h       = (half*)malloc(sizeof(half) * M_GLOBAL * N_GLOBAL);\n    cudaMemcpy(D_fp6_h, D_fp6, sizeof(half) * M_GLOBAL * N_GLOBAL, cudaMemcpyDeviceToHost);  // Col Major\n    cudaFree(D_fp6);\n    cudaFree(Reduction_Workspace);\n    /////////////////////////////////////////////////////////////////////////////////////////////////\n    double totalRelativeError_fp6  = ComputeTotalError(D_cublas_h, D_fp6_h, M_GLOBAL, N_GLOBAL);\n    printf(\"************************************* \");\n    printf(\"[%d-bit Weights, e%dm%d] M: %d N: %d K: %d SplitK: %d\", BIT_WIDTH, EXPONENT, MANTISSA, M_GLOBAL, N_GLOBAL, K_GLOBAL, SPLIT_K);\n    printf(\" ************************************\\n\");\n    PrintPerformance(\"cuBLAS\", milliseconds_cublas, tflops_cublas, 0.0);\n    PrintPerformance(\"quant_llm\", milliseconds_fp6, tflops_fp6, totalRelativeError_fp6);\n    //PrintMismatch(\"fp6\", 100, 0.002, D_cublas_h, D_fp6_h, M_GLOBAL, N_GLOBAL);\n\n    free(D_cublas_h);\n    free(D_fp6_h);\n    cudaFree(B);\n    return 0;\n}\n"
  },
  {
    "path": "tests/cpp/run.sh",
    "content": "#! /bin/bash\n\n# Batch sizes to test\nN=(1 2 3 4 5 6 7 8)\n\n# Benchmarking the specific Matrix Shape from llama-1 65b\nM=(13824 5120  22016 8192)\nK=(5120  13824 8192  22016)\nSplitK=(4 10 5 6)      # SplitK for smaller Batch Sizes\n#SplitK=(3)      # SplitK for Batch Sizes 128\n#SplitK=(3)      # SplitK for Batch Sizes 512\n#SplitK=(1)      # SplitK for Batch Sizes 2048\n\n\n# Benchmarking Matrix Shapes from OPT models \n#M=(21504        7168     28672  7168    27648   9216    36864   9216    36864   12288   49152   12288)\n#K=(7168         7168     7168   28672   9216    9216    9216    36864   12288   12288   12288   49152)\n#SplitK=(2       7       7      7       2       6       3       6       3       4       1       4)        # SplitK for smaller Batch Sizes\n#SplitK=(5       7       7      7       1       3       3       3       3       2       1       2)        # SplitK for Batch Sizes 128\n#SplitK=(3       2       3      3       1       3       1       3       1       1       1       1)        # SplitK for Batch Sizes 512\n#SplitK=(1       2       1      2       1       1       1       1       1       1       1       1)        # SplitK for Batch Sizes 2048\n\n# Benchmarking Matrix Shapes from llama-1 models \n#M=(12288       4096    11008   4096    15360   5120    13824   5120    19968   6656    17920   6656    24576   8192    22016   8192)\n#K=(4096        4096    4096    11008   5120    5120    5120    13824   6656    6656    6656    17920   8192    8192    8192    22016)\n#SplitK=(9      13      5       13      3       10      4       10      5       8       3       8       2       6       5       6)      # SplitK for smaller Batch Sizes\n#SplitK=(2      6       5       6       3       5       2       5       4       4       3       4       1       3       5       3)      # SplitK for Batch Sizes 128\n#SplitK=(1      3       3       3       2       4       1       4       1       1       3       1       1       3       2       3)      # SplitK for Batch Sizes 512\n#SplitK=(1      2       1       2       1       1       1       1       1       1       1       1       1       1       1       1)      # SplitK for Batch Sizes 2048\n\n#mkdir -p Profiling\nfor ((i=0;i<${#M[@]};i++)) \ndo\n    echo \"Processing Shape ${i}...\"\n    for BS in ${N[@]} \n    do\n        echo \"BS=${BS}\"\n        #ncu -f -o Profiling/M${M[i]}K${K[i]}N${BS} --set full \\\n        #./kernel_test_fp6 ${M[i]} ${K[i]} ${BS} ${SplitK[i]}\n        ./kernel_test_fpx 2 2 ${M[i]} ${K[i]} ${BS} ${SplitK[i]}\n        ./kernel_test_fpx 3 2 ${M[i]} ${K[i]} ${BS} ${SplitK[i]}  \n    done\ndone\n"
  },
  {
    "path": "tests/python/kernel_test_fp6.py",
    "content": "import argparse\nimport torch\nimport fp6_llm\n\nWARMUP = 10\nREPEAT = 1000\n\nparser = argparse.ArgumentParser(description='The shape of the MatMul: (M, K)*(K, N)->(M, N).')\nparser.add_argument('--OC',        type=int, required=False,     default=4096,   help='number of rows of the weight matrix.')\nparser.add_argument('--IC',        type=int, required=False,     default=4096,   help='number of columns of the weight matrix.')\nparser.add_argument('--BS',        type=int, required=False,     default=32,     help='inference batch size.')\nparser.add_argument('--splitK',    type=int, required=False,     default=1,      help='Split-K parameters allow users to split the GEMM computation along the K dimension so that more CTAs will be created with a better SM utilization.')\nargs = parser.parse_args()\n\nassert(args.OC%256==0)\nassert(args.IC%64==0)\n\nprint(\"#\"*64)\nprint(args)\n\nfp6_weight = torch.randint(4294967295, (args.OC,args.IC//16*3)).to(torch.int)    # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t.\nfp16_scale = torch.rand(args.OC).to(torch.half)+0.5\nfp16_activation = torch.rand(args.BS, args.IC).to(torch.half)+0.5\nstart_event = torch.cuda.Event(enable_timing=True)\nend_event = torch.cuda.Event(enable_timing=True)\n\n# fp6-fp16 GEMM (fp6-llm)\n####################################################################################################################################\ntorch.cuda.synchronize()\nfp6_weight_packed = fp6_llm.weight_prepacking_cpu(fp6_weight)\nact_cuda = fp16_activation.cuda()\nweight_cuda = fp6_weight_packed.cuda()\nscale_cuda = fp16_scale.cuda()\nfor i in range(WARMUP):\n    results_fp6_llm = fp6_llm.linear_forward_cuda(act_cuda, weight_cuda, scale_cuda, args.splitK);\nstart_event.record()\nfor i in range(REPEAT):\n    results_fp6_llm = fp6_llm.linear_forward_cuda(act_cuda, weight_cuda, scale_cuda, args.splitK);\nend_event.record()\ntorch.cuda.synchronize()\nfp6_llm_time_ms = start_event.elapsed_time(end_event)/REPEAT\nfp6_llm_tflops  = args.OC*args.IC*args.BS*2/fp6_llm_time_ms/1e9\n####################################################################################################################################\n\n# baseline fp16 GEMM (cuBLAS)\n####################################################################################################################################\ntorch.cuda.synchronize()\nfp16_weight = fp6_llm.weight_dequant_cpu(fp6_weight, fp16_scale)\ncuBLAS_MatMul = torch.nn.Linear(args.IC, args.OC, False)\nresults_cublas = None\nwith torch.no_grad():\n    cuBLAS_MatMul.weight = torch.nn.Parameter(fp16_weight.clone().cuda())\n    act_cuda = fp16_activation.cuda()\n    for i in range(WARMUP):\n        results_cublas = cuBLAS_MatMul(act_cuda)\n    start_event.record()\n    for i in range(REPEAT):\n        results_cublas = cuBLAS_MatMul(act_cuda)\n    end_event.record()\ntorch.cuda.synchronize()\ncublas_time_ms = start_event.elapsed_time(end_event)/REPEAT\ncublas_tflops  = args.OC*args.IC*args.BS*2/cublas_time_ms/1e9\n####################################################################################################################################\n\n# Performance\nprint( 'cuBLAS  time: {:.3f} ms \\t\\t cuBLAS  TFLOPs: {:.1f}'.format(cublas_time_ms,  cublas_tflops) )\nprint( 'fp6-llm time: {:.3f} ms \\t\\t fp6-llm TFLOPs: {:.1f}'.format(fp6_llm_time_ms, fp6_llm_tflops) )\nprint( 'speedup: {:.2f}'.format(cublas_time_ms/fp6_llm_time_ms) )\n\n# Correctness\nerror             = results_cublas.cpu() - results_fp6_llm.cpu()\nground_truth      = results_cublas.cpu()\nmean_error        = torch.mean(abs(error))\nmean_ground_truth = torch.mean(abs(ground_truth))\nrelative_error    = mean_error.item()/mean_ground_truth.item()\nprint( \"relative error: {:.6f}\".format(relative_error) )"
  },
  {
    "path": "tests/python/kernel_test_fpx.py",
    "content": "import argparse\nimport torch\nimport fp6_llm\n\nWARMUP = 10\nREPEAT = 1000\n\nparser = argparse.ArgumentParser(description='The shape of the MatMul: (M, K)*(K, N)->(M, N).')\nparser.add_argument('--OC',        type=int, required=False,     default=4096,   help='number of rows of the weight matrix.')\nparser.add_argument('--IC',        type=int, required=False,     default=4096,   help='number of columns of the weight matrix.')\nparser.add_argument('--BS',        type=int, required=False,     default=32,     help='inference batch size.')\nparser.add_argument('--splitK',    type=int, required=False,     default=1,      help='Split-K parameters allow users to split the GEMM computation along the K dimension so that more CTAs will be created with a better SM utilization.')\nparser.add_argument('--EXP',       type=int, required=False,     default=2,      help='number of bits of fpx Exponent, can be set to 2 or 3.')\nparser.add_argument('--MAN',       type=int, required=False,     default=2,      help='number of bits of fpx Mantissa, can only be set to 2.')\nargs = parser.parse_args()\n\nEXPONENT  = args.EXP\nMANTISSA  = args.MAN\nBIT_WIDTH = 1 + EXPONENT + MANTISSA\nassert EXPONENT in [2,3]\nassert MANTISSA in [2]\n\nassert(args.OC%256==0)\nassert(args.IC%64==0)\n\nprint(\"#\"*64)\nprint(args)\n\nfpx_weight = torch.randint(4294967295, (args.OC,args.IC//32*BIT_WIDTH)).to(torch.int)    # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t.\nfp16_scale = torch.rand(args.OC).to(torch.half)+0.5\nfp16_activation = torch.rand(args.BS, args.IC).to(torch.half)+0.5\n\nstart_event = torch.cuda.Event(enable_timing=True)\nend_event = torch.cuda.Event(enable_timing=True)\n\n# fpx-fp16 GEMM (fp6-llm)\n####################################################################################################################################\ntorch.cuda.synchronize()\nfpx_weight_packed = fp6_llm.weight_prepacking_eXmY_cpu(EXPONENT, MANTISSA, fpx_weight)\nact_cuda = fp16_activation.cuda()\nweight_cuda = fpx_weight_packed.cuda()\nscale_cuda = fp16_scale.cuda()\nfor i in range(WARMUP):\n    results_fp6_llm = fp6_llm.linear_forward_eXmY_cuda(EXPONENT, MANTISSA, act_cuda, weight_cuda, scale_cuda, args.splitK)\nstart_event.record()\nfor i in range(REPEAT):\n    results_fp6_llm = fp6_llm.linear_forward_eXmY_cuda(EXPONENT, MANTISSA, act_cuda, weight_cuda, scale_cuda, args.splitK)\nend_event.record()\ntorch.cuda.synchronize()\nfp6_llm_time_ms = start_event.elapsed_time(end_event)/REPEAT\nfp6_llm_tflops  = args.OC*args.IC*args.BS*2/fp6_llm_time_ms/1e9\n####################################################################################################################################\n\n# baseline fp16 GEMM (cuBLAS)\n####################################################################################################################################\ntorch.cuda.synchronize()\nfp16_weight = fp6_llm.weight_dequant_eXmY_cpu(EXPONENT, MANTISSA, fpx_weight, fp16_scale)\ncuBLAS_MatMul = torch.nn.Linear(args.IC, args.OC, False)\nresults_cublas = None\nwith torch.no_grad():\n    cuBLAS_MatMul.weight = torch.nn.Parameter(fp16_weight.clone().cuda())\n    act_cuda = fp16_activation.cuda()\n    for i in range(WARMUP):\n        results_cublas = cuBLAS_MatMul(act_cuda)\n    start_event.record()\n    for i in range(REPEAT):\n        results_cublas = cuBLAS_MatMul(act_cuda)\n    end_event.record()\ntorch.cuda.synchronize()\ncublas_time_ms = start_event.elapsed_time(end_event)/REPEAT\ncublas_tflops  = args.OC*args.IC*args.BS*2/cublas_time_ms/1e9\n####################################################################################################################################\n\n# Performance\nprint( 'cuBLAS    time: {:.3f} ms \\t\\t cuBLAS    TFLOPs: {:.1f}'.format(cublas_time_ms,  cublas_tflops ) )\nprint( 'quant-llm time: {:.3f} ms \\t\\t quant-llm TFLOPs: {:.1f}'.format(fp6_llm_time_ms, fp6_llm_tflops) )\nprint( 'speedup: {:.2f}'.format(cublas_time_ms/fp6_llm_time_ms) )\n\n# Correctness\nerror             = results_cublas.cpu() - results_fp6_llm.cpu()\nground_truth      = results_cublas.cpu()\nmean_error        = torch.mean(abs(error))\nmean_ground_truth = torch.mean(abs(ground_truth))\nrelative_error    = mean_error.item()/mean_ground_truth.item()\nprint( \"relative error: {:.6f}\".format(relative_error) )\n\n\n# [FP5_e2m2] Setting each element of the input matrices to 1.0f instead of a random value.\n#fpx_weight = torch.zeros(args.OC, args.IC//32*BIT_WIDTH).to(torch.int64)  \n#for i in range(args.OC):\n#    for j in range(args.IC//32):\n#        fpx_weight[i][j*BIT_WIDTH+0] = 272762913\n#        fpx_weight[i][j*BIT_WIDTH+1] = 1107829124\n#        fpx_weight[i][j*BIT_WIDTH+2] = 136414224 \n#        fpx_weight[i][j*BIT_WIDTH+3] = 562303042 \n#        fpx_weight[i][j*BIT_WIDTH+4] = 2215657992\n#fpx_weight = fpx_weight.to(torch.int)\n#fp16_scale      = torch.zeros(args.OC).to(torch.half)         +1.0\n#fp16_activation = torch.zeros(args.BS, args.IC).to(torch.half)+1.0\n\n# 1.0f values in fp5_e2m2\n# 00100 00100 00100 00100 00100 00100 00100 00100       00100 00100 00100 00100 00100 00100 00100 00100     00100 00100 00100 00100 00100 00100 00100 00100     00100 00100 00100 00100 00100 00100 00100 00100\n# “00100001 00001000 01000010 00010000”     “10000100 00100001 00001000 01000010”    “00010000 10000100 00100001 00001000”   “01000010 00010000 10000100 00100001”  “00001000 01000010 00010000 10000100”\n# Considering Byte order within a INT32\n# 00010000 01000010 00001000 00100001       01000010 00001000 00100001 10000100      00001000 00100001 10000100 00010000      00100001 10000100 00010000 01000010   10000100 00010000 01000010 00001000\n# 00010000010000100000100000100001          01000010000010000010000110000100         00001000001000011000010000010000         00100001100001000001000001000010      10000100000100000100001000001000\n# 272762913                                 1107829124                               136414224                                562303042                             2215657992\n\n\n# [FP6_e3m2] Setting each element of the input matrices to 1.0f instead of a random value.\n#fp6_weight = torch.zeros(args.OC, args.IC//32*BIT_WIDTH).to(torch.int64)\n#for i in range(args.OC):\n#    for j in range(args.IC//16):\n#        fp6_weight[i][j*3+0] = 806142768\n#        fp6_weight[i][j*3+1] = 3274706115\n#        fp6_weight[i][j*3+2] = 214118412\n#        fp6_weight[i][j*3+3] = 806142768\n#        fp6_weight[i][j*3+4] = 3274706115\n#        fp6_weight[i][j*3+5] = 214118412\n#fp6_weight = fp6_weight.to(torch.int32)\n#fp16_scale = torch.zeros(args.OC).to(torch.half)+1.0\n#fp16_activation = torch.zeros(args.BS, args.IC).to(torch.half)+1.0\n\n# 1.0f values in fp6_e3m2\n# 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 \n# \"00110000110000110000110000110000\"       \"11000011000011000011000011000011\"     \"00001100001100001100001100001100\"\n# 00110000 11000011 00001100 00110000      11000011 00001100 00110000 11000011    00001100 00110000 11000011 00001100\n# Considering Byte order within a INT32\n# 00110000 00001100 11000011 00110000      11000011 00110000 00001100 11000011     00001100 11000011 00110000 00001100\n# 00110000000011001100001100110000         11000011001100000000110011000011        00001100110000110011000000001100\n# 806142768                                3274706115                              214118412"
  },
  {
    "path": "tests/python/run.sh",
    "content": "#! /bin/bash\n\n# [Batch sizes to test]\n# If you want to test the performance of FP6-LLM for larger inference batch sizes, \n# which typically happens during prompt processing, \n# please revise this file by simply \"commenting\" and \"uncommenting\".\n\n# BS <=64\nN=(1 2 4 8 16 32 64)\nSplitK=(5 6 7 6)\n\n# BS = 128\n#N=(128)\n#SplitK=(5 3 3 3)\n\n# BS = 256\n#N=(256)\n#SplitK=(4 3 2 3)\n\n# BS = 512\n#N=(512)\n#SplitK=(2 5 2 4)\n\n# BS = 1024\n#N=(1024)\n#SplitK=(1 2 1 2)\n\n# BS >= 2048\n#N=(2048, 4096, 8192, 16384)\n#SplitK=(1 1 1 1)\n\n# Benchmarking the specific Matrix Shape from llama2-70b\nM=(10240 8192  57344 8192)\nK=(8192  8192  8192  28672)\n\n#mkdir -p Profiling\nfor ((i=0;i<${#M[@]};i++)) \ndo\n    for BS in ${N[@]} \n    do\n        #ncu -f -o Profiling/M${M[i]}K${K[i]}N${BS} --set full \\\n        #python kernel_test_fp6.py --OC=${M[i]} --IC=${K[i]} --BS=${BS} --splitK=${SplitK[i]}\n        python kernel_test_fpx.py --OC=${M[i]} --IC=${K[i]} --BS=${BS} --splitK=${SplitK[i]}\n        python kernel_test_fpx.py --OC=${M[i]} --IC=${K[i]} --BS=${BS} --splitK=${SplitK[i]} --EXP=3 --MAN=2\n    done\ndone"
  }
]