Full Code of usyd-fsalab/fp6_llm for AI

main 12e83379f16a cached
30 files
145.7 KB
42.7k tokens
16 symbols
1 requests
Download .txt
Repository: usyd-fsalab/fp6_llm
Branch: main
Commit: 12e83379f16a
Files: 30
Total size: 145.7 KB

Directory structure:
gitextract_xbh0tnpv/

├── .gitignore
├── LICENSE
├── README.md
├── examples/
│   └── README.md
├── fp6_llm/
│   ├── Makefile
│   ├── __init__.py
│   └── csrc/
│       ├── fp6_linear.cu
│       ├── fp6_linear.cuh
│       ├── include/
│       │   ├── configs.h
│       │   ├── kernel_matmul.cuh
│       │   ├── kernel_reduction.cuh
│       │   ├── ptx_cp.async.cuh
│       │   ├── ptx_mma.cuh
│       │   ├── utils_core.cuh
│       │   ├── utils_gmem.cuh
│       │   └── utils_parallel_dequant.cuh
│       ├── pybind.cpp
│       └── utils/
│           ├── common.h
│           ├── weight_dequant.h
│           ├── weight_prepacking.h
│           └── weight_quant.h
├── setup.py
└── tests/
    ├── cpp/
    │   ├── Makefile
    │   ├── kernel_test.h
    │   ├── kernel_test_fp6.cu
    │   ├── kernel_test_fpx.cu
    │   └── run.sh
    └── python/
        ├── kernel_test_fp6.py
        ├── kernel_test_fpx.py
        └── run.sh

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*ncu
*.so
*.o
*.ncu-rep
*.nsys-rep
*.sqlite
build
kernel_test
*.egg-info
.vscode

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# Quant-LLM (FP6, FP5, FPx...)
Six-bit quantization (FP6) can achieve **better trade-offs** between [*model quality*](#1-model-accuracy) and [*inference cost*](#2-speedups-on-linear-layers) compard to 4-bit and 8-bit quantization counterparts, reducing the size of large language models (LLMs) effectively and preserving the model quality consistently across varied applications.
To support **6-bit inference of LLMs effective on modern GPUs**, we provide the official implementation of [**FP6-LLM**](https://arxiv.org/pdf/2401.14112.pdf), achieving significant *speedups of linear layers* and *GPU memory reduction* over the fp16/int8 baselines.

Our long-term goal is to support **various quantization methods** by providing **extensible & high-performance** GPU kernels for mixed-input matrix multiplication, using the **unified design scheme** presented in our [paper](https://arxiv.org/pdf/2401.14112.pdf), which is recently accepted by [USENIX ATC24](https://www.usenix.org/conference/atc24/presentation/xia).

Currently, we have tested **FP6_e3m2** and **FP5_e2m2**.
However, our code is templated and easy to support different combination of eXmY if necessary.

![Overview of FP6-LLM.](./docs/figures/banner.png)

## Roadmap

The current release contains:
- We support model weights in **FP6_e3m2** or **FP5_e2m2** and the activations in FP16 format.
- Efficient CUDA implementation for mixed-input matrix multiplication of linear layers (weights in FP6 and activations in FP16 format) with Tensor Core enabled.
- C++ APIs and PyTorch APIs to use the CUDA kernels.
- Test codes to demonstrate the usage of FP6-LLM and verify its correctness.
- **End-to-end inference** support of [LLaMA2](https://arxiv.org/abs/2307.09288) models is released at [Deepspeed](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024).

Our future plan includes but not limited to :
- [ ] Currently, FP6-LLM only supports **FP6** quantization due to its accuracy/performance tradeoff benefits. However, the technology of FP6-LLM can be easily applied to other quanzation methods, e.g., **FP4, INT5**.
- [ ] Currently, FP6-LLM is only tested and verified on **A100 GPUs**, but the core design methods can also be applied to other Tensor Core GPUs like **NVIDIA H100 and GH200**. Furtheremore, W6A8 quantization can be supported on H100 GPUs by exploiting the FP8 Tensor Cores.

## Installation
1. Clone this repository.
```sh
git clone https://github.com/usyd-fsalab/fp6_llm.git
cd fp6_llm
```
2. Install the python package. [Enabling PyTorch APIs]
```sh
pip install .
```

3. Compiling the .so file. [Enabling C++ APIs]
```sh
cd fp6_llm && make
```

## Tests
We provide scripts to verify the correctness of FP6-LLM.
The outputs of FP6-LLM are compared to the outputs of the FP16 baseliness (the official implementation of linear layer in PyTorch and NVIDIA cuBLAS).

#### 1. Run tests for PyTorch APIs
```
cd ../tests/python
./run.sh
```

#### 2. Run tests for C++ APIs
```
cd ../cpp
make
./run.sh
```

## How to use our FP6 CUDA kernels
We implemented the CUDA kernel supporting matrix multiply C = A × B, where A is the weight matrix of shape [OC, IC] in FP6 and B is the activation matrix of shape [IC, BS] in FP16.
C and B are column-major matrices.
The CUDA kernels can be launched via **PyTorch APIs** or **C++ APIs**.
Currently:
- OC(Output Channel) must be a multiple of 256, and IC(Input Channel) must be a multiple of 64.
- BS(Batch Size) can be arbitrary values, smaller BS is prefered for better performance than FP16 baseline.

For more details of using FP6-LLM APIs, please see the source code of the [C++/Python Test](#tests).

#### 1. PyTorch APIs
To use the PyTorch APIs of FP6-LLM, you only need to import the python module:
```
import fp6_llm
```

* Ahead of time weight prepacking:
```
fp6_pakced_weight = fp6_llm.weight_prepacking_cpu(fp6_weight)

Arguments:
    fp6_weight:         int tensor of shape [OC, IC // 16 * 3];     // 3 INT32 words contains 16 FP6 weights.
Return:
    fp6_pakced_weight:  int tensor of shape [OC, IC // 16 * 3];     // 3 INT32 words contains 16 FP6 weights.
```

* Execute FP6-FP16 mixed input GEMM on GPUs:
```
fp16_output=fp6_llm.linear_forward_cuda(fp6_packed_weight, fp16_scale, fp16_input, splitK)

Arguments:
    fp6_packed_weight:  int tensor of shape [OC, IC // 16 * 3];     // 3 INT32 words contains 16 FP6 weights.
    fp16_scale:         tensor of shape [OC];                       // half tensor
    fp16_input:         tensor of shape [B, IC];                    // half tensor
    splitK:             int value, spliting the MatMul problem along K dimension for higher GPU utilization, default 1.
Return:
    fp16_output:        tensor of shape [B, OC];                    // half tensor
```

* To help users determine the SplitK for other MatMul shapes, we provide a heuristic function in this repo to calculate a SplitK.
This function can provide you a pretty good candidature for SplitK.
To achieve best performance, please try different SplitK executing the MatMul of the given shape and choose the best SplitK according to the measured kernel latency.
```
import fp6_llm
SplitK = fp6_llm.HeuristicFuntion_SplitK(M, N, Number_GPU_SMs)

Arguments:
    M is the number of rows of the weight matrix, N is the inference batch size. More specifically, the shape of the MulMal: (M, K) * (K, N) -> (M, N), 
    Number_GPU_SMs is the number of Stream Multiprocessors of the GPU you are using. For A100 GPUs, Number_GPU_SMs=108.
```

* Dequantize an FP6 matrix back to FP16 matrix with CPUs (a useful tool to construct input matrices for the FP16 GEMM baseline):
```
fp16_tensor=fp6_llm.weight_dequant_cpu(fp6_tensor, fp16_scale)

Arguments:
    fp6_tensor:         int  tensor of shape [OC, IC // 16 * 3];   // 3 INT32 words contains 16 FP6  weights.
    fp16_scale:         half tensor of shape [OC];                 // for row-wise quantization.
Return:
    fp16_tensor:        half tensor of shape [OC, IC].
```

#### 2. C++ APIs
To use the C++ APIs of FP6-LLM, you need to include [this head file](fp6_llm/csrc/fp6_linear.cuh) of FP6-LLM in your C++ codes:
```
#include "fp6_linear.cuh"
```
Besides, you need to link the dynamic linkable library (*fp6.so*) under [this directory](fp6_llm/) during compilation.

* Ahead of time weight prepacking:
```
void weight_matrix_prepacking(int* packed_weights,                  // [Output] prepacked FP6 weight matrix
                              int *FP6Weights,                      // [Input]  original FP6 weight matrix
                              size_t M,                             // OC
                              size_t K);                            // IC
```

* Execute FP6-FP16 mixed input GEMM on GPUs:
```
cudaError_t fp6_linear_kernel(cudaStream_t    stream,               // CUDA stream to execute the GPU kernel.
                              const uint4     *Weight,              // [Input]  Pointer to the FP6 weight matrix.
                              const half      *Scales,              // [Input]  Pointer to the FP16 quantization scales.
                              const half      *B,                   // [Input]  Pointer to the FP16 input activation.
                              half            *C,                   // [Output] Pointer to the FP16 output activation.
                              const size_t    M_Global,             // OC
                              const size_t    N_Global,             // BS
                              const size_t    K_Global,             // IC
                              float           *Reduction_Workspace, // Pointer to the temporary workspace for splitK reduction. Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32)
                              int             Split_K);             // splitK
```

* Dequantize an FP6 matrix back to FP16 matrix with CPUs
```
void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h,                     // [Output] Pointer to the dequantized FP16 weight. matrix.
                               unsigned char* A_6bit_h,             // [Input]  Pointer to the quantized FP6 weight matrix. 
                               size_t M,                            // OC
                               size_t K,                            // IC
                               half* scale);                        // [Input]  Pointer to the FP16 quantization scales.
```

## How to use our FP5 CUDA kernels
Similar to FP6 APIs, except that two more parameters (EXPONENT, MANTISSA) are required. More detailed will be provided soon.

## Performance

#### 1. Model Accuracy
> **FP6 quantization is a practical alternative to further democratize the deployment of LLMs without significantly sacrificing model quality on complex tasks and various model sizes.**

- While 4-bit quantization unavoidably causes degradation in model quality, near-lossless model compression can be achieved with 6-bit quantization. As shown in Table 1 and Table 2, FP6 displays **strong and consistent performance across various tasks** including code generation and zero-shot perplexity performance.
- It also shows **high robustness across various model sizes**, e.g., 1B, 13B, and 65B LLaMA models. 
- FP6 quantization already works well on coarse-grained quantization, while INT4 quantization heavily relies on Fine-Grained Quantization (FGQ) methods to maintain high model quality.
- *More details can be found in these two papers ([FP6-LLM](https://arxiv.org/pdf/2401.14112.pdf) & [ZeroQuant(4+2)](https://arxiv.org/abs/2312.08583)).*

![ModelAccuracy](docs/figures/Accuracy.png)


#### 2. Speedups on linear layers

> **Compared to the FP16 baselines (cuBLAS), INT8 quantization (TensorRT_LLM),
and FP4 quantization (bitsandbytes):**

- FP6-LLM outperforms bitsandbytes, cuBLAS, and TensorRT_LLM by **up to** 8.9×, 2.6×, and 1.9×. 
- FP6-LLM outperforms bitsandbytes, cuBLAS and TensorRT_LLM by 7.2×, 2.1×, and 1.3× **on average**.

![Speedup_To_8bit](docs/figures/Speedup_to_8bit.png)

> **Compared to baselines for INT4 quantization (TensorRT_LLM):**

- FP6-LLM is **1.06×/1.04×/0.94× faster** than **Fine-grained_W4A16** at batch size 8/16/32, espectively. 
- FP6-LLM is **only 16%/17%/24% slower** than **Coarse-grained_W4A16** at batch size 8/16/32. 
- It is **a worthwhile trade-off** between model quality and inference speed, since FP6 quantization can provide [higher model quality](#1-model-accuracy) than INT4 quantization. 

![Speedup_To_4bit](docs/figures/Speedup_to_4bit.png)

#### 3. Speedups on end-to-end inference

> **Experimental settings and baselines:**

- We integrate our FP6-LLM kernel into DeepSpeed for end-to-end evaluation. The baseline for comparison is the FP16 execution of the original DeepSpeed inference system.
- We set the prefill/prompt length of each request to 0.5K, and generate 1.5K tokens for each request ignoring the "EOS" (end of sequence) token. 

![End-to-end Inference](./docs/figures/e2e_inference.png)


> **Results of LLaMA-70b:**
- Both FP6-LLM and FP16 baseline can at most set the inference **batch size to 32** before running out of GPU memory, whereas **FP6-LLM** only requires **a single GPU** and the **baseline** uses **two GPUs**. 
- FP6-LLM achieves **1.69×-2.65× higher normalized inference throughput** than the FP16 baseline.

> **Results of OPT-30b:**
- FP6-LLM can set the inference batch size at most to 16 before running out of GPU memory while the FP16 baseline can at most serve 4 requests in a batch. 
- Using a 80GB A100 GPU, FP6-LLM/FP16-baseline can at most achieve 319.1/78.8 tokens per GPU-second with batch size 16/4.
- FP6-LLM achieves 1.91×/1.84×/1.72× higher generation throughput compared to the FP16 baseline when their batch sizes are set to 1/2/4.

> **Inference latency breakdown of LLaMA-70b:**
- The execution of linear layers (MatMul) implemented with FP6-LLM is 1.20× faster than the FP16 baseline on average, even with half number of GPUs.
- The FP16 baseline is faster running multi-head attention (MHA) with 2-way tensor parallelism.
- Cross-GPU communications (NCCL) is avoided using FP6-LLM since only a
single GPU is required. 

> **Inference latency breakdown of OPT-30b:**
- End-to-end performance improvements mainly comes from time reduction in executing linear layers.
- The linear layers implemented with FP6-LLM are 2.39× faster than the FP16 baselines on average.


## Key Innovations
> We propose **TC-FPx**, **the first full-stack GPU system design scheme with unified Tensor Core support of float-point weights for various quantization bit-width (6-bit, 5-bit, 3-bit, etc.), mitigating the "memory wall" issues during LLM inference.**
TC-FPx breaks the limitations of the underlying GPU hardware, allowing the GPU to support linear layer calculations involving model weights of arbitrary bit width. 
In TC-FPx, Tensor Cores are utilized for intensive computation of matrix multiplications,
while SIMT cores are effectively leveraged for weight dequantization, transforming the x-bit model weights to FP16 type during runtime before feeding them to Tensor Cores. 

- We propose **Ahead-of-time Bit-level Pre-packing** to resolve the challenge of unfriendly memory access for weights with irregular bit-width, enabling optimal GPU memory access.
- Besides, we propose **SIMT-Efficient GPU Runtime** to minimize the runtime overhead of weight de-quantization. 
- Last but not least, we present the software pipeline of TC-FPx kernel, where SIMT cores, Tensor Cores, and the GPU memory hierarchy cooperate efficiently with high performance.

![](./docs/figures/designs.png)

## FP6-LLM Community
FP6-LLM is already integrated in [DeepSpeed](https://github.com/microsoft/DeepSpeed) and this new feature will be available soon.
Given that easy-to-use PyTorch and C++ APIs are provided, FP6-LLM can be easily integrated to any inference frameworks as an useful component.
We welcome collaborations to integrate FP6-LLM into other inference frameworks.
We also welcome all AI developers/practitioners/researchers to join this on-going project, fully exploring the potential of different quantization methods.

## Citation

If you find FP6-LLM useful or relevant to your research, please kindly cite [our paper](https://arxiv.org/pdf/2401.14112.pdf):

```
@misc{xia2024fp6llm,
      title={FP6-LLM: Efficiently Serving Large Language Models Through FP6-Centric Algorithm-System Co-Design}, 
      author={Haojun Xia and Zhen Zheng and Xiaoxia Wu and Shiyang Chen and Zhewei Yao and Stephen Youn and Arash Bakhtiari and Michael Wyatt and Donglin Zhuang and Zhongzhu Zhou and Olatunji Ruwase and Yuxiong He and Shuaiwen Leon Song},
      year={2024},
      eprint={2401.14112},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
```

## Change Logs
- **[28th May, 2024]**: Release of FP6-LLM v0.2. Change the project name back to QuantLLM.
- **[31th April, 2024]**: Our paper is accepted by [USENIX ATC24](https://www.usenix.org/conference/atc24/presentation/xia).
- **[4th March, 2024]**: Release of FP6-LLM v0.1.

## Related Projects
- [ZeroQuant(4+2): Redefining LLMs Quantization with a New FP6-Centric Strategy for Diverse Generative Tasks](https://arxiv.org/abs/2312.08583)
- [DeepSpeed: Extreme Speed and Scale for DL Training and Inference](https://github.com/microsoft/DeepSpeed)
- [cuBLAS: Basic Linear Algebra on NVIDIA GPUs](https://developer.nvidia.com/cublas)
- [TensorRT-LLM: A TensorRT Toolbox for Optimized Large Language Model Inference](https://github.com/NVIDIA/TensorRT-LLM)
- [bitsandbytes: the library including quantization primitives for 8-bit & 4-bit operations](https://github.com/TimDettmers/bitsandbytes)
- [Mixed-input matrix multiplication performance optimizations](https://blog.research.google/2024/01/mixed-input-matrix-multiplication.html)
- [Flash-LLM: Enabling Cost-Effective and Highly-Efficient Large Generative Model Inference with Unstructured Sparsity](https://www.vldb.org/pvldb/vol17/p211-xia.pdf)


================================================
FILE: examples/README.md
================================================
# Example of LLM inference using FP6-LLM

Example scripts of using FP6-LLM for end-to-end inference are coming soon.

================================================
FILE: fp6_llm/Makefile
================================================
# host compiler
HOST_COMPILER ?= g++
NVCC          := nvcc -ccbin $(HOST_COMPILER)

# internal flags
NVCCFLAGS   := -m$(shell getconf LONG_BIT)
CCFLAGS     := -fPIC
LDFLAGS     :=

ALL_CCFLAGS :=
ALL_CCFLAGS += $(NVCCFLAGS)
ALL_CCFLAGS += $(addprefix -Xcompiler ,$(CCFLAGS))

################################################################################
ALL_CCFLAGS += -DNO_PYTORCH
ALL_CCFLAGS += --std=c++17
ALL_CCFLAGS += -maxrregcount=255
ALL_CCFLAGS += --use_fast_math
ALL_CCFLAGS += --ptxas-options=-v,-warn-lmem-usage,--warn-on-spills
################################################################################

ALL_LDFLAGS :=
ALL_LDFLAGS += $(ALL_CCFLAGS)
ALL_LDFLAGS += $(addprefix -Xlinker ,$(LDFLAGS))

# Common includes and paths for CUDA
INCLUDES  := -I/usr/local/cuda/include/

# Gencode arguments
SMS ?= 80
# Generate SASS code for each SM architecture listed in $(SMS)
$(foreach sm,$(SMS),$(eval GENCODE_FLAGS += -gencode arch=compute_$(sm),code=sm_$(sm)))


HEAD_FILES = csrc/fp6_linear.cuh \
			 csrc/include/configs.h \
			 csrc/include/kernel_matmul.cuh \
			 csrc/include/kernel_reduction.cuh \
			 csrc/include/ptx_cp.async.cuh \
			 csrc/include/ptx_mma.cuh \
			 csrc/include/utils_core.cuh \
			 csrc/include/utils_gmem.cuh \
			 csrc/include/utils_parallel_dequant.cuh \
			 csrc/utils/common.h \
			 csrc/utils/weight_prepacking.h \
			 csrc/utils/weight_quant.h \
			 csrc/utils/weight_dequant.h

# Target rules
all: libfp6.so

libfp6.so: fp6.o
	$(EXEC) $(NVCC) --shared $(ALL_LDFLAGS) $(GENCODE_FLAGS) -o $@ $+ $(LIBRARIES)

fp6.o:  csrc/fp6_linear.cu $(HEAD_FILES)  
	$(EXEC) $(NVCC) $(INCLUDES) $(ALL_CCFLAGS) $(GENCODE_FLAGS) -o $@ -c $<

clean:
	rm -f libfp6.so fp6.o

================================================
FILE: fp6_llm/__init__.py
================================================
from fp6_llm_cuda import linear_forward_cuda, weight_prepacking_cpu, weight_dequant_cpu, linear_forward_eXmY_cuda, weight_prepacking_eXmY_cpu, weight_dequant_eXmY_cpu



import math
def Num_Wave(M, N, SplitK, Num_GPU_SMs):
    Num_Wave = math.ceil(M/512) * math.ceil(N/64) * SplitK / Num_GPU_SMs
    return Num_Wave

# The shape of the MatMul: (M, K)*(K, N)->(M, N).
# Typically, M is the number of rows of weight matrix, and N is the inference batch size.
# Num_GPU_SMs is the number of SMs (Streaming Multiprocessors) within the GPUs, e,g. each A100 GPU has 108 SMs.
def HeuristicFuntion_SplitK(M, N, Num_GPU_SMs):
    SplitK=1
    Efficiency=0.0
    for i in range(1, 100):
        numWave = Num_Wave(M, N, i, Num_GPU_SMs)
        eff = numWave / math.ceil(numWave)
        if eff >= 0.8:
            SplitK = i
            Efficiency = eff
            break
    return SplitK
    

================================================
FILE: fp6_llm/csrc/fp6_linear.cu
================================================
#include "include/kernel_matmul.cuh"
#include "include/kernel_reduction.cuh"
#include "utils/weight_prepacking.h"
#include "utils/weight_dequant.h"
#include "utils/weight_quant.h"
#include <ATen/cuda/CUDAContext.h> // For CUDA stream management


#include <stdio.h>
#include <stdlib.h>
#include <assert.h>

template<typename TilingConfig, typename OutputDataType, int EXPONENT, int MANTISSA>
static void Kernel_Ex(cudaStream_t    stream,
                      const uint4     *Weight,
                      const half      *Scales,
                      const half      *B,
                      OutputDataType  *C,
                      const size_t    M_Global,
                      const size_t    N_Global,
                      const size_t    K_Global, 
                      int             Split_K) 
{
    #ifdef DEBUG_MODE
        printf("\n");
        printf("Launcher.cu->Kernel_Ex():\n");
        printf("M: %d, N: %d, K: %d, SplitK: %d\n", M_Global, N_Global, K_Global, Split_K);
        printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", TilingConfig::TILE_M, TilingConfig::TILE_K, TilingConfig::TILE_N);
    #endif
    static size_t SHMEM_SZ = max(TilingConfig::SMEM_SIZE_B_TILE+SMEM_SIZE_PER_TB_A_TILE, TilingConfig::SMEM_SIZE_C_TILE);
    cudaFuncSetAttribute(QUANT_GEMM_Kernel<TilingConfig, OutputDataType, EXPONENT, MANTISSA>, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ);
    size_t  dimN = (N_Global-1) / TilingConfig::TILE_N + 1;
    size_t  dimM = M_Global * Split_K / TilingConfig::TILE_M;
    dim3    GridDim(dimN, dimM, 1);
    dim3    BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1);
    //
    #ifdef DEBUG_MODE
        printf("GridDim.x: %d, GridDim.y: %d, GridDim.z: %d, BlockDim.x: %d, BlockDim.y: %d, BlockDim.z: %d SHMEM_SZ: %d\n",
                GridDim.x, GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z, SHMEM_SZ);
        printf("\n");
    #endif
    QUANT_GEMM_Kernel<TilingConfig, OutputDataType, EXPONENT, MANTISSA><<<GridDim, BlockDim, SHMEM_SZ, stream>>>
                    (Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K);
}

template<int EXPONENT, int MANTISSA>
cudaError_t fpx_linear_kernel(cudaStream_t    stream,
                              const uint4     *Weight,
                              const half      *Scales,
                              const half      *B,
                              half            *C,
                              const size_t    M_Global,
                              const size_t    N_Global,
                              const size_t    K_Global, 
                              float           *Reduction_Workspace,  // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32)
                              int             Split_K)
{
    assert(M_Global % 256 == 0);
    assert(K_Global % 64 == 0);
    assert(N_Global>0);

    // Work around to support more N shapes:
    size_t N_PowerOf2;
    if(N_Global>0 &&  N_Global<=8)      N_PowerOf2 = 8;
    if(N_Global>8 &&  N_Global<=16)     N_PowerOf2 = 16;
    if(N_Global>16 && N_Global<=32)     N_PowerOf2 = 32;
    if(N_Global>32 && N_Global<=64)     N_PowerOf2 = 64;
    if(N_Global>64 && N_Global<=128)    N_PowerOf2 = 128;
    if(N_Global>128)                    N_PowerOf2 = ((N_Global-1)/128+1) * 128;

    if (Split_K == 1) {
        switch (N_PowerOf2) {
            case 8:     Kernel_Ex<TilingConfig<4, 1, 1>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K);  break;
            case 16:    Kernel_Ex<TilingConfig<4, 1, 2>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K);  break;
            case 32:    Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K);  break;
            case 64:    Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K);  break;
            case 128:   Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K);  break;
            default:    if (N_PowerOf2 % 128 != 0) {
                            printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
                            return cudaErrorUnknown;
                        }
                        Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K);  break;
        }
    }
    else {
        switch (N_PowerOf2) {
            case 8:     Kernel_Ex<TilingConfig<4, 1, 1>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K);  break;
            case 16:    Kernel_Ex<TilingConfig<4, 1, 2>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K);  break;
            case 32:    Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K);  break;
            case 64:    Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K);  break;
            case 128:   Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K);  break;
            default:    if (N_PowerOf2 % 128 != 0) {
                            printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
                            return cudaErrorUnknown;
                        }
                        Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K);  break;
        }
        // Reduction for SplitK
        dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1);
        dim3 BlockDim(WARP_SIZE, 1, 1);
        SplitK_Reduction<<<GridDim, BlockDim, 0, stream>>>(C, Reduction_Workspace, M_Global, N_Global, Split_K);
    }
    return cudaGetLastError();
}

cudaError_t fp6_linear_kernel(
    cudaStream_t    stream,
    const uint4     *Weight,
    const half      *Scales,
    const half      *B,
    half            *C,
    const size_t    M_Global,
    const size_t    N_Global,
    const size_t    K_Global, 
    float           *Reduction_Workspace,
    int             Split_K) {
    //
    return fpx_linear_kernel<3,2>( stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,  Reduction_Workspace, Split_K);
}

cudaError_t fp_eXmY_linear_kernel(
    const int       EXPONENT,
    const int       MANTISSA,
    cudaStream_t    stream,
    const uint4     *Weight,
    const half      *Scales,
    const half      *B,
    half            *C,
    const size_t    M_Global,
    const size_t    N_Global,
    const size_t    K_Global, 
    float           *Reduction_Workspace,
    int             Split_K) {
    //
    if(EXPONENT==2 && MANTISSA==2)
        return fpx_linear_kernel<2,2>( stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,  Reduction_Workspace, Split_K);
    if(EXPONENT==3 && MANTISSA==2)
        return fpx_linear_kernel<3,2>( stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,  Reduction_Workspace, Split_K);
    printf("QuantLLM_API Error: Unsupported EXPONENT=%d, MANTISSA=%d!\n", EXPONENT, MANTISSA);
    exit(-1);
}

#ifndef NO_PYTORCH
#include <torch/extension.h>
#include <ATen/ATen.h>

/////////////////////////////////////////////////// Old Interface only Supporting FP6 /////////////////////////////////////////////////////////////////////
/*
Computes FP6-FP16 GEMM (PyTorch interface).

[Mathmatical Formula]
Standard definition of linear layer:    Out = In * trans(W), where In, Out, and W are stored in row-major.
After Equivalent transformation    :    trans(Out) = W * trans(In). Note that we do not perform "transpose" during runtime, we instead interpret the In/Out as column-major matrices when calling our CUDA kernel.

[Inputs]
  _in_feats:  tensor of shape [B, IC];                  // half 
  _weights:   int tensor of shape [OC, IC // 16 * 3];   // 3 INT32 words contains 16 FP6 weights.
  _scales:    tensor of shape [OC];                     // half
  splitK:     spliting the MatMul problem along K dimension for higher GPU utilization, default 1.
[Outputs]
  _out_feats: tensor of shape [B, OC];                  // half
*/
torch::Tensor fp6_linear_forward_cuda(
    torch::Tensor _in_feats,
    torch::Tensor _weights,
    torch::Tensor _scales,
    int           splitK=1)
{
    int num_in_feats      = _in_feats.size(0);
    int num_in_channels   = _in_feats.size(1);
    int num_out_channels  = _weights.size(0);
    assert( num_in_channels%64 == 0 );
    assert( (num_in_channels/16*3) == _weights.size(1) );    // Making sure the K dimension is matched.
    //
    int M = num_out_channels;
    int K = num_in_channels;
    int N = num_in_feats;
    // Input Tensors
    auto weight = reinterpret_cast<const uint4*>(_weights.data_ptr<int>());  // weights is [OC, IC] but in FP6.
    auto in_feats = reinterpret_cast<const half*>(_in_feats.data_ptr<at::Half>());
    auto scales   = reinterpret_cast<const half*>(_scales.data_ptr<at::Half>());
    // Output Tensors
    auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
    at::Tensor _out_feats = torch::empty({num_in_feats, num_out_channels}, options);
    auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());

    options = torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device());
    at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options);
    auto Reduction_Workspace = reinterpret_cast<float*>(_workspace.data_ptr<float>());  // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32)

    // Get the current device and stream
    int device_id = _in_feats.device().index();
    auto stream = at::cuda::getCurrentCUDAStream(device_id).stream();

    fp6_linear_kernel(stream, // Using current stream.
                      weight,
                      scales,
                      in_feats,
                      out_feats,
                      M,
                      N,
                      K, 
                      Reduction_Workspace,  
                      splitK);

    return _out_feats;
}


/*
 * Weight prepacking (Pytorch interface).
 * [Input & Output]
 *  fp6_tensor: int tensor of shape [OC, IC // 16 * 3];   // 3 INT32 words contains 16 FP6 weights.
 * [Output]
 *  packed_tensor: int tensor of shape [OC, IC // 16 * 3];
 */
torch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor)
{
    size_t OC = fp6_tensor.size(0);
    size_t IC = fp6_tensor.size(1);
    assert (IC%3==0);   
    IC = IC*16/3;
    assert( (OC%256==0) && (IC%64==0) );
    auto packed_tensor = torch::empty_like(fp6_tensor);
    auto packed_tensor_ptr = reinterpret_cast<int*>(packed_tensor.data_ptr<int>());
    auto fp6_tensor_ptr = reinterpret_cast<int*>(fp6_tensor.data_ptr<int>());
    weight_matrix_prepacking(packed_tensor_ptr, fp6_tensor_ptr, OC, IC);
    return packed_tensor;
}

/*
 * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs.
 * A useful tool to construct input matrices for the FP16 GEMM baseline.
 * [Input]
 *  fp6_tensor:  int  tensor of shape [OC, IC // 16 * 3];   // 3 INT32 words contains 16 FP6  weights.
 *  fp16_scale:  half tensor of shape [OC];                 // for row-wise quantization.
 * [Output]
 *  fp16_tensor: half tensor of shape [OC, IC].     
 */
torch::Tensor weight_matrix_dequant_cpu(torch::Tensor fp6_tensor, torch::Tensor fp16_scale) 
{
    int OC = fp6_tensor.size(0);
    assert(fp6_tensor.size(1) % 3 == 0);
    int IC = fp6_tensor.size(1) / 3 * 16;
    assert(fp16_scale.size(0)==OC);
    //
    auto fp6_tensor_ptr = reinterpret_cast<int*>(fp6_tensor.data_ptr<int>());
    auto fp16_scale_ptr = reinterpret_cast<half*>(fp16_scale.data_ptr<at::Half>());
    //
    auto options = torch::TensorOptions().dtype(fp16_scale.dtype()).device(fp16_scale.device());
    at::Tensor fp16_tensor = torch::empty({OC, IC}, options);
    auto fp16_tensor_ptr = reinterpret_cast<half*>(fp16_tensor.data_ptr<at::Half>());
    //
    DeQuantMatrix_FP6_To_FP16(fp16_tensor_ptr, (unsigned char*)fp6_tensor_ptr, OC, IC, fp16_scale_ptr);
    //
    return fp16_tensor;
}

/////////////////////////////////////////////////// New Interface Supporting FPx /////////////////////////////////////////////////////////////////////
/*
Computes FPx-FP16 GEMM (PyTorch interface).

[Mathmatical Formula]
Standard definition of linear layer:    Out = In * trans(W), where In, Out, and W are stored in row-major.
After Equivalent transformation    :    trans(Out) = W * trans(In). Note that we do not perform "transpose" during runtime, we instead interpret the In/Out as column-major matrices when calling our CUDA kernel.

[Inputs]
  _in_feats:  tensor of shape [B, IC];                  // half 
  _weights:   int tensor of shape [OC, IC // 32 * x];   // x INT32 words contains 32 FPx weights.
  _scales:    tensor of shape [OC];                     // half
  splitK:     spliting the MatMul problem along K dimension for higher GPU utilization, default 1.
[Outputs]
  _out_feats: tensor of shape [B, OC];                  // half
*/
torch::Tensor fp_eXmY_linear_forward_cuda(
    int             EXPONENT,
    int             MANTISSA,
    torch::Tensor   _in_feats,
    torch::Tensor   _weights,
    torch::Tensor   _scales,
    int             splitK=1)
{
    int num_in_feats      = _in_feats.size(0);
    int num_in_channels   = _in_feats.size(1);
    int num_out_channels  = _weights.size(0);
    assert( num_in_channels%64 == 0 );
    assert( (num_in_channels/32*(1+EXPONENT+MANTISSA)) == _weights.size(1) );    // Making sure the K dimension is matched.
    //
    int M = num_out_channels;
    int K = num_in_channels;
    int N = num_in_feats;
    // Input Tensors
    auto weight = reinterpret_cast<const uint4*>(_weights.data_ptr<int>());  // weights is [OC, IC] but in FP6.
    auto in_feats = reinterpret_cast<const half*>(_in_feats.data_ptr<at::Half>());
    auto scales   = reinterpret_cast<const half*>(_scales.data_ptr<at::Half>());
    // Output Tensors
    auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
    at::Tensor _out_feats = torch::empty({num_in_feats, num_out_channels}, options);
    auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());

    options = torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device());
    at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options);
    auto Reduction_Workspace = reinterpret_cast<float*>(_workspace.data_ptr<float>());  // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32)

    // Get the current device and stream
    int device_id = _in_feats.device().index();
    auto stream = at::cuda::getCurrentCUDAStream(device_id).stream();

    //
    fp_eXmY_linear_kernel(
        EXPONENT,
        MANTISSA,
        stream, // Using torch's current stream here.
        weight,
        scales,
        in_feats,
        out_feats,
        M,
        N,
        K, 
        Reduction_Workspace,  
        splitK);
    return _out_feats;
}


/*
 * Weight prepacking (Pytorch interface).
 * [Input & Output]
 *  fpx_tensor: int tensor of shape [OC, IC // 32 * x];
 * [Output]
 *  packed_tensor: int tensor of shape [OC, IC // 32 * x];
 */
torch::Tensor weight_matrix_prepacking_fp_eXmY_cpu(
    int EXPONENT,
    int MANTISSA,
    torch::Tensor fpx_tensor)
{
    int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
    //
    size_t OC = fpx_tensor.size(0);
    size_t IC = fpx_tensor.size(1);
    assert (IC%BIT_WIDTH==0);   
    IC = IC*32/BIT_WIDTH;
    assert( (OC%256==0) && (IC%64==0) );
    auto packed_tensor = torch::empty_like(fpx_tensor);
    auto packed_tensor_ptr = reinterpret_cast<int*>(packed_tensor.data_ptr<int>());
    auto fpx_tensor_ptr = reinterpret_cast<int*>(fpx_tensor.data_ptr<int>());
    //
    weight_matrix_prepacking_fp_eXmY(EXPONENT, MANTISSA, packed_tensor_ptr, fpx_tensor_ptr, OC, IC);
    return packed_tensor;
}

/*
 * Dequant a FPx matrix to a equivalent FP16 matrix using CPUs.
 * A useful tool to construct input matrices for the FP16 GEMM baseline.
 * [Input]
 *  fpx_tensor:  int  tensor of shape [OC, IC // 32 * x];   //
 *  fp16_scale:  half tensor of shape [OC];                 // for row-wise quantization.
 * [Output]
 *  fp16_tensor: half tensor of shape [OC, IC].     
 */
torch::Tensor weight_matrix_dequant_fp_eXmY_cpu(
    int EXPONENT,
    int MANTISSA,
    torch::Tensor fpx_tensor,
    torch::Tensor fp16_scale) 
{
    int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
    //
    int OC = fpx_tensor.size(0);
    assert(fpx_tensor.size(1) % BIT_WIDTH == 0);
    int IC = fpx_tensor.size(1) / BIT_WIDTH * 32;
    assert(fp16_scale.size(0)==OC);
    //
    auto fpx_tensor_ptr = reinterpret_cast<int*>(fpx_tensor.data_ptr<int>());
    auto fp16_scale_ptr = reinterpret_cast<half*>(fp16_scale.data_ptr<at::Half>());
    //
    auto options = torch::TensorOptions().dtype(fp16_scale.dtype()).device(fp16_scale.device());
    at::Tensor fp16_tensor = torch::empty({OC, IC}, options);
    auto fp16_tensor_ptr = reinterpret_cast<half*>(fp16_tensor.data_ptr<at::Half>());
    //
    dequant_matrix_fp_eXmY_to_fp16(EXPONENT, MANTISSA, fp16_tensor_ptr, (unsigned char*)fpx_tensor_ptr, OC, IC, fp16_scale_ptr);
    //
    return fp16_tensor;
}
#endif

================================================
FILE: fp6_llm/csrc/fp6_linear.cuh
================================================
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

/*
* Computes FP6-FP16 GEMM (C++ interface).
*/
cudaError_t fp6_linear_kernel(
    cudaStream_t    stream,
    const uint4     *Weight,
    const half      *Scales,
    const half      *B,
    half            *C,
    const size_t    M_Global,
    const size_t    N_Global,
    const size_t    K_Global, 
    float           *Reduction_Workspace,  // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32)
    int             Split_K);

cudaError_t fp_eXmY_linear_kernel(
    const int       EXPONENT,
    const int       MANTISSA,
    cudaStream_t    stream,
    const uint4     *Weight,
    const half      *Scales,
    const half      *B,
    half            *C,
    const size_t    M_Global,
    const size_t    N_Global,
    const size_t    K_Global, 
    float           *Reduction_Workspace,
    int             Split_K);
/*
 * In-place weight prepacking (C++ interface).
 */
void weight_matrix_prepacking(int* packed_weights, int *FP6Weights, size_t M, size_t K);
void weight_matrix_prepacking_fp_eXmY(const int EXPONENT, const int MANTISSA, int* packed_weights, int *FPxWeights, size_t M, size_t K);

/*
 * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs.
 */
void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale);
void dequant_matrix_fp_eXmY_to_fp16(const int EXPONENT, const int MANTISSA, half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale);

#ifndef NO_PYTORCH
#include <torch/extension.h>
/*
* Computes FP6-FP16 GEMM (PyTorch interface).
*/
torch::Tensor fp6_linear_forward_cuda(
    torch::Tensor _in_feats,
    torch::Tensor _weights,
    torch::Tensor _scales,
    int           splitK=1);
torch::Tensor fp_eXmY_linear_forward_cuda(
    int             EXPONENT,
    int             MANTISSA,
    torch::Tensor   _in_feats,
    torch::Tensor   _weights,
    torch::Tensor   _scales,
    int             splitK=1);


/*
 * Weight prepacking (Pytorch interface).
 */
torch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor);
torch::Tensor weight_matrix_prepacking_fp_eXmY_cpu(
    int EXPONENT,
    int MANTISSA,
    torch::Tensor fpx_tensor);


/*
 * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs.
 * A useful tool to construct input matrices for the FP16 GEMM baseline.
 * [Input]
 *  fp6_tensor:  int  tensor of shape [OC, IC // 16 * 3];   // 3 INT32 words contains 16 FP6  weights.
 *  fp16_scale:  half tensor of shape [OC];                 // for row-wise quantization.
 * [Output]
 *  fp16_tensor: half tensor of shape [OC, IC].     
 */
torch::Tensor weight_matrix_dequant_cpu(
    torch::Tensor fp6_tensor, 
    torch::Tensor fp16_scale);
torch::Tensor weight_matrix_dequant_fp_eXmY_cpu(
    int EXPONENT,
    int MANTISSA,
    torch::Tensor fpx_tensor,
    torch::Tensor fp16_scale);
#endif

================================================
FILE: fp6_llm/csrc/include/configs.h
================================================
#ifndef CONFIGS_H
#define CONFIGS_H

//#define DEBUG_MODE
#define PIPELINE_LEVEL_GMEM 2
#define PIPELINE_LEVEL_SMEM 2       // only support 2

/************************ Hardware Parameters ************************/
#define WARP_SIZE                           32
#define REG_BIT_WIDTH                       32
// mma: M=16 K=16 N=8
#define MMA_8                               8
#define MMA_16                              16
// for memory access
#define THREAD_OPT_ACCESS_BIT_WIDTH_128     128 // LDS.128, cp_async.128, ...
#define BIT_WIDTH_PER_HALF                  16  // Half precision: FP16

/******************** Register Allocation For GEMM ********************/
#define REG_PER_THREAD_C_TENSOR_16_16       8   // 8 for FP32 Accumulation
/********************** Memory Padding Parameters **********************/   
// Eliminating bank-conflict
#define PADDING_BYTES_16                    16 // Padding 16 bytes each column
#define PADDING_SHARED_MEM_FOR_B_8          8  // Padding 8 half  each column, during CopyFromGlobalToShared() for B
#define PADDING_SHARED_MEM_FOR_C_4          4  // Padding 4 float each column, during StoreToSharedMemoryFromRegister() for C
/************************* WARP Tiling part-1 *************************/
#define WARP_ROW_MMA_TENSORS                4
#define WARP_M                              (WARP_ROW_MMA_TENSORS * MMA_16)       // 64
#define WARP_K_MMA_TENSORS                  4
#define WARP_K                              (WARP_K_MMA_TENSORS   * MMA_16)       // 64
template<int BLOCK_ROW_WARPS_, int BLOCK_COL_WARPS_, int WARP_COL_MMA_TENSORS_>
struct TilingConfig {
    // Depending on "n" dimension of the GEMM
    static constexpr int BLOCK_ROW_WARPS        = BLOCK_ROW_WARPS_;
    static constexpr int BLOCK_COL_WARPS        = BLOCK_COL_WARPS_;
    static constexpr int WARP_COL_MMA_TENSORS   = WARP_COL_MMA_TENSORS_;    
    /************************* WARP Tiling part-2 *************************/
    static constexpr int WARP_N                 = WARP_COL_MMA_TENSORS * MMA_8;
    /*************************Thread Block Tiling *************************/
    static constexpr int TILE_M                 = WARP_M * BLOCK_ROW_WARPS;
    static constexpr int TILE_N                 = MMA_8  * WARP_COL_MMA_TENSORS * BLOCK_COL_WARPS;
    static constexpr int TILE_K                 = WARP_K;
    /********************** #Thread per Thread Block **********************/
    static constexpr int BLOCK_WARPS   = BLOCK_ROW_WARPS * BLOCK_COL_WARPS;
    static constexpr int BLOCK_THREADS = BLOCK_WARPS * WARP_SIZE;
    /******************************* Others *******************************/
    static constexpr int SMEM_SIZE_B_TILE   = TILE_N * (TILE_K + PADDING_BYTES_16) * 2 * PIPELINE_LEVEL_GMEM;          // sizeof(half)=2, doubleBuffer=2
    static constexpr int SMEM_SIZE_C_TILE   = TILE_N * (TILE_M + PADDING_BYTES_16) * 4;                             // sizeof(float)=4
};



#endif  // CONFIGS_H

================================================
FILE: fp6_llm/csrc/include/kernel_matmul.cuh
================================================
#include "configs.h"
#include "utils_gmem.cuh"
#include "utils_core.cuh"

/************************** Bitwidth of Weight Segments ************************/
#define BIT_WIDTH_1                 1
#define BIT_WIDTH_2                 2
#define BIT_WIDTH_4                 4
/*************************** 64*64 Weghts of Weight Matrix *********************/
#define WEIGHT_PER_WARP            (WARP_M*WARP_K)                                                            // 64*64 = 4096
#define SMEM_SIZE_PER_WARP_1BIT    (WEIGHT_PER_WARP*BIT_WIDTH_1/8)                                            // 512 Bytes,  doubleBuffer not taken into consideration
#define SMEM_SIZE_PER_WARP_2BIT    (WEIGHT_PER_WARP*BIT_WIDTH_2/8)                                            // 1024 Bytes, doubleBuffer not taken into consideration
#define SMEM_SIZE_PER_WARP_4BIT    (WEIGHT_PER_WARP*BIT_WIDTH_4/8)                                            // 2048 Bytes, doubleBuffer not taken into consideration
#define SMEM_SIZE_PER_TB_1BIT      (SMEM_SIZE_PER_WARP_1BIT*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM)    // #WARP=4; Trible-Buffer for 3-level pipeline for A = 6 KB;  double buffer for 2-level pipeline A= 4  KB. 
#define SMEM_SIZE_PER_TB_2BIT      (SMEM_SIZE_PER_WARP_2BIT*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM)    // #WARP=4; Trible-Buffer for 3-level pipeline for A = 12 KB; double buffer for 2-level pipeline A= 8  KB.   
#define SMEM_SIZE_PER_TB_4BIT      (SMEM_SIZE_PER_WARP_4BIT*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM)    // #WARP=4; Trible-Buffer for 3-level pipeline for A = 24 KB; double buffer for 2-level pipeline A= 16 KB.
#define SMEM_SIZE_PER_TB_A_TILE    (SMEM_SIZE_PER_TB_1BIT+SMEM_SIZE_PER_TB_2BIT+SMEM_SIZE_PER_TB_4BIT)        // used in fp6_linear.cu, Kernel_Ex().
/******************** Gloabl Memory Layout For QUANTIZED DATA *******************/
#define NUM_INT4_PER_WARP_1BIT     (WEIGHT_PER_WARP*BIT_WIDTH_1/128)                                          // 32
#define NUM_INT4_PER_WARP_2BIT     (WEIGHT_PER_WARP*BIT_WIDTH_2/128)                                          // 64
#define NUM_INT4_PER_WARP_4BIT     (WEIGHT_PER_WARP*BIT_WIDTH_4/128)                                          // 128

/*
 * C = A*B
 * A: row major with ahead-of-time layout transformation, FP6
 * B: col major, FP16
 * C: col major, FP16
 */ 
 template<typename TilingConfig, typename OutputDataType, int EXPONENT, int MANTISSA>
__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
                                  const half *B,
                                  OutputDataType* C,
                                  const size_t M_Global, const size_t N_Global, const size_t K_Global,
                                  int Split_K)
{
  #ifdef DEBUG_MODE
    assert(K_Global%TilingConfig::TILE_K==0);
    assert(M_Global%TilingConfig::TILE_M==0);
    assert( gridDim.y == Split_K * (M_Global/TilingConfig::TILE_M));
  #endif
  // 1+2+4 weight split
  constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
  constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
  constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
  constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;
  const uint4* Weight_1bit = Weight;
  const uint4* Weight_2bit = Weight_1bit + (USE_SEG_1BIT ? M_Global*K_Global*BIT_WIDTH_1/128 : 0);
  const uint4* Weight_4bit = Weight_2bit + (USE_SEG_2BIT ? M_Global*K_Global*BIT_WIDTH_2/128 : 0);
  // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned
  extern __shared__ __align__(128) half smem[];   
  half (*smem_array)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = reinterpret_cast<half (*)[WARP_K+PADDING_SHARED_MEM_FOR_B_8]> ( smem + SMEM_SIZE_PER_TB_A_TILE/2 ); // Dynamic shared memory for FP16 B tiles
  __shared__ half QuantScales[64*TilingConfig::BLOCK_WARPS];  // static shared memory for quantization scales, 64 row per warp * 4 warps = 512 Bytes
  // Thread Block Mapping, considering SplitK
  const size_t BatchID = blockIdx.y / (M_Global/TilingConfig::TILE_M);
  const size_t x = blockIdx.x;                                     // Output Block ID: (BlockID_Row = y; BlockID_Col = x )
  const size_t y = blockIdx.y % (M_Global/TilingConfig::TILE_M);   // Output Block ID: (BlockID_Row = y; BlockID_Col = x )
  const size_t Tile_Start_M = y * TilingConfig::TILE_M;
  const size_t Tile_Start_N = x * TilingConfig::TILE_N;
  const size_t NumColumnToCopy = (N_Global-Tile_Start_N) < TilingConfig::TILE_N ? (N_Global-Tile_Start_N) : TilingConfig::TILE_N;
  const size_t NumBlock_K = K_Global/TilingConfig::TILE_K;
  const size_t AverageNumBlock_K = NumBlock_K/Split_K;
  const size_t ExtraNumBlock_K   = NumBlock_K - AverageNumBlock_K * Split_K;
  size_t NumIter = AverageNumBlock_K;
  size_t StartBlockID_K = AverageNumBlock_K*BatchID;  
  if(BatchID<ExtraNumBlock_K){
    NumIter ++; 
    StartBlockID_K += BatchID;
  }
  else
    StartBlockID_K += ExtraNumBlock_K;
  // Warp ID.
  const int warpId = threadIdx.x / WARP_SIZE;
  int WARP_i = warpId / TilingConfig::BLOCK_COL_WARPS;  // WARP_i: row number;  WARP_j: column number
  //int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS;
  // Global Memory Address for Matrix A (Weight) /////////////////////////////////////////////////////////////////////////
  // StartPTR for each ThreadBlock(TB)
  const uint4* TB_StartGPTR_A_1BIT = Weight_1bit + (y*TilingConfig::BLOCK_ROW_WARPS)*NumBlock_K * NUM_INT4_PER_WARP_1BIT;
  const uint4* TB_StartGPTR_A_2BIT = Weight_2bit + (y*TilingConfig::BLOCK_ROW_WARPS)*NumBlock_K * NUM_INT4_PER_WARP_2BIT;
  const uint4* TB_StartGPTR_A_4BIT = Weight_4bit + (y*TilingConfig::BLOCK_ROW_WARPS)*NumBlock_K * NUM_INT4_PER_WARP_4BIT;
  // StartPTR for each WARP.
  const uint4* WARP_StartGPTR_A_1BIT  = TB_StartGPTR_A_1BIT + WARP_i * NumBlock_K * NUM_INT4_PER_WARP_1BIT;
  const uint4* WARP_StartGPTR_A_2BIT  = TB_StartGPTR_A_2BIT + WARP_i * NumBlock_K * NUM_INT4_PER_WARP_2BIT;
  const uint4* WARP_StartGPTR_A_4BIT  = TB_StartGPTR_A_4BIT + WARP_i * NumBlock_K * NUM_INT4_PER_WARP_4BIT;
  // StartPTR for each WARP, considering SplitK
  const size_t WARP_Start_UnitID_K = StartBlockID_K;
  WARP_StartGPTR_A_1BIT  += WARP_Start_UnitID_K * NUM_INT4_PER_WARP_1BIT;
  WARP_StartGPTR_A_2BIT  += WARP_Start_UnitID_K * NUM_INT4_PER_WARP_2BIT;
  WARP_StartGPTR_A_4BIT  += WARP_Start_UnitID_K * NUM_INT4_PER_WARP_4BIT;
  // Copying A tile from Global to Shared, using double-buffer //////////////////////////////////////////////////////////
  // StartSPTR for each ThreadBlock
  uint32_t* AFrag_1BIT_SPTR = reinterpret_cast<uint32_t*>(smem);
  uint32_t* AFrag_2BIT_SPTR = AFrag_1BIT_SPTR + SMEM_SIZE_PER_TB_1BIT/4;
  uint32_t* AFrag_4BIT_SPTR = AFrag_2BIT_SPTR + SMEM_SIZE_PER_TB_2BIT/4;   // 8 buffers including double buffers, 12 for trible buffers
  // StartSPTR for each WARP
  AFrag_1BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_1BIT/4;
  AFrag_2BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_2BIT/4;
  AFrag_4BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_4BIT/4;
  // Pre-fetch of A tile
  for(int i=0; i<PIPELINE_LEVEL_GMEM-1; i++) {
    if(USE_SEG_1BIT) CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_1BIT>(AFrag_1BIT_SPTR+i*SMEM_SIZE_PER_WARP_1BIT/4*4, WARP_StartGPTR_A_1BIT);
    if(USE_SEG_2BIT) CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_2BIT>(AFrag_2BIT_SPTR+i*SMEM_SIZE_PER_WARP_2BIT/4*4, WARP_StartGPTR_A_2BIT);
    if(USE_SEG_4BIT) CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_4BIT>(AFrag_4BIT_SPTR+i*SMEM_SIZE_PER_WARP_4BIT/4*4, WARP_StartGPTR_A_4BIT);
    WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT/16;
    WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT/16;
    WARP_StartGPTR_A_4BIT += SMEM_SIZE_PER_WARP_4BIT/16;
  }
  // Global Memory Address for Matrix A (QuantScale) /////////////////////////////////////////////////////////////////////
  const half* TB_StartGPTR_A_Scale    = Scales + (y*TilingConfig::BLOCK_ROW_WARPS) * 64;
  const half* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64;
  CopyFromGlobalToShared_Scales(QuantScales+WARP_i*64, WARP_StartGPTR_A_Scales);
  // Copying B tile from Global to Shared, considering SplitK /////////////////////////////////////////////////////////////
  const half *BTile_GPTR = B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K;
  for(int i=0; i<PIPELINE_LEVEL_GMEM-1; i++) {
    CopyFromGlobalToShared<TilingConfig::TILE_N, TilingConfig::BLOCK_WARPS> (smem_array+i*TilingConfig::TILE_N, BTile_GPTR, K_Global, NumColumnToCopy);
    BTile_GPTR += TilingConfig::TILE_K;    
  }
  // Register Allocation for A,B, and C, Initilazed to Zeros /////////////////////////////////////////////////////////////////////
  constexpr int NumRegSets_a = WARP_ROW_MMA_TENSORS;                                                                  // 1 set = 4 registers, containing a 16*16 MMA block
  constexpr int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2;    // 1 set = 4 registers, containing a 16*16 MMA block
  uint32_t a  [NumRegSets_a * PIPELINE_LEVEL_SMEM][4];      // double/Trible buffer is used // Registers to store decompressed FP6
  uint32_t b  [NumRegSets_b * PIPELINE_LEVEL_SMEM][4];      // double/Triple buffer is used // Register to store FP16 B matrix (a slice)
  float c[NumRegSets_a * NumRegSets_b][REG_PER_THREAD_C_TENSOR_16_16];
  for(int i=0; i<NumRegSets_a * NumRegSets_b; i++)
    for(int j=0; j<REG_PER_THREAD_C_TENSOR_16_16; j++)
      c[i][j] = 0.0f;
  //
  cp_async_wait_all();
  __syncthreads();

  /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  uint32_t Scales_RPTR[4]; // 4 Registers per thread for Quantization Scales
  ExtractFromSharedToReg_Scales(Scales_RPTR, QuantScales + WARP_i*64);
  // Initializing the Software Pipeline: writing registers. ////////////////////////////////////////////////////////////////////////////////////////////////
  initialize_mma_slice<TilingConfig, EXPONENT, MANTISSA>(a, b, AFrag_1BIT_SPTR, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR);
  // The outer loop. /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  #pragma unroll(1)
  for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++)
  {
    // Trible-Buffer for A Tile
    uint32_t* __restrict__ read_SPTR_Frag_1bit  = AFrag_1BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT/4*4; // 512  (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
    uint32_t* __restrict__ read_SPTR_Frag_2bit  = AFrag_2BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
    uint32_t* __restrict__ read_SPTR_Frag_4bit  = AFrag_4BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
    uint32_t* __restrict__ read2_SPTR_Frag_1bit = AFrag_1BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT/4*4;
    uint32_t* __restrict__ read2_SPTR_Frag_2bit = AFrag_2BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT/4*4;
    uint32_t* __restrict__ read2_SPTR_Frag_4bit = AFrag_4BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4;
    uint32_t* __restrict__ write_SPTR_Frag_1bit = AFrag_1BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1))  % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT/4*4; // 512  (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
    uint32_t* __restrict__ write_SPTR_Frag_2bit = AFrag_2BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1))  % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
    uint32_t* __restrict__ write_SPTR_Frag_4bit = AFrag_4BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1))  % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
    // Trible-Buffer for B Tile
    half __restrict__ (*read_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0)  % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
    half __restrict__ (*read2_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
    half __restrict__ (*write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1))  % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
    //
    bool GlobalCopy = (tile_id_k+PIPELINE_LEVEL_GMEM-1) < NumIter;
    // Copying A tile from Global to Register, Bypassing L1, using double-buffer
    if(USE_SEG_1BIT) CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_1BIT>(write_SPTR_Frag_1bit, WARP_StartGPTR_A_1BIT, GlobalCopy);
    if(USE_SEG_2BIT) CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_2BIT>(write_SPTR_Frag_2bit, WARP_StartGPTR_A_2BIT, GlobalCopy);
    if(USE_SEG_4BIT) CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_4BIT>(write_SPTR_Frag_4bit, WARP_StartGPTR_A_4BIT, GlobalCopy);
    // copying B tile from GlobalMemory to SharedMemory
    CopyFromGlobalToShared<TilingConfig::TILE_N, TilingConfig::BLOCK_WARPS> (write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy);
    cp_async_group_commit();
    core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each WARP; read_SPTR is shared among WARPs
    core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 2);
    core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 3);
    // Barriers and Synchronizations
    cp_async_wait_group<PIPELINE_LEVEL_GMEM-2>();
    __syncthreads();
    core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit, read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0);
    // Updating global PTRs
    WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT/16;  // 2KB/16=128 (1)/16: int4*+1 = char*+16
    WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT/16;  // 4KB/16=256 (1)/16: int4*+1 = char*+16
    WARP_StartGPTR_A_4BIT += SMEM_SIZE_PER_WARP_4BIT/16;  // 8KB/16=512 (1)/16: int4*+1 = char*+16
    BTile_GPTR += TilingConfig::TILE_K;
  }
  /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  // Store the C fragments to shared memory.
  float (*smem_CFrag) [TilingConfig::TILE_M+PADDING_SHARED_MEM_FOR_C_4] =
        reinterpret_cast <float (*)[TilingConfig::TILE_M+PADDING_SHARED_MEM_FOR_C_4]> (smem);
  StoreToSharedMemoryFromRegister<TilingConfig>(smem_CFrag, c);
  __syncthreads();
  // Now that shared memory contains all the D tiles, stream them to global memory.
  OutputDataType* BlockGlobalPTR = C + BatchID*(M_Global*N_Global) + Tile_Start_M + Tile_Start_N*M_Global;
  for(size_t i=warpId; i<NumColumnToCopy; i+=TilingConfig::BLOCK_WARPS)    // i-th column
    #pragma unroll
    for(size_t j=threadIdx.x%WARP_SIZE; j<TilingConfig::TILE_M; j+=WARP_SIZE) // j-th row
    {
      if constexpr (std::is_same<OutputDataType, half>::value)   BlockGlobalPTR[j+i*M_Global] = __float2half_rn(smem_CFrag[i][j]);
      else                                            BlockGlobalPTR[j+i*M_Global] = smem_CFrag[i][j];
    }
}

================================================
FILE: fp6_llm/csrc/include/kernel_reduction.cuh
================================================
/***************************************************************************
 * Copyright 2023 The FLash-LLM Authors. All rights reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 ***************************************************************************/
// Used for the reduction of result matrix if Split-K is used
// Reduction_Workspace:     (Split_K, M_Global, N_Global),  column major
// C:                       (M_Global, N_Global),           column major
// Each thread deals with 8 output elements, each elements is the sum of Split_K elements
//      Read  Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 float_per_thread (256bit) -> 256 float per warp
//      Write Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 half_per_thread  (128bit) -> 256 half  per warp
// GridSize = (M_Global*N_Global) / 256

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#define REDUCTION_ELEMENT_PER_THREADBLOCK   256
#define HALF_PER_128BIT                     8

__global__ void SplitK_Reduction(half* C, float* Reduction_Workspace, size_t M_Global, size_t N_Global, int Split_K)
{
    half*  WARP_GPTR_C      = C                    + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x;
    float* WARP_GPTR_R      = Reduction_Workspace  + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x;
    half*  THREAD_GPTR_C    = WARP_GPTR_C          + threadIdx.x * HALF_PER_128BIT;    
    float* THREAD_GPTR_R    = WARP_GPTR_R          + threadIdx.x * HALF_PER_128BIT;
    // Initializing Thread-Local Results
    float Results[HALF_PER_128BIT];
    #pragma unroll
    for (int i = 0; i < HALF_PER_128BIT; i++)       Results[i] = 0.0f;
    // Reduction
    for (int i = 0; i < Split_K; i++) {
        #pragma unroll
        for (int j = 0; j < HALF_PER_128BIT; j++)   Results[j] += THREAD_GPTR_R[j];
        THREAD_GPTR_R += M_Global * N_Global;
    }
    // Writing to global memory
    #pragma unroll
    for (int i = 0; i < HALF_PER_128BIT; i++)       THREAD_GPTR_C[i] = __float2half_rn(Results[i]);
}


================================================
FILE: fp6_llm/csrc/include/ptx_cp.async.cuh
================================================
/***************************************************************************
 * Copyright 2023 The FLash-LLM Authors. All rights reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 ***************************************************************************/
// Extended from CUTLASS's source code

#ifndef PTX_CP_ASYNC_CUH
#define PTX_CP_ASYNC_CUH

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

template<int SizeInBytes>
__device__ __forceinline__ void cp_async(half* smem_ptr, const half* global_ptr, bool pred_guard = true)
{
    static_assert(SizeInBytes == 16, "Size is not supported");
    unsigned smem_int_ptr = __cvta_generic_to_shared(smem_ptr);
    asm volatile("{ \n"
                 "  .reg .pred p;\n"
                 "  setp.ne.b32 p, %0, 0;\n"
                 "  @p cp.async.cg.shared.global [%1], [%2], %3;\n"
                 "}\n" ::"r"((int)pred_guard),
                 "r"(smem_int_ptr),
                 "l"(global_ptr),
                 "n"(SizeInBytes));
}

/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block.
__device__ __forceinline__ void cp_async_group_commit()
{
    asm volatile("cp.async.commit_group;\n" ::);
}

/// Blocks until all but <N> previous cp.async.commit_group operations have committed.
template<int N>
__device__ __forceinline__ void cp_async_wait_group()
{
    asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
}

/// Blocks until all previous cp.async.commit_group operations have committed.
// cp.async.wait_all is equivalent to :
// cp.async.commit_group;
// cp.async.wait_group 0;
__device__ __forceinline__ void cp_async_wait_all()
{
    asm volatile("cp.async.wait_all;\n" ::);
}

#endif

================================================
FILE: fp6_llm/csrc/include/ptx_mma.cuh
================================================
/***************************************************************************
 * Copyright 2023 The FLash-LLM Authors. All rights reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 ***************************************************************************/
#ifndef PTX_MMA_CUH
#define PTX_MMA_CUH

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include <assert.h>
#include "configs.h"

template <typename TilingConfig>
__device__ __forceinline__ void B_FromSharedToReg(uint32_t  __restrict__    Reg[][4],
                                                  half      __restrict__    (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
                                                  int                       slice_id) {
    #ifdef DEBUG_MODE
        static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) );
    #endif
    
    const int   warpId  = threadIdx.x / WARP_SIZE;
    int         lane_id = threadIdx.x % WARP_SIZE;
    int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS;
    int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * WARP_j;   // each warp may start from reading warp_start_col'th column of the B tile in shared memory
    #ifdef DEBUG_MODE
        assert( warp_start_col==0 );
    #endif    

    int col = (lane_id%8) + (lane_id/16)*8;
    int row = (lane_id%16) / 8 * 8;
    uint32_t smem_local_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&read_SPTR[warp_start_col+col][slice_id*MMA_16 + row]));
    if(TilingConfig::WARP_COL_MMA_TENSORS==1) {
        asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
                     : "=r"(Reg[0][0]), "=r"(Reg[0][1])
                     : "r"(smem_local_ptr));
    }
    else {
        #pragma unroll
        for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS/2; i++)
        {
            asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
                         : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3])
                         : "r"(smem_local_ptr));
            smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(half);
        }
    }
}

__device__ __forceinline__ void
MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b)
{
    asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
                 "{ %0, %1, %2, %3},"
                 "{ %4, %5, %6, %7 },"
                 "{ %8, %9 },"
                 "{ %10, %11, %12, %13 };"
                 : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
                 : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
                   "r"(b[0]), "r"(b[1]),
                   "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}

#endif

================================================
FILE: fp6_llm/csrc/include/utils_core.cuh
================================================
#ifndef UTILS_CORE_CUH
#define UTILS_CORE_CUH

#include <assert.h>

#include "configs.h"
#include "ptx_mma.cuh"
#include "utils_parallel_dequant.cuh"


template<int NUM_INT_PER_THREAD>
__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR, int slice_id) {
    SPTR += slice_id * (NUM_INT_PER_THREAD*WARP_SIZE);
    int     lane_id = threadIdx.x % WARP_SIZE;
    #pragma unroll
    for(int i=0; i<NUM_INT_PER_THREAD; i++) {
        Reg[i] = SPTR[lane_id+i*WARP_SIZE];
    }
}

template <typename TilingConfig, int EXPONENT, int MANTISSA>
__device__ __forceinline__ void initialize_mma_slice(uint32_t                  (*a)[4],
                                                     uint32_t                  (*b)[4],
                                                     uint32_t* __restrict__    A_1BIT_SPTR_read,
                                                     uint32_t* __restrict__    A_2BIT_SPTR_read,
                                                     uint32_t* __restrict__    A_4BIT_SPTR_read,
                                                     half      __restrict__    (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
                                                     uint32_t*                 RPTR_Scales)
{
    // 1+2+4 weight split
    constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
    constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
    constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
    constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;
    // Writing registers
    // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread;
    uint32_t a_1bit[1];                      // NO double buffer
    uint32_t a_2bit[2];                      // NO double buffer
    uint32_t a_4bit[4];                      // NO double buffer
    if(USE_SEG_1BIT) CopyFromSharedToRegister_AFrag<1>   (a_1bit, A_1BIT_SPTR_read, 0);
    if(USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2>   (a_2bit, A_2BIT_SPTR_read, 0);
    if(USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4>   (a_4bit, A_4BIT_SPTR_read, 0);
    Dequant_32FP6_4Way<EXPONENT, MANTISSA>(a, a_1bit, a_2bit, a_4bit, RPTR_Scales);   // SIMT Dequant: dequantizing FPx to FP16 at register level, dequantizing a slice each time 
    B_FromSharedToReg<TilingConfig>(b, B_SPTR_read, 0); // Loading B from shared to registers
}

template <typename TilingConfig, int EXPONENT, int MANTISSA>
__device__ __forceinline__ void core_mma_slice(float                     c[][REG_PER_THREAD_C_TENSOR_16_16],
                                               uint32_t                  (*a)[4],
                                               uint32_t                  (*b)[4],
                                               uint32_t* __restrict__    A_1bit_SPTR_read,
                                               uint32_t* __restrict__    A_2bit_SPTR_read,
                                               uint32_t* __restrict__    A_4bit_SPTR_read,
                                               half      __restrict__    (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
                                               uint32_t*                 RPTR_Scales,
                                               int                       slice_id)      // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching
{
    // 1+2+4 weight split
    constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
    constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
    constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
    constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;

    #ifdef DEBUG_MODE
        assert((TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0));   // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block
    #endif
    const int NumRegSets_a = WARP_ROW_MMA_TENSORS;                                                                              // 1 set = 4 registers, containing a 16*16 MMA block
    const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2;                // 1 set = 4 registers, containing a 16*16 MMA block
    uint32_t (*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = reinterpret_cast<uint32_t(*)[REG_PER_THREAD_C_TENSOR_16_16]>(c);    // Reigsters for accumulated FP32 results

    // Setting RPTRs for double buffers
    uint32_t (*a_read )[4] = a;
    uint32_t (*a_write)[4] = a;
    uint32_t (*b_read )[4] = b;
    uint32_t (*b_write)[4] = b;
    if(slice_id%2==1)   { b_write += NumRegSets_b; a_write += NumRegSets_a;} 
    else                { b_read  += NumRegSets_b; a_read  += NumRegSets_a;}

    // Reading registers and issuing core tensor core computations (a slice of A and B tile in shared memory)
    #pragma unroll
    for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) {
        if(TilingConfig::WARP_COL_MMA_TENSORS==1) {
            MMA_FP16_M16N8K16( c_uint_ptr[i], a_read[i], b_read[0] );
        }
        else {            
            #pragma unroll
            for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS/2; j++) {
                MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS],     a_read[i], b_read[j]     );
                MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, a_read[i], b_read[j] + 2 ); // c+4; b+2
            }
        }
    }
    // Writing registers
    // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread;
    uint32_t a_1bit[1];                      // NO double buffer
    uint32_t a_2bit[2];                      // NO double buffer
    uint32_t a_4bit[4];                      // NO double buffer
    if(USE_SEG_1BIT) CopyFromSharedToRegister_AFrag<1>   (a_1bit, A_1bit_SPTR_read, slice_id);
    if(USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2>   (a_2bit, A_2bit_SPTR_read, slice_id);
    if(USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4>   (a_4bit, A_4bit_SPTR_read, slice_id);
    Dequant_32FP6_4Way<EXPONENT, MANTISSA>(a_write, a_1bit, a_2bit, a_4bit, RPTR_Scales);   // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time 
    B_FromSharedToReg<TilingConfig>     (b_write, B_SPTR_read, slice_id); // Loading B from shared to registers
}

template <typename TilingConfig>
__device__ __forceinline__ void StoreToSharedMemoryFromRegister(float (*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4],
                                                                float c[][REG_PER_THREAD_C_TENSOR_16_16])
{
    const int   lane_id             = threadIdx.x % WARP_SIZE;
    const int   warpId              = threadIdx.x / WARP_SIZE;
    int         warp_row_offset     = warpId * (MMA_16 * WARP_ROW_MMA_TENSORS);
    #pragma unroll
    for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) {
        #pragma unroll
        for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS; j++) {    // Dealing with one 16*8 Tensor
            int RegSetID            = i + (j/2)*WARP_ROW_MMA_TENSORS;
            int RegOffset           = (j%2)*(REG_PER_THREAD_C_TENSOR_16_16/2);
            int Tensor_row_offset   = warp_row_offset + i * MMA_16;
            int Tensor_col_offset   = j * MMA_8;
            #pragma unroll
            for (int r = 0; r < REG_PER_THREAD_C_TENSOR_16_16/2; r++) {
                int row_offset = lane_id / 4;
                if (r >= 2) row_offset += 8;
                int col_offset = (lane_id % 4) * 2;
                if (r%2==1) col_offset += 1;
                smem_CFrag[Tensor_col_offset + col_offset][Tensor_row_offset + row_offset] = c[RegSetID][r + RegOffset];
            }
        }
    }
}

#endif

================================================
FILE: fp6_llm/csrc/include/utils_gmem.cuh
================================================
#ifndef UTILS_GMEM_CUH
#define UTILS_GMEM_CUH

#include <assert.h>
#include "configs.h"
#include "ptx_cp.async.cuh"

/* 
 * Copying A1/A2 from global memory to shared memory.
 * Usually 1024 or 2048 Bytes
 */
template<int SMEM_SIZE_IN_BYTES_PER_WARP>
__device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, 
                                                        const uint4* GPTR,
                                                        bool pred_guard = true) {
    #ifdef DEBUG_MODE
        static_assert(SMEM_SIZE_IN_BYTES_PER_WARP/WARP_SIZE % 16 == 0);
    #endif
    int lane_id      = threadIdx.x % WARP_SIZE;
    half* SPTR_HALF = reinterpret_cast<half*>(SPTR);
    const half* GPTR_HALF = reinterpret_cast<const half*>(GPTR);
    SPTR_HALF += lane_id*8;
    GPTR_HALF += lane_id*8;
    #pragma unroll
    for(int i=0; i<SMEM_SIZE_IN_BYTES_PER_WARP/WARP_SIZE/16; i++) {
        cp_async<16>( SPTR_HALF, GPTR_HALF, pred_guard);
        SPTR_HALF += 256;   // Forward 512 Bytes
        GPTR_HALF += 256;   // Forward 512 Bytes
    }

}

/* 
 * Copying 64 Quant Scales (FP16) from global memory to shared memory.
 */
__device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantScales,
                                                              const half* GPTR_A_Scales) {
    int lane_id         = threadIdx.x % WARP_SIZE;
    int Offset_Shared   = lane_id*2; 
    int Offset_Global   = lane_id/4 + (lane_id%4)*16;
    for(int i=0; i<2; i++)  SPTR_QuantScales[Offset_Shared+i] = GPTR_A_Scales[Offset_Global+i*8];
}

/* 
 * (1) Copying X  rows * 64 columns of FP16 values, originally in row    major
 * (2) Copying 64 rows * X  columns of FP16 values, originally in column major
 * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 Threads
 */
template<int MaxNumOfLinesToCopy, int BLOCK_WARPS>
__device__ __forceinline__ void CopyFromGlobalToShared(half __restrict__ (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
                                                       const half*       GlobalPTR,
                                                       const int         GlobalStride,
                                                       const int         NumOfLinesLeft,        // To support arbitrary N dimensions.
                                                       bool              Pred = true) {
    // static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time
    const int NumOfThreads  = BLOCK_WARPS * WARP_SIZE;
    const int NumOfGroups   = NumOfThreads / 8; 
    const int MaxIteration  = (MaxNumOfLinesToCopy-1) / NumOfGroups + 1;
    // runtime variables   
    const int line_id       = threadIdx.x / 8;
    const int line_offset   = (threadIdx.x%8) * 8;
    // PTR for source global memory and target shared memory
    GlobalPTR += line_id * GlobalStride + line_offset;
    SharedPTR += line_id;
    #pragma unroll
    for (int i = 0; i < MaxIteration; i++) {
        bool AsyncCopyPred = (line_id+i*NumOfGroups) < NumOfLinesLeft && Pred;
        cp_async<16>( &(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred);
        //
        GlobalPTR += NumOfGroups * GlobalStride;
        SharedPTR += NumOfGroups;
    }
}

#endif

================================================
FILE: fp6_llm/csrc/include/utils_parallel_dequant.cuh
================================================
#ifndef UTILS_PARALLELDEQUANT_CUH
#define UTILS_PARALLELDEQUANT_CUH

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

/*
 * Input:   R1
 * Outputs: R1, R2
 * Note:    Simplified Exponent calculation is applied.
 */
template<int EXPONENT, int MANTISSA>
__device__ __forceinline__ void FPx_FP16_Cast_4Way(u_int32_t *In, u_int32_t *Out1, u_int32_t *Out2) {
    //
    constexpr int RIGHT_SHIFT = 5 - EXPONENT;
    constexpr int MASK1 = 0x80000000;
    constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA;
    constexpr int MASK3 = MASK2 & 0x7fffffff;
    constexpr int MASK  = MASK3 | MASK3 >> 16;
    //
    *Out1  = *In & 0x80008000;
    *Out1 |= ( (*In) & MASK ) >> RIGHT_SHIFT;
    //
    *In    = (*In) << 8;
    *Out2  = *In & 0x80008000;
    *Out2 |= ( (*In) & MASK ) >> RIGHT_SHIFT;
}

template<int EXPONENT, int MANTISSA>
__device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale) {
    constexpr int BIAS_OFFSET = (int(1) << (5-1)) - (int(1) << (EXPONENT-1));
    constexpr int BIAS        = int(1) << BIAS_OFFSET;
    //
    half* FP16_1 = reinterpret_cast<half*>(&PackedFP16Pair);
    half* FP16_2 = FP16_1 + 1;
    uint32_t output;
    half* output_half_ptr = reinterpret_cast<half*>(&output);
    output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(1.0f*BIAS)), Scale);
    output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2half(1.0f*BIAS)), Scale);   
    return output;
}

template<int EXPONENT, int MANTISSA>
__device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__   Reg[][4], 
                                                   u_int32_t __restrict__   *read_RPTR_1bit,
                                                   u_int32_t __restrict__   *read_RPTR_2bit, 
                                                   u_int32_t __restrict__   *read_RPTR_4bit,
                                                   u_int32_t                *Scales) {
    // 1+2+4 weight split
    constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
    constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
    constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
    constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;
    //
    u_int32_t *OutputRegs    = reinterpret_cast<u_int32_t*> (Reg);
    u_int32_t *Frag_PTR_1bit = read_RPTR_1bit;
    u_int32_t *Frag_PTR_2bit = read_RPTR_2bit;
    u_int32_t *Frag_PTR_4bit = read_RPTR_4bit;
    half      *Scale_RPTR    = reinterpret_cast<half*>(Scales);
    // Dequantizing 32 FP6, each Loop dequantizing 4 FP6
    #pragma unroll(8)
    for(int i=0; i<8; i++) { 
        u_int32_t Packed_FP6 = 0;
        u_int32_t tmp        = 0;
        // 1bit Frag
        if(USE_SEG_1BIT) {
            tmp = (*Frag_PTR_1bit) & 0x80808080;
            Packed_FP6 |= tmp >> (BIT_WIDTH & 0);
            if(i%8==7)  Frag_PTR_1bit++;
            else        (*Frag_PTR_1bit) = (*Frag_PTR_1bit) << 1;
        }
        // 2bit Frag
        if(USE_SEG_2BIT) {
            tmp = (*Frag_PTR_2bit) & 0xc0c0c0c0;
            Packed_FP6 |= tmp >> (BIT_WIDTH & 1);
            if(i%4==3)  Frag_PTR_2bit++;
            else        (*Frag_PTR_2bit) = (*Frag_PTR_2bit) << 2;
        }
        // 4bit Frag2
        if(USE_SEG_4BIT) {
            tmp = (*Frag_PTR_4bit) & 0xf0f0f0f0;
            Packed_FP6 |= tmp >> (BIT_WIDTH & 3);
            if(i%2==1)  Frag_PTR_4bit++;
            else        (*Frag_PTR_4bit) = (*Frag_PTR_4bit) << 4;
        }
        //
        u_int32_t out1, out2;
        FPx_FP16_Cast_4Way<EXPONENT, MANTISSA>(&Packed_FP6, &out1, &out2);
        //
        *OutputRegs = MultScale<EXPONENT, MANTISSA>(out1, Scale_RPTR[0]  );       // Muliply FP16 scales
        OutputRegs += 1;
        *OutputRegs = MultScale<EXPONENT, MANTISSA>(out2, Scale_RPTR[1]);         // Muliply FP16 scales
        OutputRegs += 1;
        // Updating offset for FP16 scales for every two iterations
        if(i%2==1)  Scale_RPTR += 2;
    }
    
}

/*
 * 
 */
__device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, half* WARP_SPTR_Scales) {
    int lane_id = threadIdx.x % WARP_SIZE;
    uint32_t* SPTR_uint = reinterpret_cast<uint32_t*>(WARP_SPTR_Scales);
    uint32_t tmpReg = SPTR_uint[lane_id];
    #pragma unroll
    for(int i=0; i<4; i++) {
        // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize); 
        Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4); 
    }
}

#endif

================================================
FILE: fp6_llm/csrc/pybind.cpp
================================================
#include <pybind11/pybind11.h>
#include <torch/extension.h>

#include "fp6_linear.cuh"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    // Old Interfaces.
    m.def("linear_forward_cuda", &fp6_linear_forward_cuda, "Computes FP6-FP16 GEMM.");
    m.def("weight_prepacking_cpu", &weight_matrix_prepacking_cpu, "Weight prepacking.");
    m.def("weight_dequant_cpu", &weight_matrix_dequant_cpu, "Dequantize weight from fp6 to fp16.");
    // New Interfaces.
    m.def("linear_forward_eXmY_cuda", &fp_eXmY_linear_forward_cuda, "Computes FPx-FP16 GEMM.");
    m.def("weight_prepacking_eXmY_cpu", &weight_matrix_prepacking_fp_eXmY_cpu, "FPx Weight prepacking.");
    m.def("weight_dequant_eXmY_cpu", &weight_matrix_dequant_fp_eXmY_cpu, "Dequantize weight from fpx to fp16.");    
}

================================================
FILE: fp6_llm/csrc/utils/common.h
================================================
#ifndef UTILS_COMMON_H
#define UTILS_COMMON_H

template<int EXPONENT, int MANTISSA>
unsigned char Extract_X_Bits_To_A_Byte(unsigned char* Bytes, int ByteOffset, int BitOffset){
    assert (sizeof(unsigned int)==4);
    unsigned int tmp_int32_word=0;
    unsigned char* uchar_ptr = reinterpret_cast<unsigned char*>(&tmp_int32_word);
    uchar_ptr[3] = Bytes[ByteOffset+0];
    uchar_ptr[2] = Bytes[ByteOffset+1];
    tmp_int32_word = tmp_int32_word << BitOffset;
    //
    signed int mask = 0x80000000;
    mask = mask >> (EXPONENT+MANTISSA);
    tmp_int32_word &= mask;
    //
    unsigned char out = uchar_ptr[3];
    return out;
}

#endif

================================================
FILE: fp6_llm/csrc/utils/weight_dequant.h
================================================
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include "common.h"

template<int EXPONENT, int MANTISSA>
void DeQuantMatrix_FPx_To_FP16(half* A_16bit_h, unsigned char* A_x_bit_h, size_t M, size_t K, half* scale) {
    //
    assert(M%64==0);                 // Currently, M must be a multiple of 64.
    assert(K%64==0);                 // Currently, K must be a multiple of 64.
    constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
    assert(BIT_WIDTH<=8);
    size_t TotalSizeInByte = M * K * BIT_WIDTH / 8;
    //
    half* OutPTR = A_16bit_h;
    for(size_t i=0; i<TotalSizeInByte/BIT_WIDTH; i++) {    // Processing BIT_WIDTH Bytes for each Loop, generating 8 FP16.
        unsigned char Bytes[BIT_WIDTH];
        for(int x=0; x<BIT_WIDTH; x++)  Bytes[x] = A_x_bit_h[i*BIT_WIDTH+x];
        unsigned char OUT[8];
        for(int x=0; x<8; x++) {                        // Prepare Initial memory layout for Dequant
            int ByteOffset  = BIT_WIDTH * x / 8;
            int BitOffset   = BIT_WIDTH * x % 8;
            OUT[x] = Extract_X_Bits_To_A_Byte<EXPONENT, MANTISSA>(Bytes, ByteOffset, BitOffset);
        }
        // Dequant
        constexpr int MASK1 = 0x80000000;
        constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA;
        constexpr int MASK  = MASK2 & 0x7fffffff;
        constexpr int RIGHT_SHIFT = 5 - EXPONENT;
        constexpr int BIAS_OFFSET = (int(1) << (5-1)) - (int(1) << (EXPONENT-1));
        constexpr int BIAS        = int(1) << BIAS_OFFSET;
        for(int x=0; x<8; x++) {
            unsigned int OUT_fp16;        // Storing fp16 in the high 16 bits.
            OUT_fp16 = int(OUT[x]) << 24;
            OUT_fp16 = (OUT_fp16 & 0x80000000) | ( (OUT_fp16 & MASK) >> RIGHT_SHIFT );
            OUT_fp16 = OUT_fp16 >> 16;
            //
            half* OUT_FP16_PTR = reinterpret_cast<half*>(&OUT_fp16);
            OutPTR[x] = __float2half_rn ( __half2float(*OUT_FP16_PTR) * (1.0f*BIAS) * __half2float(scale[(8*i)/K]) );
        }   
        //
        OutPTR +=8;
    }
}


void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale) {
    DeQuantMatrix_FPx_To_FP16<3, 2>(A_16bit_h, A_6bit_h, M, K, scale);
}
void dequant_matrix_fp_eXmY_to_fp16(const int EXPONENT, const int MANTISSA, half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale){
    if(EXPONENT==2 && MANTISSA==2)
        return DeQuantMatrix_FPx_To_FP16<2, 2>(A_16bit_h, A_6bit_h, M, K, scale);
    if(EXPONENT==3 && MANTISSA==2)
        return DeQuantMatrix_FPx_To_FP16<3, 2>(A_16bit_h, A_6bit_h, M, K, scale);
    printf("DeQuantMatrix Error: Unsupported EXPONENT=%d, MANTISSA=%d!\n", EXPONENT, MANTISSA);
    exit(-1);
}

================================================
FILE: fp6_llm/csrc/utils/weight_prepacking.h
================================================
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <vector>
#include "common.h"

/*
 * Inputs:
 * (1) unsigned char Weight_6bit [M*K*6/8]
 * Outputs:
 * (1) unsigned char Weight_2bit [M*K*2/8]
 * (2) unsigned char Weight_4bit [M*K*4/8]
 * 
 * Assumption: Weight_6bit, Weight_2bit, Weight_4bit all stored continuously in row-major.
 * 8 FP6 = 6 Bytes
 * 8 FP4 = 4 Bytes
 * 8 FP2 = 2 Bytes
 */

using namespace std;


void Extract_segments_from_8_padded_fpx(unsigned char Seg_xbit[], unsigned char Padded_8_FPx[], int bit_width, int bit_offset){
    for(int i=0; i< bit_width; i++)
        Seg_xbit[i] = 0;
    for(int i=0; i<8; i++){
        unsigned int seg = (Padded_8_FPx[i] << bit_offset) & 0x000000ff;
        int mask = 0xffffff00;
        seg &= mask >> bit_width;
        //
        int Seg_idx = (i * bit_width) / 8;
        int Seg_off = (i * bit_width) % 8;
        Seg_xbit[Seg_idx] |= seg >> Seg_off;
    }
}


// dealing with 4 1*8 blocks of FPx
template<int EXPONENT, int MANTISSA>
void Assign_32_FPx_To_4_Thread(vector<unsigned char> Vec_Seg_1bit[], vector<unsigned char> Vec_Seg_2bit[], vector<unsigned char> Vec_Seg_4bit[], unsigned char* PTR[]) 
{
    constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
    assert(BIT_WIDTH<8);
    constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
    constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
    constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;
    //
    constexpr int nTHREADS = 4;
    constexpr int FPx_PER_THREAD = 8;
    unsigned char Padded_8_FPx[nTHREADS][FPx_PER_THREAD];
    for(int i=0; i<nTHREADS; i++){                             // 4 threads
        for(int j=0; j<FPx_PER_THREAD; j++){                   // 8 FPx per thread
            int offset = (i*2 + j%2) * BIT_WIDTH;
            int ByteOffset = offset / 8;
            int BitOffset  = offset % 8;
            Padded_8_FPx[i][j] = Extract_X_Bits_To_A_Byte<EXPONENT, MANTISSA>(PTR[j/2], ByteOffset, BitOffset);
        }
    }
    //
    unsigned char Seg_1bit[nTHREADS][1];
    unsigned char Seg_2bit[nTHREADS][2];
    unsigned char Seg_4bit[nTHREADS][4];
    for(int t=0; t<nTHREADS; t++){
        Extract_segments_from_8_padded_fpx(Seg_1bit[t], Padded_8_FPx[t], 1, int(BIT_WIDTH & 0));
        Extract_segments_from_8_padded_fpx(Seg_2bit[t], Padded_8_FPx[t], 2, int(BIT_WIDTH & 1));
        Extract_segments_from_8_padded_fpx(Seg_4bit[t], Padded_8_FPx[t], 4, int(BIT_WIDTH & 3));
    }
    //
    for(int t=0; t<4; t++)
    {
        if (USE_SEG_1BIT) {
            Vec_Seg_1bit[t].push_back(Seg_1bit[t][0]);
        }
        if (USE_SEG_2BIT) {
            Vec_Seg_2bit[t].push_back(Seg_2bit[t][0]);
            Vec_Seg_2bit[t].push_back(Seg_2bit[t][1]);
        }
        if (USE_SEG_4BIT) {
            Vec_Seg_4bit[t].push_back(Seg_4bit[t][0]);
            Vec_Seg_4bit[t].push_back(Seg_4bit[t][1]);
            Vec_Seg_4bit[t].push_back(Seg_4bit[t][2]);
            Vec_Seg_4bit[t].push_back(Seg_4bit[t][3]);
        }
    }
}

template<int BIT_WIDTH>
void BitInterleaving_x_bit(unsigned char* PTR_4Bytes)
{
    unsigned int *PTR_UINT = reinterpret_cast<unsigned int*>(PTR_4Bytes);
    unsigned int input  = *PTR_UINT;
    //
    int* order = NULL;
    int order_1bit[32] = {2,6,10,14,18,22,26,30,
                          4,8,12,16,20,24,28,32,
                          1,5,9, 13,17,21,25,29,
                          3,7,11,15,19,23,27,31};  // pre-defined order for bit-interleaving in FP6-LLM
    int order_2bit[16] = {2,6,10,14,4,8,12,16,1,5,9,13,3,7,11,15};  // pre-defined order for bit-interleaving in FP6-LLM
    int order_4bit[8] = {2,6,4,8,1,5,3,7};  // pre-defined order for bit-interleaving in FP6-LLM
    if(BIT_WIDTH==1) order = order_1bit;
    if(BIT_WIDTH==2) order = order_2bit;
    if(BIT_WIDTH==4) order = order_4bit;
    assert(order);
    //
    int mask = 0x80000000;
    assert(BIT_WIDTH>=1);
    mask = mask >> (BIT_WIDTH-1);
    //
    unsigned int output = 0x00000000;
    for(int i=0; i<32/BIT_WIDTH; i++){
        unsigned int Frag_xbit = ( input << BIT_WIDTH*(order[i]-1) ) & mask;    // The highest x bits are used to store the extracted fragments.
        output |= Frag_xbit >> (i*BIT_WIDTH);
    }
    //
    *PTR_UINT = output;
}

template<int EXPONENT, int MANTISSA>
void weight_matrix_prepacking_x_bit(int* packed_weights, int *FPxWeights, size_t M, size_t K)
{
    assert(M % 64 == 0);
    assert(K % 64 == 0);
    //
    constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
    assert(BIT_WIDTH<8);
    constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
    constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
    constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;
    //
    unsigned char* Weight_xbit = reinterpret_cast<unsigned char*>(FPxWeights);
    unsigned char* Weight_1bit = reinterpret_cast<unsigned char*>(packed_weights);
    unsigned char* Weight_2bit = Weight_1bit + (USE_SEG_1BIT ? M*K*1/8 : 0);
    unsigned char* Weight_4bit = Weight_2bit + (USE_SEG_2BIT ? M*K*2/8 : 0);
    //
    vector<unsigned char> A_Segment_1bit[32];
    vector<unsigned char> A_Segment_2bit[32];
    vector<unsigned char> A_Segment_4bit[32];
    //
    size_t BytesPerRow = K*BIT_WIDTH/8;
    // Pass-1: (1) 1+2+4 split; (2) assign weights to 32 threads.
    for (size_t i = 0; i < M / 64; i++){
        for (size_t j = 0; j < K / 16; j++){
            for(size_t k=0; k<64/16; k++){
                size_t row = i*64 + k*16;
                size_t col = j*16;
                unsigned char* StartPTR_1 = Weight_xbit + row*BytesPerRow + col*(BIT_WIDTH)/8;
                unsigned char* StartPTR_2 = StartPTR_1 + 8*BytesPerRow;
                unsigned char* StartPTR_3 = StartPTR_1 + 8*(BIT_WIDTH)/8;
                unsigned char* StartPTR_4 = StartPTR_2 + 8*(BIT_WIDTH)/8;
                // Dealing with each 16*16 blocks then...
                for(int l=0; l<8; l++) {
                    unsigned char* PTR[4]={StartPTR_1+l*BytesPerRow, StartPTR_2+l*BytesPerRow, StartPTR_3+l*BytesPerRow, StartPTR_4+l*BytesPerRow};
                    Assign_32_FPx_To_4_Thread<EXPONENT,MANTISSA>(&A_Segment_1bit[l*4], &A_Segment_2bit[l*4], &A_Segment_4bit[l*4], PTR);
                }
            }
        }
    }
    // Verifying the length of 1/2/4_bit segments.
    size_t BytesPerThread_1bit = M*K*1/8/32;
    size_t BytesPerThread_2bit = M*K*2/8/32;
    size_t BytesPerThread_4bit = M*K*4/8/32;
    for(int i=0; i<32; i++){
        if(USE_SEG_1BIT)    assert(A_Segment_1bit[i].size()==BytesPerThread_1bit);
        else                assert(A_Segment_1bit[i].size()==0);
        if(USE_SEG_2BIT)    assert(A_Segment_2bit[i].size()==BytesPerThread_2bit);
        else                assert(A_Segment_2bit[i].size()==0);
        if(USE_SEG_4BIT)    assert(A_Segment_4bit[i].size()==BytesPerThread_4bit);
        else                assert(A_Segment_4bit[i].size()==0);
    }
    // Pass-2: Optimizing coleasced global memory access
    if(USE_SEG_1BIT)
    for(size_t i=0; i<BytesPerThread_1bit/4; i++) 
        for(int t=0; t<32; t++)
            for(int b=0; b<4; b++)              // why (3-b): special byte order within a register
                Weight_1bit[i*128+t*4+(3-b)] = A_Segment_1bit[t][i*4+b];    
    if(USE_SEG_2BIT)
    for(size_t i=0; i<BytesPerThread_2bit/4; i++) 
        for(int t=0; t<32; t++)
            for(int b=0; b<4; b++)              // why (3-b): special byte order within a register
                Weight_2bit[i*128+t*4+(3-b)] = A_Segment_2bit[t][i*4+b];    
    if(USE_SEG_4BIT)
    for(size_t i=0; i<BytesPerThread_4bit/4; i++) 
        for(int t=0; t<32; t++)
            for(int b=0; b<4; b++)              // why (3-b):special byte order within a register
                Weight_4bit[i*128+t*4+(3-b)] = A_Segment_4bit[t][i*4+b];
    // Pass-3: Bit-level interleaving
    if(USE_SEG_1BIT)
    for(size_t i=0; i<BytesPerThread_1bit*32/4; i++)
        BitInterleaving_x_bit<1>(Weight_1bit+4*i);
    if(USE_SEG_2BIT)
    for(size_t i=0; i<BytesPerThread_2bit*32/4; i++)
        BitInterleaving_x_bit<2>(Weight_2bit+4*i);
    if(USE_SEG_4BIT)
    for(size_t i=0; i<BytesPerThread_4bit*32/4; i++)
        BitInterleaving_x_bit<4>(Weight_4bit+4*i);
}


void weight_matrix_prepacking(int* packed_weights, int *FP6Weights, size_t M, size_t K){
    weight_matrix_prepacking_x_bit<3, 2>(packed_weights, FP6Weights, M, K);
}

//
void weight_matrix_prepacking_fp_eXmY(const int EXPONENT, const int MANTISSA, int* packed_weights, int *FPxWeights, size_t M, size_t K){
    if(EXPONENT==2 && MANTISSA==2)
        return weight_matrix_prepacking_x_bit<2, 2>(packed_weights, FPxWeights, M, K);
    if(EXPONENT==3 && MANTISSA==2)
        return weight_matrix_prepacking_x_bit<3, 2>(packed_weights, FPxWeights, M, K);
    printf("Weight_prepacking Error: Unsupported EXPONENT=%d, MANTISSA=%d!\n", EXPONENT, MANTISSA);
    exit(-1);
}

================================================
FILE: fp6_llm/csrc/utils/weight_quant.h
================================================
// Author: Zhen Zheng
// To be used in the future as a tool to generating the FP6 matrix from the FP16 matrix.

#include<iostream>

/*
 * Function to pack 4 fake quantized FP16 value into continuously stored 4 FP6 values.
 */
void cast_fp16_fp6(uint16_t* FP16x4, uint8_t* FP6x4)
{
    // Constants for FP6
    constexpr int exponent_nbits_fp6 = 3;
    constexpr int mantissa_nbits_fp6 = 2;
    constexpr int exp_bias_fp6 = (1 << (exponent_nbits_fp6 - 1)) - 1;
    // Constants for FP16
    constexpr int exponent_nbits_fp16 = 5;
    constexpr int mantissa_nbits_fp16 = 10;
    constexpr int exp_bias_fp16 = (1 << (exponent_nbits_fp16 - 1)) - 1;

    int fp6_temp[4];

    float absmin_nonzero_fp6 = 0.0625;
    // Note that we regard the exponent of '111' as a regular value rather than NaN or inf. This is
    // the same with that in qtorch.
    float absmax_fp6 = 28;

    for (int i = 0; i < 4; ++i) {
        uint16_t source = FP16x4[i];
        float fp6_value_abs = std::abs(__half2float(*((half*)(&source))));
        if ((fp6_value_abs != 0 && fp6_value_abs < absmin_nonzero_fp6) ||
            fp6_value_abs > absmax_fp6) {
            // TODO(zhen): a better way may be rounding it to the nearest FP6 value.
            throw std::invalid_argument("Input value out of range for FP6.");
        }

        // It is not safe to do shift operation on uint16_t. So we promote it to int.
        int source_promote = int(source);

        int sign_bit = (source_promote >> 15);
        // Extracting exponent represented in FP16. The sign mask 0x7FFF is '0111 1111 1111 1111'
        int exp_bit = (source_promote & 0x7FFF) >> mantissa_nbits_fp16;
        // Extracting mantissa represented in FP16
        int mant_bit = source_promote & ((1 << mantissa_nbits_fp16) - 1);

        int new_exp_bit;
        int new_mant_bit;

        if (exp_bit == 0) {
            // Subnormal FP16 number. Too small for FP6.
            new_exp_bit = 0;
            new_mant_bit = 0;
        } else {
            new_mant_bit = mant_bit >> (mantissa_nbits_fp16 - mantissa_nbits_fp6);
            new_exp_bit = exp_bit - exp_bias_fp16 + exp_bias_fp6;

            // Deal with subnormal FP6 values.
            int target_exp_val = exp_bit - exp_bias_fp16;
            int min_fp6_exp_val = -exp_bias_fp6 + 1;
            bool subnormal_fp6 = target_exp_val < min_fp6_exp_val;
            if (subnormal_fp6) {
                // TODO(zhen): add the rounding logic.
                new_exp_bit = 0;
                // The implicit 1 in the mantissa of FP16 is not present in subnormal FP6. Thus we
                // need to add it
                new_mant_bit = (new_mant_bit | (1 << mantissa_nbits_fp6)) >>
                               (min_fp6_exp_val - target_exp_val);
            }
        }

        fp6_temp[i] = (sign_bit << (exponent_nbits_fp6 + mantissa_nbits_fp6)) |
                      (new_exp_bit << mantissa_nbits_fp6) | new_mant_bit;
    }
    // Pack the values
    FP6x4[0] = fp6_temp[0] << 2 | (fp6_temp[1] >> 4);
    FP6x4[1] = (fp6_temp[1] & 0x0F) << 4 | (fp6_temp[2] >> 2);
    FP6x4[2] = (fp6_temp[2] & 0x03) << 6 | fp6_temp[3];
}

/*
 *  Function to prepack FP16 weights into continuous FP6 values.
 *
 *  Parameters:
 *     weight_16bit: input weight in FP16, size M*K
 *     weight_6bit: output weight in packed FP6, continuously stored, size M*K*6/8
 *     M, K: the shape of the weight
 */
void weight_prepacking_fp16_to_fp6(uint16_t* weight_16bit,
                                   uint8_t* weight_6bit_packed,
                                   size_t M,
                                   size_t K)
{
    // Every four 16-bit elements are packed into three 6-bit values (4*6bit == 3*8bit).
    if (K * 6 % 8 != 0) { throw std::invalid_argument("(K * 6 % 8) should be 0"); }
    size_t K_fp6_packed = K * 6 / 8;
    // #pragma omp parallel for
    for (auto m = 0; m < M; m++) {
        uint8_t* ptr_6bit = weight_6bit_packed + m * K_fp6_packed;
        uint16_t* ptr_16bit = weight_16bit + m * K;
        for (auto k = 0; k < K; k += 4) {
            cast_fp16_fp6(ptr_16bit, ptr_6bit);
            ptr_16bit += 4;
            ptr_6bit += 3;
        }
    }
}

================================================
FILE: setup.py
================================================
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension


extra_compile_args = {
    "cxx": [
        "-O3",
        "-std=c++17"
    ],
    "nvcc": [
        "-O3", 
        "--use_fast_math",
        "-std=c++17",
        "-maxrregcount=255",
        "--ptxas-options=-v,-warn-lmem-usage,--warn-on-spills",
        "-gencode=arch=compute_80,code=sm_80"
    ],
}

setup(
    name="fp6_llm",
    author="Haojun Xia, Zhen Zheng, Xiaoxia Wu, Shiyang Chen, Zhewei Yao, Stephen Youn, Arash Bakhtiari, Michael Wyatt, Donglin Zhuang, Zhongzhu Zhou, Olatunji Ruwase, Yuxiong He, Shuaiwen Leon Song",
    version="0.2",
    author_email="xhjustc@gmail.com",
    description ="An efficient GPU support for LLM inference with x-bit quantization (e.g., FP6 and FP5).",
    python_requires=">=3.8",
    install_requires=[
        "torch",
        "transformers"
    ],
    packages=find_packages(),
    ext_modules=[
        CUDAExtension(
            name="fp6_llm_cuda",
            sources=[
                "fp6_llm/csrc/pybind.cpp", 
                "fp6_llm/csrc/fp6_linear.cu"
            ],
            extra_compile_args=extra_compile_args,
        ),
    ],
    cmdclass={"build_ext": BuildExtension}
)

================================================
FILE: tests/cpp/Makefile
================================================
# host compiler
HOST_COMPILER ?= g++
NVCC          := nvcc -ccbin $(HOST_COMPILER)

# internal flags
NVCCFLAGS   := -m$(shell getconf LONG_BIT)
CCFLAGS     := -DNO_PYTORCH
LDFLAGS     := -rpath=../../fp6_llm

ALL_CCFLAGS :=
ALL_CCFLAGS += $(NVCCFLAGS)
ALL_CCFLAGS += $(addprefix -Xcompiler ,$(CCFLAGS))

ALL_LDFLAGS :=
ALL_LDFLAGS += $(ALL_CCFLAGS)
ALL_LDFLAGS += $(addprefix -Xlinker ,$(LDFLAGS))

# Common includes and paths for CUDA
INCLUDES  := -I/usr/local/cuda/include/ 
LIBRARIES := -lcublas 

#
INCLUDES  += -I../../fp6_llm/csrc
LIBRARIES += -L../../fp6_llm -lfp6

################################################################################
# Gencode arguments
SMS ?= 80
# Generate SASS code for each SM architecture listed in $(SMS)
$(foreach sm,$(SMS),$(eval GENCODE_FLAGS += -gencode arch=compute_$(sm),code=sm_$(sm)))
################################################################################
# Target rules
all: kernel_test_fp6 kernel_test_fpx

kernel_test_fp6.o:  kernel_test_fp6.cu kernel_test.h
	$(EXEC) $(NVCC) $(INCLUDES) $(ALL_CCFLAGS) $(GENCODE_FLAGS) -o $@ -c $<

kernel_test_fpx.o:  kernel_test_fpx.cu kernel_test.h
	$(EXEC) $(NVCC) $(INCLUDES) $(ALL_CCFLAGS) $(GENCODE_FLAGS) -o $@ -c $<

kernel_test_fp6: kernel_test_fp6.o 
	$(EXEC) $(NVCC) $(ALL_LDFLAGS) $(GENCODE_FLAGS) -o $@ $+ $(LIBRARIES) 

kernel_test_fpx: kernel_test_fpx.o 
	$(EXEC) $(NVCC) $(ALL_LDFLAGS) $(GENCODE_FLAGS) -o $@ $+ $(LIBRARIES) 

clean:
	rm -f kernel_test_fp6 kernel_test_fp6.o kernel_test_fpx kernel_test_fpx.o

================================================
FILE: tests/cpp/kernel_test.h
================================================
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>

// Performance Benchmark
#define WARM_UP_ITERATION 100
#define BENCHMARK_ITERATION 10000

void __forceinline__ CheckMallocCPU(void* PTR, int line = -1) {
    if (PTR == NULL) {
        printf("Error in CPU Malloc, line %d!\n", line);
        exit(-1);
    }
}

void __forceinline__ CheckMallocCUDA(void* PTR, int line = -1) {
    if (PTR == NULL) {
        printf("Error in cudaMalloc, line %d!\n", line);
        exit(-1);
    }
}

void checkCublasError(cublasStatus_t status, int line)
{
    if (status != CUBLAS_STATUS_SUCCESS) {
        printf("Cublas Error at line %d, Error Code: %d\n", line, status);
        exit(EXIT_FAILURE);
    }
}

void checkLastCudaError(int line)
{
    cudaError_t error = cudaGetLastError();
    if (error != cudaSuccess) {
        printf("Last Cuda Error Detected at line: %d, Error: %s.\n", line, cudaGetErrorString(error));
        exit(EXIT_FAILURE);
    }
}

// Note: totalAbsSum might overflow if (1)the shape of output matrix are large and (2)the value of each element within the output matrix is large.
// The overflow might result in NaN for the "TotalAbsError/TotalAbsSum", while the fp6_llm is working correctly.
// This problem can be fixed by setting the quantizaiton scales to a smaller value.
double ComputeTotalError(half* CuBlas, half* Other, size_t m, size_t n)
{
    long double totalError = 0.0;
    for (size_t i = 0; i < m * n; i++)
        totalError += fabs(__half2float(CuBlas[i]) - __half2float(Other[i]));
    long double totalAbsSum = 0.0;
    for (size_t i = 0; i < m * n; i++)
        totalAbsSum += fabs(__half2float(CuBlas[i]));    
    return totalError/totalAbsSum;
}

void PrintPerformance(const char* KernelName, float milliseconds, float tflops, double error)
{
    printf("%-10s \t -> \t\t Time/ms: %5.3f \t Performance/TFLOPs: %4.2f \t TotalAbsError/TotalAbsSum: %.8lf\n",
           KernelName,
           milliseconds,
           tflops,
           error);
}


void PrintMismatch(const char* KernelName,
                   size_t      MaxNumMismatch,
                   float       RelativeErrorThreshold,
                   half*       CuBlas,
                   half*       Other,
                   size_t         M_GLOBAL,
                   size_t         N_GLOBAL)
{
    //printf("First %d Mismatches between Cublas and %s:\n", MaxNumMismatch, KernelName);
    size_t count = 0;
    for (size_t i = 0; i < M_GLOBAL; i++) {
        for (size_t j = 0; j < N_GLOBAL; j++) {
            if (fabs(__half2float(CuBlas[i + j * M_GLOBAL]) - __half2float(Other[i + j * M_GLOBAL]))/fabs(__half2float(CuBlas[i + j * M_GLOBAL])) > RelativeErrorThreshold) {
                count++;
                printf("(%d,%d) CuBlas=%f %s=%f\n",
                       i,
                       j,
                       __half2float(CuBlas[i + j * M_GLOBAL]),
                       KernelName,
                       __half2float(Other[i + j * M_GLOBAL]));
            }
            if (count == MaxNumMismatch)
                break;
        }
        if (count == MaxNumMismatch)
            break;
    }
}



================================================
FILE: tests/cpp/kernel_test_fp6.cu
================================================
#include "kernel_test.h"
#include "fp6_linear.cuh"

int main(int argc, char** argv)
{
    // Parsing the inputs from CLI.
    if (argc != 5) {
        printf("Wrong Inputs! Correct input format: ./kernel_test #Row_Weight #Column_Weight BatchSize SplitK\n");
        return -1;
    }
    size_t M_GLOBAL = atoi(argv[1]);
    size_t K_GLOBAL = atoi(argv[2]);
    size_t N_GLOBAL = atoi(argv[3]);
    int    SPLIT_K  = atoi(argv[4]);
    assert(M_GLOBAL%256==0);                 // Currently, M_GLOBAL must be a multiple of 256.
    assert(K_GLOBAL%64==0);                  // Currently, K_GLOBAL must be a multiple of 64.
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // Matrices in quantized FP6 models with faked values.
    unsigned char* A_6bit_h  = (unsigned char*)malloc(M_GLOBAL*K_GLOBAL*6/8);       CheckMallocCPU(A_6bit_h, __LINE__);     // Weight matrix with FP6 values, stored in row-major.
    for(size_t i=0; i<M_GLOBAL*K_GLOBAL*6/8; i++)   A_6bit_h[i] = rand() % 256;                                             // Random initialization.
    half*          A_Scale_h = (half*)malloc(M_GLOBAL*sizeof(half));                CheckMallocCPU(A_Scale_h, __LINE__);    // Quantization Scales with FP16 values.
    for(size_t i=0; i<M_GLOBAL; i++)                A_Scale_h[i] = float(rand()%256)/64.0f;                                 // Scale
    // Generaing FP16 format of the Weight Matrix
    half* A_16bit_h = (half*) malloc(M_GLOBAL*K_GLOBAL*sizeof(half));                           CheckMallocCPU(A_16bit_h, __LINE__);
    DeQuantMatrix_FP6_To_FP16(A_16bit_h, A_6bit_h, M_GLOBAL, K_GLOBAL, A_Scale_h);
    // In-place weight pre-packing
    weight_matrix_prepacking((int*)A_6bit_h, (int*)A_6bit_h, M_GLOBAL, K_GLOBAL);

    // Devices Memory
    unsigned char*  A_6bit;
    half*           A_Scale;
    half*           A_16bit;
    cudaMalloc(reinterpret_cast<void**>(&A_6bit),  M_GLOBAL*K_GLOBAL*6/8);             CheckMallocCUDA(A_6bit, __LINE__);
    cudaMalloc(reinterpret_cast<void**>(&A_Scale), M_GLOBAL*sizeof(half));             CheckMallocCUDA(A_Scale, __LINE__);
    cudaMalloc(reinterpret_cast<void**>(&A_16bit),          M_GLOBAL*K_GLOBAL*sizeof(half));    CheckMallocCUDA(A_16bit, __LINE__);
    // Memory Copy from CPU to GPU
    cudaMemcpy(A_6bit,     A_6bit_h,  M_GLOBAL*K_GLOBAL*6/8,          cudaMemcpyHostToDevice);
    cudaMemcpy(A_Scale,    A_Scale_h,          M_GLOBAL*sizeof(half),          cudaMemcpyHostToDevice);
    cudaMemcpy(A_16bit,             A_16bit_h,          M_GLOBAL*K_GLOBAL*sizeof(half), cudaMemcpyHostToDevice);
    checkLastCudaError(__LINE__);
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // B Matrix: Activations
    half* B_h = (half*)malloc(sizeof(half) * K_GLOBAL * N_GLOBAL); CheckMallocCPU(B_h);       // col major 
    for (size_t i = 0; i < N_GLOBAL * K_GLOBAL; i++)
        B_h[i] = __float2half_rn(static_cast<float>((rand() % 5)) / 5 - 0.5f);
    // Device memory
    half* B            = NULL;
    cudaMalloc(reinterpret_cast<void**>(&B), sizeof(half) * N_GLOBAL * K_GLOBAL);               CheckMallocCUDA(B, __LINE__);
    // Memory Copy from CPU to GPU
    cudaMemcpy(B, B_h, sizeof(half) * N_GLOBAL * K_GLOBAL, cudaMemcpyHostToDevice);
    checkLastCudaError(__LINE__);
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    cublasStatus_t cublas_status;
    cudaEvent_t start, stop;
    cudaEventCreate(&start);
    cudaEventCreate(&stop);
    checkLastCudaError(__LINE__);
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    half* D_cublas = NULL;
    cudaMalloc(reinterpret_cast<void**>(&D_cublas), sizeof(half) * M_GLOBAL * N_GLOBAL);        CheckMallocCUDA(D_cublas, __LINE__);
    cudaMemset(D_cublas, 0, sizeof(half) * M_GLOBAL * N_GLOBAL);
    cublasHandle_t handle;
    cublasCreate(&handle);
    cublasSetStream(handle, 0);
    //cublasSetMathMode(handle, CUBLAS_PEDANTIC_MATH);          // Tensor core NOT enabled
    cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH);             // Tensor core enabled
    cudaDeviceSynchronize();
    int              m = M_GLOBAL, n = N_GLOBAL, k = K_GLOBAL;
    const float      alpha     = 1.0;
    const float      beta      = 0.0;
    cublasGemmAlgo_t CuBlasALG = static_cast<cublasGemmAlgo_t>(0);
    for (int i = 0; i < WARM_UP_ITERATION; i++) {
        cublas_status = cublasGemmEx(handle,
                                     CUBLAS_OP_T,   CUBLAS_OP_N,
                                     m, n, k,
                                     &alpha,
                                     A_16bit,   CUDA_R_16F, k,
                                     B,         CUDA_R_16F, k,
                                     &beta,
                                     D_cublas,  CUDA_R_16F, m,
                                     CUDA_R_32F,
                                     CuBlasALG);
        checkCublasError(cublas_status, __LINE__);
    }
    cudaEventRecord(start);
    for (int i = 0; i < BENCHMARK_ITERATION; i++)
        cublas_status = cublasGemmEx(handle,
                                     CUBLAS_OP_T,   CUBLAS_OP_N,
                                     m, n, k,
                                     &alpha,
                                     A_16bit,   CUDA_R_16F, k,
                                     B,         CUDA_R_16F, k,
                                     &beta,
                                     D_cublas,  CUDA_R_16F, m,
                                     CUDA_R_32F,
                                     CuBlasALG);
    cudaEventRecord(stop);
    cudaEventSynchronize(stop);
    //
    float milliseconds_cublas = 0;
    cudaEventElapsedTime(&milliseconds_cublas, start, stop);
    milliseconds_cublas = milliseconds_cublas / BENCHMARK_ITERATION;
    float tflops_cublas = static_cast<double>((static_cast<double>(M_GLOBAL) * N_GLOBAL * K_GLOBAL * 2) / (milliseconds_cublas / 1000.)) / 1e12;
    //
    half* D_cublas_h = NULL;  // col major
    D_cublas_h       = (half*)malloc(sizeof(half) * M_GLOBAL * N_GLOBAL);   CheckMallocCPU(D_cublas_h);
    cudaMemcpy(D_cublas_h, D_cublas, sizeof(half) * M_GLOBAL * N_GLOBAL, cudaMemcpyDeviceToHost);  // Col Major
    cudaFree(D_cublas);
    checkLastCudaError(__LINE__);
    /////////////////////////////////////////////////////////////////////////////////////////////////
    half* D_fp6 = NULL;
    cudaMalloc(reinterpret_cast<void**>(&D_fp6), sizeof(half) * M_GLOBAL * N_GLOBAL); CheckMallocCUDA(D_fp6);
    cudaMemset(D_fp6, 0, sizeof(half) * M_GLOBAL * N_GLOBAL);
    //
    int Split_K = SPLIT_K;
    float* Reduction_Workspace = NULL;
    cudaMalloc(reinterpret_cast<void**>(&Reduction_Workspace), sizeof(float) * M_GLOBAL * N_GLOBAL * Split_K);   CheckMallocCUDA(Reduction_Workspace, __LINE__);
    //
    for (int i = 0; i < WARM_UP_ITERATION; i++)
        fp6_linear_kernel(  0,
                        (uint4*)A_6bit, A_Scale,
                        B,
                        D_fp6,
                        M_GLOBAL, N_GLOBAL, K_GLOBAL,
                        Reduction_Workspace,  
                        Split_K);
    cudaEventRecord(start);
    for (int i = 0; i < BENCHMARK_ITERATION; i++)
        fp6_linear_kernel(  0,
                        (uint4*)A_6bit, A_Scale,
                        B,
                        D_fp6,
                        M_GLOBAL, N_GLOBAL, K_GLOBAL,
                        Reduction_Workspace,  
                        Split_K);
    cudaEventRecord(stop);
    cudaEventSynchronize(stop);
    checkLastCudaError(__LINE__);
    //
    float milliseconds_fp6 = 0.0f;
    cudaEventElapsedTime(&milliseconds_fp6, start, stop);
    milliseconds_fp6 = milliseconds_fp6 / BENCHMARK_ITERATION;
    float tflops_fp6 = static_cast<double>((static_cast<double>(M_GLOBAL) * N_GLOBAL * K_GLOBAL * 2) / (milliseconds_fp6 / 1000.)) / 1e12;
    half* D_fp6_h = NULL;  // col major
    D_fp6_h       = (half*)malloc(sizeof(half) * M_GLOBAL * N_GLOBAL);
    cudaMemcpy(D_fp6_h, D_fp6, sizeof(half) * M_GLOBAL * N_GLOBAL, cudaMemcpyDeviceToHost);  // Col Major
    cudaFree(D_fp6);
    cudaFree(Reduction_Workspace);
    /////////////////////////////////////////////////////////////////////////////////////////////////
    double totalRelativeError_fp6  = ComputeTotalError(D_cublas_h, D_fp6_h, M_GLOBAL, N_GLOBAL);
    printf("************************************* ");
    printf("[%d-bit Weights, e%dm%d] M: %d N: %d K: %d SplitK: %d", 6, 3, 2, M_GLOBAL, N_GLOBAL, K_GLOBAL, SPLIT_K);
    printf(" ************************************\n");
    PrintPerformance("cuBLAS", milliseconds_cublas, tflops_cublas, 0.0);
    PrintPerformance("fp6_llm", milliseconds_fp6, tflops_fp6, totalRelativeError_fp6);

    //PrintMismatch("fp6", 100, 0.002, D_cublas_h, D_fp6_h, M_GLOBAL, N_GLOBAL);

    free(D_cublas_h);
    free(D_fp6_h);
    cudaFree(B);
    return 0;
}


================================================
FILE: tests/cpp/kernel_test_fpx.cu
================================================
#include "kernel_test.h"
#include "fp6_linear.cuh"


int main(int argc, char** argv)
{
    // Parsing the inputs from CLI.
    if (argc != 7) {
        printf("Wrong Inputs! Correct input format: ./kernel_test EXPONENT MANTISSA #Row_Weight #Column_Weight BatchSize SplitK\n");
        return -1;
    }
    int EXPONENT    = atoi(argv[1]);
    int MANTISSA    = atoi(argv[2]);
    size_t M_GLOBAL = atoi(argv[3]);
    size_t K_GLOBAL = atoi(argv[4]);
    size_t N_GLOBAL = atoi(argv[5]);
    int    SPLIT_K  = atoi(argv[6]);
    int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
    assert(EXPONENT==2 || EXPONENT==3);
    assert(MANTISSA==2);
    assert(M_GLOBAL%256==0);                 // Currently, M_GLOBAL must be a multiple of 256.
    assert(K_GLOBAL%64==0);                  // Currently, K_GLOBAL must be a multiple of 64.
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // Matrices in quantized FPx models with faked values.
    unsigned char* A_xbit_h  = (unsigned char*)malloc(M_GLOBAL*K_GLOBAL*BIT_WIDTH/8);       CheckMallocCPU(A_xbit_h, __LINE__);     // Weight matrix with FP6 values, stored in row-major.
    for(size_t i=0; i<M_GLOBAL*K_GLOBAL*BIT_WIDTH/8; i++)   A_xbit_h[i] = rand() % 256;                                             // Random initialization.
    half*          A_Scale_h = (half*)malloc(M_GLOBAL*sizeof(half));                CheckMallocCPU(A_Scale_h, __LINE__);    // Quantization Scales with FP16 values.
    for(size_t i=0; i<M_GLOBAL; i++)                A_Scale_h[i] = float(rand()%256)/64.0f;                                 // Scale
    // Generaing FP16 format of the Weight Matrix
    half* A_16bit_h = (half*) malloc(M_GLOBAL*K_GLOBAL*sizeof(half));                           CheckMallocCPU(A_16bit_h, __LINE__);
    dequant_matrix_fp_eXmY_to_fp16(EXPONENT, MANTISSA, A_16bit_h, A_xbit_h, M_GLOBAL, K_GLOBAL, A_Scale_h);
    // In-place weight pre-packing
    weight_matrix_prepacking_fp_eXmY(EXPONENT, MANTISSA, (int*)A_xbit_h, (int*)A_xbit_h, M_GLOBAL, K_GLOBAL);

    // Devices Memory
    unsigned char*  A_xbit;
    half*           A_Scale;
    half*           A_16bit;
    cudaMalloc(reinterpret_cast<void**>(&A_xbit),  M_GLOBAL*K_GLOBAL*BIT_WIDTH/8);             CheckMallocCUDA(A_xbit, __LINE__);
    cudaMalloc(reinterpret_cast<void**>(&A_Scale), M_GLOBAL*sizeof(half));             CheckMallocCUDA(A_Scale, __LINE__);
    cudaMalloc(reinterpret_cast<void**>(&A_16bit), M_GLOBAL*K_GLOBAL*sizeof(half));    CheckMallocCUDA(A_16bit, __LINE__);
    // Memory Copy from CPU to GPU
    cudaMemcpy(A_xbit,     A_xbit_h,  M_GLOBAL*K_GLOBAL*BIT_WIDTH/8,          cudaMemcpyHostToDevice);
    cudaMemcpy(A_Scale,    A_Scale_h,          M_GLOBAL*sizeof(half),          cudaMemcpyHostToDevice);
    cudaMemcpy(A_16bit,             A_16bit_h,          M_GLOBAL*K_GLOBAL*sizeof(half), cudaMemcpyHostToDevice);
    checkLastCudaError(__LINE__);
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // B Matrix: Activations
    half* B_h = (half*)malloc(sizeof(half) * K_GLOBAL * N_GLOBAL); CheckMallocCPU(B_h);       // col major 
    for (size_t i = 0; i < N_GLOBAL * K_GLOBAL; i++)
        B_h[i] = __float2half_rn(static_cast<float>((rand() % 5)) / 5 - 0.5f);
    // Device memory
    half* B            = NULL;
    cudaMalloc(reinterpret_cast<void**>(&B), sizeof(half) * N_GLOBAL * K_GLOBAL);               CheckMallocCUDA(B, __LINE__);
    // Memory Copy from CPU to GPU
    cudaMemcpy(B, B_h, sizeof(half) * N_GLOBAL * K_GLOBAL, cudaMemcpyHostToDevice);
    checkLastCudaError(__LINE__);
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    cublasStatus_t cublas_status;
    cudaEvent_t start, stop;
    cudaEventCreate(&start);
    cudaEventCreate(&stop);
    checkLastCudaError(__LINE__);
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    //printf("Launching CuBlas...\n");
    half* D_cublas = NULL;
    cudaMalloc(reinterpret_cast<void**>(&D_cublas), sizeof(half) * M_GLOBAL * N_GLOBAL);        CheckMallocCUDA(D_cublas, __LINE__);
    cudaMemset(D_cublas, 0, sizeof(half) * M_GLOBAL * N_GLOBAL);
    cublasHandle_t handle;
    cublasCreate(&handle);
    cublasSetStream(handle, 0);
    //cublasSetMathMode(handle, CUBLAS_PEDANTIC_MATH);          // Tensor core NOT enabled
    cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH);             // Tensor core enabled
    cudaDeviceSynchronize();
    int              m = M_GLOBAL, n = N_GLOBAL, k = K_GLOBAL;
    const float      alpha     = 1.0;
    const float      beta      = 0.0;
    cublasGemmAlgo_t CuBlasALG = static_cast<cublasGemmAlgo_t>(0);
    for (int i = 0; i < WARM_UP_ITERATION; i++) {
        cublas_status = cublasGemmEx(handle,
                                     CUBLAS_OP_T,   CUBLAS_OP_N,
                                     m, n, k,
                                     &alpha,
                                     A_16bit,   CUDA_R_16F, k,
                                     B,         CUDA_R_16F, k,
                                     &beta,
                                     D_cublas,  CUDA_R_16F, m,
                                     CUDA_R_32F,
                                     CuBlasALG);
        checkCublasError(cublas_status, __LINE__);
    }
    cudaEventRecord(start);
    for (int i = 0; i < BENCHMARK_ITERATION; i++)
        cublas_status = cublasGemmEx(handle,
                                     CUBLAS_OP_T,   CUBLAS_OP_N,
                                     m, n, k,
                                     &alpha,
                                     A_16bit,   CUDA_R_16F, k,
                                     B,         CUDA_R_16F, k,
                                     &beta,
                                     D_cublas,  CUDA_R_16F, m,
                                     CUDA_R_32F,
                                     CuBlasALG);
    cudaEventRecord(stop);
    cudaEventSynchronize(stop);
    //
    float milliseconds_cublas = 0;
    cudaEventElapsedTime(&milliseconds_cublas, start, stop);
    milliseconds_cublas = milliseconds_cublas / BENCHMARK_ITERATION;
    float tflops_cublas = static_cast<double>((static_cast<double>(M_GLOBAL) * N_GLOBAL * K_GLOBAL * 2) / (milliseconds_cublas / 1000.)) / 1e12;
    //
    half* D_cublas_h = NULL;  // col major
    D_cublas_h       = (half*)malloc(sizeof(half) * M_GLOBAL * N_GLOBAL);   CheckMallocCPU(D_cublas_h);
    cudaMemcpy(D_cublas_h, D_cublas, sizeof(half) * M_GLOBAL * N_GLOBAL, cudaMemcpyDeviceToHost);  // Col Major
    cudaFree(D_cublas);
    checkLastCudaError(__LINE__);
    /////////////////////////////////////////////////////////////////////////////////////////////////
    //printf("Launching FP6-LLM...\n");
    half* D_fp6 = NULL;
    cudaMalloc(reinterpret_cast<void**>(&D_fp6), sizeof(half) * M_GLOBAL * N_GLOBAL); CheckMallocCUDA(D_fp6);
    cudaMemset(D_fp6, 0, sizeof(half) * M_GLOBAL * N_GLOBAL);
    //
    int Split_K = SPLIT_K;
    float* Reduction_Workspace = NULL;
    cudaMalloc(reinterpret_cast<void**>(&Reduction_Workspace), sizeof(float) * M_GLOBAL * N_GLOBAL * Split_K);   CheckMallocCUDA(Reduction_Workspace, __LINE__);
    //
    for (int i = 0; i < WARM_UP_ITERATION; i++)
        fp_eXmY_linear_kernel(  
            EXPONENT,
            MANTISSA,
            0,
            (uint4*)A_xbit, A_Scale,
            B,
            D_fp6,
            M_GLOBAL, N_GLOBAL, K_GLOBAL,
            Reduction_Workspace,  
            Split_K);
    cudaEventRecord(start);
    for (int i = 0; i < BENCHMARK_ITERATION; i++)
        fp_eXmY_linear_kernel(  
            EXPONENT,
            MANTISSA,
            0,
            (uint4*)A_xbit, A_Scale,
            B,
            D_fp6,
            M_GLOBAL, N_GLOBAL, K_GLOBAL,
            Reduction_Workspace,  
            Split_K);
    cudaEventRecord(stop);
    cudaEventSynchronize(stop);
    checkLastCudaError(__LINE__);
    //
    float milliseconds_fp6 = 0.0f;
    cudaEventElapsedTime(&milliseconds_fp6, start, stop);
    milliseconds_fp6 = milliseconds_fp6 / BENCHMARK_ITERATION;
    float tflops_fp6 = static_cast<double>((static_cast<double>(M_GLOBAL) * N_GLOBAL * K_GLOBAL * 2) / (milliseconds_fp6 / 1000.)) / 1e12;
    half* D_fp6_h = NULL;  // col major
    D_fp6_h       = (half*)malloc(sizeof(half) * M_GLOBAL * N_GLOBAL);
    cudaMemcpy(D_fp6_h, D_fp6, sizeof(half) * M_GLOBAL * N_GLOBAL, cudaMemcpyDeviceToHost);  // Col Major
    cudaFree(D_fp6);
    cudaFree(Reduction_Workspace);
    /////////////////////////////////////////////////////////////////////////////////////////////////
    double totalRelativeError_fp6  = ComputeTotalError(D_cublas_h, D_fp6_h, M_GLOBAL, N_GLOBAL);
    printf("************************************* ");
    printf("[%d-bit Weights, e%dm%d] M: %d N: %d K: %d SplitK: %d", BIT_WIDTH, EXPONENT, MANTISSA, M_GLOBAL, N_GLOBAL, K_GLOBAL, SPLIT_K);
    printf(" ************************************\n");
    PrintPerformance("cuBLAS", milliseconds_cublas, tflops_cublas, 0.0);
    PrintPerformance("quant_llm", milliseconds_fp6, tflops_fp6, totalRelativeError_fp6);
    //PrintMismatch("fp6", 100, 0.002, D_cublas_h, D_fp6_h, M_GLOBAL, N_GLOBAL);

    free(D_cublas_h);
    free(D_fp6_h);
    cudaFree(B);
    return 0;
}


================================================
FILE: tests/cpp/run.sh
================================================
#! /bin/bash

# Batch sizes to test
N=(1 2 3 4 5 6 7 8)

# Benchmarking the specific Matrix Shape from llama-1 65b
M=(13824 5120  22016 8192)
K=(5120  13824 8192  22016)
SplitK=(4 10 5 6)      # SplitK for smaller Batch Sizes
#SplitK=(3)      # SplitK for Batch Sizes 128
#SplitK=(3)      # SplitK for Batch Sizes 512
#SplitK=(1)      # SplitK for Batch Sizes 2048


# Benchmarking Matrix Shapes from OPT models 
#M=(21504        7168     28672  7168    27648   9216    36864   9216    36864   12288   49152   12288)
#K=(7168         7168     7168   28672   9216    9216    9216    36864   12288   12288   12288   49152)
#SplitK=(2       7       7      7       2       6       3       6       3       4       1       4)        # SplitK for smaller Batch Sizes
#SplitK=(5       7       7      7       1       3       3       3       3       2       1       2)        # SplitK for Batch Sizes 128
#SplitK=(3       2       3      3       1       3       1       3       1       1       1       1)        # SplitK for Batch Sizes 512
#SplitK=(1       2       1      2       1       1       1       1       1       1       1       1)        # SplitK for Batch Sizes 2048

# Benchmarking Matrix Shapes from llama-1 models 
#M=(12288       4096    11008   4096    15360   5120    13824   5120    19968   6656    17920   6656    24576   8192    22016   8192)
#K=(4096        4096    4096    11008   5120    5120    5120    13824   6656    6656    6656    17920   8192    8192    8192    22016)
#SplitK=(9      13      5       13      3       10      4       10      5       8       3       8       2       6       5       6)      # SplitK for smaller Batch Sizes
#SplitK=(2      6       5       6       3       5       2       5       4       4       3       4       1       3       5       3)      # SplitK for Batch Sizes 128
#SplitK=(1      3       3       3       2       4       1       4       1       1       3       1       1       3       2       3)      # SplitK for Batch Sizes 512
#SplitK=(1      2       1       2       1       1       1       1       1       1       1       1       1       1       1       1)      # SplitK for Batch Sizes 2048

#mkdir -p Profiling
for ((i=0;i<${#M[@]};i++)) 
do
    echo "Processing Shape ${i}..."
    for BS in ${N[@]} 
    do
        echo "BS=${BS}"
        #ncu -f -o Profiling/M${M[i]}K${K[i]}N${BS} --set full \
        #./kernel_test_fp6 ${M[i]} ${K[i]} ${BS} ${SplitK[i]}
        ./kernel_test_fpx 2 2 ${M[i]} ${K[i]} ${BS} ${SplitK[i]}
        ./kernel_test_fpx 3 2 ${M[i]} ${K[i]} ${BS} ${SplitK[i]}  
    done
done


================================================
FILE: tests/python/kernel_test_fp6.py
================================================
import argparse
import torch
import fp6_llm

WARMUP = 10
REPEAT = 1000

parser = argparse.ArgumentParser(description='The shape of the MatMul: (M, K)*(K, N)->(M, N).')
parser.add_argument('--OC',        type=int, required=False,     default=4096,   help='number of rows of the weight matrix.')
parser.add_argument('--IC',        type=int, required=False,     default=4096,   help='number of columns of the weight matrix.')
parser.add_argument('--BS',        type=int, required=False,     default=32,     help='inference batch size.')
parser.add_argument('--splitK',    type=int, required=False,     default=1,      help='Split-K parameters allow users to split the GEMM computation along the K dimension so that more CTAs will be created with a better SM utilization.')
args = parser.parse_args()

assert(args.OC%256==0)
assert(args.IC%64==0)

print("#"*64)
print(args)

fp6_weight = torch.randint(4294967295, (args.OC,args.IC//16*3)).to(torch.int)    # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t.
fp16_scale = torch.rand(args.OC).to(torch.half)+0.5
fp16_activation = torch.rand(args.BS, args.IC).to(torch.half)+0.5
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

# fp6-fp16 GEMM (fp6-llm)
####################################################################################################################################
torch.cuda.synchronize()
fp6_weight_packed = fp6_llm.weight_prepacking_cpu(fp6_weight)
act_cuda = fp16_activation.cuda()
weight_cuda = fp6_weight_packed.cuda()
scale_cuda = fp16_scale.cuda()
for i in range(WARMUP):
    results_fp6_llm = fp6_llm.linear_forward_cuda(act_cuda, weight_cuda, scale_cuda, args.splitK);
start_event.record()
for i in range(REPEAT):
    results_fp6_llm = fp6_llm.linear_forward_cuda(act_cuda, weight_cuda, scale_cuda, args.splitK);
end_event.record()
torch.cuda.synchronize()
fp6_llm_time_ms = start_event.elapsed_time(end_event)/REPEAT
fp6_llm_tflops  = args.OC*args.IC*args.BS*2/fp6_llm_time_ms/1e9
####################################################################################################################################

# baseline fp16 GEMM (cuBLAS)
####################################################################################################################################
torch.cuda.synchronize()
fp16_weight = fp6_llm.weight_dequant_cpu(fp6_weight, fp16_scale)
cuBLAS_MatMul = torch.nn.Linear(args.IC, args.OC, False)
results_cublas = None
with torch.no_grad():
    cuBLAS_MatMul.weight = torch.nn.Parameter(fp16_weight.clone().cuda())
    act_cuda = fp16_activation.cuda()
    for i in range(WARMUP):
        results_cublas = cuBLAS_MatMul(act_cuda)
    start_event.record()
    for i in range(REPEAT):
        results_cublas = cuBLAS_MatMul(act_cuda)
    end_event.record()
torch.cuda.synchronize()
cublas_time_ms = start_event.elapsed_time(end_event)/REPEAT
cublas_tflops  = args.OC*args.IC*args.BS*2/cublas_time_ms/1e9
####################################################################################################################################

# Performance
print( 'cuBLAS  time: {:.3f} ms \t\t cuBLAS  TFLOPs: {:.1f}'.format(cublas_time_ms,  cublas_tflops) )
print( 'fp6-llm time: {:.3f} ms \t\t fp6-llm TFLOPs: {:.1f}'.format(fp6_llm_time_ms, fp6_llm_tflops) )
print( 'speedup: {:.2f}'.format(cublas_time_ms/fp6_llm_time_ms) )

# Correctness
error             = results_cublas.cpu() - results_fp6_llm.cpu()
ground_truth      = results_cublas.cpu()
mean_error        = torch.mean(abs(error))
mean_ground_truth = torch.mean(abs(ground_truth))
relative_error    = mean_error.item()/mean_ground_truth.item()
print( "relative error: {:.6f}".format(relative_error) )

================================================
FILE: tests/python/kernel_test_fpx.py
================================================
import argparse
import torch
import fp6_llm

WARMUP = 10
REPEAT = 1000

parser = argparse.ArgumentParser(description='The shape of the MatMul: (M, K)*(K, N)->(M, N).')
parser.add_argument('--OC',        type=int, required=False,     default=4096,   help='number of rows of the weight matrix.')
parser.add_argument('--IC',        type=int, required=False,     default=4096,   help='number of columns of the weight matrix.')
parser.add_argument('--BS',        type=int, required=False,     default=32,     help='inference batch size.')
parser.add_argument('--splitK',    type=int, required=False,     default=1,      help='Split-K parameters allow users to split the GEMM computation along the K dimension so that more CTAs will be created with a better SM utilization.')
parser.add_argument('--EXP',       type=int, required=False,     default=2,      help='number of bits of fpx Exponent, can be set to 2 or 3.')
parser.add_argument('--MAN',       type=int, required=False,     default=2,      help='number of bits of fpx Mantissa, can only be set to 2.')
args = parser.parse_args()

EXPONENT  = args.EXP
MANTISSA  = args.MAN
BIT_WIDTH = 1 + EXPONENT + MANTISSA
assert EXPONENT in [2,3]
assert MANTISSA in [2]

assert(args.OC%256==0)
assert(args.IC%64==0)

print("#"*64)
print(args)

fpx_weight = torch.randint(4294967295, (args.OC,args.IC//32*BIT_WIDTH)).to(torch.int)    # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t.
fp16_scale = torch.rand(args.OC).to(torch.half)+0.5
fp16_activation = torch.rand(args.BS, args.IC).to(torch.half)+0.5

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

# fpx-fp16 GEMM (fp6-llm)
####################################################################################################################################
torch.cuda.synchronize()
fpx_weight_packed = fp6_llm.weight_prepacking_eXmY_cpu(EXPONENT, MANTISSA, fpx_weight)
act_cuda = fp16_activation.cuda()
weight_cuda = fpx_weight_packed.cuda()
scale_cuda = fp16_scale.cuda()
for i in range(WARMUP):
    results_fp6_llm = fp6_llm.linear_forward_eXmY_cuda(EXPONENT, MANTISSA, act_cuda, weight_cuda, scale_cuda, args.splitK)
start_event.record()
for i in range(REPEAT):
    results_fp6_llm = fp6_llm.linear_forward_eXmY_cuda(EXPONENT, MANTISSA, act_cuda, weight_cuda, scale_cuda, args.splitK)
end_event.record()
torch.cuda.synchronize()
fp6_llm_time_ms = start_event.elapsed_time(end_event)/REPEAT
fp6_llm_tflops  = args.OC*args.IC*args.BS*2/fp6_llm_time_ms/1e9
####################################################################################################################################

# baseline fp16 GEMM (cuBLAS)
####################################################################################################################################
torch.cuda.synchronize()
fp16_weight = fp6_llm.weight_dequant_eXmY_cpu(EXPONENT, MANTISSA, fpx_weight, fp16_scale)
cuBLAS_MatMul = torch.nn.Linear(args.IC, args.OC, False)
results_cublas = None
with torch.no_grad():
    cuBLAS_MatMul.weight = torch.nn.Parameter(fp16_weight.clone().cuda())
    act_cuda = fp16_activation.cuda()
    for i in range(WARMUP):
        results_cublas = cuBLAS_MatMul(act_cuda)
    start_event.record()
    for i in range(REPEAT):
        results_cublas = cuBLAS_MatMul(act_cuda)
    end_event.record()
torch.cuda.synchronize()
cublas_time_ms = start_event.elapsed_time(end_event)/REPEAT
cublas_tflops  = args.OC*args.IC*args.BS*2/cublas_time_ms/1e9
####################################################################################################################################

# Performance
print( 'cuBLAS    time: {:.3f} ms \t\t cuBLAS    TFLOPs: {:.1f}'.format(cublas_time_ms,  cublas_tflops ) )
print( 'quant-llm time: {:.3f} ms \t\t quant-llm TFLOPs: {:.1f}'.format(fp6_llm_time_ms, fp6_llm_tflops) )
print( 'speedup: {:.2f}'.format(cublas_time_ms/fp6_llm_time_ms) )

# Correctness
error             = results_cublas.cpu() - results_fp6_llm.cpu()
ground_truth      = results_cublas.cpu()
mean_error        = torch.mean(abs(error))
mean_ground_truth = torch.mean(abs(ground_truth))
relative_error    = mean_error.item()/mean_ground_truth.item()
print( "relative error: {:.6f}".format(relative_error) )


# [FP5_e2m2] Setting each element of the input matrices to 1.0f instead of a random value.
#fpx_weight = torch.zeros(args.OC, args.IC//32*BIT_WIDTH).to(torch.int64)  
#for i in range(args.OC):
#    for j in range(args.IC//32):
#        fpx_weight[i][j*BIT_WIDTH+0] = 272762913
#        fpx_weight[i][j*BIT_WIDTH+1] = 1107829124
#        fpx_weight[i][j*BIT_WIDTH+2] = 136414224 
#        fpx_weight[i][j*BIT_WIDTH+3] = 562303042 
#        fpx_weight[i][j*BIT_WIDTH+4] = 2215657992
#fpx_weight = fpx_weight.to(torch.int)
#fp16_scale      = torch.zeros(args.OC).to(torch.half)         +1.0
#fp16_activation = torch.zeros(args.BS, args.IC).to(torch.half)+1.0

# 1.0f values in fp5_e2m2
# 00100 00100 00100 00100 00100 00100 00100 00100       00100 00100 00100 00100 00100 00100 00100 00100     00100 00100 00100 00100 00100 00100 00100 00100     00100 00100 00100 00100 00100 00100 00100 00100
# “00100001 00001000 01000010 00010000”     “10000100 00100001 00001000 01000010”    “00010000 10000100 00100001 00001000”   “01000010 00010000 10000100 00100001”  “00001000 01000010 00010000 10000100”
# Considering Byte order within a INT32
# 00010000 01000010 00001000 00100001       01000010 00001000 00100001 10000100      00001000 00100001 10000100 00010000      00100001 10000100 00010000 01000010   10000100 00010000 01000010 00001000
# 00010000010000100000100000100001          01000010000010000010000110000100         00001000001000011000010000010000         00100001100001000001000001000010      10000100000100000100001000001000
# 272762913                                 1107829124                               136414224                                562303042                             2215657992


# [FP6_e3m2] Setting each element of the input matrices to 1.0f instead of a random value.
#fp6_weight = torch.zeros(args.OC, args.IC//32*BIT_WIDTH).to(torch.int64)
#for i in range(args.OC):
#    for j in range(args.IC//16):
#        fp6_weight[i][j*3+0] = 806142768
#        fp6_weight[i][j*3+1] = 3274706115
#        fp6_weight[i][j*3+2] = 214118412
#        fp6_weight[i][j*3+3] = 806142768
#        fp6_weight[i][j*3+4] = 3274706115
#        fp6_weight[i][j*3+5] = 214118412
#fp6_weight = fp6_weight.to(torch.int32)
#fp16_scale = torch.zeros(args.OC).to(torch.half)+1.0
#fp16_activation = torch.zeros(args.BS, args.IC).to(torch.half)+1.0

# 1.0f values in fp6_e3m2
# 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 001100 
# "00110000110000110000110000110000"       "11000011000011000011000011000011"     "00001100001100001100001100001100"
# 00110000 11000011 00001100 00110000      11000011 00001100 00110000 11000011    00001100 00110000 11000011 00001100
# Considering Byte order within a INT32
# 00110000 00001100 11000011 00110000      11000011 00110000 00001100 11000011     00001100 11000011 00110000 00001100
# 00110000000011001100001100110000         11000011001100000000110011000011        00001100110000110011000000001100
# 806142768                                3274706115                              214118412

================================================
FILE: tests/python/run.sh
================================================
#! /bin/bash

# [Batch sizes to test]
# If you want to test the performance of FP6-LLM for larger inference batch sizes, 
# which typically happens during prompt processing, 
# please revise this file by simply "commenting" and "uncommenting".

# BS <=64
N=(1 2 4 8 16 32 64)
SplitK=(5 6 7 6)

# BS = 128
#N=(128)
#SplitK=(5 3 3 3)

# BS = 256
#N=(256)
#SplitK=(4 3 2 3)

# BS = 512
#N=(512)
#SplitK=(2 5 2 4)

# BS = 1024
#N=(1024)
#SplitK=(1 2 1 2)

# BS >= 2048
#N=(2048, 4096, 8192, 16384)
#SplitK=(1 1 1 1)

# Benchmarking the specific Matrix Shape from llama2-70b
M=(10240 8192  57344 8192)
K=(8192  8192  8192  28672)

#mkdir -p Profiling
for ((i=0;i<${#M[@]};i++)) 
do
    for BS in ${N[@]} 
    do
        #ncu -f -o Profiling/M${M[i]}K${K[i]}N${BS} --set full \
        #python kernel_test_fp6.py --OC=${M[i]} --IC=${K[i]} --BS=${BS} --splitK=${SplitK[i]}
        python kernel_test_fpx.py --OC=${M[i]} --IC=${K[i]} --BS=${BS} --splitK=${SplitK[i]}
        python kernel_test_fpx.py --OC=${M[i]} --IC=${K[i]} --BS=${BS} --splitK=${SplitK[i]} --EXP=3 --MAN=2
    done
done
Download .txt
gitextract_xbh0tnpv/

├── .gitignore
├── LICENSE
├── README.md
├── examples/
│   └── README.md
├── fp6_llm/
│   ├── Makefile
│   ├── __init__.py
│   └── csrc/
│       ├── fp6_linear.cu
│       ├── fp6_linear.cuh
│       ├── include/
│       │   ├── configs.h
│       │   ├── kernel_matmul.cuh
│       │   ├── kernel_reduction.cuh
│       │   ├── ptx_cp.async.cuh
│       │   ├── ptx_mma.cuh
│       │   ├── utils_core.cuh
│       │   ├── utils_gmem.cuh
│       │   └── utils_parallel_dequant.cuh
│       ├── pybind.cpp
│       └── utils/
│           ├── common.h
│           ├── weight_dequant.h
│           ├── weight_prepacking.h
│           └── weight_quant.h
├── setup.py
└── tests/
    ├── cpp/
    │   ├── Makefile
    │   ├── kernel_test.h
    │   ├── kernel_test_fp6.cu
    │   ├── kernel_test_fpx.cu
    │   └── run.sh
    └── python/
        ├── kernel_test_fp6.py
        ├── kernel_test_fpx.py
        └── run.sh
Download .txt
SYMBOL INDEX (16 symbols across 7 files)

FILE: fp6_llm/__init__.py
  function Num_Wave (line 6) | def Num_Wave(M, N, SplitK, Num_GPU_SMs):
  function HeuristicFuntion_SplitK (line 13) | def HeuristicFuntion_SplitK(M, N, Num_GPU_SMs):

FILE: fp6_llm/csrc/pybind.cpp
  function PYBIND11_MODULE (line 6) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)

FILE: fp6_llm/csrc/utils/common.h
  function Extract_X_Bits_To_A_Byte (line 5) | char Extract_X_Bits_To_A_Byte(unsigned char* Bytes, int ByteOffset, int ...

FILE: fp6_llm/csrc/utils/weight_dequant.h
  function DeQuantMatrix_FP6_To_FP16 (line 51) | void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h,...
  function dequant_matrix_fp_eXmY_to_fp16 (line 54) | void dequant_matrix_fp_eXmY_to_fp16(const int EXPONENT, const int MANTIS...

FILE: fp6_llm/csrc/utils/weight_prepacking.h
  function Extract_segments_from_8_padded_fpx (line 23) | void Extract_segments_from_8_padded_fpx(unsigned char Seg_xbit[], unsign...
  function weight_matrix_prepacking (line 199) | void weight_matrix_prepacking(int* packed_weights, int *FP6Weights, size...
  function weight_matrix_prepacking_fp_eXmY (line 204) | void weight_matrix_prepacking_fp_eXmY(const int EXPONENT, const int MANT...

FILE: fp6_llm/csrc/utils/weight_quant.h
  function cast_fp16_fp6 (line 9) | void cast_fp16_fp6(uint16_t* FP16x4, uint8_t* FP6x4)
  function weight_prepacking_fp16_to_fp6 (line 87) | void weight_prepacking_fp16_to_fp6(uint16_t* weight_16bit,

FILE: tests/cpp/kernel_test.h
  function checkCublasError (line 14) | void __forceinline__ CheckMallocCPU(void* PTR, int line = -1) {
  function checkLastCudaError (line 36) | void checkLastCudaError(int line)
  function ComputeTotalError (line 48) | double ComputeTotalError(half* CuBlas, half* Other, size_t m, size_t n)
  function PrintPerformance (line 59) | void PrintPerformance(const char* KernelName, float milliseconds, float ...
  function PrintMismatch (line 69) | void PrintMismatch(const char* KernelName,
Condensed preview — 30 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (154K chars).
[
  {
    "path": ".gitignore",
    "chars": 80,
    "preview": "*ncu\n*.so\n*.o\n*.ncu-rep\n*.nsys-rep\n*.sqlite\nbuild\nkernel_test\n*.egg-info\n.vscode"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 15965,
    "preview": "# Quant-LLM (FP6, FP5, FPx...)\nSix-bit quantization (FP6) can achieve **better trade-offs** between [*model quality*](#1"
  },
  {
    "path": "examples/README.md",
    "chars": 116,
    "preview": "# Example of LLM inference using FP6-LLM\n\nExample scripts of using FP6-LLM for end-to-end inference are coming soon."
  },
  {
    "path": "fp6_llm/Makefile",
    "chars": 1707,
    "preview": "# host compiler\nHOST_COMPILER ?= g++\nNVCC          := nvcc -ccbin $(HOST_COMPILER)\n\n# internal flags\nNVCCFLAGS   := -m$("
  },
  {
    "path": "fp6_llm/__init__.py",
    "chars": 884,
    "preview": "from fp6_llm_cuda import linear_forward_cuda, weight_prepacking_cpu, weight_dequant_cpu, linear_forward_eXmY_cuda, weigh"
  },
  {
    "path": "fp6_llm/csrc/fp6_linear.cu",
    "chars": 17757,
    "preview": "#include \"include/kernel_matmul.cuh\"\n#include \"include/kernel_reduction.cuh\"\n#include \"utils/weight_prepacking.h\"\n#inclu"
  },
  {
    "path": "fp6_llm/csrc/fp6_linear.cuh",
    "chars": 2911,
    "preview": "#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n/*\n* Computes FP6-FP16 GEMM (C++ interface).\n*/\ncuda"
  },
  {
    "path": "fp6_llm/csrc/include/configs.h",
    "chars": 2951,
    "preview": "#ifndef CONFIGS_H\n#define CONFIGS_H\n\n//#define DEBUG_MODE\n#define PIPELINE_LEVEL_GMEM 2\n#define PIPELINE_LEVEL_SMEM 2   "
  },
  {
    "path": "fp6_llm/csrc/include/kernel_matmul.cuh",
    "chars": 15580,
    "preview": "#include \"configs.h\"\n#include \"utils_gmem.cuh\"\n#include \"utils_core.cuh\"\n\n/************************** Bitwidth of Weight"
  },
  {
    "path": "fp6_llm/csrc/include/kernel_reduction.cuh",
    "chars": 2527,
    "preview": "/***************************************************************************\n * Copyright 2023 The FLash-LLM Authors. Al"
  },
  {
    "path": "fp6_llm/csrc/include/ptx_cp.async.cuh",
    "chars": 2209,
    "preview": "/***************************************************************************\n * Copyright 2023 The FLash-LLM Authors. Al"
  },
  {
    "path": "fp6_llm/csrc/include/ptx_mma.cuh",
    "chars": 3303,
    "preview": "/***************************************************************************\n * Copyright 2023 The FLash-LLM Authors. Al"
  },
  {
    "path": "fp6_llm/csrc/include/utils_core.cuh",
    "chars": 7684,
    "preview": "#ifndef UTILS_CORE_CUH\n#define UTILS_CORE_CUH\n\n#include <assert.h>\n\n#include \"configs.h\"\n#include \"ptx_mma.cuh\"\n#include"
  },
  {
    "path": "fp6_llm/csrc/include/utils_gmem.cuh",
    "chars": 3235,
    "preview": "#ifndef UTILS_GMEM_CUH\n#define UTILS_GMEM_CUH\n\n#include <assert.h>\n#include \"configs.h\"\n#include \"ptx_cp.async.cuh\"\n\n/* "
  },
  {
    "path": "fp6_llm/csrc/include/utils_parallel_dequant.cuh",
    "chars": 4388,
    "preview": "#ifndef UTILS_PARALLELDEQUANT_CUH\n#define UTILS_PARALLELDEQUANT_CUH\n\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <"
  },
  {
    "path": "fp6_llm/csrc/pybind.cpp",
    "chars": 773,
    "preview": "#include <pybind11/pybind11.h>\n#include <torch/extension.h>\n\n#include \"fp6_linear.cuh\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_"
  },
  {
    "path": "fp6_llm/csrc/utils/common.h",
    "chars": 641,
    "preview": "#ifndef UTILS_COMMON_H\n#define UTILS_COMMON_H\n\ntemplate<int EXPONENT, int MANTISSA>\nunsigned char Extract_X_Bits_To_A_By"
  },
  {
    "path": "fp6_llm/csrc/utils/weight_dequant.h",
    "chars": 2781,
    "preview": "#include <stdio.h>\n#include <stdlib.h>\n#include <assert.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runti"
  },
  {
    "path": "fp6_llm/csrc/utils/weight_prepacking.h",
    "chars": 8801,
    "preview": "#include <stdio.h>\n#include <stdlib.h>\n#include <assert.h>\n#include <vector>\n#include \"common.h\"\n\n/*\n * Inputs:\n * (1) u"
  },
  {
    "path": "fp6_llm/csrc/utils/weight_quant.h",
    "chars": 4186,
    "preview": "// Author: Zhen Zheng\n// To be used in the future as a tool to generating the FP6 matrix from the FP16 matrix.\n\n#include"
  },
  {
    "path": "setup.py",
    "chars": 1270,
    "preview": "from setuptools import find_packages, setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExte"
  },
  {
    "path": "tests/cpp/Makefile",
    "chars": 1522,
    "preview": "# host compiler\nHOST_COMPILER ?= g++\nNVCC          := nvcc -ccbin $(HOST_COMPILER)\n\n# internal flags\nNVCCFLAGS   := -m$("
  },
  {
    "path": "tests/cpp/kernel_test.h",
    "chars": 3219,
    "preview": "#include <stdio.h>\n#include <stdlib.h>\n#include <assert.h>\n\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runt"
  },
  {
    "path": "tests/cpp/kernel_test_fp6.cu",
    "chars": 9104,
    "preview": "#include \"kernel_test.h\"\n#include \"fp6_linear.cuh\"\n\nint main(int argc, char** argv)\n{\n    // Parsing the inputs from CLI"
  },
  {
    "path": "tests/cpp/kernel_test_fpx.cu",
    "chars": 9463,
    "preview": "#include \"kernel_test.h\"\n#include \"fp6_linear.cuh\"\n\n\nint main(int argc, char** argv)\n{\n    // Parsing the inputs from CL"
  },
  {
    "path": "tests/cpp/run.sh",
    "chars": 2566,
    "preview": "#! /bin/bash\n\n# Batch sizes to test\nN=(1 2 3 4 5 6 7 8)\n\n# Benchmarking the specific Matrix Shape from llama-1 65b\nM=(13"
  },
  {
    "path": "tests/python/kernel_test_fp6.py",
    "chars": 3762,
    "preview": "import argparse\nimport torch\nimport fp6_llm\n\nWARMUP = 10\nREPEAT = 1000\n\nparser = argparse.ArgumentParser(description='Th"
  },
  {
    "path": "tests/python/kernel_test_fpx.py",
    "chars": 7387,
    "preview": "import argparse\nimport torch\nimport fp6_llm\n\nWARMUP = 10\nREPEAT = 1000\n\nparser = argparse.ArgumentParser(description='Th"
  },
  {
    "path": "tests/python/run.sh",
    "chars": 1081,
    "preview": "#! /bin/bash\n\n# [Batch sizes to test]\n# If you want to test the performance of FP6-LLM for larger inference batch sizes,"
  }
]

About this extraction

This page contains the full source code of the usyd-fsalab/fp6_llm GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 30 files (145.7 KB), approximately 42.7k tokens, and a symbol index with 16 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!