Repository: hahnyuan/PTQ4ViT Branch: main Commit: b42c790c2811 Files: 16 Total size: 180.8 KB Directory structure: gitextract_dm7h0_9y/ ├── .gitignore ├── README.md ├── configs/ │ ├── BasePTQ.py │ └── PTQ4ViT.py ├── example/ │ ├── get_int.py │ ├── test_ablation.py │ ├── test_all.py │ └── test_vit.py ├── quant_layers/ │ ├── conv.py │ ├── linear.py │ └── matmul.py └── utils/ ├── datasets.py ├── integer.py ├── models.py ├── net_wrap.py └── quant_calib.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ tmp *.pyc __pycache__ *.pth .vscode checkpoints *.log *.csv *.png *.jpg output *.weights *.tmp.* data ckt *.out *.zip *.json test.ipynb int_weights ================================================ FILE: README.md ================================================ # PTQ4ViT Post-Training Quantization Framework for Vision Transformers. We use the twin uniform quantization method to reduce the quantization error on these activation values. And we use a Hessian guided metric to evaluate different scaling factors, which improves the accuracy of calibration with a small cost. The quantized vision transformers (ViT, DeiT, and Swin) achieve near-lossless prediction accuracy (less than 0.5\% drop at 8-bit quantization) on the ImageNet classification task. Please read the [paper](https://arxiv.org/abs/2111.12293) for details. ## Updates *19/07/2022* Add discussion on Base PTQ, and provide more ablation study results. ### Number of Calibration Images | Model | W8A8 #ims=32 | W6A6 #ims=32 | W8A8 #ims=128 | W6A6 #ims=128 | |:------------:|:------------:|:------------:|:-------------:|:-------------:| | ViT-S/224/32 | 75.58 | 71.91 | 75.54 | 72.29 | | ViT-S/224 | 81.00 | 78.63 | 80.99 | 78.44 | | ViT-B/224 | 84.25 | 81.65 | 84.27 | 81.84 | | ViT-B/384 | 85.83 | 83.35 | 85.81 | 83.84 | | DeiT-S/224 | 79.47 | 76.28 | 79.41 | 76.51 | | DeiT-B/224 | 81.48 | 80.25 | 81.54 | 80.30 | | DeiT-B/384 | 82.97 | 81.55 | 83.01 | 81.67 | | Swin-T/224 | 81.25 | 80.47 | 81.27 | 80.30 | | Swin-S/224 | 83.11 | 82.38 | 83.15 | 82.38 | | Swin-B/224 | 85.15 | 84.01 | 85.17 | 84.15 | | Swin-B/384 | 86.39 | 85.39 | 86.36 | 85.45 | | Model | Time #ims=32 | Time #ims=128 | |:------------:|:------------:|:-------------:| | ViT-S/224/32 | 2 min | 5 min | | ViT-S/224 | 3 min | 7 min | | ViT-B/224 | 4 min | 13 min | | ViT-B/384 | 12 min | 43 min | | DeiT-S/224 | 3 min | 7 min | | DeiT-B/224 | 4 min | 16 min | | DeiT-B/384 | 14 min | 52 min | | Swin-T/224 | 3 min | 9 min | | Swin-S/224 | 8 min | 17 min | | Swin-B/224 | 10 min | 23 min | | Swin-B/384 | 25 min | 69 min | One of the targets of PTQ4ViT is to quickly quantize a vision transformer. We have proposed to pre-compute the output and gradient of each layer and compute the influence of scaling factor candidates in batches to reduce the quantization time. As demonstrated in the second table, PTQ4ViT can quantize most vision transformers in several minutes using 32 calibration images. Using 128 calibration images significantly increases the quantization time. We observe the Top-1 accuracy varies slightly in the first table, demonstrating PTQ4ViT is not very sensitive to the number of calibration images. ### Base PTQ Base PTQ is a simple quantization strategy and serves as a benchmark for our experiments. Like PTQ4ViT, we quantize all weights and inputs for fully-connect layers (including the first projection layer and the last prediction layer), as well as all input matrices of matrix multiplication operations. For fully-connected layers, we use layerwise scaling factors $\Delta_W$ for weight quantization and $\Delta_X$ for input quantization; while for matrix multiplication operations, we use $\Delta_A$ and $\Delta_B$ for A's quantization and B's quantization respectively. To get the best scaling factors, we apply a linear grid search on the search space. The same as EasyQuantand Liu et al., we take hyper-parameters $\alpha=0.5$, $\beta = 1.2$, one search round and use cosine distance as the metric. Note that in PTQ4ViT, we change the hyper-parameters to $\alpha=0$, $\beta = 1.2$ and three search rounds, which slightly improves the performance. It should be noticed that Base PTQ adopts a parallel quantization paradigm, which makes it essentially different from sequential quantization paradigms such as EasyQuant. In sequential quantization, the input data of the current quantizing layer is generated with all previous layers quantizing weights and activations. While in parallel quantization, the input data of the current quantizing layer is simply the raw output of the previous layer. In practice, we found sequential quantization on vision transformers suffers from significant accuracy degradation on small calibration datasets. While parallel quantization shows robustness on small calibration datasets. Therefore, we choose parallel quantization for both Base PTQ and PTQ4ViT. ### More Ablation Study We supply more ablation studies for the hyper-parameters. It is enough to set the number of quantization intervals $\ge$ 20 (accuracy change $< 0.3\%$). It is enough to set the upper bound of m $\ge$ 15 (no accuracy change). The best settings of alpha and beta vary from different layers. It is appropriate to set $\alpha=0$ and $\beta=1/2^{k-1}$, which has little impact on search efficiency. We observe that search rounds has little impact on the prediction accuracy (accuracy change $<$ 0.05\% when search rounds $>1$). We randomly take 32 calibration images to quantize different models 20 times and we observe the fluctuation is not significant. The mean/std of accuracies are: ViT-S/32 $75.55\%/0.055\%$ , ViT-S $80.96\%/0.046\%$, ViT-B $84.12\%/0.068\%$, DeiT-S $79.45\%/0.094\%$ , and Swin-S $83.11\%/0.035\%$. *15/01/2022* Add saved quantized models with PTQ4ViT. | model | link | |:------------:|:--------:| | ViT-S/224/32 | [Google](https://drive.google.com/file/d/195JJJKULvaukte6PA9U08oezjd176CTs/view?usp=sharing) | | ViT-S/224 | [Google](https://drive.google.com/file/d/14uEDgRmDBYoKoZtpO9IWMfG8Uvkt_OuL/view?usp=sharing) | | ViT-B/224 | [Google](https://drive.google.com/file/d/1ou6s9Vd-_iyQ7sj7VYET-pRvJA6WMMLA/view?usp=sharing) | | ViT-B/384 | [Google](https://drive.google.com/file/d/1tuU8or8SfQomtoWam7WFTnUxtuw3n7fs/view?usp=sharing) | | DeiT-S/224 | [Google](https://drive.google.com/file/d/1673fX-SuiRlHhm7k0Yyyx_3ynwtvUPyf/view?usp=sharing) | | DeiT-B/224 | [Google](https://drive.google.com/file/d/1WRAtmPF0kDR9iTLc9gv_63aEkOCZ_zOI/view?usp=sharing) | | DeiT-B/384 | [Google](https://drive.google.com/file/d/1mPPlM2ioe4zts_rdKdjZTCUj8KcbquyA/view?usp=sharing) | | Swin-T/224 | [Google](https://drive.google.com/file/d/1bSahHgtL3yFaHPlG-SDtu__YY0zJ8lxr/view?usp=sharing) | | Swin-S/224 | [Google](https://drive.google.com/file/d/1SxAdDTwQaeJFWnHLFXncVocxMNBIPDOE/view?usp=sharing) | | Swin-B/224 | [Google](https://drive.google.com/file/d/19UUUQYJGs5SQaDe27PjY3x1QTBU5hwXm/view?usp=sharing) | | Swin-B/384 | [Google](https://drive.google.com/file/d/1SxAdDTwQaeJFWnHLFXncVocxMNBIPDOE/view?usp=sharing) | *10/12/2021* Add `utils/integer.py`, you can now: 1. convert calibrated fp32 model into int8 2. register pre-forward hook in the model, and fetch activation in int8. (We use uint8 to store results of twin quantization, please refer to the paper to see the bits' layout). ## Install ### Requirement - python>=3.5 - pytorch>=1.5 - matplotlib - pandas - timm ### Datasets To run example testing, you should put your ImageNet2012 dataset in path `/datasets/imagenet`. We use `ViTImageNetLoaderGenerator` in `utils/datasets.py` to initialize our DataLoader. If your Imagenet datasets are stored elsewhere, you'll need to manually pass its root as an argument when instantiating a `ViTImageNetLoaderGenerator`. ## Usage ### 1. Run example quantization To test on all models with BasePTQ/PTQ4ViT, run ```bash python example/test_all.py ``` To run ablation testing, run ```bash python example/test_ablation.py ``` You can run the testing scripts with multiple GPUs. For example, calling ```bash python example/test_all.py --multigpu --n_gpu 6 ``` will use 6 gpus to run the test. ### 2. Download quantized model checkpoints (Coming soon) ## Results ### Results of BasePTQ | model | original | w8a8 | w6a6 | |:------------:|:--------:|:------:|:-------:| | ViT-S/224/32 | 75.99 | 73.61 | 60.144 | | ViT-S/224 | 81.39 | 80.468 | 70.244 | | ViT-B/224 | 84.54 | 83.896 | 75.668 | | ViT-B/384 | 86.00 | 85.352 | 46.886 | | DeiT-S/224 | 79.80 | 77.654 | 72.268 | | DeiT-B/224 | 81.80 | 80.946 | 78.786 | | DeiT-B/384 | 83.11 | 82.33 | 68.442 | | Swin-T/224 | 81.39 | 80.962 | 78.456 | | Swin-S/224 | 83.23 | 82.758 | 81.742 | | Swin-B/224 | 85.27 | 84.792 | 83.354 | | Swin-B/384 | 86.44 | 86.168 | 85.226 | Results of PTQ4ViT | model | original | w8a8 | w6a6 | |:------------:|:--------:|:------:|:-------:| | ViT-S/224/32 | 75.99 | 75.582 | 71.908 | | ViT-S/224 | 81.39 | 81.002 | 78.63 | | ViT-B/224 | 84.54 | 84.25 | 81.65 | | ViT-B/384 | 86.00 | 85.828 | 83.348 | | DeiT-S/224 | 79.80 | 79.474 | 76.282 | | DeiT-B/224 | 81.80 | 81.482 | 80.25 | | DeiT-B/384 | 83.11 | 82.974 | 81.55 | | Swin-T/224 | 81.39 | 81.246 | 80.47 | | Swin-S/224 | 83.23 | 83.106 | 82.38 | | Swin-B/224 | 85.27 | 85.146 | 84.012 | | Swin-B/384 | 86.44 | 86.394 | 85.388 | ### Results of Ablation - ViT-S/224 (original top-1 accuracy 81.39%) | Hessian Guided | Softmax Twin | GELU Twin | W8A8 | W6A6 | |:--------------:|:------------:|:---------:|:------:|:-------:| | | | | 80.47 | 70.24 | | ✓ | | | 80.93 | 77.20 | | ✓ | ✓ | | 81.11 | 78.57 | | ✓ | | ✓ | 80.84 | 76.93 | | | ✓ | ✓ | 79.25 | 74.07 | | ✓ | ✓ | ✓ | 81.00 | 78.63 | - ViT-B/224 (original top-1 accuracy 84.54%) | Hessian Guided | Softmax Twin | GELU Twin | W8A8 | W6A6 | |:--------------:|:------------:|:---------:|:------:|:-------:| | | | | 83.90 | 75.67 | | ✓ | | | 83.97 | 79.90 | | ✓ | ✓ | | 84.07 | 80.76 | | ✓ | | ✓ | 84.10 | 80.82 | | | ✓ | ✓ | 83.40 | 78.86 | | ✓ | ✓ | ✓ | 84.25 | 81.65 | - ViT-B/384 (original top-1 accuracy 86.00%) | Hessian Guided | Softmax Twin | GELU Twin | W8A8 | W6A6 | |:--------------:|:------------:|:---------:|:------:|:-------:| | | | | 85.35 | 46.89 | | ✓ | | | 85.42 | 79.99 | | ✓ | ✓ | | 85.67 | 82.01 | | ✓ | | ✓ | 85.60 | 82.21 | | | ✓ | ✓ | 84.35 | 80.86 | | ✓ | ✓ | ✓ | 85.89 | 83.19 | ## Citation ``` @article{PTQ4ViT_arixv2022, title={PTQ4ViT: Post-Training Quantization Framework for Vision Transformers}, author={Zhihang Yuan, Chenhao Xue, Yiqi Chen, Qiang Wu, Guangyu Sun}, journal={arXiv preprint arXiv:2111.12293}, year={2022}, } ``` ================================================ FILE: configs/BasePTQ.py ================================================ from quant_layers.conv import PTQSLQuantConv2d, BatchingEasyQuantConv2d from quant_layers.linear import PTQSLBatchingQuantLinear, PostGeluPTQSLBatchingQuantLinear from quant_layers.matmul import PTQSLBatchingQuantMatMul, SoSPTQSLBatchingQuantMatMul bit = 8 conv_fc_name_list = ["qconv", "qlinear_qkv", "qlinear_proj", "qlinear_MLP_1", "qlinear_MLP_2", "qlinear_classifier", "qlinear_reduction"] matmul_name_list = [ "qmatmul_qk", "qmatmul_scorev"] w_bit = {name: bit for name in conv_fc_name_list} a_bit = {name: bit for name in conv_fc_name_list} A_bit = {name: bit for name in matmul_name_list} B_bit = {name: bit for name in matmul_name_list} ptqsl_conv2d_kwargs = { "metric": "cosine", "eq_alpha": 0.5, "eq_beta": 1.2, "eq_n": 100, 'search_round': 1, "n_V": 1, "n_H": 1, } ptqsl_linear_kwargs = { "metric": "cosine", "eq_alpha": 0.5, "eq_beta": 1.2, "eq_n": 100, 'search_round': 1, "n_V": 1, "n_H": 1, "n_a": 1, } ptqsl_matmul_kwargs = { "metric": "cosine", "eq_alpha": 0.5, "eq_beta": 1.2, "eq_n": 100, 'search_round': 1, "n_G_A": 1, "n_V_A": 1, "n_H_A": 1, "n_G_B": 1, "n_V_B": 1, "n_H_B": 1, } def get_module(module_type, *args, **kwargs): if module_type == "qconv": kwargs.update(ptqsl_conv2d_kwargs) module=BatchingEasyQuantConv2d(*args,**kwargs,w_bit=w_bit["qconv"],a_bit=32) # turn off activation quantization # module=PTQSLQuantConv2d(*args,**kwargs,w_bit=w_bit["qconv"],a_bit=32) # turn off activation quantization elif "qlinear" in module_type: kwargs.update(ptqsl_linear_kwargs) if module_type == "qlinear_qkv": kwargs["n_V"] *= 3 # q, k, v module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type]) else: module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type]) elif "qmatmul" in module_type: kwargs.update(ptqsl_matmul_kwargs) module=PTQSLBatchingQuantMatMul(*args,**kwargs,A_bit=A_bit[module_type],B_bit=B_bit[module_type]) return module ================================================ FILE: configs/PTQ4ViT.py ================================================ from quant_layers.conv import PTQSLQuantConv2d, ChannelwiseBatchingQuantConv2d from quant_layers.linear import PTQSLBatchingQuantLinear, PostGeluPTQSLBatchingQuantLinear from quant_layers.matmul import PTQSLBatchingQuantMatMul, SoSPTQSLBatchingQuantMatMul no_softmax = False no_postgelu = False bit = 8 conv_fc_name_list = ["qconv", "qlinear_qkv", "qlinear_proj", "qlinear_MLP_1", "qlinear_MLP_2", "qlinear_classifier", "qlinear_reduction"] matmul_name_list = [ "qmatmul_qk", "qmatmul_scorev"] w_bit = {name: bit for name in conv_fc_name_list} a_bit = {name: bit for name in conv_fc_name_list} A_bit = {name: bit for name in matmul_name_list} B_bit = {name: bit for name in matmul_name_list} ptqsl_conv2d_kwargs = { "metric": "hessian", "eq_alpha": 0.01, "eq_beta": 1.2, "eq_n": 100, 'search_round': 3, "n_V": 1, "n_H": 1, } ptqsl_linear_kwargs = { "metric": "hessian", "eq_alpha": 0.01, "eq_beta": 1.2, "eq_n": 100, 'search_round': 3, "n_V": 1, "n_H": 1, "n_a": 1, "bias_correction":True # Conventionally I'll not add an actual bias correction in linear } ptqsl_matmul_kwargs = { "metric": "hessian", "eq_alpha": 0.01, "eq_beta": 1.2, "eq_n": 100, 'search_round': 3, "n_G_A": 1, "n_V_A": 1, "n_H_A": 1, "n_G_B": 1, "n_V_B": 1, "n_H_B": 1, } def get_module(module_type, *args, **kwargs): if module_type == "qconv": kwargs.update(ptqsl_conv2d_kwargs) module=ChannelwiseBatchingQuantConv2d(*args,**kwargs,w_bit=w_bit["qconv"],a_bit=32) # turn off activation quantization # module=PTQSLQuantConv2d(*args,**kwargs,w_bit=w_bit["qconv"],a_bit=32) # turn off activation quantization elif "qlinear" in module_type: kwargs.update(ptqsl_linear_kwargs) if module_type == "qlinear_qkv": kwargs["n_V"] *= 3 # q, k, v module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type]) elif module_type == "qlinear_MLP_2": if no_postgelu: module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type]) else: module=PostGeluPTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type]) elif module_type == "qlinear_classifier": kwargs["n_V"] = 1 module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type]) else: module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type]) elif "qmatmul" in module_type: kwargs.update(ptqsl_matmul_kwargs) if module_type == "qmatmul_qk": module=PTQSLBatchingQuantMatMul(*args,**kwargs,A_bit=A_bit[module_type],B_bit=B_bit[module_type]) elif module_type == "qmatmul_scorev": if no_softmax: module=PTQSLBatchingQuantMatMul(*args,**kwargs,A_bit=A_bit[module_type],B_bit=B_bit[module_type]) else: module=SoSPTQSLBatchingQuantMatMul(*args,**kwargs,A_bit=A_bit[module_type],B_bit=B_bit[module_type]) return module ================================================ FILE: example/get_int.py ================================================ import sys sys.path.insert(0,'..') sys.path.insert(0,'.') from example.test_vit import * import utils.net_wrap as net_wrap import utils.datasets as datasets import utils.integer as integer from utils.quant_calib import HessianQuantCalibrator from itertools import product def get_int_weights(name, config_name): quant_cfg = init_config(config_name) net = get_net(name) wrapped_modules=net_wrap.wrap_modules_in_net(net,quant_cfg) g=datasets.ViTImageNetLoaderGenerator('/datasets/imagenet','imagenet',32,32,16, kwargs={"model":net}) test_loader=g.test_loader() calib_loader=g.calib_loader(num=32) quant_calibrator = HessianQuantCalibrator(net,wrapped_modules,calib_loader,sequential=False,batch_size=4) # 16 is too big for ViT-L-16 quant_calibrator.batching_quant_calib() int_weights = integer.get_model_int_weight(wrapped_modules) torch.save(int_weights, f"./int_weights/{name}.pth") if __name__ == "__main__": args = parse_args() names = [ # "vit_tiny_patch16_224", # "vit_small_patch32_224", # "vit_small_patch16_224", # "vit_base_patch16_224", "vit_base_patch16_384", # "deit_tiny_patch16_224", # "deit_small_patch16_224", # "deit_base_patch16_224", # "deit_base_patch16_384", # "swin_tiny_patch4_window7_224", # "swin_small_patch4_window7_224", # "swin_base_patch4_window7_224", # "swin_base_patch4_window12_384", ] config_names = ["PTQ4ViT", "BasePTQ"] cfg_list = [] for name, config in product(names, config_names): cfg_list.append({"name":name, "config_name":config}) if args.multiprocess: multiprocess(get_int_weights, cfg_list, n_gpu=args.n_gpu) else: for cfg in cfg_list: get_int_weights(**cfg) ================================================ FILE: example/test_ablation.py ================================================ from torch.nn.modules import module from test_vit import * from quant_layers.conv import MinMaxQuantConv2d from quant_layers.linear import MinMaxQuantLinear, PTQSLQuantLinear from quant_layers.matmul import MinMaxQuantMatMul, PTQSLQuantMatMul import matplotlib.pyplot as plt from utils.net_wrap import wrap_certain_modules_in_net from tqdm import tqdm import torch.nn.functional as F import pickle as pkl from itertools import product import types from utils.quant_calib import HessianQuantCalibrator, QuantCalibrator from utils.models import get_net import time def test_all_ablation(name, cfg_modifier=lambda x: x, calib_size=32): quant_cfg = init_config("PTQ4ViT") quant_cfg = cfg_modifier(quant_cfg) net = get_net(name) wrapped_modules=net_wrap.wrap_modules_in_net(net,quant_cfg) g=datasets.ViTImageNetLoaderGenerator('/datasets/imagenet','imagenet',32,32,16, kwargs={"model":net}) test_loader=g.test_loader() calib_loader=g.calib_loader(num=calib_size) quant_calibrator = HessianQuantCalibrator(net,wrapped_modules,calib_loader,sequential=False,batch_size=4) # 16 is too big for ViT-L-16 quant_calibrator.batching_quant_calib() acc = test_classification(net,test_loader, description=quant_cfg.ptqsl_linear_kwargs["metric"]) print(f"model: {name} \n") print(f"calibration size: {calib_size} \n") print(f"bit settings: {quant_cfg.bit} \n") print(f"ptqsl_conv2d_kwargs: {quant_cfg.ptqsl_conv2d_kwargs} \n") print(f"ptqsl_linear_kwargs: {quant_cfg.ptqsl_linear_kwargs} \n") print(f"ptqsl_matmul_kwargs: {quant_cfg.ptqsl_matmul_kwargs} \n") print(f"accuracy: {acc} \n\n") class cfg_modifier(): def __init__(self, **kwargs): for name, value in kwargs.items(): setattr(self,name,value) def __call__(self, cfg): # bit setting cfg.bit = self.bit_setting cfg.w_bit = {name: self.bit_setting[0] for name in cfg.conv_fc_name_list} cfg.a_bit = {name: self.bit_setting[1] for name in cfg.conv_fc_name_list} cfg.A_bit = {name: self.bit_setting[1] for name in cfg.matmul_name_list} cfg.B_bit = {name: self.bit_setting[1] for name in cfg.matmul_name_list} # conv2d configs cfg.ptqsl_conv2d_kwargs["n_V"] = self.linear_ptq_setting[0] cfg.ptqsl_conv2d_kwargs["n_H"] = self.linear_ptq_setting[1] cfg.ptqsl_conv2d_kwargs["metric"] = self.metric cfg.ptqsl_conv2d_kwargs["search_round"] = self.search_round cfg.ptqsl_conv2d_kwargs["parallel_eq_n"] = 1 # maximum 7 , reserve 4Gb for gradient cfg.ptqsl_conv2d_kwargs["init_layerwise"] = False # linear configs cfg.ptqsl_linear_kwargs["n_V"] = self.linear_ptq_setting[0] cfg.ptqsl_linear_kwargs["n_H"] = self.linear_ptq_setting[1] cfg.ptqsl_linear_kwargs["n_a"] = self.linear_ptq_setting[2] cfg.ptqsl_linear_kwargs["metric"] = self.metric cfg.ptqsl_linear_kwargs["search_round"] = self.search_round cfg.ptqsl_linear_kwargs["parallel_eq_n"] = 1 # maximum 7, reserve 4Gb for gradient cfg.ptqsl_linear_kwargs["init_layerwise"] = False # matmul configs cfg.ptqsl_matmul_kwargs["metric"] = self.metric cfg.ptqsl_matmul_kwargs["search_round"] = self.search_round cfg.ptqsl_matmul_kwargs["parallel_eq_n"] = 1 # maximum 3! cfg.ptqsl_matmul_kwargs["init_layerwise"] = False # ablation cfg.no_softmax = self.no_softmax cfg.no_postgelu = self.no_postgelu return cfg if __name__=='__main__': args = parse_args() names = [ "vit_small_patch16_224", "vit_base_patch16_224", "vit_base_patch16_384", ] metrics = ["hessian", "cosine"] linear_ptq_settings = [(1,1,1)] # n_V, n_H, n_a search_rounds = [3] calib_sizes = [32] bit_settings = [(8,8), (6,6)] # weight, activation no_softmaxs = [True, False] no_postgelus = [True, False] cfg_list = [] for name, metric, linear_ptq_setting, search_round, calib_size, bit_setting, no_softmax, no_postgelu in product(names, metrics, linear_ptq_settings, search_rounds, calib_sizes, bit_settings, no_softmaxs, no_postgelus): cfg_list.append({ "name": name, "cfg_modifier":cfg_modifier(linear_ptq_setting=linear_ptq_setting, metric=metric, search_round=search_round, bit_setting=bit_setting, no_softmax=no_softmax, no_postgelu=no_postgelu), "calib_size":calib_size, }) if args.multiprocess: multiprocess(test_all_ablation, cfg_list, n_gpu=args.n_gpu) else: for cfg in cfg_list: test_all_ablation(**cfg) ================================================ FILE: example/test_all.py ================================================ from timm.models.layers import config from torch.nn.modules import module from test_vit import * from quant_layers.conv import MinMaxQuantConv2d from quant_layers.linear import MinMaxQuantLinear, PTQSLQuantLinear from quant_layers.matmul import MinMaxQuantMatMul, PTQSLQuantMatMul import matplotlib.pyplot as plt from utils.net_wrap import wrap_certain_modules_in_net from tqdm import tqdm import torch.nn.functional as F import pickle as pkl from itertools import product import types from utils.quant_calib import HessianQuantCalibrator, QuantCalibrator from utils.models import get_net import time def test_all(name, cfg_modifier=lambda x: x, calib_size=32, config_name="PTQ4ViT"): quant_cfg = init_config(config_name) quant_cfg = cfg_modifier(quant_cfg) net = get_net(name) wrapped_modules=net_wrap.wrap_modules_in_net(net,quant_cfg) g=datasets.ViTImageNetLoaderGenerator('/datasets/imagenet','imagenet',32,32,16, kwargs={"model":net}) test_loader=g.test_loader() calib_loader=g.calib_loader(num=calib_size) # add timing calib_start_time = time.time() quant_calibrator = HessianQuantCalibrator(net,wrapped_modules,calib_loader,sequential=False,batch_size=4) # 16 is too big for ViT-L-16 quant_calibrator.batching_quant_calib() calib_end_time = time.time() acc = test_classification(net,test_loader, description=quant_cfg.ptqsl_linear_kwargs["metric"]) print(f"model: {name} \n") print(f"calibration size: {calib_size} \n") print(f"bit settings: {quant_cfg.bit} \n") print(f"config: {config_name} \n") print(f"ptqsl_conv2d_kwargs: {quant_cfg.ptqsl_conv2d_kwargs} \n") print(f"ptqsl_linear_kwargs: {quant_cfg.ptqsl_linear_kwargs} \n") print(f"ptqsl_matmul_kwargs: {quant_cfg.ptqsl_matmul_kwargs} \n") print(f"calibration time: {(calib_end_time-calib_start_time)/60}min \n") print(f"accuracy: {acc} \n\n") class cfg_modifier(): def __init__(self, **kwargs): for name, value in kwargs.items(): setattr(self,name,value) def __call__(self, cfg): # bit setting cfg.bit = self.bit_setting cfg.w_bit = {name: self.bit_setting[0] for name in cfg.conv_fc_name_list} cfg.a_bit = {name: self.bit_setting[1] for name in cfg.conv_fc_name_list} cfg.A_bit = {name: self.bit_setting[1] for name in cfg.matmul_name_list} cfg.B_bit = {name: self.bit_setting[1] for name in cfg.matmul_name_list} # conv2d configs cfg.ptqsl_conv2d_kwargs["n_V"] = self.linear_ptq_setting[0] cfg.ptqsl_conv2d_kwargs["n_H"] = self.linear_ptq_setting[1] cfg.ptqsl_conv2d_kwargs["metric"] = self.metric cfg.ptqsl_conv2d_kwargs["init_layerwise"] = False # linear configs cfg.ptqsl_linear_kwargs["n_V"] = self.linear_ptq_setting[0] cfg.ptqsl_linear_kwargs["n_H"] = self.linear_ptq_setting[1] cfg.ptqsl_linear_kwargs["n_a"] = self.linear_ptq_setting[2] cfg.ptqsl_linear_kwargs["metric"] = self.metric cfg.ptqsl_linear_kwargs["init_layerwise"] = False # matmul configs cfg.ptqsl_matmul_kwargs["metric"] = self.metric cfg.ptqsl_matmul_kwargs["init_layerwise"] = False return cfg if __name__=='__main__': args = parse_args() names = [ "vit_tiny_patch16_224", "vit_small_patch32_224", "vit_small_patch16_224", "vit_base_patch16_224", "vit_base_patch16_384", "deit_tiny_patch16_224", "deit_small_patch16_224", "deit_base_patch16_224", "deit_base_patch16_384", "swin_tiny_patch4_window7_224", "swin_small_patch4_window7_224", "swin_base_patch4_window7_224", "swin_base_patch4_window12_384", ] metrics = ["hessian"] linear_ptq_settings = [(1,1,1)] # n_V, n_H, n_a calib_sizes = [32,128] bit_settings = [(8,8), (6,6)] # weight, activation config_names = ["PTQ4ViT", "BasePTQ"] cfg_list = [] for name, metric, linear_ptq_setting, calib_size, bit_setting, config_name in product(names, metrics, linear_ptq_settings, calib_sizes, bit_settings, config_names): cfg_list.append({ "name": name, "cfg_modifier":cfg_modifier(linear_ptq_setting=linear_ptq_setting, metric=metric, bit_setting=bit_setting), "calib_size":calib_size, "config_name": config_name }) if args.multiprocess: multiprocess(test_all, cfg_list, n_gpu=args.n_gpu) else: for cfg in cfg_list: test_all(**cfg) ================================================ FILE: example/test_vit.py ================================================ import sys sys.path.insert(0,'..') sys.path.insert(0,'.') import torch import torch.nn as nn from tqdm import tqdm import argparse from importlib import reload,import_module import multiprocessing import os import time from itertools import product import utils.datasets as datasets import utils.net_wrap as net_wrap from utils.quant_calib import QuantCalibrator, HessianQuantCalibrator from utils.models import get_net def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--n_gpu", type=int, default=6) parser.add_argument("--multiprocess", action='store_true') args = parser.parse_args() return args def test_classification(net,test_loader,max_iteration=None, description=None): pos=0 tot=0 i = 0 max_iteration = len(test_loader) if max_iteration is None else max_iteration with torch.no_grad(): q=tqdm(test_loader, desc=description) for inp,target in q: i+=1 inp=inp.cuda() target=target.cuda() out=net(inp) pos_num=torch.sum(out.argmax(1)==target).item() pos+=pos_num tot+=inp.size(0) q.set_postfix({"acc":pos/tot}) if i >= max_iteration: break print(pos/tot) return pos/tot def process(pid, experiment_process, args_queue, n_gpu): """ worker process. """ gpu_id=pid%n_gpu os.environ['CUDA_VISIBLE_DEVICES']=f'{gpu_id}' tot_run=0 while args_queue.qsize(): test_args=args_queue.get() print(f"Run {test_args} on pid={pid} gpu_id={gpu_id}") experiment_process(**test_args) time.sleep(0.5) tot_run+=1 # run_experiment(**args) print(f"{pid} tot_run {tot_run}") def multiprocess(experiment_process, cfg_list=None, n_gpu=6): """ run experiment processes on "n_gpu" cards via "n_gpu" worker process. "cfg_list" arranges kwargs for each test point, and worker process will fetch kwargs and carry out an experiment. """ args_queue = multiprocessing.Queue() for cfg in cfg_list: args_queue.put(cfg) ps=[] for pid in range(n_gpu): p=multiprocessing.Process(target=process,args=(pid,experiment_process,args_queue,n_gpu)) p.start() ps.append(p) for p in ps: p.join() def init_config(config_name): """initialize the config. Use reload to make sure it's fresh one!""" _,_,files = next(os.walk("./configs")) if config_name+".py" in files: quant_cfg = import_module(f"configs.{config_name}") else: raise NotImplementedError(f"Invalid config name {config_name}") reload(quant_cfg) return quant_cfg def experiment_basic(net='vit_base_patch16_384', config="PTQ4ViT"): """ A basic testbench. """ quant_cfg = init_config(config) net = get_net(net) wrapped_modules = net_wrap.wrap_modules_in_net(net,quant_cfg) g=datasets.ViTImageNetLoaderGenerator('/datasets/imagenet','imagenet',32,32,16,kwargs={"model":net}) test_loader=g.test_loader() calib_loader=g.calib_loader(num=32) quant_calibrator = HessianQuantCalibrator(net,wrapped_modules,calib_loader,sequential=False,batch_size=4) # 16 is too big for ViT-L-16 quant_calibrator.batching_quant_calib() test_classification(net,test_loader) if __name__=='__main__': args = parse_args() cfg_list = [] nets = ['vit_tiny_patch16_224', "deit_base_patch16_384"] configs= ['PTQ4ViT'] cfg_list = [{ "net":net, "config":config, } for net, config in product(nets, configs) ] if args.multiprocess: multiprocess(experiment_basic, cfg_list, n_gpu=args.n_gpu) else: for cfg in cfg_list: experiment_basic(**cfg) ================================================ FILE: quant_layers/conv.py ================================================ from numpy import not_equal from torch import tensor from quant_layers.linear import MinMaxQuantLinear import torch import torch.nn as nn import torch.nn.functional as F from itertools import product class MinMaxQuantConv2d(nn.Conv2d): """ MinMax quantize weight and output """ def __init__(self,in_channels: int, out_channels: int, kernel_size, stride = 1, padding = 0, dilation = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros',mode='raw',w_bit=8,a_bit=8,bias_bit=None): super().__init__(in_channels,out_channels,kernel_size,stride,padding,dilation,groups,bias,padding_mode) self.n_calibration_steps=2 self.mode=mode self.w_bit=w_bit self.a_bit=a_bit self.bias_bit=bias_bit assert bias_bit is None,"No support bias bit now" self.w_interval=None self.a_interval=None self.bias_interval=None self.raw_input=None self.raw_out=None self.metric=None self.next_nodes=[] self.w_qmax=2**(self.w_bit-1) self.a_qmax=2**(self.a_bit-1) # self.bias_qmax=2**(self.bias_bit-1) def forward(self, x): if self.mode=='raw': out=F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) elif self.mode=="quant_forward": out=self.quant_forward(x) elif self.mode=="calibration_step1": out=self.calibration_step1(x) elif self.mode=="calibration_step2": out=self.calibration_step2(x) else: raise NotImplementedError return out def quant_weight_bias(self): w=(self.weight/self.w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1) w_sim=w.mul_(self.w_interval) if self.bias is not None: return w_sim,self.bias # bias=(self.bias/self.bias_interval).round_().clamp_(-self.bias_qmax,self.bias_qmax-1) # bias_sim=bias*self.bias_interval # return w_sim,bias_sim else: return w_sim,None def quant_input(self,x): x_sim=(x/self.a_interval).round_().clamp_(-self.a_qmax,self.a_qmax-1) x_sim.mul_(self.a_interval) return x_sim def quant_forward(self,x): assert self.calibrated is not None,f"You should run calibrate_forward before run quant_forward for {self}" w_sim,bias_sim=self.quant_weight_bias() x_sim=self.quant_input(x) out=F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) return out def calibration_step1(self,x): # step1: collection the FP32 values out=F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) self.raw_input=x.cpu().detach() self.raw_out=out.cpu().detach() return out def calibration_step2(self,x): # step2: search for the best S^w and S^a of each layer self.w_interval=(self.weight.data.abs().max()/(self.w_qmax-0.5)).detach() self.a_interval=(x.abs().max()/(self.a_qmax-0.5)).detach() self.calibrated=True out=self.quant_forward(x) return out class QuantileQuantConv2d(MinMaxQuantConv2d): """ Quantile quantize weight and output """ def __init__(self, in_channels: int, out_channels: int, kernel_size, stride = 1, padding = 0, dilation = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', mode='raw',w_bit=8,a_bit=8,bias_bit=None, w_quantile=0.9999,a_quantile=0.9999): super().__init__(in_channels,out_channels,kernel_size,stride,padding,dilation,groups,bias,padding_mode,mode,w_bit,a_bit,bias_bit) self.w_quantile = w_quantile self.a_quantile = a_quantile def _quantile(self, tensor, quantile): if tensor.numel() >= 16777216: n = tensor.numel()//16777216 return torch.quantile(tensor.view(-1)[:16777216*n].view(n,16777216),quantile,1).mean() else: return torch.quantile(tensor,quantile) def calibration_step2(self,x): # step2: search for the best S^w and S^o of each layer self.w_interval=(self._quantile(self.weight.data.abs(),self.w_quantile)/(self.w_qmax-0.5)).detach() self.a_interval=(self._quantile(x.abs(),self.a_quantile)/(self.a_qmax-0.5)).detach() self.calibrated=True out=self.quant_forward(x) return out class PTQSLQuantConv2d(MinMaxQuantConv2d): """ PTQSL on Conv2d weight: (oc,ic,kw,kh) -> (oc,ic*kw*kh) -> divide into sub-matrixs and quantize input: (B,ic,W,H), keep this shape Only support SL quantization on weights. """ def __init__(self, in_channels: int, out_channels: int, kernel_size, stride = 1, padding = 0, dilation = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros',mode='raw',w_bit=8,a_bit=8,bias_bit=None, metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10, n_V=1, n_H=1, init_layerwise=False): super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit) self.metric = metric self.search_round = search_round self.eq_alpha = eq_alpha self.eq_beta = eq_beta self.eq_n = eq_n self.parallel_eq_n = parallel_eq_n self.n_H = n_H self.n_V = n_V self.init_layerwise = init_layerwise self.raw_grad = None def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1): """ tensor_raw: *, features tensor_sim: *, features similarity: * It's your job to calculate mean on * dims! """ if metric == "cosine": similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=dim) else: if metric == "L1_norm": similarity = -torch.abs(tensor_raw - tensor_sim) elif metric == "L2_norm": similarity = -(tensor_raw - tensor_sim) ** 2 elif metric == "linear_weighted_L2_norm": similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2 elif metric == "square_weighted_L2_norm": similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2 elif metric == "hessian": raw_grad = self.raw_grad.reshape_as(tensor_raw) similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2 else: raise NotImplementedError(f"metric {metric} not implemented!") similarity = torch.mean(similarity, dim=dim) return similarity def quant_weight_bias(self): # self.weight_interval shape: n_V, 1, n_H, 1 oc,ic,kw,kh=self.weight.data.shape w_sim = self.weight.view(self.n_V, oc//self.n_V, self.n_H, (ic*kw*kh)//self.n_H) w_sim = (w_sim/self.w_interval).round_().clamp(-self.w_qmax,self.w_qmax-1).mul_(self.w_interval) w_sim = w_sim.view(oc,ic,kw,kh) return w_sim, self.bias def _search_best_w_interval(self, x, weight_interval_candidates): """ Modularization of searching best weight intervals """ tmp_w_interval = self.w_interval.unsqueeze(0) for v,h in product(range(self.n_V), range(self.n_H)): similarities = [] for p_st in range(0, self.eq_n, self.parallel_eq_n): p_ed = min(self.eq_n, p_st+self.parallel_eq_n) cur_w_interval = tmp_w_interval.repeat(p_ed-p_st,1,1,1,1) cur_w_interval[:,v:v+1,:,h:h+1,:] = weight_interval_candidates[p_st:p_ed,v:v+1,:,h:h+1,:] # quantize weight and bias oc,ic,kw,kh=self.weight.data.shape w_sim = self.weight.view(self.n_V,oc//self.n_V,self.n_H,-1).unsqueeze(0) # shape: 1,n_V,crb_rows,n_H,crb_cols w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,n_V,crb_rows,n_H,crb_cols w_sim = w_sim.view(-1,ic,kw,kh) # shape: parallel_eq_n*oc,ic,kw,kh bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None # quantize input x_sim = self.quant_input(x) # calculate similarity and store them out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: B,parallel_eq_n*oc,fw,fh out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(1), chunks=p_ed-p_st, dim=2), dim=1) # shape: B,parallel_eq_n,oc,fw,fh similarity = self._get_similarity(self.raw_out, out_sim, self.metric, dim=2) # shape: B,parallel_eq_n,fw,fh similarity = torch.mean(similarity, [0,2,3]) # shape: parallel_eq_n similarities.append(similarity) # store best weight interval of h into tmp_w_interval similarities = torch.cat(similarities, dim=0) # shape: eq_n best_index = similarities.argmax(dim=0).reshape(-1,1,1,1,1) tmp_w_interval[:,v:v+1,:,h:h+1,:] = torch.gather(weight_interval_candidates[:,v:v+1,:,h:h+1,:],dim=0,index=best_index) self.w_interval = tmp_w_interval.squeeze(dim=0) def _search_best_a_interval(self, x, input_interval_candidates): similarities = [] for p_st in range(0,self.eq_n,self.parallel_eq_n): p_ed = min(self.eq_n, p_st+self.parallel_eq_n) cur_a_interval = input_interval_candidates[p_st:p_ed] # quantize weight and bias w_sim, bias_sim = self.quant_weight_bias() # quantize input B,ic,iw,ih = x.shape x_sim=x.unsqueeze(0) # shape: 1,B,ic,iw,ih x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: parallel_eq_n,B,ic,iw,ih x_sim=x_sim.view(-1,ic,iw,ih) # calculate similarity and store them out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: parallel_eq_n*B,oc,fw,fh out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(0), chunks=p_ed-p_st, dim=1), dim=0) # shape: parallel_eq_n,B,oc,fw,fh similarity = self._get_similarity(self.raw_out.transpose(0,1), out_sim, self.metric, dim=2) # shape: parallel_eq_n,B,fw,fh similarity = torch.mean(similarity, dim=[1,2,3]) # shape: parallel_eq_n similarities.append(similarity) # store best input interval and store in tmp_a_interval similarities = torch.cat(similarities, dim=0) # shape: eq_n a_best_index = similarities.argmax(dim=0).view(1,1,1,1,1) self.a_interval = torch.gather(input_interval_candidates,dim=0,index=a_best_index).squeeze() def _initialize_intervals(self, x): self.a_interval=(x.abs().max()/(self.a_qmax-0.5)).detach() if self.init_layerwise: self.w_interval = ((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1) else: self.w_interval = (self.weight.view(self.n_V,self.out_channels//self.n_V,self.n_H,-1).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5)) def calibration_step2(self, x): # initialize intervals with minmax intervals self._initialize_intervals(x) # put raw outs on GPU self.raw_out = self.raw_out.to(x.device).unsqueeze(1) # shape: B,1,oc,W,H # put raw grad on GPU self.raw_grad = self.raw_grad.to(x.device) if self.raw_grad != None else None # prepare weight intervals and similarities weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,n_V,1,n_H,1 input_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.a_interval # shape: nq_n,1,1,1,1 for e in range(self.search_round): # search for best weight interval self._search_best_w_interval(x, weight_interval_candidates) # search for best input interval self._search_best_a_interval(x, input_interval_candidates) self.raw_grad = self.raw_grad.to("cpu") if self.raw_grad != None else None self.calibrated = True out=self.quant_forward(x) del self.raw_input, self.raw_out, self.raw_grad return out class BatchingEasyQuantConv2d(PTQSLQuantConv2d): """An agile implementation of Layerwise Easyquant""" def __init__(self, in_channels: int, out_channels: int, kernel_size, stride = 1, padding = 0, dilation = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros',mode='raw',w_bit=8,a_bit=8,bias_bit=None, metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10, n_V=1, n_H=1, init_layerwise=False): super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_V=n_V, n_H=n_H, init_layerwise=init_layerwise) self.n_V = 1 self.n_H = 1 def _initialize_calib_parameters(self): """ set parameters for feeding calibration data """ self.calib_size = int(self.raw_input.shape[0]) self.calib_batch_size = int(self.raw_input.shape[0]) while True: numel = (2*(self.raw_input.numel()+self.raw_out.numel())/self.calib_size*self.calib_batch_size) # number of parameters on GPU self.parallel_eq_n = int((15*1024*1024*1024/4)//numel) if self.parallel_eq_n <= 1: self.calib_need_batching = True self.calib_batch_size //= 2 else: break def _initialize_intervals(self): self.w_interval=(self.weight.data.abs().max()/(self.w_qmax-0.5)).detach() tmp_a_intervals = [] for b_st in range(0,self.calib_size,self.calib_batch_size): b_ed = min(self.calib_size, b_st+self.calib_batch_size) x_ = self.raw_input[b_st:b_ed].cuda() a_interval_=(x_.abs().max()/(self.a_qmax-0.5)).detach().view(1,1) tmp_a_intervals.append(a_interval_) self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=False) def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1, raw_grad=None): """ tensor_raw: *, features tensor_sim: *, features similarity: * It's your job to calculate mean on * dims! """ if metric == "cosine": similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=dim) elif metric == "pearson": # calculate similarity w.r.t complete feature map, but maintain dimension requirement b, parallel_eq_n = tensor_sim.shape[0], tensor_sim.shape[1] similarity = F.cosine_similarity(tensor_raw.view(b,1,-1), tensor_sim.view(b,parallel_eq_n,-1), dim=dim).view(b,parallel_eq_n,1,1) else: if metric == "L1_norm": similarity = -torch.abs(tensor_raw - tensor_sim) elif metric == "L2_norm": similarity = -(tensor_raw - tensor_sim) ** 2 elif metric == "linear_weighted_L2_norm": similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2 elif metric == "square_weighted_L2_norm": similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2 elif metric == "hessian": assert raw_grad != None, f"No raw grad!" raw_grad = raw_grad.reshape_as(tensor_raw) similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2 else: raise NotImplementedError(f"metric {metric} not implemented!") similarity = torch.mean(similarity, dim=dim) return similarity def quant_weight_bias(self): w_sim = self.weight w_sim = (w_sim/self.w_interval).round_().clamp(-self.w_qmax,self.w_qmax-1).mul_(self.w_interval) return w_sim, self.bias def quant_forward(self, x): assert self.calibrated is not None,f"You should run calibrate_forward before run quant_forward for {self}" w_sim,bias_sim=self.quant_weight_bias() x_sim=self.quant_input(x) if self.a_bit < 32 else x out=F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) return out def _search_best_w_interval(self, weight_interval_candidates): batch_similarities = [] for b_st in range(0,self.calib_size,self.calib_batch_size): b_ed = min(self.calib_size, b_st+self.calib_batch_size) x = self.raw_input[b_st:b_ed].cuda() raw_out = self.raw_out[b_st:b_ed].cuda().unsqueeze(1) # shape: b,1,oc,fw,fh raw_grad = self.raw_grad[b_st:b_ed].cuda() similarities = [] for p_st in range(0, self.eq_n, self.parallel_eq_n): p_ed = min(self.eq_n, p_st+self.parallel_eq_n) cur_w_interval = weight_interval_candidates[p_st:p_ed] # shape: parallel_eq_n,1,1,1,1 # quantize weight and bias oc,ic,kw,kh = self.weight.data.shape w_sim = self.weight.unsqueeze(0) # shape: 1,oc,ic,kw,kh w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,oc,ic,kw,kh w_sim = w_sim.reshape(-1,ic,kw,kh) # shape: parallel_eq_n*oc,ic,kw,kh bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None # quantize input x_sim = self.quant_input(x) # calculate similarity and store them out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: b,parallel_eq_n*oc,fw,fh out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(1), chunks=p_ed-p_st, dim=2), dim=1) # shape: b,parallel_eq_n,oc,fw,fh similarity = self._get_similarity(raw_out, out_sim, self.metric, dim=-3, raw_grad=raw_grad) # shape: b,parallel_eq_n,fw,fh similarity = torch.mean(similarity, [2,3]) # shape: b,parallel_eq_n similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1, parallel_eq_n similarities.append(similarity) # store best weight interval of h into tmp_w_interval similarities = torch.cat(similarities, dim=1) # shape: 1,eq_n batch_similarities.append(similarities) batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) #shape: eq_n best_index = batch_similarities.argmax(dim=0).reshape(1,1,1,1,1) # shape: 1,1,1,1,1 self.w_interval = torch.gather(weight_interval_candidates,dim=0,index=best_index).squeeze(dim=0) def _search_best_a_interval(self, input_interval_candidates): batch_similarities = [] for b_st in range(0,self.calib_size,self.calib_batch_size): b_ed = min(self.calib_size, b_st+self.calib_batch_size) x = self.raw_input[b_st:b_ed].cuda() raw_out = self.raw_out[b_st:b_ed].cuda().unsqueeze(0) # shape: 1,b,oc,fw,fh raw_grad = self.raw_grad[b_st:b_ed].cuda() similarities = [] for p_st in range(0,self.eq_n,self.parallel_eq_n): p_ed = min(self.eq_n, p_st+self.parallel_eq_n) cur_a_interval = input_interval_candidates[p_st:p_ed] # shape: parallel_eq_n,1,1,1,1 # quantize weight and bias w_sim, bias_sim = self.quant_weight_bias() # quantize input B,ic,iw,ih = x.shape x_sim=x.unsqueeze(0) # shape: 1,b,ic,iw,ih x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: parallel_eq_n,b,ic,iw,ih x_sim=x_sim.view(-1,ic,iw,ih) # shape: parallel_eq_n*b,ic,iw,ih # calculate similarity and store them out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: parallel_eq_n*b,oc,fw,fh out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(0), chunks=p_ed-p_st, dim=1), dim=0) # shape: parallel_eq_n,b,oc,fw,fh similarity = self._get_similarity(raw_out, out_sim, self.metric, dim=-3, raw_grad=raw_grad) # shape: parallel_eq_n,b,fw,fh similarity = torch.mean(similarity, dim=[3,4]) # shape: parallel_eq_n,b similarity = torch.sum(similarity, dim=1, keepdim=True) # shape: parallel_eq_n,1 similarities.append(similarity) similarities = torch.cat(similarities, dim=0) # shape: eq_n, 1 batch_similarities.append(similarities) batch_similarities = torch.cat(batch_similarities, dim=1).sum(dim=1, keepdim=False) #shape: eq_n a_best_index = batch_similarities.argmax(dim=0).view(1,1,1,1,1) self.a_interval = torch.gather(input_interval_candidates,dim=0,index=a_best_index).squeeze() def calibration_step2(self): self._initialize_calib_parameters() self._initialize_intervals() weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval # shape: eq_n,1,1,1,1 input_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.a_interval # shape: eq_n,1,1,1,1 for e in range(self.search_round): # search for best weight interval self._search_best_w_interval(weight_interval_candidates) # search for best input interval if self.a_bit < 32: self._search_best_a_interval(input_interval_candidates) self.calibrated = True del self.raw_input, self.raw_out, self.raw_grad class ChannelwiseBatchingQuantConv2d(PTQSLQuantConv2d): """ Only implemented acceleration with batching_calibration_step2 setting a_bit to >= 32 will use minmax quantization, which means turning off activation quantization """ def __init__(self, in_channels: int, out_channels: int, kernel_size, stride = 1, padding = 0, dilation = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros',mode='raw',w_bit=8,a_bit=8,bias_bit=None, metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10, n_V=1, n_H=1, init_layerwise=False): super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_V=n_V, n_H=n_H, init_layerwise=init_layerwise) self.n_V = self.out_channels self.n_H = 1 def _initialize_calib_parameters(self): """ set parameters for feeding calibration data """ self.calib_size = int(self.raw_input.shape[0]) self.calib_batch_size = int(self.raw_input.shape[0]) while True: numel = (2*(self.raw_input.numel()+self.raw_out.numel())/self.calib_size*self.calib_batch_size) # number of parameters on GPU self.parallel_eq_n = int((15*1024*1024*1024/4)//numel) if self.parallel_eq_n <= 1: self.calib_need_batching = True self.calib_batch_size //= 2 else: break def _initialize_intervals(self): # weight intervals: shape oc,1,1,1 if self.init_layerwise: self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.out_channels,1,1,1) else: self.w_interval=((self.weight.abs().amax([1,2,3],keepdim=True))/(self.w_qmax-0.5)) # activation intervals: shape 1 tmp_a_intervals = [] for b_st in range(0,self.calib_size,self.calib_batch_size): b_ed = min(self.calib_size, b_st+self.calib_batch_size) x_ = self.raw_input[b_st:b_ed].cuda() a_interval_=(x_.abs().max()/(self.a_qmax-0.5)).detach().view(1,1) tmp_a_intervals.append(a_interval_) self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=False) def _get_similarity(self, tensor_raw, tensor_sim, metric=None, raw_grad=None): """ tensor_raw: *, features tensor_sim: *, features similarity: *, features """ if metric == "cosine": # support cosine on patch dim, which is sub-optimal # not supporting search best a interval b, parallel_eq_n, oc = tensor_sim.shape[0], tensor_sim.shape[1], tensor_sim.shape[2] similarity = F.cosine_similarity(tensor_raw.view(b,1,oc,-1), tensor_sim.view(b,parallel_eq_n,oc,-1), dim=-1).view(b,parallel_eq_n,oc,1,1) else: if metric == "L1_norm": similarity = -torch.abs(tensor_raw - tensor_sim) elif metric == "L2_norm": similarity = -(tensor_raw - tensor_sim) ** 2 elif metric == "linear_weighted_L2_norm": similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2 elif metric == "square_weighted_L2_norm": similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2 elif metric == "hessian": assert raw_grad != None, f"raw_grad is None in _get_similarity!" raw_grad = raw_grad.reshape_as(tensor_raw) similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2 else: raise NotImplementedError(f"metric {metric} not implemented!") return similarity def _search_best_w_interval(self, weight_interval_candidates): batch_similarities = [] for b_st in range(0,self.calib_size,self.calib_batch_size): b_ed = min(self.calib_size, b_st+self.calib_batch_size) x = self.raw_input[b_st:b_ed].cuda() raw_out = self.raw_out[b_st:b_ed].cuda().unsqueeze(1) # shape: b,1,oc,fw,fh raw_grad = self.raw_grad[b_st:b_ed].cuda() similarities = [] for p_st in range(0, self.eq_n, self.parallel_eq_n): p_ed = min(self.eq_n, p_st+self.parallel_eq_n) cur_w_interval = weight_interval_candidates[p_st:p_ed] # shape: parallel_eq_n,oc,1,1,1 # quantize weight and bias oc,ic,kw,kh = self.weight.data.shape w_sim = self.weight.unsqueeze(0) # shape: 1,oc,ic,kw,kh w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,oc,ic,kw,kh w_sim = w_sim.reshape(-1,ic,kw,kh) # shape: parallel_eq_n*oc,ic,kw,kh bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None # quantize input x_sim = self.quant_input(x) if self.a_bit < 32 else x # calculate similarity and store them out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: b,parallel_eq_n*oc,fw,fh out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(1), chunks=p_ed-p_st, dim=2), dim=1) # shape: b,parallel_eq_n,oc,fw,fh similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad) # shape: b,parallel_eq_n,oc,fw,fh similarity = torch.mean(similarity, [3,4]) # shape: b,parallel_eq_n,oc similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1, parallel_eq_n, oc similarities.append(similarity) # store best weight interval of h into tmp_w_interval similarities = torch.cat(similarities, dim=1) # shape: 1,eq_n,oc batch_similarities.append(similarities) batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) #shape: eq_n,oc best_index = batch_similarities.argmax(dim=0).reshape(1,-1,1,1,1) # shape: 1,oc,1,1,1 self.w_interval = torch.gather(weight_interval_candidates,dim=0,index=best_index).squeeze(dim=0) def _search_best_a_interval(self, input_interval_candidates): batch_similarities = [] for b_st in range(0,self.calib_size,self.calib_batch_size): b_ed = min(self.calib_size, b_st+self.calib_batch_size) x = self.raw_input[b_st:b_ed].cuda() raw_out = self.raw_out[b_st:b_ed].cuda().unsqueeze(1) # shape: b,1,oc,fw,fh raw_grad = self.raw_grad[b_st:b_ed].cuda() similarities = [] for p_st in range(0,self.eq_n,self.parallel_eq_n): p_ed = min(self.eq_n, p_st+self.parallel_eq_n) cur_a_interval = input_interval_candidates[p_st:p_ed] # shape: parallel_eq_n,1,1,1,1 # quantize weight and bias w_sim, bias_sim = self.quant_weight_bias() # quantize input B,ic,iw,ih = x.shape x_sim=x.unsqueeze(0) # shape: 1,b,ic,iw,ih x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: parallel_eq_n,b,ic,iw,ih x_sim=x_sim.view(-1,ic,iw,ih) # shape: parallel_eq_n*b,ic,iw,ih # calculate similarity and store them out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: parallel_eq_n*b,oc,fw,fh out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(0), chunks=p_ed-p_st, dim=1), dim=0) # shape: parallel_eq_n,b,oc,fw,fh out_sim = out_sim.transpose_(0, 1) # shape: b,parallel_eq_n,oc,fw,fh similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad=raw_grad) # shape: b,parallel_eq_n,oc,fw,fh similarity = torch.mean(similarity, dim=[2,3,4]) # shape: b,parallel_eq_n similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1,parallel_eq_n similarities.append(similarity) similarities = torch.cat(similarities, dim=1) # shape: 1,eq_n batch_similarities.append(similarities) batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) #shape: eq_n a_best_index = batch_similarities.argmax(dim=0).view(1,1,1,1,1) self.a_interval = torch.gather(input_interval_candidates,dim=0,index=a_best_index).squeeze() def calibration_step2(self): self._initialize_calib_parameters() self._initialize_intervals() weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,oc,1,1,1 input_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.a_interval # shape: eq_n,1,1,1,1 for e in range(self.search_round): # search for best weight interval self._search_best_w_interval(weight_interval_candidates) # search for best input interval if self.a_bit < 32: self._search_best_a_interval(input_interval_candidates) self.calibrated = True del self.raw_input, self.raw_out, self.raw_grad def quant_weight_bias(self): w_sim = (self.weight/self.w_interval).round_().clamp(-self.w_qmax,self.w_qmax-1).mul_(self.w_interval) return w_sim, self.bias def quant_forward(self, x): assert self.calibrated is not None,f"You should run calibrate_forward before run quant_forward for {self}" w_sim,bias_sim=self.quant_weight_bias() x_sim=self.quant_input(x) if self.a_bit < 32 else x out=F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) return out ================================================ FILE: quant_layers/linear.py ================================================ from quant_layers.matmul import PTQSLBatchingQuantMatMul import torch import torch.nn as nn import torch.nn.functional as F class MinMaxQuantLinear(nn.Linear): def __init__(self, in_features: int, out_features: int, bias: bool = True, mode = "raw", w_bit = 8, a_bit = 8, bias_bit = None, bias_correction=False): super().__init__(in_features,out_features,bias) self.n_calibration_step=2 self.mode = mode self.w_bit = w_bit self.a_bit = a_bit self.bias_bit=bias_bit assert bias_bit is None,"No support bias bit now" self.w_interval=None self.a_interval=None self.raw_input=None self.raw_out=None self.metric=None self.next_nodes=[] self.w_qmax=2**(self.w_bit-1) self.a_qmax=2**(self.a_bit-1) self.bias_correction = bias_correction def forward(self, x): if self.mode=='raw': out=F.linear(x, self.weight, self.bias) elif self.mode=="quant_forward": out=self.quant_forward(x) elif self.mode=="calibration_step1": out=self.calibration_step1(x) elif self.mode=="calibration_step2": out=self.calibration_step2(x) else: raise NotImplementedError return out def quant_weight_bias(self): w=(self.weight/self.w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1) w_sim=w.mul_(self.w_interval) if self.bias is not None: return w_sim,self.bias # bias=(self.bias/self.bias_interval).round_().clamp_(-self.bias_qmax,self.bias_qmax-1) # bias_sim=bias*self.bias_interval # return w_sim,bias_sim else: return w_sim,None def quant_input(self, x): x_sim=(x/self.a_interval).round_().clamp_(-self.a_qmax,self.a_qmax-1) x_sim.mul_(self.a_interval) return x_sim def quant_forward(self,x): assert self.calibrated is not None,f"You should run calibrate_forward before run quant_forward for {self}" w_sim,bias_sim=self.quant_weight_bias() x_sim=self.quant_input(x) out=F.linear(x_sim, w_sim, bias_sim) return out def _bias_correction_quant_forward(self, x): if self.bias_correction and self.bias != None: w_sim = self.quant_weight_bias()[0] x_sim = self.quant_input(x) eps = F.linear(x_sim, w_sim-self.weight.data, None) eps = torch.mean(eps, dim=(list(range(len(eps.shape)-1))), keepdim=False) self.bias -= eps self.bias_correction = False return self.quant_forward(x) def calibration_step1(self,x): # step1: collection the FP32 values out=F.linear(x, self.weight, self.bias) self.raw_input=x.cpu().detach() self.raw_out=out.cpu().detach() return out def calibration_step2(self,x): # step2: search for the best S^w and S^o of each layer self.w_interval=(self.weight.data.abs().max()/(self.w_qmax-0.5)).detach() self.a_interval=(x.abs().max()/(self.a_qmax-0.5)).detach() self.calibrated=True out=self._bias_correction_quant_forward(x) return out class PTQSLQuantLinear(MinMaxQuantLinear): """ PTQSL on linear modules. """ def __init__(self, in_features: int, out_features: int, bias: bool = True, mode = "raw", w_bit = 8, a_bit = 8, bias_bit = None, bias_correction = False, metric="L2_norm", search_round=1, eq_alpha=0, eq_beta=1, eq_n=100, parallel_eq_n=10, n_H=1, n_V=1, n_a=1, init_layerwise=False): super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, bias_correction=bias_correction) self.metric = metric self.search_round = search_round self.eq_alpha = eq_alpha self.eq_beta = eq_beta self.eq_n = eq_n self.n_H = n_H self.n_V = n_V self.n_a = n_a self.crb_rows = out_features // n_V self.crb_cols = in_features // n_H # ignore remnent != 0 situations self.crb_acts = in_features // n_a self.parallel_eq_n = parallel_eq_n self.init_layerwise = init_layerwise self.raw_grad = None def _get_similarity(self, tensor_raw, tensor_sim, metric=None): """ tensor_raw: *, features tensor_sim: *, features similarity: * It's your job to calculate mean on * dims! """ if metric == "cosine": similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=-1) elif metric == "pearson": similarity = F.cosine_similarity(tensor_raw-torch.mean(tensor_raw,dim=-1,keepdim=True), tensor_sim-torch.mean(tensor_sim,dim=-1,keepdim=True), dim=-1) else: if metric == "L1_norm": similarity = -torch.abs(tensor_raw - tensor_sim) elif metric == "L2_norm": similarity = -(tensor_raw - tensor_sim) ** 2 elif metric == "linear_weighted_L2_norm": similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2 elif metric == "square_weighted_L2_norm": similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2 elif metric == "hessian": raw_grad = self.raw_grad.reshape_as(tensor_raw) similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2 else: raise NotImplementedError(f"metric {metric} not implemented!") similarity = torch.mean(similarity, dim=-1) return similarity def quant_weight_bias(self): # self.w_interval shape: n_V, 1, n_H, 1 w=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols)/self.w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1) w_sim=w.mul_(self.w_interval).view(self.out_features,self.in_features) if self.bias is not None: return w_sim,self.bias # bias=(self.bias/self.bias_interval).round_().clamp_(-self.bias_qmax,self.bias_qmax-1) # bias_sim=bias*self.bias_interval # return w_sim,bias_sim else: return w_sim,None def quant_input(self, x): # self.a_interval shape: n_a,1 x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2) x_sim=(x_sim.div_(self.a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1) x_sim = x_sim.mul_(self.a_interval).reshape_as(x) return x_sim def _search_best_w_interval(self, x, weight_interval_candidates, raw_out_expanded_chunked): """ Modularization of searching best weight intervals """ tmp_w_interval = self.w_interval.unsqueeze(0) # shape: 1,n_V,1,n_H,1 for h in range(self.n_H): similarities = [] for p_st in range(0,self.eq_n,self.parallel_eq_n): p_ed = min(self.eq_n, p_st+self.parallel_eq_n) cur_w_interval = tmp_w_interval.repeat(p_ed-p_st,1,1,1,1) cur_w_interval[:,:,:,h:h+1,:] = weight_interval_candidates[p_st:p_ed,:,:,h:h+1,:] # quantize weight and bias w_sim = self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).unsqueeze(0) # shape: 1,n_V,crb_rows,n_H,crb_cols w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,n_V,crb_rows,n_H,crb_cols w_sim = w_sim.view(-1,self.in_features) # shape: parallel_eq_n*oc,ic bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None # quantize input x_sim = self.quant_input(x) # calculate similarity and store them out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: B,*,parallel_eq_n*oc out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(-2), chunks=p_ed-p_st, dim=-1), dim=-2) # shape: B,*,parallel_eq_n,oc out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: B,*,parallel_eq_n,n_V,crb_rows similarity = self._get_similarity(raw_out_expanded_chunked, out_sim, self.metric) # shape: B,*,parallel_eq_n,n_V similarity = torch.mean(similarity, dim=list(range(len(similarity.shape)-2))) # shape: parallel_eq_n, n_V similarities.append(similarity) # store best weight interval of h into tmp_w_interval similarities = torch.cat(similarities, dim=0) # shape: eq_n, n_V h_best_index = similarities.argmax(dim=0).reshape(1,-1,1,1,1) # shape: 1,n_V,1,1,1 tmp_w_interval[:,:,:,h:h+1,:] = torch.gather(weight_interval_candidates[:,:,:,h:h+1,:],dim=0,index=h_best_index) self.w_interval = tmp_w_interval.squeeze(dim=0) def _search_best_a_interval(self, x, input_interval_candidates, raw_out_expanded): tmp_a_interval = self.a_interval.unsqueeze(-1) # shape: n_a,1,1 for a in range(self.n_a): similarities = [] for p_st in range(0,self.eq_n,self.parallel_eq_n): p_ed = min(self.eq_n, p_st+self.parallel_eq_n) cur_a_interval = tmp_a_interval.repeat(1,1,p_ed-p_st) # shape: n_a,1,parallel_eq_n cur_a_interval[a:a+1,:,:] = input_interval_candidates[a:a+1,:,p_st:p_ed] # quantize weight and bias w_sim, bias_sim = self.quant_weight_bias() # quantize input x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2).unsqueeze(-1) x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: B,*,n_a,crb_acts,parallel_eq_n x_sim = x_sim.permute(*list(range(len(x_sim.shape)-3)),-1,-3,-2).reshape(*x.shape[:-1],p_ed-p_st,x.shape[-1]) # shape: B,*,parallel_eq_n,ic # calculate similarity and store them out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: B,*,parallel_eq_n,oc similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric) # shape: B,*,parallel_eq_n similarity = torch.mean(similarity, dim=list(range(len(similarity.shape)-1))) # shape: parallel_eq_n similarities.append(similarity) # store best input interval and store in tmp_a_interval similarities = torch.cat(similarities, dim=0) # shape: eq_n a_best_index = similarities.argmax(dim=0, keepdim=True).reshape(1,1,-1) tmp_a_interval[a:a+1,:,:] = torch.gather(input_interval_candidates[a:a+1,:,:],dim=2,index=a_best_index) self.a_interval = tmp_a_interval.squeeze(-1) def _initialize_intervals(self, x): if self.init_layerwise: self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1) self.a_interval=(x.abs().max()/(self.a_qmax-0.5)).detach().view(1,1).repeat(self.n_a,1) else: self.w_interval=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5)) self.a_interval=((x.view(*x.shape[:-1],self.n_a,self.crb_acts).abs().amax(list(range(len(x.shape)-1))+[-1],keepdim=False))/(self.a_qmax-0.5)).unsqueeze(-1) def calibration_step2(self,x): # initialize intervals with minmax intervals self._initialize_intervals(x) # put raw outs on GPU raw_out_expanded = self.raw_out.to(x.device).unsqueeze(-2) # shape: B,*,1,oc raw_out_expanded_chunked = torch.cat(torch.chunk(raw_out_expanded.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: B,*,1,n_V,crb_rows # put raw grad on GPU self.raw_grad = self.raw_grad.to(x.device) if self.raw_grad != None else None # prepare weight intervals and similarities weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,n_V,1,n_H,1 input_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(1,1,-1) * self.a_interval.unsqueeze(-1) # shape: n_a,1,eq_n for e in range(self.search_round): # search for best weight interval self._search_best_w_interval(x, weight_interval_candidates, raw_out_expanded_chunked) # search for best input interval self._search_best_a_interval(x, input_interval_candidates, raw_out_expanded) self.raw_grad = self.raw_grad.to("cpu") if self.raw_grad != None else None self.calibrated = True out=self._bias_correction_quant_forward(x) del self.raw_input, self.raw_out, self.raw_grad return out class PostGeluPTQSLQuantLinear(PTQSLQuantLinear): def __init__(self, in_features: int, out_features: int, bias: bool = True, mode = "raw", w_bit = 8, a_bit = 8, bias_bit = None, bias_correction = False, metric="L2_norm", search_round=1, eq_alpha=0, eq_beta=1, eq_n=100, parallel_eq_n=10, n_H=1, n_V=1, n_a=1, init_layerwise=False): super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, bias_correction=bias_correction, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_H=n_H, n_V=n_V, n_a=n_a, init_layerwise=init_layerwise) def quant_input(self, x): """ self.a_interval = [a_interval_pos, a_interval_neg] """ # self.a_interval[0] shape: n_a,1 # self.a_interval[1] shape: 1 x_=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2) x_pos=(x_/(self.a_interval[0])).round_().clamp_(0,self.a_qmax-1).mul_(self.a_interval[0]) x_neg=(x_/(self.a_interval[1])).round_().clamp_(-self.a_qmax,0).mul_(self.a_interval[1]) return (x_pos + x_neg).reshape_as(x) def _search_best_a_interval(self, x, input_interval_candidates, raw_out_expanded): tmp_a_interval = self.a_interval[0].unsqueeze(-1) # shape: n_a,1,1 for a in range(self.n_a): similarities = [] for p_st in range(0,self.eq_n,self.parallel_eq_n): p_ed = min(self.eq_n, p_st+self.parallel_eq_n) cur_a_interval = tmp_a_interval.repeat(1,1,p_ed-p_st) # shape: n_a,1,parallel_eq_n cur_a_interval[a:a+1,:,:] = input_interval_candidates[a:a+1,:,p_st:p_ed] # quantize weight and bias w_sim, bias_sim = self.quant_weight_bias() # quantize input x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2).unsqueeze(-1) x_pos=(x_sim/(cur_a_interval)).round_().clamp_(0,self.a_qmax-1)*(cur_a_interval) # shape: B,*,n_a,crb_acts,parallel_eq_n x_neg=(x_sim/(self.a_interval[1])).round_().clamp_(-self.a_qmax,0)*(self.a_interval[1]) # shape: B,*,n_a,crb_acts,1 x_sim = (x_pos + x_neg).permute(*list(range(len(x_sim.shape)-3)),-1,-3,-2).reshape(*x.shape[:-1],p_ed-p_st,x.shape[-1]) # shape: B,*,parallel_eq_n,ic # calculate similarity and store them out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: B,*,parallel_eq_n,oc similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric) # shape: B,*,parallel_eq_n similarity = torch.mean(similarity, dim=list(range(len(similarity.shape)-1))) # shape: parallel_eq_n similarities.append(similarity) # store best input interval and store in tmp_a_interval similarities = torch.cat(similarities, dim=0) # shape: eq_n a_best_index = similarities.argmax(dim=0, keepdim=True).reshape(1,1,-1) tmp_a_interval[a:a+1,:,:] = torch.gather(input_interval_candidates[a:a+1,:,:],dim=2,index=a_best_index) self.a_interval[0] = tmp_a_interval.squeeze(-1) def _initialize_intervals(self, x): if self.init_layerwise: self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1) self.a_interval=[(x.max()/(self.a_qmax-0.5)).detach().view(1,1).repeat(self.n_a,1)] else: self.w_interval=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5)) self.a_interval=[((x.view(*x.shape[:-1],self.n_a,self.crb_acts).amax(list(range(len(x.shape)-1))+[-1],keepdim=False))/(self.a_qmax-0.5)).unsqueeze(-1)] self.a_interval.append(0.16997124254703522/self.a_qmax) def calibration_step2(self,x): # initialize intervals with minmax intervals self._initialize_intervals(x) # put raw outs on GPU raw_out_expanded = self.raw_out.to(x.device).unsqueeze(-2) # shape: B,*,1,oc raw_out_expanded_chunked = torch.cat(torch.chunk(raw_out_expanded.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: B,*,1,n_V,crb_rows # put raw grad on GPU self.raw_grad = self.raw_grad.to(x.device) if self.raw_grad != None else None # prepare weight intervals and similarities weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,n_V,1,n_H,1 input_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(1,1,-1) * self.a_interval[0].unsqueeze(-1) # shape: n_a,1,eq_n for e in range(self.search_round): # search for best weight interval self._search_best_w_interval(x, weight_interval_candidates, raw_out_expanded_chunked) # search for best input interval self._search_best_a_interval(x, input_interval_candidates, raw_out_expanded) self.raw_grad = self.raw_grad.to("cpu") if self.raw_grad != None else None self.calibrated = True out=self._bias_correction_quant_forward(x) del self.raw_input, self.raw_out, self.raw_grad return out class PTQSLBatchingQuantLinear(PTQSLQuantLinear): def __init__(self, in_features: int, out_features: int, bias: bool = True, mode = "raw", w_bit = 8, a_bit = 8, bias_bit = None, bias_correction = False, metric="L2_norm", search_round=1, eq_alpha=0, eq_beta=1, eq_n=100, parallel_eq_n=10, n_H=1, n_V=1, n_a=1, init_layerwise=False): super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, bias_correction=bias_correction, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_H=n_H, n_V=n_V, n_a=n_a, init_layerwise=init_layerwise) self.calib_size = None self.calib_batch_size = None self.calib_need_batching = False def _initialize_calib_parameters(self): """ set parameters for feeding calibration data """ self.calib_size = int(self.raw_input.shape[0]) self.calib_batch_size = int(self.raw_input.shape[0]) while True: numel = (2*(self.raw_input.numel()+self.raw_out.numel())/self.calib_size*self.calib_batch_size) # number of parameters on GPU self.parallel_eq_n = int((3*1024*1024*1024/4)//numel) if self.parallel_eq_n <= 1: self.calib_need_batching = True self.calib_batch_size //= 2 else: break def _initialize_intervals(self): # weight intervals if self.init_layerwise: self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1) else: self.w_interval=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5)) # activation intervals tmp_a_intervals = [] for b_st in range(0,self.calib_size,self.calib_batch_size): b_ed = min(self.calib_size, b_st+self.calib_batch_size) x_ = self.raw_input[b_st:b_ed].cuda() if self.init_layerwise: a_interval_=(x_.abs().max()/(self.a_qmax-0.5)).detach().view(1,1).repeat(self.n_a,1) else: a_interval_=((x_.view(*x_.shape[:-1],self.n_a,self.crb_acts).abs().amax(list(range(len(x_.shape)-1))+[-1],keepdim=False))/(self.a_qmax-0.5)).unsqueeze(-1) tmp_a_intervals.append(a_interval_) self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=True) def _get_similarity(self, tensor_raw, tensor_sim, metric=None, raw_grad=None): """ tensor_raw: *, features tensor_sim: *, features similarity: * It's your job to calculate mean on * dims! """ if metric == "cosine": similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=-1) else: if metric == "L1_norm": similarity = -torch.abs(tensor_raw - tensor_sim) elif metric == "L2_norm": similarity = -(tensor_raw - tensor_sim) ** 2 elif metric == "linear_weighted_L2_norm": similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2 elif metric == "square_weighted_L2_norm": similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2 elif metric == "hessian": assert raw_grad != None, f"raw_grad is None in _get_similarity!" raw_grad = raw_grad.reshape_as(tensor_raw) similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2 else: raise NotImplementedError(f"metric {metric} not implemented!") similarity = torch.mean(similarity, dim=-1) return similarity def _get_pearson_w(self, tensor_raw, tensor_sim): """ Quick implementation of similarity-aware linear quantization tensor_sim: b,*,parallel_eq_n,n_V,crb_rows tensor_raw: b,*,1,n_V,crb_rows """ b, parallel_eq_n, n_V = tensor_sim.shape[0],tensor_sim.shape[-3],tensor_sim.shape[-2] tensor_sim = tensor_sim.transpose(-1,-3).contiguous_().view(b,-1,n_V,parallel_eq_n) tensor_raw = tensor_raw.transpose(-1,-3).view(b,-1,n_V,1) tensor_sim_mean = tensor_sim.mean(dim=[0,1],keepdim=True) tensor_raw_mean = tensor_raw.mean(dim=[0,1],keepdim=True) similarity = torch.cosine_similarity(tensor_raw-tensor_raw_mean, tensor_sim-tensor_sim_mean, dim=1) # shape: b,n_V,parallel_eq_n similarity = similarity.permute(0,2,1).contiguous_() return similarity def _get_pearson_a(self, tensor_raw, tensor_sim): """ Quick implementation of similarity-aware linear quantization tensor_sim: b,*,parallel_eq_n,oc tensor_raw: b,*,1,oc """ b, parallel_eq_n = tensor_sim.shape[0],tensor_sim.shape[-2] tensor_sim = tensor_sim.transpose(-1,-2).contiguous_().view(b,-1,parallel_eq_n) tensor_raw = tensor_raw.transpose(-1,-2).view(b,-1,1) tensor_sim_mean = tensor_sim.mean(dim=[0,1],keepdim=True) tensor_raw_mean = tensor_raw.mean(dim=[0,1],keepdim=True) similarity = torch.cosine_similarity(tensor_raw-tensor_raw_mean, tensor_sim-tensor_sim_mean, dim=1) # shape: b,parallel_eq_n return similarity def _search_best_w_interval(self, weight_interval_candidates): tmp_w_interval = self.w_interval.unsqueeze(0) # shape: 1,n_V,1,n_H,1 for h in range(self.n_H): batch_similarities = [] # similarities, need to concatenate and calculate sum (equivalent to mean with argmax) for b_st in range(0, self.calib_size, self.calib_batch_size): b_ed = min(self.calib_size, b_st + self.calib_batch_size) x = self.raw_input[b_st:b_ed].cuda() raw_out_expanded = self.raw_out[b_st:b_ed].cuda().unsqueeze(-2) # shape: b,*,1,oc raw_out_expanded = torch.cat(torch.chunk(raw_out_expanded.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: b,*,1,n_V,crb_rows raw_grad = self.raw_grad[b_st:b_ed].cuda() # will be reshaped later similarities = [] for p_st in range(0,self.eq_n,self.parallel_eq_n): p_ed = min(self.eq_n, p_st+self.parallel_eq_n) cur_w_interval = tmp_w_interval.repeat(p_ed-p_st,1,1,1,1) cur_w_interval[:,:,:,h:h+1,:] = weight_interval_candidates[p_st:p_ed,:,:,h:h+1,:] # quantize weight and bias w_sim = self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).unsqueeze(0) # shape: 1,n_V,crb_rows,n_H,crb_cols w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,n_V,crb_rows,n_H,crb_cols w_sim = w_sim.view(-1,self.in_features) # shape: parallel_eq_n*oc,ic bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None # quantize input x_sim = self.quant_input(x) # calculate similarity and store them out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: b,*,parallel_eq_n*oc out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(-2), chunks=p_ed-p_st, dim=-1), dim=-2) # shape: b,*,parallel_eq_n,oc out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: b,*,parallel_eq_n,n_V,crb_rows if self.metric != "pearson": similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric, raw_grad) # shape: b,*,parallel_eq_n,n_V if len(similarity.shape) > 3: similarity = torch.mean(similarity, dim=list(range(1,len(similarity.shape)-2))) # shape: b, parallel_eq_n, n_V else: similarity = self._get_pearson_w(raw_out_expanded, out_sim) similarity = similarity.sum(dim=0, keepdim=True) # shape: 1, parallel_eq_n, n_V similarities.append(similarity) # store best weight interval of h into tmp_w_interval similarities = torch.cat(similarities, dim=1) # shape: 1, eq_n, n_V batch_similarities.append(similarities) batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) # shape: eq_n, n_V h_best_index = batch_similarities.argmax(dim=0).reshape(1,-1,1,1,1) # shape: 1,n_V,1,1,1 tmp_w_interval[:,:,:,h:h+1,:] = torch.gather(weight_interval_candidates[:,:,:,h:h+1,:],dim=0,index=h_best_index) self.w_interval = tmp_w_interval.squeeze(dim=0) def _search_best_a_interval(self, input_interval_candidates): tmp_a_interval = self.a_interval.unsqueeze(-1) # shape: n_a,1,1 for a in range(self.n_a): batch_similarities = [] # similarities, need to concatenate and calculate sum (equivalent to mean with argmax) for b_st in range(0, self.calib_size, self.calib_batch_size): b_ed = min(self.calib_size, b_st + self.calib_batch_size) x = self.raw_input[b_st:b_ed].cuda() raw_out_expanded = self.raw_out[b_st:b_ed].cuda().unsqueeze(-2) # shape: b,*,1,oc raw_grad = self.raw_grad[b_st:b_ed].cuda() # will be reshaped later similarities = [] for p_st in range(0,self.eq_n,self.parallel_eq_n): p_ed = min(self.eq_n, p_st+self.parallel_eq_n) cur_a_interval = tmp_a_interval.repeat(1,1,p_ed-p_st) # shape: n_a,1,parallel_eq_n cur_a_interval[a:a+1,:,:] = input_interval_candidates[a:a+1,:,p_st:p_ed] # quantize weight and bias w_sim, bias_sim = self.quant_weight_bias() # quantize input x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2).unsqueeze(-1) x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: b,*,n_a,crb_acts,parallel_eq_n x_sim = x_sim.permute(*list(range(len(x_sim.shape)-3)),-1,-3,-2).reshape(*x.shape[:-1],p_ed-p_st,x.shape[-1]) # shape: b,*,parallel_eq_n,ic # calculate similarity and store them out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: b,*,parallel_eq_n,oc if self.metric != "pearson": similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric, raw_grad) # shape: b,*,parallel_eq_n if len(similarity.shape) > 2: similarity = torch.mean(similarity, dim=list(range(1,len(similarity.shape)-1))) # shape: b, parallel_eq_n else: similarity = self._get_pearson_a(raw_out_expanded, out_sim) similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1, parallel_eq_n similarities.append(similarity) # store best input interval and store in tmp_a_interval similarities = torch.cat(similarities, dim=1) # shape: 1, eq_n batch_similarities.append(similarities) batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) # shape: eq_n a_best_index = batch_similarities.argmax(dim=0, keepdim=True).reshape(1,1,-1) tmp_a_interval[a:a+1,:,:] = torch.gather(input_interval_candidates[a:a+1,:,:],dim=2,index=a_best_index) self.a_interval = tmp_a_interval.squeeze(-1) def calibration_step2(self): """ Only use cached raw inputs/outs/grads """ self._initialize_calib_parameters() self._initialize_intervals() # prepare weight intervals and similarities weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,n_V,1,n_H,1 input_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(1,1,-1) * self.a_interval.unsqueeze(-1) # shape: n_a,1,eq_n for e in range(self.search_round): # search for best weight interval self._search_best_w_interval(weight_interval_candidates) # search for best input interval self._search_best_a_interval(input_interval_candidates) self.calibrated = True # self._bias_correction_quant_forward(self.raw_input.cuda()) # debugging del self.raw_input, self.raw_out, self.raw_grad return None class PostGeluPTQSLBatchingQuantLinear(PTQSLBatchingQuantLinear): """ An Agile implementation of PostGeluPTQSLBatchingQuantLinear use a_interval for positive activation quantization and a_neg_interval for negative activation quantization """ def __init__(self, in_features: int, out_features: int, bias: bool = True, mode = "raw", w_bit = 8, a_bit = 8, bias_bit = None, bias_correction = False, metric="L2_norm", search_round=1, eq_alpha=0, eq_beta=1, eq_n=100, parallel_eq_n=10, n_H=1, n_V=1, n_a=1, init_layerwise=False): super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, bias_correction=bias_correction, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_H=n_H, n_V=n_V, n_a=n_a, init_layerwise=init_layerwise) self.a_neg_interval = 0.16997124254703522/self.a_qmax def _initialize_intervals(self): # weight intervals if self.init_layerwise: self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1) else: self.w_interval=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5)) # activation intervals (for positive parts) if self.init_layerwise: tmp_a_intervals = [] for b_st in range(0,self.calib_size,self.calib_batch_size): b_ed = min(self.calib_size, b_st+self.calib_batch_size) x_ = self.raw_input[b_st:b_ed].cuda() a_interval_=(x_.max()/(self.a_qmax-0.5)).detach().view(1,1).repeat(self.n_a,1) tmp_a_intervals.append(a_interval_) self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=True) else: tmp_a_intervals = [] for b_st in range(0,self.calib_size,self.calib_batch_size): b_ed = min(self.calib_size, b_st+self.calib_batch_size) x_ = self.raw_input[b_st:b_ed].cuda() a_interval_=((x_.view(*x_.shape[:-1],self.n_a,self.crb_acts).amax(list(range(len(x_.shape)-1))+[-1],keepdim=False))/(self.a_qmax-0.5)).unsqueeze(-1) tmp_a_intervals.append(a_interval_) self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=True) def quant_input(self, x): # self.a_interval shape: n_a,1 # self.a_neg_interval shape: 1 x_=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2) x_pos=(x_/(self.a_interval)).round_().clamp_(0,self.a_qmax-1).mul_(self.a_interval) x_neg=(x_/(self.a_neg_interval)).round_().clamp_(-self.a_qmax,0).mul_(self.a_neg_interval) return (x_pos + x_neg).reshape_as(x) def _search_best_a_interval(self, input_interval_candidates): tmp_a_interval = self.a_interval.unsqueeze(-1) # shape: n_a,1,1 for a in range(self.n_a): batch_similarities = [] # similarities, need to concatenate and calculate sum (equivalent to mean with argmax) for b_st in range(0, self.calib_size, self.calib_batch_size): b_ed = min(self.calib_size, b_st + self.calib_batch_size) x = self.raw_input[b_st:b_ed].cuda() raw_out_expanded = self.raw_out[b_st:b_ed].cuda().unsqueeze(-2) # shape: b,*,1,oc raw_grad = self.raw_grad[b_st:b_ed].cuda() # will be reshaped later similarities = [] for p_st in range(0,self.eq_n,self.parallel_eq_n): p_ed = min(self.eq_n, p_st+self.parallel_eq_n) cur_a_interval = tmp_a_interval.repeat(1,1,p_ed-p_st) # shape: n_a,1,parallel_eq_n cur_a_interval[a:a+1,:,:] = input_interval_candidates[a:a+1,:,p_st:p_ed] # quantize weight and bias w_sim, bias_sim = self.quant_weight_bias() # quantize input x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2).unsqueeze(-1) x_pos=(x_sim/(cur_a_interval)).round_().clamp_(0,self.a_qmax-1)*(cur_a_interval) # shape: b,*,n_a,crb_acts,parallel_eq_n x_neg=(x_sim/(self.a_neg_interval)).round_().clamp_(-self.a_qmax,0)*(self.a_neg_interval) # shape: b,*,n_a,crb_acts,1 x_sim = (x_pos + x_neg).permute(*list(range(len(x_sim.shape)-3)),-1,-3,-2).reshape(*x.shape[:-1],p_ed-p_st,x.shape[-1]) # shape: b,*,parallel_eq_n,ic # calculate similarity and store them out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: b,*,parallel_eq_n,oc similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric, raw_grad) # shape: b,*,parallel_eq_n similarity = torch.mean(similarity, dim=list(range(1,len(similarity.shape)-1))) # shape: b, parallel_eq_n similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1, parallel_eq_n similarities.append(similarity) # store best input interval and store in tmp_a_interval similarities = torch.cat(similarities, dim=1) # shape: 1, eq_n batch_similarities.append(similarities) batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) # shape: eq_n a_best_index = batch_similarities.argmax(dim=0, keepdim=True).reshape(1,1,-1) tmp_a_interval[a:a+1,:,:] = torch.gather(input_interval_candidates[a:a+1,:,:],dim=2,index=a_best_index) self.a_interval = tmp_a_interval.squeeze(-1) ================================================ FILE: quant_layers/matmul.py ================================================ import numpy as np import torch from torch import nn from torch import Tensor from torch.nn import functional as F from itertools import product class MinMaxQuantMatMul(nn.Module): """Matrix Multiplication base class""" def __init__(self, A_bit=8, B_bit=8, mode="raw"): super().__init__() self.A_bit=A_bit self.B_bit=B_bit self.A_interval=None self.B_interval=None self.A_qmax=2**(self.A_bit-1) self.B_qmax=2**(self.B_bit-1) self.mode=mode self.raw_input = None self.raw_out = None def forward(self, A,B): if self.mode=='raw': out=A @ B elif self.mode=="quant_forward": out=self.quant_forward(A,B) elif self.mode=="calibration_step1": out=self.calibration_step1(A,B) elif self.mode=="calibration_step2": out=self.calibration_step2(A,B) else: raise NotImplementedError return out def quant_input(self,x,interval,qmax): x_sim=(x/interval).round_().clamp_(-qmax,qmax-1) x_sim.mul_(interval) return x_sim def quant_forward(self,A,B): assert self.calibrated is not None,f"You should run calibrate_forward before run quant_forward for {self}" A_sim=self.quant_input(A,self.A_interval,self.A_qmax) B_sim=self.quant_input(B,self.B_interval,self.B_qmax) out=A_sim@B_sim return out def calibration_step1(self,A,B): # step1: collection the FP32 values self.raw_input=A.cpu().detach(), B.cpu().detach() out=A@B self.raw_out=out.cpu().detach() return out def calibration_step2(self,A,B): # step2: search for the best S^w and S^o of each layer self.A_interval=(A.data.abs().max()/(self.A_qmax-0.5)).detach() self.B_interval=(B.data.abs().max()/(self.B_qmax-0.5)).detach() self.calibrated=True out=self.quant_forward(A,B) return out class PTQSLQuantMatMul(MinMaxQuantMatMul): """ Chunk matrix into blockes and quantize. Chunking follows naive padding strategy. Alternately search for best intervals of each individual blocks for A and B. two different scenarios: - Q @ K: - A's shape: B,H,S,W - B's shape: B,H,W,S - scores @ V: - A's shape: B,H,S,S - B's shape: B,H,S,W - interval shape: 1,n_G,1,n_V,1,n_H,1 """ def __init__(self, A_bit=8, B_bit=8, mode="raw", metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10, n_G_A=1, n_V_A=1, n_H_A=1, n_G_B=1, n_V_B=1, n_H_B=1, init_layerwise=False): super().__init__(A_bit=A_bit, B_bit=B_bit, mode=mode) self.metric = metric self.search_round = search_round self.eq_alpha = eq_alpha self.eq_beta = eq_beta self.eq_n = eq_n self.parallel_eq_n = parallel_eq_n self.n_G_A = n_G_A self.n_V_A = n_V_A self.n_H_A = n_H_A self.n_G_B = n_G_B self.n_V_B = n_V_B self.n_H_B = n_H_B # init these parameters in self.calibration_step1 self.crb_groups_A = None self.crb_groups_B = None self.crb_rows_A = None self.crb_cols_A = None self.crb_rows_B = None self.crb_cols_B = None self.pad_groups_A = None self.pad_groups_B = None self.pad_rows_A = None self.pad_rows_B = None self.pad_cols_A = None self.pad_cols_B = None self.raw_grad = None self.init_layerwise = init_layerwise def _get_padding_parameters(self, A, B): self.crb_groups_A = (A.shape[1]+self.n_G_A-1) // self.n_G_A self.crb_groups_B = (B.shape[1]+self.n_G_B-1) // self.n_G_B self.crb_rows_A = (A.shape[2]+self.n_V_A-1) // self.n_V_A self.crb_cols_A = (A.shape[3]+self.n_H_A-1) // self.n_H_A self.crb_rows_B = (B.shape[2]+self.n_V_B-1) // self.n_V_B self.crb_cols_B = (B.shape[3]+self.n_H_B-1) // self.n_H_B self.pad_groups_A = self.crb_groups_A*self.n_G_A - A.shape[1] self.pad_rows_A = self.crb_rows_A*self.n_V_A - A.shape[2] self.pad_cols_A = self.crb_cols_A*self.n_H_A - A.shape[3] self.pad_groups_B = self.crb_groups_B*self.n_G_B - B.shape[1] self.pad_rows_B = self.crb_rows_B*self.n_V_B - B.shape[2] self.pad_cols_B = self.crb_cols_B*self.n_H_B - B.shape[3] def quant_input_A(self, x): x = F.pad(x, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A]) x = x.view(-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A) x = (x/self.A_interval).round_().clamp(-self.A_qmax,self.A_qmax-1).mul_(self.A_interval) x = x.view(-1,self.n_G_A*self.crb_groups_A,self.n_V_A*self.crb_rows_A,self.n_H_A*self.crb_cols_A) x = x[:,:x.shape[1]-self.pad_groups_A,:x.shape[2]-self.pad_rows_A,:x.shape[3]-self.pad_cols_A] return x def quant_input_B(self, x): x = F.pad(x, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]) x = x.view(-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B) x = (x/self.B_interval).round_().clamp(-self.B_qmax,self.B_qmax-1).mul_(self.B_interval) x = x.view(-1,self.n_G_B*self.crb_groups_B,self.n_V_B*self.crb_rows_B,self.n_H_B*self.crb_cols_B) x = x[:,:x.shape[1]-self.pad_groups_B,:x.shape[2]-self.pad_rows_B,:x.shape[3]-self.pad_cols_B] return x def quant_forward(self, A, B): assert self.calibrated is not None,f"You should run calibrate_forward before run quant_forward for {self}" A_sim=self.quant_input_A(A) B_sim=self.quant_input_B(B) out=A_sim@B_sim return out def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1): """ tensor_raw: *, features, * tensor_sim: *, features, * similarity: * It's your job to calculate mean on non-feature * dims! Similarity without inherent feature structure is more welcome to parallelism. """ if metric == "cosine": similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=dim) # should only support dim=-1 and cannot be paralleled elif metric == "pearson": similarity = F.cosine_similarity(tensor_raw-torch.mean(tensor_raw), tensor_sim-torch.mean(tensor_sim), dim=dim) else: if metric == "L1_norm": similarity = -torch.abs(tensor_raw - tensor_sim) elif metric == "L2_norm": similarity = -(tensor_raw - tensor_sim) ** 2 elif metric == "linear_weighted_L2_norm": similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2 elif metric == "square_weighted_L2_norm": similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2 elif metric == "hessian": raw_grad = self.raw_grad.reshape_as(tensor_raw) similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2 else: raise NotImplementedError(f"metric {metric} not implemented!") similarity = torch.mean(similarity, dim=dim) return similarity def _search_best_A_interval(self, A, B, A_interval_candidates): """ Modularization of searching best interval """ # recalculate A_pad A_pad = F.pad(A, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A]).unsqueeze(0).view(1,-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A) tmp_A_interval = self.A_interval.unsqueeze(0) # shape: 1,1,n_G,1,n_V,1,n_H,1 # out-of-loop optimization B_sim = self.quant_input_B(B).unsqueeze(0) # shape: 1,B,H,dim2,dim3 for v, h in product(range(self.n_V_A), range(self.n_H_A)): similarities = [] for p_st in range(0, self.eq_n, self.parallel_eq_n): p_ed = min(self.eq_n,p_st+self.parallel_eq_n) # quantize A cur_A_interval = tmp_A_interval.repeat(p_ed-p_st,1,1,1,1,1,1,1) cur_A_interval[:,:,:,:,v:v+1,:,h:h+1,:] = A_interval_candidates[p_st:p_ed,:,:,:,v:v+1,:,h:h+1,:] A_sim = (A_pad/cur_A_interval).round_().clamp_(-self.A_qmax,self.A_qmax-1).mul_(cur_A_interval) A_sim = A_sim.view(p_ed-p_st,-1,A.shape[1]+self.pad_groups_A,A.shape[2]+self.pad_rows_A,A.shape[3]+self.pad_cols_A) # shape: parallel_eq_n,B,H*,dim1*,dim2* (* stand for padding) A_sim = A_sim[:,:,:A.shape[1],:A.shape[2],:A.shape[3]] # shape: parallel_eq_n,B,H,dim1,dim2 # quantize B, this quantization is optimized out of loop # calculate similarity and store them out_sim = A_sim @ B_sim # shape: parallel_eq_n,B,H,dim1,dim3 similarity = self._get_similarity(self.raw_out, out_sim, self.metric) # shape: parallel_eq_n,B,H,dim1 similarity = similarity.mean([1,3]) # shape: parallel_eq_n,H (remaining mean operation will be done later on) similarities.append(similarity) # calculate best similarity for this block similarities = torch.cat(similarities, 0) # shape: eq_n,H similarities = F.pad(similarities, [0,self.pad_groups_A]).view(self.eq_n,self.n_G_A,self.crb_groups_A).mean(-1) # shape: eq_n, n_G_A best_index = torch.argmax(similarities, dim=0, keepdim=False).view(1,1,-1,1,1,1,1,1) tmp_A_interval[:,:,:,:,v:v+1,:,h:h+1,:] = torch.gather(A_interval_candidates[:,:,:,:,v:v+1,:,h:h+1,:],dim=0,index=best_index) self.A_interval = tmp_A_interval.squeeze(0) def _search_best_B_interval(self, A, B, B_interval_candidates): """ Modularization of searching best interval """ # recalculate B_pad B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B) tmp_B_interval = self.B_interval.unsqueeze(0) # shape: 1,1,n_G,1,n_V,1,n_H,1 # out-of-loop optimization A_sim = self.quant_input_A(A).unsqueeze(0) # shape: 1,B,H,dim1,dim2 for v, h in product(range(self.n_V_B), range(self.n_H_B)): similarities = [] for p_st in range(0, self.eq_n, self.parallel_eq_n): p_ed = min(self.eq_n,p_st+self.parallel_eq_n) # quantize A, this quantization is optimized out of loop # quantize B cur_B_interval = tmp_B_interval.repeat(p_ed-p_st,1,1,1,1,1,1,1) cur_B_interval[:,:,:,:,v:v+1,:,h:h+1,:] = B_interval_candidates[p_st:p_ed,:,:,:,v:v+1,:,h:h+1,:] B_sim = (B_pad/cur_B_interval).round_().clamp_(-self.B_qmax,self.B_qmax-1).mul_(cur_B_interval) B_sim = B_sim.view(p_ed-p_st,-1,B.shape[1]+self.pad_groups_B,B.shape[2]+self.pad_rows_B,B.shape[3]+self.pad_cols_B) # shape: parallel_eq_n,B,H*,dim2*,dim3* (* stand for padding) B_sim = B_sim[:,:,:B.shape[1],:B.shape[2],:B.shape[3]] # shape: parallel_eq_n,B,H,dim2,dim3 # calculate similarity and store them out_sim = A_sim @ B_sim # shape: parallel_eq_n,B,H,dim1,dim3 similarity = self._get_similarity(self.raw_out, out_sim, self.metric) # shape: parallel_eq_n,B,H,dim1 similarity = similarity.mean([1,3]) # shape: parallel_eq_n,H (remaining mean operation will be done later on) similarities.append(similarity) # calculate best similarity for this block similarities = torch.cat(similarities, 0) # shape: eq_n,H similarities = F.pad(similarities, [0,self.pad_groups_B]).view(self.eq_n,self.n_G_B,self.crb_groups_B).mean(-1) # shape: eq_n, n_G_B best_index = torch.argmax(similarities, dim=0, keepdim=False).view(1,1,-1,1,1,1,1,1) tmp_B_interval[:,:,:,:,v:v+1,:,h:h+1,:] = torch.gather(B_interval_candidates[:,:,:,:,v:v+1,:,h:h+1,:],dim=0,index=best_index) self.B_interval = tmp_B_interval.squeeze(0) def _initialize_intervals(self, A, B): # pad A and B for future quantization self._get_padding_parameters(A, B) # put it here because hessian does not use calibration step 1 A_pad = F.pad(A, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A]).unsqueeze(0).view(1,-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A) # shape: 1,B,n_G,crb_groups,n_V,crb_rows,n_H,crb_cols B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B) # initialize intervals with minmax intervals if self.init_layerwise: self.A_interval = (A.abs().max()/(self.A_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_A,1,self.n_V_A,1,self.n_H_A,1) self.B_interval = (B.abs().max()/(self.B_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_B,1,self.n_V_B,1,self.n_H_B,1) else: self.A_interval=(A_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.A_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1 self.B_interval=(B_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.B_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1 def calibration_step2(self, A, B): # put raw outs/grads on GPU self.raw_out = self.raw_out.unsqueeze(0).to(A.device) self.raw_grad = self.raw_grad.to(A.device) if self.raw_grad != None else None self._initialize_intervals(A, B) # prepare weight intervals and similarities A_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.A_interval.unsqueeze(0) B_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.B_interval.unsqueeze(0) for e in range(self.search_round): # search for best A interval self._search_best_A_interval(A, B, A_interval_candidates) # search for best B interval self._search_best_B_interval(A, B, B_interval_candidates) # put raw data back to cpu self.raw_out = self.raw_out.squeeze(0).to("cpu") self.raw_grad = self.raw_grad.to("cpu") if self.raw_grad != None else None # finish calibration and output the result self.calibrated = True del self.raw_input, self.raw_out, self.raw_grad out=self.quant_forward(A,B) return out class SoSPTQSLQuantMatMul(PTQSLQuantMatMul): """ Sublayerwise PTQ on matmul modules with Split-of-Softmax (SoS) on score matrix. Data after softmaxing has highly biased distribution, making it difficult to quantize with uniform quantization. An elegant tradeoff between great majority of unimportant values and few crucial values is impossible under low bit quantization. Therefore, we propose to split complete interval of (0, 1) into several smaller intervals and perform uniform quantization on each. We could manually assgin or search for the best split point. Currently, we only consider single split point scenarios, since this proves to be effective enough. The algorithm no longer requires PTQSL on score matrix, and will ignore relevant parameters. with proper hardware implementation, we don't need to use a sign bit anymore. """ def __init__(self, A_bit=8, B_bit=8, mode="raw", metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10, n_G_A=1, n_V_A=1, n_H_A=1, n_G_B=1, n_V_B=1, n_H_B=1, init_layerwise=False, split=None): super().__init__(A_bit=A_bit, B_bit=B_bit, mode=mode, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_G_A=n_G_A, n_V_A=n_V_A, n_H_A=n_H_A, n_G_B=n_G_B, n_V_B=n_V_B, n_H_B=n_H_B, init_layerwise=init_layerwise) self.n_G_A = 1 self.n_V_A = 1 self.n_H_A = 1 self.A_qmax = 2**(self.A_bit-1) # well, still need it self.split = split if split != None: self.A_interval = self.split/(self.A_qmax-1) def quant_input_A(self, x): x_high = (x.clamp(self.split, 1)*(self.A_qmax-1)).round_().clamp_(0,self.A_qmax-1)/(self.A_qmax-1) x_low = (x.clamp(0, self.split)/self.A_interval).round_().clamp_(0,self.A_qmax-1)*self.A_interval return x_high + x_low def _search_best_A_interval(self, A, B, split_candidates): """ search for best split point """ # out-of-loop optimization A_ = A.unsqueeze(0) # B_sim = self.quant_input_B(B).unsqueeze(0) # shape: 1,B,H,dim2,dim3 B_sim = B.unsqueeze(0) similarities = [] for i in range(len(split_candidates)): # quantize A cur_A_interval = split_candidates[i]/(self.A_qmax-1) A_high = (A_.clamp(split_candidates[i], 1)*(self.A_qmax-1)).round_().clamp_(0,self.A_qmax-1)/(self.A_qmax-1) A_low =( A_.clamp(0, split_candidates[i])/cur_A_interval).round_().clamp_(0,self.A_qmax-1)*cur_A_interval A_sim = A_high + A_low # shape: 1,B,H,S,S # quantize B, this quantization is optimized out of loop # calculate similarity and store them (dim1=dim2=S, dim3=W) out_sim = A_sim @ B_sim # shape: 1,B,H,dim1,dim3 similarity = self._get_similarity(self.raw_out, out_sim, self.metric) # shape: parallel_eq_n,B,H,dim1 similarity = similarity.mean([1,2,3]) # shape: 1 similarities.append(similarity) # calculate best similarity for this block similarities = torch.cat(similarities, 0) # shape: eq_n best_index = torch.argmax(similarities, dim=0, keepdim=False) self.split = split_candidates[best_index] self.A_interval = self.split/(self.A_qmax-1) # debugging # print(f"best split: {self.split}") def _initialize_intervals(self, A, B): # pad A and B for future quantization self._get_padding_parameters(A, B) B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B) # initialize intervals with minmax intervals self.split = 0.01 self.A_interval = self.split/(self.A_qmax-1) if self.init_layerwise: self.B_interval = (B.abs().max()/(self.B_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_B,1,self.n_V_B,1,self.n_H_B,1) else: self.B_interval=(B_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.B_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1 def calibration_step2(self, A, B): # put raw outs/grads on GPU self.raw_out = self.raw_out.unsqueeze(0).to(A.device) self.raw_grad = self.raw_grad.to(A.device) if self.raw_grad != None else None self._initialize_intervals(A, B) # prepare weight intervals and similarities A_split_candidates = torch.tensor([2**(-i) for i in range(20)]).cuda() # split_eq_alpha, split_eq_beta, split_eq_n = 0.002, 0.03, 50 # A_split_candidates = torch.tensor([split_eq_alpha + (split_eq_beta- split_eq_alpha)*i/split_eq_n for i in range(split_eq_n + 1)]).cuda() B_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.B_interval.unsqueeze(0) for e in range(self.search_round): # search for best A interval self._search_best_A_interval(A, B, A_split_candidates) # search for best B interval self._search_best_B_interval(A, B, B_interval_candidates) # put raw data back to cpu self.raw_out = self.raw_out.squeeze(0).to("cpu") self.raw_grad = self.raw_grad.to("cpu") if self.raw_grad != None else None # finish calibration and output the result self.calibrated = True del self.raw_input, self.raw_out, self.raw_grad out=self.quant_forward(A,B) return out class PTQSLBatchingQuantMatMul(PTQSLQuantMatMul): def __init__(self, A_bit=8, B_bit=8, mode="raw", metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10, n_G_A=1, n_V_A=1, n_H_A=1, n_G_B=1, n_V_B=1, n_H_B=1, init_layerwise=False): super().__init__(A_bit=A_bit, B_bit=B_bit, mode=mode, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_G_A=n_G_A, n_V_A=n_V_A, n_H_A=n_H_A, n_G_B=n_G_B, n_V_B=n_V_B, n_H_B=n_H_B, init_layerwise=init_layerwise) def _initialize_calib_parameters(self): """ set parameters for feeding calibration data """ self.calib_size = int(self.raw_input[0].shape[0]) self.calib_batch_size = int(self.raw_input[0].shape[0]) while True: numel = ((self.raw_input[0].numel()+self.raw_input[1].numel()+2*self.raw_out.numel())/self.calib_size*self.calib_batch_size) # number of parameters on GPU self.parallel_eq_n = int((3*1024*1024*1024/4)//numel) if self.parallel_eq_n <= 1: self.calib_need_batching = True self.calib_batch_size //= 2 else: break def _get_padding_parameters(self, A, B): """ We adopt a head-wise quantization here """ self.n_G_A = A.shape[1] self.n_G_B = B.shape[1] super()._get_padding_parameters(A,B) def _initialize_intervals(self): # pad A and B for future quantization self._get_padding_parameters(self.raw_input[0], self.raw_input[1]) # put it here because hessian does not use calibration step 1 # initialize intervals with minmax intervals tmp_A_intervals = [] tmp_B_intervals = [] for b_st in range(0,self.calib_size,self.calib_batch_size): b_ed = min(self.calib_size, b_st+self.calib_batch_size) A, B = self.raw_input[0][b_st:b_ed].cuda(), self.raw_input[1][b_st:b_ed].cuda() if self.init_layerwise: A_interval = (A.abs().max()/(self.A_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_A,1,self.n_V_A,1,self.n_H_A,1) B_interval = (B.abs().max()/(self.B_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_B,1,self.n_V_B,1,self.n_H_B,1) else: A_pad = F.pad(A, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A]).unsqueeze(0).view(1,-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A) B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B) A_interval=(A_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.A_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1 B_interval=(B_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.B_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1 tmp_A_intervals.append(A_interval) tmp_B_intervals.append(B_interval) self.A_interval = torch.cat(tmp_A_intervals, dim=0).amax(0, keepdim=True) self.B_interval = torch.cat(tmp_B_intervals, dim=0).amax(0, keepdim=True) def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1, raw_grad=None): """ tensor_raw: *, features, * tensor_sim: *, features, * similarity: * It's your job to calculate mean on non-feature * dims! Similarity without inherent feature structure is more welcome to parallelism. """ if metric == "cosine": similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=dim) # should only support dim=-1 and cannot be paralleled elif metric == "pearson": similarity = F.cosine_similarity(tensor_raw-torch.mean(tensor_raw,dim=dim,keepdim=True), tensor_sim-torch.mean(tensor_sim,dim=dim,keepdim=True), dim=dim) # should only support dim=-1 and cannot be paralleled # a quick implementation of pearson similarity # tensor_raw: 1,B,H,dim1,dim3 # tensor_sim: parallel_eq_n,B,H,dim1,dim3 # parallel_eq_n,B,H,dim1,dim3 = tensor_sim.shape # tensor_sim = tensor_sim.view(parallel_eq_n,B,-1) # tensor_raw = tensor_raw.view(1,B,-1) # tensor_sim_mean = tensor_sim.mean(dim=[1,2],keepdim=True) # tensor_raw_mean = tensor_raw.mean(dim=[1,2],keepdim=True) # similarity = F.cosine_similarity(tensor_raw-tensor_raw_mean,tensor_sim-tensor_sim_mean,dim=-1) # shape: parallel_eq_n,B # similarity = similarity.reshape(parallel_eq_n,B,1,1) # restore two dims else: if metric == "L1_norm": similarity = -torch.abs(tensor_raw - tensor_sim) elif metric == "L2_norm": similarity = -(tensor_raw - tensor_sim) ** 2 elif metric == "linear_weighted_L2_norm": similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2 elif metric == "square_weighted_L2_norm": similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2 elif metric == "hessian": assert raw_grad != None, f"No raw_grad in PTQSLBatchingQuantMatMul!" raw_grad = raw_grad.reshape_as(tensor_raw) similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2 else: raise NotImplementedError(f"metric {metric} not implemented!") similarity = torch.mean(similarity, dim=dim) return similarity def _search_best_A_interval(self, A_interval_candidates): """ Modularization of searching best interval """ tmp_A_interval = self.A_interval.unsqueeze(0) # shape: 1,1,n_G,1,n_V,1,n_H,1 # out-of-loop optimization for v, h in product(range(self.n_V_A), range(self.n_H_A)): batch_similarities = [] # similarities, need to concatenate and calculate sum for b_st in range(0, self.calib_size, self.calib_batch_size): b_ed = min(self.calib_size, b_st + self.calib_batch_size) A = self.raw_input[0][b_st:b_ed].cuda() A_pad = F.pad(A, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A]).unsqueeze(0).view(1,-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A) B = self.raw_input[1][b_st:b_ed].cuda() B_sim = self.quant_input_B(B).unsqueeze(0) # shape: 1,b,H,dim2,dim3 raw_out = self.raw_out[b_st:b_ed].unsqueeze(0).cuda() raw_grad = self.raw_grad[b_st:b_ed].cuda() similarities = [] for p_st in range(0, self.eq_n, self.parallel_eq_n): p_ed = min(self.eq_n,p_st+self.parallel_eq_n) # quantize A cur_A_interval = tmp_A_interval.repeat(p_ed-p_st,1,1,1,1,1,1,1) cur_A_interval[:,:,:,:,v:v+1,:,h:h+1,:] = A_interval_candidates[p_st:p_ed,:,:,:,v:v+1,:,h:h+1,:] A_sim = (A_pad/cur_A_interval).round_().clamp_(-self.A_qmax,self.A_qmax-1).mul_(cur_A_interval) A_sim = A_sim.view(p_ed-p_st,-1,A.shape[1]+self.pad_groups_A,A.shape[2]+self.pad_rows_A,A.shape[3]+self.pad_cols_A) # shape: parallel_eq_n,B,H*,dim1*,dim2* (* stand for padding) A_sim = A_sim[:,:,:A.shape[1],:A.shape[2],:A.shape[3]] # shape: parallel_eq_n,b,H,dim1,dim2 # quantize B, this quantization is optimized out of loop # calculate similarity and store them out_sim = A_sim @ B_sim # shape: parallel_eq_n,B,H,dim1,dim3 similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad=raw_grad) # shape: parallel_eq_n,b,H,dim1 similarity = similarity.mean([3]) # shape: parallel_eq_n,b,H (remaining mean operation will be done later on) similarity = similarity.sum(dim=1, keepdim=True) # shape: parallel_eq_n,1,H similarities.append(similarity) # calculate best similarity for this block similarities = torch.cat(similarities, 0) # shape: eq_n,1,H batch_similarities.append(similarities) batch_similarities = torch.cat(batch_similarities, dim=1).sum(dim=1, keepdim=False) #shape: eq_n,H batch_similarities = F.pad(batch_similarities, [0,self.pad_groups_A]).view(self.eq_n,self.n_G_A,self.crb_groups_A).mean(-1) # shape: eq_n, n_G_A best_index = torch.argmax(batch_similarities, dim=0, keepdim=False).view(1,1,-1,1,1,1,1,1) tmp_A_interval[:,:,:,:,v:v+1,:,h:h+1,:] = torch.gather(A_interval_candidates[:,:,:,:,v:v+1,:,h:h+1,:],dim=0,index=best_index) self.A_interval = tmp_A_interval.squeeze(0) def _search_best_B_interval(self, B_interval_candidates): """ Modularization of searching best interval """ tmp_B_interval = self.B_interval.unsqueeze(0) # shape: 1,1,n_G,1,n_V,1,n_H,1 # out-of-loop optimization for v, h in product(range(self.n_V_B), range(self.n_H_B)): batch_similarities = [] # similarities, need to concatenate and calculate sum for b_st in range(0, self.calib_size, self.calib_batch_size): b_ed = min(self.calib_size, b_st + self.calib_batch_size) A = self.raw_input[0][b_st:b_ed].cuda() A_sim = self.quant_input_A(A).unsqueeze(0) # shape: 1,B,H,dim1,dim2 B = self.raw_input[1][b_st:b_ed].cuda() B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B) raw_out = self.raw_out[b_st:b_ed].unsqueeze(0).cuda() raw_grad = self.raw_grad[b_st:b_ed].cuda() similarities = [] for p_st in range(0, self.eq_n, self.parallel_eq_n): p_ed = min(self.eq_n,p_st+self.parallel_eq_n) # quantize A, this quantization is optimized out of loop # quantize B cur_B_interval = tmp_B_interval.repeat(p_ed-p_st,1,1,1,1,1,1,1) cur_B_interval[:,:,:,:,v:v+1,:,h:h+1,:] = B_interval_candidates[p_st:p_ed,:,:,:,v:v+1,:,h:h+1,:] B_sim = (B_pad/cur_B_interval).round_().clamp_(-self.B_qmax,self.B_qmax-1).mul_(cur_B_interval) B_sim = B_sim.view(p_ed-p_st,-1,B.shape[1]+self.pad_groups_B,B.shape[2]+self.pad_rows_B,B.shape[3]+self.pad_cols_B) # shape: parallel_eq_n,b,H*,dim2*,dim3* (* stand for padding) B_sim = B_sim[:,:,:B.shape[1],:B.shape[2],:B.shape[3]] # shape: parallel_eq_n,b,H,dim2,dim3 # calculate similarity and store them out_sim = A_sim @ B_sim # shape: parallel_eq_n,b,H,dim1,dim3 similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad=raw_grad) # shape: parallel_eq_n,b,H,dim1 similarity = similarity.mean([3]) # shape: parallel_eq_n,b,H (remaining mean operation will be done later on) similarity = similarity.sum(dim=1, keepdim=True) # shape: parallel_eq_n,1,H similarities.append(similarity) # calculate best similarity for this block similarities = torch.cat(similarities, 0) # shape: eq_n,1,H batch_similarities.append(similarities) batch_similarities = torch.cat(batch_similarities, dim=1).sum(dim=1, keepdim=False) #shape: eq_n,H batch_similarities = F.pad(batch_similarities, [0,self.pad_groups_B]).view(self.eq_n,self.n_G_B,self.crb_groups_B).mean(-1) # shape: eq_n, n_G_B best_index = torch.argmax(batch_similarities, dim=0, keepdim=False).view(1,1,-1,1,1,1,1,1) tmp_B_interval[:,:,:,:,v:v+1,:,h:h+1,:] = torch.gather(B_interval_candidates[:,:,:,:,v:v+1,:,h:h+1,:],dim=0,index=best_index) self.B_interval = tmp_B_interval.squeeze(0) def calibration_step2(self): self._initialize_calib_parameters() self._initialize_intervals() A_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.A_interval.unsqueeze(0) B_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.B_interval.unsqueeze(0) for e in range(self.search_round): # search for best A interval self._search_best_A_interval(A_interval_candidates) # search for best B interval self._search_best_B_interval(B_interval_candidates) self.calibrated = True del self.raw_input, self.raw_out, self.raw_grad class SoSPTQSLBatchingQuantMatMul(PTQSLBatchingQuantMatMul): def __init__(self, A_bit=8, B_bit=8, mode="raw", metric="L2_norm", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10, n_G_A=1, n_V_A=1, n_H_A=1, n_G_B=1, n_V_B=1, n_H_B=1, init_layerwise=False, split=None): super().__init__(A_bit=A_bit, B_bit=B_bit, mode=mode, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_G_A=n_G_A, n_V_A=n_V_A, n_H_A=n_H_A, n_G_B=n_G_B, n_V_B=n_V_B, n_H_B=n_H_B, init_layerwise=init_layerwise) self.n_G_A = 1 self.n_V_A = 1 self.n_H_A = 1 # with proper hardware implementation, we don't need to use a sign bit anymore self.A_qmax = 2**(self.A_bit-1) self.split = split if split != None: self.A_interval = self.split/(self.A_qmax-1) def quant_input_A(self, x): x_high = (x.clamp(self.split, 1)*(self.A_qmax-1)).round_().clamp_(0,self.A_qmax-1)/(self.A_qmax-1) x_low = (x.clamp(0, self.split)/self.A_interval).round_().clamp_(0,self.A_qmax-1)*self.A_interval return x_high + x_low def _search_best_A_interval(self, split_candidates): batch_similarities = [] for b_st in range(0, self.calib_size, self.calib_batch_size): b_ed = min(self.calib_size, b_st + self.calib_batch_size) A = self.raw_input[0][b_st:b_ed].unsqueeze(0).cuda() B = self.raw_input[1][b_st:b_ed].unsqueeze(0).cuda() B_sim = B raw_out = self.raw_out[b_st:b_ed].unsqueeze(0).cuda() raw_grad = self.raw_grad[b_st:b_ed].cuda() similarities = [] for i in range(len(split_candidates)): # quantize A cur_A_interval = split_candidates[i]/(self.A_qmax-1) A_high = (A.clamp(split_candidates[i], 1)*(self.A_qmax-1)).round_().clamp_(0,self.A_qmax-1)/(self.A_qmax-1) A_low =( A.clamp(0, split_candidates[i])/cur_A_interval).round_().clamp_(0,self.A_qmax-1)*cur_A_interval A_sim = A_high + A_low # shape: 1,b,H,S,S # quantize B, this quantization is optimized out of loop # calculate similarity and store them (dim1=dim2=S, dim3=W) out_sim = A_sim @ B_sim # shape: 1,b,H,dim1,dim3 similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad=raw_grad) # shape: parallel_eq_n,b,H,dim1 similarity = similarity.mean([2,3]) # shape: parallel_eq_n, b similarity = similarity.sum(dim=1,keepdim=True) # parallel_eq_n, 1 similarities.append(similarity) # calculate best similarity for this block similarities = torch.cat(similarities, 0) # shape: eq_n, 1 batch_similarities.append(similarities) batch_similarities = torch.cat(batch_similarities, dim=1).sum(dim=1, keepdim=False) #shape: eq_n best_index = torch.argmax(batch_similarities, dim=0, keepdim=False) self.split = split_candidates[best_index] self.A_interval = self.split/(self.A_qmax-1) # debugging # print(f"best split: {self.split}") def calibration_step2(self): self._initialize_calib_parameters() self._initialize_intervals() A_split_candidates = torch.tensor([2**(-i) for i in range(20)]).cuda() B_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.B_interval.unsqueeze(0) for e in range(self.search_round): # search for best A interval self._search_best_A_interval(A_split_candidates) # search for best B interval self._search_best_B_interval(B_interval_candidates) self.calibrated = True del self.raw_input, self.raw_out, self.raw_grad ================================================ FILE: utils/datasets.py ================================================ """ Reuse version v4 Author: Hahn Yuan """ import PIL import torch import argparse import numpy as np import os import copy import torchvision.transforms as transforms import torchvision.datasets as datasets from torchvision.datasets import ImageFolder,DatasetFolder import torch.utils.data import re import warnings from PIL import Image from PIL import ImageFile import random import torch.nn.functional as F from torch.utils.data import Dataset def calculate_n_correct(outputs,targets): _, predicted = outputs.max(1) n_correct= predicted.eq(targets).sum().item() return n_correct class SetSplittor(): def __init__(self,fraction=0.2): self.fraction=fraction def split(self,dataset): pass class LoaderGenerator(): """ """ def __init__(self,root,dataset_name,train_batch_size=1,test_batch_size=1,num_workers=0,kwargs={}): self.root=root self.dataset_name=str.lower(dataset_name) self.train_batch_size=train_batch_size self.test_batch_size=test_batch_size self.num_workers=num_workers self.kwargs=kwargs self.items=[] self._train_set=None self._test_set=None self._calib_set=None self.train_transform=None self.test_transform=None self.train_loader_kwargs = { 'num_workers': self.num_workers , 'pin_memory': kwargs.get('pin_memory',True), 'drop_last':kwargs.get('drop_last',False) } self.test_loader_kwargs=self.train_loader_kwargs.copy() self.load() @property def train_set(self): pass @property def test_set(self): pass def load(self): pass def train_loader(self): assert self.train_set is not None return torch.utils.data.DataLoader(self.train_set, batch_size=self.train_batch_size, shuffle=True, **self.train_loader_kwargs) def test_loader(self,shuffle=False,batch_size=None): assert self.test_set is not None if batch_size is None: batch_size=self.test_batch_size return torch.utils.data.DataLoader(self.test_set, batch_size=batch_size, shuffle=shuffle, **self.test_loader_kwargs) def val_loader(self): assert self.val_set is not None return torch.utils.data.DataLoader(self.val_set, batch_size=self.test_batch_size, shuffle=False, **self.test_loader_kwargs) def trainval_loader(self): assert self.trainval_set is not None return torch.utils.data.DataLoader(self.trainval_set, batch_size=self.train_batch_size, shuffle=True, **self.train_loader_kwargs) def calib_loader(self,num=1024,seed=3): if self._calib_set is None: np.random.seed(seed) inds=np.random.permutation(len(self.train_set))[:num] self._calib_set=torch.utils.data.Subset(copy.deepcopy(self.train_set),inds) self._calib_set.dataset.transform=self.test_transform return torch.utils.data.DataLoader(self._calib_set, batch_size=num, shuffle=False, **self.train_loader_kwargs) class CIFARLoaderGenerator(LoaderGenerator): def load(self): if self.dataset_name=='cifar100': self.dataset_fn=datasets.CIFAR100 normalize = transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762]) elif self.dataset_name=='cifar10': self.dataset_fn=datasets.CIFAR10 normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]) else: raise NotImplementedError self.train_transform = transforms.Compose([ transforms.RandomCrop(32,padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) self.test_transform = transforms.Compose([ transforms.ToTensor(), normalize, ]) @property def train_set(self): if self._train_set is None: self._train_set=self.dataset_fn(self.root, train=True, download=True, transform=self.train_transform) return self._train_set @property def test_set(self): if self._test_set is None: self._test_set=self.dataset_fn(self.root, train=False, transform=self.test_transform) return self._test_set class COCOLoaderGenerator(LoaderGenerator): def load(self): # download from https://github.com/pjreddie/darknet/tree/master/scripts/get_coco_dataset.sh self.train_set = DetectionListDataset(os.path.join(self.root,'trainvalno5k.txt'),transform=augmentation_detection_tansforms) self.test_set = DetectionListDataset(os.path.join(self.root,'5k.txt'),transform=detection_tansforms,multiscale=False) self.train_loader_kwargs={"collate_fn":self.train_set.collate_fn} self.test_loader_kwargs={"collate_fn":self.test_set.collate_fn} class DetectionListDataset(Dataset): def __init__(self, list_path, img_size=416, multiscale=True, transform=None): with open(list_path, "r") as file: self.img_files = [path for path in file.readlines()] self.label_files = [ path.replace("images", "labels").replace(".png", ".txt").replace(".jpg", ".txt") for path in self.img_files ] self.img_size = img_size self.max_objects = 100 self.multiscale = multiscale self.min_size = self.img_size - 3 * 32 self.max_size = self.img_size + 3 * 32 self.batch_count = 0 self.transform = transform def __getitem__(self, index): try: img_path = self.img_files[index % len(self.img_files)].rstrip() img = np.array(Image.open(img_path).convert('RGB'), dtype=np.uint8) except Exception as e: print(f"Could not read image '{img_path}'.") return try: label_path = self.label_files[index % len(self.img_files)].rstrip() # Ignore warning if file is empty with warnings.catch_warnings(): warnings.simplefilter("ignore") boxes = np.loadtxt(label_path).reshape(-1, 5) except Exception as e: print(f"Could not read label '{label_path}'.") return if self.transform: try: img, bb_targets = self.transform((img, boxes)) except: print(f"Could not apply transform.") return return img_path, img, bb_targets def collate_fn(self, batch): self.batch_count += 1 # Drop invalid images batch = [data for data in batch if data is not None] paths, imgs, bb_targets = list(zip(*batch)) # Selects new image size every tenth batch if self.multiscale and self.batch_count % 10 == 0: self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32)) # Resize images to input shape imgs = torch.stack([F.interpolate(img.unsqueeze(0), size=self.img_size, mode="nearest").squeeze(0) for img in imgs]) # Add sample index to targets for i, boxes in enumerate(bb_targets): boxes[:, 0] = i bb_targets = torch.cat(bb_targets, 0) return paths, imgs, bb_targets def __len__(self): return len(self.img_files) # def faster_im_loader(path): # with open(path,'rb') as f: # bgr_array = TurboJPEG().decode(f.read()) # rgb_array=np.concatenate([bgr_array[:,:,2:3],bgr_array[:,:,1:2],bgr_array[:,:,0:1]],-1) # return torch.Tensor(rgb_array)/255 class ImageNetLoaderGenerator(LoaderGenerator): def load(self): normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.train_transform = transforms.Compose([ transforms.Resize(256), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) self.test_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ]) @property def train_set(self): if self._train_set is None: self._train_set=ImageFolder(os.path.join(self.root,'train'), self.train_transform) return self._train_set @property def test_set(self): if self._test_set is None: self._test_set=ImageFolder(os.path.join(self.root,'val'), self.test_transform) return self._test_set class CacheDataset(Dataset): def __init__(self,datas,targets) -> None: super().__init__() self.datas=datas self.targets=targets def __getitem__(self,idx): return self.datas[idx],self.targets[idx] def __len__(self): return len(self.datas) class FasterImageNetLoaderGenerator(ImageNetLoaderGenerator): def test_loader(self,shuffle=False,batch_size=None): cache='/dev/shm/imagenet.pkl' assert self.test_set is not None if batch_size is None: batch_size=self.test_batch_size if os.path.exists(cache): print("Loading the dataset from shared memory") datas,targets=torch.load(cache) else: print("Preprocessing the dataset and save it to shared memory") loader=torch.utils.data.DataLoader(self.test_set, batch_size=batch_size, shuffle=shuffle, **self.test_loader_kwargs) datas=[] targets=[] for data,target in loader: datas.append(data) targets.append(target) datas=torch.cat(datas,0) targets=torch.cat(targets,0) torch.save([datas,targets],cache) dataset=CacheDataset(datas,targets) return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, **self.test_loader_kwargs) class DebugLoaderGenerator(LoaderGenerator): def load(self): version=re.findall("\d+",self.dataset_name)[0] class DebugSet(torch.utils.data.Dataset): def __getitem__(self,idx): if version=='0': return torch.ones([1,4,4]),0 if version=='1': return torch.ones([1,8,8]),0 if version=='2': return torch.ones([1,1,1]),0 if version=='3': return torch.ones([1,3,3]),0 else: raise NotImplementedError(f"version {version} of Debug dataset is not supported") def __len__(self): return 1 self.train_set=DebugSet() self.test_set=DebugSet() def get_dataset(args:argparse.Namespace): """ Preparing Datasets, args: dataset (required): MNIST, cifar10/100, ImageNet, coco dataset_root: str, default='./datasets' num_workers: int batch_size: int test_batch_size: int val_fraction: float, default=0 """ dataset_name=str.lower(args.dataset) dataset_root=getattr(args,'dataset_root','./datasets') num_workers=args.num_workers if hasattr(args,'num_workers') else 4 batch_size=args.batch_size if hasattr(args,'batch_size') else 64 test_batch_size=args.test_batch_size if hasattr(args,'test_batch_size') else batch_size val_fraction=args.val_fraction if hasattr(args,"val_fraction") else 0 if "cifar" in dataset_name: # Data loading code g=CIFARLoaderGenerator(dataset_root,args.dataset,batch_size,test_batch_size,num_workers) elif "coco" in dataset_name: g=COCOLoaderGenerator(dataset_root,args.dataset,batch_size,test_batch_size,num_workers) elif "debug" in dataset_name: g=DebugLoaderGenerator(dataset_root,args.dataset,batch_size,test_batch_size,num_workers) elif args.dataset=='ImageNet': g=ImageNetLoaderGenerator(dataset_root,args.dataset,batch_size,test_batch_size,num_workers) else: raise NotImplementedError return g.train_loader(),g.test_loader() import timm from timm.models.vision_transformer import VisionTransformer from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform class ViTImageNetLoaderGenerator(ImageNetLoaderGenerator): """ DataLoader for Vision Transformer. To comply with timm's framework, we use the model's corresponding transform. """ def __init__(self, root, dataset_name, train_batch_size, test_batch_size, num_workers, kwargs={}): kwargs.update({"pin_memory":False}) super().__init__(root, dataset_name, train_batch_size=train_batch_size, test_batch_size=test_batch_size, num_workers=num_workers, kwargs=kwargs) def load(self): model = self.kwargs.get("model", None) assert model != None, f"No model in ViTImageNetLoaderGenerator!" config = resolve_data_config({}, model=model) self.train_transform = create_transform(**config, is_training=True) self.test_transform = create_transform(**config) ================================================ FILE: utils/integer.py ================================================ from numpy import dtype from quant_layers.matmul import MinMaxQuantMatMul, PTQSLBatchingQuantMatMul, PTQSLQuantMatMul, SoSPTQSLBatchingQuantMatMul, SoSPTQSLQuantMatMul from quant_layers.linear import MinMaxQuantLinear, PTQSLBatchingQuantLinear, PostGeluPTQSLBatchingQuantLinear, PostGeluPTQSLQuantLinear import torch import torch.nn as nn import torch.nn.functional as F def quantize_int_weight(module): """ get weight of type 'uint8' of a quantized module. Bias are not quantized and you can use raw bias. """ assert hasattr(module, 'weight'), f"module {module} does not have weight" assert module.w_bit == 8, f"module {module}'s weight is quantized with {module.w_bit} bits" w_int = (module.weight/module.w_interval).round_().clamp_(-module.w_qmax, module.w_qmax-1) w_int = w_int.cpu().detach().to(torch.int8) return w_int def dequantize_int_weight(module, w_int): """ Make sure it's the same module that generates w_int """ w_sim = module.w_interval.cpu() * w_int.float() return w_sim def quantize_matmul_input(input, interval, qmax, n_G, n_V, n_H, crb_groups, crb_rows, crb_cols): """ quantize input matrix of matmul operation, with respect to sublayerwise padding settings """ pad_groups = crb_groups*n_G - input.shape[1] pad_rows = crb_rows*n_V - input.shape[2] pad_cols = crb_cols*n_H - input.shape[3] x = F.pad(input, [0,pad_cols,0,pad_rows,0,pad_groups]) x = x.view(-1,n_G,crb_groups,n_V,crb_rows,n_H,crb_cols) x = (x/interval).round_().clamp(-qmax,qmax-1) x = x.view(-1,n_G*crb_groups,n_V*crb_rows,n_H*crb_cols) x = x[:,:x.shape[1]-pad_groups,:x.shape[2]-pad_rows,:x.shape[3]-pad_cols] return x def quantize_int_activation(module, input): """ Quantize current inputs into uint8 and store them as an attribute of the module. The function is a pre-forward hook that need to be manually added to the calibrated model. You need to manipulate the cached data before feeding another batch of pictures. Currently only support int8. (For twin quantization, we use uint8) For twin quantization: - For softmax, the MSB being 1 means using large interval, while MSB being 0 means using small interval. - For post-GELU, the MSB serves as sign bit. We use 1 for positive values and 0 for negative values. """ if isinstance(module, PostGeluPTQSLQuantLinear) or isinstance(module, PostGeluPTQSLBatchingQuantLinear): assert module.a_bit == 8, f"module {module}'s activation is quantized with {module.a_bit} bits" x = input[0] int_input_pos = (x/module.a_interval).round_().clamp_(0, module.a_qmax-1) int_input_pos = int_input_pos.detach().to(torch.uint8) + 128 int_input_neg = (x/module.a_neg_interval).round_().clamp_(-module.a_qmax+1, 0).abs() int_input_neg = int_input_neg.detach().to(torch.uint8) int_input = (int_input_pos + int_input_neg).cpu() module.int_input = [int_input] elif isinstance(module, MinMaxQuantLinear): assert module.a_bit == 8, f"module {module}'s activation is quantized with {module.a_bit} bits" x = input[0] int_input = (x/module.a_interval).round_().clamp_(-module.a_qmax, module.a_qmax-1) int_input = int_input.cpu().detach().to(torch.int8) module.int_input = [int_input] elif isinstance(module, SoSPTQSLQuantMatMul) or isinstance(module, SoSPTQSLBatchingQuantMatMul): assert module.A_bit == 8, f"module {module}'s matrix A is quantized with {module.A_bit} bits" assert module.B_bit == 8, f"module {module}'s matrix B is quantized with {module.B_bit} bits" A, B = input[0], input[1] A_high = (A.clamp(module.split, 1)*(module.A_qmax-1)).round_().clamp_(0,module.A_qmax-1) A_high = A_high.detach().to(torch.uint8) + 128 A_low = (A.clamp(0, module.split)/module.A_interval).round_().clamp_(0,module.A_qmax-1) A_low = A_low.detach().to(torch.uint8) A_int = (A_high + A_low).cpu() B_int = quantize_matmul_input(B,module.B_interval,module.B_qmax,module.n_G_B,module.n_V_B,module.n_H_B,module.crb_groups_B,module.crb_rows_B,module.crb_cols_B) B_int = B_int.cpu().detach().to(torch.int8) module.int_input = [A_int, B_int] elif isinstance(module, PTQSLQuantMatMul) or isinstance(module, PTQSLBatchingQuantMatMul): assert module.A_bit == 8, f"module {module}'s matrix A is quantized with {module.A_bit} bits" assert module.B_bit == 8, f"module {module}'s matrix B is quantized with {module.B_bit} bits" A, B = input[0], input[1] A_int = quantize_matmul_input(A,module.A_interval,module.A_qmax,module.n_G_A,module.n_V_A,module.n_H_A,module.crb_groups_A,module.crb_rows_A,module.crb_cols_A) A_int = A_int.cpu().detach().to(torch.int8) B_int = quantize_matmul_input(B,module.B_interval,module.B_qmax,module.n_G_B,module.n_V_B,module.n_H_B,module.crb_groups_B,module.crb_rows_B,module.crb_cols_B) B_int = B_int.cpu().detach().to(torch.int8) module.int_input = [A_int, B_int] def get_model_int_weight(wrapped_modules): """ Get quantized weights (in int8) of a model. Return: A dict, with modules' names as keys, and int weights as values. """ int_weights = {} for name, m in wrapped_modules.items(): try: int_weights[name] = quantize_int_weight(m) except: pass return int_weights ================================================ FILE: utils/models.py ================================================ from types import MethodType import torch import torch.nn as nn import torch.nn.functional as F import timm from timm.models import vision_transformer from timm.models.vision_transformer import Attention from timm.models.swin_transformer import WindowAttention def attention_forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) # attn = (q @ k.transpose(-2, -1)) * self.scale attn = self.matmul1(q, k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) del q, k # x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.matmul2(attn, v).transpose(1, 2).reshape(B, N, C) del attn, v x = self.proj(x) x = self.proj_drop(x) return x def window_attention_forward(self, x, mask = None): B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q = q * self.scale # attn = (q @ k.transpose(-2, -1)) attn = self.matmul1(q, k.transpose(-2,-1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) # x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.matmul2(attn, v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x class MatMul(nn.Module): def forward(self, A, B): return A @ B def get_net(name): """ Get a vision transformer model. This will replace matrix multiplication operations with matmul modules in the model. Currently support almost all models in timm.models.transformers, including: - vit_tiny/small/base/large_patch16/patch32_224/384, - deit_tiny/small/base(_distilled)_patch16_224, - deit_base(_distilled)_patch16_384, - swin_tiny/small/base/large_patch4_window7_224, - swin_base/large_patch4_window12_384 These models are finetuned on imagenet-1k and should use ViTImageNetLoaderGenerator for calibration and testing. """ net = timm.create_model(name, pretrained=True) for name, module in net.named_modules(): if isinstance(module, Attention): setattr(module, "matmul1", MatMul()) setattr(module, "matmul2", MatMul()) module.forward = MethodType(attention_forward, module) if isinstance(module, WindowAttention): setattr(module, "matmul1", MatMul()) setattr(module, "matmul2", MatMul()) module.forward = MethodType(window_attention_forward, module) net.cuda() net.eval() return net ================================================ FILE: utils/net_wrap.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from utils.models import MatMul import re def _fold_bn(conv_module, bn_module): w = conv_module.weight.data y_mean = bn_module.running_mean y_var = bn_module.running_var safe_std = torch.sqrt(y_var + bn_module.eps) w_view = (conv_module.out_channels, 1, 1, 1) if bn_module.affine: weight = w * (bn_module.weight / safe_std).view(w_view) beta = bn_module.bias - bn_module.weight * y_mean / safe_std if conv_module.bias is not None: bias = bn_module.weight * conv_module.bias / safe_std + beta else: bias = beta else: weight = w / safe_std.view(w_view) beta = -y_mean / safe_std if conv_module.bias is not None: bias = conv_module.bias / safe_std + beta else: bias = beta return weight, bias def fold_bn_into_conv(conv_module, bn_module): w, b = _fold_bn(conv_module, bn_module) if conv_module.bias is None: conv_module.bias = nn.Parameter(b.data) else: conv_module.bias.data = b.data conv_module.weight.data = w.data def wrap_modules_in_net(net,cfg): wrapped_modules={} module_dict={} module_types = {"qkv":"qlinear_qkv", "proj":'qlinear_proj', 'fc1':'qlinear_MLP_1', 'fc2':"qlinear_MLP_2", 'head':'qlinear_classifier','matmul1':"qmatmul_qk", 'matmul2':"qmatmul_scorev", "reduction": "qlinear_reduction"} it=[(name,m) for name,m in net.named_modules()] for name,m in it: module_dict[name]=m idx=name.rfind('.') if idx==-1: idx=0 father_name=name[:idx] if father_name in module_dict: father_module=module_dict[father_name] else: raise RuntimeError(f"father module {father_name} not found") if isinstance(m,nn.Conv2d): # Embedding Layer idx = idx+1 if idx != 0 else idx new_m=cfg.get_module("qconv",m.in_channels,m.out_channels,m.kernel_size,m.stride,m.padding,m.dilation,m.groups,m.bias is not None,m.padding_mode) new_m.weight.data=m.weight.data new_m.bias=m.bias replace_m=new_m wrapped_modules[name] = new_m setattr(father_module,name[idx:],replace_m) elif isinstance(m,nn.Linear): # Linear Layer idx = idx+1 if idx != 0 else idx new_m = cfg.get_module(module_types[name[idx:]],m.in_features,m.out_features) new_m.weight.data=m.weight.data new_m.bias=m.bias replace_m=new_m wrapped_modules[name] = new_m setattr(father_module,name[idx:],replace_m) elif isinstance(m,MatMul): # Matmul Layer idx = idx+1 if idx != 0 else idx new_m = cfg.get_module(module_types[name[idx:]]) replace_m=new_m wrapped_modules[name] = new_m setattr(father_module,name[idx:],replace_m) print("Completed net wrap.") return wrapped_modules def wrap_certain_modules_in_net(net,cfg,layers,modules_to_wrap,wrap_embedding=False): """ wrap specific module inside transformer block of specific layer layers: list of integers, indicating layers to wrap modules_to_wrap: list of modules to wrap """ wrapped_modules={} module_dict={} module_types = {"qkv":"qlinear_qkv", "proj":'qlinear_proj', 'fc1':'qlinear_MLP_1', 'fc2':"qlinear_MLP_2", 'head':'qlinear_classifier','matmul1':"qmatmul_qk", 'matmul2':"qmatmul_scorev"} it=[(name,m) for name,m in net.named_modules()] for name,m in it: module_dict[name]=m idx=name.rfind('.') if idx==-1: idx=0 father_name=name[:idx] if father_name in module_dict: father_module=module_dict[father_name] else: raise RuntimeError(f"father module {father_name} not found") layer = re.search('\d+', name) if layer is not None: # inside a transformer block layer = int(name[layer.span()[0]:layer.span()[1]]) if layer not in layers: continue if isinstance(m,nn.Conv2d): # Embedding Layer idx = idx+1 if idx != 0 else idx if not wrap_embedding: continue # timm patch_embed use proj as well... # if name[idx:] not in modules_to_wrap: continue new_m=cfg.get_module("qconv",m.in_channels,m.out_channels,m.kernel_size,m.stride,m.padding,m.dilation,m.groups,m.bias is not None,m.padding_mode) new_m.weight.data=m.weight.data new_m.bias=m.bias replace_m=new_m wrapped_modules[name] = new_m setattr(father_module,name[idx:],replace_m) elif isinstance(m,nn.Linear): # Linear Layer idx = idx+1 if idx != 0 else idx if name[idx:] not in modules_to_wrap: continue new_m = cfg.get_module(module_types[name[idx:]],m.in_features,m.out_features) new_m.weight.data=m.weight.data new_m.bias=m.bias replace_m=new_m wrapped_modules[name] = new_m setattr(father_module,name[idx:],replace_m) elif isinstance(m,MatMul): # Matmul Layer idx = idx+1 if idx != 0 else idx if name[idx:] not in modules_to_wrap: continue new_m = cfg.get_module(module_types[name[idx:]]) replace_m=new_m wrapped_modules[name] = new_m setattr(father_module,name[idx:],replace_m) print("Completed net wrap.") return wrapped_modules ================================================ FILE: utils/quant_calib.py ================================================ from numpy import isin import torch from quant_layers.conv import MinMaxQuantConv2d from quant_layers.linear import MinMaxQuantLinear, PTQSLQuantLinear from quant_layers.matmul import MinMaxQuantMatMul, PTQSLQuantMatMul import torch.nn.functional as F from tqdm import tqdm class QuantCalibrator(): """ Modularization of quant calib. Notice: all quant modules has method "calibration_step1" that should only store raw inputs and outputs all quant modules has method "calibration_step2" that should only quantize its intervals and we assume we could feed in all calibration data in one batch, without backward propagations sequential calibration is memory-friendly, while parallel calibration may consume hundreds of GB of memory. """ def __init__(self, net, wrapped_modules, calib_loader, sequential=True): self.net = net self.wrapped_modules = wrapped_modules self.calib_loader = calib_loader self.sequential = sequential self.calibrated = False def sequential_quant_calib(self): """ A quick implementation of calibration. Assume calibration dataset could be fed at once. """ # run calibration n_calibration_steps=2 for step in range(n_calibration_steps): print(f"Start calibration step={step+1}") for name,module in self.wrapped_modules.items(): # corner cases for calibrated modules if hasattr(module, "calibrated"): if step == 1: module.mode = "raw" elif step == 2: module.mode = "quant_forward" else: module.mode=f'calibration_step{step+1}' with torch.no_grad(): for inp,target in self.calib_loader: inp=inp.cuda() self.net(inp) # finish calibration for name,module in self.wrapped_modules.items(): module.mode='quant_forward' torch.cuda.empty_cache() # memory footprint cleanup print("sequential calibration finished") def parallel_quant_calib(self): """ A quick implementation of parallel quant calib Assume calibration dataset could be fed at once, and memory could hold all raw inputs/outs """ # calibration step1: collect raw data print(f"Start calibration step=1") for name,module in self.wrapped_modules.items(): # corner cases for calibrated modules if hasattr(module, "calibrated"): module.mode = "raw" else: module.mode=f'calibration_step1' with torch.no_grad(): for inp,target in self.calib_loader: inp=inp.cuda() self.net(inp) # calibration step2: each module run calibration with collected raw data for name,module in self.wrapped_modules.items(): if hasattr(module, "calibrated"): continue else: module.mode=f"calibration_step2" with torch.no_grad(): if isinstance(module, MinMaxQuantLinear): module.forward(module.raw_input.cuda()) elif isinstance(module, MinMaxQuantConv2d): module.forward(module.raw_input.cuda()) elif isinstance(module, MinMaxQuantMatMul): module.forward(module.raw_input[0].cuda(), module.raw_input[1].cuda()) torch.cuda.empty_cache() # finish calibration for name,module in self.wrapped_modules.items(): module.mode='quant_forward' torch.cuda.empty_cache() # memory footprint cleanup print("calibration finished") def quant_calib(self): calib_layers=[] for name,module in self.wrapped_modules.items(): calib_layers.append(name) print(f"prepare parallel calibration for {calib_layers}") if self.sequential: self.sequential_quant_calib() else: self.parallel_quant_calib() self.calibrated = True def batching_quant_calib(self): calib_layers=[] for name,module in self.wrapped_modules.items(): calib_layers.append(name) print(f"prepare parallel calibration for {calib_layers}") print("start calibration") # assume wrapped modules are in order (true for dict in python>=3.5) q = tqdm(self.wrapped_modules.items(), desc="Brecq") for name, module in q: q.set_postfix_str(name) # add fp and bp hooks to current modules, which bypass calibration step 1 # precedent modules are using quant forward hooks = [] if isinstance(module, MinMaxQuantLinear): hooks.append(module.register_forward_hook(linear_forward_hook)) if isinstance(module, MinMaxQuantConv2d): hooks.append(module.register_forward_hook(conv2d_forward_hook)) if isinstance(module, MinMaxQuantMatMul): hooks.append(module.register_forward_hook(matmul_forward_hook)) # feed in calibration data, and store the data for inp, target in self.calib_loader: for batch_st in range(0,self.calib_loader.batch_size,self.batch_size): self.net.zero_grad() inp_ = inp[batch_st:batch_st+self.batch_size].cuda() self.net(inp_) del inp, target torch.cuda.empty_cache() # replace cached raw_inputs, raw_outs if isinstance(module, MinMaxQuantLinear): module.raw_input = torch.cat(module.raw_input, dim=0) module.raw_out = torch.cat(module.raw_out, dim=0) if isinstance(module, MinMaxQuantConv2d): module.raw_input = torch.cat(module.raw_input, dim=0) module.raw_out = torch.cat(module.raw_out, dim=0) if isinstance(module, MinMaxQuantMatMul): module.raw_input = [torch.cat(_, dim=0) for _ in module.raw_input] module.raw_out = torch.cat(module.raw_out, dim=0) for hook in hooks: hook.remove() # run calibration step2 with torch.no_grad(): if isinstance(module, MinMaxQuantLinear): module.calibration_step2() if isinstance(module, MinMaxQuantConv2d): module.calibration_step2() if isinstance(module, MinMaxQuantMatMul): module.calibration_step2() torch.cuda.empty_cache() # finishing up current module calibration if self.sequential: module.mode = "quant_forward" else: module.mode = "raw" # finish calibration for name, module in self.wrapped_modules.items(): module.mode = "quant_forward" print("calibration finished") def grad_hook(module, grad_input, grad_output): if module.raw_grad is None: module.raw_grad = [] module.raw_grad.append(grad_output[0].cpu().detach()) # that's a tuple! def linear_forward_hook(module, input, output): if module.raw_input is None: module.raw_input = [] if module.raw_out is None: module.raw_out = [] module.raw_input.append(input[0].cpu().detach()) module.raw_out.append(output.cpu().detach()) def conv2d_forward_hook(module, input, output): if module.raw_input is None: module.raw_input = [] if module.raw_out is None: module.raw_out = [] module.raw_input.append(input[0].cpu().detach()) module.raw_out.append(output.cpu().detach()) def matmul_forward_hook(module, input, output): if module.raw_input is None: module.raw_input = [[],[]] if module.raw_out is None: module.raw_out = [] module.raw_input[0].append(input[0].cpu().detach()) module.raw_input[1].append(input[1].cpu().detach()) module.raw_out.append(output.cpu().detach()) class HessianQuantCalibrator(QuantCalibrator): """ Modularization of hessian_quant_calib Hessian metric needs gradients of layer outputs to weigh the loss, which calls for back propagation in calibration, both sequentially and parallelly. Despite the complexity of bp, hessian quant calibrator is compatible with other non-gradient quantization metrics. """ def __init__(self, net, wrapped_modules, calib_loader, sequential=False, batch_size=1): super().__init__(net, wrapped_modules, calib_loader, sequential=sequential) self.batch_size = batch_size def quant_calib(self): """ An implementation of original hessian calibration. """ calib_layers=[] for name,module in self.wrapped_modules.items(): calib_layers.append(name) print(f"prepare parallel calibration for {calib_layers}") print("start hessian calibration") # get raw_pred as target distribution with torch.no_grad(): for inp, _ in self.calib_loader: raw_pred = self.net(inp.cuda()) raw_pred_softmax = F.softmax(raw_pred, dim=-1).detach() torch.cuda.empty_cache() # assume wrapped modules are in order (true for dict in python>=3.5) q = tqdm(self.wrapped_modules.items(), desc="Brecq") for name, module in q: q.set_postfix_str(name) # add fp and bp hooks to current modules, which bypass calibration step 1 # precedent modules are using quant forward hooks = [] if isinstance(module, MinMaxQuantLinear): hooks.append(module.register_forward_hook(linear_forward_hook)) if isinstance(module, MinMaxQuantConv2d): hooks.append(module.register_forward_hook(conv2d_forward_hook)) if isinstance(module, MinMaxQuantMatMul): hooks.append(module.register_forward_hook(matmul_forward_hook)) if hasattr(module, "metric") and module.metric == "hessian": hooks.append(module.register_backward_hook(grad_hook)) # feed in calibration data, and store the data for inp, target in self.calib_loader: for batch_st in range(0,self.calib_loader.batch_size,self.batch_size): self.net.zero_grad() inp_ = inp[batch_st:batch_st+self.batch_size].cuda() pred = self.net(inp_) loss = F.kl_div(F.log_softmax(pred, dim=-1), raw_pred_softmax[batch_st:batch_st+self.batch_size], reduction="batchmean") loss.backward() del inp, target, pred, loss torch.cuda.empty_cache() # replace cached raw_inputs, raw_outs if isinstance(module, MinMaxQuantLinear): module.raw_input = torch.cat(module.raw_input, dim=0) module.raw_out = torch.cat(module.raw_out, dim=0) if isinstance(module, MinMaxQuantConv2d): module.raw_input = torch.cat(module.raw_input, dim=0) module.raw_out = torch.cat(module.raw_out, dim=0) if isinstance(module, MinMaxQuantMatMul): module.raw_input = [torch.cat(_, dim=0) for _ in module.raw_input] module.raw_out = torch.cat(module.raw_out, dim=0) if hasattr(module, "metric") and module.metric == "hessian": module.raw_grad = torch.cat(module.raw_grad, dim=0) for hook in hooks: hook.remove() # run calibration step2 with torch.no_grad(): if isinstance(module, MinMaxQuantLinear): module.calibration_step2(module.raw_input.cuda()) if isinstance(module, MinMaxQuantConv2d): module.calibration_step2(module.raw_input.cuda()) if isinstance(module, MinMaxQuantMatMul): module.calibration_step2(module.raw_input[0].cuda(), module.raw_input[1].cuda()) torch.cuda.empty_cache() # finishing up current module calibration if self.sequential: module.mode = "quant_forward" else: module.mode = "raw" # finish calibration for name, module in self.wrapped_modules.items(): module.mode = "quant_forward" print("hessian calibration finished") def batching_quant_calib(self): calib_layers=[] for name,module in self.wrapped_modules.items(): calib_layers.append(name) print(f"prepare parallel calibration for {calib_layers}") print("start hessian calibration") # get raw_pred as target distribution with torch.no_grad(): for inp, _ in self.calib_loader: raw_pred = self.net(inp.cuda()) raw_pred_softmax = F.softmax(raw_pred, dim=-1).detach() torch.cuda.empty_cache() # assume wrapped modules are in order (true for dict in python>=3.5) q = tqdm(self.wrapped_modules.items(), desc="Hessian") for name, module in q: q.set_postfix_str(name) # add fp and bp hooks to current modules, which bypass calibration step 1 # precedent modules are using quant forward hooks = [] if isinstance(module, MinMaxQuantLinear): hooks.append(module.register_forward_hook(linear_forward_hook)) if isinstance(module, MinMaxQuantConv2d): hooks.append(module.register_forward_hook(conv2d_forward_hook)) if isinstance(module, MinMaxQuantMatMul): hooks.append(module.register_forward_hook(matmul_forward_hook)) if hasattr(module, "metric"): hooks.append(module.register_backward_hook(grad_hook)) # feed in calibration data, and store the data for inp, target in self.calib_loader: for batch_st in range(0,self.calib_loader.batch_size,self.batch_size): self.net.zero_grad() inp_ = inp[batch_st:batch_st+self.batch_size].cuda() pred = self.net(inp_) loss = F.kl_div(F.log_softmax(pred, dim=-1), raw_pred_softmax[batch_st:batch_st+self.batch_size], reduction="batchmean") loss.backward() del inp, target, pred, loss torch.cuda.empty_cache() # replace cached raw_inputs, raw_outs if isinstance(module, MinMaxQuantLinear): module.raw_input = torch.cat(module.raw_input, dim=0) module.raw_out = torch.cat(module.raw_out, dim=0) if isinstance(module, MinMaxQuantConv2d): module.raw_input = torch.cat(module.raw_input, dim=0) module.raw_out = torch.cat(module.raw_out, dim=0) if isinstance(module, MinMaxQuantMatMul): module.raw_input = [torch.cat(_, dim=0) for _ in module.raw_input] module.raw_out = torch.cat(module.raw_out, dim=0) if hasattr(module, "metric"): module.raw_grad = torch.cat(module.raw_grad, dim=0) for hook in hooks: hook.remove() # run calibration step2 with torch.no_grad(): if isinstance(module, MinMaxQuantLinear): module.calibration_step2() if isinstance(module, MinMaxQuantConv2d): module.calibration_step2() if isinstance(module, MinMaxQuantMatMul): module.calibration_step2() torch.cuda.empty_cache() # finishing up current module calibration if self.sequential: module.mode = "quant_forward" else: module.mode = "raw" # finish calibration for name, module in self.wrapped_modules.items(): module.mode = "quant_forward" print("hessian calibration finished")