Repository: hahnyuan/PTQ4ViT
Branch: main
Commit: b42c790c2811
Files: 16
Total size: 180.8 KB
Directory structure:
gitextract_dm7h0_9y/
├── .gitignore
├── README.md
├── configs/
│ ├── BasePTQ.py
│ └── PTQ4ViT.py
├── example/
│ ├── get_int.py
│ ├── test_ablation.py
│ ├── test_all.py
│ └── test_vit.py
├── quant_layers/
│ ├── conv.py
│ ├── linear.py
│ └── matmul.py
└── utils/
├── datasets.py
├── integer.py
├── models.py
├── net_wrap.py
└── quant_calib.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
tmp
*.pyc
__pycache__
*.pth
.vscode
checkpoints
*.log
*.csv
*.png
*.jpg
output
*.weights
*.tmp.*
data
ckt
*.out
*.zip
*.json
test.ipynb
int_weights
================================================
FILE: README.md
================================================
# PTQ4ViT
Post-Training Quantization Framework for Vision Transformers.
We use the twin uniform quantization method to reduce the quantization error on these activation values.
And we use a Hessian guided metric to evaluate different scaling factors, which improves the accuracy of calibration with a small cost.
The quantized vision transformers (ViT, DeiT, and Swin) achieve near-lossless prediction accuracy (less than 0.5\% drop at 8-bit quantization) on the ImageNet classification task. Please read the [paper](https://arxiv.org/abs/2111.12293) for details.
## Updates
*19/07/2022*
Add discussion on Base PTQ, and provide more ablation study results.
### Number of Calibration Images
| Model | W8A8 #ims=32 | W6A6 #ims=32 | W8A8 #ims=128 | W6A6 #ims=128 |
|:------------:|:------------:|:------------:|:-------------:|:-------------:|
| ViT-S/224/32 | 75.58 | 71.91 | 75.54 | 72.29 |
| ViT-S/224 | 81.00 | 78.63 | 80.99 | 78.44 |
| ViT-B/224 | 84.25 | 81.65 | 84.27 | 81.84 |
| ViT-B/384 | 85.83 | 83.35 | 85.81 | 83.84 |
| DeiT-S/224 | 79.47 | 76.28 | 79.41 | 76.51 |
| DeiT-B/224 | 81.48 | 80.25 | 81.54 | 80.30 |
| DeiT-B/384 | 82.97 | 81.55 | 83.01 | 81.67 |
| Swin-T/224 | 81.25 | 80.47 | 81.27 | 80.30 |
| Swin-S/224 | 83.11 | 82.38 | 83.15 | 82.38 |
| Swin-B/224 | 85.15 | 84.01 | 85.17 | 84.15 |
| Swin-B/384 | 86.39 | 85.39 | 86.36 | 85.45 |
| Model | Time #ims=32 | Time #ims=128 |
|:------------:|:------------:|:-------------:|
| ViT-S/224/32 | 2 min | 5 min |
| ViT-S/224 | 3 min | 7 min |
| ViT-B/224 | 4 min | 13 min |
| ViT-B/384 | 12 min | 43 min |
| DeiT-S/224 | 3 min | 7 min |
| DeiT-B/224 | 4 min | 16 min |
| DeiT-B/384 | 14 min | 52 min |
| Swin-T/224 | 3 min | 9 min |
| Swin-S/224 | 8 min | 17 min |
| Swin-B/224 | 10 min | 23 min |
| Swin-B/384 | 25 min | 69 min |
One of the targets of PTQ4ViT is to quickly quantize a vision transformer.
We have proposed to pre-compute the output and gradient of each layer and compute the influence of scaling factor candidates in batches to reduce the quantization time.
As demonstrated in the second table, PTQ4ViT can quantize most vision transformers in several minutes using 32 calibration images.
Using 128 calibration images significantly increases the quantization time.
We observe the Top-1 accuracy varies slightly in the first table, demonstrating PTQ4ViT is not very sensitive to the number of calibration images.
### Base PTQ
Base PTQ is a simple quantization strategy and serves as a benchmark for our experiments.
Like PTQ4ViT, we quantize all weights and inputs for fully-connect layers (including the first projection layer and the last prediction layer), as well as all input matrices of matrix multiplication operations.
For fully-connected layers, we use layerwise scaling factors $\Delta_W$ for weight quantization and $\Delta_X$ for input quantization; while for matrix multiplication operations, we use $\Delta_A$ and $\Delta_B$ for A's quantization and B's quantization respectively.
To get the best scaling factors, we apply a linear grid search on the search space.
The same as EasyQuantand Liu et al., we take hyper-parameters $\alpha=0.5$, $\beta = 1.2$, one search round and use cosine distance as the metric.
Note that in PTQ4ViT, we change the hyper-parameters to $\alpha=0$, $\beta = 1.2$ and three search rounds, which slightly improves the performance.
It should be noticed that Base PTQ adopts a parallel quantization paradigm, which makes it essentially different from sequential quantization paradigms such as EasyQuant.
In sequential quantization, the input data of the current quantizing layer is generated with all previous layers quantizing weights and activations.
While in parallel quantization, the input data of the current quantizing layer is simply the raw output of the previous layer.
In practice, we found sequential quantization on vision transformers suffers from significant accuracy degradation on small calibration datasets.
While parallel quantization shows robustness on small calibration datasets.
Therefore, we choose parallel quantization for both Base PTQ and PTQ4ViT.
### More Ablation Study
We supply more ablation studies for the hyper-parameters.
It is enough to set the number of quantization intervals $\ge$ 20 (accuracy change $< 0.3\%$).
It is enough to set the upper bound of m $\ge$ 15 (no accuracy change).
The best settings of alpha and beta vary from different layers.
It is appropriate to set $\alpha=0$ and $\beta=1/2^{k-1}$, which has little impact on search efficiency.
We observe that search rounds has little impact on the prediction accuracy (accuracy change $<$ 0.05\% when search rounds $>1$).
We randomly take 32 calibration images to quantize different models 20 times and we observe the fluctuation is not significant.
The mean/std of accuracies are: ViT-S/32 $75.55\%/0.055\%$ , ViT-S $80.96\%/0.046\%$, ViT-B $84.12\%/0.068\%$, DeiT-S $79.45\%/0.094\%$ , and Swin-S $83.11\%/0.035\%$.
*15/01/2022*
Add saved quantized models with PTQ4ViT.
| model | link |
|:------------:|:--------:|
| ViT-S/224/32 | [Google](https://drive.google.com/file/d/195JJJKULvaukte6PA9U08oezjd176CTs/view?usp=sharing) |
| ViT-S/224 | [Google](https://drive.google.com/file/d/14uEDgRmDBYoKoZtpO9IWMfG8Uvkt_OuL/view?usp=sharing) |
| ViT-B/224 | [Google](https://drive.google.com/file/d/1ou6s9Vd-_iyQ7sj7VYET-pRvJA6WMMLA/view?usp=sharing) |
| ViT-B/384 | [Google](https://drive.google.com/file/d/1tuU8or8SfQomtoWam7WFTnUxtuw3n7fs/view?usp=sharing) |
| DeiT-S/224 | [Google](https://drive.google.com/file/d/1673fX-SuiRlHhm7k0Yyyx_3ynwtvUPyf/view?usp=sharing) |
| DeiT-B/224 | [Google](https://drive.google.com/file/d/1WRAtmPF0kDR9iTLc9gv_63aEkOCZ_zOI/view?usp=sharing) |
| DeiT-B/384 | [Google](https://drive.google.com/file/d/1mPPlM2ioe4zts_rdKdjZTCUj8KcbquyA/view?usp=sharing) |
| Swin-T/224 | [Google](https://drive.google.com/file/d/1bSahHgtL3yFaHPlG-SDtu__YY0zJ8lxr/view?usp=sharing) |
| Swin-S/224 | [Google](https://drive.google.com/file/d/1SxAdDTwQaeJFWnHLFXncVocxMNBIPDOE/view?usp=sharing) |
| Swin-B/224 | [Google](https://drive.google.com/file/d/19UUUQYJGs5SQaDe27PjY3x1QTBU5hwXm/view?usp=sharing) |
| Swin-B/384 | [Google](https://drive.google.com/file/d/1SxAdDTwQaeJFWnHLFXncVocxMNBIPDOE/view?usp=sharing) |
*10/12/2021*
Add `utils/integer.py`, you can now:
1. convert calibrated fp32 model into int8
2. register pre-forward hook in the model, and fetch activation in int8. (We use uint8 to store results
of twin quantization, please refer to the paper to see the bits' layout).
## Install
### Requirement
- python>=3.5
- pytorch>=1.5
- matplotlib
- pandas
- timm
### Datasets
To run example testing, you should put your ImageNet2012 dataset in path `/datasets/imagenet`.
We use `ViTImageNetLoaderGenerator` in `utils/datasets.py` to initialize our DataLoader.
If your Imagenet datasets are stored elsewhere, you'll need to manually pass its root as an argument when instantiating a `ViTImageNetLoaderGenerator`.
## Usage
### 1. Run example quantization
To test on all models with BasePTQ/PTQ4ViT, run
```bash
python example/test_all.py
```
To run ablation testing, run
```bash
python example/test_ablation.py
```
You can run the testing scripts with multiple GPUs. For example, calling
```bash
python example/test_all.py --multigpu --n_gpu 6
```
will use 6 gpus to run the test.
### 2. Download quantized model checkpoints
(Coming soon)
## Results
### Results of BasePTQ
| model | original | w8a8 | w6a6 |
|:------------:|:--------:|:------:|:-------:|
| ViT-S/224/32 | 75.99 | 73.61 | 60.144 |
| ViT-S/224 | 81.39 | 80.468 | 70.244 |
| ViT-B/224 | 84.54 | 83.896 | 75.668 |
| ViT-B/384 | 86.00 | 85.352 | 46.886 |
| DeiT-S/224 | 79.80 | 77.654 | 72.268 |
| DeiT-B/224 | 81.80 | 80.946 | 78.786 |
| DeiT-B/384 | 83.11 | 82.33 | 68.442 |
| Swin-T/224 | 81.39 | 80.962 | 78.456 |
| Swin-S/224 | 83.23 | 82.758 | 81.742 |
| Swin-B/224 | 85.27 | 84.792 | 83.354 |
| Swin-B/384 | 86.44 | 86.168 | 85.226 |
Results of PTQ4ViT
| model | original | w8a8 | w6a6 |
|:------------:|:--------:|:------:|:-------:|
| ViT-S/224/32 | 75.99 | 75.582 | 71.908 |
| ViT-S/224 | 81.39 | 81.002 | 78.63 |
| ViT-B/224 | 84.54 | 84.25 | 81.65 |
| ViT-B/384 | 86.00 | 85.828 | 83.348 |
| DeiT-S/224 | 79.80 | 79.474 | 76.282 |
| DeiT-B/224 | 81.80 | 81.482 | 80.25 |
| DeiT-B/384 | 83.11 | 82.974 | 81.55 |
| Swin-T/224 | 81.39 | 81.246 | 80.47 |
| Swin-S/224 | 83.23 | 83.106 | 82.38 |
| Swin-B/224 | 85.27 | 85.146 | 84.012 |
| Swin-B/384 | 86.44 | 86.394 | 85.388 |
### Results of Ablation
- ViT-S/224 (original top-1 accuracy 81.39%)
| Hessian Guided | Softmax Twin | GELU Twin | W8A8 | W6A6 |
|:--------------:|:------------:|:---------:|:------:|:-------:|
| | | | 80.47 | 70.24 |
| ✓ | | | 80.93 | 77.20 |
| ✓ | ✓ | | 81.11 | 78.57 |
| ✓ | | ✓ | 80.84 | 76.93 |
| | ✓ | ✓ | 79.25 | 74.07 |
| ✓ | ✓ | ✓ | 81.00 | 78.63 |
- ViT-B/224 (original top-1 accuracy 84.54%)
| Hessian Guided | Softmax Twin | GELU Twin | W8A8 | W6A6 |
|:--------------:|:------------:|:---------:|:------:|:-------:|
| | | | 83.90 | 75.67 |
| ✓ | | | 83.97 | 79.90 |
| ✓ | ✓ | | 84.07 | 80.76 |
| ✓ | | ✓ | 84.10 | 80.82 |
| | ✓ | ✓ | 83.40 | 78.86 |
| ✓ | ✓ | ✓ | 84.25 | 81.65 |
- ViT-B/384 (original top-1 accuracy 86.00%)
| Hessian Guided | Softmax Twin | GELU Twin | W8A8 | W6A6 |
|:--------------:|:------------:|:---------:|:------:|:-------:|
| | | | 85.35 | 46.89 |
| ✓ | | | 85.42 | 79.99 |
| ✓ | ✓ | | 85.67 | 82.01 |
| ✓ | | ✓ | 85.60 | 82.21 |
| | ✓ | ✓ | 84.35 | 80.86 |
| ✓ | ✓ | ✓ | 85.89 | 83.19 |
## Citation
```
@article{PTQ4ViT_arixv2022,
title={PTQ4ViT: Post-Training Quantization Framework for Vision Transformers},
author={Zhihang Yuan, Chenhao Xue, Yiqi Chen, Qiang Wu, Guangyu Sun},
journal={arXiv preprint arXiv:2111.12293},
year={2022},
}
```
================================================
FILE: configs/BasePTQ.py
================================================
from quant_layers.conv import PTQSLQuantConv2d, BatchingEasyQuantConv2d
from quant_layers.linear import PTQSLBatchingQuantLinear, PostGeluPTQSLBatchingQuantLinear
from quant_layers.matmul import PTQSLBatchingQuantMatMul, SoSPTQSLBatchingQuantMatMul
bit = 8
conv_fc_name_list = ["qconv", "qlinear_qkv", "qlinear_proj", "qlinear_MLP_1", "qlinear_MLP_2", "qlinear_classifier", "qlinear_reduction"]
matmul_name_list = [ "qmatmul_qk", "qmatmul_scorev"]
w_bit = {name: bit for name in conv_fc_name_list}
a_bit = {name: bit for name in conv_fc_name_list}
A_bit = {name: bit for name in matmul_name_list}
B_bit = {name: bit for name in matmul_name_list}
ptqsl_conv2d_kwargs = {
"metric": "cosine",
"eq_alpha": 0.5,
"eq_beta": 1.2,
"eq_n": 100,
'search_round': 1,
"n_V": 1,
"n_H": 1,
}
ptqsl_linear_kwargs = {
"metric": "cosine",
"eq_alpha": 0.5,
"eq_beta": 1.2,
"eq_n": 100,
'search_round': 1,
"n_V": 1,
"n_H": 1,
"n_a": 1,
}
ptqsl_matmul_kwargs = {
"metric": "cosine",
"eq_alpha": 0.5,
"eq_beta": 1.2,
"eq_n": 100,
'search_round': 1,
"n_G_A": 1,
"n_V_A": 1,
"n_H_A": 1,
"n_G_B": 1,
"n_V_B": 1,
"n_H_B": 1,
}
def get_module(module_type, *args, **kwargs):
if module_type == "qconv":
kwargs.update(ptqsl_conv2d_kwargs)
module=BatchingEasyQuantConv2d(*args,**kwargs,w_bit=w_bit["qconv"],a_bit=32) # turn off activation quantization
# module=PTQSLQuantConv2d(*args,**kwargs,w_bit=w_bit["qconv"],a_bit=32) # turn off activation quantization
elif "qlinear" in module_type:
kwargs.update(ptqsl_linear_kwargs)
if module_type == "qlinear_qkv":
kwargs["n_V"] *= 3 # q, k, v
module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type])
else:
module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type])
elif "qmatmul" in module_type:
kwargs.update(ptqsl_matmul_kwargs)
module=PTQSLBatchingQuantMatMul(*args,**kwargs,A_bit=A_bit[module_type],B_bit=B_bit[module_type])
return module
================================================
FILE: configs/PTQ4ViT.py
================================================
from quant_layers.conv import PTQSLQuantConv2d, ChannelwiseBatchingQuantConv2d
from quant_layers.linear import PTQSLBatchingQuantLinear, PostGeluPTQSLBatchingQuantLinear
from quant_layers.matmul import PTQSLBatchingQuantMatMul, SoSPTQSLBatchingQuantMatMul
no_softmax = False
no_postgelu = False
bit = 8
conv_fc_name_list = ["qconv", "qlinear_qkv", "qlinear_proj", "qlinear_MLP_1", "qlinear_MLP_2", "qlinear_classifier", "qlinear_reduction"]
matmul_name_list = [ "qmatmul_qk", "qmatmul_scorev"]
w_bit = {name: bit for name in conv_fc_name_list}
a_bit = {name: bit for name in conv_fc_name_list}
A_bit = {name: bit for name in matmul_name_list}
B_bit = {name: bit for name in matmul_name_list}
ptqsl_conv2d_kwargs = {
"metric": "hessian",
"eq_alpha": 0.01,
"eq_beta": 1.2,
"eq_n": 100,
'search_round': 3,
"n_V": 1,
"n_H": 1,
}
ptqsl_linear_kwargs = {
"metric": "hessian",
"eq_alpha": 0.01,
"eq_beta": 1.2,
"eq_n": 100,
'search_round': 3,
"n_V": 1,
"n_H": 1,
"n_a": 1,
"bias_correction":True # Conventionally I'll not add an actual bias correction in linear
}
ptqsl_matmul_kwargs = {
"metric": "hessian",
"eq_alpha": 0.01,
"eq_beta": 1.2,
"eq_n": 100,
'search_round': 3,
"n_G_A": 1,
"n_V_A": 1,
"n_H_A": 1,
"n_G_B": 1,
"n_V_B": 1,
"n_H_B": 1,
}
def get_module(module_type, *args, **kwargs):
if module_type == "qconv":
kwargs.update(ptqsl_conv2d_kwargs)
module=ChannelwiseBatchingQuantConv2d(*args,**kwargs,w_bit=w_bit["qconv"],a_bit=32) # turn off activation quantization
# module=PTQSLQuantConv2d(*args,**kwargs,w_bit=w_bit["qconv"],a_bit=32) # turn off activation quantization
elif "qlinear" in module_type:
kwargs.update(ptqsl_linear_kwargs)
if module_type == "qlinear_qkv":
kwargs["n_V"] *= 3 # q, k, v
module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type])
elif module_type == "qlinear_MLP_2":
if no_postgelu:
module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type])
else:
module=PostGeluPTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type])
elif module_type == "qlinear_classifier":
kwargs["n_V"] = 1
module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type])
else:
module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type])
elif "qmatmul" in module_type:
kwargs.update(ptqsl_matmul_kwargs)
if module_type == "qmatmul_qk":
module=PTQSLBatchingQuantMatMul(*args,**kwargs,A_bit=A_bit[module_type],B_bit=B_bit[module_type])
elif module_type == "qmatmul_scorev":
if no_softmax:
module=PTQSLBatchingQuantMatMul(*args,**kwargs,A_bit=A_bit[module_type],B_bit=B_bit[module_type])
else:
module=SoSPTQSLBatchingQuantMatMul(*args,**kwargs,A_bit=A_bit[module_type],B_bit=B_bit[module_type])
return module
================================================
FILE: example/get_int.py
================================================
import sys
sys.path.insert(0,'..')
sys.path.insert(0,'.')
from example.test_vit import *
import utils.net_wrap as net_wrap
import utils.datasets as datasets
import utils.integer as integer
from utils.quant_calib import HessianQuantCalibrator
from itertools import product
def get_int_weights(name, config_name):
quant_cfg = init_config(config_name)
net = get_net(name)
wrapped_modules=net_wrap.wrap_modules_in_net(net,quant_cfg)
g=datasets.ViTImageNetLoaderGenerator('/datasets/imagenet','imagenet',32,32,16, kwargs={"model":net})
test_loader=g.test_loader()
calib_loader=g.calib_loader(num=32)
quant_calibrator = HessianQuantCalibrator(net,wrapped_modules,calib_loader,sequential=False,batch_size=4) # 16 is too big for ViT-L-16
quant_calibrator.batching_quant_calib()
int_weights = integer.get_model_int_weight(wrapped_modules)
torch.save(int_weights, f"./int_weights/{name}.pth")
if __name__ == "__main__":
args = parse_args()
names = [
# "vit_tiny_patch16_224",
# "vit_small_patch32_224",
# "vit_small_patch16_224",
# "vit_base_patch16_224",
"vit_base_patch16_384",
# "deit_tiny_patch16_224",
# "deit_small_patch16_224",
# "deit_base_patch16_224",
# "deit_base_patch16_384",
# "swin_tiny_patch4_window7_224",
# "swin_small_patch4_window7_224",
# "swin_base_patch4_window7_224",
# "swin_base_patch4_window12_384",
]
config_names = ["PTQ4ViT", "BasePTQ"]
cfg_list = []
for name, config in product(names, config_names):
cfg_list.append({"name":name, "config_name":config})
if args.multiprocess:
multiprocess(get_int_weights, cfg_list, n_gpu=args.n_gpu)
else:
for cfg in cfg_list:
get_int_weights(**cfg)
================================================
FILE: example/test_ablation.py
================================================
from torch.nn.modules import module
from test_vit import *
from quant_layers.conv import MinMaxQuantConv2d
from quant_layers.linear import MinMaxQuantLinear, PTQSLQuantLinear
from quant_layers.matmul import MinMaxQuantMatMul, PTQSLQuantMatMul
import matplotlib.pyplot as plt
from utils.net_wrap import wrap_certain_modules_in_net
from tqdm import tqdm
import torch.nn.functional as F
import pickle as pkl
from itertools import product
import types
from utils.quant_calib import HessianQuantCalibrator, QuantCalibrator
from utils.models import get_net
import time
def test_all_ablation(name, cfg_modifier=lambda x: x, calib_size=32):
quant_cfg = init_config("PTQ4ViT")
quant_cfg = cfg_modifier(quant_cfg)
net = get_net(name)
wrapped_modules=net_wrap.wrap_modules_in_net(net,quant_cfg)
g=datasets.ViTImageNetLoaderGenerator('/datasets/imagenet','imagenet',32,32,16, kwargs={"model":net})
test_loader=g.test_loader()
calib_loader=g.calib_loader(num=calib_size)
quant_calibrator = HessianQuantCalibrator(net,wrapped_modules,calib_loader,sequential=False,batch_size=4) # 16 is too big for ViT-L-16
quant_calibrator.batching_quant_calib()
acc = test_classification(net,test_loader, description=quant_cfg.ptqsl_linear_kwargs["metric"])
print(f"model: {name} \n")
print(f"calibration size: {calib_size} \n")
print(f"bit settings: {quant_cfg.bit} \n")
print(f"ptqsl_conv2d_kwargs: {quant_cfg.ptqsl_conv2d_kwargs} \n")
print(f"ptqsl_linear_kwargs: {quant_cfg.ptqsl_linear_kwargs} \n")
print(f"ptqsl_matmul_kwargs: {quant_cfg.ptqsl_matmul_kwargs} \n")
print(f"accuracy: {acc} \n\n")
class cfg_modifier():
def __init__(self, **kwargs):
for name, value in kwargs.items():
setattr(self,name,value)
def __call__(self, cfg):
# bit setting
cfg.bit = self.bit_setting
cfg.w_bit = {name: self.bit_setting[0] for name in cfg.conv_fc_name_list}
cfg.a_bit = {name: self.bit_setting[1] for name in cfg.conv_fc_name_list}
cfg.A_bit = {name: self.bit_setting[1] for name in cfg.matmul_name_list}
cfg.B_bit = {name: self.bit_setting[1] for name in cfg.matmul_name_list}
# conv2d configs
cfg.ptqsl_conv2d_kwargs["n_V"] = self.linear_ptq_setting[0]
cfg.ptqsl_conv2d_kwargs["n_H"] = self.linear_ptq_setting[1]
cfg.ptqsl_conv2d_kwargs["metric"] = self.metric
cfg.ptqsl_conv2d_kwargs["search_round"] = self.search_round
cfg.ptqsl_conv2d_kwargs["parallel_eq_n"] = 1 # maximum 7 , reserve 4Gb for gradient
cfg.ptqsl_conv2d_kwargs["init_layerwise"] = False
# linear configs
cfg.ptqsl_linear_kwargs["n_V"] = self.linear_ptq_setting[0]
cfg.ptqsl_linear_kwargs["n_H"] = self.linear_ptq_setting[1]
cfg.ptqsl_linear_kwargs["n_a"] = self.linear_ptq_setting[2]
cfg.ptqsl_linear_kwargs["metric"] = self.metric
cfg.ptqsl_linear_kwargs["search_round"] = self.search_round
cfg.ptqsl_linear_kwargs["parallel_eq_n"] = 1 # maximum 7, reserve 4Gb for gradient
cfg.ptqsl_linear_kwargs["init_layerwise"] = False
# matmul configs
cfg.ptqsl_matmul_kwargs["metric"] = self.metric
cfg.ptqsl_matmul_kwargs["search_round"] = self.search_round
cfg.ptqsl_matmul_kwargs["parallel_eq_n"] = 1 # maximum 3!
cfg.ptqsl_matmul_kwargs["init_layerwise"] = False
# ablation
cfg.no_softmax = self.no_softmax
cfg.no_postgelu = self.no_postgelu
return cfg
if __name__=='__main__':
args = parse_args()
names = [
"vit_small_patch16_224",
"vit_base_patch16_224",
"vit_base_patch16_384",
]
metrics = ["hessian", "cosine"]
linear_ptq_settings = [(1,1,1)] # n_V, n_H, n_a
search_rounds = [3]
calib_sizes = [32]
bit_settings = [(8,8), (6,6)] # weight, activation
no_softmaxs = [True, False]
no_postgelus = [True, False]
cfg_list = []
for name, metric, linear_ptq_setting, search_round, calib_size, bit_setting, no_softmax, no_postgelu in product(names, metrics, linear_ptq_settings, search_rounds, calib_sizes, bit_settings, no_softmaxs, no_postgelus):
cfg_list.append({
"name": name,
"cfg_modifier":cfg_modifier(linear_ptq_setting=linear_ptq_setting, metric=metric, search_round=search_round, bit_setting=bit_setting, no_softmax=no_softmax, no_postgelu=no_postgelu),
"calib_size":calib_size,
})
if args.multiprocess:
multiprocess(test_all_ablation, cfg_list, n_gpu=args.n_gpu)
else:
for cfg in cfg_list:
test_all_ablation(**cfg)
================================================
FILE: example/test_all.py
================================================
from timm.models.layers import config
from torch.nn.modules import module
from test_vit import *
from quant_layers.conv import MinMaxQuantConv2d
from quant_layers.linear import MinMaxQuantLinear, PTQSLQuantLinear
from quant_layers.matmul import MinMaxQuantMatMul, PTQSLQuantMatMul
import matplotlib.pyplot as plt
from utils.net_wrap import wrap_certain_modules_in_net
from tqdm import tqdm
import torch.nn.functional as F
import pickle as pkl
from itertools import product
import types
from utils.quant_calib import HessianQuantCalibrator, QuantCalibrator
from utils.models import get_net
import time
def test_all(name, cfg_modifier=lambda x: x, calib_size=32, config_name="PTQ4ViT"):
quant_cfg = init_config(config_name)
quant_cfg = cfg_modifier(quant_cfg)
net = get_net(name)
wrapped_modules=net_wrap.wrap_modules_in_net(net,quant_cfg)
g=datasets.ViTImageNetLoaderGenerator('/datasets/imagenet','imagenet',32,32,16, kwargs={"model":net})
test_loader=g.test_loader()
calib_loader=g.calib_loader(num=calib_size)
# add timing
calib_start_time = time.time()
quant_calibrator = HessianQuantCalibrator(net,wrapped_modules,calib_loader,sequential=False,batch_size=4) # 16 is too big for ViT-L-16
quant_calibrator.batching_quant_calib()
calib_end_time = time.time()
acc = test_classification(net,test_loader, description=quant_cfg.ptqsl_linear_kwargs["metric"])
print(f"model: {name} \n")
print(f"calibration size: {calib_size} \n")
print(f"bit settings: {quant_cfg.bit} \n")
print(f"config: {config_name} \n")
print(f"ptqsl_conv2d_kwargs: {quant_cfg.ptqsl_conv2d_kwargs} \n")
print(f"ptqsl_linear_kwargs: {quant_cfg.ptqsl_linear_kwargs} \n")
print(f"ptqsl_matmul_kwargs: {quant_cfg.ptqsl_matmul_kwargs} \n")
print(f"calibration time: {(calib_end_time-calib_start_time)/60}min \n")
print(f"accuracy: {acc} \n\n")
class cfg_modifier():
def __init__(self, **kwargs):
for name, value in kwargs.items():
setattr(self,name,value)
def __call__(self, cfg):
# bit setting
cfg.bit = self.bit_setting
cfg.w_bit = {name: self.bit_setting[0] for name in cfg.conv_fc_name_list}
cfg.a_bit = {name: self.bit_setting[1] for name in cfg.conv_fc_name_list}
cfg.A_bit = {name: self.bit_setting[1] for name in cfg.matmul_name_list}
cfg.B_bit = {name: self.bit_setting[1] for name in cfg.matmul_name_list}
# conv2d configs
cfg.ptqsl_conv2d_kwargs["n_V"] = self.linear_ptq_setting[0]
cfg.ptqsl_conv2d_kwargs["n_H"] = self.linear_ptq_setting[1]
cfg.ptqsl_conv2d_kwargs["metric"] = self.metric
cfg.ptqsl_conv2d_kwargs["init_layerwise"] = False
# linear configs
cfg.ptqsl_linear_kwargs["n_V"] = self.linear_ptq_setting[0]
cfg.ptqsl_linear_kwargs["n_H"] = self.linear_ptq_setting[1]
cfg.ptqsl_linear_kwargs["n_a"] = self.linear_ptq_setting[2]
cfg.ptqsl_linear_kwargs["metric"] = self.metric
cfg.ptqsl_linear_kwargs["init_layerwise"] = False
# matmul configs
cfg.ptqsl_matmul_kwargs["metric"] = self.metric
cfg.ptqsl_matmul_kwargs["init_layerwise"] = False
return cfg
if __name__=='__main__':
args = parse_args()
names = [
"vit_tiny_patch16_224",
"vit_small_patch32_224",
"vit_small_patch16_224",
"vit_base_patch16_224",
"vit_base_patch16_384",
"deit_tiny_patch16_224",
"deit_small_patch16_224",
"deit_base_patch16_224",
"deit_base_patch16_384",
"swin_tiny_patch4_window7_224",
"swin_small_patch4_window7_224",
"swin_base_patch4_window7_224",
"swin_base_patch4_window12_384",
]
metrics = ["hessian"]
linear_ptq_settings = [(1,1,1)] # n_V, n_H, n_a
calib_sizes = [32,128]
bit_settings = [(8,8), (6,6)] # weight, activation
config_names = ["PTQ4ViT", "BasePTQ"]
cfg_list = []
for name, metric, linear_ptq_setting, calib_size, bit_setting, config_name in product(names, metrics, linear_ptq_settings, calib_sizes, bit_settings, config_names):
cfg_list.append({
"name": name,
"cfg_modifier":cfg_modifier(linear_ptq_setting=linear_ptq_setting, metric=metric, bit_setting=bit_setting),
"calib_size":calib_size,
"config_name": config_name
})
if args.multiprocess:
multiprocess(test_all, cfg_list, n_gpu=args.n_gpu)
else:
for cfg in cfg_list:
test_all(**cfg)
================================================
FILE: example/test_vit.py
================================================
import sys
sys.path.insert(0,'..')
sys.path.insert(0,'.')
import torch
import torch.nn as nn
from tqdm import tqdm
import argparse
from importlib import reload,import_module
import multiprocessing
import os
import time
from itertools import product
import utils.datasets as datasets
import utils.net_wrap as net_wrap
from utils.quant_calib import QuantCalibrator, HessianQuantCalibrator
from utils.models import get_net
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--n_gpu", type=int, default=6)
parser.add_argument("--multiprocess", action='store_true')
args = parser.parse_args()
return args
def test_classification(net,test_loader,max_iteration=None, description=None):
pos=0
tot=0
i = 0
max_iteration = len(test_loader) if max_iteration is None else max_iteration
with torch.no_grad():
q=tqdm(test_loader, desc=description)
for inp,target in q:
i+=1
inp=inp.cuda()
target=target.cuda()
out=net(inp)
pos_num=torch.sum(out.argmax(1)==target).item()
pos+=pos_num
tot+=inp.size(0)
q.set_postfix({"acc":pos/tot})
if i >= max_iteration:
break
print(pos/tot)
return pos/tot
def process(pid, experiment_process, args_queue, n_gpu):
"""
worker process.
"""
gpu_id=pid%n_gpu
os.environ['CUDA_VISIBLE_DEVICES']=f'{gpu_id}'
tot_run=0
while args_queue.qsize():
test_args=args_queue.get()
print(f"Run {test_args} on pid={pid} gpu_id={gpu_id}")
experiment_process(**test_args)
time.sleep(0.5)
tot_run+=1
# run_experiment(**args)
print(f"{pid} tot_run {tot_run}")
def multiprocess(experiment_process, cfg_list=None, n_gpu=6):
"""
run experiment processes on "n_gpu" cards via "n_gpu" worker process.
"cfg_list" arranges kwargs for each test point, and worker process will fetch kwargs and carry out an experiment.
"""
args_queue = multiprocessing.Queue()
for cfg in cfg_list:
args_queue.put(cfg)
ps=[]
for pid in range(n_gpu):
p=multiprocessing.Process(target=process,args=(pid,experiment_process,args_queue,n_gpu))
p.start()
ps.append(p)
for p in ps:
p.join()
def init_config(config_name):
"""initialize the config. Use reload to make sure it's fresh one!"""
_,_,files = next(os.walk("./configs"))
if config_name+".py" in files:
quant_cfg = import_module(f"configs.{config_name}")
else:
raise NotImplementedError(f"Invalid config name {config_name}")
reload(quant_cfg)
return quant_cfg
def experiment_basic(net='vit_base_patch16_384', config="PTQ4ViT"):
"""
A basic testbench.
"""
quant_cfg = init_config(config)
net = get_net(net)
wrapped_modules = net_wrap.wrap_modules_in_net(net,quant_cfg)
g=datasets.ViTImageNetLoaderGenerator('/datasets/imagenet','imagenet',32,32,16,kwargs={"model":net})
test_loader=g.test_loader()
calib_loader=g.calib_loader(num=32)
quant_calibrator = HessianQuantCalibrator(net,wrapped_modules,calib_loader,sequential=False,batch_size=4) # 16 is too big for ViT-L-16
quant_calibrator.batching_quant_calib()
test_classification(net,test_loader)
if __name__=='__main__':
args = parse_args()
cfg_list = []
nets = ['vit_tiny_patch16_224', "deit_base_patch16_384"]
configs= ['PTQ4ViT']
cfg_list = [{
"net":net,
"config":config,
}
for net, config in product(nets, configs)
]
if args.multiprocess:
multiprocess(experiment_basic, cfg_list, n_gpu=args.n_gpu)
else:
for cfg in cfg_list:
experiment_basic(**cfg)
================================================
FILE: quant_layers/conv.py
================================================
from numpy import not_equal
from torch import tensor
from quant_layers.linear import MinMaxQuantLinear
import torch
import torch.nn as nn
import torch.nn.functional as F
from itertools import product
class MinMaxQuantConv2d(nn.Conv2d):
"""
MinMax quantize weight and output
"""
def __init__(self,in_channels: int,
out_channels: int,
kernel_size,
stride = 1,
padding = 0,
dilation = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',mode='raw',w_bit=8,a_bit=8,bias_bit=None):
super().__init__(in_channels,out_channels,kernel_size,stride,padding,dilation,groups,bias,padding_mode)
self.n_calibration_steps=2
self.mode=mode
self.w_bit=w_bit
self.a_bit=a_bit
self.bias_bit=bias_bit
assert bias_bit is None,"No support bias bit now"
self.w_interval=None
self.a_interval=None
self.bias_interval=None
self.raw_input=None
self.raw_out=None
self.metric=None
self.next_nodes=[]
self.w_qmax=2**(self.w_bit-1)
self.a_qmax=2**(self.a_bit-1)
# self.bias_qmax=2**(self.bias_bit-1)
def forward(self, x):
if self.mode=='raw':
out=F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
elif self.mode=="quant_forward":
out=self.quant_forward(x)
elif self.mode=="calibration_step1":
out=self.calibration_step1(x)
elif self.mode=="calibration_step2":
out=self.calibration_step2(x)
else:
raise NotImplementedError
return out
def quant_weight_bias(self):
w=(self.weight/self.w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1)
w_sim=w.mul_(self.w_interval)
if self.bias is not None:
return w_sim,self.bias
# bias=(self.bias/self.bias_interval).round_().clamp_(-self.bias_qmax,self.bias_qmax-1)
# bias_sim=bias*self.bias_interval
# return w_sim,bias_sim
else:
return w_sim,None
def quant_input(self,x):
x_sim=(x/self.a_interval).round_().clamp_(-self.a_qmax,self.a_qmax-1)
x_sim.mul_(self.a_interval)
return x_sim
def quant_forward(self,x):
assert self.calibrated is not None,f"You should run calibrate_forward before run quant_forward for {self}"
w_sim,bias_sim=self.quant_weight_bias()
x_sim=self.quant_input(x)
out=F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups)
return out
def calibration_step1(self,x):
# step1: collection the FP32 values
out=F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
self.raw_input=x.cpu().detach()
self.raw_out=out.cpu().detach()
return out
def calibration_step2(self,x):
# step2: search for the best S^w and S^a of each layer
self.w_interval=(self.weight.data.abs().max()/(self.w_qmax-0.5)).detach()
self.a_interval=(x.abs().max()/(self.a_qmax-0.5)).detach()
self.calibrated=True
out=self.quant_forward(x)
return out
class QuantileQuantConv2d(MinMaxQuantConv2d):
"""
Quantile quantize weight and output
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size,
stride = 1,
padding = 0,
dilation = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
mode='raw',w_bit=8,a_bit=8,bias_bit=None,
w_quantile=0.9999,a_quantile=0.9999):
super().__init__(in_channels,out_channels,kernel_size,stride,padding,dilation,groups,bias,padding_mode,mode,w_bit,a_bit,bias_bit)
self.w_quantile = w_quantile
self.a_quantile = a_quantile
def _quantile(self, tensor, quantile):
if tensor.numel() >= 16777216:
n = tensor.numel()//16777216
return torch.quantile(tensor.view(-1)[:16777216*n].view(n,16777216),quantile,1).mean()
else:
return torch.quantile(tensor,quantile)
def calibration_step2(self,x):
# step2: search for the best S^w and S^o of each layer
self.w_interval=(self._quantile(self.weight.data.abs(),self.w_quantile)/(self.w_qmax-0.5)).detach()
self.a_interval=(self._quantile(x.abs(),self.a_quantile)/(self.a_qmax-0.5)).detach()
self.calibrated=True
out=self.quant_forward(x)
return out
class PTQSLQuantConv2d(MinMaxQuantConv2d):
"""
PTQSL on Conv2d
weight: (oc,ic,kw,kh) -> (oc,ic*kw*kh) -> divide into sub-matrixs and quantize
input: (B,ic,W,H), keep this shape
Only support SL quantization on weights.
"""
def __init__(self, in_channels: int,
out_channels: int,
kernel_size,
stride = 1,
padding = 0,
dilation = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',mode='raw',w_bit=8,a_bit=8,bias_bit=None,
metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10,
n_V=1, n_H=1, init_layerwise=False):
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit)
self.metric = metric
self.search_round = search_round
self.eq_alpha = eq_alpha
self.eq_beta = eq_beta
self.eq_n = eq_n
self.parallel_eq_n = parallel_eq_n
self.n_H = n_H
self.n_V = n_V
self.init_layerwise = init_layerwise
self.raw_grad = None
def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1):
"""
tensor_raw: *, features
tensor_sim: *, features
similarity: *
It's your job to calculate mean on * dims!
"""
if metric == "cosine":
similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=dim)
else:
if metric == "L1_norm":
similarity = -torch.abs(tensor_raw - tensor_sim)
elif metric == "L2_norm":
similarity = -(tensor_raw - tensor_sim) ** 2
elif metric == "linear_weighted_L2_norm":
similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2
elif metric == "square_weighted_L2_norm":
similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2
elif metric == "hessian":
raw_grad = self.raw_grad.reshape_as(tensor_raw)
similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2
else:
raise NotImplementedError(f"metric {metric} not implemented!")
similarity = torch.mean(similarity, dim=dim)
return similarity
def quant_weight_bias(self):
# self.weight_interval shape: n_V, 1, n_H, 1
oc,ic,kw,kh=self.weight.data.shape
w_sim = self.weight.view(self.n_V, oc//self.n_V, self.n_H, (ic*kw*kh)//self.n_H)
w_sim = (w_sim/self.w_interval).round_().clamp(-self.w_qmax,self.w_qmax-1).mul_(self.w_interval)
w_sim = w_sim.view(oc,ic,kw,kh)
return w_sim, self.bias
def _search_best_w_interval(self, x, weight_interval_candidates):
"""
Modularization of searching best weight intervals
"""
tmp_w_interval = self.w_interval.unsqueeze(0)
for v,h in product(range(self.n_V), range(self.n_H)):
similarities = []
for p_st in range(0, self.eq_n, self.parallel_eq_n):
p_ed = min(self.eq_n, p_st+self.parallel_eq_n)
cur_w_interval = tmp_w_interval.repeat(p_ed-p_st,1,1,1,1)
cur_w_interval[:,v:v+1,:,h:h+1,:] = weight_interval_candidates[p_st:p_ed,v:v+1,:,h:h+1,:]
# quantize weight and bias
oc,ic,kw,kh=self.weight.data.shape
w_sim = self.weight.view(self.n_V,oc//self.n_V,self.n_H,-1).unsqueeze(0) # shape: 1,n_V,crb_rows,n_H,crb_cols
w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,n_V,crb_rows,n_H,crb_cols
w_sim = w_sim.view(-1,ic,kw,kh) # shape: parallel_eq_n*oc,ic,kw,kh
bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None
# quantize input
x_sim = self.quant_input(x)
# calculate similarity and store them
out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: B,parallel_eq_n*oc,fw,fh
out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(1), chunks=p_ed-p_st, dim=2), dim=1) # shape: B,parallel_eq_n,oc,fw,fh
similarity = self._get_similarity(self.raw_out, out_sim, self.metric, dim=2) # shape: B,parallel_eq_n,fw,fh
similarity = torch.mean(similarity, [0,2,3]) # shape: parallel_eq_n
similarities.append(similarity)
# store best weight interval of h into tmp_w_interval
similarities = torch.cat(similarities, dim=0) # shape: eq_n
best_index = similarities.argmax(dim=0).reshape(-1,1,1,1,1)
tmp_w_interval[:,v:v+1,:,h:h+1,:] = torch.gather(weight_interval_candidates[:,v:v+1,:,h:h+1,:],dim=0,index=best_index)
self.w_interval = tmp_w_interval.squeeze(dim=0)
def _search_best_a_interval(self, x, input_interval_candidates):
similarities = []
for p_st in range(0,self.eq_n,self.parallel_eq_n):
p_ed = min(self.eq_n, p_st+self.parallel_eq_n)
cur_a_interval = input_interval_candidates[p_st:p_ed]
# quantize weight and bias
w_sim, bias_sim = self.quant_weight_bias()
# quantize input
B,ic,iw,ih = x.shape
x_sim=x.unsqueeze(0) # shape: 1,B,ic,iw,ih
x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: parallel_eq_n,B,ic,iw,ih
x_sim=x_sim.view(-1,ic,iw,ih)
# calculate similarity and store them
out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: parallel_eq_n*B,oc,fw,fh
out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(0), chunks=p_ed-p_st, dim=1), dim=0) # shape: parallel_eq_n,B,oc,fw,fh
similarity = self._get_similarity(self.raw_out.transpose(0,1), out_sim, self.metric, dim=2) # shape: parallel_eq_n,B,fw,fh
similarity = torch.mean(similarity, dim=[1,2,3]) # shape: parallel_eq_n
similarities.append(similarity)
# store best input interval and store in tmp_a_interval
similarities = torch.cat(similarities, dim=0) # shape: eq_n
a_best_index = similarities.argmax(dim=0).view(1,1,1,1,1)
self.a_interval = torch.gather(input_interval_candidates,dim=0,index=a_best_index).squeeze()
def _initialize_intervals(self, x):
self.a_interval=(x.abs().max()/(self.a_qmax-0.5)).detach()
if self.init_layerwise:
self.w_interval = ((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1)
else:
self.w_interval = (self.weight.view(self.n_V,self.out_channels//self.n_V,self.n_H,-1).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5))
def calibration_step2(self, x):
# initialize intervals with minmax intervals
self._initialize_intervals(x)
# put raw outs on GPU
self.raw_out = self.raw_out.to(x.device).unsqueeze(1) # shape: B,1,oc,W,H
# put raw grad on GPU
self.raw_grad = self.raw_grad.to(x.device) if self.raw_grad != None else None
# prepare weight intervals and similarities
weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,n_V,1,n_H,1
input_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.a_interval # shape: nq_n,1,1,1,1
for e in range(self.search_round):
# search for best weight interval
self._search_best_w_interval(x, weight_interval_candidates)
# search for best input interval
self._search_best_a_interval(x, input_interval_candidates)
self.raw_grad = self.raw_grad.to("cpu") if self.raw_grad != None else None
self.calibrated = True
out=self.quant_forward(x)
del self.raw_input, self.raw_out, self.raw_grad
return out
class BatchingEasyQuantConv2d(PTQSLQuantConv2d):
"""An agile implementation of Layerwise Easyquant"""
def __init__(self, in_channels: int,
out_channels: int,
kernel_size,
stride = 1,
padding = 0,
dilation = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',mode='raw',w_bit=8,a_bit=8,bias_bit=None,
metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10,
n_V=1, n_H=1, init_layerwise=False):
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode,
mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_V=n_V, n_H=n_H, init_layerwise=init_layerwise)
self.n_V = 1
self.n_H = 1
def _initialize_calib_parameters(self):
"""
set parameters for feeding calibration data
"""
self.calib_size = int(self.raw_input.shape[0])
self.calib_batch_size = int(self.raw_input.shape[0])
while True:
numel = (2*(self.raw_input.numel()+self.raw_out.numel())/self.calib_size*self.calib_batch_size) # number of parameters on GPU
self.parallel_eq_n = int((15*1024*1024*1024/4)//numel)
if self.parallel_eq_n <= 1:
self.calib_need_batching = True
self.calib_batch_size //= 2
else:
break
def _initialize_intervals(self):
self.w_interval=(self.weight.data.abs().max()/(self.w_qmax-0.5)).detach()
tmp_a_intervals = []
for b_st in range(0,self.calib_size,self.calib_batch_size):
b_ed = min(self.calib_size, b_st+self.calib_batch_size)
x_ = self.raw_input[b_st:b_ed].cuda()
a_interval_=(x_.abs().max()/(self.a_qmax-0.5)).detach().view(1,1)
tmp_a_intervals.append(a_interval_)
self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=False)
def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1, raw_grad=None):
"""
tensor_raw: *, features
tensor_sim: *, features
similarity: *
It's your job to calculate mean on * dims!
"""
if metric == "cosine":
similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=dim)
elif metric == "pearson":
# calculate similarity w.r.t complete feature map, but maintain dimension requirement
b, parallel_eq_n = tensor_sim.shape[0], tensor_sim.shape[1]
similarity = F.cosine_similarity(tensor_raw.view(b,1,-1), tensor_sim.view(b,parallel_eq_n,-1), dim=dim).view(b,parallel_eq_n,1,1)
else:
if metric == "L1_norm":
similarity = -torch.abs(tensor_raw - tensor_sim)
elif metric == "L2_norm":
similarity = -(tensor_raw - tensor_sim) ** 2
elif metric == "linear_weighted_L2_norm":
similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2
elif metric == "square_weighted_L2_norm":
similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2
elif metric == "hessian":
assert raw_grad != None, f"No raw grad!"
raw_grad = raw_grad.reshape_as(tensor_raw)
similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2
else:
raise NotImplementedError(f"metric {metric} not implemented!")
similarity = torch.mean(similarity, dim=dim)
return similarity
def quant_weight_bias(self):
w_sim = self.weight
w_sim = (w_sim/self.w_interval).round_().clamp(-self.w_qmax,self.w_qmax-1).mul_(self.w_interval)
return w_sim, self.bias
def quant_forward(self, x):
assert self.calibrated is not None,f"You should run calibrate_forward before run quant_forward for {self}"
w_sim,bias_sim=self.quant_weight_bias()
x_sim=self.quant_input(x) if self.a_bit < 32 else x
out=F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups)
return out
def _search_best_w_interval(self, weight_interval_candidates):
batch_similarities = []
for b_st in range(0,self.calib_size,self.calib_batch_size):
b_ed = min(self.calib_size, b_st+self.calib_batch_size)
x = self.raw_input[b_st:b_ed].cuda()
raw_out = self.raw_out[b_st:b_ed].cuda().unsqueeze(1) # shape: b,1,oc,fw,fh
raw_grad = self.raw_grad[b_st:b_ed].cuda()
similarities = []
for p_st in range(0, self.eq_n, self.parallel_eq_n):
p_ed = min(self.eq_n, p_st+self.parallel_eq_n)
cur_w_interval = weight_interval_candidates[p_st:p_ed] # shape: parallel_eq_n,1,1,1,1
# quantize weight and bias
oc,ic,kw,kh = self.weight.data.shape
w_sim = self.weight.unsqueeze(0) # shape: 1,oc,ic,kw,kh
w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,oc,ic,kw,kh
w_sim = w_sim.reshape(-1,ic,kw,kh) # shape: parallel_eq_n*oc,ic,kw,kh
bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None
# quantize input
x_sim = self.quant_input(x)
# calculate similarity and store them
out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: b,parallel_eq_n*oc,fw,fh
out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(1), chunks=p_ed-p_st, dim=2), dim=1) # shape: b,parallel_eq_n,oc,fw,fh
similarity = self._get_similarity(raw_out, out_sim, self.metric, dim=-3, raw_grad=raw_grad) # shape: b,parallel_eq_n,fw,fh
similarity = torch.mean(similarity, [2,3]) # shape: b,parallel_eq_n
similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1, parallel_eq_n
similarities.append(similarity)
# store best weight interval of h into tmp_w_interval
similarities = torch.cat(similarities, dim=1) # shape: 1,eq_n
batch_similarities.append(similarities)
batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) #shape: eq_n
best_index = batch_similarities.argmax(dim=0).reshape(1,1,1,1,1) # shape: 1,1,1,1,1
self.w_interval = torch.gather(weight_interval_candidates,dim=0,index=best_index).squeeze(dim=0)
def _search_best_a_interval(self, input_interval_candidates):
batch_similarities = []
for b_st in range(0,self.calib_size,self.calib_batch_size):
b_ed = min(self.calib_size, b_st+self.calib_batch_size)
x = self.raw_input[b_st:b_ed].cuda()
raw_out = self.raw_out[b_st:b_ed].cuda().unsqueeze(0) # shape: 1,b,oc,fw,fh
raw_grad = self.raw_grad[b_st:b_ed].cuda()
similarities = []
for p_st in range(0,self.eq_n,self.parallel_eq_n):
p_ed = min(self.eq_n, p_st+self.parallel_eq_n)
cur_a_interval = input_interval_candidates[p_st:p_ed] # shape: parallel_eq_n,1,1,1,1
# quantize weight and bias
w_sim, bias_sim = self.quant_weight_bias()
# quantize input
B,ic,iw,ih = x.shape
x_sim=x.unsqueeze(0) # shape: 1,b,ic,iw,ih
x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: parallel_eq_n,b,ic,iw,ih
x_sim=x_sim.view(-1,ic,iw,ih) # shape: parallel_eq_n*b,ic,iw,ih
# calculate similarity and store them
out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: parallel_eq_n*b,oc,fw,fh
out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(0), chunks=p_ed-p_st, dim=1), dim=0) # shape: parallel_eq_n,b,oc,fw,fh
similarity = self._get_similarity(raw_out, out_sim, self.metric, dim=-3, raw_grad=raw_grad) # shape: parallel_eq_n,b,fw,fh
similarity = torch.mean(similarity, dim=[3,4]) # shape: parallel_eq_n,b
similarity = torch.sum(similarity, dim=1, keepdim=True) # shape: parallel_eq_n,1
similarities.append(similarity)
similarities = torch.cat(similarities, dim=0) # shape: eq_n, 1
batch_similarities.append(similarities)
batch_similarities = torch.cat(batch_similarities, dim=1).sum(dim=1, keepdim=False) #shape: eq_n
a_best_index = batch_similarities.argmax(dim=0).view(1,1,1,1,1)
self.a_interval = torch.gather(input_interval_candidates,dim=0,index=a_best_index).squeeze()
def calibration_step2(self):
self._initialize_calib_parameters()
self._initialize_intervals()
weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval # shape: eq_n,1,1,1,1
input_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.a_interval # shape: eq_n,1,1,1,1
for e in range(self.search_round):
# search for best weight interval
self._search_best_w_interval(weight_interval_candidates)
# search for best input interval
if self.a_bit < 32:
self._search_best_a_interval(input_interval_candidates)
self.calibrated = True
del self.raw_input, self.raw_out, self.raw_grad
class ChannelwiseBatchingQuantConv2d(PTQSLQuantConv2d):
"""
Only implemented acceleration with batching_calibration_step2
setting a_bit to >= 32 will use minmax quantization, which means turning off activation quantization
"""
def __init__(self, in_channels: int,
out_channels: int,
kernel_size,
stride = 1,
padding = 0,
dilation = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',mode='raw',w_bit=8,a_bit=8,bias_bit=None,
metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10,
n_V=1, n_H=1, init_layerwise=False):
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode,
mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n,
n_V=n_V, n_H=n_H, init_layerwise=init_layerwise)
self.n_V = self.out_channels
self.n_H = 1
def _initialize_calib_parameters(self):
"""
set parameters for feeding calibration data
"""
self.calib_size = int(self.raw_input.shape[0])
self.calib_batch_size = int(self.raw_input.shape[0])
while True:
numel = (2*(self.raw_input.numel()+self.raw_out.numel())/self.calib_size*self.calib_batch_size) # number of parameters on GPU
self.parallel_eq_n = int((15*1024*1024*1024/4)//numel)
if self.parallel_eq_n <= 1:
self.calib_need_batching = True
self.calib_batch_size //= 2
else:
break
def _initialize_intervals(self):
# weight intervals: shape oc,1,1,1
if self.init_layerwise:
self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.out_channels,1,1,1)
else:
self.w_interval=((self.weight.abs().amax([1,2,3],keepdim=True))/(self.w_qmax-0.5))
# activation intervals: shape 1
tmp_a_intervals = []
for b_st in range(0,self.calib_size,self.calib_batch_size):
b_ed = min(self.calib_size, b_st+self.calib_batch_size)
x_ = self.raw_input[b_st:b_ed].cuda()
a_interval_=(x_.abs().max()/(self.a_qmax-0.5)).detach().view(1,1)
tmp_a_intervals.append(a_interval_)
self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=False)
def _get_similarity(self, tensor_raw, tensor_sim, metric=None, raw_grad=None):
"""
tensor_raw: *, features
tensor_sim: *, features
similarity: *, features
"""
if metric == "cosine":
# support cosine on patch dim, which is sub-optimal
# not supporting search best a interval
b, parallel_eq_n, oc = tensor_sim.shape[0], tensor_sim.shape[1], tensor_sim.shape[2]
similarity = F.cosine_similarity(tensor_raw.view(b,1,oc,-1), tensor_sim.view(b,parallel_eq_n,oc,-1), dim=-1).view(b,parallel_eq_n,oc,1,1)
else:
if metric == "L1_norm":
similarity = -torch.abs(tensor_raw - tensor_sim)
elif metric == "L2_norm":
similarity = -(tensor_raw - tensor_sim) ** 2
elif metric == "linear_weighted_L2_norm":
similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2
elif metric == "square_weighted_L2_norm":
similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2
elif metric == "hessian":
assert raw_grad != None, f"raw_grad is None in _get_similarity!"
raw_grad = raw_grad.reshape_as(tensor_raw)
similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2
else:
raise NotImplementedError(f"metric {metric} not implemented!")
return similarity
def _search_best_w_interval(self, weight_interval_candidates):
batch_similarities = []
for b_st in range(0,self.calib_size,self.calib_batch_size):
b_ed = min(self.calib_size, b_st+self.calib_batch_size)
x = self.raw_input[b_st:b_ed].cuda()
raw_out = self.raw_out[b_st:b_ed].cuda().unsqueeze(1) # shape: b,1,oc,fw,fh
raw_grad = self.raw_grad[b_st:b_ed].cuda()
similarities = []
for p_st in range(0, self.eq_n, self.parallel_eq_n):
p_ed = min(self.eq_n, p_st+self.parallel_eq_n)
cur_w_interval = weight_interval_candidates[p_st:p_ed] # shape: parallel_eq_n,oc,1,1,1
# quantize weight and bias
oc,ic,kw,kh = self.weight.data.shape
w_sim = self.weight.unsqueeze(0) # shape: 1,oc,ic,kw,kh
w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,oc,ic,kw,kh
w_sim = w_sim.reshape(-1,ic,kw,kh) # shape: parallel_eq_n*oc,ic,kw,kh
bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None
# quantize input
x_sim = self.quant_input(x) if self.a_bit < 32 else x
# calculate similarity and store them
out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: b,parallel_eq_n*oc,fw,fh
out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(1), chunks=p_ed-p_st, dim=2), dim=1) # shape: b,parallel_eq_n,oc,fw,fh
similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad) # shape: b,parallel_eq_n,oc,fw,fh
similarity = torch.mean(similarity, [3,4]) # shape: b,parallel_eq_n,oc
similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1, parallel_eq_n, oc
similarities.append(similarity)
# store best weight interval of h into tmp_w_interval
similarities = torch.cat(similarities, dim=1) # shape: 1,eq_n,oc
batch_similarities.append(similarities)
batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) #shape: eq_n,oc
best_index = batch_similarities.argmax(dim=0).reshape(1,-1,1,1,1) # shape: 1,oc,1,1,1
self.w_interval = torch.gather(weight_interval_candidates,dim=0,index=best_index).squeeze(dim=0)
def _search_best_a_interval(self, input_interval_candidates):
batch_similarities = []
for b_st in range(0,self.calib_size,self.calib_batch_size):
b_ed = min(self.calib_size, b_st+self.calib_batch_size)
x = self.raw_input[b_st:b_ed].cuda()
raw_out = self.raw_out[b_st:b_ed].cuda().unsqueeze(1) # shape: b,1,oc,fw,fh
raw_grad = self.raw_grad[b_st:b_ed].cuda()
similarities = []
for p_st in range(0,self.eq_n,self.parallel_eq_n):
p_ed = min(self.eq_n, p_st+self.parallel_eq_n)
cur_a_interval = input_interval_candidates[p_st:p_ed] # shape: parallel_eq_n,1,1,1,1
# quantize weight and bias
w_sim, bias_sim = self.quant_weight_bias()
# quantize input
B,ic,iw,ih = x.shape
x_sim=x.unsqueeze(0) # shape: 1,b,ic,iw,ih
x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: parallel_eq_n,b,ic,iw,ih
x_sim=x_sim.view(-1,ic,iw,ih) # shape: parallel_eq_n*b,ic,iw,ih
# calculate similarity and store them
out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: parallel_eq_n*b,oc,fw,fh
out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(0), chunks=p_ed-p_st, dim=1), dim=0) # shape: parallel_eq_n,b,oc,fw,fh
out_sim = out_sim.transpose_(0, 1) # shape: b,parallel_eq_n,oc,fw,fh
similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad=raw_grad) # shape: b,parallel_eq_n,oc,fw,fh
similarity = torch.mean(similarity, dim=[2,3,4]) # shape: b,parallel_eq_n
similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1,parallel_eq_n
similarities.append(similarity)
similarities = torch.cat(similarities, dim=1) # shape: 1,eq_n
batch_similarities.append(similarities)
batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) #shape: eq_n
a_best_index = batch_similarities.argmax(dim=0).view(1,1,1,1,1)
self.a_interval = torch.gather(input_interval_candidates,dim=0,index=a_best_index).squeeze()
def calibration_step2(self):
self._initialize_calib_parameters()
self._initialize_intervals()
weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,oc,1,1,1
input_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.a_interval # shape: eq_n,1,1,1,1
for e in range(self.search_round):
# search for best weight interval
self._search_best_w_interval(weight_interval_candidates)
# search for best input interval
if self.a_bit < 32:
self._search_best_a_interval(input_interval_candidates)
self.calibrated = True
del self.raw_input, self.raw_out, self.raw_grad
def quant_weight_bias(self):
w_sim = (self.weight/self.w_interval).round_().clamp(-self.w_qmax,self.w_qmax-1).mul_(self.w_interval)
return w_sim, self.bias
def quant_forward(self, x):
assert self.calibrated is not None,f"You should run calibrate_forward before run quant_forward for {self}"
w_sim,bias_sim=self.quant_weight_bias()
x_sim=self.quant_input(x) if self.a_bit < 32 else x
out=F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups)
return out
================================================
FILE: quant_layers/linear.py
================================================
from quant_layers.matmul import PTQSLBatchingQuantMatMul
import torch
import torch.nn as nn
import torch.nn.functional as F
class MinMaxQuantLinear(nn.Linear):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
mode = "raw",
w_bit = 8,
a_bit = 8,
bias_bit = None,
bias_correction=False):
super().__init__(in_features,out_features,bias)
self.n_calibration_step=2
self.mode = mode
self.w_bit = w_bit
self.a_bit = a_bit
self.bias_bit=bias_bit
assert bias_bit is None,"No support bias bit now"
self.w_interval=None
self.a_interval=None
self.raw_input=None
self.raw_out=None
self.metric=None
self.next_nodes=[]
self.w_qmax=2**(self.w_bit-1)
self.a_qmax=2**(self.a_bit-1)
self.bias_correction = bias_correction
def forward(self, x):
if self.mode=='raw':
out=F.linear(x, self.weight, self.bias)
elif self.mode=="quant_forward":
out=self.quant_forward(x)
elif self.mode=="calibration_step1":
out=self.calibration_step1(x)
elif self.mode=="calibration_step2":
out=self.calibration_step2(x)
else:
raise NotImplementedError
return out
def quant_weight_bias(self):
w=(self.weight/self.w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1)
w_sim=w.mul_(self.w_interval)
if self.bias is not None:
return w_sim,self.bias
# bias=(self.bias/self.bias_interval).round_().clamp_(-self.bias_qmax,self.bias_qmax-1)
# bias_sim=bias*self.bias_interval
# return w_sim,bias_sim
else:
return w_sim,None
def quant_input(self, x):
x_sim=(x/self.a_interval).round_().clamp_(-self.a_qmax,self.a_qmax-1)
x_sim.mul_(self.a_interval)
return x_sim
def quant_forward(self,x):
assert self.calibrated is not None,f"You should run calibrate_forward before run quant_forward for {self}"
w_sim,bias_sim=self.quant_weight_bias()
x_sim=self.quant_input(x)
out=F.linear(x_sim, w_sim, bias_sim)
return out
def _bias_correction_quant_forward(self, x):
if self.bias_correction and self.bias != None:
w_sim = self.quant_weight_bias()[0]
x_sim = self.quant_input(x)
eps = F.linear(x_sim, w_sim-self.weight.data, None)
eps = torch.mean(eps, dim=(list(range(len(eps.shape)-1))), keepdim=False)
self.bias -= eps
self.bias_correction = False
return self.quant_forward(x)
def calibration_step1(self,x):
# step1: collection the FP32 values
out=F.linear(x, self.weight, self.bias)
self.raw_input=x.cpu().detach()
self.raw_out=out.cpu().detach()
return out
def calibration_step2(self,x):
# step2: search for the best S^w and S^o of each layer
self.w_interval=(self.weight.data.abs().max()/(self.w_qmax-0.5)).detach()
self.a_interval=(x.abs().max()/(self.a_qmax-0.5)).detach()
self.calibrated=True
out=self._bias_correction_quant_forward(x)
return out
class PTQSLQuantLinear(MinMaxQuantLinear):
"""
PTQSL on linear modules.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
mode = "raw",
w_bit = 8,
a_bit = 8,
bias_bit = None,
bias_correction = False,
metric="L2_norm", search_round=1, eq_alpha=0, eq_beta=1, eq_n=100, parallel_eq_n=10, n_H=1, n_V=1, n_a=1, init_layerwise=False):
super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, bias_correction=bias_correction)
self.metric = metric
self.search_round = search_round
self.eq_alpha = eq_alpha
self.eq_beta = eq_beta
self.eq_n = eq_n
self.n_H = n_H
self.n_V = n_V
self.n_a = n_a
self.crb_rows = out_features // n_V
self.crb_cols = in_features // n_H # ignore remnent != 0 situations
self.crb_acts = in_features // n_a
self.parallel_eq_n = parallel_eq_n
self.init_layerwise = init_layerwise
self.raw_grad = None
def _get_similarity(self, tensor_raw, tensor_sim, metric=None):
"""
tensor_raw: *, features
tensor_sim: *, features
similarity: *
It's your job to calculate mean on * dims!
"""
if metric == "cosine":
similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=-1)
elif metric == "pearson":
similarity = F.cosine_similarity(tensor_raw-torch.mean(tensor_raw,dim=-1,keepdim=True), tensor_sim-torch.mean(tensor_sim,dim=-1,keepdim=True), dim=-1)
else:
if metric == "L1_norm":
similarity = -torch.abs(tensor_raw - tensor_sim)
elif metric == "L2_norm":
similarity = -(tensor_raw - tensor_sim) ** 2
elif metric == "linear_weighted_L2_norm":
similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2
elif metric == "square_weighted_L2_norm":
similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2
elif metric == "hessian":
raw_grad = self.raw_grad.reshape_as(tensor_raw)
similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2
else:
raise NotImplementedError(f"metric {metric} not implemented!")
similarity = torch.mean(similarity, dim=-1)
return similarity
def quant_weight_bias(self):
# self.w_interval shape: n_V, 1, n_H, 1
w=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols)/self.w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1)
w_sim=w.mul_(self.w_interval).view(self.out_features,self.in_features)
if self.bias is not None:
return w_sim,self.bias
# bias=(self.bias/self.bias_interval).round_().clamp_(-self.bias_qmax,self.bias_qmax-1)
# bias_sim=bias*self.bias_interval
# return w_sim,bias_sim
else:
return w_sim,None
def quant_input(self, x):
# self.a_interval shape: n_a,1
x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2)
x_sim=(x_sim.div_(self.a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)
x_sim = x_sim.mul_(self.a_interval).reshape_as(x)
return x_sim
def _search_best_w_interval(self, x, weight_interval_candidates, raw_out_expanded_chunked):
"""
Modularization of searching best weight intervals
"""
tmp_w_interval = self.w_interval.unsqueeze(0) # shape: 1,n_V,1,n_H,1
for h in range(self.n_H):
similarities = []
for p_st in range(0,self.eq_n,self.parallel_eq_n):
p_ed = min(self.eq_n, p_st+self.parallel_eq_n)
cur_w_interval = tmp_w_interval.repeat(p_ed-p_st,1,1,1,1)
cur_w_interval[:,:,:,h:h+1,:] = weight_interval_candidates[p_st:p_ed,:,:,h:h+1,:]
# quantize weight and bias
w_sim = self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).unsqueeze(0) # shape: 1,n_V,crb_rows,n_H,crb_cols
w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,n_V,crb_rows,n_H,crb_cols
w_sim = w_sim.view(-1,self.in_features) # shape: parallel_eq_n*oc,ic
bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None
# quantize input
x_sim = self.quant_input(x)
# calculate similarity and store them
out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: B,*,parallel_eq_n*oc
out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(-2), chunks=p_ed-p_st, dim=-1), dim=-2) # shape: B,*,parallel_eq_n,oc
out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: B,*,parallel_eq_n,n_V,crb_rows
similarity = self._get_similarity(raw_out_expanded_chunked, out_sim, self.metric) # shape: B,*,parallel_eq_n,n_V
similarity = torch.mean(similarity, dim=list(range(len(similarity.shape)-2))) # shape: parallel_eq_n, n_V
similarities.append(similarity)
# store best weight interval of h into tmp_w_interval
similarities = torch.cat(similarities, dim=0) # shape: eq_n, n_V
h_best_index = similarities.argmax(dim=0).reshape(1,-1,1,1,1) # shape: 1,n_V,1,1,1
tmp_w_interval[:,:,:,h:h+1,:] = torch.gather(weight_interval_candidates[:,:,:,h:h+1,:],dim=0,index=h_best_index)
self.w_interval = tmp_w_interval.squeeze(dim=0)
def _search_best_a_interval(self, x, input_interval_candidates, raw_out_expanded):
tmp_a_interval = self.a_interval.unsqueeze(-1) # shape: n_a,1,1
for a in range(self.n_a):
similarities = []
for p_st in range(0,self.eq_n,self.parallel_eq_n):
p_ed = min(self.eq_n, p_st+self.parallel_eq_n)
cur_a_interval = tmp_a_interval.repeat(1,1,p_ed-p_st) # shape: n_a,1,parallel_eq_n
cur_a_interval[a:a+1,:,:] = input_interval_candidates[a:a+1,:,p_st:p_ed]
# quantize weight and bias
w_sim, bias_sim = self.quant_weight_bias()
# quantize input
x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2).unsqueeze(-1)
x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: B,*,n_a,crb_acts,parallel_eq_n
x_sim = x_sim.permute(*list(range(len(x_sim.shape)-3)),-1,-3,-2).reshape(*x.shape[:-1],p_ed-p_st,x.shape[-1]) # shape: B,*,parallel_eq_n,ic
# calculate similarity and store them
out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: B,*,parallel_eq_n,oc
similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric) # shape: B,*,parallel_eq_n
similarity = torch.mean(similarity, dim=list(range(len(similarity.shape)-1))) # shape: parallel_eq_n
similarities.append(similarity)
# store best input interval and store in tmp_a_interval
similarities = torch.cat(similarities, dim=0) # shape: eq_n
a_best_index = similarities.argmax(dim=0, keepdim=True).reshape(1,1,-1)
tmp_a_interval[a:a+1,:,:] = torch.gather(input_interval_candidates[a:a+1,:,:],dim=2,index=a_best_index)
self.a_interval = tmp_a_interval.squeeze(-1)
def _initialize_intervals(self, x):
if self.init_layerwise:
self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1)
self.a_interval=(x.abs().max()/(self.a_qmax-0.5)).detach().view(1,1).repeat(self.n_a,1)
else:
self.w_interval=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5))
self.a_interval=((x.view(*x.shape[:-1],self.n_a,self.crb_acts).abs().amax(list(range(len(x.shape)-1))+[-1],keepdim=False))/(self.a_qmax-0.5)).unsqueeze(-1)
def calibration_step2(self,x):
# initialize intervals with minmax intervals
self._initialize_intervals(x)
# put raw outs on GPU
raw_out_expanded = self.raw_out.to(x.device).unsqueeze(-2) # shape: B,*,1,oc
raw_out_expanded_chunked = torch.cat(torch.chunk(raw_out_expanded.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: B,*,1,n_V,crb_rows
# put raw grad on GPU
self.raw_grad = self.raw_grad.to(x.device) if self.raw_grad != None else None
# prepare weight intervals and similarities
weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,n_V,1,n_H,1
input_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(1,1,-1) * self.a_interval.unsqueeze(-1) # shape: n_a,1,eq_n
for e in range(self.search_round):
# search for best weight interval
self._search_best_w_interval(x, weight_interval_candidates, raw_out_expanded_chunked)
# search for best input interval
self._search_best_a_interval(x, input_interval_candidates, raw_out_expanded)
self.raw_grad = self.raw_grad.to("cpu") if self.raw_grad != None else None
self.calibrated = True
out=self._bias_correction_quant_forward(x)
del self.raw_input, self.raw_out, self.raw_grad
return out
class PostGeluPTQSLQuantLinear(PTQSLQuantLinear):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
mode = "raw",
w_bit = 8,
a_bit = 8,
bias_bit = None,
bias_correction = False,
metric="L2_norm", search_round=1, eq_alpha=0, eq_beta=1, eq_n=100, parallel_eq_n=10, n_H=1, n_V=1, n_a=1, init_layerwise=False):
super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, bias_correction=bias_correction,
metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_H=n_H, n_V=n_V, n_a=n_a, init_layerwise=init_layerwise)
def quant_input(self, x):
"""
self.a_interval = [a_interval_pos, a_interval_neg]
"""
# self.a_interval[0] shape: n_a,1
# self.a_interval[1] shape: 1
x_=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2)
x_pos=(x_/(self.a_interval[0])).round_().clamp_(0,self.a_qmax-1).mul_(self.a_interval[0])
x_neg=(x_/(self.a_interval[1])).round_().clamp_(-self.a_qmax,0).mul_(self.a_interval[1])
return (x_pos + x_neg).reshape_as(x)
def _search_best_a_interval(self, x, input_interval_candidates, raw_out_expanded):
tmp_a_interval = self.a_interval[0].unsqueeze(-1) # shape: n_a,1,1
for a in range(self.n_a):
similarities = []
for p_st in range(0,self.eq_n,self.parallel_eq_n):
p_ed = min(self.eq_n, p_st+self.parallel_eq_n)
cur_a_interval = tmp_a_interval.repeat(1,1,p_ed-p_st) # shape: n_a,1,parallel_eq_n
cur_a_interval[a:a+1,:,:] = input_interval_candidates[a:a+1,:,p_st:p_ed]
# quantize weight and bias
w_sim, bias_sim = self.quant_weight_bias()
# quantize input
x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2).unsqueeze(-1)
x_pos=(x_sim/(cur_a_interval)).round_().clamp_(0,self.a_qmax-1)*(cur_a_interval) # shape: B,*,n_a,crb_acts,parallel_eq_n
x_neg=(x_sim/(self.a_interval[1])).round_().clamp_(-self.a_qmax,0)*(self.a_interval[1]) # shape: B,*,n_a,crb_acts,1
x_sim = (x_pos + x_neg).permute(*list(range(len(x_sim.shape)-3)),-1,-3,-2).reshape(*x.shape[:-1],p_ed-p_st,x.shape[-1]) # shape: B,*,parallel_eq_n,ic
# calculate similarity and store them
out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: B,*,parallel_eq_n,oc
similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric) # shape: B,*,parallel_eq_n
similarity = torch.mean(similarity, dim=list(range(len(similarity.shape)-1))) # shape: parallel_eq_n
similarities.append(similarity)
# store best input interval and store in tmp_a_interval
similarities = torch.cat(similarities, dim=0) # shape: eq_n
a_best_index = similarities.argmax(dim=0, keepdim=True).reshape(1,1,-1)
tmp_a_interval[a:a+1,:,:] = torch.gather(input_interval_candidates[a:a+1,:,:],dim=2,index=a_best_index)
self.a_interval[0] = tmp_a_interval.squeeze(-1)
def _initialize_intervals(self, x):
if self.init_layerwise:
self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1)
self.a_interval=[(x.max()/(self.a_qmax-0.5)).detach().view(1,1).repeat(self.n_a,1)]
else:
self.w_interval=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5))
self.a_interval=[((x.view(*x.shape[:-1],self.n_a,self.crb_acts).amax(list(range(len(x.shape)-1))+[-1],keepdim=False))/(self.a_qmax-0.5)).unsqueeze(-1)]
self.a_interval.append(0.16997124254703522/self.a_qmax)
def calibration_step2(self,x):
# initialize intervals with minmax intervals
self._initialize_intervals(x)
# put raw outs on GPU
raw_out_expanded = self.raw_out.to(x.device).unsqueeze(-2) # shape: B,*,1,oc
raw_out_expanded_chunked = torch.cat(torch.chunk(raw_out_expanded.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: B,*,1,n_V,crb_rows
# put raw grad on GPU
self.raw_grad = self.raw_grad.to(x.device) if self.raw_grad != None else None
# prepare weight intervals and similarities
weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,n_V,1,n_H,1
input_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(1,1,-1) * self.a_interval[0].unsqueeze(-1) # shape: n_a,1,eq_n
for e in range(self.search_round):
# search for best weight interval
self._search_best_w_interval(x, weight_interval_candidates, raw_out_expanded_chunked)
# search for best input interval
self._search_best_a_interval(x, input_interval_candidates, raw_out_expanded)
self.raw_grad = self.raw_grad.to("cpu") if self.raw_grad != None else None
self.calibrated = True
out=self._bias_correction_quant_forward(x)
del self.raw_input, self.raw_out, self.raw_grad
return out
class PTQSLBatchingQuantLinear(PTQSLQuantLinear):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
mode = "raw",
w_bit = 8,
a_bit = 8,
bias_bit = None,
bias_correction = False,
metric="L2_norm", search_round=1, eq_alpha=0, eq_beta=1, eq_n=100, parallel_eq_n=10, n_H=1, n_V=1, n_a=1, init_layerwise=False):
super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, bias_correction=bias_correction, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_H=n_H, n_V=n_V, n_a=n_a, init_layerwise=init_layerwise)
self.calib_size = None
self.calib_batch_size = None
self.calib_need_batching = False
def _initialize_calib_parameters(self):
"""
set parameters for feeding calibration data
"""
self.calib_size = int(self.raw_input.shape[0])
self.calib_batch_size = int(self.raw_input.shape[0])
while True:
numel = (2*(self.raw_input.numel()+self.raw_out.numel())/self.calib_size*self.calib_batch_size) # number of parameters on GPU
self.parallel_eq_n = int((3*1024*1024*1024/4)//numel)
if self.parallel_eq_n <= 1:
self.calib_need_batching = True
self.calib_batch_size //= 2
else:
break
def _initialize_intervals(self):
# weight intervals
if self.init_layerwise:
self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1)
else:
self.w_interval=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5))
# activation intervals
tmp_a_intervals = []
for b_st in range(0,self.calib_size,self.calib_batch_size):
b_ed = min(self.calib_size, b_st+self.calib_batch_size)
x_ = self.raw_input[b_st:b_ed].cuda()
if self.init_layerwise:
a_interval_=(x_.abs().max()/(self.a_qmax-0.5)).detach().view(1,1).repeat(self.n_a,1)
else:
a_interval_=((x_.view(*x_.shape[:-1],self.n_a,self.crb_acts).abs().amax(list(range(len(x_.shape)-1))+[-1],keepdim=False))/(self.a_qmax-0.5)).unsqueeze(-1)
tmp_a_intervals.append(a_interval_)
self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=True)
def _get_similarity(self, tensor_raw, tensor_sim, metric=None, raw_grad=None):
"""
tensor_raw: *, features
tensor_sim: *, features
similarity: *
It's your job to calculate mean on * dims!
"""
if metric == "cosine":
similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=-1)
else:
if metric == "L1_norm":
similarity = -torch.abs(tensor_raw - tensor_sim)
elif metric == "L2_norm":
similarity = -(tensor_raw - tensor_sim) ** 2
elif metric == "linear_weighted_L2_norm":
similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2
elif metric == "square_weighted_L2_norm":
similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2
elif metric == "hessian":
assert raw_grad != None, f"raw_grad is None in _get_similarity!"
raw_grad = raw_grad.reshape_as(tensor_raw)
similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2
else:
raise NotImplementedError(f"metric {metric} not implemented!")
similarity = torch.mean(similarity, dim=-1)
return similarity
def _get_pearson_w(self, tensor_raw, tensor_sim):
"""
Quick implementation of similarity-aware linear quantization
tensor_sim: b,*,parallel_eq_n,n_V,crb_rows
tensor_raw: b,*,1,n_V,crb_rows
"""
b, parallel_eq_n, n_V = tensor_sim.shape[0],tensor_sim.shape[-3],tensor_sim.shape[-2]
tensor_sim = tensor_sim.transpose(-1,-3).contiguous_().view(b,-1,n_V,parallel_eq_n)
tensor_raw = tensor_raw.transpose(-1,-3).view(b,-1,n_V,1)
tensor_sim_mean = tensor_sim.mean(dim=[0,1],keepdim=True)
tensor_raw_mean = tensor_raw.mean(dim=[0,1],keepdim=True)
similarity = torch.cosine_similarity(tensor_raw-tensor_raw_mean, tensor_sim-tensor_sim_mean, dim=1) # shape: b,n_V,parallel_eq_n
similarity = similarity.permute(0,2,1).contiguous_()
return similarity
def _get_pearson_a(self, tensor_raw, tensor_sim):
"""
Quick implementation of similarity-aware linear quantization
tensor_sim: b,*,parallel_eq_n,oc
tensor_raw: b,*,1,oc
"""
b, parallel_eq_n = tensor_sim.shape[0],tensor_sim.shape[-2]
tensor_sim = tensor_sim.transpose(-1,-2).contiguous_().view(b,-1,parallel_eq_n)
tensor_raw = tensor_raw.transpose(-1,-2).view(b,-1,1)
tensor_sim_mean = tensor_sim.mean(dim=[0,1],keepdim=True)
tensor_raw_mean = tensor_raw.mean(dim=[0,1],keepdim=True)
similarity = torch.cosine_similarity(tensor_raw-tensor_raw_mean, tensor_sim-tensor_sim_mean, dim=1) # shape: b,parallel_eq_n
return similarity
def _search_best_w_interval(self, weight_interval_candidates):
tmp_w_interval = self.w_interval.unsqueeze(0) # shape: 1,n_V,1,n_H,1
for h in range(self.n_H):
batch_similarities = [] # similarities, need to concatenate and calculate sum (equivalent to mean with argmax)
for b_st in range(0, self.calib_size, self.calib_batch_size):
b_ed = min(self.calib_size, b_st + self.calib_batch_size)
x = self.raw_input[b_st:b_ed].cuda()
raw_out_expanded = self.raw_out[b_st:b_ed].cuda().unsqueeze(-2) # shape: b,*,1,oc
raw_out_expanded = torch.cat(torch.chunk(raw_out_expanded.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: b,*,1,n_V,crb_rows
raw_grad = self.raw_grad[b_st:b_ed].cuda() # will be reshaped later
similarities = []
for p_st in range(0,self.eq_n,self.parallel_eq_n):
p_ed = min(self.eq_n, p_st+self.parallel_eq_n)
cur_w_interval = tmp_w_interval.repeat(p_ed-p_st,1,1,1,1)
cur_w_interval[:,:,:,h:h+1,:] = weight_interval_candidates[p_st:p_ed,:,:,h:h+1,:]
# quantize weight and bias
w_sim = self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).unsqueeze(0) # shape: 1,n_V,crb_rows,n_H,crb_cols
w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,n_V,crb_rows,n_H,crb_cols
w_sim = w_sim.view(-1,self.in_features) # shape: parallel_eq_n*oc,ic
bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None
# quantize input
x_sim = self.quant_input(x)
# calculate similarity and store them
out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: b,*,parallel_eq_n*oc
out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(-2), chunks=p_ed-p_st, dim=-1), dim=-2) # shape: b,*,parallel_eq_n,oc
out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: b,*,parallel_eq_n,n_V,crb_rows
if self.metric != "pearson":
similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric, raw_grad) # shape: b,*,parallel_eq_n,n_V
if len(similarity.shape) > 3:
similarity = torch.mean(similarity, dim=list(range(1,len(similarity.shape)-2))) # shape: b, parallel_eq_n, n_V
else:
similarity = self._get_pearson_w(raw_out_expanded, out_sim)
similarity = similarity.sum(dim=0, keepdim=True) # shape: 1, parallel_eq_n, n_V
similarities.append(similarity)
# store best weight interval of h into tmp_w_interval
similarities = torch.cat(similarities, dim=1) # shape: 1, eq_n, n_V
batch_similarities.append(similarities)
batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) # shape: eq_n, n_V
h_best_index = batch_similarities.argmax(dim=0).reshape(1,-1,1,1,1) # shape: 1,n_V,1,1,1
tmp_w_interval[:,:,:,h:h+1,:] = torch.gather(weight_interval_candidates[:,:,:,h:h+1,:],dim=0,index=h_best_index)
self.w_interval = tmp_w_interval.squeeze(dim=0)
def _search_best_a_interval(self, input_interval_candidates):
tmp_a_interval = self.a_interval.unsqueeze(-1) # shape: n_a,1,1
for a in range(self.n_a):
batch_similarities = [] # similarities, need to concatenate and calculate sum (equivalent to mean with argmax)
for b_st in range(0, self.calib_size, self.calib_batch_size):
b_ed = min(self.calib_size, b_st + self.calib_batch_size)
x = self.raw_input[b_st:b_ed].cuda()
raw_out_expanded = self.raw_out[b_st:b_ed].cuda().unsqueeze(-2) # shape: b,*,1,oc
raw_grad = self.raw_grad[b_st:b_ed].cuda() # will be reshaped later
similarities = []
for p_st in range(0,self.eq_n,self.parallel_eq_n):
p_ed = min(self.eq_n, p_st+self.parallel_eq_n)
cur_a_interval = tmp_a_interval.repeat(1,1,p_ed-p_st) # shape: n_a,1,parallel_eq_n
cur_a_interval[a:a+1,:,:] = input_interval_candidates[a:a+1,:,p_st:p_ed]
# quantize weight and bias
w_sim, bias_sim = self.quant_weight_bias()
# quantize input
x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2).unsqueeze(-1)
x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: b,*,n_a,crb_acts,parallel_eq_n
x_sim = x_sim.permute(*list(range(len(x_sim.shape)-3)),-1,-3,-2).reshape(*x.shape[:-1],p_ed-p_st,x.shape[-1]) # shape: b,*,parallel_eq_n,ic
# calculate similarity and store them
out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: b,*,parallel_eq_n,oc
if self.metric != "pearson":
similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric, raw_grad) # shape: b,*,parallel_eq_n
if len(similarity.shape) > 2:
similarity = torch.mean(similarity, dim=list(range(1,len(similarity.shape)-1))) # shape: b, parallel_eq_n
else:
similarity = self._get_pearson_a(raw_out_expanded, out_sim)
similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1, parallel_eq_n
similarities.append(similarity)
# store best input interval and store in tmp_a_interval
similarities = torch.cat(similarities, dim=1) # shape: 1, eq_n
batch_similarities.append(similarities)
batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) # shape: eq_n
a_best_index = batch_similarities.argmax(dim=0, keepdim=True).reshape(1,1,-1)
tmp_a_interval[a:a+1,:,:] = torch.gather(input_interval_candidates[a:a+1,:,:],dim=2,index=a_best_index)
self.a_interval = tmp_a_interval.squeeze(-1)
def calibration_step2(self):
"""
Only use cached raw inputs/outs/grads
"""
self._initialize_calib_parameters()
self._initialize_intervals()
# prepare weight intervals and similarities
weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,n_V,1,n_H,1
input_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(1,1,-1) * self.a_interval.unsqueeze(-1) # shape: n_a,1,eq_n
for e in range(self.search_round):
# search for best weight interval
self._search_best_w_interval(weight_interval_candidates)
# search for best input interval
self._search_best_a_interval(input_interval_candidates)
self.calibrated = True
# self._bias_correction_quant_forward(self.raw_input.cuda()) # debugging
del self.raw_input, self.raw_out, self.raw_grad
return None
class PostGeluPTQSLBatchingQuantLinear(PTQSLBatchingQuantLinear):
"""
An Agile implementation of PostGeluPTQSLBatchingQuantLinear
use a_interval for positive activation quantization and a_neg_interval for negative activation quantization
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
mode = "raw",
w_bit = 8,
a_bit = 8,
bias_bit = None,
bias_correction = False,
metric="L2_norm", search_round=1, eq_alpha=0, eq_beta=1, eq_n=100, parallel_eq_n=10, n_H=1, n_V=1, n_a=1, init_layerwise=False):
super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, bias_correction=bias_correction,
metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_H=n_H, n_V=n_V, n_a=n_a, init_layerwise=init_layerwise)
self.a_neg_interval = 0.16997124254703522/self.a_qmax
def _initialize_intervals(self):
# weight intervals
if self.init_layerwise:
self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1)
else:
self.w_interval=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5))
# activation intervals (for positive parts)
if self.init_layerwise:
tmp_a_intervals = []
for b_st in range(0,self.calib_size,self.calib_batch_size):
b_ed = min(self.calib_size, b_st+self.calib_batch_size)
x_ = self.raw_input[b_st:b_ed].cuda()
a_interval_=(x_.max()/(self.a_qmax-0.5)).detach().view(1,1).repeat(self.n_a,1)
tmp_a_intervals.append(a_interval_)
self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=True)
else:
tmp_a_intervals = []
for b_st in range(0,self.calib_size,self.calib_batch_size):
b_ed = min(self.calib_size, b_st+self.calib_batch_size)
x_ = self.raw_input[b_st:b_ed].cuda()
a_interval_=((x_.view(*x_.shape[:-1],self.n_a,self.crb_acts).amax(list(range(len(x_.shape)-1))+[-1],keepdim=False))/(self.a_qmax-0.5)).unsqueeze(-1)
tmp_a_intervals.append(a_interval_)
self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=True)
def quant_input(self, x):
# self.a_interval shape: n_a,1
# self.a_neg_interval shape: 1
x_=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2)
x_pos=(x_/(self.a_interval)).round_().clamp_(0,self.a_qmax-1).mul_(self.a_interval)
x_neg=(x_/(self.a_neg_interval)).round_().clamp_(-self.a_qmax,0).mul_(self.a_neg_interval)
return (x_pos + x_neg).reshape_as(x)
def _search_best_a_interval(self, input_interval_candidates):
tmp_a_interval = self.a_interval.unsqueeze(-1) # shape: n_a,1,1
for a in range(self.n_a):
batch_similarities = [] # similarities, need to concatenate and calculate sum (equivalent to mean with argmax)
for b_st in range(0, self.calib_size, self.calib_batch_size):
b_ed = min(self.calib_size, b_st + self.calib_batch_size)
x = self.raw_input[b_st:b_ed].cuda()
raw_out_expanded = self.raw_out[b_st:b_ed].cuda().unsqueeze(-2) # shape: b,*,1,oc
raw_grad = self.raw_grad[b_st:b_ed].cuda() # will be reshaped later
similarities = []
for p_st in range(0,self.eq_n,self.parallel_eq_n):
p_ed = min(self.eq_n, p_st+self.parallel_eq_n)
cur_a_interval = tmp_a_interval.repeat(1,1,p_ed-p_st) # shape: n_a,1,parallel_eq_n
cur_a_interval[a:a+1,:,:] = input_interval_candidates[a:a+1,:,p_st:p_ed]
# quantize weight and bias
w_sim, bias_sim = self.quant_weight_bias()
# quantize input
x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2).unsqueeze(-1)
x_pos=(x_sim/(cur_a_interval)).round_().clamp_(0,self.a_qmax-1)*(cur_a_interval) # shape: b,*,n_a,crb_acts,parallel_eq_n
x_neg=(x_sim/(self.a_neg_interval)).round_().clamp_(-self.a_qmax,0)*(self.a_neg_interval) # shape: b,*,n_a,crb_acts,1
x_sim = (x_pos + x_neg).permute(*list(range(len(x_sim.shape)-3)),-1,-3,-2).reshape(*x.shape[:-1],p_ed-p_st,x.shape[-1]) # shape: b,*,parallel_eq_n,ic
# calculate similarity and store them
out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: b,*,parallel_eq_n,oc
similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric, raw_grad) # shape: b,*,parallel_eq_n
similarity = torch.mean(similarity, dim=list(range(1,len(similarity.shape)-1))) # shape: b, parallel_eq_n
similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1, parallel_eq_n
similarities.append(similarity)
# store best input interval and store in tmp_a_interval
similarities = torch.cat(similarities, dim=1) # shape: 1, eq_n
batch_similarities.append(similarities)
batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) # shape: eq_n
a_best_index = batch_similarities.argmax(dim=0, keepdim=True).reshape(1,1,-1)
tmp_a_interval[a:a+1,:,:] = torch.gather(input_interval_candidates[a:a+1,:,:],dim=2,index=a_best_index)
self.a_interval = tmp_a_interval.squeeze(-1)
================================================
FILE: quant_layers/matmul.py
================================================
import numpy as np
import torch
from torch import nn
from torch import Tensor
from torch.nn import functional as F
from itertools import product
class MinMaxQuantMatMul(nn.Module):
"""Matrix Multiplication base class"""
def __init__(self, A_bit=8, B_bit=8, mode="raw"):
super().__init__()
self.A_bit=A_bit
self.B_bit=B_bit
self.A_interval=None
self.B_interval=None
self.A_qmax=2**(self.A_bit-1)
self.B_qmax=2**(self.B_bit-1)
self.mode=mode
self.raw_input = None
self.raw_out = None
def forward(self, A,B):
if self.mode=='raw':
out=A @ B
elif self.mode=="quant_forward":
out=self.quant_forward(A,B)
elif self.mode=="calibration_step1":
out=self.calibration_step1(A,B)
elif self.mode=="calibration_step2":
out=self.calibration_step2(A,B)
else:
raise NotImplementedError
return out
def quant_input(self,x,interval,qmax):
x_sim=(x/interval).round_().clamp_(-qmax,qmax-1)
x_sim.mul_(interval)
return x_sim
def quant_forward(self,A,B):
assert self.calibrated is not None,f"You should run calibrate_forward before run quant_forward for {self}"
A_sim=self.quant_input(A,self.A_interval,self.A_qmax)
B_sim=self.quant_input(B,self.B_interval,self.B_qmax)
out=A_sim@B_sim
return out
def calibration_step1(self,A,B):
# step1: collection the FP32 values
self.raw_input=A.cpu().detach(), B.cpu().detach()
out=A@B
self.raw_out=out.cpu().detach()
return out
def calibration_step2(self,A,B):
# step2: search for the best S^w and S^o of each layer
self.A_interval=(A.data.abs().max()/(self.A_qmax-0.5)).detach()
self.B_interval=(B.data.abs().max()/(self.B_qmax-0.5)).detach()
self.calibrated=True
out=self.quant_forward(A,B)
return out
class PTQSLQuantMatMul(MinMaxQuantMatMul):
"""
Chunk matrix into blockes and quantize.
Chunking follows naive padding strategy.
Alternately search for best intervals of each individual blocks for A and B.
two different scenarios:
- Q @ K:
- A's shape: B,H,S,W
- B's shape: B,H,W,S
- scores @ V:
- A's shape: B,H,S,S
- B's shape: B,H,S,W
- interval shape: 1,n_G,1,n_V,1,n_H,1
"""
def __init__(self, A_bit=8, B_bit=8, mode="raw",
metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10,
n_G_A=1, n_V_A=1, n_H_A=1, n_G_B=1, n_V_B=1, n_H_B=1, init_layerwise=False):
super().__init__(A_bit=A_bit, B_bit=B_bit, mode=mode)
self.metric = metric
self.search_round = search_round
self.eq_alpha = eq_alpha
self.eq_beta = eq_beta
self.eq_n = eq_n
self.parallel_eq_n = parallel_eq_n
self.n_G_A = n_G_A
self.n_V_A = n_V_A
self.n_H_A = n_H_A
self.n_G_B = n_G_B
self.n_V_B = n_V_B
self.n_H_B = n_H_B
# init these parameters in self.calibration_step1
self.crb_groups_A = None
self.crb_groups_B = None
self.crb_rows_A = None
self.crb_cols_A = None
self.crb_rows_B = None
self.crb_cols_B = None
self.pad_groups_A = None
self.pad_groups_B = None
self.pad_rows_A = None
self.pad_rows_B = None
self.pad_cols_A = None
self.pad_cols_B = None
self.raw_grad = None
self.init_layerwise = init_layerwise
def _get_padding_parameters(self, A, B):
self.crb_groups_A = (A.shape[1]+self.n_G_A-1) // self.n_G_A
self.crb_groups_B = (B.shape[1]+self.n_G_B-1) // self.n_G_B
self.crb_rows_A = (A.shape[2]+self.n_V_A-1) // self.n_V_A
self.crb_cols_A = (A.shape[3]+self.n_H_A-1) // self.n_H_A
self.crb_rows_B = (B.shape[2]+self.n_V_B-1) // self.n_V_B
self.crb_cols_B = (B.shape[3]+self.n_H_B-1) // self.n_H_B
self.pad_groups_A = self.crb_groups_A*self.n_G_A - A.shape[1]
self.pad_rows_A = self.crb_rows_A*self.n_V_A - A.shape[2]
self.pad_cols_A = self.crb_cols_A*self.n_H_A - A.shape[3]
self.pad_groups_B = self.crb_groups_B*self.n_G_B - B.shape[1]
self.pad_rows_B = self.crb_rows_B*self.n_V_B - B.shape[2]
self.pad_cols_B = self.crb_cols_B*self.n_H_B - B.shape[3]
def quant_input_A(self, x):
x = F.pad(x, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A])
x = x.view(-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A)
x = (x/self.A_interval).round_().clamp(-self.A_qmax,self.A_qmax-1).mul_(self.A_interval)
x = x.view(-1,self.n_G_A*self.crb_groups_A,self.n_V_A*self.crb_rows_A,self.n_H_A*self.crb_cols_A)
x = x[:,:x.shape[1]-self.pad_groups_A,:x.shape[2]-self.pad_rows_A,:x.shape[3]-self.pad_cols_A]
return x
def quant_input_B(self, x):
x = F.pad(x, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B])
x = x.view(-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B)
x = (x/self.B_interval).round_().clamp(-self.B_qmax,self.B_qmax-1).mul_(self.B_interval)
x = x.view(-1,self.n_G_B*self.crb_groups_B,self.n_V_B*self.crb_rows_B,self.n_H_B*self.crb_cols_B)
x = x[:,:x.shape[1]-self.pad_groups_B,:x.shape[2]-self.pad_rows_B,:x.shape[3]-self.pad_cols_B]
return x
def quant_forward(self, A, B):
assert self.calibrated is not None,f"You should run calibrate_forward before run quant_forward for {self}"
A_sim=self.quant_input_A(A)
B_sim=self.quant_input_B(B)
out=A_sim@B_sim
return out
def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1):
"""
tensor_raw: *, features, *
tensor_sim: *, features, *
similarity: *
It's your job to calculate mean on non-feature * dims!
Similarity without inherent feature structure is more welcome to parallelism.
"""
if metric == "cosine":
similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=dim) # should only support dim=-1 and cannot be paralleled
elif metric == "pearson":
similarity = F.cosine_similarity(tensor_raw-torch.mean(tensor_raw), tensor_sim-torch.mean(tensor_sim), dim=dim)
else:
if metric == "L1_norm":
similarity = -torch.abs(tensor_raw - tensor_sim)
elif metric == "L2_norm":
similarity = -(tensor_raw - tensor_sim) ** 2
elif metric == "linear_weighted_L2_norm":
similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2
elif metric == "square_weighted_L2_norm":
similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2
elif metric == "hessian":
raw_grad = self.raw_grad.reshape_as(tensor_raw)
similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2
else:
raise NotImplementedError(f"metric {metric} not implemented!")
similarity = torch.mean(similarity, dim=dim)
return similarity
def _search_best_A_interval(self, A, B, A_interval_candidates):
"""
Modularization of searching best interval
"""
# recalculate A_pad
A_pad = F.pad(A, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A]).unsqueeze(0).view(1,-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A)
tmp_A_interval = self.A_interval.unsqueeze(0) # shape: 1,1,n_G,1,n_V,1,n_H,1
# out-of-loop optimization
B_sim = self.quant_input_B(B).unsqueeze(0) # shape: 1,B,H,dim2,dim3
for v, h in product(range(self.n_V_A), range(self.n_H_A)):
similarities = []
for p_st in range(0, self.eq_n, self.parallel_eq_n):
p_ed = min(self.eq_n,p_st+self.parallel_eq_n)
# quantize A
cur_A_interval = tmp_A_interval.repeat(p_ed-p_st,1,1,1,1,1,1,1)
cur_A_interval[:,:,:,:,v:v+1,:,h:h+1,:] = A_interval_candidates[p_st:p_ed,:,:,:,v:v+1,:,h:h+1,:]
A_sim = (A_pad/cur_A_interval).round_().clamp_(-self.A_qmax,self.A_qmax-1).mul_(cur_A_interval)
A_sim = A_sim.view(p_ed-p_st,-1,A.shape[1]+self.pad_groups_A,A.shape[2]+self.pad_rows_A,A.shape[3]+self.pad_cols_A) # shape: parallel_eq_n,B,H*,dim1*,dim2* (* stand for padding)
A_sim = A_sim[:,:,:A.shape[1],:A.shape[2],:A.shape[3]] # shape: parallel_eq_n,B,H,dim1,dim2
# quantize B, this quantization is optimized out of loop
# calculate similarity and store them
out_sim = A_sim @ B_sim # shape: parallel_eq_n,B,H,dim1,dim3
similarity = self._get_similarity(self.raw_out, out_sim, self.metric) # shape: parallel_eq_n,B,H,dim1
similarity = similarity.mean([1,3]) # shape: parallel_eq_n,H (remaining mean operation will be done later on)
similarities.append(similarity)
# calculate best similarity for this block
similarities = torch.cat(similarities, 0) # shape: eq_n,H
similarities = F.pad(similarities, [0,self.pad_groups_A]).view(self.eq_n,self.n_G_A,self.crb_groups_A).mean(-1) # shape: eq_n, n_G_A
best_index = torch.argmax(similarities, dim=0, keepdim=False).view(1,1,-1,1,1,1,1,1)
tmp_A_interval[:,:,:,:,v:v+1,:,h:h+1,:] = torch.gather(A_interval_candidates[:,:,:,:,v:v+1,:,h:h+1,:],dim=0,index=best_index)
self.A_interval = tmp_A_interval.squeeze(0)
def _search_best_B_interval(self, A, B, B_interval_candidates):
"""
Modularization of searching best interval
"""
# recalculate B_pad
B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B)
tmp_B_interval = self.B_interval.unsqueeze(0) # shape: 1,1,n_G,1,n_V,1,n_H,1
# out-of-loop optimization
A_sim = self.quant_input_A(A).unsqueeze(0) # shape: 1,B,H,dim1,dim2
for v, h in product(range(self.n_V_B), range(self.n_H_B)):
similarities = []
for p_st in range(0, self.eq_n, self.parallel_eq_n):
p_ed = min(self.eq_n,p_st+self.parallel_eq_n)
# quantize A, this quantization is optimized out of loop
# quantize B
cur_B_interval = tmp_B_interval.repeat(p_ed-p_st,1,1,1,1,1,1,1)
cur_B_interval[:,:,:,:,v:v+1,:,h:h+1,:] = B_interval_candidates[p_st:p_ed,:,:,:,v:v+1,:,h:h+1,:]
B_sim = (B_pad/cur_B_interval).round_().clamp_(-self.B_qmax,self.B_qmax-1).mul_(cur_B_interval)
B_sim = B_sim.view(p_ed-p_st,-1,B.shape[1]+self.pad_groups_B,B.shape[2]+self.pad_rows_B,B.shape[3]+self.pad_cols_B) # shape: parallel_eq_n,B,H*,dim2*,dim3* (* stand for padding)
B_sim = B_sim[:,:,:B.shape[1],:B.shape[2],:B.shape[3]] # shape: parallel_eq_n,B,H,dim2,dim3
# calculate similarity and store them
out_sim = A_sim @ B_sim # shape: parallel_eq_n,B,H,dim1,dim3
similarity = self._get_similarity(self.raw_out, out_sim, self.metric) # shape: parallel_eq_n,B,H,dim1
similarity = similarity.mean([1,3]) # shape: parallel_eq_n,H (remaining mean operation will be done later on)
similarities.append(similarity)
# calculate best similarity for this block
similarities = torch.cat(similarities, 0) # shape: eq_n,H
similarities = F.pad(similarities, [0,self.pad_groups_B]).view(self.eq_n,self.n_G_B,self.crb_groups_B).mean(-1) # shape: eq_n, n_G_B
best_index = torch.argmax(similarities, dim=0, keepdim=False).view(1,1,-1,1,1,1,1,1)
tmp_B_interval[:,:,:,:,v:v+1,:,h:h+1,:] = torch.gather(B_interval_candidates[:,:,:,:,v:v+1,:,h:h+1,:],dim=0,index=best_index)
self.B_interval = tmp_B_interval.squeeze(0)
def _initialize_intervals(self, A, B):
# pad A and B for future quantization
self._get_padding_parameters(A, B) # put it here because hessian does not use calibration step 1
A_pad = F.pad(A, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A]).unsqueeze(0).view(1,-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A) # shape: 1,B,n_G,crb_groups,n_V,crb_rows,n_H,crb_cols
B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B)
# initialize intervals with minmax intervals
if self.init_layerwise:
self.A_interval = (A.abs().max()/(self.A_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_A,1,self.n_V_A,1,self.n_H_A,1)
self.B_interval = (B.abs().max()/(self.B_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_B,1,self.n_V_B,1,self.n_H_B,1)
else:
self.A_interval=(A_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.A_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1
self.B_interval=(B_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.B_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1
def calibration_step2(self, A, B):
# put raw outs/grads on GPU
self.raw_out = self.raw_out.unsqueeze(0).to(A.device)
self.raw_grad = self.raw_grad.to(A.device) if self.raw_grad != None else None
self._initialize_intervals(A, B)
# prepare weight intervals and similarities
A_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.A_interval.unsqueeze(0)
B_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.B_interval.unsqueeze(0)
for e in range(self.search_round):
# search for best A interval
self._search_best_A_interval(A, B, A_interval_candidates)
# search for best B interval
self._search_best_B_interval(A, B, B_interval_candidates)
# put raw data back to cpu
self.raw_out = self.raw_out.squeeze(0).to("cpu")
self.raw_grad = self.raw_grad.to("cpu") if self.raw_grad != None else None
# finish calibration and output the result
self.calibrated = True
del self.raw_input, self.raw_out, self.raw_grad
out=self.quant_forward(A,B)
return out
class SoSPTQSLQuantMatMul(PTQSLQuantMatMul):
"""
Sublayerwise PTQ on matmul modules with Split-of-Softmax (SoS) on score matrix.
Data after softmaxing has highly biased distribution, making it difficult to quantize with uniform quantization.
An elegant tradeoff between great majority of unimportant values and few crucial values is impossible under low bit quantization.
Therefore, we propose to split complete interval of (0, 1) into several smaller intervals and perform uniform quantization on each.
We could manually assgin or search for the best split point.
Currently, we only consider single split point scenarios, since this proves to be effective enough.
The algorithm no longer requires PTQSL on score matrix, and will ignore relevant parameters.
with proper hardware implementation, we don't need to use a sign bit anymore.
"""
def __init__(self, A_bit=8, B_bit=8, mode="raw",
metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10,
n_G_A=1, n_V_A=1, n_H_A=1, n_G_B=1, n_V_B=1, n_H_B=1, init_layerwise=False,
split=None):
super().__init__(A_bit=A_bit, B_bit=B_bit, mode=mode,
metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n,
n_G_A=n_G_A, n_V_A=n_V_A, n_H_A=n_H_A, n_G_B=n_G_B, n_V_B=n_V_B, n_H_B=n_H_B, init_layerwise=init_layerwise)
self.n_G_A = 1
self.n_V_A = 1
self.n_H_A = 1
self.A_qmax = 2**(self.A_bit-1) # well, still need it
self.split = split
if split != None:
self.A_interval = self.split/(self.A_qmax-1)
def quant_input_A(self, x):
x_high = (x.clamp(self.split, 1)*(self.A_qmax-1)).round_().clamp_(0,self.A_qmax-1)/(self.A_qmax-1)
x_low = (x.clamp(0, self.split)/self.A_interval).round_().clamp_(0,self.A_qmax-1)*self.A_interval
return x_high + x_low
def _search_best_A_interval(self, A, B, split_candidates):
"""
search for best split point
"""
# out-of-loop optimization
A_ = A.unsqueeze(0)
# B_sim = self.quant_input_B(B).unsqueeze(0) # shape: 1,B,H,dim2,dim3
B_sim = B.unsqueeze(0)
similarities = []
for i in range(len(split_candidates)):
# quantize A
cur_A_interval = split_candidates[i]/(self.A_qmax-1)
A_high = (A_.clamp(split_candidates[i], 1)*(self.A_qmax-1)).round_().clamp_(0,self.A_qmax-1)/(self.A_qmax-1)
A_low =( A_.clamp(0, split_candidates[i])/cur_A_interval).round_().clamp_(0,self.A_qmax-1)*cur_A_interval
A_sim = A_high + A_low # shape: 1,B,H,S,S
# quantize B, this quantization is optimized out of loop
# calculate similarity and store them (dim1=dim2=S, dim3=W)
out_sim = A_sim @ B_sim # shape: 1,B,H,dim1,dim3
similarity = self._get_similarity(self.raw_out, out_sim, self.metric) # shape: parallel_eq_n,B,H,dim1
similarity = similarity.mean([1,2,3]) # shape: 1
similarities.append(similarity)
# calculate best similarity for this block
similarities = torch.cat(similarities, 0) # shape: eq_n
best_index = torch.argmax(similarities, dim=0, keepdim=False)
self.split = split_candidates[best_index]
self.A_interval = self.split/(self.A_qmax-1)
# debugging
# print(f"best split: {self.split}")
def _initialize_intervals(self, A, B):
# pad A and B for future quantization
self._get_padding_parameters(A, B)
B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B)
# initialize intervals with minmax intervals
self.split = 0.01
self.A_interval = self.split/(self.A_qmax-1)
if self.init_layerwise:
self.B_interval = (B.abs().max()/(self.B_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_B,1,self.n_V_B,1,self.n_H_B,1)
else:
self.B_interval=(B_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.B_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1
def calibration_step2(self, A, B):
# put raw outs/grads on GPU
self.raw_out = self.raw_out.unsqueeze(0).to(A.device)
self.raw_grad = self.raw_grad.to(A.device) if self.raw_grad != None else None
self._initialize_intervals(A, B)
# prepare weight intervals and similarities
A_split_candidates = torch.tensor([2**(-i) for i in range(20)]).cuda()
# split_eq_alpha, split_eq_beta, split_eq_n = 0.002, 0.03, 50
# A_split_candidates = torch.tensor([split_eq_alpha + (split_eq_beta- split_eq_alpha)*i/split_eq_n for i in range(split_eq_n + 1)]).cuda()
B_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.B_interval.unsqueeze(0)
for e in range(self.search_round):
# search for best A interval
self._search_best_A_interval(A, B, A_split_candidates)
# search for best B interval
self._search_best_B_interval(A, B, B_interval_candidates)
# put raw data back to cpu
self.raw_out = self.raw_out.squeeze(0).to("cpu")
self.raw_grad = self.raw_grad.to("cpu") if self.raw_grad != None else None
# finish calibration and output the result
self.calibrated = True
del self.raw_input, self.raw_out, self.raw_grad
out=self.quant_forward(A,B)
return out
class PTQSLBatchingQuantMatMul(PTQSLQuantMatMul):
def __init__(self, A_bit=8, B_bit=8, mode="raw",
metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10,
n_G_A=1, n_V_A=1, n_H_A=1, n_G_B=1, n_V_B=1, n_H_B=1, init_layerwise=False):
super().__init__(A_bit=A_bit, B_bit=B_bit, mode=mode, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_G_A=n_G_A, n_V_A=n_V_A, n_H_A=n_H_A, n_G_B=n_G_B, n_V_B=n_V_B, n_H_B=n_H_B, init_layerwise=init_layerwise)
def _initialize_calib_parameters(self):
"""
set parameters for feeding calibration data
"""
self.calib_size = int(self.raw_input[0].shape[0])
self.calib_batch_size = int(self.raw_input[0].shape[0])
while True:
numel = ((self.raw_input[0].numel()+self.raw_input[1].numel()+2*self.raw_out.numel())/self.calib_size*self.calib_batch_size) # number of parameters on GPU
self.parallel_eq_n = int((3*1024*1024*1024/4)//numel)
if self.parallel_eq_n <= 1:
self.calib_need_batching = True
self.calib_batch_size //= 2
else:
break
def _get_padding_parameters(self, A, B):
"""
We adopt a head-wise quantization here
"""
self.n_G_A = A.shape[1]
self.n_G_B = B.shape[1]
super()._get_padding_parameters(A,B)
def _initialize_intervals(self):
# pad A and B for future quantization
self._get_padding_parameters(self.raw_input[0], self.raw_input[1]) # put it here because hessian does not use calibration step 1
# initialize intervals with minmax intervals
tmp_A_intervals = []
tmp_B_intervals = []
for b_st in range(0,self.calib_size,self.calib_batch_size):
b_ed = min(self.calib_size, b_st+self.calib_batch_size)
A, B = self.raw_input[0][b_st:b_ed].cuda(), self.raw_input[1][b_st:b_ed].cuda()
if self.init_layerwise:
A_interval = (A.abs().max()/(self.A_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_A,1,self.n_V_A,1,self.n_H_A,1)
B_interval = (B.abs().max()/(self.B_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_B,1,self.n_V_B,1,self.n_H_B,1)
else:
A_pad = F.pad(A, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A]).unsqueeze(0).view(1,-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A)
B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B)
A_interval=(A_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.A_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1
B_interval=(B_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.B_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1
tmp_A_intervals.append(A_interval)
tmp_B_intervals.append(B_interval)
self.A_interval = torch.cat(tmp_A_intervals, dim=0).amax(0, keepdim=True)
self.B_interval = torch.cat(tmp_B_intervals, dim=0).amax(0, keepdim=True)
def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1, raw_grad=None):
"""
tensor_raw: *, features, *
tensor_sim: *, features, *
similarity: *
It's your job to calculate mean on non-feature * dims!
Similarity without inherent feature structure is more welcome to parallelism.
"""
if metric == "cosine":
similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=dim) # should only support dim=-1 and cannot be paralleled
elif metric == "pearson":
similarity = F.cosine_similarity(tensor_raw-torch.mean(tensor_raw,dim=dim,keepdim=True), tensor_sim-torch.mean(tensor_sim,dim=dim,keepdim=True), dim=dim) # should only support dim=-1 and cannot be paralleled
# a quick implementation of pearson similarity
# tensor_raw: 1,B,H,dim1,dim3
# tensor_sim: parallel_eq_n,B,H,dim1,dim3
# parallel_eq_n,B,H,dim1,dim3 = tensor_sim.shape
# tensor_sim = tensor_sim.view(parallel_eq_n,B,-1)
# tensor_raw = tensor_raw.view(1,B,-1)
# tensor_sim_mean = tensor_sim.mean(dim=[1,2],keepdim=True)
# tensor_raw_mean = tensor_raw.mean(dim=[1,2],keepdim=True)
# similarity = F.cosine_similarity(tensor_raw-tensor_raw_mean,tensor_sim-tensor_sim_mean,dim=-1) # shape: parallel_eq_n,B
# similarity = similarity.reshape(parallel_eq_n,B,1,1) # restore two dims
else:
if metric == "L1_norm":
similarity = -torch.abs(tensor_raw - tensor_sim)
elif metric == "L2_norm":
similarity = -(tensor_raw - tensor_sim) ** 2
elif metric == "linear_weighted_L2_norm":
similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2
elif metric == "square_weighted_L2_norm":
similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2
elif metric == "hessian":
assert raw_grad != None, f"No raw_grad in PTQSLBatchingQuantMatMul!"
raw_grad = raw_grad.reshape_as(tensor_raw)
similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2
else:
raise NotImplementedError(f"metric {metric} not implemented!")
similarity = torch.mean(similarity, dim=dim)
return similarity
def _search_best_A_interval(self, A_interval_candidates):
"""
Modularization of searching best interval
"""
tmp_A_interval = self.A_interval.unsqueeze(0) # shape: 1,1,n_G,1,n_V,1,n_H,1
# out-of-loop optimization
for v, h in product(range(self.n_V_A), range(self.n_H_A)):
batch_similarities = [] # similarities, need to concatenate and calculate sum
for b_st in range(0, self.calib_size, self.calib_batch_size):
b_ed = min(self.calib_size, b_st + self.calib_batch_size)
A = self.raw_input[0][b_st:b_ed].cuda()
A_pad = F.pad(A, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A]).unsqueeze(0).view(1,-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A)
B = self.raw_input[1][b_st:b_ed].cuda()
B_sim = self.quant_input_B(B).unsqueeze(0) # shape: 1,b,H,dim2,dim3
raw_out = self.raw_out[b_st:b_ed].unsqueeze(0).cuda()
raw_grad = self.raw_grad[b_st:b_ed].cuda()
similarities = []
for p_st in range(0, self.eq_n, self.parallel_eq_n):
p_ed = min(self.eq_n,p_st+self.parallel_eq_n)
# quantize A
cur_A_interval = tmp_A_interval.repeat(p_ed-p_st,1,1,1,1,1,1,1)
cur_A_interval[:,:,:,:,v:v+1,:,h:h+1,:] = A_interval_candidates[p_st:p_ed,:,:,:,v:v+1,:,h:h+1,:]
A_sim = (A_pad/cur_A_interval).round_().clamp_(-self.A_qmax,self.A_qmax-1).mul_(cur_A_interval)
A_sim = A_sim.view(p_ed-p_st,-1,A.shape[1]+self.pad_groups_A,A.shape[2]+self.pad_rows_A,A.shape[3]+self.pad_cols_A) # shape: parallel_eq_n,B,H*,dim1*,dim2* (* stand for padding)
A_sim = A_sim[:,:,:A.shape[1],:A.shape[2],:A.shape[3]] # shape: parallel_eq_n,b,H,dim1,dim2
# quantize B, this quantization is optimized out of loop
# calculate similarity and store them
out_sim = A_sim @ B_sim # shape: parallel_eq_n,B,H,dim1,dim3
similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad=raw_grad) # shape: parallel_eq_n,b,H,dim1
similarity = similarity.mean([3]) # shape: parallel_eq_n,b,H (remaining mean operation will be done later on)
similarity = similarity.sum(dim=1, keepdim=True) # shape: parallel_eq_n,1,H
similarities.append(similarity)
# calculate best similarity for this block
similarities = torch.cat(similarities, 0) # shape: eq_n,1,H
batch_similarities.append(similarities)
batch_similarities = torch.cat(batch_similarities, dim=1).sum(dim=1, keepdim=False) #shape: eq_n,H
batch_similarities = F.pad(batch_similarities, [0,self.pad_groups_A]).view(self.eq_n,self.n_G_A,self.crb_groups_A).mean(-1) # shape: eq_n, n_G_A
best_index = torch.argmax(batch_similarities, dim=0, keepdim=False).view(1,1,-1,1,1,1,1,1)
tmp_A_interval[:,:,:,:,v:v+1,:,h:h+1,:] = torch.gather(A_interval_candidates[:,:,:,:,v:v+1,:,h:h+1,:],dim=0,index=best_index)
self.A_interval = tmp_A_interval.squeeze(0)
def _search_best_B_interval(self, B_interval_candidates):
"""
Modularization of searching best interval
"""
tmp_B_interval = self.B_interval.unsqueeze(0) # shape: 1,1,n_G,1,n_V,1,n_H,1
# out-of-loop optimization
for v, h in product(range(self.n_V_B), range(self.n_H_B)):
batch_similarities = [] # similarities, need to concatenate and calculate sum
for b_st in range(0, self.calib_size, self.calib_batch_size):
b_ed = min(self.calib_size, b_st + self.calib_batch_size)
A = self.raw_input[0][b_st:b_ed].cuda()
A_sim = self.quant_input_A(A).unsqueeze(0) # shape: 1,B,H,dim1,dim2
B = self.raw_input[1][b_st:b_ed].cuda()
B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B)
raw_out = self.raw_out[b_st:b_ed].unsqueeze(0).cuda()
raw_grad = self.raw_grad[b_st:b_ed].cuda()
similarities = []
for p_st in range(0, self.eq_n, self.parallel_eq_n):
p_ed = min(self.eq_n,p_st+self.parallel_eq_n)
# quantize A, this quantization is optimized out of loop
# quantize B
cur_B_interval = tmp_B_interval.repeat(p_ed-p_st,1,1,1,1,1,1,1)
cur_B_interval[:,:,:,:,v:v+1,:,h:h+1,:] = B_interval_candidates[p_st:p_ed,:,:,:,v:v+1,:,h:h+1,:]
B_sim = (B_pad/cur_B_interval).round_().clamp_(-self.B_qmax,self.B_qmax-1).mul_(cur_B_interval)
B_sim = B_sim.view(p_ed-p_st,-1,B.shape[1]+self.pad_groups_B,B.shape[2]+self.pad_rows_B,B.shape[3]+self.pad_cols_B) # shape: parallel_eq_n,b,H*,dim2*,dim3* (* stand for padding)
B_sim = B_sim[:,:,:B.shape[1],:B.shape[2],:B.shape[3]] # shape: parallel_eq_n,b,H,dim2,dim3
# calculate similarity and store them
out_sim = A_sim @ B_sim # shape: parallel_eq_n,b,H,dim1,dim3
similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad=raw_grad) # shape: parallel_eq_n,b,H,dim1
similarity = similarity.mean([3]) # shape: parallel_eq_n,b,H (remaining mean operation will be done later on)
similarity = similarity.sum(dim=1, keepdim=True) # shape: parallel_eq_n,1,H
similarities.append(similarity)
# calculate best similarity for this block
similarities = torch.cat(similarities, 0) # shape: eq_n,1,H
batch_similarities.append(similarities)
batch_similarities = torch.cat(batch_similarities, dim=1).sum(dim=1, keepdim=False) #shape: eq_n,H
batch_similarities = F.pad(batch_similarities, [0,self.pad_groups_B]).view(self.eq_n,self.n_G_B,self.crb_groups_B).mean(-1) # shape: eq_n, n_G_B
best_index = torch.argmax(batch_similarities, dim=0, keepdim=False).view(1,1,-1,1,1,1,1,1)
tmp_B_interval[:,:,:,:,v:v+1,:,h:h+1,:] = torch.gather(B_interval_candidates[:,:,:,:,v:v+1,:,h:h+1,:],dim=0,index=best_index)
self.B_interval = tmp_B_interval.squeeze(0)
def calibration_step2(self):
self._initialize_calib_parameters()
self._initialize_intervals()
A_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.A_interval.unsqueeze(0)
B_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.B_interval.unsqueeze(0)
for e in range(self.search_round):
# search for best A interval
self._search_best_A_interval(A_interval_candidates)
# search for best B interval
self._search_best_B_interval(B_interval_candidates)
self.calibrated = True
del self.raw_input, self.raw_out, self.raw_grad
class SoSPTQSLBatchingQuantMatMul(PTQSLBatchingQuantMatMul):
def __init__(self, A_bit=8, B_bit=8, mode="raw",
metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10,
n_G_A=1, n_V_A=1, n_H_A=1, n_G_B=1, n_V_B=1, n_H_B=1, init_layerwise=False,
split=None):
super().__init__(A_bit=A_bit, B_bit=B_bit, mode=mode,
metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n,
n_G_A=n_G_A, n_V_A=n_V_A, n_H_A=n_H_A, n_G_B=n_G_B, n_V_B=n_V_B, n_H_B=n_H_B, init_layerwise=init_layerwise)
self.n_G_A = 1
self.n_V_A = 1
self.n_H_A = 1
# with proper hardware implementation, we don't need to use a sign bit anymore
self.A_qmax = 2**(self.A_bit-1)
self.split = split
if split != None:
self.A_interval = self.split/(self.A_qmax-1)
def quant_input_A(self, x):
x_high = (x.clamp(self.split, 1)*(self.A_qmax-1)).round_().clamp_(0,self.A_qmax-1)/(self.A_qmax-1)
x_low = (x.clamp(0, self.split)/self.A_interval).round_().clamp_(0,self.A_qmax-1)*self.A_interval
return x_high + x_low
def _search_best_A_interval(self, split_candidates):
batch_similarities = []
for b_st in range(0, self.calib_size, self.calib_batch_size):
b_ed = min(self.calib_size, b_st + self.calib_batch_size)
A = self.raw_input[0][b_st:b_ed].unsqueeze(0).cuda()
B = self.raw_input[1][b_st:b_ed].unsqueeze(0).cuda()
B_sim = B
raw_out = self.raw_out[b_st:b_ed].unsqueeze(0).cuda()
raw_grad = self.raw_grad[b_st:b_ed].cuda()
similarities = []
for i in range(len(split_candidates)):
# quantize A
cur_A_interval = split_candidates[i]/(self.A_qmax-1)
A_high = (A.clamp(split_candidates[i], 1)*(self.A_qmax-1)).round_().clamp_(0,self.A_qmax-1)/(self.A_qmax-1)
A_low =( A.clamp(0, split_candidates[i])/cur_A_interval).round_().clamp_(0,self.A_qmax-1)*cur_A_interval
A_sim = A_high + A_low # shape: 1,b,H,S,S
# quantize B, this quantization is optimized out of loop
# calculate similarity and store them (dim1=dim2=S, dim3=W)
out_sim = A_sim @ B_sim # shape: 1,b,H,dim1,dim3
similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad=raw_grad) # shape: parallel_eq_n,b,H,dim1
similarity = similarity.mean([2,3]) # shape: parallel_eq_n, b
similarity = similarity.sum(dim=1,keepdim=True) # parallel_eq_n, 1
similarities.append(similarity)
# calculate best similarity for this block
similarities = torch.cat(similarities, 0) # shape: eq_n, 1
batch_similarities.append(similarities)
batch_similarities = torch.cat(batch_similarities, dim=1).sum(dim=1, keepdim=False) #shape: eq_n
best_index = torch.argmax(batch_similarities, dim=0, keepdim=False)
self.split = split_candidates[best_index]
self.A_interval = self.split/(self.A_qmax-1)
# debugging
# print(f"best split: {self.split}")
def calibration_step2(self):
self._initialize_calib_parameters()
self._initialize_intervals()
A_split_candidates = torch.tensor([2**(-i) for i in range(20)]).cuda()
B_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.B_interval.unsqueeze(0)
for e in range(self.search_round):
# search for best A interval
self._search_best_A_interval(A_split_candidates)
# search for best B interval
self._search_best_B_interval(B_interval_candidates)
self.calibrated = True
del self.raw_input, self.raw_out, self.raw_grad
================================================
FILE: utils/datasets.py
================================================
"""
Reuse version v4
Author: Hahn Yuan
"""
import PIL
import torch
import argparse
import numpy as np
import os
import copy
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.datasets import ImageFolder,DatasetFolder
import torch.utils.data
import re
import warnings
from PIL import Image
from PIL import ImageFile
import random
import torch.nn.functional as F
from torch.utils.data import Dataset
def calculate_n_correct(outputs,targets):
_, predicted = outputs.max(1)
n_correct= predicted.eq(targets).sum().item()
return n_correct
class SetSplittor():
def __init__(self,fraction=0.2):
self.fraction=fraction
def split(self,dataset):
pass
class LoaderGenerator():
"""
"""
def __init__(self,root,dataset_name,train_batch_size=1,test_batch_size=1,num_workers=0,kwargs={}):
self.root=root
self.dataset_name=str.lower(dataset_name)
self.train_batch_size=train_batch_size
self.test_batch_size=test_batch_size
self.num_workers=num_workers
self.kwargs=kwargs
self.items=[]
self._train_set=None
self._test_set=None
self._calib_set=None
self.train_transform=None
self.test_transform=None
self.train_loader_kwargs = {
'num_workers': self.num_workers ,
'pin_memory': kwargs.get('pin_memory',True),
'drop_last':kwargs.get('drop_last',False)
}
self.test_loader_kwargs=self.train_loader_kwargs.copy()
self.load()
@property
def train_set(self):
pass
@property
def test_set(self):
pass
def load(self):
pass
def train_loader(self):
assert self.train_set is not None
return torch.utils.data.DataLoader(self.train_set, batch_size=self.train_batch_size, shuffle=True, **self.train_loader_kwargs)
def test_loader(self,shuffle=False,batch_size=None):
assert self.test_set is not None
if batch_size is None:
batch_size=self.test_batch_size
return torch.utils.data.DataLoader(self.test_set, batch_size=batch_size, shuffle=shuffle, **self.test_loader_kwargs)
def val_loader(self):
assert self.val_set is not None
return torch.utils.data.DataLoader(self.val_set, batch_size=self.test_batch_size, shuffle=False, **self.test_loader_kwargs)
def trainval_loader(self):
assert self.trainval_set is not None
return torch.utils.data.DataLoader(self.trainval_set, batch_size=self.train_batch_size, shuffle=True, **self.train_loader_kwargs)
def calib_loader(self,num=1024,seed=3):
if self._calib_set is None:
np.random.seed(seed)
inds=np.random.permutation(len(self.train_set))[:num]
self._calib_set=torch.utils.data.Subset(copy.deepcopy(self.train_set),inds)
self._calib_set.dataset.transform=self.test_transform
return torch.utils.data.DataLoader(self._calib_set, batch_size=num, shuffle=False, **self.train_loader_kwargs)
class CIFARLoaderGenerator(LoaderGenerator):
def load(self):
if self.dataset_name=='cifar100':
self.dataset_fn=datasets.CIFAR100
normalize = transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
std=[0.2673, 0.2564, 0.2762])
elif self.dataset_name=='cifar10':
self.dataset_fn=datasets.CIFAR10
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.2470, 0.2435, 0.2616])
else:
raise NotImplementedError
self.train_transform = transforms.Compose([
transforms.RandomCrop(32,padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
self.test_transform = transforms.Compose([
transforms.ToTensor(),
normalize,
])
@property
def train_set(self):
if self._train_set is None:
self._train_set=self.dataset_fn(self.root, train=True, download=True, transform=self.train_transform)
return self._train_set
@property
def test_set(self):
if self._test_set is None:
self._test_set=self.dataset_fn(self.root, train=False, transform=self.test_transform)
return self._test_set
class COCOLoaderGenerator(LoaderGenerator):
def load(self):
# download from https://github.com/pjreddie/darknet/tree/master/scripts/get_coco_dataset.sh
self.train_set = DetectionListDataset(os.path.join(self.root,'trainvalno5k.txt'),transform=augmentation_detection_tansforms)
self.test_set = DetectionListDataset(os.path.join(self.root,'5k.txt'),transform=detection_tansforms,multiscale=False)
self.train_loader_kwargs={"collate_fn":self.train_set.collate_fn}
self.test_loader_kwargs={"collate_fn":self.test_set.collate_fn}
class DetectionListDataset(Dataset):
def __init__(self, list_path, img_size=416, multiscale=True, transform=None):
with open(list_path, "r") as file:
self.img_files = [path for path in file.readlines()]
self.label_files = [
path.replace("images", "labels").replace(".png", ".txt").replace(".jpg", ".txt")
for path in self.img_files
]
self.img_size = img_size
self.max_objects = 100
self.multiscale = multiscale
self.min_size = self.img_size - 3 * 32
self.max_size = self.img_size + 3 * 32
self.batch_count = 0
self.transform = transform
def __getitem__(self, index):
try:
img_path = self.img_files[index % len(self.img_files)].rstrip()
img = np.array(Image.open(img_path).convert('RGB'), dtype=np.uint8)
except Exception as e:
print(f"Could not read image '{img_path}'.")
return
try:
label_path = self.label_files[index % len(self.img_files)].rstrip()
# Ignore warning if file is empty
with warnings.catch_warnings():
warnings.simplefilter("ignore")
boxes = np.loadtxt(label_path).reshape(-1, 5)
except Exception as e:
print(f"Could not read label '{label_path}'.")
return
if self.transform:
try:
img, bb_targets = self.transform((img, boxes))
except:
print(f"Could not apply transform.")
return
return img_path, img, bb_targets
def collate_fn(self, batch):
self.batch_count += 1
# Drop invalid images
batch = [data for data in batch if data is not None]
paths, imgs, bb_targets = list(zip(*batch))
# Selects new image size every tenth batch
if self.multiscale and self.batch_count % 10 == 0:
self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32))
# Resize images to input shape
imgs = torch.stack([F.interpolate(img.unsqueeze(0), size=self.img_size, mode="nearest").squeeze(0) for img in imgs])
# Add sample index to targets
for i, boxes in enumerate(bb_targets):
boxes[:, 0] = i
bb_targets = torch.cat(bb_targets, 0)
return paths, imgs, bb_targets
def __len__(self):
return len(self.img_files)
# def faster_im_loader(path):
# with open(path,'rb') as f:
# bgr_array = TurboJPEG().decode(f.read())
# rgb_array=np.concatenate([bgr_array[:,:,2:3],bgr_array[:,:,1:2],bgr_array[:,:,0:1]],-1)
# return torch.Tensor(rgb_array)/255
class ImageNetLoaderGenerator(LoaderGenerator):
def load(self):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
self.train_transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
self.test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
@property
def train_set(self):
if self._train_set is None:
self._train_set=ImageFolder(os.path.join(self.root,'train'), self.train_transform)
return self._train_set
@property
def test_set(self):
if self._test_set is None:
self._test_set=ImageFolder(os.path.join(self.root,'val'), self.test_transform)
return self._test_set
class CacheDataset(Dataset):
def __init__(self,datas,targets) -> None:
super().__init__()
self.datas=datas
self.targets=targets
def __getitem__(self,idx):
return self.datas[idx],self.targets[idx]
def __len__(self):
return len(self.datas)
class FasterImageNetLoaderGenerator(ImageNetLoaderGenerator):
def test_loader(self,shuffle=False,batch_size=None):
cache='/dev/shm/imagenet.pkl'
assert self.test_set is not None
if batch_size is None:
batch_size=self.test_batch_size
if os.path.exists(cache):
print("Loading the dataset from shared memory")
datas,targets=torch.load(cache)
else:
print("Preprocessing the dataset and save it to shared memory")
loader=torch.utils.data.DataLoader(self.test_set, batch_size=batch_size, shuffle=shuffle, **self.test_loader_kwargs)
datas=[]
targets=[]
for data,target in loader:
datas.append(data)
targets.append(target)
datas=torch.cat(datas,0)
targets=torch.cat(targets,0)
torch.save([datas,targets],cache)
dataset=CacheDataset(datas,targets)
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, **self.test_loader_kwargs)
class DebugLoaderGenerator(LoaderGenerator):
def load(self):
version=re.findall("\d+",self.dataset_name)[0]
class DebugSet(torch.utils.data.Dataset):
def __getitem__(self,idx):
if version=='0':
return torch.ones([1,4,4]),0
if version=='1':
return torch.ones([1,8,8]),0
if version=='2':
return torch.ones([1,1,1]),0
if version=='3':
return torch.ones([1,3,3]),0
else:
raise NotImplementedError(f"version {version} of Debug dataset is not supported")
def __len__(self): return 1
self.train_set=DebugSet()
self.test_set=DebugSet()
def get_dataset(args:argparse.Namespace):
""" Preparing Datasets, args:
dataset (required): MNIST, cifar10/100, ImageNet, coco
dataset_root: str, default='./datasets'
num_workers: int
batch_size: int
test_batch_size: int
val_fraction: float, default=0
"""
dataset_name=str.lower(args.dataset)
dataset_root=getattr(args,'dataset_root','./datasets')
num_workers=args.num_workers if hasattr(args,'num_workers') else 4
batch_size=args.batch_size if hasattr(args,'batch_size') else 64
test_batch_size=args.test_batch_size if hasattr(args,'test_batch_size') else batch_size
val_fraction=args.val_fraction if hasattr(args,"val_fraction") else 0
if "cifar" in dataset_name:
# Data loading code
g=CIFARLoaderGenerator(dataset_root,args.dataset,batch_size,test_batch_size,num_workers)
elif "coco" in dataset_name:
g=COCOLoaderGenerator(dataset_root,args.dataset,batch_size,test_batch_size,num_workers)
elif "debug" in dataset_name:
g=DebugLoaderGenerator(dataset_root,args.dataset,batch_size,test_batch_size,num_workers)
elif args.dataset=='ImageNet':
g=ImageNetLoaderGenerator(dataset_root,args.dataset,batch_size,test_batch_size,num_workers)
else:
raise NotImplementedError
return g.train_loader(),g.test_loader()
import timm
from timm.models.vision_transformer import VisionTransformer
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
class ViTImageNetLoaderGenerator(ImageNetLoaderGenerator):
"""
DataLoader for Vision Transformer.
To comply with timm's framework, we use the model's corresponding transform.
"""
def __init__(self, root, dataset_name, train_batch_size, test_batch_size, num_workers, kwargs={}):
kwargs.update({"pin_memory":False})
super().__init__(root, dataset_name, train_batch_size=train_batch_size, test_batch_size=test_batch_size, num_workers=num_workers, kwargs=kwargs)
def load(self):
model = self.kwargs.get("model", None)
assert model != None, f"No model in ViTImageNetLoaderGenerator!"
config = resolve_data_config({}, model=model)
self.train_transform = create_transform(**config, is_training=True)
self.test_transform = create_transform(**config)
================================================
FILE: utils/integer.py
================================================
from numpy import dtype
from quant_layers.matmul import MinMaxQuantMatMul, PTQSLBatchingQuantMatMul, PTQSLQuantMatMul, SoSPTQSLBatchingQuantMatMul, SoSPTQSLQuantMatMul
from quant_layers.linear import MinMaxQuantLinear, PTQSLBatchingQuantLinear, PostGeluPTQSLBatchingQuantLinear, PostGeluPTQSLQuantLinear
import torch
import torch.nn as nn
import torch.nn.functional as F
def quantize_int_weight(module):
"""
get weight of type 'uint8' of a quantized module.
Bias are not quantized and you can use raw bias.
"""
assert hasattr(module, 'weight'), f"module {module} does not have weight"
assert module.w_bit == 8, f"module {module}'s weight is quantized with {module.w_bit} bits"
w_int = (module.weight/module.w_interval).round_().clamp_(-module.w_qmax, module.w_qmax-1)
w_int = w_int.cpu().detach().to(torch.int8)
return w_int
def dequantize_int_weight(module, w_int):
"""
Make sure it's the same module that generates w_int
"""
w_sim = module.w_interval.cpu() * w_int.float()
return w_sim
def quantize_matmul_input(input, interval, qmax, n_G, n_V, n_H, crb_groups, crb_rows, crb_cols):
"""
quantize input matrix of matmul operation, with respect to sublayerwise padding settings
"""
pad_groups = crb_groups*n_G - input.shape[1]
pad_rows = crb_rows*n_V - input.shape[2]
pad_cols = crb_cols*n_H - input.shape[3]
x = F.pad(input, [0,pad_cols,0,pad_rows,0,pad_groups])
x = x.view(-1,n_G,crb_groups,n_V,crb_rows,n_H,crb_cols)
x = (x/interval).round_().clamp(-qmax,qmax-1)
x = x.view(-1,n_G*crb_groups,n_V*crb_rows,n_H*crb_cols)
x = x[:,:x.shape[1]-pad_groups,:x.shape[2]-pad_rows,:x.shape[3]-pad_cols]
return x
def quantize_int_activation(module, input):
"""
Quantize current inputs into uint8 and store them as an attribute of the module.
The function is a pre-forward hook that need to be manually added to the calibrated model.
You need to manipulate the cached data before feeding another batch of pictures.
Currently only support int8. (For twin quantization, we use uint8)
For twin quantization:
- For softmax, the MSB being 1 means using large interval, while MSB being 0 means using small interval.
- For post-GELU, the MSB serves as sign bit. We use 1 for positive values and 0 for negative values.
"""
if isinstance(module, PostGeluPTQSLQuantLinear) or isinstance(module, PostGeluPTQSLBatchingQuantLinear):
assert module.a_bit == 8, f"module {module}'s activation is quantized with {module.a_bit} bits"
x = input[0]
int_input_pos = (x/module.a_interval).round_().clamp_(0, module.a_qmax-1)
int_input_pos = int_input_pos.detach().to(torch.uint8) + 128
int_input_neg = (x/module.a_neg_interval).round_().clamp_(-module.a_qmax+1, 0).abs()
int_input_neg = int_input_neg.detach().to(torch.uint8)
int_input = (int_input_pos + int_input_neg).cpu()
module.int_input = [int_input]
elif isinstance(module, MinMaxQuantLinear):
assert module.a_bit == 8, f"module {module}'s activation is quantized with {module.a_bit} bits"
x = input[0]
int_input = (x/module.a_interval).round_().clamp_(-module.a_qmax, module.a_qmax-1)
int_input = int_input.cpu().detach().to(torch.int8)
module.int_input = [int_input]
elif isinstance(module, SoSPTQSLQuantMatMul) or isinstance(module, SoSPTQSLBatchingQuantMatMul):
assert module.A_bit == 8, f"module {module}'s matrix A is quantized with {module.A_bit} bits"
assert module.B_bit == 8, f"module {module}'s matrix B is quantized with {module.B_bit} bits"
A, B = input[0], input[1]
A_high = (A.clamp(module.split, 1)*(module.A_qmax-1)).round_().clamp_(0,module.A_qmax-1)
A_high = A_high.detach().to(torch.uint8) + 128
A_low = (A.clamp(0, module.split)/module.A_interval).round_().clamp_(0,module.A_qmax-1)
A_low = A_low.detach().to(torch.uint8)
A_int = (A_high + A_low).cpu()
B_int = quantize_matmul_input(B,module.B_interval,module.B_qmax,module.n_G_B,module.n_V_B,module.n_H_B,module.crb_groups_B,module.crb_rows_B,module.crb_cols_B)
B_int = B_int.cpu().detach().to(torch.int8)
module.int_input = [A_int, B_int]
elif isinstance(module, PTQSLQuantMatMul) or isinstance(module, PTQSLBatchingQuantMatMul):
assert module.A_bit == 8, f"module {module}'s matrix A is quantized with {module.A_bit} bits"
assert module.B_bit == 8, f"module {module}'s matrix B is quantized with {module.B_bit} bits"
A, B = input[0], input[1]
A_int = quantize_matmul_input(A,module.A_interval,module.A_qmax,module.n_G_A,module.n_V_A,module.n_H_A,module.crb_groups_A,module.crb_rows_A,module.crb_cols_A)
A_int = A_int.cpu().detach().to(torch.int8)
B_int = quantize_matmul_input(B,module.B_interval,module.B_qmax,module.n_G_B,module.n_V_B,module.n_H_B,module.crb_groups_B,module.crb_rows_B,module.crb_cols_B)
B_int = B_int.cpu().detach().to(torch.int8)
module.int_input = [A_int, B_int]
def get_model_int_weight(wrapped_modules):
"""
Get quantized weights (in int8) of a model.
Return:
A dict, with modules' names as keys, and int weights as values.
"""
int_weights = {}
for name, m in wrapped_modules.items():
try:
int_weights[name] = quantize_int_weight(m)
except:
pass
return int_weights
================================================
FILE: utils/models.py
================================================
from types import MethodType
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from timm.models import vision_transformer
from timm.models.vision_transformer import Attention
from timm.models.swin_transformer import WindowAttention
def attention_forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
# attn = (q @ k.transpose(-2, -1)) * self.scale
attn = self.matmul1(q, k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
del q, k
# x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.matmul2(attn, v).transpose(1, 2).reshape(B, N, C)
del attn, v
x = self.proj(x)
x = self.proj_drop(x)
return x
def window_attention_forward(self, x, mask = None):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
# attn = (q @ k.transpose(-2, -1))
attn = self.matmul1(q, k.transpose(-2,-1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
# x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.matmul2(attn, v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MatMul(nn.Module):
def forward(self, A, B):
return A @ B
def get_net(name):
"""
Get a vision transformer model.
This will replace matrix multiplication operations with matmul modules in the model.
Currently support almost all models in timm.models.transformers, including:
- vit_tiny/small/base/large_patch16/patch32_224/384,
- deit_tiny/small/base(_distilled)_patch16_224,
- deit_base(_distilled)_patch16_384,
- swin_tiny/small/base/large_patch4_window7_224,
- swin_base/large_patch4_window12_384
These models are finetuned on imagenet-1k and should use ViTImageNetLoaderGenerator
for calibration and testing.
"""
net = timm.create_model(name, pretrained=True)
for name, module in net.named_modules():
if isinstance(module, Attention):
setattr(module, "matmul1", MatMul())
setattr(module, "matmul2", MatMul())
module.forward = MethodType(attention_forward, module)
if isinstance(module, WindowAttention):
setattr(module, "matmul1", MatMul())
setattr(module, "matmul2", MatMul())
module.forward = MethodType(window_attention_forward, module)
net.cuda()
net.eval()
return net
================================================
FILE: utils/net_wrap.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.models import MatMul
import re
def _fold_bn(conv_module, bn_module):
w = conv_module.weight.data
y_mean = bn_module.running_mean
y_var = bn_module.running_var
safe_std = torch.sqrt(y_var + bn_module.eps)
w_view = (conv_module.out_channels, 1, 1, 1)
if bn_module.affine:
weight = w * (bn_module.weight / safe_std).view(w_view)
beta = bn_module.bias - bn_module.weight * y_mean / safe_std
if conv_module.bias is not None:
bias = bn_module.weight * conv_module.bias / safe_std + beta
else:
bias = beta
else:
weight = w / safe_std.view(w_view)
beta = -y_mean / safe_std
if conv_module.bias is not None:
bias = conv_module.bias / safe_std + beta
else:
bias = beta
return weight, bias
def fold_bn_into_conv(conv_module, bn_module):
w, b = _fold_bn(conv_module, bn_module)
if conv_module.bias is None:
conv_module.bias = nn.Parameter(b.data)
else:
conv_module.bias.data = b.data
conv_module.weight.data = w.data
def wrap_modules_in_net(net,cfg):
wrapped_modules={}
module_dict={}
module_types = {"qkv":"qlinear_qkv", "proj":'qlinear_proj', 'fc1':'qlinear_MLP_1', 'fc2':"qlinear_MLP_2", 'head':'qlinear_classifier','matmul1':"qmatmul_qk", 'matmul2':"qmatmul_scorev", "reduction": "qlinear_reduction"}
it=[(name,m) for name,m in net.named_modules()]
for name,m in it:
module_dict[name]=m
idx=name.rfind('.')
if idx==-1:
idx=0
father_name=name[:idx]
if father_name in module_dict:
father_module=module_dict[father_name]
else:
raise RuntimeError(f"father module {father_name} not found")
if isinstance(m,nn.Conv2d):
# Embedding Layer
idx = idx+1 if idx != 0 else idx
new_m=cfg.get_module("qconv",m.in_channels,m.out_channels,m.kernel_size,m.stride,m.padding,m.dilation,m.groups,m.bias is not None,m.padding_mode)
new_m.weight.data=m.weight.data
new_m.bias=m.bias
replace_m=new_m
wrapped_modules[name] = new_m
setattr(father_module,name[idx:],replace_m)
elif isinstance(m,nn.Linear):
# Linear Layer
idx = idx+1 if idx != 0 else idx
new_m = cfg.get_module(module_types[name[idx:]],m.in_features,m.out_features)
new_m.weight.data=m.weight.data
new_m.bias=m.bias
replace_m=new_m
wrapped_modules[name] = new_m
setattr(father_module,name[idx:],replace_m)
elif isinstance(m,MatMul):
# Matmul Layer
idx = idx+1 if idx != 0 else idx
new_m = cfg.get_module(module_types[name[idx:]])
replace_m=new_m
wrapped_modules[name] = new_m
setattr(father_module,name[idx:],replace_m)
print("Completed net wrap.")
return wrapped_modules
def wrap_certain_modules_in_net(net,cfg,layers,modules_to_wrap,wrap_embedding=False):
"""
wrap specific module inside transformer block of specific layer
layers: list of integers, indicating layers to wrap
modules_to_wrap: list of modules to wrap
"""
wrapped_modules={}
module_dict={}
module_types = {"qkv":"qlinear_qkv", "proj":'qlinear_proj', 'fc1':'qlinear_MLP_1', 'fc2':"qlinear_MLP_2", 'head':'qlinear_classifier','matmul1':"qmatmul_qk", 'matmul2':"qmatmul_scorev"}
it=[(name,m) for name,m in net.named_modules()]
for name,m in it:
module_dict[name]=m
idx=name.rfind('.')
if idx==-1:
idx=0
father_name=name[:idx]
if father_name in module_dict:
father_module=module_dict[father_name]
else:
raise RuntimeError(f"father module {father_name} not found")
layer = re.search('\d+', name)
if layer is not None: # inside a transformer block
layer = int(name[layer.span()[0]:layer.span()[1]])
if layer not in layers: continue
if isinstance(m,nn.Conv2d):
# Embedding Layer
idx = idx+1 if idx != 0 else idx
if not wrap_embedding:
continue # timm patch_embed use proj as well...
# if name[idx:] not in modules_to_wrap: continue
new_m=cfg.get_module("qconv",m.in_channels,m.out_channels,m.kernel_size,m.stride,m.padding,m.dilation,m.groups,m.bias is not None,m.padding_mode)
new_m.weight.data=m.weight.data
new_m.bias=m.bias
replace_m=new_m
wrapped_modules[name] = new_m
setattr(father_module,name[idx:],replace_m)
elif isinstance(m,nn.Linear):
# Linear Layer
idx = idx+1 if idx != 0 else idx
if name[idx:] not in modules_to_wrap: continue
new_m = cfg.get_module(module_types[name[idx:]],m.in_features,m.out_features)
new_m.weight.data=m.weight.data
new_m.bias=m.bias
replace_m=new_m
wrapped_modules[name] = new_m
setattr(father_module,name[idx:],replace_m)
elif isinstance(m,MatMul):
# Matmul Layer
idx = idx+1 if idx != 0 else idx
if name[idx:] not in modules_to_wrap: continue
new_m = cfg.get_module(module_types[name[idx:]])
replace_m=new_m
wrapped_modules[name] = new_m
setattr(father_module,name[idx:],replace_m)
print("Completed net wrap.")
return wrapped_modules
================================================
FILE: utils/quant_calib.py
================================================
from numpy import isin
import torch
from quant_layers.conv import MinMaxQuantConv2d
from quant_layers.linear import MinMaxQuantLinear, PTQSLQuantLinear
from quant_layers.matmul import MinMaxQuantMatMul, PTQSLQuantMatMul
import torch.nn.functional as F
from tqdm import tqdm
class QuantCalibrator():
"""
Modularization of quant calib.
Notice:
all quant modules has method "calibration_step1" that should only store raw inputs and outputs
all quant modules has method "calibration_step2" that should only quantize its intervals
and we assume we could feed in all calibration data in one batch, without backward propagations
sequential calibration is memory-friendly, while parallel calibration may consume
hundreds of GB of memory.
"""
def __init__(self, net, wrapped_modules, calib_loader, sequential=True):
self.net = net
self.wrapped_modules = wrapped_modules
self.calib_loader = calib_loader
self.sequential = sequential
self.calibrated = False
def sequential_quant_calib(self):
"""
A quick implementation of calibration.
Assume calibration dataset could be fed at once.
"""
# run calibration
n_calibration_steps=2
for step in range(n_calibration_steps):
print(f"Start calibration step={step+1}")
for name,module in self.wrapped_modules.items():
# corner cases for calibrated modules
if hasattr(module, "calibrated"):
if step == 1:
module.mode = "raw"
elif step == 2:
module.mode = "quant_forward"
else:
module.mode=f'calibration_step{step+1}'
with torch.no_grad():
for inp,target in self.calib_loader:
inp=inp.cuda()
self.net(inp)
# finish calibration
for name,module in self.wrapped_modules.items():
module.mode='quant_forward'
torch.cuda.empty_cache() # memory footprint cleanup
print("sequential calibration finished")
def parallel_quant_calib(self):
"""
A quick implementation of parallel quant calib
Assume calibration dataset could be fed at once, and memory could hold all raw inputs/outs
"""
# calibration step1: collect raw data
print(f"Start calibration step=1")
for name,module in self.wrapped_modules.items():
# corner cases for calibrated modules
if hasattr(module, "calibrated"):
module.mode = "raw"
else:
module.mode=f'calibration_step1'
with torch.no_grad():
for inp,target in self.calib_loader:
inp=inp.cuda()
self.net(inp)
# calibration step2: each module run calibration with collected raw data
for name,module in self.wrapped_modules.items():
if hasattr(module, "calibrated"):
continue
else:
module.mode=f"calibration_step2"
with torch.no_grad():
if isinstance(module, MinMaxQuantLinear):
module.forward(module.raw_input.cuda())
elif isinstance(module, MinMaxQuantConv2d):
module.forward(module.raw_input.cuda())
elif isinstance(module, MinMaxQuantMatMul):
module.forward(module.raw_input[0].cuda(), module.raw_input[1].cuda())
torch.cuda.empty_cache()
# finish calibration
for name,module in self.wrapped_modules.items():
module.mode='quant_forward'
torch.cuda.empty_cache() # memory footprint cleanup
print("calibration finished")
def quant_calib(self):
calib_layers=[]
for name,module in self.wrapped_modules.items():
calib_layers.append(name)
print(f"prepare parallel calibration for {calib_layers}")
if self.sequential:
self.sequential_quant_calib()
else:
self.parallel_quant_calib()
self.calibrated = True
def batching_quant_calib(self):
calib_layers=[]
for name,module in self.wrapped_modules.items():
calib_layers.append(name)
print(f"prepare parallel calibration for {calib_layers}")
print("start calibration")
# assume wrapped modules are in order (true for dict in python>=3.5)
q = tqdm(self.wrapped_modules.items(), desc="Brecq")
for name, module in q:
q.set_postfix_str(name)
# add fp and bp hooks to current modules, which bypass calibration step 1
# precedent modules are using quant forward
hooks = []
if isinstance(module, MinMaxQuantLinear):
hooks.append(module.register_forward_hook(linear_forward_hook))
if isinstance(module, MinMaxQuantConv2d):
hooks.append(module.register_forward_hook(conv2d_forward_hook))
if isinstance(module, MinMaxQuantMatMul):
hooks.append(module.register_forward_hook(matmul_forward_hook))
# feed in calibration data, and store the data
for inp, target in self.calib_loader:
for batch_st in range(0,self.calib_loader.batch_size,self.batch_size):
self.net.zero_grad()
inp_ = inp[batch_st:batch_st+self.batch_size].cuda()
self.net(inp_)
del inp, target
torch.cuda.empty_cache()
# replace cached raw_inputs, raw_outs
if isinstance(module, MinMaxQuantLinear):
module.raw_input = torch.cat(module.raw_input, dim=0)
module.raw_out = torch.cat(module.raw_out, dim=0)
if isinstance(module, MinMaxQuantConv2d):
module.raw_input = torch.cat(module.raw_input, dim=0)
module.raw_out = torch.cat(module.raw_out, dim=0)
if isinstance(module, MinMaxQuantMatMul):
module.raw_input = [torch.cat(_, dim=0) for _ in module.raw_input]
module.raw_out = torch.cat(module.raw_out, dim=0)
for hook in hooks:
hook.remove()
# run calibration step2
with torch.no_grad():
if isinstance(module, MinMaxQuantLinear):
module.calibration_step2()
if isinstance(module, MinMaxQuantConv2d):
module.calibration_step2()
if isinstance(module, MinMaxQuantMatMul):
module.calibration_step2()
torch.cuda.empty_cache()
# finishing up current module calibration
if self.sequential:
module.mode = "quant_forward"
else:
module.mode = "raw"
# finish calibration
for name, module in self.wrapped_modules.items():
module.mode = "quant_forward"
print("calibration finished")
def grad_hook(module, grad_input, grad_output):
if module.raw_grad is None:
module.raw_grad = []
module.raw_grad.append(grad_output[0].cpu().detach()) # that's a tuple!
def linear_forward_hook(module, input, output):
if module.raw_input is None:
module.raw_input = []
if module.raw_out is None:
module.raw_out = []
module.raw_input.append(input[0].cpu().detach())
module.raw_out.append(output.cpu().detach())
def conv2d_forward_hook(module, input, output):
if module.raw_input is None:
module.raw_input = []
if module.raw_out is None:
module.raw_out = []
module.raw_input.append(input[0].cpu().detach())
module.raw_out.append(output.cpu().detach())
def matmul_forward_hook(module, input, output):
if module.raw_input is None:
module.raw_input = [[],[]]
if module.raw_out is None:
module.raw_out = []
module.raw_input[0].append(input[0].cpu().detach())
module.raw_input[1].append(input[1].cpu().detach())
module.raw_out.append(output.cpu().detach())
class HessianQuantCalibrator(QuantCalibrator):
"""
Modularization of hessian_quant_calib
Hessian metric needs gradients of layer outputs to weigh the loss,
which calls for back propagation in calibration, both sequentially
and parallelly. Despite the complexity of bp, hessian quant calibrator
is compatible with other non-gradient quantization metrics.
"""
def __init__(self, net, wrapped_modules, calib_loader, sequential=False, batch_size=1):
super().__init__(net, wrapped_modules, calib_loader, sequential=sequential)
self.batch_size = batch_size
def quant_calib(self):
"""
An implementation of original hessian calibration.
"""
calib_layers=[]
for name,module in self.wrapped_modules.items():
calib_layers.append(name)
print(f"prepare parallel calibration for {calib_layers}")
print("start hessian calibration")
# get raw_pred as target distribution
with torch.no_grad():
for inp, _ in self.calib_loader:
raw_pred = self.net(inp.cuda())
raw_pred_softmax = F.softmax(raw_pred, dim=-1).detach()
torch.cuda.empty_cache()
# assume wrapped modules are in order (true for dict in python>=3.5)
q = tqdm(self.wrapped_modules.items(), desc="Brecq")
for name, module in q:
q.set_postfix_str(name)
# add fp and bp hooks to current modules, which bypass calibration step 1
# precedent modules are using quant forward
hooks = []
if isinstance(module, MinMaxQuantLinear):
hooks.append(module.register_forward_hook(linear_forward_hook))
if isinstance(module, MinMaxQuantConv2d):
hooks.append(module.register_forward_hook(conv2d_forward_hook))
if isinstance(module, MinMaxQuantMatMul):
hooks.append(module.register_forward_hook(matmul_forward_hook))
if hasattr(module, "metric") and module.metric == "hessian":
hooks.append(module.register_backward_hook(grad_hook))
# feed in calibration data, and store the data
for inp, target in self.calib_loader:
for batch_st in range(0,self.calib_loader.batch_size,self.batch_size):
self.net.zero_grad()
inp_ = inp[batch_st:batch_st+self.batch_size].cuda()
pred = self.net(inp_)
loss = F.kl_div(F.log_softmax(pred, dim=-1), raw_pred_softmax[batch_st:batch_st+self.batch_size], reduction="batchmean")
loss.backward()
del inp, target, pred, loss
torch.cuda.empty_cache()
# replace cached raw_inputs, raw_outs
if isinstance(module, MinMaxQuantLinear):
module.raw_input = torch.cat(module.raw_input, dim=0)
module.raw_out = torch.cat(module.raw_out, dim=0)
if isinstance(module, MinMaxQuantConv2d):
module.raw_input = torch.cat(module.raw_input, dim=0)
module.raw_out = torch.cat(module.raw_out, dim=0)
if isinstance(module, MinMaxQuantMatMul):
module.raw_input = [torch.cat(_, dim=0) for _ in module.raw_input]
module.raw_out = torch.cat(module.raw_out, dim=0)
if hasattr(module, "metric") and module.metric == "hessian":
module.raw_grad = torch.cat(module.raw_grad, dim=0)
for hook in hooks:
hook.remove()
# run calibration step2
with torch.no_grad():
if isinstance(module, MinMaxQuantLinear):
module.calibration_step2(module.raw_input.cuda())
if isinstance(module, MinMaxQuantConv2d):
module.calibration_step2(module.raw_input.cuda())
if isinstance(module, MinMaxQuantMatMul):
module.calibration_step2(module.raw_input[0].cuda(), module.raw_input[1].cuda())
torch.cuda.empty_cache()
# finishing up current module calibration
if self.sequential:
module.mode = "quant_forward"
else:
module.mode = "raw"
# finish calibration
for name, module in self.wrapped_modules.items():
module.mode = "quant_forward"
print("hessian calibration finished")
def batching_quant_calib(self):
calib_layers=[]
for name,module in self.wrapped_modules.items():
calib_layers.append(name)
print(f"prepare parallel calibration for {calib_layers}")
print("start hessian calibration")
# get raw_pred as target distribution
with torch.no_grad():
for inp, _ in self.calib_loader:
raw_pred = self.net(inp.cuda())
raw_pred_softmax = F.softmax(raw_pred, dim=-1).detach()
torch.cuda.empty_cache()
# assume wrapped modules are in order (true for dict in python>=3.5)
q = tqdm(self.wrapped_modules.items(), desc="Hessian")
for name, module in q:
q.set_postfix_str(name)
# add fp and bp hooks to current modules, which bypass calibration step 1
# precedent modules are using quant forward
hooks = []
if isinstance(module, MinMaxQuantLinear):
hooks.append(module.register_forward_hook(linear_forward_hook))
if isinstance(module, MinMaxQuantConv2d):
hooks.append(module.register_forward_hook(conv2d_forward_hook))
if isinstance(module, MinMaxQuantMatMul):
hooks.append(module.register_forward_hook(matmul_forward_hook))
if hasattr(module, "metric"):
hooks.append(module.register_backward_hook(grad_hook))
# feed in calibration data, and store the data
for inp, target in self.calib_loader:
for batch_st in range(0,self.calib_loader.batch_size,self.batch_size):
self.net.zero_grad()
inp_ = inp[batch_st:batch_st+self.batch_size].cuda()
pred = self.net(inp_)
loss = F.kl_div(F.log_softmax(pred, dim=-1), raw_pred_softmax[batch_st:batch_st+self.batch_size], reduction="batchmean")
loss.backward()
del inp, target, pred, loss
torch.cuda.empty_cache()
# replace cached raw_inputs, raw_outs
if isinstance(module, MinMaxQuantLinear):
module.raw_input = torch.cat(module.raw_input, dim=0)
module.raw_out = torch.cat(module.raw_out, dim=0)
if isinstance(module, MinMaxQuantConv2d):
module.raw_input = torch.cat(module.raw_input, dim=0)
module.raw_out = torch.cat(module.raw_out, dim=0)
if isinstance(module, MinMaxQuantMatMul):
module.raw_input = [torch.cat(_, dim=0) for _ in module.raw_input]
module.raw_out = torch.cat(module.raw_out, dim=0)
if hasattr(module, "metric"):
module.raw_grad = torch.cat(module.raw_grad, dim=0)
for hook in hooks:
hook.remove()
# run calibration step2
with torch.no_grad():
if isinstance(module, MinMaxQuantLinear):
module.calibration_step2()
if isinstance(module, MinMaxQuantConv2d):
module.calibration_step2()
if isinstance(module, MinMaxQuantMatMul):
module.calibration_step2()
torch.cuda.empty_cache()
# finishing up current module calibration
if self.sequential:
module.mode = "quant_forward"
else:
module.mode = "raw"
# finish calibration
for name, module in self.wrapped_modules.items():
module.mode = "quant_forward"
print("hessian calibration finished")
gitextract_dm7h0_9y/
├── .gitignore
├── README.md
├── configs/
│ ├── BasePTQ.py
│ └── PTQ4ViT.py
├── example/
│ ├── get_int.py
│ ├── test_ablation.py
│ ├── test_all.py
│ └── test_vit.py
├── quant_layers/
│ ├── conv.py
│ ├── linear.py
│ └── matmul.py
└── utils/
├── datasets.py
├── integer.py
├── models.py
├── net_wrap.py
└── quant_calib.py
SYMBOL INDEX (203 symbols across 14 files)
FILE: configs/BasePTQ.py
function get_module (line 47) | def get_module(module_type, *args, **kwargs):
FILE: configs/PTQ4ViT.py
function get_module (line 51) | def get_module(module_type, *args, **kwargs):
FILE: example/get_int.py
function get_int_weights (line 12) | def get_int_weights(name, config_name):
FILE: example/test_ablation.py
function test_all_ablation (line 17) | def test_all_ablation(name, cfg_modifier=lambda x: x, calib_size=32):
class cfg_modifier (line 42) | class cfg_modifier():
method __init__ (line 43) | def __init__(self, **kwargs):
method __call__ (line 47) | def __call__(self, cfg):
FILE: example/test_all.py
function test_all (line 18) | def test_all(name, cfg_modifier=lambda x: x, calib_size=32, config_name=...
class cfg_modifier (line 48) | class cfg_modifier():
method __init__ (line 49) | def __init__(self, **kwargs):
method __call__ (line 53) | def __call__(self, cfg):
FILE: example/test_vit.py
function parse_args (line 19) | def parse_args():
function test_classification (line 26) | def test_classification(net,test_loader,max_iteration=None, description=...
function process (line 47) | def process(pid, experiment_process, args_queue, n_gpu):
function multiprocess (line 65) | def multiprocess(experiment_process, cfg_list=None, n_gpu=6):
function init_config (line 82) | def init_config(config_name):
function experiment_basic (line 93) | def experiment_basic(net='vit_base_patch16_384', config="PTQ4ViT"):
FILE: quant_layers/conv.py
class MinMaxQuantConv2d (line 9) | class MinMaxQuantConv2d(nn.Conv2d):
method __init__ (line 13) | def __init__(self,in_channels: int,
method forward (line 40) | def forward(self, x):
method quant_weight_bias (line 53) | def quant_weight_bias(self):
method quant_input (line 64) | def quant_input(self,x):
method quant_forward (line 69) | def quant_forward(self,x):
method calibration_step1 (line 76) | def calibration_step1(self,x):
method calibration_step2 (line 83) | def calibration_step2(self,x):
class QuantileQuantConv2d (line 91) | class QuantileQuantConv2d(MinMaxQuantConv2d):
method __init__ (line 95) | def __init__(self,
method _quantile (line 111) | def _quantile(self, tensor, quantile):
method calibration_step2 (line 118) | def calibration_step2(self,x):
class PTQSLQuantConv2d (line 126) | class PTQSLQuantConv2d(MinMaxQuantConv2d):
method __init__ (line 134) | def __init__(self, in_channels: int,
method _get_similarity (line 157) | def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1):
method quant_weight_bias (line 183) | def quant_weight_bias(self):
method _search_best_w_interval (line 191) | def _search_best_w_interval(self, x, weight_interval_candidates):
method _search_best_a_interval (line 222) | def _search_best_a_interval(self, x, input_interval_candidates):
method _initialize_intervals (line 246) | def _initialize_intervals(self, x):
method calibration_step2 (line 253) | def calibration_step2(self, x):
class BatchingEasyQuantConv2d (line 279) | class BatchingEasyQuantConv2d(PTQSLQuantConv2d):
method __init__ (line 281) | def __init__(self, in_channels: int,
method _initialize_calib_parameters (line 297) | def _initialize_calib_parameters(self):
method _initialize_intervals (line 312) | def _initialize_intervals(self):
method _get_similarity (line 322) | def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1,...
method quant_weight_bias (line 353) | def quant_weight_bias(self):
method quant_forward (line 358) | def quant_forward(self, x):
method _search_best_w_interval (line 365) | def _search_best_w_interval(self, weight_interval_candidates):
method _search_best_a_interval (line 398) | def _search_best_a_interval(self, input_interval_candidates):
method calibration_step2 (line 429) | def calibration_step2(self):
class ChannelwiseBatchingQuantConv2d (line 444) | class ChannelwiseBatchingQuantConv2d(PTQSLQuantConv2d):
method __init__ (line 450) | def __init__(self, in_channels: int,
method _initialize_calib_parameters (line 467) | def _initialize_calib_parameters(self):
method _initialize_intervals (line 482) | def _initialize_intervals(self):
method _get_similarity (line 498) | def _get_similarity(self, tensor_raw, tensor_sim, metric=None, raw_gra...
method _search_best_w_interval (line 526) | def _search_best_w_interval(self, weight_interval_candidates):
method _search_best_a_interval (line 559) | def _search_best_a_interval(self, input_interval_candidates):
method calibration_step2 (line 591) | def calibration_step2(self):
method quant_weight_bias (line 605) | def quant_weight_bias(self):
method quant_forward (line 609) | def quant_forward(self, x):
FILE: quant_layers/linear.py
class MinMaxQuantLinear (line 6) | class MinMaxQuantLinear(nn.Linear):
method __init__ (line 7) | def __init__(self,
method forward (line 33) | def forward(self, x):
method quant_weight_bias (line 46) | def quant_weight_bias(self):
method quant_input (line 57) | def quant_input(self, x):
method quant_forward (line 62) | def quant_forward(self,x):
method _bias_correction_quant_forward (line 69) | def _bias_correction_quant_forward(self, x):
method calibration_step1 (line 79) | def calibration_step1(self,x):
method calibration_step2 (line 86) | def calibration_step2(self,x):
class PTQSLQuantLinear (line 94) | class PTQSLQuantLinear(MinMaxQuantLinear):
method __init__ (line 98) | def __init__(self,
method _get_similarity (line 124) | def _get_similarity(self, tensor_raw, tensor_sim, metric=None):
method quant_weight_bias (line 152) | def quant_weight_bias(self):
method quant_input (line 164) | def quant_input(self, x):
method _search_best_w_interval (line 171) | def _search_best_w_interval(self, x, weight_interval_candidates, raw_o...
method _search_best_a_interval (line 202) | def _search_best_a_interval(self, x, input_interval_candidates, raw_ou...
method _initialize_intervals (line 227) | def _initialize_intervals(self, x):
method calibration_step2 (line 235) | def calibration_step2(self,x):
class PostGeluPTQSLQuantLinear (line 262) | class PostGeluPTQSLQuantLinear(PTQSLQuantLinear):
method __init__ (line 263) | def __init__(self,
method quant_input (line 276) | def quant_input(self, x):
method _search_best_a_interval (line 287) | def _search_best_a_interval(self, x, input_interval_candidates, raw_ou...
method _initialize_intervals (line 313) | def _initialize_intervals(self, x):
method calibration_step2 (line 322) | def calibration_step2(self,x):
class PTQSLBatchingQuantLinear (line 349) | class PTQSLBatchingQuantLinear(PTQSLQuantLinear):
method __init__ (line 350) | def __init__(self,
method _initialize_calib_parameters (line 365) | def _initialize_calib_parameters(self):
method _initialize_intervals (line 380) | def _initialize_intervals(self):
method _get_similarity (line 399) | def _get_similarity(self, tensor_raw, tensor_sim, metric=None, raw_gra...
method _get_pearson_w (line 426) | def _get_pearson_w(self, tensor_raw, tensor_sim):
method _get_pearson_a (line 441) | def _get_pearson_a(self, tensor_raw, tensor_sim):
method _search_best_w_interval (line 455) | def _search_best_w_interval(self, weight_interval_candidates):
method _search_best_a_interval (line 497) | def _search_best_a_interval(self, input_interval_candidates):
method calibration_step2 (line 536) | def calibration_step2(self):
class PostGeluPTQSLBatchingQuantLinear (line 557) | class PostGeluPTQSLBatchingQuantLinear(PTQSLBatchingQuantLinear):
method __init__ (line 562) | def __init__(self,
method _initialize_intervals (line 576) | def _initialize_intervals(self):
method quant_input (line 601) | def quant_input(self, x):
method _search_best_a_interval (line 609) | def _search_best_a_interval(self, input_interval_candidates):
FILE: quant_layers/matmul.py
class MinMaxQuantMatMul (line 8) | class MinMaxQuantMatMul(nn.Module):
method __init__ (line 10) | def __init__(self, A_bit=8, B_bit=8, mode="raw"):
method forward (line 22) | def forward(self, A,B):
method quant_input (line 35) | def quant_input(self,x,interval,qmax):
method quant_forward (line 40) | def quant_forward(self,A,B):
method calibration_step1 (line 47) | def calibration_step1(self,A,B):
method calibration_step2 (line 54) | def calibration_step2(self,A,B):
class PTQSLQuantMatMul (line 62) | class PTQSLQuantMatMul(MinMaxQuantMatMul):
method __init__ (line 77) | def __init__(self, A_bit=8, B_bit=8, mode="raw",
method _get_padding_parameters (line 109) | def _get_padding_parameters(self, A, B):
method quant_input_A (line 124) | def quant_input_A(self, x):
method quant_input_B (line 132) | def quant_input_B(self, x):
method quant_forward (line 140) | def quant_forward(self, A, B):
method _get_similarity (line 147) | def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1):
method _search_best_A_interval (line 177) | def _search_best_A_interval(self, A, B, A_interval_candidates):
method _search_best_B_interval (line 210) | def _search_best_B_interval(self, A, B, B_interval_candidates):
method _initialize_intervals (line 243) | def _initialize_intervals(self, A, B):
method calibration_step2 (line 257) | def calibration_step2(self, A, B):
class SoSPTQSLQuantMatMul (line 284) | class SoSPTQSLQuantMatMul(PTQSLQuantMatMul):
method __init__ (line 298) | def __init__(self, A_bit=8, B_bit=8, mode="raw",
method quant_input_A (line 313) | def quant_input_A(self, x):
method _search_best_A_interval (line 318) | def _search_best_A_interval(self, A, B, split_candidates):
method _initialize_intervals (line 348) | def _initialize_intervals(self, A, B):
method calibration_step2 (line 361) | def calibration_step2(self, A, B):
class PTQSLBatchingQuantMatMul (line 390) | class PTQSLBatchingQuantMatMul(PTQSLQuantMatMul):
method __init__ (line 391) | def __init__(self, A_bit=8, B_bit=8, mode="raw",
method _initialize_calib_parameters (line 396) | def _initialize_calib_parameters(self):
method _get_padding_parameters (line 411) | def _get_padding_parameters(self, A, B):
method _initialize_intervals (line 419) | def _initialize_intervals(self):
method _get_similarity (line 442) | def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1,...
method _search_best_A_interval (line 483) | def _search_best_A_interval(self, A_interval_candidates):
method _search_best_B_interval (line 524) | def _search_best_B_interval(self, B_interval_candidates):
method calibration_step2 (line 565) | def calibration_step2(self):
class SoSPTQSLBatchingQuantMatMul (line 578) | class SoSPTQSLBatchingQuantMatMul(PTQSLBatchingQuantMatMul):
method __init__ (line 579) | def __init__(self, A_bit=8, B_bit=8, mode="raw",
method quant_input_A (line 595) | def quant_input_A(self, x):
method _search_best_A_interval (line 600) | def _search_best_A_interval(self, split_candidates):
method calibration_step2 (line 633) | def calibration_step2(self):
FILE: utils/datasets.py
function calculate_n_correct (line 23) | def calculate_n_correct(outputs,targets):
class SetSplittor (line 28) | class SetSplittor():
method __init__ (line 29) | def __init__(self,fraction=0.2):
method split (line 32) | def split(self,dataset):
class LoaderGenerator (line 35) | class LoaderGenerator():
method __init__ (line 38) | def __init__(self,root,dataset_name,train_batch_size=1,test_batch_size...
method train_set (line 60) | def train_set(self):
method test_set (line 64) | def test_set(self):
method load (line 67) | def load(self):
method train_loader (line 70) | def train_loader(self):
method test_loader (line 74) | def test_loader(self,shuffle=False,batch_size=None):
method val_loader (line 80) | def val_loader(self):
method trainval_loader (line 84) | def trainval_loader(self):
method calib_loader (line 88) | def calib_loader(self,num=1024,seed=3):
class CIFARLoaderGenerator (line 96) | class CIFARLoaderGenerator(LoaderGenerator):
method load (line 97) | def load(self):
method train_set (line 119) | def train_set(self):
method test_set (line 125) | def test_set(self):
class COCOLoaderGenerator (line 130) | class COCOLoaderGenerator(LoaderGenerator):
method load (line 131) | def load(self):
class DetectionListDataset (line 138) | class DetectionListDataset(Dataset):
method __init__ (line 139) | def __init__(self, list_path, img_size=416, multiscale=True, transform...
method __getitem__ (line 154) | def __getitem__(self, index):
method collate_fn (line 178) | def collate_fn(self, batch):
method __len__ (line 195) | def __len__(self):
class ImageNetLoaderGenerator (line 204) | class ImageNetLoaderGenerator(LoaderGenerator):
method load (line 205) | def load(self):
method train_set (line 224) | def train_set(self):
method test_set (line 230) | def test_set(self):
class CacheDataset (line 235) | class CacheDataset(Dataset):
method __init__ (line 236) | def __init__(self,datas,targets) -> None:
method __getitem__ (line 241) | def __getitem__(self,idx):
method __len__ (line 244) | def __len__(self):
class FasterImageNetLoaderGenerator (line 247) | class FasterImageNetLoaderGenerator(ImageNetLoaderGenerator):
method test_loader (line 248) | def test_loader(self,shuffle=False,batch_size=None):
class DebugLoaderGenerator (line 270) | class DebugLoaderGenerator(LoaderGenerator):
method load (line 272) | def load(self):
function get_dataset (line 290) | def get_dataset(args:argparse.Namespace):
class ViTImageNetLoaderGenerator (line 325) | class ViTImageNetLoaderGenerator(ImageNetLoaderGenerator):
method __init__ (line 330) | def __init__(self, root, dataset_name, train_batch_size, test_batch_si...
method load (line 334) | def load(self):
FILE: utils/integer.py
function quantize_int_weight (line 8) | def quantize_int_weight(module):
function dequantize_int_weight (line 20) | def dequantize_int_weight(module, w_int):
function quantize_matmul_input (line 27) | def quantize_matmul_input(input, interval, qmax, n_G, n_V, n_H, crb_grou...
function quantize_int_activation (line 44) | def quantize_int_activation(module, input):
function get_model_int_weight (line 113) | def get_model_int_weight(wrapped_modules):
FILE: utils/models.py
function attention_forward (line 10) | def attention_forward(self, x):
function window_attention_forward (line 28) | def window_attention_forward(self, x, mask = None):
class MatMul (line 58) | class MatMul(nn.Module):
method forward (line 59) | def forward(self, A, B):
function get_net (line 62) | def get_net(name):
FILE: utils/net_wrap.py
function _fold_bn (line 8) | def _fold_bn(conv_module, bn_module):
function fold_bn_into_conv (line 30) | def fold_bn_into_conv(conv_module, bn_module):
function wrap_modules_in_net (line 39) | def wrap_modules_in_net(net,cfg):
function wrap_certain_modules_in_net (line 83) | def wrap_certain_modules_in_net(net,cfg,layers,modules_to_wrap,wrap_embe...
FILE: utils/quant_calib.py
class QuantCalibrator (line 9) | class QuantCalibrator():
method __init__ (line 21) | def __init__(self, net, wrapped_modules, calib_loader, sequential=True):
method sequential_quant_calib (line 28) | def sequential_quant_calib(self):
method parallel_quant_calib (line 57) | def parallel_quant_calib(self):
method quant_calib (line 95) | def quant_calib(self):
method batching_quant_calib (line 106) | def batching_quant_calib(self):
function grad_hook (line 173) | def grad_hook(module, grad_input, grad_output):
function linear_forward_hook (line 178) | def linear_forward_hook(module, input, output):
function conv2d_forward_hook (line 186) | def conv2d_forward_hook(module, input, output):
function matmul_forward_hook (line 194) | def matmul_forward_hook(module, input, output):
class HessianQuantCalibrator (line 203) | class HessianQuantCalibrator(QuantCalibrator):
method __init__ (line 212) | def __init__(self, net, wrapped_modules, calib_loader, sequential=Fals...
method quant_calib (line 216) | def quant_calib(self):
method batching_quant_calib (line 300) | def batching_quant_calib(self):
Condensed preview — 16 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (191K chars).
[
{
"path": ".gitignore",
"chars": 147,
"preview": "tmp\n*.pyc\n__pycache__\n*.pth\n.vscode\ncheckpoints\n*.log\n*.csv\n*.png\n*.jpg\noutput\n*.weights\n*.tmp.*\ndata\nckt\n*.out\n*.zip\n*."
},
{
"path": "README.md",
"chars": 11307,
"preview": "# PTQ4ViT\nPost-Training Quantization Framework for Vision Transformers.\nWe use the twin uniform quantization method to r"
},
{
"path": "configs/BasePTQ.py",
"chars": 2160,
"preview": "from quant_layers.conv import PTQSLQuantConv2d, BatchingEasyQuantConv2d\nfrom quant_layers.linear import PTQSLBatchingQua"
},
{
"path": "configs/PTQ4ViT.py",
"chars": 3196,
"preview": "from quant_layers.conv import PTQSLQuantConv2d, ChannelwiseBatchingQuantConv2d\nfrom quant_layers.linear import PTQSLBatc"
},
{
"path": "example/get_int.py",
"chars": 1845,
"preview": "import sys\nsys.path.insert(0,'..')\nsys.path.insert(0,'.')\nfrom example.test_vit import *\nimport utils.net_wrap as net_wr"
},
{
"path": "example/test_ablation.py",
"chars": 4682,
"preview": "from torch.nn.modules import module\nfrom test_vit import *\nfrom quant_layers.conv import MinMaxQuantConv2d\nfrom quant_la"
},
{
"path": "example/test_all.py",
"chars": 4579,
"preview": "from timm.models.layers import config\nfrom torch.nn.modules import module\nfrom test_vit import *\nfrom quant_layers.conv "
},
{
"path": "example/test_vit.py",
"chars": 3817,
"preview": "import sys\nsys.path.insert(0,'..')\nsys.path.insert(0,'.')\nimport torch\nimport torch.nn as nn\nfrom tqdm import tqdm\nimpor"
},
{
"path": "quant_layers/conv.py",
"chars": 33366,
"preview": "from numpy import not_equal\nfrom torch import tensor\nfrom quant_layers.linear import MinMaxQuantLinear\nimport torch\nimpo"
},
{
"path": "quant_layers/linear.py",
"chars": 37581,
"preview": "from quant_layers.matmul import PTQSLBatchingQuantMatMul\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional a"
},
{
"path": "quant_layers/matmul.py",
"chars": 37964,
"preview": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch import Tensor \nfrom torch.nn import functional as F\nfrom"
},
{
"path": "utils/datasets.py",
"chars": 13385,
"preview": "\"\"\"\nReuse version v4\nAuthor: Hahn Yuan\n\"\"\"\nimport PIL\nimport torch\nimport argparse\nimport numpy as np\nimport os\nimport c"
},
{
"path": "utils/integer.py",
"chars": 5548,
"preview": "from numpy import dtype\nfrom quant_layers.matmul import MinMaxQuantMatMul, PTQSLBatchingQuantMatMul, PTQSLQuantMatMul, S"
},
{
"path": "utils/models.py",
"chars": 3378,
"preview": "from types import MethodType\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport timm\nfrom timm.mo"
},
{
"path": "utils/net_wrap.py",
"chars": 5661,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom utils.models import MatMul\nimport re\n\n\ndef _fold"
},
{
"path": "utils/quant_calib.py",
"chars": 16524,
"preview": "from numpy import isin\nimport torch\nfrom quant_layers.conv import MinMaxQuantConv2d\nfrom quant_layers.linear import MinM"
}
]
About this extraction
This page contains the full source code of the hahnyuan/PTQ4ViT GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 16 files (180.8 KB), approximately 49.4k tokens, and a symbol index with 203 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.