[
  {
    "path": ".gitignore",
    "content": "tmp\n*.pyc\n__pycache__\n*.pth\n.vscode\ncheckpoints\n*.log\n*.csv\n*.png\n*.jpg\noutput\n*.weights\n*.tmp.*\ndata\nckt\n*.out\n*.zip\n*.json\ntest.ipynb\nint_weights"
  },
  {
    "path": "README.md",
    "content": "# PTQ4ViT\nPost-Training Quantization Framework for Vision Transformers.\nWe use the twin uniform quantization method to reduce the quantization error on these activation values.\nAnd we use a Hessian guided metric to evaluate different scaling factors, which improves the accuracy of calibration with a small cost.\nThe quantized vision transformers (ViT, DeiT, and Swin) achieve near-lossless prediction accuracy (less than 0.5\\% drop at 8-bit quantization) on the ImageNet classification task. Please read the [paper](https://arxiv.org/abs/2111.12293) for details.\n\n## Updates\n\n*19/07/2022*\nAdd discussion on Base PTQ, and provide more ablation study results.\n\n### Number of Calibration Images\n\n| Model        | W8A8 #ims=32 | W6A6 #ims=32 | W8A8 #ims=128 | W6A6 #ims=128 |\n|:------------:|:------------:|:------------:|:-------------:|:-------------:|\n| ViT-S/224/32 | 75.58        | 71.91        |  75.54        | 72.29         |\n| ViT-S/224    | 81.00        | 78.63        |  80.99        | 78.44         |\n| ViT-B/224    | 84.25        | 81.65        |  84.27        | 81.84         |\n| ViT-B/384    | 85.83        | 83.35        |  85.81        | 83.84         |\n| DeiT-S/224   | 79.47        | 76.28        |  79.41        | 76.51         |\n| DeiT-B/224   | 81.48        | 80.25        |  81.54        | 80.30         |\n| DeiT-B/384   | 82.97        | 81.55        |  83.01        | 81.67         |\n| Swin-T/224   | 81.25        | 80.47        |  81.27        | 80.30         |\n| Swin-S/224   | 83.11        | 82.38        |  83.15        | 82.38         |\n| Swin-B/224   | 85.15        | 84.01        |  85.17        | 84.15         |\n| Swin-B/384   | 86.39        | 85.39        |  86.36        | 85.45         |\n\n| Model        | Time #ims=32 | Time #ims=128 |\n|:------------:|:------------:|:-------------:|\n| ViT-S/224/32 | 2 min        | 5 min         |\n| ViT-S/224    | 3 min        | 7 min         |\n| ViT-B/224    | 4 min        | 13 min        |\n| ViT-B/384    | 12 min       | 43 min        |\n| DeiT-S/224   | 3 min        | 7 min         |\n| DeiT-B/224   | 4 min        | 16 min        |\n| DeiT-B/384   | 14 min       | 52 min        |\n| Swin-T/224   | 3 min        | 9 min         |\n| Swin-S/224   | 8 min        | 17 min        |\n| Swin-B/224   | 10 min       | 23 min        |\n| Swin-B/384   | 25 min       | 69 min        |\n\nOne of the targets of PTQ4ViT is to quickly quantize a vision transformer. \nWe have proposed to pre-compute the output and gradient of each layer and compute the influence of scaling factor candidates in batches to reduce the quantization time. \nAs demonstrated in the second table, PTQ4ViT can quantize most vision transformers in several minutes using 32 calibration images. \nUsing 128 calibration images significantly  increases  the  quantization  time.  \nWe observe the Top-1 accuracy varies slightly in the first table, demonstrating PTQ4ViT is not very sensitive to the number of calibration images.\n\n### Base PTQ\nBase PTQ is a simple quantization strategy and serves as a benchmark for our experiments. \nLike PTQ4ViT, we quantize all weights and inputs for fully-connect layers (including the first projection layer and the last prediction layer), as well as all input matrices of matrix multiplication operations. \nFor fully-connected layers, we use layerwise scaling factors $\\Delta_W$ for weight quantization and $\\Delta_X$ for input quantization; while for matrix multiplication operations, we use $\\Delta_A$ and $\\Delta_B$ for A's quantization and B's quantization respectively. \n\nTo get the best scaling factors, we apply a linear grid search on the search space. \nThe same as EasyQuantand Liu et al., we take hyper-parameters $\\alpha=0.5$, $\\beta = 1.2$, one search round and use cosine distance as the metric. \nNote that in PTQ4ViT, we change the hyper-parameters to $\\alpha=0$, $\\beta = 1.2$ and three search rounds, which slightly improves the performance.\n\nIt should be noticed that Base PTQ adopts a parallel quantization paradigm, which makes it essentially different from sequential quantization paradigms such as EasyQuant. \nIn sequential quantization, the input data of the current quantizing layer is generated with all previous layers quantizing weights and activations. \nWhile in parallel quantization, the input data of the current quantizing layer is simply the raw output of the previous layer. \n\nIn practice, we found sequential quantization on vision transformers suffers from significant accuracy degradation on small calibration datasets. \nWhile parallel quantization shows robustness on small calibration datasets. \nTherefore, we choose parallel quantization for both Base PTQ and PTQ4ViT.\n\n### More Ablation Study\n\nWe supply more ablation studies for the hyper-parameters.\nIt is enough to set the number of quantization intervals $\\ge$ 20 (accuracy change $< 0.3\\%$).\nIt is enough to set the upper bound of m $\\ge$ 15 (no accuracy change).\nThe best settings of alpha and beta vary from different layers. \nIt is appropriate to set $\\alpha=0$ and $\\beta=1/2^{k-1}$, which has little impact on search efficiency.\nWe observe that search rounds has little impact on the prediction accuracy (accuracy change $<$ 0.05\\% when search rounds $>1$).\n\nWe randomly take 32 calibration images to quantize different models 20 times and we observe the fluctuation is not significant. \nThe mean/std of accuracies are: ViT-S/32 $75.55\\%/0.055\\%$ , ViT-S $80.96\\%/0.046\\%$, ViT-B $84.12\\%/0.068\\%$, DeiT-S $79.45\\%/0.094\\%$ , and Swin-S $83.11\\%/0.035\\%$.\n\n\n*15/01/2022*\nAdd saved quantized models with PTQ4ViT.\n| model        |   link   |\n|:------------:|:--------:|\n| ViT-S/224/32 | [Google](https://drive.google.com/file/d/195JJJKULvaukte6PA9U08oezjd176CTs/view?usp=sharing)   |\n| ViT-S/224    | [Google](https://drive.google.com/file/d/14uEDgRmDBYoKoZtpO9IWMfG8Uvkt_OuL/view?usp=sharing)   |\n| ViT-B/224    | [Google](https://drive.google.com/file/d/1ou6s9Vd-_iyQ7sj7VYET-pRvJA6WMMLA/view?usp=sharing)   |\n| ViT-B/384    | [Google](https://drive.google.com/file/d/1tuU8or8SfQomtoWam7WFTnUxtuw3n7fs/view?usp=sharing)   |\n| DeiT-S/224   | [Google](https://drive.google.com/file/d/1673fX-SuiRlHhm7k0Yyyx_3ynwtvUPyf/view?usp=sharing)   |\n| DeiT-B/224   | [Google](https://drive.google.com/file/d/1WRAtmPF0kDR9iTLc9gv_63aEkOCZ_zOI/view?usp=sharing)   |\n| DeiT-B/384   | [Google](https://drive.google.com/file/d/1mPPlM2ioe4zts_rdKdjZTCUj8KcbquyA/view?usp=sharing)   |\n| Swin-T/224   | [Google](https://drive.google.com/file/d/1bSahHgtL3yFaHPlG-SDtu__YY0zJ8lxr/view?usp=sharing)   |\n| Swin-S/224   | [Google](https://drive.google.com/file/d/1SxAdDTwQaeJFWnHLFXncVocxMNBIPDOE/view?usp=sharing)   |\n| Swin-B/224   | [Google](https://drive.google.com/file/d/19UUUQYJGs5SQaDe27PjY3x1QTBU5hwXm/view?usp=sharing)   |\n| Swin-B/384   | [Google](https://drive.google.com/file/d/1SxAdDTwQaeJFWnHLFXncVocxMNBIPDOE/view?usp=sharing)   |\n\n*10/12/2021*\nAdd `utils/integer.py`, you can now:\n1. convert calibrated fp32 model into int8\n2. register pre-forward hook in the model, and fetch activation in int8. (We use uint8 to store results\n    of twin quantization, please refer to the paper to see the bits' layout).\n\n## Install\n\n### Requirement \n- python>=3.5\n- pytorch>=1.5\n- matplotlib\n- pandas\n- timm\n\n### Datasets\nTo run example testing, you should put your ImageNet2012 dataset in path `/datasets/imagenet`.\n\nWe use `ViTImageNetLoaderGenerator` in `utils/datasets.py` to initialize our DataLoader.\nIf your Imagenet datasets are stored elsewhere, you'll need to manually pass its root as an argument when instantiating a `ViTImageNetLoaderGenerator`.\n\n## Usage\n\n### 1. Run example quantization\nTo test on all models with BasePTQ/PTQ4ViT, run\n```bash\npython example/test_all.py\n```\n\nTo run ablation testing, run\n```bash\npython example/test_ablation.py\n```\n\nYou can run the testing scripts with multiple GPUs. For example, calling\n```bash\npython example/test_all.py --multigpu --n_gpu 6\n```\nwill use 6 gpus to run the test.\n\n### 2. Download quantized model checkpoints\n(Coming soon)\n\n## Results\n### Results of BasePTQ\n\n| model        | original | w8a8   | w6a6    |\n|:------------:|:--------:|:------:|:-------:|\n| ViT-S/224/32 | 75.99    | 73.61  | 60.144  |\n| ViT-S/224    | 81.39    | 80.468 | 70.244  |\n| ViT-B/224    | 84.54    | 83.896 | 75.668  |\n| ViT-B/384    | 86.00    | 85.352 | 46.886  |\n| DeiT-S/224   | 79.80    | 77.654 | 72.268  |\n| DeiT-B/224   | 81.80    | 80.946 | 78.786  |\n| DeiT-B/384   | 83.11    | 82.33  | 68.442  |\n| Swin-T/224   | 81.39    | 80.962 | 78.456  |\n| Swin-S/224   | 83.23    | 82.758 | 81.742  |\n| Swin-B/224   | 85.27    | 84.792 | 83.354  |\n| Swin-B/384   | 86.44    | 86.168 | 85.226  |\n\nResults of PTQ4ViT\n\n| model        | original | w8a8   | w6a6    |\n|:------------:|:--------:|:------:|:-------:|\n| ViT-S/224/32 | 75.99    | 75.582 | 71.908  |\n| ViT-S/224    | 81.39    | 81.002 | 78.63   |\n| ViT-B/224    | 84.54    | 84.25  | 81.65   |\n| ViT-B/384    | 86.00    | 85.828 | 83.348  |\n| DeiT-S/224   | 79.80    | 79.474 | 76.282  |\n| DeiT-B/224   | 81.80    | 81.482 | 80.25   |\n| DeiT-B/384   | 83.11    | 82.974 | 81.55   |\n| Swin-T/224   | 81.39    | 81.246 | 80.47   |\n| Swin-S/224   | 83.23    | 83.106 | 82.38   |\n| Swin-B/224   | 85.27    | 85.146 | 84.012  |\n| Swin-B/384   | 86.44    | 86.394 | 85.388  |\n\n### Results of Ablation\n- ViT-S/224 (original top-1 accuracy 81.39%)\n\n| Hessian Guided | Softmax Twin | GELU Twin | W8A8   | W6A6    |\n|:--------------:|:------------:|:---------:|:------:|:-------:|\n|                |              |           | 80.47  | 70.24   |\n| ✓              |              |           | 80.93  | 77.20   |\n| ✓              | ✓            |           | 81.11  | 78.57   |\n| ✓              |              | ✓         | 80.84  | 76.93   |\n|                | ✓            | ✓         | 79.25  | 74.07   |\n| ✓              | ✓            | ✓         | 81.00  | 78.63   |\n\n- ViT-B/224 (original top-1 accuracy 84.54%)\n\n| Hessian Guided | Softmax Twin | GELU Twin | W8A8   | W6A6    |\n|:--------------:|:------------:|:---------:|:------:|:-------:|\n|                |              |           | 83.90  | 75.67   |\n| ✓              |              |           | 83.97  | 79.90   |\n| ✓              | ✓            |           | 84.07  | 80.76   |\n| ✓              |              | ✓         | 84.10  | 80.82   |\n|                | ✓            | ✓         | 83.40  | 78.86   |\n| ✓              | ✓            | ✓         | 84.25  | 81.65   |\n\n- ViT-B/384 (original top-1 accuracy 86.00%)\n\n| Hessian Guided | Softmax Twin | GELU Twin | W8A8   | W6A6    |\n|:--------------:|:------------:|:---------:|:------:|:-------:|\n|                |              |           | 85.35  | 46.89   |\n| ✓              |              |           | 85.42  | 79.99   |\n| ✓              | ✓            |           | 85.67  | 82.01   |\n| ✓              |              | ✓         | 85.60  | 82.21   |\n|                | ✓            | ✓         | 84.35  | 80.86   |\n| ✓              | ✓            | ✓         | 85.89  | 83.19   |\n\n## Citation\n```\n@article{PTQ4ViT_arixv2022,\n    title={PTQ4ViT: Post-Training Quantization Framework for Vision Transformers},\n    author={Zhihang Yuan, Chenhao Xue, Yiqi Chen, Qiang Wu, Guangyu Sun},\n    journal={arXiv preprint arXiv:2111.12293},\n    year={2022},\n}\n```\n"
  },
  {
    "path": "configs/BasePTQ.py",
    "content": "from quant_layers.conv import PTQSLQuantConv2d, BatchingEasyQuantConv2d\nfrom quant_layers.linear import PTQSLBatchingQuantLinear, PostGeluPTQSLBatchingQuantLinear\nfrom quant_layers.matmul import PTQSLBatchingQuantMatMul, SoSPTQSLBatchingQuantMatMul\n\nbit = 8\nconv_fc_name_list = [\"qconv\", \"qlinear_qkv\", \"qlinear_proj\", \"qlinear_MLP_1\", \"qlinear_MLP_2\", \"qlinear_classifier\", \"qlinear_reduction\"]\nmatmul_name_list = [ \"qmatmul_qk\", \"qmatmul_scorev\"]\nw_bit = {name: bit for name in conv_fc_name_list}\na_bit = {name: bit for name in conv_fc_name_list}\nA_bit = {name: bit for name in matmul_name_list}\nB_bit = {name: bit for name in matmul_name_list}\n\nptqsl_conv2d_kwargs = {\n    \"metric\": \"cosine\",\n    \"eq_alpha\": 0.5,\n    \"eq_beta\": 1.2,\n    \"eq_n\": 100,\n    'search_round': 1,\n    \"n_V\": 1,\n    \"n_H\": 1,\n}\nptqsl_linear_kwargs = {\n    \"metric\": \"cosine\",\n    \"eq_alpha\": 0.5,\n    \"eq_beta\": 1.2,\n    \"eq_n\": 100,\n    'search_round': 1,\n    \"n_V\": 1,\n    \"n_H\": 1,\n    \"n_a\": 1,\n}\nptqsl_matmul_kwargs = {\n    \"metric\": \"cosine\",\n    \"eq_alpha\": 0.5,\n    \"eq_beta\": 1.2,\n    \"eq_n\": 100,\n    'search_round': 1,\n    \"n_G_A\": 1,\n    \"n_V_A\": 1,\n    \"n_H_A\": 1,\n    \"n_G_B\": 1,\n    \"n_V_B\": 1,\n    \"n_H_B\": 1,\n}\n\n\ndef get_module(module_type, *args, **kwargs):\n    if module_type == \"qconv\":\n        kwargs.update(ptqsl_conv2d_kwargs)\n        module=BatchingEasyQuantConv2d(*args,**kwargs,w_bit=w_bit[\"qconv\"],a_bit=32) # turn off activation quantization\n        # module=PTQSLQuantConv2d(*args,**kwargs,w_bit=w_bit[\"qconv\"],a_bit=32) # turn off activation quantization\n    elif \"qlinear\" in module_type:\n        kwargs.update(ptqsl_linear_kwargs)\n        if module_type == \"qlinear_qkv\":\n            kwargs[\"n_V\"] *= 3  # q, k, v\n            module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type])\n        else:\n            module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type])\n    elif \"qmatmul\" in module_type:\n        kwargs.update(ptqsl_matmul_kwargs)\n        module=PTQSLBatchingQuantMatMul(*args,**kwargs,A_bit=A_bit[module_type],B_bit=B_bit[module_type])\n    return module"
  },
  {
    "path": "configs/PTQ4ViT.py",
    "content": "from quant_layers.conv import PTQSLQuantConv2d, ChannelwiseBatchingQuantConv2d\nfrom quant_layers.linear import PTQSLBatchingQuantLinear, PostGeluPTQSLBatchingQuantLinear\nfrom quant_layers.matmul import PTQSLBatchingQuantMatMul, SoSPTQSLBatchingQuantMatMul\n\nno_softmax = False\nno_postgelu = False\n\nbit = 8\nconv_fc_name_list = [\"qconv\", \"qlinear_qkv\", \"qlinear_proj\", \"qlinear_MLP_1\", \"qlinear_MLP_2\", \"qlinear_classifier\", \"qlinear_reduction\"]\nmatmul_name_list = [ \"qmatmul_qk\", \"qmatmul_scorev\"]\nw_bit = {name: bit for name in conv_fc_name_list}\na_bit = {name: bit for name in conv_fc_name_list}\nA_bit = {name: bit for name in matmul_name_list}\nB_bit = {name: bit for name in matmul_name_list}\n\nptqsl_conv2d_kwargs = {\n    \"metric\": \"hessian\",\n    \"eq_alpha\": 0.01,\n    \"eq_beta\": 1.2,\n    \"eq_n\": 100,\n    'search_round': 3,\n    \"n_V\": 1,\n    \"n_H\": 1,\n}\nptqsl_linear_kwargs = {\n    \"metric\": \"hessian\",\n    \"eq_alpha\": 0.01,\n    \"eq_beta\": 1.2,\n    \"eq_n\": 100,\n    'search_round': 3,\n    \"n_V\": 1,\n    \"n_H\": 1,\n    \"n_a\": 1,\n    \"bias_correction\":True # Conventionally I'll not add an actual bias correction in linear\n}\nptqsl_matmul_kwargs = {\n    \"metric\": \"hessian\",\n    \"eq_alpha\": 0.01,\n    \"eq_beta\": 1.2,\n    \"eq_n\": 100,\n    'search_round': 3,\n    \"n_G_A\": 1,\n    \"n_V_A\": 1,\n    \"n_H_A\": 1,\n    \"n_G_B\": 1,\n    \"n_V_B\": 1,\n    \"n_H_B\": 1,\n}\n\n\ndef get_module(module_type, *args, **kwargs):\n    if module_type == \"qconv\":\n        kwargs.update(ptqsl_conv2d_kwargs)\n        module=ChannelwiseBatchingQuantConv2d(*args,**kwargs,w_bit=w_bit[\"qconv\"],a_bit=32) # turn off activation quantization\n        # module=PTQSLQuantConv2d(*args,**kwargs,w_bit=w_bit[\"qconv\"],a_bit=32) # turn off activation quantization\n    elif \"qlinear\" in module_type:\n        kwargs.update(ptqsl_linear_kwargs)\n        if module_type == \"qlinear_qkv\":\n            kwargs[\"n_V\"] *= 3  # q, k, v\n            module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type])\n        elif module_type == \"qlinear_MLP_2\":\n            if no_postgelu:\n                module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type])\n            else:\n                module=PostGeluPTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type])\n        elif module_type == \"qlinear_classifier\":\n            kwargs[\"n_V\"] = 1\n            module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type])\n        else:\n            module=PTQSLBatchingQuantLinear(*args,**kwargs,w_bit=w_bit[module_type],a_bit=a_bit[module_type])\n    elif \"qmatmul\" in module_type:\n        kwargs.update(ptqsl_matmul_kwargs)\n        if module_type == \"qmatmul_qk\":\n            module=PTQSLBatchingQuantMatMul(*args,**kwargs,A_bit=A_bit[module_type],B_bit=B_bit[module_type])\n        elif module_type == \"qmatmul_scorev\":\n            if no_softmax:\n                module=PTQSLBatchingQuantMatMul(*args,**kwargs,A_bit=A_bit[module_type],B_bit=B_bit[module_type])\n            else:\n                module=SoSPTQSLBatchingQuantMatMul(*args,**kwargs,A_bit=A_bit[module_type],B_bit=B_bit[module_type])\n    return module"
  },
  {
    "path": "example/get_int.py",
    "content": "import sys\nsys.path.insert(0,'..')\nsys.path.insert(0,'.')\nfrom example.test_vit import *\nimport utils.net_wrap as net_wrap\nimport utils.datasets as datasets\nimport utils.integer as integer\nfrom utils.quant_calib import HessianQuantCalibrator\n\nfrom itertools import product\n\ndef get_int_weights(name, config_name):\n    quant_cfg = init_config(config_name)\n\n    net = get_net(name)\n\n    wrapped_modules=net_wrap.wrap_modules_in_net(net,quant_cfg)\n    \n    g=datasets.ViTImageNetLoaderGenerator('/datasets/imagenet','imagenet',32,32,16, kwargs={\"model\":net})\n    test_loader=g.test_loader()\n    calib_loader=g.calib_loader(num=32)\n    \n    quant_calibrator = HessianQuantCalibrator(net,wrapped_modules,calib_loader,sequential=False,batch_size=4) # 16 is too big for ViT-L-16\n    quant_calibrator.batching_quant_calib()\n\n    int_weights = integer.get_model_int_weight(wrapped_modules)\n    torch.save(int_weights, f\"./int_weights/{name}.pth\")\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n\n    names = [\n        # \"vit_tiny_patch16_224\",\n        # \"vit_small_patch32_224\",\n        # \"vit_small_patch16_224\",\n        # \"vit_base_patch16_224\",\n        \"vit_base_patch16_384\",\n\n        # \"deit_tiny_patch16_224\",\n        # \"deit_small_patch16_224\",\n        # \"deit_base_patch16_224\",\n        # \"deit_base_patch16_384\",\n\n        # \"swin_tiny_patch4_window7_224\",\n        # \"swin_small_patch4_window7_224\",\n        # \"swin_base_patch4_window7_224\",\n        # \"swin_base_patch4_window12_384\",\n        ]\n    config_names = [\"PTQ4ViT\", \"BasePTQ\"]\n\n    cfg_list = []\n    for name, config in product(names, config_names):\n        cfg_list.append({\"name\":name, \"config_name\":config})\n    \n    if args.multiprocess:\n        multiprocess(get_int_weights, cfg_list, n_gpu=args.n_gpu)\n    else:\n        for cfg in cfg_list:\n            get_int_weights(**cfg)"
  },
  {
    "path": "example/test_ablation.py",
    "content": "from torch.nn.modules import module\nfrom test_vit import *\nfrom quant_layers.conv import MinMaxQuantConv2d\nfrom quant_layers.linear import MinMaxQuantLinear, PTQSLQuantLinear\nfrom quant_layers.matmul import MinMaxQuantMatMul, PTQSLQuantMatMul\nimport matplotlib.pyplot as plt\nfrom utils.net_wrap import wrap_certain_modules_in_net\nfrom tqdm import tqdm\nimport torch.nn.functional as F\nimport pickle as pkl\nfrom itertools import product\nimport types\nfrom utils.quant_calib import HessianQuantCalibrator, QuantCalibrator\nfrom utils.models import get_net\nimport time\n\ndef test_all_ablation(name, cfg_modifier=lambda x: x, calib_size=32):\n    quant_cfg = init_config(\"PTQ4ViT\")\n    quant_cfg = cfg_modifier(quant_cfg)\n\n    net = get_net(name)\n\n    wrapped_modules=net_wrap.wrap_modules_in_net(net,quant_cfg)\n    \n    g=datasets.ViTImageNetLoaderGenerator('/datasets/imagenet','imagenet',32,32,16, kwargs={\"model\":net})\n    test_loader=g.test_loader()\n    calib_loader=g.calib_loader(num=calib_size)\n    \n    quant_calibrator = HessianQuantCalibrator(net,wrapped_modules,calib_loader,sequential=False,batch_size=4) # 16 is too big for ViT-L-16\n    quant_calibrator.batching_quant_calib()\n\n    acc = test_classification(net,test_loader, description=quant_cfg.ptqsl_linear_kwargs[\"metric\"])\n\n    print(f\"model: {name} \\n\")\n    print(f\"calibration size: {calib_size} \\n\")\n    print(f\"bit settings: {quant_cfg.bit} \\n\")\n    print(f\"ptqsl_conv2d_kwargs: {quant_cfg.ptqsl_conv2d_kwargs} \\n\")\n    print(f\"ptqsl_linear_kwargs: {quant_cfg.ptqsl_linear_kwargs} \\n\")\n    print(f\"ptqsl_matmul_kwargs: {quant_cfg.ptqsl_matmul_kwargs} \\n\")\n    print(f\"accuracy: {acc} \\n\\n\")\n\nclass cfg_modifier():\n    def __init__(self, **kwargs):\n        for name, value in kwargs.items():\n            setattr(self,name,value)\n\n    def __call__(self, cfg):\n        # bit setting\n        cfg.bit = self.bit_setting\n        cfg.w_bit = {name: self.bit_setting[0] for name in cfg.conv_fc_name_list}\n        cfg.a_bit = {name: self.bit_setting[1] for name in cfg.conv_fc_name_list}\n        cfg.A_bit = {name: self.bit_setting[1] for name in cfg.matmul_name_list}\n        cfg.B_bit = {name: self.bit_setting[1] for name in cfg.matmul_name_list}\n\n        # conv2d configs\n        cfg.ptqsl_conv2d_kwargs[\"n_V\"] = self.linear_ptq_setting[0]\n        cfg.ptqsl_conv2d_kwargs[\"n_H\"] = self.linear_ptq_setting[1]\n        cfg.ptqsl_conv2d_kwargs[\"metric\"] = self.metric\n        cfg.ptqsl_conv2d_kwargs[\"search_round\"] = self.search_round\n        cfg.ptqsl_conv2d_kwargs[\"parallel_eq_n\"] = 1 # maximum 7 , reserve 4Gb for gradient \n        cfg.ptqsl_conv2d_kwargs[\"init_layerwise\"] = False\n\n        # linear configs\n        cfg.ptqsl_linear_kwargs[\"n_V\"] = self.linear_ptq_setting[0]\n        cfg.ptqsl_linear_kwargs[\"n_H\"] = self.linear_ptq_setting[1]\n        cfg.ptqsl_linear_kwargs[\"n_a\"] = self.linear_ptq_setting[2]\n        cfg.ptqsl_linear_kwargs[\"metric\"] = self.metric\n        cfg.ptqsl_linear_kwargs[\"search_round\"] = self.search_round\n        cfg.ptqsl_linear_kwargs[\"parallel_eq_n\"] = 1 # maximum 7, reserve 4Gb for gradient \n        cfg.ptqsl_linear_kwargs[\"init_layerwise\"] = False\n\n        # matmul configs\n        cfg.ptqsl_matmul_kwargs[\"metric\"] = self.metric\n        cfg.ptqsl_matmul_kwargs[\"search_round\"] = self.search_round\n        cfg.ptqsl_matmul_kwargs[\"parallel_eq_n\"] = 1 # maximum 3!\n        cfg.ptqsl_matmul_kwargs[\"init_layerwise\"] = False\n\n        # ablation\n        cfg.no_softmax = self.no_softmax\n        cfg.no_postgelu = self.no_postgelu\n\n        return cfg\n\nif __name__=='__main__':\n    args = parse_args()\n\n    names = [\n        \"vit_small_patch16_224\",\n        \"vit_base_patch16_224\",\n        \"vit_base_patch16_384\",\n        ]\n    metrics = [\"hessian\", \"cosine\"]\n    linear_ptq_settings = [(1,1,1)] # n_V, n_H, n_a\n    search_rounds = [3]\n    calib_sizes = [32]\n    bit_settings = [(8,8), (6,6)] # weight, activation\n    no_softmaxs = [True, False]\n    no_postgelus = [True, False]\n\n    cfg_list = []\n    for name, metric, linear_ptq_setting, search_round, calib_size, bit_setting, no_softmax, no_postgelu in product(names, metrics, linear_ptq_settings, search_rounds, calib_sizes, bit_settings, no_softmaxs, no_postgelus):\n        cfg_list.append({\n            \"name\": name,\n            \"cfg_modifier\":cfg_modifier(linear_ptq_setting=linear_ptq_setting, metric=metric, search_round=search_round, bit_setting=bit_setting, no_softmax=no_softmax, no_postgelu=no_postgelu),\n            \"calib_size\":calib_size,\n        })\n    \n    if args.multiprocess:\n        multiprocess(test_all_ablation, cfg_list, n_gpu=args.n_gpu)\n    else:\n        for cfg in cfg_list:\n            test_all_ablation(**cfg)"
  },
  {
    "path": "example/test_all.py",
    "content": "from timm.models.layers import config\nfrom torch.nn.modules import module\nfrom test_vit import *\nfrom quant_layers.conv import MinMaxQuantConv2d\nfrom quant_layers.linear import MinMaxQuantLinear, PTQSLQuantLinear\nfrom quant_layers.matmul import MinMaxQuantMatMul, PTQSLQuantMatMul\nimport matplotlib.pyplot as plt\nfrom utils.net_wrap import wrap_certain_modules_in_net\nfrom tqdm import tqdm\nimport torch.nn.functional as F\nimport pickle as pkl\nfrom itertools import product\nimport types\nfrom utils.quant_calib import HessianQuantCalibrator, QuantCalibrator\nfrom utils.models import get_net\nimport time\n\ndef test_all(name, cfg_modifier=lambda x: x, calib_size=32, config_name=\"PTQ4ViT\"):\n    quant_cfg = init_config(config_name)\n    quant_cfg = cfg_modifier(quant_cfg)\n\n    net = get_net(name)\n\n    wrapped_modules=net_wrap.wrap_modules_in_net(net,quant_cfg)\n    \n    g=datasets.ViTImageNetLoaderGenerator('/datasets/imagenet','imagenet',32,32,16, kwargs={\"model\":net})\n    test_loader=g.test_loader()\n    calib_loader=g.calib_loader(num=calib_size)\n    \n    # add timing\n    calib_start_time = time.time()\n    quant_calibrator = HessianQuantCalibrator(net,wrapped_modules,calib_loader,sequential=False,batch_size=4) # 16 is too big for ViT-L-16\n    quant_calibrator.batching_quant_calib()\n    calib_end_time = time.time()\n\n    acc = test_classification(net,test_loader, description=quant_cfg.ptqsl_linear_kwargs[\"metric\"])\n\n    print(f\"model: {name} \\n\")\n    print(f\"calibration size: {calib_size} \\n\")\n    print(f\"bit settings: {quant_cfg.bit} \\n\")\n    print(f\"config: {config_name} \\n\")\n    print(f\"ptqsl_conv2d_kwargs: {quant_cfg.ptqsl_conv2d_kwargs} \\n\")\n    print(f\"ptqsl_linear_kwargs: {quant_cfg.ptqsl_linear_kwargs} \\n\")\n    print(f\"ptqsl_matmul_kwargs: {quant_cfg.ptqsl_matmul_kwargs} \\n\")\n    print(f\"calibration time: {(calib_end_time-calib_start_time)/60}min \\n\")\n    print(f\"accuracy: {acc} \\n\\n\")\n\nclass cfg_modifier():\n    def __init__(self, **kwargs):\n        for name, value in kwargs.items():\n            setattr(self,name,value)\n\n    def __call__(self, cfg):\n        # bit setting\n        cfg.bit = self.bit_setting\n        cfg.w_bit = {name: self.bit_setting[0] for name in cfg.conv_fc_name_list}\n        cfg.a_bit = {name: self.bit_setting[1] for name in cfg.conv_fc_name_list}\n        cfg.A_bit = {name: self.bit_setting[1] for name in cfg.matmul_name_list}\n        cfg.B_bit = {name: self.bit_setting[1] for name in cfg.matmul_name_list}\n\n        # conv2d configs\n        cfg.ptqsl_conv2d_kwargs[\"n_V\"] = self.linear_ptq_setting[0]\n        cfg.ptqsl_conv2d_kwargs[\"n_H\"] = self.linear_ptq_setting[1]\n        cfg.ptqsl_conv2d_kwargs[\"metric\"] = self.metric\n        cfg.ptqsl_conv2d_kwargs[\"init_layerwise\"] = False\n\n        # linear configs\n        cfg.ptqsl_linear_kwargs[\"n_V\"] = self.linear_ptq_setting[0]\n        cfg.ptqsl_linear_kwargs[\"n_H\"] = self.linear_ptq_setting[1]\n        cfg.ptqsl_linear_kwargs[\"n_a\"] = self.linear_ptq_setting[2]\n        cfg.ptqsl_linear_kwargs[\"metric\"] = self.metric\n        cfg.ptqsl_linear_kwargs[\"init_layerwise\"] = False\n\n        # matmul configs\n        cfg.ptqsl_matmul_kwargs[\"metric\"] = self.metric\n        cfg.ptqsl_matmul_kwargs[\"init_layerwise\"] = False\n\n        return cfg\n\nif __name__=='__main__':\n    args = parse_args()\n\n    names = [\n        \"vit_tiny_patch16_224\",\n        \"vit_small_patch32_224\",\n        \"vit_small_patch16_224\",\n        \"vit_base_patch16_224\",\n        \"vit_base_patch16_384\",\n\n        \"deit_tiny_patch16_224\",\n        \"deit_small_patch16_224\",\n        \"deit_base_patch16_224\",\n        \"deit_base_patch16_384\",\n\n        \"swin_tiny_patch4_window7_224\",\n        \"swin_small_patch4_window7_224\",\n        \"swin_base_patch4_window7_224\",\n        \"swin_base_patch4_window12_384\",\n        ]\n    metrics = [\"hessian\"]\n    linear_ptq_settings = [(1,1,1)] # n_V, n_H, n_a\n    calib_sizes = [32,128]\n    bit_settings = [(8,8), (6,6)] # weight, activation\n    config_names = [\"PTQ4ViT\", \"BasePTQ\"]\n\n    cfg_list = []\n    for name, metric, linear_ptq_setting, calib_size, bit_setting, config_name in product(names, metrics, linear_ptq_settings, calib_sizes, bit_settings, config_names):\n        cfg_list.append({\n            \"name\": name,\n            \"cfg_modifier\":cfg_modifier(linear_ptq_setting=linear_ptq_setting, metric=metric, bit_setting=bit_setting),\n            \"calib_size\":calib_size,\n            \"config_name\": config_name\n        })\n    \n    if args.multiprocess:\n        multiprocess(test_all, cfg_list, n_gpu=args.n_gpu)\n    else:\n        for cfg in cfg_list:\n            test_all(**cfg)"
  },
  {
    "path": "example/test_vit.py",
    "content": "import sys\nsys.path.insert(0,'..')\nsys.path.insert(0,'.')\nimport torch\nimport torch.nn as nn\nfrom tqdm import tqdm\nimport argparse\nfrom importlib import reload,import_module\nimport multiprocessing\nimport os\nimport time\nfrom itertools import product\n\nimport utils.datasets as datasets\nimport utils.net_wrap as net_wrap\nfrom utils.quant_calib import QuantCalibrator, HessianQuantCalibrator\nfrom utils.models import get_net\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--n_gpu\", type=int, default=6)\n    parser.add_argument(\"--multiprocess\", action='store_true')\n    args = parser.parse_args()\n    return args\n\ndef test_classification(net,test_loader,max_iteration=None, description=None):\n    pos=0\n    tot=0\n    i = 0\n    max_iteration = len(test_loader) if max_iteration is None else max_iteration\n    with torch.no_grad():\n        q=tqdm(test_loader, desc=description)\n        for inp,target in q:\n            i+=1\n            inp=inp.cuda()\n            target=target.cuda()\n            out=net(inp)\n            pos_num=torch.sum(out.argmax(1)==target).item()\n            pos+=pos_num\n            tot+=inp.size(0)\n            q.set_postfix({\"acc\":pos/tot})\n            if i >= max_iteration:\n                break\n    print(pos/tot)\n    return pos/tot\n\ndef process(pid, experiment_process, args_queue, n_gpu):\n    \"\"\"\n    worker process. \n    \"\"\"\n    gpu_id=pid%n_gpu\n    os.environ['CUDA_VISIBLE_DEVICES']=f'{gpu_id}'\n\n    tot_run=0\n    while args_queue.qsize():\n        test_args=args_queue.get()\n        print(f\"Run {test_args} on pid={pid} gpu_id={gpu_id}\")\n        experiment_process(**test_args)\n        time.sleep(0.5)\n        tot_run+=1\n        # run_experiment(**args)\n    print(f\"{pid} tot_run {tot_run}\")\n\n\ndef multiprocess(experiment_process, cfg_list=None, n_gpu=6):\n    \"\"\"\n    run experiment processes on \"n_gpu\" cards via \"n_gpu\" worker process.\n    \"cfg_list\" arranges kwargs for each test point, and worker process will fetch kwargs and carry out an experiment.\n    \"\"\"\n    args_queue = multiprocessing.Queue()\n    for cfg in cfg_list:\n        args_queue.put(cfg)\n\n    ps=[]\n    for pid in range(n_gpu):\n        p=multiprocessing.Process(target=process,args=(pid,experiment_process,args_queue,n_gpu))\n        p.start()\n        ps.append(p)\n    for p in ps:\n        p.join()\n\ndef init_config(config_name):\n    \"\"\"initialize the config. Use reload to make sure it's fresh one!\"\"\"\n    _,_,files =  next(os.walk(\"./configs\"))\n    if config_name+\".py\" in files:\n        quant_cfg = import_module(f\"configs.{config_name}\")\n    else:\n        raise NotImplementedError(f\"Invalid config name {config_name}\")\n    reload(quant_cfg)\n    return quant_cfg\n        \n\ndef experiment_basic(net='vit_base_patch16_384', config=\"PTQ4ViT\"):\n    \"\"\"\n    A basic testbench.\n    \"\"\"\n    quant_cfg = init_config(config)\n    net = get_net(net)\n    wrapped_modules = net_wrap.wrap_modules_in_net(net,quant_cfg)\n    \n    g=datasets.ViTImageNetLoaderGenerator('/datasets/imagenet','imagenet',32,32,16,kwargs={\"model\":net})\n    test_loader=g.test_loader()\n    calib_loader=g.calib_loader(num=32)\n    \n    quant_calibrator = HessianQuantCalibrator(net,wrapped_modules,calib_loader,sequential=False,batch_size=4) # 16 is too big for ViT-L-16\n    quant_calibrator.batching_quant_calib()\n    \n    test_classification(net,test_loader)\n\nif __name__=='__main__':\n    args = parse_args()\n    cfg_list = []\n\n    nets = ['vit_tiny_patch16_224', \"deit_base_patch16_384\"]\n    configs= ['PTQ4ViT']\n\n    cfg_list = [{\n        \"net\":net,\n        \"config\":config,\n        }\n        for net, config in product(nets, configs) \n    ]\n\n    if args.multiprocess:\n        multiprocess(experiment_basic, cfg_list, n_gpu=args.n_gpu)\n    else:\n        for cfg in cfg_list:\n            experiment_basic(**cfg)\n    \n\n"
  },
  {
    "path": "quant_layers/conv.py",
    "content": "from numpy import not_equal\nfrom torch import tensor\nfrom quant_layers.linear import MinMaxQuantLinear\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom itertools import product\n\nclass MinMaxQuantConv2d(nn.Conv2d):\n    \"\"\"\n    MinMax quantize weight and output\n    \"\"\"\n    def __init__(self,in_channels: int,\n        out_channels: int,\n        kernel_size,\n        stride = 1,\n        padding = 0,\n        dilation = 1,\n        groups: int = 1,\n        bias: bool = True,\n        padding_mode: str = 'zeros',mode='raw',w_bit=8,a_bit=8,bias_bit=None):\n        super().__init__(in_channels,out_channels,kernel_size,stride,padding,dilation,groups,bias,padding_mode)\n        self.n_calibration_steps=2\n        self.mode=mode\n        self.w_bit=w_bit\n        self.a_bit=a_bit\n        self.bias_bit=bias_bit\n        assert bias_bit is None,\"No support bias bit now\"\n        self.w_interval=None\n        self.a_interval=None\n        self.bias_interval=None\n        self.raw_input=None\n        self.raw_out=None\n        self.metric=None\n        self.next_nodes=[]\n        self.w_qmax=2**(self.w_bit-1)\n        self.a_qmax=2**(self.a_bit-1)\n        # self.bias_qmax=2**(self.bias_bit-1)\n        \n    def forward(self, x):\n        if self.mode=='raw':\n            out=F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)\n        elif self.mode==\"quant_forward\":\n            out=self.quant_forward(x)\n        elif self.mode==\"calibration_step1\":\n            out=self.calibration_step1(x)\n        elif self.mode==\"calibration_step2\":\n            out=self.calibration_step2(x)\n        else:\n            raise NotImplementedError\n        return out\n            \n    def quant_weight_bias(self):\n        w=(self.weight/self.w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1)\n        w_sim=w.mul_(self.w_interval)\n        if self.bias is not None:\n            return w_sim,self.bias\n            # bias=(self.bias/self.bias_interval).round_().clamp_(-self.bias_qmax,self.bias_qmax-1)\n            # bias_sim=bias*self.bias_interval\n            # return w_sim,bias_sim\n        else:\n            return w_sim,None\n    \n    def quant_input(self,x):\n        x_sim=(x/self.a_interval).round_().clamp_(-self.a_qmax,self.a_qmax-1)\n        x_sim.mul_(self.a_interval)\n        return x_sim\n    \n    def quant_forward(self,x):\n        assert self.calibrated is not None,f\"You should run calibrate_forward before run quant_forward for {self}\"\n        w_sim,bias_sim=self.quant_weight_bias()\n        x_sim=self.quant_input(x)\n        out=F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups)\n        return out\n\n    def calibration_step1(self,x):\n        # step1: collection the FP32 values\n        out=F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)\n        self.raw_input=x.cpu().detach()\n        self.raw_out=out.cpu().detach()\n        return out\n    \n    def calibration_step2(self,x):\n        # step2: search for the best S^w and S^a of each layer\n        self.w_interval=(self.weight.data.abs().max()/(self.w_qmax-0.5)).detach()\n        self.a_interval=(x.abs().max()/(self.a_qmax-0.5)).detach()\n        self.calibrated=True\n        out=self.quant_forward(x)        \n        return out\n\nclass QuantileQuantConv2d(MinMaxQuantConv2d):\n    \"\"\"\n    Quantile quantize weight and output\n    \"\"\"\n    def __init__(self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size,\n        stride = 1,\n        padding = 0,\n        dilation = 1,\n        groups: int = 1,\n        bias: bool = True,\n        padding_mode: str = 'zeros',\n        mode='raw',w_bit=8,a_bit=8,bias_bit=None,\n        w_quantile=0.9999,a_quantile=0.9999):\n        super().__init__(in_channels,out_channels,kernel_size,stride,padding,dilation,groups,bias,padding_mode,mode,w_bit,a_bit,bias_bit)\n        self.w_quantile = w_quantile\n        self.a_quantile = a_quantile\n\n    def _quantile(self, tensor, quantile):\n        if tensor.numel() >= 16777216:\n            n = tensor.numel()//16777216\n            return torch.quantile(tensor.view(-1)[:16777216*n].view(n,16777216),quantile,1).mean()\n        else:\n            return torch.quantile(tensor,quantile)\n\n    def calibration_step2(self,x):\n        # step2: search for the best S^w and S^o of each layer\n        self.w_interval=(self._quantile(self.weight.data.abs(),self.w_quantile)/(self.w_qmax-0.5)).detach()\n        self.a_interval=(self._quantile(x.abs(),self.a_quantile)/(self.a_qmax-0.5)).detach()\n        self.calibrated=True\n        out=self.quant_forward(x)        \n        return out\n\nclass PTQSLQuantConv2d(MinMaxQuantConv2d):\n    \"\"\"\n    PTQSL on Conv2d\n    weight: (oc,ic,kw,kh) -> (oc,ic*kw*kh) -> divide into sub-matrixs and quantize\n    input: (B,ic,W,H), keep this shape\n\n    Only support SL quantization on weights.\n    \"\"\"\n    def __init__(self, in_channels: int,\n        out_channels: int,\n        kernel_size,\n        stride = 1,\n        padding = 0,\n        dilation = 1,\n        groups: int = 1,\n        bias: bool = True,\n        padding_mode: str = 'zeros',mode='raw',w_bit=8,a_bit=8,bias_bit=None,\n        metric=\"L2_norm\", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10,\n        n_V=1, n_H=1, init_layerwise=False):\n        super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit)\n        self.metric = metric\n        self.search_round = search_round\n        self.eq_alpha = eq_alpha\n        self.eq_beta = eq_beta\n        self.eq_n = eq_n\n        self.parallel_eq_n = parallel_eq_n\n        self.n_H = n_H\n        self.n_V = n_V\n        self.init_layerwise = init_layerwise\n        self.raw_grad = None\n    \n    def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1):\n        \"\"\"\n        tensor_raw: *, features\n        tensor_sim: *, features\n        similarity: *\n        It's your job to calculate mean on * dims!\n        \"\"\"\n        if metric == \"cosine\":\n            similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=dim)\n        else:\n            if metric == \"L1_norm\":\n                similarity = -torch.abs(tensor_raw - tensor_sim)\n            elif metric == \"L2_norm\":\n                similarity = -(tensor_raw - tensor_sim) ** 2\n            elif metric == \"linear_weighted_L2_norm\":\n                similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2\n            elif metric == \"square_weighted_L2_norm\":\n                similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2\n            elif metric == \"hessian\":\n                raw_grad = self.raw_grad.reshape_as(tensor_raw)\n                similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2\n            else:\n                raise NotImplementedError(f\"metric {metric} not implemented!\")\n            similarity = torch.mean(similarity, dim=dim)\n        return similarity\n\n    def quant_weight_bias(self):\n        # self.weight_interval shape: n_V, 1, n_H, 1\n        oc,ic,kw,kh=self.weight.data.shape\n        w_sim = self.weight.view(self.n_V, oc//self.n_V, self.n_H, (ic*kw*kh)//self.n_H)\n        w_sim = (w_sim/self.w_interval).round_().clamp(-self.w_qmax,self.w_qmax-1).mul_(self.w_interval)\n        w_sim = w_sim.view(oc,ic,kw,kh)\n        return w_sim, self.bias\n    \n    def _search_best_w_interval(self, x, weight_interval_candidates):\n        \"\"\"\n        Modularization of searching best weight intervals\n        \"\"\"\n        tmp_w_interval = self.w_interval.unsqueeze(0)\n        for v,h in product(range(self.n_V), range(self.n_H)):\n            similarities = []\n            for p_st in range(0, self.eq_n, self.parallel_eq_n):\n                p_ed = min(self.eq_n, p_st+self.parallel_eq_n)\n                cur_w_interval = tmp_w_interval.repeat(p_ed-p_st,1,1,1,1)\n                cur_w_interval[:,v:v+1,:,h:h+1,:] = weight_interval_candidates[p_st:p_ed,v:v+1,:,h:h+1,:]\n                # quantize weight and bias \n                oc,ic,kw,kh=self.weight.data.shape\n                w_sim = self.weight.view(self.n_V,oc//self.n_V,self.n_H,-1).unsqueeze(0) # shape: 1,n_V,crb_rows,n_H,crb_cols\n                w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,n_V,crb_rows,n_H,crb_cols\n                w_sim = w_sim.view(-1,ic,kw,kh) # shape: parallel_eq_n*oc,ic,kw,kh\n                bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None\n                # quantize input\n                x_sim = self.quant_input(x)\n                # calculate similarity and store them\n                out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: B,parallel_eq_n*oc,fw,fh\n                out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(1), chunks=p_ed-p_st, dim=2), dim=1) # shape: B,parallel_eq_n,oc,fw,fh\n                similarity = self._get_similarity(self.raw_out, out_sim, self.metric, dim=2) # shape: B,parallel_eq_n,fw,fh\n                similarity = torch.mean(similarity, [0,2,3]) # shape: parallel_eq_n\n                similarities.append(similarity)\n            # store best weight interval of h into tmp_w_interval\n            similarities = torch.cat(similarities, dim=0) # shape: eq_n\n            best_index = similarities.argmax(dim=0).reshape(-1,1,1,1,1)\n            tmp_w_interval[:,v:v+1,:,h:h+1,:] = torch.gather(weight_interval_candidates[:,v:v+1,:,h:h+1,:],dim=0,index=best_index)\n        self.w_interval = tmp_w_interval.squeeze(dim=0)\n\n    def _search_best_a_interval(self, x, input_interval_candidates):\n        similarities = []\n        for p_st in range(0,self.eq_n,self.parallel_eq_n):\n            p_ed = min(self.eq_n, p_st+self.parallel_eq_n)\n            cur_a_interval = input_interval_candidates[p_st:p_ed]\n            # quantize weight and bias \n            w_sim, bias_sim = self.quant_weight_bias()\n            # quantize input\n            B,ic,iw,ih = x.shape\n            x_sim=x.unsqueeze(0) # shape: 1,B,ic,iw,ih\n            x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: parallel_eq_n,B,ic,iw,ih\n            x_sim=x_sim.view(-1,ic,iw,ih)\n            # calculate similarity and store them\n            out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: parallel_eq_n*B,oc,fw,fh\n            out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(0), chunks=p_ed-p_st, dim=1), dim=0) # shape: parallel_eq_n,B,oc,fw,fh\n            similarity = self._get_similarity(self.raw_out.transpose(0,1), out_sim, self.metric, dim=2) # shape: parallel_eq_n,B,fw,fh\n            similarity = torch.mean(similarity, dim=[1,2,3]) # shape: parallel_eq_n\n            similarities.append(similarity)\n        # store best input interval and store in tmp_a_interval\n        similarities = torch.cat(similarities, dim=0) # shape: eq_n\n        a_best_index = similarities.argmax(dim=0).view(1,1,1,1,1)\n        self.a_interval = torch.gather(input_interval_candidates,dim=0,index=a_best_index).squeeze()\n\n\n    def _initialize_intervals(self, x):\n        self.a_interval=(x.abs().max()/(self.a_qmax-0.5)).detach()\n        if self.init_layerwise:\n            self.w_interval = ((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1)\n        else:\n            self.w_interval = (self.weight.view(self.n_V,self.out_channels//self.n_V,self.n_H,-1).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5))\n    \n    def calibration_step2(self, x):\n        # initialize intervals with minmax intervals\n        self._initialize_intervals(x)\n\n        # put raw outs on GPU\n        self.raw_out = self.raw_out.to(x.device).unsqueeze(1)  # shape: B,1,oc,W,H\n\n        # put raw grad on GPU\n        self.raw_grad = self.raw_grad.to(x.device) if self.raw_grad != None else None\n\n        # prepare weight intervals and similarities\n        weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,n_V,1,n_H,1\n        input_interval_candidates =  torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.a_interval # shape: nq_n,1,1,1,1\n        for e in range(self.search_round):\n            # search for best weight interval\n            self._search_best_w_interval(x, weight_interval_candidates)\n            # search for best input interval\n            self._search_best_a_interval(x, input_interval_candidates)\n\n        self.raw_grad = self.raw_grad.to(\"cpu\") if self.raw_grad != None else None\n\n        self.calibrated = True\n        out=self.quant_forward(x)\n        del self.raw_input, self.raw_out, self.raw_grad\n        return out  \n\nclass BatchingEasyQuantConv2d(PTQSLQuantConv2d):\n    \"\"\"An agile implementation of Layerwise Easyquant\"\"\"\n    def __init__(self, in_channels: int,\n        out_channels: int,\n        kernel_size,\n        stride = 1,\n        padding = 0,\n        dilation = 1,\n        groups: int = 1,\n        bias: bool = True,\n        padding_mode: str = 'zeros',mode='raw',w_bit=8,a_bit=8,bias_bit=None,\n        metric=\"L2_norm\", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10,\n        n_V=1, n_H=1, init_layerwise=False):\n        super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, \n                         mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_V=n_V, n_H=n_H, init_layerwise=init_layerwise)\n        self.n_V = 1\n        self.n_H = 1\n\n    def _initialize_calib_parameters(self):\n        \"\"\" \n        set parameters for feeding calibration data\n        \"\"\"\n        self.calib_size = int(self.raw_input.shape[0])\n        self.calib_batch_size = int(self.raw_input.shape[0])\n        while True:\n            numel = (2*(self.raw_input.numel()+self.raw_out.numel())/self.calib_size*self.calib_batch_size) # number of parameters on GPU\n            self.parallel_eq_n = int((15*1024*1024*1024/4)//numel)\n            if self.parallel_eq_n <= 1:\n                self.calib_need_batching = True\n                self.calib_batch_size //= 2\n            else:\n                break\n\n    def _initialize_intervals(self):\n        self.w_interval=(self.weight.data.abs().max()/(self.w_qmax-0.5)).detach()\n        tmp_a_intervals = []\n        for b_st in range(0,self.calib_size,self.calib_batch_size):\n            b_ed = min(self.calib_size, b_st+self.calib_batch_size)\n            x_ = self.raw_input[b_st:b_ed].cuda()\n            a_interval_=(x_.abs().max()/(self.a_qmax-0.5)).detach().view(1,1)\n            tmp_a_intervals.append(a_interval_)\n        self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=False)\n\n    def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1, raw_grad=None):\n        \"\"\"\n        tensor_raw: *, features\n        tensor_sim: *, features\n        similarity: *\n        It's your job to calculate mean on * dims!\n        \"\"\"\n        if metric == \"cosine\":\n            similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=dim)\n        elif metric == \"pearson\":\n            # calculate similarity w.r.t complete feature map, but maintain dimension requirement\n            b, parallel_eq_n = tensor_sim.shape[0], tensor_sim.shape[1]\n            similarity = F.cosine_similarity(tensor_raw.view(b,1,-1), tensor_sim.view(b,parallel_eq_n,-1), dim=dim).view(b,parallel_eq_n,1,1)\n        else:\n            if metric == \"L1_norm\":\n                similarity = -torch.abs(tensor_raw - tensor_sim)\n            elif metric == \"L2_norm\":\n                similarity = -(tensor_raw - tensor_sim) ** 2\n            elif metric == \"linear_weighted_L2_norm\":\n                similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2\n            elif metric == \"square_weighted_L2_norm\":\n                similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2\n            elif metric == \"hessian\":\n                assert raw_grad != None, f\"No raw grad!\"\n                raw_grad = raw_grad.reshape_as(tensor_raw)\n                similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2\n            else:\n                raise NotImplementedError(f\"metric {metric} not implemented!\")\n            similarity = torch.mean(similarity, dim=dim)\n        return similarity\n\n    def quant_weight_bias(self):\n        w_sim = self.weight\n        w_sim = (w_sim/self.w_interval).round_().clamp(-self.w_qmax,self.w_qmax-1).mul_(self.w_interval)\n        return w_sim, self.bias\n\n    def quant_forward(self, x):\n        assert self.calibrated is not None,f\"You should run calibrate_forward before run quant_forward for {self}\"\n        w_sim,bias_sim=self.quant_weight_bias()\n        x_sim=self.quant_input(x) if self.a_bit < 32 else x\n        out=F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups)\n        return out\n\n    def _search_best_w_interval(self, weight_interval_candidates):\n        batch_similarities = []\n        for b_st in range(0,self.calib_size,self.calib_batch_size):\n            b_ed = min(self.calib_size, b_st+self.calib_batch_size)\n            x = self.raw_input[b_st:b_ed].cuda()\n            raw_out = self.raw_out[b_st:b_ed].cuda().unsqueeze(1) # shape: b,1,oc,fw,fh\n            raw_grad = self.raw_grad[b_st:b_ed].cuda()\n            similarities = []\n            for p_st in range(0, self.eq_n, self.parallel_eq_n):\n                p_ed = min(self.eq_n, p_st+self.parallel_eq_n)\n                cur_w_interval = weight_interval_candidates[p_st:p_ed] # shape: parallel_eq_n,1,1,1,1\n                # quantize weight and bias\n                oc,ic,kw,kh = self.weight.data.shape\n                w_sim = self.weight.unsqueeze(0) # shape: 1,oc,ic,kw,kh\n                w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,oc,ic,kw,kh\n                w_sim = w_sim.reshape(-1,ic,kw,kh) # shape: parallel_eq_n*oc,ic,kw,kh\n                bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None\n                # quantize input\n                x_sim = self.quant_input(x)\n                # calculate similarity and store them\n                out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: b,parallel_eq_n*oc,fw,fh\n                out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(1), chunks=p_ed-p_st, dim=2), dim=1) # shape: b,parallel_eq_n,oc,fw,fh\n                similarity = self._get_similarity(raw_out, out_sim, self.metric, dim=-3, raw_grad=raw_grad) # shape: b,parallel_eq_n,fw,fh\n                similarity = torch.mean(similarity, [2,3]) # shape: b,parallel_eq_n\n                similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1, parallel_eq_n\n                similarities.append(similarity)\n            # store best weight interval of h into tmp_w_interval\n            similarities = torch.cat(similarities, dim=1) # shape: 1,eq_n\n            batch_similarities.append(similarities)\n        batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) #shape: eq_n\n        best_index = batch_similarities.argmax(dim=0).reshape(1,1,1,1,1) # shape: 1,1,1,1,1\n        self.w_interval = torch.gather(weight_interval_candidates,dim=0,index=best_index).squeeze(dim=0)\n\n    def _search_best_a_interval(self, input_interval_candidates):\n        batch_similarities = []\n        for b_st in range(0,self.calib_size,self.calib_batch_size):\n            b_ed = min(self.calib_size, b_st+self.calib_batch_size)\n            x = self.raw_input[b_st:b_ed].cuda()\n            raw_out = self.raw_out[b_st:b_ed].cuda().unsqueeze(0) # shape: 1,b,oc,fw,fh\n            raw_grad = self.raw_grad[b_st:b_ed].cuda()\n            similarities = []\n            for p_st in range(0,self.eq_n,self.parallel_eq_n):\n                p_ed = min(self.eq_n, p_st+self.parallel_eq_n)\n                cur_a_interval = input_interval_candidates[p_st:p_ed] # shape: parallel_eq_n,1,1,1,1\n                # quantize weight and bias \n                w_sim, bias_sim = self.quant_weight_bias()\n                # quantize input\n                B,ic,iw,ih = x.shape\n                x_sim=x.unsqueeze(0) # shape: 1,b,ic,iw,ih\n                x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: parallel_eq_n,b,ic,iw,ih\n                x_sim=x_sim.view(-1,ic,iw,ih) # shape: parallel_eq_n*b,ic,iw,ih\n                # calculate similarity and store them\n                out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: parallel_eq_n*b,oc,fw,fh\n                out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(0), chunks=p_ed-p_st, dim=1), dim=0) # shape: parallel_eq_n,b,oc,fw,fh\n                similarity = self._get_similarity(raw_out, out_sim, self.metric, dim=-3, raw_grad=raw_grad) # shape: parallel_eq_n,b,fw,fh\n                similarity = torch.mean(similarity, dim=[3,4]) # shape: parallel_eq_n,b\n                similarity = torch.sum(similarity, dim=1, keepdim=True) # shape: parallel_eq_n,1\n                similarities.append(similarity)\n            similarities = torch.cat(similarities, dim=0) # shape: eq_n, 1\n            batch_similarities.append(similarities)\n        batch_similarities = torch.cat(batch_similarities, dim=1).sum(dim=1, keepdim=False) #shape: eq_n\n        a_best_index = batch_similarities.argmax(dim=0).view(1,1,1,1,1)\n        self.a_interval = torch.gather(input_interval_candidates,dim=0,index=a_best_index).squeeze()\n\n    def calibration_step2(self):\n        self._initialize_calib_parameters()\n        self._initialize_intervals()\n        weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval # shape: eq_n,1,1,1,1\n        input_interval_candidates =  torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.a_interval # shape: eq_n,1,1,1,1\n        for e in range(self.search_round):\n            # search for best weight interval\n            self._search_best_w_interval(weight_interval_candidates)\n            # search for best input interval\n            if self.a_bit < 32:\n                self._search_best_a_interval(input_interval_candidates)\n        self.calibrated = True\n        del self.raw_input, self.raw_out, self.raw_grad\n\n\nclass ChannelwiseBatchingQuantConv2d(PTQSLQuantConv2d):\n    \"\"\"\n    Only implemented acceleration with batching_calibration_step2\n\n    setting a_bit to >= 32 will use minmax quantization, which means turning off activation quantization\n    \"\"\"\n    def __init__(self, in_channels: int,\n        out_channels: int,\n        kernel_size,\n        stride = 1,\n        padding = 0,\n        dilation = 1,\n        groups: int = 1,\n        bias: bool = True,\n        padding_mode: str = 'zeros',mode='raw',w_bit=8,a_bit=8,bias_bit=None,\n        metric=\"L2_norm\", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10,\n        n_V=1, n_H=1, init_layerwise=False):\n        super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, \n                         mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n,\n                         n_V=n_V, n_H=n_H, init_layerwise=init_layerwise)\n        self.n_V = self.out_channels\n        self.n_H = 1\n    \n    def _initialize_calib_parameters(self):\n        \"\"\" \n        set parameters for feeding calibration data\n        \"\"\"\n        self.calib_size = int(self.raw_input.shape[0])\n        self.calib_batch_size = int(self.raw_input.shape[0])\n        while True:\n            numel = (2*(self.raw_input.numel()+self.raw_out.numel())/self.calib_size*self.calib_batch_size) # number of parameters on GPU\n            self.parallel_eq_n = int((15*1024*1024*1024/4)//numel)\n            if self.parallel_eq_n <= 1:\n                self.calib_need_batching = True\n                self.calib_batch_size //= 2\n            else:\n                break\n    \n    def _initialize_intervals(self):\n        # weight intervals: shape oc,1,1,1\n        if self.init_layerwise:\n            self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.out_channels,1,1,1)\n        else:\n            self.w_interval=((self.weight.abs().amax([1,2,3],keepdim=True))/(self.w_qmax-0.5))\n\n        # activation intervals: shape 1\n        tmp_a_intervals = []\n        for b_st in range(0,self.calib_size,self.calib_batch_size):\n            b_ed = min(self.calib_size, b_st+self.calib_batch_size)\n            x_ = self.raw_input[b_st:b_ed].cuda()\n            a_interval_=(x_.abs().max()/(self.a_qmax-0.5)).detach().view(1,1)\n            tmp_a_intervals.append(a_interval_)\n        self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=False)\n\n    def _get_similarity(self, tensor_raw, tensor_sim, metric=None, raw_grad=None):\n        \"\"\"\n        tensor_raw: *, features\n        tensor_sim: *, features\n        similarity: *, features\n        \"\"\"\n        if metric == \"cosine\":\n            # support cosine on patch dim, which is sub-optimal\n            # not supporting search best a interval\n            b, parallel_eq_n, oc = tensor_sim.shape[0], tensor_sim.shape[1], tensor_sim.shape[2]\n            similarity = F.cosine_similarity(tensor_raw.view(b,1,oc,-1), tensor_sim.view(b,parallel_eq_n,oc,-1), dim=-1).view(b,parallel_eq_n,oc,1,1)\n        else:\n            if metric == \"L1_norm\":\n                similarity = -torch.abs(tensor_raw - tensor_sim)\n            elif metric == \"L2_norm\":\n                similarity = -(tensor_raw - tensor_sim) ** 2\n            elif metric == \"linear_weighted_L2_norm\":\n                similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2\n            elif metric == \"square_weighted_L2_norm\":\n                similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2\n            elif metric == \"hessian\":\n                assert raw_grad != None, f\"raw_grad is None in _get_similarity!\"\n                raw_grad = raw_grad.reshape_as(tensor_raw)\n                similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2\n            else:\n                raise NotImplementedError(f\"metric {metric} not implemented!\")\n        return similarity\n\n    def _search_best_w_interval(self, weight_interval_candidates):\n        batch_similarities = []\n        for b_st in range(0,self.calib_size,self.calib_batch_size):\n            b_ed = min(self.calib_size, b_st+self.calib_batch_size)\n            x = self.raw_input[b_st:b_ed].cuda()\n            raw_out = self.raw_out[b_st:b_ed].cuda().unsqueeze(1) # shape: b,1,oc,fw,fh\n            raw_grad = self.raw_grad[b_st:b_ed].cuda()\n            similarities = []\n            for p_st in range(0, self.eq_n, self.parallel_eq_n):\n                p_ed = min(self.eq_n, p_st+self.parallel_eq_n)\n                cur_w_interval = weight_interval_candidates[p_st:p_ed] # shape: parallel_eq_n,oc,1,1,1\n                # quantize weight and bias\n                oc,ic,kw,kh = self.weight.data.shape\n                w_sim = self.weight.unsqueeze(0) # shape: 1,oc,ic,kw,kh\n                w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,oc,ic,kw,kh\n                w_sim = w_sim.reshape(-1,ic,kw,kh) # shape: parallel_eq_n*oc,ic,kw,kh\n                bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None\n                # quantize input\n                x_sim = self.quant_input(x) if self.a_bit < 32 else x\n                # calculate similarity and store them\n                out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: b,parallel_eq_n*oc,fw,fh\n                out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(1), chunks=p_ed-p_st, dim=2), dim=1) # shape: b,parallel_eq_n,oc,fw,fh\n                similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad) # shape: b,parallel_eq_n,oc,fw,fh\n                similarity = torch.mean(similarity, [3,4]) # shape: b,parallel_eq_n,oc\n                similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1, parallel_eq_n, oc\n                similarities.append(similarity)\n            # store best weight interval of h into tmp_w_interval\n            similarities = torch.cat(similarities, dim=1) # shape: 1,eq_n,oc\n            batch_similarities.append(similarities)\n        batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) #shape: eq_n,oc\n        best_index = batch_similarities.argmax(dim=0).reshape(1,-1,1,1,1) # shape: 1,oc,1,1,1\n        self.w_interval = torch.gather(weight_interval_candidates,dim=0,index=best_index).squeeze(dim=0)\n\n    def _search_best_a_interval(self, input_interval_candidates):\n        batch_similarities = []\n        for b_st in range(0,self.calib_size,self.calib_batch_size):\n            b_ed = min(self.calib_size, b_st+self.calib_batch_size)\n            x = self.raw_input[b_st:b_ed].cuda()\n            raw_out = self.raw_out[b_st:b_ed].cuda().unsqueeze(1) # shape: b,1,oc,fw,fh\n            raw_grad = self.raw_grad[b_st:b_ed].cuda()\n            similarities = []\n            for p_st in range(0,self.eq_n,self.parallel_eq_n):\n                p_ed = min(self.eq_n, p_st+self.parallel_eq_n)\n                cur_a_interval = input_interval_candidates[p_st:p_ed] # shape: parallel_eq_n,1,1,1,1\n                # quantize weight and bias \n                w_sim, bias_sim = self.quant_weight_bias()\n                # quantize input\n                B,ic,iw,ih = x.shape\n                x_sim=x.unsqueeze(0) # shape: 1,b,ic,iw,ih\n                x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: parallel_eq_n,b,ic,iw,ih\n                x_sim=x_sim.view(-1,ic,iw,ih) # shape: parallel_eq_n*b,ic,iw,ih\n                # calculate similarity and store them\n                out_sim = F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups) # shape: parallel_eq_n*b,oc,fw,fh\n                out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(0), chunks=p_ed-p_st, dim=1), dim=0) # shape: parallel_eq_n,b,oc,fw,fh\n                out_sim = out_sim.transpose_(0, 1) # shape: b,parallel_eq_n,oc,fw,fh\n                similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad=raw_grad) # shape: b,parallel_eq_n,oc,fw,fh\n                similarity = torch.mean(similarity, dim=[2,3,4]) # shape: b,parallel_eq_n\n                similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1,parallel_eq_n\n                similarities.append(similarity)\n            similarities = torch.cat(similarities, dim=1) # shape: 1,eq_n\n            batch_similarities.append(similarities)\n        batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) #shape: eq_n\n        a_best_index = batch_similarities.argmax(dim=0).view(1,1,1,1,1)\n        self.a_interval = torch.gather(input_interval_candidates,dim=0,index=a_best_index).squeeze()\n\n    def calibration_step2(self):\n        self._initialize_calib_parameters()\n        self._initialize_intervals()\n        weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,oc,1,1,1\n        input_interval_candidates =  torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.a_interval # shape: eq_n,1,1,1,1\n        for e in range(self.search_round):\n            # search for best weight interval\n            self._search_best_w_interval(weight_interval_candidates)\n            # search for best input interval\n            if self.a_bit < 32:\n                self._search_best_a_interval(input_interval_candidates)\n        self.calibrated = True\n        del self.raw_input, self.raw_out, self.raw_grad\n    \n    def quant_weight_bias(self):\n        w_sim = (self.weight/self.w_interval).round_().clamp(-self.w_qmax,self.w_qmax-1).mul_(self.w_interval)\n        return w_sim, self.bias\n\n    def quant_forward(self, x):\n        assert self.calibrated is not None,f\"You should run calibrate_forward before run quant_forward for {self}\"\n        w_sim,bias_sim=self.quant_weight_bias()\n        x_sim=self.quant_input(x) if self.a_bit < 32 else x\n        out=F.conv2d(x_sim, w_sim, bias_sim, self.stride, self.padding, self.dilation, self.groups)\n        return out"
  },
  {
    "path": "quant_layers/linear.py",
    "content": "from quant_layers.matmul import PTQSLBatchingQuantMatMul\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass MinMaxQuantLinear(nn.Linear):\n    def __init__(self,\n        in_features: int,\n        out_features: int,\n        bias: bool = True,\n        mode = \"raw\",\n        w_bit = 8,\n        a_bit = 8,\n        bias_bit = None,\n        bias_correction=False):\n        super().__init__(in_features,out_features,bias)\n        self.n_calibration_step=2\n        self.mode = mode\n        self.w_bit = w_bit\n        self.a_bit = a_bit\n        self.bias_bit=bias_bit\n        assert bias_bit is None,\"No support bias bit now\"\n        self.w_interval=None\n        self.a_interval=None\n        self.raw_input=None\n        self.raw_out=None\n        self.metric=None\n        self.next_nodes=[]\n        self.w_qmax=2**(self.w_bit-1)\n        self.a_qmax=2**(self.a_bit-1)\n        self.bias_correction = bias_correction\n\n    def forward(self, x):\n        if self.mode=='raw':\n            out=F.linear(x, self.weight, self.bias)\n        elif self.mode==\"quant_forward\":\n            out=self.quant_forward(x)\n        elif self.mode==\"calibration_step1\":\n            out=self.calibration_step1(x)\n        elif self.mode==\"calibration_step2\":\n            out=self.calibration_step2(x)\n        else:\n            raise NotImplementedError\n        return out\n    \n    def quant_weight_bias(self):\n        w=(self.weight/self.w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1)\n        w_sim=w.mul_(self.w_interval)\n        if self.bias is not None:\n            return w_sim,self.bias\n            # bias=(self.bias/self.bias_interval).round_().clamp_(-self.bias_qmax,self.bias_qmax-1)\n            # bias_sim=bias*self.bias_interval\n            # return w_sim,bias_sim\n        else:\n            return w_sim,None\n    \n    def quant_input(self, x):\n        x_sim=(x/self.a_interval).round_().clamp_(-self.a_qmax,self.a_qmax-1)\n        x_sim.mul_(self.a_interval)\n        return x_sim\n    \n    def quant_forward(self,x):\n        assert self.calibrated is not None,f\"You should run calibrate_forward before run quant_forward for {self}\"\n        w_sim,bias_sim=self.quant_weight_bias()\n        x_sim=self.quant_input(x)\n        out=F.linear(x_sim, w_sim, bias_sim)\n        return out\n    \n    def _bias_correction_quant_forward(self, x):\n        if self.bias_correction and self.bias != None:\n            w_sim = self.quant_weight_bias()[0]\n            x_sim = self.quant_input(x)\n            eps = F.linear(x_sim, w_sim-self.weight.data, None)\n            eps = torch.mean(eps, dim=(list(range(len(eps.shape)-1))), keepdim=False)\n            self.bias -= eps\n            self.bias_correction = False\n        return self.quant_forward(x)\n\n    def calibration_step1(self,x):\n        # step1: collection the FP32 values\n        out=F.linear(x, self.weight, self.bias)\n        self.raw_input=x.cpu().detach()\n        self.raw_out=out.cpu().detach()\n        return out\n    \n    def calibration_step2(self,x):\n        # step2: search for the best S^w and S^o of each layer\n        self.w_interval=(self.weight.data.abs().max()/(self.w_qmax-0.5)).detach()\n        self.a_interval=(x.abs().max()/(self.a_qmax-0.5)).detach()\n        self.calibrated=True\n        out=self._bias_correction_quant_forward(x)\n        return out\n\nclass PTQSLQuantLinear(MinMaxQuantLinear):\n    \"\"\"\n    PTQSL on linear modules.\n    \"\"\"\n    def __init__(self, \n        in_features: int,\n        out_features: int,\n        bias: bool = True,\n        mode = \"raw\",\n        w_bit = 8,\n        a_bit = 8,\n        bias_bit = None,\n        bias_correction = False,\n        metric=\"L2_norm\", search_round=1, eq_alpha=0, eq_beta=1, eq_n=100, parallel_eq_n=10, n_H=1, n_V=1, n_a=1, init_layerwise=False):\n        super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, bias_correction=bias_correction)\n        self.metric = metric\n        self.search_round = search_round\n        self.eq_alpha = eq_alpha\n        self.eq_beta = eq_beta\n        self.eq_n = eq_n\n        self.n_H = n_H\n        self.n_V = n_V\n        self.n_a = n_a\n        self.crb_rows = out_features // n_V\n        self.crb_cols = in_features // n_H # ignore remnent != 0 situations\n        self.crb_acts = in_features // n_a\n        self.parallel_eq_n = parallel_eq_n\n        self.init_layerwise = init_layerwise\n        self.raw_grad = None\n\n    def _get_similarity(self, tensor_raw, tensor_sim, metric=None):\n        \"\"\"\n        tensor_raw: *, features\n        tensor_sim: *, features\n        similarity: *\n        It's your job to calculate mean on * dims!\n        \"\"\"\n        if metric == \"cosine\":\n            similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=-1)\n        elif metric == \"pearson\":\n            similarity = F.cosine_similarity(tensor_raw-torch.mean(tensor_raw,dim=-1,keepdim=True), tensor_sim-torch.mean(tensor_sim,dim=-1,keepdim=True), dim=-1)\n        else:\n            if metric == \"L1_norm\":\n                similarity = -torch.abs(tensor_raw - tensor_sim)\n            elif metric == \"L2_norm\":\n                similarity = -(tensor_raw - tensor_sim) ** 2\n            elif metric == \"linear_weighted_L2_norm\":\n                similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2\n            elif metric == \"square_weighted_L2_norm\":\n                similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2\n            elif metric == \"hessian\":\n                raw_grad = self.raw_grad.reshape_as(tensor_raw)\n                similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2\n            else:\n                raise NotImplementedError(f\"metric {metric} not implemented!\")\n            similarity = torch.mean(similarity, dim=-1)\n        return similarity\n    \n    def quant_weight_bias(self):\n        # self.w_interval shape: n_V, 1, n_H, 1\n        w=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols)/self.w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1)\n        w_sim=w.mul_(self.w_interval).view(self.out_features,self.in_features)\n        if self.bias is not None:\n            return w_sim,self.bias\n            # bias=(self.bias/self.bias_interval).round_().clamp_(-self.bias_qmax,self.bias_qmax-1)\n            # bias_sim=bias*self.bias_interval\n            # return w_sim,bias_sim\n        else:\n            return w_sim,None\n    \n    def quant_input(self, x):\n        # self.a_interval shape: n_a,1\n        x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2)\n        x_sim=(x_sim.div_(self.a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)\n        x_sim = x_sim.mul_(self.a_interval).reshape_as(x)\n        return x_sim\n\n    def _search_best_w_interval(self, x, weight_interval_candidates, raw_out_expanded_chunked):\n        \"\"\"\n        Modularization of searching best weight intervals\n        \"\"\"\n        tmp_w_interval = self.w_interval.unsqueeze(0) # shape: 1,n_V,1,n_H,1\n        for h in range(self.n_H):\n            similarities = []\n            for p_st in range(0,self.eq_n,self.parallel_eq_n):\n                p_ed = min(self.eq_n, p_st+self.parallel_eq_n)\n                cur_w_interval = tmp_w_interval.repeat(p_ed-p_st,1,1,1,1)\n                cur_w_interval[:,:,:,h:h+1,:] = weight_interval_candidates[p_st:p_ed,:,:,h:h+1,:]\n                # quantize weight and bias \n                w_sim = self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).unsqueeze(0) # shape: 1,n_V,crb_rows,n_H,crb_cols\n                w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,n_V,crb_rows,n_H,crb_cols\n                w_sim = w_sim.view(-1,self.in_features) # shape: parallel_eq_n*oc,ic\n                bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None\n                # quantize input\n                x_sim = self.quant_input(x)\n                # calculate similarity and store them\n                out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: B,*,parallel_eq_n*oc\n                out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(-2), chunks=p_ed-p_st, dim=-1), dim=-2) # shape: B,*,parallel_eq_n,oc\n                out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: B,*,parallel_eq_n,n_V,crb_rows\n                similarity = self._get_similarity(raw_out_expanded_chunked, out_sim, self.metric) # shape: B,*,parallel_eq_n,n_V\n                similarity = torch.mean(similarity, dim=list(range(len(similarity.shape)-2))) # shape: parallel_eq_n, n_V\n                similarities.append(similarity)\n            # store best weight interval of h into tmp_w_interval\n            similarities = torch.cat(similarities, dim=0) # shape: eq_n, n_V\n            h_best_index = similarities.argmax(dim=0).reshape(1,-1,1,1,1) # shape: 1,n_V,1,1,1\n            tmp_w_interval[:,:,:,h:h+1,:] = torch.gather(weight_interval_candidates[:,:,:,h:h+1,:],dim=0,index=h_best_index)\n        self.w_interval = tmp_w_interval.squeeze(dim=0)\n    \n    def _search_best_a_interval(self, x, input_interval_candidates, raw_out_expanded):\n        tmp_a_interval = self.a_interval.unsqueeze(-1) # shape: n_a,1,1\n        for a in range(self.n_a):\n            similarities = []\n            for p_st in range(0,self.eq_n,self.parallel_eq_n):\n                p_ed = min(self.eq_n, p_st+self.parallel_eq_n)\n                cur_a_interval = tmp_a_interval.repeat(1,1,p_ed-p_st) # shape: n_a,1,parallel_eq_n\n                cur_a_interval[a:a+1,:,:] = input_interval_candidates[a:a+1,:,p_st:p_ed]\n                # quantize weight and bias \n                w_sim, bias_sim = self.quant_weight_bias()\n                # quantize input\n                x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2).unsqueeze(-1)\n                x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: B,*,n_a,crb_acts,parallel_eq_n\n                x_sim = x_sim.permute(*list(range(len(x_sim.shape)-3)),-1,-3,-2).reshape(*x.shape[:-1],p_ed-p_st,x.shape[-1]) # shape: B,*,parallel_eq_n,ic\n                # calculate similarity and store them\n                out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: B,*,parallel_eq_n,oc\n                similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric) # shape: B,*,parallel_eq_n\n                similarity = torch.mean(similarity, dim=list(range(len(similarity.shape)-1))) # shape: parallel_eq_n\n                similarities.append(similarity)\n            # store best input interval and store in tmp_a_interval\n            similarities = torch.cat(similarities, dim=0) # shape: eq_n\n            a_best_index = similarities.argmax(dim=0, keepdim=True).reshape(1,1,-1)\n            tmp_a_interval[a:a+1,:,:] = torch.gather(input_interval_candidates[a:a+1,:,:],dim=2,index=a_best_index)\n        self.a_interval = tmp_a_interval.squeeze(-1)\n\n    def _initialize_intervals(self, x):\n        if self.init_layerwise:\n            self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1)\n            self.a_interval=(x.abs().max()/(self.a_qmax-0.5)).detach().view(1,1).repeat(self.n_a,1)\n        else:\n            self.w_interval=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5))\n            self.a_interval=((x.view(*x.shape[:-1],self.n_a,self.crb_acts).abs().amax(list(range(len(x.shape)-1))+[-1],keepdim=False))/(self.a_qmax-0.5)).unsqueeze(-1)\n\n    def calibration_step2(self,x):\n        # initialize intervals with minmax intervals\n        self._initialize_intervals(x)\n\n        # put raw outs on GPU\n        raw_out_expanded = self.raw_out.to(x.device).unsqueeze(-2)  # shape: B,*,1,oc\n        raw_out_expanded_chunked = torch.cat(torch.chunk(raw_out_expanded.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: B,*,1,n_V,crb_rows\n\n        # put raw grad on GPU\n        self.raw_grad = self.raw_grad.to(x.device) if self.raw_grad != None else None\n\n        # prepare weight intervals and similarities\n        weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,n_V,1,n_H,1\n        input_interval_candidates =  torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(1,1,-1) * self.a_interval.unsqueeze(-1) # shape: n_a,1,eq_n\n        for e in range(self.search_round):\n            # search for best weight interval\n            self._search_best_w_interval(x, weight_interval_candidates, raw_out_expanded_chunked)\n            # search for best input interval\n            self._search_best_a_interval(x, input_interval_candidates, raw_out_expanded)\n\n        self.raw_grad = self.raw_grad.to(\"cpu\") if self.raw_grad != None else None\n\n        self.calibrated = True\n        out=self._bias_correction_quant_forward(x)\n        del self.raw_input, self.raw_out, self.raw_grad\n        return out    \n\nclass PostGeluPTQSLQuantLinear(PTQSLQuantLinear):\n    def __init__(self, \n        in_features: int,\n        out_features: int,\n        bias: bool = True,\n        mode = \"raw\",\n        w_bit = 8,\n        a_bit = 8,\n        bias_bit = None,\n        bias_correction = False,\n        metric=\"L2_norm\", search_round=1, eq_alpha=0, eq_beta=1, eq_n=100, parallel_eq_n=10, n_H=1, n_V=1, n_a=1, init_layerwise=False):\n        super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, bias_correction=bias_correction,\n                         metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_H=n_H, n_V=n_V, n_a=n_a, init_layerwise=init_layerwise)\n    \n    def quant_input(self, x):\n        \"\"\"\n        self.a_interval = [a_interval_pos, a_interval_neg]\n        \"\"\"\n        # self.a_interval[0] shape: n_a,1\n        # self.a_interval[1] shape: 1\n        x_=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2)\n        x_pos=(x_/(self.a_interval[0])).round_().clamp_(0,self.a_qmax-1).mul_(self.a_interval[0])\n        x_neg=(x_/(self.a_interval[1])).round_().clamp_(-self.a_qmax,0).mul_(self.a_interval[1])\n        return (x_pos + x_neg).reshape_as(x)\n\n    def _search_best_a_interval(self, x, input_interval_candidates, raw_out_expanded):\n        tmp_a_interval = self.a_interval[0].unsqueeze(-1) # shape: n_a,1,1\n        for a in range(self.n_a):\n            similarities = []\n            for p_st in range(0,self.eq_n,self.parallel_eq_n):\n                p_ed = min(self.eq_n, p_st+self.parallel_eq_n)\n                cur_a_interval = tmp_a_interval.repeat(1,1,p_ed-p_st) # shape: n_a,1,parallel_eq_n\n                cur_a_interval[a:a+1,:,:] = input_interval_candidates[a:a+1,:,p_st:p_ed]\n                # quantize weight and bias \n                w_sim, bias_sim = self.quant_weight_bias()\n                # quantize input\n                x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2).unsqueeze(-1)\n                x_pos=(x_sim/(cur_a_interval)).round_().clamp_(0,self.a_qmax-1)*(cur_a_interval) # shape: B,*,n_a,crb_acts,parallel_eq_n\n                x_neg=(x_sim/(self.a_interval[1])).round_().clamp_(-self.a_qmax,0)*(self.a_interval[1]) # shape: B,*,n_a,crb_acts,1\n                x_sim = (x_pos + x_neg).permute(*list(range(len(x_sim.shape)-3)),-1,-3,-2).reshape(*x.shape[:-1],p_ed-p_st,x.shape[-1]) # shape: B,*,parallel_eq_n,ic\n                # calculate similarity and store them\n                out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: B,*,parallel_eq_n,oc\n                similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric) # shape: B,*,parallel_eq_n\n                similarity = torch.mean(similarity, dim=list(range(len(similarity.shape)-1))) # shape: parallel_eq_n\n                similarities.append(similarity)\n            # store best input interval and store in tmp_a_interval\n            similarities = torch.cat(similarities, dim=0) # shape: eq_n\n            a_best_index = similarities.argmax(dim=0, keepdim=True).reshape(1,1,-1)\n            tmp_a_interval[a:a+1,:,:] = torch.gather(input_interval_candidates[a:a+1,:,:],dim=2,index=a_best_index)\n        self.a_interval[0] = tmp_a_interval.squeeze(-1)\n\n    def _initialize_intervals(self, x):\n        if self.init_layerwise:\n            self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1)\n            self.a_interval=[(x.max()/(self.a_qmax-0.5)).detach().view(1,1).repeat(self.n_a,1)]\n        else:\n            self.w_interval=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5))\n            self.a_interval=[((x.view(*x.shape[:-1],self.n_a,self.crb_acts).amax(list(range(len(x.shape)-1))+[-1],keepdim=False))/(self.a_qmax-0.5)).unsqueeze(-1)]\n        self.a_interval.append(0.16997124254703522/self.a_qmax)\n\n    def calibration_step2(self,x):\n        # initialize intervals with minmax intervals\n        self._initialize_intervals(x)\n\n        # put raw outs on GPU\n        raw_out_expanded = self.raw_out.to(x.device).unsqueeze(-2)  # shape: B,*,1,oc\n        raw_out_expanded_chunked = torch.cat(torch.chunk(raw_out_expanded.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: B,*,1,n_V,crb_rows\n\n        # put raw grad on GPU\n        self.raw_grad = self.raw_grad.to(x.device) if self.raw_grad != None else None\n\n        # prepare weight intervals and similarities\n        weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,n_V,1,n_H,1\n        input_interval_candidates =  torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(1,1,-1) * self.a_interval[0].unsqueeze(-1) # shape: n_a,1,eq_n\n        for e in range(self.search_round):\n            # search for best weight interval\n            self._search_best_w_interval(x, weight_interval_candidates, raw_out_expanded_chunked)\n            # search for best input interval\n            self._search_best_a_interval(x, input_interval_candidates, raw_out_expanded)\n\n        self.raw_grad = self.raw_grad.to(\"cpu\") if self.raw_grad != None else None\n\n        self.calibrated = True\n        out=self._bias_correction_quant_forward(x)\n        del self.raw_input, self.raw_out, self.raw_grad\n        return out    \n\nclass PTQSLBatchingQuantLinear(PTQSLQuantLinear):\n    def __init__(self, \n        in_features: int,\n        out_features: int,\n        bias: bool = True,\n        mode = \"raw\",\n        w_bit = 8,\n        a_bit = 8,\n        bias_bit = None,\n        bias_correction = False,\n        metric=\"L2_norm\", search_round=1, eq_alpha=0, eq_beta=1, eq_n=100, parallel_eq_n=10, n_H=1, n_V=1, n_a=1, init_layerwise=False):\n        super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, bias_correction=bias_correction, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_H=n_H, n_V=n_V, n_a=n_a, init_layerwise=init_layerwise)\n        self.calib_size = None\n        self.calib_batch_size = None\n        self.calib_need_batching = False\n\n    def _initialize_calib_parameters(self):\n        \"\"\" \n        set parameters for feeding calibration data\n        \"\"\"\n        self.calib_size = int(self.raw_input.shape[0])\n        self.calib_batch_size = int(self.raw_input.shape[0])\n        while True:\n            numel = (2*(self.raw_input.numel()+self.raw_out.numel())/self.calib_size*self.calib_batch_size) # number of parameters on GPU\n            self.parallel_eq_n = int((3*1024*1024*1024/4)//numel)\n            if self.parallel_eq_n <= 1:\n                self.calib_need_batching = True\n                self.calib_batch_size //= 2\n            else:\n                break\n    \n    def _initialize_intervals(self):\n        # weight intervals \n        if self.init_layerwise:\n            self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1)\n        else:\n            self.w_interval=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5))\n\n        # activation intervals\n        tmp_a_intervals = []\n        for b_st in range(0,self.calib_size,self.calib_batch_size):\n            b_ed = min(self.calib_size, b_st+self.calib_batch_size)\n            x_ = self.raw_input[b_st:b_ed].cuda()\n            if self.init_layerwise:\n                a_interval_=(x_.abs().max()/(self.a_qmax-0.5)).detach().view(1,1).repeat(self.n_a,1)\n            else:\n                a_interval_=((x_.view(*x_.shape[:-1],self.n_a,self.crb_acts).abs().amax(list(range(len(x_.shape)-1))+[-1],keepdim=False))/(self.a_qmax-0.5)).unsqueeze(-1)\n            tmp_a_intervals.append(a_interval_)\n        self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=True)\n\n    def _get_similarity(self, tensor_raw, tensor_sim, metric=None, raw_grad=None):\n        \"\"\"\n        tensor_raw: *, features\n        tensor_sim: *, features\n        similarity: *\n        It's your job to calculate mean on * dims!\n        \"\"\"\n        if metric == \"cosine\":\n            similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=-1)\n        else:\n            if metric == \"L1_norm\":\n                similarity = -torch.abs(tensor_raw - tensor_sim)\n            elif metric == \"L2_norm\":\n                similarity = -(tensor_raw - tensor_sim) ** 2\n            elif metric == \"linear_weighted_L2_norm\":\n                similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2\n            elif metric == \"square_weighted_L2_norm\":\n                similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2\n            elif metric == \"hessian\":\n                assert raw_grad != None, f\"raw_grad is None in _get_similarity!\"\n                raw_grad = raw_grad.reshape_as(tensor_raw)\n                similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2\n            else:\n                raise NotImplementedError(f\"metric {metric} not implemented!\")\n            similarity = torch.mean(similarity, dim=-1)\n        return similarity\n\n    def _get_pearson_w(self, tensor_raw, tensor_sim):\n        \"\"\"\n        Quick implementation of similarity-aware linear quantization\n        tensor_sim: b,*,parallel_eq_n,n_V,crb_rows\n        tensor_raw: b,*,1,n_V,crb_rows\n        \"\"\"\n        b, parallel_eq_n, n_V = tensor_sim.shape[0],tensor_sim.shape[-3],tensor_sim.shape[-2]\n        tensor_sim = tensor_sim.transpose(-1,-3).contiguous_().view(b,-1,n_V,parallel_eq_n)\n        tensor_raw = tensor_raw.transpose(-1,-3).view(b,-1,n_V,1)\n        tensor_sim_mean = tensor_sim.mean(dim=[0,1],keepdim=True)\n        tensor_raw_mean = tensor_raw.mean(dim=[0,1],keepdim=True)\n        similarity = torch.cosine_similarity(tensor_raw-tensor_raw_mean, tensor_sim-tensor_sim_mean, dim=1) # shape: b,n_V,parallel_eq_n\n        similarity = similarity.permute(0,2,1).contiguous_()\n        return similarity\n    \n    def _get_pearson_a(self, tensor_raw, tensor_sim):\n        \"\"\"\n        Quick implementation of similarity-aware linear quantization\n        tensor_sim: b,*,parallel_eq_n,oc\n        tensor_raw: b,*,1,oc\n        \"\"\"\n        b, parallel_eq_n = tensor_sim.shape[0],tensor_sim.shape[-2]\n        tensor_sim = tensor_sim.transpose(-1,-2).contiguous_().view(b,-1,parallel_eq_n)\n        tensor_raw = tensor_raw.transpose(-1,-2).view(b,-1,1)\n        tensor_sim_mean = tensor_sim.mean(dim=[0,1],keepdim=True)\n        tensor_raw_mean = tensor_raw.mean(dim=[0,1],keepdim=True)\n        similarity = torch.cosine_similarity(tensor_raw-tensor_raw_mean, tensor_sim-tensor_sim_mean, dim=1) # shape: b,parallel_eq_n\n        return similarity\n\n    def _search_best_w_interval(self, weight_interval_candidates):\n        tmp_w_interval = self.w_interval.unsqueeze(0) # shape: 1,n_V,1,n_H,1\n        for h in range(self.n_H):\n            batch_similarities = [] # similarities, need to concatenate and calculate sum (equivalent to mean with argmax)\n            for b_st in range(0, self.calib_size, self.calib_batch_size):\n                b_ed = min(self.calib_size, b_st + self.calib_batch_size)\n                x = self.raw_input[b_st:b_ed].cuda()\n                raw_out_expanded = self.raw_out[b_st:b_ed].cuda().unsqueeze(-2) # shape: b,*,1,oc\n                raw_out_expanded = torch.cat(torch.chunk(raw_out_expanded.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: b,*,1,n_V,crb_rows\n                raw_grad = self.raw_grad[b_st:b_ed].cuda() # will be reshaped later\n                similarities = []\n                for p_st in range(0,self.eq_n,self.parallel_eq_n):\n                    p_ed = min(self.eq_n, p_st+self.parallel_eq_n)\n                    cur_w_interval = tmp_w_interval.repeat(p_ed-p_st,1,1,1,1)\n                    cur_w_interval[:,:,:,h:h+1,:] = weight_interval_candidates[p_st:p_ed,:,:,h:h+1,:]\n                    # quantize weight and bias \n                    w_sim = self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).unsqueeze(0) # shape: 1,n_V,crb_rows,n_H,crb_cols\n                    w_sim = (w_sim/cur_w_interval).round_().clamp_(-self.w_qmax,self.w_qmax-1).mul_(cur_w_interval) # shape: parallel_eq_n,n_V,crb_rows,n_H,crb_cols\n                    w_sim = w_sim.view(-1,self.in_features) # shape: parallel_eq_n*oc,ic\n                    bias_sim = self.bias.repeat(p_ed-p_st) if self.bias is not None else None\n                    # quantize input\n                    x_sim = self.quant_input(x)\n                    # calculate similarity and store them\n                    out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: b,*,parallel_eq_n*oc\n                    out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(-2), chunks=p_ed-p_st, dim=-1), dim=-2) # shape: b,*,parallel_eq_n,oc\n                    out_sim = torch.cat(torch.chunk(out_sim.unsqueeze(-2), chunks=self.n_V, dim=-1), dim=-2) # shape: b,*,parallel_eq_n,n_V,crb_rows\n                    if self.metric != \"pearson\":\n                        similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric, raw_grad) # shape: b,*,parallel_eq_n,n_V\n                        if len(similarity.shape) > 3:\n                            similarity = torch.mean(similarity, dim=list(range(1,len(similarity.shape)-2))) # shape: b, parallel_eq_n, n_V\n                    else:\n                        similarity = self._get_pearson_w(raw_out_expanded, out_sim)\n                    similarity = similarity.sum(dim=0, keepdim=True) # shape: 1, parallel_eq_n, n_V\n                    similarities.append(similarity)\n                # store best weight interval of h into tmp_w_interval\n                similarities = torch.cat(similarities, dim=1) # shape: 1, eq_n, n_V\n                batch_similarities.append(similarities)\n            batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) # shape: eq_n, n_V\n            h_best_index = batch_similarities.argmax(dim=0).reshape(1,-1,1,1,1) # shape: 1,n_V,1,1,1\n            tmp_w_interval[:,:,:,h:h+1,:] = torch.gather(weight_interval_candidates[:,:,:,h:h+1,:],dim=0,index=h_best_index)\n        self.w_interval = tmp_w_interval.squeeze(dim=0)\n\n    def _search_best_a_interval(self, input_interval_candidates):\n        tmp_a_interval = self.a_interval.unsqueeze(-1) # shape: n_a,1,1\n        for a in range(self.n_a):\n            batch_similarities = [] # similarities, need to concatenate and calculate sum (equivalent to mean with argmax)\n            for b_st in range(0, self.calib_size, self.calib_batch_size):\n                b_ed = min(self.calib_size, b_st + self.calib_batch_size)\n                x = self.raw_input[b_st:b_ed].cuda()\n                raw_out_expanded = self.raw_out[b_st:b_ed].cuda().unsqueeze(-2) # shape: b,*,1,oc\n                raw_grad = self.raw_grad[b_st:b_ed].cuda() # will be reshaped later\n                similarities = []\n                for p_st in range(0,self.eq_n,self.parallel_eq_n):\n                    p_ed = min(self.eq_n, p_st+self.parallel_eq_n)\n                    cur_a_interval = tmp_a_interval.repeat(1,1,p_ed-p_st) # shape: n_a,1,parallel_eq_n\n                    cur_a_interval[a:a+1,:,:] = input_interval_candidates[a:a+1,:,p_st:p_ed]\n                    # quantize weight and bias \n                    w_sim, bias_sim = self.quant_weight_bias()\n                    # quantize input\n                    x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2).unsqueeze(-1)\n                    x_sim=(x_sim/(cur_a_interval)).round_().clamp_(-self.a_qmax,self.a_qmax-1)*(cur_a_interval) # shape: b,*,n_a,crb_acts,parallel_eq_n\n                    x_sim = x_sim.permute(*list(range(len(x_sim.shape)-3)),-1,-3,-2).reshape(*x.shape[:-1],p_ed-p_st,x.shape[-1]) # shape: b,*,parallel_eq_n,ic\n                    # calculate similarity and store them\n                    out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: b,*,parallel_eq_n,oc\n                    if self.metric != \"pearson\":\n                        similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric, raw_grad) # shape: b,*,parallel_eq_n\n                        if len(similarity.shape) > 2:\n                            similarity = torch.mean(similarity, dim=list(range(1,len(similarity.shape)-1))) # shape: b, parallel_eq_n\n                    else:\n                        similarity = self._get_pearson_a(raw_out_expanded, out_sim)\n                    similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1, parallel_eq_n\n                    similarities.append(similarity)\n                # store best input interval and store in tmp_a_interval\n                similarities = torch.cat(similarities, dim=1) # shape: 1, eq_n\n                batch_similarities.append(similarities)\n            batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) # shape: eq_n\n            a_best_index = batch_similarities.argmax(dim=0, keepdim=True).reshape(1,1,-1)\n            tmp_a_interval[a:a+1,:,:] = torch.gather(input_interval_candidates[a:a+1,:,:],dim=2,index=a_best_index)\n        self.a_interval = tmp_a_interval.squeeze(-1)\n\n\n    def calibration_step2(self):\n        \"\"\"\n        Only use cached raw inputs/outs/grads\n        \"\"\"\n        self._initialize_calib_parameters()\n        self._initialize_intervals()\n\n        # prepare weight intervals and similarities\n        weight_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1) * self.w_interval.unsqueeze(0) # shape: eq_n,n_V,1,n_H,1\n        input_interval_candidates =  torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(1,1,-1) * self.a_interval.unsqueeze(-1) # shape: n_a,1,eq_n\n        for e in range(self.search_round):\n            # search for best weight interval\n            self._search_best_w_interval(weight_interval_candidates)\n            # search for best input interval\n            self._search_best_a_interval(input_interval_candidates)\n\n        self.calibrated = True\n        # self._bias_correction_quant_forward(self.raw_input.cuda()) # debugging\n        del self.raw_input, self.raw_out, self.raw_grad\n        return None\n\nclass PostGeluPTQSLBatchingQuantLinear(PTQSLBatchingQuantLinear):\n    \"\"\" \n    An Agile implementation of PostGeluPTQSLBatchingQuantLinear\n    use a_interval for positive activation quantization and a_neg_interval for negative activation quantization\n    \"\"\"\n    def __init__(self, \n        in_features: int,\n        out_features: int,\n        bias: bool = True,\n        mode = \"raw\",\n        w_bit = 8,\n        a_bit = 8,\n        bias_bit = None,\n        bias_correction = False,\n        metric=\"L2_norm\", search_round=1, eq_alpha=0, eq_beta=1, eq_n=100, parallel_eq_n=10, n_H=1, n_V=1, n_a=1, init_layerwise=False):\n        super().__init__(in_features, out_features, bias=bias, mode=mode, w_bit=w_bit, a_bit=a_bit, bias_bit=bias_bit, bias_correction=bias_correction,\n                         metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_H=n_H, n_V=n_V, n_a=n_a, init_layerwise=init_layerwise)\n        self.a_neg_interval = 0.16997124254703522/self.a_qmax\n\n    def _initialize_intervals(self):\n        # weight intervals \n        if self.init_layerwise:\n            self.w_interval=((self.weight.abs().max())/(self.w_qmax-0.5)).view(1,1,1,1).repeat(self.n_V,1,self.n_H,1)\n        else:\n            self.w_interval=(self.weight.view(self.n_V,self.crb_rows,self.n_H,self.crb_cols).abs().amax([1,3],keepdim=True)/(self.w_qmax-0.5))\n\n        # activation intervals (for positive parts)\n        if self.init_layerwise:\n            tmp_a_intervals = []\n            for b_st in range(0,self.calib_size,self.calib_batch_size):\n                b_ed = min(self.calib_size, b_st+self.calib_batch_size)\n                x_ = self.raw_input[b_st:b_ed].cuda()\n                a_interval_=(x_.max()/(self.a_qmax-0.5)).detach().view(1,1).repeat(self.n_a,1)\n                tmp_a_intervals.append(a_interval_)\n            self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=True)\n        else:\n            tmp_a_intervals = []\n            for b_st in range(0,self.calib_size,self.calib_batch_size):\n                b_ed = min(self.calib_size, b_st+self.calib_batch_size)\n                x_ = self.raw_input[b_st:b_ed].cuda()\n                a_interval_=((x_.view(*x_.shape[:-1],self.n_a,self.crb_acts).amax(list(range(len(x_.shape)-1))+[-1],keepdim=False))/(self.a_qmax-0.5)).unsqueeze(-1)\n                tmp_a_intervals.append(a_interval_)\n            self.a_interval = torch.cat(tmp_a_intervals, dim=1).amax(dim=1, keepdim=True)\n\n    def quant_input(self, x):\n        # self.a_interval shape: n_a,1\n        # self.a_neg_interval shape: 1\n        x_=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2)\n        x_pos=(x_/(self.a_interval)).round_().clamp_(0,self.a_qmax-1).mul_(self.a_interval)\n        x_neg=(x_/(self.a_neg_interval)).round_().clamp_(-self.a_qmax,0).mul_(self.a_neg_interval)\n        return (x_pos + x_neg).reshape_as(x)\n\n    def _search_best_a_interval(self, input_interval_candidates):\n        tmp_a_interval = self.a_interval.unsqueeze(-1) # shape: n_a,1,1\n        for a in range(self.n_a):\n            batch_similarities = [] # similarities, need to concatenate and calculate sum (equivalent to mean with argmax)\n            for b_st in range(0, self.calib_size, self.calib_batch_size):\n                b_ed = min(self.calib_size, b_st + self.calib_batch_size)\n                x = self.raw_input[b_st:b_ed].cuda()\n                raw_out_expanded = self.raw_out[b_st:b_ed].cuda().unsqueeze(-2) # shape: b,*,1,oc\n                raw_grad = self.raw_grad[b_st:b_ed].cuda() # will be reshaped later\n                similarities = []\n                for p_st in range(0,self.eq_n,self.parallel_eq_n):\n                    p_ed = min(self.eq_n, p_st+self.parallel_eq_n)\n                    cur_a_interval = tmp_a_interval.repeat(1,1,p_ed-p_st) # shape: n_a,1,parallel_eq_n\n                    cur_a_interval[a:a+1,:,:] = input_interval_candidates[a:a+1,:,p_st:p_ed]\n                    # quantize weight and bias \n                    w_sim, bias_sim = self.quant_weight_bias()\n                    # quantize input\n                    x_sim=torch.cat(torch.chunk(x.unsqueeze(-2), chunks=self.n_a, dim=-1), dim=-2).unsqueeze(-1)\n                    x_pos=(x_sim/(cur_a_interval)).round_().clamp_(0,self.a_qmax-1)*(cur_a_interval) # shape: b,*,n_a,crb_acts,parallel_eq_n\n                    x_neg=(x_sim/(self.a_neg_interval)).round_().clamp_(-self.a_qmax,0)*(self.a_neg_interval) # shape: b,*,n_a,crb_acts,1\n                    x_sim = (x_pos + x_neg).permute(*list(range(len(x_sim.shape)-3)),-1,-3,-2).reshape(*x.shape[:-1],p_ed-p_st,x.shape[-1]) # shape: b,*,parallel_eq_n,ic\n                    # calculate similarity and store them\n                    out_sim = F.linear(x_sim, w_sim, bias_sim) # shape: b,*,parallel_eq_n,oc\n                    similarity = self._get_similarity(raw_out_expanded, out_sim, self.metric, raw_grad) # shape: b,*,parallel_eq_n\n                    similarity = torch.mean(similarity, dim=list(range(1,len(similarity.shape)-1))) # shape: b, parallel_eq_n\n                    similarity = torch.sum(similarity, dim=0, keepdim=True) # shape: 1, parallel_eq_n\n                    similarities.append(similarity)\n                # store best input interval and store in tmp_a_interval\n                similarities = torch.cat(similarities, dim=1) # shape: 1, eq_n\n                batch_similarities.append(similarities)\n            batch_similarities = torch.cat(batch_similarities, dim=0).sum(dim=0, keepdim=False) # shape: eq_n\n            a_best_index = batch_similarities.argmax(dim=0, keepdim=True).reshape(1,1,-1)\n            tmp_a_interval[a:a+1,:,:] = torch.gather(input_interval_candidates[a:a+1,:,:],dim=2,index=a_best_index)\n        self.a_interval = tmp_a_interval.squeeze(-1)"
  },
  {
    "path": "quant_layers/matmul.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch import Tensor \nfrom torch.nn import functional as F\nfrom itertools import product     \n\nclass MinMaxQuantMatMul(nn.Module):\n    \"\"\"Matrix Multiplication base class\"\"\"\n    def __init__(self, A_bit=8, B_bit=8, mode=\"raw\"):\n        super().__init__()\n        self.A_bit=A_bit\n        self.B_bit=B_bit\n        self.A_interval=None\n        self.B_interval=None\n        self.A_qmax=2**(self.A_bit-1)\n        self.B_qmax=2**(self.B_bit-1)\n        self.mode=mode\n        self.raw_input = None\n        self.raw_out = None\n    \n    def forward(self, A,B):\n        if self.mode=='raw':\n            out=A @ B\n        elif self.mode==\"quant_forward\":\n            out=self.quant_forward(A,B)\n        elif self.mode==\"calibration_step1\":\n            out=self.calibration_step1(A,B)\n        elif self.mode==\"calibration_step2\":\n            out=self.calibration_step2(A,B)\n        else:\n            raise NotImplementedError\n        return out\n    \n    def quant_input(self,x,interval,qmax):\n        x_sim=(x/interval).round_().clamp_(-qmax,qmax-1)\n        x_sim.mul_(interval)\n        return x_sim\n    \n    def quant_forward(self,A,B):\n        assert self.calibrated is not None,f\"You should run calibrate_forward before run quant_forward for {self}\"\n        A_sim=self.quant_input(A,self.A_interval,self.A_qmax)\n        B_sim=self.quant_input(B,self.B_interval,self.B_qmax)\n        out=A_sim@B_sim\n        return out\n\n    def calibration_step1(self,A,B):\n        # step1: collection the FP32 values\n        self.raw_input=A.cpu().detach(), B.cpu().detach()\n        out=A@B\n        self.raw_out=out.cpu().detach()\n        return out\n    \n    def calibration_step2(self,A,B):\n        # step2: search for the best S^w and S^o of each layer\n        self.A_interval=(A.data.abs().max()/(self.A_qmax-0.5)).detach()\n        self.B_interval=(B.data.abs().max()/(self.B_qmax-0.5)).detach()\n        self.calibrated=True\n        out=self.quant_forward(A,B)        \n        return out\n\nclass PTQSLQuantMatMul(MinMaxQuantMatMul):\n    \"\"\"\n    Chunk matrix into blockes and quantize.\n    Chunking follows naive padding strategy.\n    Alternately search for best intervals of each individual blocks for A and B.\n\n    two different scenarios:\n    - Q @ K:\n        - A's shape: B,H,S,W\n        - B's shape: B,H,W,S\n    - scores @ V:\n        - A's shape: B,H,S,S\n        - B's shape: B,H,S,W\n    - interval shape: 1,n_G,1,n_V,1,n_H,1\n    \"\"\"\n    def __init__(self, A_bit=8, B_bit=8, mode=\"raw\",\n                 metric=\"L2_norm\", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10,\n                 n_G_A=1, n_V_A=1, n_H_A=1, n_G_B=1, n_V_B=1, n_H_B=1, init_layerwise=False):\n        super().__init__(A_bit=A_bit, B_bit=B_bit, mode=mode)\n        self.metric = metric\n        self.search_round = search_round\n        self.eq_alpha = eq_alpha\n        self.eq_beta = eq_beta\n        self.eq_n = eq_n\n        self.parallel_eq_n = parallel_eq_n\n        self.n_G_A = n_G_A\n        self.n_V_A = n_V_A\n        self.n_H_A = n_H_A\n        self.n_G_B = n_G_B\n        self.n_V_B = n_V_B\n        self.n_H_B = n_H_B\n        # init these parameters in self.calibration_step1\n        self.crb_groups_A = None\n        self.crb_groups_B = None\n        self.crb_rows_A = None\n        self.crb_cols_A = None\n        self.crb_rows_B = None\n        self.crb_cols_B = None\n        self.pad_groups_A = None\n        self.pad_groups_B = None\n        self.pad_rows_A = None\n        self.pad_rows_B = None\n        self.pad_cols_A = None\n        self.pad_cols_B = None\n        self.raw_grad = None\n        self.init_layerwise = init_layerwise\n\n    def _get_padding_parameters(self, A, B):\n        self.crb_groups_A = (A.shape[1]+self.n_G_A-1) // self.n_G_A\n        self.crb_groups_B = (B.shape[1]+self.n_G_B-1) // self.n_G_B\n        self.crb_rows_A = (A.shape[2]+self.n_V_A-1) // self.n_V_A\n        self.crb_cols_A = (A.shape[3]+self.n_H_A-1) // self.n_H_A\n        self.crb_rows_B = (B.shape[2]+self.n_V_B-1) // self.n_V_B\n        self.crb_cols_B = (B.shape[3]+self.n_H_B-1) // self.n_H_B\n\n        self.pad_groups_A = self.crb_groups_A*self.n_G_A - A.shape[1]\n        self.pad_rows_A = self.crb_rows_A*self.n_V_A - A.shape[2]\n        self.pad_cols_A = self.crb_cols_A*self.n_H_A - A.shape[3]\n        self.pad_groups_B = self.crb_groups_B*self.n_G_B - B.shape[1]\n        self.pad_rows_B = self.crb_rows_B*self.n_V_B - B.shape[2]\n        self.pad_cols_B = self.crb_cols_B*self.n_H_B - B.shape[3]\n\n    def quant_input_A(self, x):\n        x = F.pad(x, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A])\n        x = x.view(-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A)\n        x = (x/self.A_interval).round_().clamp(-self.A_qmax,self.A_qmax-1).mul_(self.A_interval)\n        x = x.view(-1,self.n_G_A*self.crb_groups_A,self.n_V_A*self.crb_rows_A,self.n_H_A*self.crb_cols_A)\n        x = x[:,:x.shape[1]-self.pad_groups_A,:x.shape[2]-self.pad_rows_A,:x.shape[3]-self.pad_cols_A]\n        return x\n    \n    def quant_input_B(self, x):\n        x = F.pad(x, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B])\n        x = x.view(-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B)\n        x = (x/self.B_interval).round_().clamp(-self.B_qmax,self.B_qmax-1).mul_(self.B_interval)\n        x = x.view(-1,self.n_G_B*self.crb_groups_B,self.n_V_B*self.crb_rows_B,self.n_H_B*self.crb_cols_B)\n        x = x[:,:x.shape[1]-self.pad_groups_B,:x.shape[2]-self.pad_rows_B,:x.shape[3]-self.pad_cols_B]\n        return x\n\n    def quant_forward(self, A, B):\n        assert self.calibrated is not None,f\"You should run calibrate_forward before run quant_forward for {self}\"\n        A_sim=self.quant_input_A(A)\n        B_sim=self.quant_input_B(B)\n        out=A_sim@B_sim\n        return out\n\n    def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1):\n        \"\"\"\n        tensor_raw: *, features, *\n        tensor_sim: *, features, *\n        similarity: *\n        It's your job to calculate mean on non-feature * dims!\n\n        Similarity without inherent feature structure is more welcome to parallelism.\n        \"\"\"\n        if metric == \"cosine\":\n            similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=dim) # should only support dim=-1 and cannot be paralleled\n        elif metric == \"pearson\":\n            similarity = F.cosine_similarity(tensor_raw-torch.mean(tensor_raw), tensor_sim-torch.mean(tensor_sim), dim=dim)\n        else:\n            if metric == \"L1_norm\":\n                similarity = -torch.abs(tensor_raw - tensor_sim)\n            elif metric == \"L2_norm\":\n                similarity = -(tensor_raw - tensor_sim) ** 2\n            elif metric == \"linear_weighted_L2_norm\":\n                similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2\n            elif metric == \"square_weighted_L2_norm\":\n                similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2\n            elif metric == \"hessian\":\n                raw_grad = self.raw_grad.reshape_as(tensor_raw)\n                similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2\n            else:\n                raise NotImplementedError(f\"metric {metric} not implemented!\")\n            similarity = torch.mean(similarity, dim=dim)\n        return similarity\n\n    def _search_best_A_interval(self, A, B, A_interval_candidates):\n        \"\"\"\n        Modularization of searching best interval\n        \"\"\"\n        # recalculate A_pad\n        A_pad = F.pad(A, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A]).unsqueeze(0).view(1,-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A)\n\n        tmp_A_interval = self.A_interval.unsqueeze(0) # shape: 1,1,n_G,1,n_V,1,n_H,1\n        # out-of-loop optimization\n        B_sim = self.quant_input_B(B).unsqueeze(0) # shape: 1,B,H,dim2,dim3\n        for v, h in product(range(self.n_V_A), range(self.n_H_A)):\n            similarities = []\n            for p_st in range(0, self.eq_n, self.parallel_eq_n):\n                p_ed = min(self.eq_n,p_st+self.parallel_eq_n)\n                # quantize A\n                cur_A_interval = tmp_A_interval.repeat(p_ed-p_st,1,1,1,1,1,1,1)\n                cur_A_interval[:,:,:,:,v:v+1,:,h:h+1,:] = A_interval_candidates[p_st:p_ed,:,:,:,v:v+1,:,h:h+1,:]\n                A_sim = (A_pad/cur_A_interval).round_().clamp_(-self.A_qmax,self.A_qmax-1).mul_(cur_A_interval)\n                A_sim = A_sim.view(p_ed-p_st,-1,A.shape[1]+self.pad_groups_A,A.shape[2]+self.pad_rows_A,A.shape[3]+self.pad_cols_A) # shape: parallel_eq_n,B,H*,dim1*,dim2* (* stand for padding)\n                A_sim = A_sim[:,:,:A.shape[1],:A.shape[2],:A.shape[3]] # shape: parallel_eq_n,B,H,dim1,dim2\n                # quantize B, this quantization is optimized out of loop\n                # calculate similarity and store them\n                out_sim = A_sim @ B_sim # shape: parallel_eq_n,B,H,dim1,dim3\n                similarity = self._get_similarity(self.raw_out, out_sim, self.metric) # shape: parallel_eq_n,B,H,dim1\n                similarity = similarity.mean([1,3]) # shape: parallel_eq_n,H (remaining mean operation will be done later on)\n                similarities.append(similarity)\n            # calculate best similarity for this block\n            similarities = torch.cat(similarities, 0) # shape: eq_n,H\n            similarities = F.pad(similarities, [0,self.pad_groups_A]).view(self.eq_n,self.n_G_A,self.crb_groups_A).mean(-1) # shape: eq_n, n_G_A\n            best_index = torch.argmax(similarities, dim=0, keepdim=False).view(1,1,-1,1,1,1,1,1)\n            tmp_A_interval[:,:,:,:,v:v+1,:,h:h+1,:] = torch.gather(A_interval_candidates[:,:,:,:,v:v+1,:,h:h+1,:],dim=0,index=best_index)\n        self.A_interval = tmp_A_interval.squeeze(0)\n\n    def _search_best_B_interval(self, A, B, B_interval_candidates):\n        \"\"\"\n        Modularization of searching best interval\n        \"\"\"\n        # recalculate B_pad\n        B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B)\n\n        tmp_B_interval = self.B_interval.unsqueeze(0) # shape: 1,1,n_G,1,n_V,1,n_H,1\n        # out-of-loop optimization\n        A_sim = self.quant_input_A(A).unsqueeze(0) # shape: 1,B,H,dim1,dim2\n        for v, h in product(range(self.n_V_B), range(self.n_H_B)):\n            similarities = []\n            for p_st in range(0, self.eq_n, self.parallel_eq_n):\n                p_ed = min(self.eq_n,p_st+self.parallel_eq_n)\n                # quantize A, this quantization is optimized out of loop\n                # quantize B\n                cur_B_interval = tmp_B_interval.repeat(p_ed-p_st,1,1,1,1,1,1,1)\n                cur_B_interval[:,:,:,:,v:v+1,:,h:h+1,:] = B_interval_candidates[p_st:p_ed,:,:,:,v:v+1,:,h:h+1,:]\n                B_sim = (B_pad/cur_B_interval).round_().clamp_(-self.B_qmax,self.B_qmax-1).mul_(cur_B_interval)\n                B_sim = B_sim.view(p_ed-p_st,-1,B.shape[1]+self.pad_groups_B,B.shape[2]+self.pad_rows_B,B.shape[3]+self.pad_cols_B) # shape: parallel_eq_n,B,H*,dim2*,dim3* (* stand for padding)\n                B_sim = B_sim[:,:,:B.shape[1],:B.shape[2],:B.shape[3]] # shape: parallel_eq_n,B,H,dim2,dim3\n                # calculate similarity and store them\n                out_sim = A_sim @ B_sim # shape: parallel_eq_n,B,H,dim1,dim3\n                similarity = self._get_similarity(self.raw_out, out_sim, self.metric) # shape: parallel_eq_n,B,H,dim1\n                similarity = similarity.mean([1,3]) # shape: parallel_eq_n,H (remaining mean operation will be done later on)\n                similarities.append(similarity)\n            # calculate best similarity for this block\n            similarities = torch.cat(similarities, 0) # shape: eq_n,H\n            similarities = F.pad(similarities, [0,self.pad_groups_B]).view(self.eq_n,self.n_G_B,self.crb_groups_B).mean(-1) # shape: eq_n, n_G_B\n            best_index = torch.argmax(similarities, dim=0, keepdim=False).view(1,1,-1,1,1,1,1,1)\n            tmp_B_interval[:,:,:,:,v:v+1,:,h:h+1,:] = torch.gather(B_interval_candidates[:,:,:,:,v:v+1,:,h:h+1,:],dim=0,index=best_index)\n        self.B_interval = tmp_B_interval.squeeze(0)\n\n    def _initialize_intervals(self, A, B):\n        # pad A and B for future quantization\n        self._get_padding_parameters(A, B) # put it here because hessian does not use calibration step 1\n        A_pad = F.pad(A, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A]).unsqueeze(0).view(1,-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A) # shape: 1,B,n_G,crb_groups,n_V,crb_rows,n_H,crb_cols\n        B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B)\n\n        # initialize intervals with minmax intervals\n        if self.init_layerwise:\n            self.A_interval = (A.abs().max()/(self.A_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_A,1,self.n_V_A,1,self.n_H_A,1)\n            self.B_interval = (B.abs().max()/(self.B_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_B,1,self.n_V_B,1,self.n_H_B,1)\n        else:\n            self.A_interval=(A_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.A_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1\n            self.B_interval=(B_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.B_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1\n\n    def calibration_step2(self, A, B):\n        # put raw outs/grads on GPU\n        self.raw_out = self.raw_out.unsqueeze(0).to(A.device)\n        self.raw_grad = self.raw_grad.to(A.device) if self.raw_grad != None else None\n\n        self._initialize_intervals(A, B)\n\n        # prepare weight intervals and similarities\n        A_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.A_interval.unsqueeze(0)\n        B_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.B_interval.unsqueeze(0)\n\n        for e in range(self.search_round):\n            # search for best A interval\n            self._search_best_A_interval(A, B, A_interval_candidates)\n            # search for best B interval\n            self._search_best_B_interval(A, B, B_interval_candidates)\n\n        # put raw data back to cpu\n        self.raw_out = self.raw_out.squeeze(0).to(\"cpu\")\n        self.raw_grad = self.raw_grad.to(\"cpu\") if self.raw_grad != None else None\n\n        # finish calibration and output the result\n        self.calibrated = True\n        del self.raw_input, self.raw_out, self.raw_grad\n        out=self.quant_forward(A,B)\n        return out    \n\nclass SoSPTQSLQuantMatMul(PTQSLQuantMatMul):\n    \"\"\"\n    Sublayerwise PTQ on matmul modules with Split-of-Softmax (SoS) on score matrix.\n    \n    Data after softmaxing has highly biased distribution, making it difficult to quantize with uniform quantization.\n    An elegant tradeoff between great majority of unimportant values and few crucial values is impossible under low bit quantization.\n    Therefore, we propose to split complete interval of (0, 1) into several smaller intervals and perform uniform quantization on each.\n    We could manually assgin or search for the best split point.\n    Currently, we only consider single split point scenarios, since this proves to be effective enough.\n\n    The algorithm no longer requires PTQSL on score matrix, and will ignore relevant parameters.\n\n    with proper hardware implementation, we don't need to use a sign bit anymore.\n    \"\"\"\n    def __init__(self, A_bit=8, B_bit=8, mode=\"raw\",\n                 metric=\"L2_norm\", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10,\n                 n_G_A=1, n_V_A=1, n_H_A=1, n_G_B=1, n_V_B=1, n_H_B=1, init_layerwise=False,\n                 split=None):\n        super().__init__(A_bit=A_bit, B_bit=B_bit, mode=mode, \n                         metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, \n                         n_G_A=n_G_A, n_V_A=n_V_A, n_H_A=n_H_A, n_G_B=n_G_B, n_V_B=n_V_B, n_H_B=n_H_B, init_layerwise=init_layerwise)\n        self.n_G_A = 1\n        self.n_V_A = 1\n        self.n_H_A = 1\n        self.A_qmax = 2**(self.A_bit-1) # well, still need it \n        self.split = split\n        if split != None:\n            self.A_interval = self.split/(self.A_qmax-1)\n\n    def quant_input_A(self, x):\n        x_high = (x.clamp(self.split, 1)*(self.A_qmax-1)).round_().clamp_(0,self.A_qmax-1)/(self.A_qmax-1)\n        x_low = (x.clamp(0, self.split)/self.A_interval).round_().clamp_(0,self.A_qmax-1)*self.A_interval\n        return x_high + x_low\n\n    def _search_best_A_interval(self, A, B, split_candidates):\n        \"\"\"\n        search for best split point\n        \"\"\"\n        # out-of-loop optimization\n        A_ = A.unsqueeze(0)\n        # B_sim = self.quant_input_B(B).unsqueeze(0) # shape: 1,B,H,dim2,dim3\n        B_sim = B.unsqueeze(0)\n\n        similarities = []\n        for i in range(len(split_candidates)):\n            # quantize A\n            cur_A_interval = split_candidates[i]/(self.A_qmax-1)\n            A_high = (A_.clamp(split_candidates[i], 1)*(self.A_qmax-1)).round_().clamp_(0,self.A_qmax-1)/(self.A_qmax-1)\n            A_low =( A_.clamp(0, split_candidates[i])/cur_A_interval).round_().clamp_(0,self.A_qmax-1)*cur_A_interval\n            A_sim = A_high + A_low # shape: 1,B,H,S,S\n            # quantize B, this quantization is optimized out of loop\n            # calculate similarity and store them (dim1=dim2=S, dim3=W)\n            out_sim = A_sim @ B_sim # shape: 1,B,H,dim1,dim3\n            similarity = self._get_similarity(self.raw_out, out_sim, self.metric) # shape: parallel_eq_n,B,H,dim1\n            similarity = similarity.mean([1,2,3]) # shape: 1\n            similarities.append(similarity)\n        # calculate best similarity for this block\n        similarities = torch.cat(similarities, 0) # shape: eq_n\n        best_index = torch.argmax(similarities, dim=0, keepdim=False)\n        self.split = split_candidates[best_index]\n        self.A_interval = self.split/(self.A_qmax-1)\n        # debugging\n        # print(f\"best split: {self.split}\")\n\n    def _initialize_intervals(self, A, B):\n        # pad A and B for future quantization\n        self._get_padding_parameters(A, B)\n        B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B)\n\n        # initialize intervals with minmax intervals\n        self.split = 0.01\n        self.A_interval = self.split/(self.A_qmax-1)\n        if self.init_layerwise:\n            self.B_interval = (B.abs().max()/(self.B_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_B,1,self.n_V_B,1,self.n_H_B,1)\n        else:\n            self.B_interval=(B_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.B_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1\n    \n    def calibration_step2(self, A, B):\n        # put raw outs/grads on GPU\n        self.raw_out = self.raw_out.unsqueeze(0).to(A.device)\n        self.raw_grad = self.raw_grad.to(A.device) if self.raw_grad != None else None\n\n        self._initialize_intervals(A, B)\n\n        # prepare weight intervals and similarities\n        A_split_candidates = torch.tensor([2**(-i) for i in range(20)]).cuda()\n        # split_eq_alpha, split_eq_beta, split_eq_n = 0.002, 0.03, 50\n        # A_split_candidates = torch.tensor([split_eq_alpha + (split_eq_beta- split_eq_alpha)*i/split_eq_n for i in range(split_eq_n + 1)]).cuda()\n        B_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.B_interval.unsqueeze(0)\n\n        for e in range(self.search_round):\n            # search for best A interval\n            self._search_best_A_interval(A, B, A_split_candidates)\n            # search for best B interval\n            self._search_best_B_interval(A, B, B_interval_candidates)\n\n        # put raw data back to cpu\n        self.raw_out = self.raw_out.squeeze(0).to(\"cpu\")\n        self.raw_grad = self.raw_grad.to(\"cpu\") if self.raw_grad != None else None\n\n        # finish calibration and output the result\n        self.calibrated = True\n        del self.raw_input, self.raw_out, self.raw_grad\n        out=self.quant_forward(A,B)\n        return out    \n\nclass PTQSLBatchingQuantMatMul(PTQSLQuantMatMul):\n    def __init__(self, A_bit=8, B_bit=8, mode=\"raw\",\n                 metric=\"L2_norm\", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10,\n                 n_G_A=1, n_V_A=1, n_H_A=1, n_G_B=1, n_V_B=1, n_H_B=1, init_layerwise=False):\n        super().__init__(A_bit=A_bit, B_bit=B_bit, mode=mode, metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, n_G_A=n_G_A, n_V_A=n_V_A, n_H_A=n_H_A, n_G_B=n_G_B, n_V_B=n_V_B, n_H_B=n_H_B, init_layerwise=init_layerwise)\n\n    def _initialize_calib_parameters(self):\n        \"\"\" \n        set parameters for feeding calibration data\n        \"\"\"\n        self.calib_size = int(self.raw_input[0].shape[0])\n        self.calib_batch_size = int(self.raw_input[0].shape[0])\n        while True:\n            numel = ((self.raw_input[0].numel()+self.raw_input[1].numel()+2*self.raw_out.numel())/self.calib_size*self.calib_batch_size) # number of parameters on GPU\n            self.parallel_eq_n = int((3*1024*1024*1024/4)//numel)\n            if self.parallel_eq_n <= 1:\n                self.calib_need_batching = True\n                self.calib_batch_size //= 2\n            else:\n                break\n\n    def _get_padding_parameters(self, A, B):\n        \"\"\"\n        We adopt a head-wise quantization here\n        \"\"\"\n        self.n_G_A = A.shape[1]\n        self.n_G_B = B.shape[1]\n        super()._get_padding_parameters(A,B)\n    \n    def _initialize_intervals(self):\n        # pad A and B for future quantization\n        self._get_padding_parameters(self.raw_input[0], self.raw_input[1]) # put it here because hessian does not use calibration step 1\n\n        # initialize intervals with minmax intervals\n        tmp_A_intervals = []\n        tmp_B_intervals = []\n        for b_st in range(0,self.calib_size,self.calib_batch_size):\n            b_ed = min(self.calib_size, b_st+self.calib_batch_size)\n            A, B = self.raw_input[0][b_st:b_ed].cuda(), self.raw_input[1][b_st:b_ed].cuda()\n            if self.init_layerwise:\n                A_interval = (A.abs().max()/(self.A_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_A,1,self.n_V_A,1,self.n_H_A,1)\n                B_interval = (B.abs().max()/(self.B_qmax-0.5)).detach().view(1,1,1,1,1,1,1).repeat(1,self.n_G_B,1,self.n_V_B,1,self.n_H_B,1)\n            else:\n                A_pad = F.pad(A, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A]).unsqueeze(0).view(1,-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A)\n                B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B)\n                A_interval=(A_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.A_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1\n                B_interval=(B_pad.abs().amax([0,1,3,5,7], keepdim=True)/(self.B_qmax-0.5)).detach().squeeze(0) # shape: 1,n_G,1,n_V,1,n_H,1\n            tmp_A_intervals.append(A_interval)\n            tmp_B_intervals.append(B_interval)\n        self.A_interval = torch.cat(tmp_A_intervals, dim=0).amax(0, keepdim=True)\n        self.B_interval = torch.cat(tmp_B_intervals, dim=0).amax(0, keepdim=True)\n\n    def _get_similarity(self, tensor_raw, tensor_sim, metric=None, dim=-1, raw_grad=None):\n        \"\"\"\n        tensor_raw: *, features, *\n        tensor_sim: *, features, *\n        similarity: *\n        It's your job to calculate mean on non-feature * dims!\n\n        Similarity without inherent feature structure is more welcome to parallelism.\n        \"\"\"\n        if metric == \"cosine\":\n            similarity = F.cosine_similarity(tensor_raw, tensor_sim, dim=dim) # should only support dim=-1 and cannot be paralleled\n        elif metric == \"pearson\":\n            similarity = F.cosine_similarity(tensor_raw-torch.mean(tensor_raw,dim=dim,keepdim=True), tensor_sim-torch.mean(tensor_sim,dim=dim,keepdim=True), dim=dim) # should only support dim=-1 and cannot be paralleled\n            # a quick implementation of pearson similarity\n            # tensor_raw: 1,B,H,dim1,dim3\n            # tensor_sim: parallel_eq_n,B,H,dim1,dim3\n            # parallel_eq_n,B,H,dim1,dim3 = tensor_sim.shape\n            # tensor_sim = tensor_sim.view(parallel_eq_n,B,-1)\n            # tensor_raw = tensor_raw.view(1,B,-1)\n            # tensor_sim_mean = tensor_sim.mean(dim=[1,2],keepdim=True)\n            # tensor_raw_mean = tensor_raw.mean(dim=[1,2],keepdim=True)\n            # similarity = F.cosine_similarity(tensor_raw-tensor_raw_mean,tensor_sim-tensor_sim_mean,dim=-1) # shape: parallel_eq_n,B\n            # similarity = similarity.reshape(parallel_eq_n,B,1,1) # restore two dims\n        else:\n            if metric == \"L1_norm\":\n                similarity = -torch.abs(tensor_raw - tensor_sim)\n            elif metric == \"L2_norm\":\n                similarity = -(tensor_raw - tensor_sim) ** 2\n            elif metric == \"linear_weighted_L2_norm\":\n                similarity = -tensor_raw.abs() * (tensor_raw - tensor_sim) ** 2\n            elif metric == \"square_weighted_L2_norm\":\n                similarity = -(tensor_raw * (tensor_raw - tensor_sim)) ** 2\n            elif metric == \"hessian\":\n                assert raw_grad != None, f\"No raw_grad in PTQSLBatchingQuantMatMul!\"\n                raw_grad = raw_grad.reshape_as(tensor_raw)\n                similarity = -(raw_grad * (tensor_raw - tensor_sim)) ** 2\n            else:\n                raise NotImplementedError(f\"metric {metric} not implemented!\")\n            similarity = torch.mean(similarity, dim=dim)\n        return similarity\n\n    def _search_best_A_interval(self, A_interval_candidates):\n        \"\"\"\n        Modularization of searching best interval\n        \"\"\"\n        tmp_A_interval = self.A_interval.unsqueeze(0) # shape: 1,1,n_G,1,n_V,1,n_H,1\n        # out-of-loop optimization\n        for v, h in product(range(self.n_V_A), range(self.n_H_A)):\n            batch_similarities = [] # similarities, need to concatenate and calculate sum\n            for b_st in range(0, self.calib_size, self.calib_batch_size):\n                b_ed = min(self.calib_size, b_st + self.calib_batch_size)\n                A = self.raw_input[0][b_st:b_ed].cuda()\n                A_pad = F.pad(A, [0,self.pad_cols_A,0,self.pad_rows_A,0,self.pad_groups_A]).unsqueeze(0).view(1,-1,self.n_G_A,self.crb_groups_A,self.n_V_A,self.crb_rows_A,self.n_H_A,self.crb_cols_A)\n                B = self.raw_input[1][b_st:b_ed].cuda()\n                B_sim = self.quant_input_B(B).unsqueeze(0) # shape: 1,b,H,dim2,dim3\n                raw_out = self.raw_out[b_st:b_ed].unsqueeze(0).cuda()\n                raw_grad = self.raw_grad[b_st:b_ed].cuda()\n                similarities = []\n                for p_st in range(0, self.eq_n, self.parallel_eq_n):\n                    p_ed = min(self.eq_n,p_st+self.parallel_eq_n)\n                    # quantize A\n                    cur_A_interval = tmp_A_interval.repeat(p_ed-p_st,1,1,1,1,1,1,1)\n                    cur_A_interval[:,:,:,:,v:v+1,:,h:h+1,:] = A_interval_candidates[p_st:p_ed,:,:,:,v:v+1,:,h:h+1,:]\n                    A_sim = (A_pad/cur_A_interval).round_().clamp_(-self.A_qmax,self.A_qmax-1).mul_(cur_A_interval)\n                    A_sim = A_sim.view(p_ed-p_st,-1,A.shape[1]+self.pad_groups_A,A.shape[2]+self.pad_rows_A,A.shape[3]+self.pad_cols_A) # shape: parallel_eq_n,B,H*,dim1*,dim2* (* stand for padding)\n                    A_sim = A_sim[:,:,:A.shape[1],:A.shape[2],:A.shape[3]] # shape: parallel_eq_n,b,H,dim1,dim2\n                    # quantize B, this quantization is optimized out of loop\n                    # calculate similarity and store them\n                    out_sim = A_sim @ B_sim # shape: parallel_eq_n,B,H,dim1,dim3\n                    similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad=raw_grad) # shape: parallel_eq_n,b,H,dim1\n                    similarity = similarity.mean([3]) # shape: parallel_eq_n,b,H (remaining mean operation will be done later on)\n                    similarity = similarity.sum(dim=1, keepdim=True) # shape: parallel_eq_n,1,H\n                    similarities.append(similarity)\n                # calculate best similarity for this block\n                similarities = torch.cat(similarities, 0) # shape: eq_n,1,H\n                batch_similarities.append(similarities)\n            batch_similarities = torch.cat(batch_similarities, dim=1).sum(dim=1, keepdim=False) #shape: eq_n,H\n            batch_similarities = F.pad(batch_similarities, [0,self.pad_groups_A]).view(self.eq_n,self.n_G_A,self.crb_groups_A).mean(-1) # shape: eq_n, n_G_A\n            best_index = torch.argmax(batch_similarities, dim=0, keepdim=False).view(1,1,-1,1,1,1,1,1)\n            tmp_A_interval[:,:,:,:,v:v+1,:,h:h+1,:] = torch.gather(A_interval_candidates[:,:,:,:,v:v+1,:,h:h+1,:],dim=0,index=best_index)\n        self.A_interval = tmp_A_interval.squeeze(0)\n\n    def _search_best_B_interval(self, B_interval_candidates):\n        \"\"\"\n        Modularization of searching best interval\n        \"\"\"\n        tmp_B_interval = self.B_interval.unsqueeze(0) # shape: 1,1,n_G,1,n_V,1,n_H,1\n        # out-of-loop optimization\n        for v, h in product(range(self.n_V_B), range(self.n_H_B)):\n            batch_similarities = [] # similarities, need to concatenate and calculate sum\n            for b_st in range(0, self.calib_size, self.calib_batch_size):\n                b_ed = min(self.calib_size, b_st + self.calib_batch_size)\n                A = self.raw_input[0][b_st:b_ed].cuda()\n                A_sim = self.quant_input_A(A).unsqueeze(0) # shape: 1,B,H,dim1,dim2\n                B = self.raw_input[1][b_st:b_ed].cuda()\n                B_pad = F.pad(B, [0,self.pad_cols_B,0,self.pad_rows_B,0,self.pad_groups_B]).unsqueeze(0).view(1,-1,self.n_G_B,self.crb_groups_B,self.n_V_B,self.crb_rows_B,self.n_H_B,self.crb_cols_B)\n                raw_out = self.raw_out[b_st:b_ed].unsqueeze(0).cuda()\n                raw_grad = self.raw_grad[b_st:b_ed].cuda()\n                similarities = []\n                for p_st in range(0, self.eq_n, self.parallel_eq_n):\n                    p_ed = min(self.eq_n,p_st+self.parallel_eq_n)\n                    # quantize A, this quantization is optimized out of loop\n                    # quantize B\n                    cur_B_interval = tmp_B_interval.repeat(p_ed-p_st,1,1,1,1,1,1,1)\n                    cur_B_interval[:,:,:,:,v:v+1,:,h:h+1,:] = B_interval_candidates[p_st:p_ed,:,:,:,v:v+1,:,h:h+1,:]\n                    B_sim = (B_pad/cur_B_interval).round_().clamp_(-self.B_qmax,self.B_qmax-1).mul_(cur_B_interval)\n                    B_sim = B_sim.view(p_ed-p_st,-1,B.shape[1]+self.pad_groups_B,B.shape[2]+self.pad_rows_B,B.shape[3]+self.pad_cols_B) # shape: parallel_eq_n,b,H*,dim2*,dim3* (* stand for padding)\n                    B_sim = B_sim[:,:,:B.shape[1],:B.shape[2],:B.shape[3]] # shape: parallel_eq_n,b,H,dim2,dim3\n                    # calculate similarity and store them\n                    out_sim = A_sim @ B_sim # shape: parallel_eq_n,b,H,dim1,dim3\n                    similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad=raw_grad) # shape: parallel_eq_n,b,H,dim1\n                    similarity = similarity.mean([3]) # shape: parallel_eq_n,b,H (remaining mean operation will be done later on)\n                    similarity = similarity.sum(dim=1, keepdim=True) # shape: parallel_eq_n,1,H\n                    similarities.append(similarity)\n                # calculate best similarity for this block\n                similarities = torch.cat(similarities, 0) # shape: eq_n,1,H\n                batch_similarities.append(similarities)\n            batch_similarities = torch.cat(batch_similarities, dim=1).sum(dim=1, keepdim=False) #shape: eq_n,H\n            batch_similarities = F.pad(batch_similarities, [0,self.pad_groups_B]).view(self.eq_n,self.n_G_B,self.crb_groups_B).mean(-1) # shape: eq_n, n_G_B\n            best_index = torch.argmax(batch_similarities, dim=0, keepdim=False).view(1,1,-1,1,1,1,1,1)\n            tmp_B_interval[:,:,:,:,v:v+1,:,h:h+1,:] = torch.gather(B_interval_candidates[:,:,:,:,v:v+1,:,h:h+1,:],dim=0,index=best_index)\n        self.B_interval = tmp_B_interval.squeeze(0)\n\n    def calibration_step2(self):\n        self._initialize_calib_parameters()\n        self._initialize_intervals()\n        A_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.A_interval.unsqueeze(0)\n        B_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.B_interval.unsqueeze(0)\n        for e in range(self.search_round):\n            # search for best A interval\n            self._search_best_A_interval(A_interval_candidates)\n            # search for best B interval\n            self._search_best_B_interval(B_interval_candidates)\n        self.calibrated = True\n        del self.raw_input, self.raw_out, self.raw_grad\n\nclass SoSPTQSLBatchingQuantMatMul(PTQSLBatchingQuantMatMul):\n    def __init__(self, A_bit=8, B_bit=8, mode=\"raw\",\n                 metric=\"L2_norm\", search_round=1, eq_alpha=0.1, eq_beta=2, eq_n=100, parallel_eq_n=10,\n                 n_G_A=1, n_V_A=1, n_H_A=1, n_G_B=1, n_V_B=1, n_H_B=1, init_layerwise=False,\n                 split=None):\n        super().__init__(A_bit=A_bit, B_bit=B_bit, mode=mode, \n                         metric=metric, search_round=search_round, eq_alpha=eq_alpha, eq_beta=eq_beta, eq_n=eq_n, parallel_eq_n=parallel_eq_n, \n                         n_G_A=n_G_A, n_V_A=n_V_A, n_H_A=n_H_A, n_G_B=n_G_B, n_V_B=n_V_B, n_H_B=n_H_B, init_layerwise=init_layerwise)\n        self.n_G_A = 1\n        self.n_V_A = 1\n        self.n_H_A = 1\n        # with proper hardware implementation, we don't need to use a sign bit anymore\n        self.A_qmax = 2**(self.A_bit-1)\n        self.split = split\n        if split != None:\n            self.A_interval = self.split/(self.A_qmax-1)\n\n    def quant_input_A(self, x):\n        x_high = (x.clamp(self.split, 1)*(self.A_qmax-1)).round_().clamp_(0,self.A_qmax-1)/(self.A_qmax-1)\n        x_low = (x.clamp(0, self.split)/self.A_interval).round_().clamp_(0,self.A_qmax-1)*self.A_interval\n        return x_high + x_low\n\n    def _search_best_A_interval(self, split_candidates):\n        batch_similarities = []\n        for b_st in range(0, self.calib_size, self.calib_batch_size):\n            b_ed = min(self.calib_size, b_st + self.calib_batch_size)\n            A = self.raw_input[0][b_st:b_ed].unsqueeze(0).cuda()\n            B = self.raw_input[1][b_st:b_ed].unsqueeze(0).cuda()\n            B_sim = B\n            raw_out = self.raw_out[b_st:b_ed].unsqueeze(0).cuda()\n            raw_grad = self.raw_grad[b_st:b_ed].cuda()\n            similarities = []\n            for i in range(len(split_candidates)):\n                # quantize A\n                cur_A_interval = split_candidates[i]/(self.A_qmax-1)\n                A_high = (A.clamp(split_candidates[i], 1)*(self.A_qmax-1)).round_().clamp_(0,self.A_qmax-1)/(self.A_qmax-1)\n                A_low =( A.clamp(0, split_candidates[i])/cur_A_interval).round_().clamp_(0,self.A_qmax-1)*cur_A_interval\n                A_sim = A_high + A_low # shape: 1,b,H,S,S\n                # quantize B, this quantization is optimized out of loop\n                # calculate similarity and store them (dim1=dim2=S, dim3=W)\n                out_sim = A_sim @ B_sim # shape: 1,b,H,dim1,dim3\n                similarity = self._get_similarity(raw_out, out_sim, self.metric, raw_grad=raw_grad) # shape: parallel_eq_n,b,H,dim1\n                similarity = similarity.mean([2,3]) # shape: parallel_eq_n, b\n                similarity = similarity.sum(dim=1,keepdim=True) # parallel_eq_n, 1\n                similarities.append(similarity)\n            # calculate best similarity for this block\n            similarities = torch.cat(similarities, 0) # shape: eq_n, 1\n            batch_similarities.append(similarities)\n        batch_similarities = torch.cat(batch_similarities, dim=1).sum(dim=1, keepdim=False) #shape: eq_n\n        best_index = torch.argmax(batch_similarities, dim=0, keepdim=False)\n        self.split = split_candidates[best_index]\n        self.A_interval = self.split/(self.A_qmax-1)\n        # debugging\n        # print(f\"best split: {self.split}\")\n\n    def calibration_step2(self):\n        self._initialize_calib_parameters()\n        self._initialize_intervals()\n        A_split_candidates = torch.tensor([2**(-i) for i in range(20)]).cuda()\n        B_interval_candidates = torch.tensor([self.eq_alpha + i*(self.eq_beta - self.eq_alpha)/self.eq_n for i in range(self.eq_n + 1)]).cuda().view(-1,1,1,1,1,1,1,1) * self.B_interval.unsqueeze(0)\n        for e in range(self.search_round):\n            # search for best A interval\n            self._search_best_A_interval(A_split_candidates)\n            # search for best B interval\n            self._search_best_B_interval(B_interval_candidates)\n        self.calibrated = True\n        del self.raw_input, self.raw_out, self.raw_grad"
  },
  {
    "path": "utils/datasets.py",
    "content": "\"\"\"\nReuse version v4\nAuthor: Hahn Yuan\n\"\"\"\nimport PIL\nimport torch\nimport argparse\nimport numpy as np\nimport os\nimport copy\nimport torchvision.transforms as transforms\nimport torchvision.datasets as datasets\nfrom torchvision.datasets import ImageFolder,DatasetFolder\nimport torch.utils.data\nimport re\nimport warnings\nfrom PIL import Image\nfrom PIL import ImageFile\nimport random\nimport torch.nn.functional as F\nfrom torch.utils.data import Dataset\n\ndef calculate_n_correct(outputs,targets):\n    _, predicted = outputs.max(1)\n    n_correct= predicted.eq(targets).sum().item()\n    return n_correct\n\nclass SetSplittor():\n    def __init__(self,fraction=0.2):\n        self.fraction=fraction\n    \n    def split(self,dataset):\n        pass\n\nclass LoaderGenerator():\n    \"\"\"\n    \"\"\"\n    def __init__(self,root,dataset_name,train_batch_size=1,test_batch_size=1,num_workers=0,kwargs={}):\n        self.root=root\n        self.dataset_name=str.lower(dataset_name)\n        self.train_batch_size=train_batch_size\n        self.test_batch_size=test_batch_size\n        self.num_workers=num_workers\n        self.kwargs=kwargs\n        self.items=[]\n        self._train_set=None\n        self._test_set=None\n        self._calib_set=None\n        self.train_transform=None\n        self.test_transform=None\n        self.train_loader_kwargs = {\n            'num_workers': self.num_workers ,\n            'pin_memory': kwargs.get('pin_memory',True),\n            'drop_last':kwargs.get('drop_last',False)\n            }\n        self.test_loader_kwargs=self.train_loader_kwargs.copy()\n        self.load()\n    \n    @property\n    def train_set(self):\n        pass\n    \n    @property\n    def test_set(self):\n        pass\n    \n    def load(self):\n        pass\n    \n    def train_loader(self):\n        assert self.train_set is not None\n        return torch.utils.data.DataLoader(self.train_set, batch_size=self.train_batch_size, shuffle=True,  **self.train_loader_kwargs)\n    \n    def test_loader(self,shuffle=False,batch_size=None):\n        assert self.test_set is not None\n        if batch_size is None:\n            batch_size=self.test_batch_size\n        return torch.utils.data.DataLoader(self.test_set, batch_size=batch_size, shuffle=shuffle,  **self.test_loader_kwargs)\n    \n    def val_loader(self):\n        assert self.val_set is not None\n        return torch.utils.data.DataLoader(self.val_set, batch_size=self.test_batch_size, shuffle=False,  **self.test_loader_kwargs)\n    \n    def trainval_loader(self):\n        assert self.trainval_set is not None\n        return torch.utils.data.DataLoader(self.trainval_set, batch_size=self.train_batch_size, shuffle=True,  **self.train_loader_kwargs)\n\n    def calib_loader(self,num=1024,seed=3):\n        if self._calib_set is None:\n            np.random.seed(seed)\n            inds=np.random.permutation(len(self.train_set))[:num]\n            self._calib_set=torch.utils.data.Subset(copy.deepcopy(self.train_set),inds)\n            self._calib_set.dataset.transform=self.test_transform\n        return torch.utils.data.DataLoader(self._calib_set, batch_size=num, shuffle=False,  **self.train_loader_kwargs)\n        \nclass CIFARLoaderGenerator(LoaderGenerator):\n    def load(self):\n        if self.dataset_name=='cifar100':\n            self.dataset_fn=datasets.CIFAR100\n            normalize = transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],\n                                             std=[0.2673, 0.2564, 0.2762])\n        elif self.dataset_name=='cifar10':\n            self.dataset_fn=datasets.CIFAR10\n            normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],\n                                             std=[0.2470, 0.2435, 0.2616])\n        else:\n            raise NotImplementedError\n        self.train_transform = transforms.Compose([\n            transforms.RandomCrop(32,padding=4),\n            transforms.RandomHorizontalFlip(),\n            transforms.ToTensor(),\n            normalize,\n        ])\n        self.test_transform = transforms.Compose([\n            transforms.ToTensor(),\n            normalize,\n        ])\n    @property\n    def train_set(self):\n        if self._train_set is None:\n            self._train_set=self.dataset_fn(self.root, train=True, download=True, transform=self.train_transform)\n        return self._train_set\n\n    @property\n    def test_set(self):\n        if self._test_set is None:\n            self._test_set=self.dataset_fn(self.root, train=False, transform=self.test_transform)\n        return self._test_set\n\nclass COCOLoaderGenerator(LoaderGenerator):\n    def load(self):\n        # download from https://github.com/pjreddie/darknet/tree/master/scripts/get_coco_dataset.sh\n        self.train_set = DetectionListDataset(os.path.join(self.root,'trainvalno5k.txt'),transform=augmentation_detection_tansforms)\n        self.test_set = DetectionListDataset(os.path.join(self.root,'5k.txt'),transform=detection_tansforms,multiscale=False)\n        self.train_loader_kwargs={\"collate_fn\":self.train_set.collate_fn}\n        self.test_loader_kwargs={\"collate_fn\":self.test_set.collate_fn}\n        \nclass DetectionListDataset(Dataset):\n    def __init__(self, list_path, img_size=416, multiscale=True, transform=None):\n        with open(list_path, \"r\") as file:\n            self.img_files = [path for path in file.readlines()]\n        self.label_files = [\n            path.replace(\"images\", \"labels\").replace(\".png\", \".txt\").replace(\".jpg\", \".txt\")\n            for path in self.img_files\n        ]\n        self.img_size = img_size\n        self.max_objects = 100\n        self.multiscale = multiscale\n        self.min_size = self.img_size - 3 * 32\n        self.max_size = self.img_size + 3 * 32\n        self.batch_count = 0\n        self.transform = transform\n\n    def __getitem__(self, index):\n        try:\n            img_path = self.img_files[index % len(self.img_files)].rstrip()\n            img = np.array(Image.open(img_path).convert('RGB'), dtype=np.uint8)\n        except Exception as e:\n            print(f\"Could not read image '{img_path}'.\")\n            return\n        try:\n            label_path = self.label_files[index % len(self.img_files)].rstrip()\n            # Ignore warning if file is empty\n            with warnings.catch_warnings():\n                warnings.simplefilter(\"ignore\")\n                boxes = np.loadtxt(label_path).reshape(-1, 5)\n        except Exception as e:\n            print(f\"Could not read label '{label_path}'.\")\n            return\n        if self.transform:\n            try:\n                img, bb_targets = self.transform((img, boxes))\n            except:\n                print(f\"Could not apply transform.\")\n                return\n        return img_path, img, bb_targets\n\n    def collate_fn(self, batch):\n        self.batch_count += 1\n        # Drop invalid images\n        batch = [data for data in batch if data is not None]\n        \n        paths, imgs, bb_targets = list(zip(*batch))\n        # Selects new image size every tenth batch\n        if self.multiscale and self.batch_count % 10 == 0:\n            self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32))\n        # Resize images to input shape\n        imgs = torch.stack([F.interpolate(img.unsqueeze(0), size=self.img_size, mode=\"nearest\").squeeze(0) for img in imgs])\n        # Add sample index to targets\n        for i, boxes in enumerate(bb_targets):\n            boxes[:, 0] = i\n        bb_targets = torch.cat(bb_targets, 0)\n        return paths, imgs, bb_targets\n\n    def __len__(self):\n        return len(self.img_files)\n\n# def faster_im_loader(path):\n#     with open(path,'rb') as f:\n#         bgr_array = TurboJPEG().decode(f.read())\n#     rgb_array=np.concatenate([bgr_array[:,:,2:3],bgr_array[:,:,1:2],bgr_array[:,:,0:1]],-1)\n#     return torch.Tensor(rgb_array)/255\n\nclass ImageNetLoaderGenerator(LoaderGenerator):\n    def load(self):\n        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n                                     std=[0.229, 0.224, 0.225])\n        self.train_transform = transforms.Compose([\n            transforms.Resize(256),\n                transforms.RandomResizedCrop(224),\n                transforms.RandomHorizontalFlip(),\n                transforms.ToTensor(),\n                normalize,\n            ])\n\n        self.test_transform = transforms.Compose([\n                transforms.Resize(256),\n                transforms.CenterCrop(224),\n                transforms.ToTensor(),\n                normalize,\n            ])\n    \n    @property\n    def train_set(self):\n        if self._train_set is None:\n            self._train_set=ImageFolder(os.path.join(self.root,'train'), self.train_transform)\n        return self._train_set\n\n    @property\n    def test_set(self):\n        if self._test_set is None:\n            self._test_set=ImageFolder(os.path.join(self.root,'val'), self.test_transform)\n        return self._test_set\n\nclass CacheDataset(Dataset):\n    def __init__(self,datas,targets) -> None:\n        super().__init__()\n        self.datas=datas\n        self.targets=targets\n        \n    def __getitem__(self,idx):\n        return self.datas[idx],self.targets[idx]\n\n    def __len__(self):\n        return len(self.datas)\n\nclass FasterImageNetLoaderGenerator(ImageNetLoaderGenerator):\n    def test_loader(self,shuffle=False,batch_size=None):\n        cache='/dev/shm/imagenet.pkl'\n        assert self.test_set is not None\n        if batch_size is None:\n            batch_size=self.test_batch_size\n        if os.path.exists(cache):\n            print(\"Loading the dataset from shared memory\")\n            datas,targets=torch.load(cache)\n        else:\n            print(\"Preprocessing the dataset and save it to shared memory\")\n            loader=torch.utils.data.DataLoader(self.test_set, batch_size=batch_size, shuffle=shuffle,  **self.test_loader_kwargs)\n            datas=[]\n            targets=[]\n            for data,target in loader:\n                datas.append(data)\n                targets.append(target)\n            datas=torch.cat(datas,0)\n            targets=torch.cat(targets,0)\n            torch.save([datas,targets],cache)\n        dataset=CacheDataset(datas,targets)\n        return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,  **self.test_loader_kwargs)\n            \nclass DebugLoaderGenerator(LoaderGenerator):\n\n    def load(self):\n        version=re.findall(\"\\d+\",self.dataset_name)[0]\n        class DebugSet(torch.utils.data.Dataset):\n            def __getitem__(self,idx):\n                if version=='0':\n                    return torch.ones([1,4,4]),0\n                if version=='1':\n                    return torch.ones([1,8,8]),0\n                if version=='2':\n                    return torch.ones([1,1,1]),0\n                if version=='3':\n                    return torch.ones([1,3,3]),0\n                else:\n                    raise NotImplementedError(f\"version {version} of Debug dataset is not supported\")\n            def __len__(self): return 1\n        self.train_set=DebugSet()\n        self.test_set=DebugSet()\n\ndef get_dataset(args:argparse.Namespace):\n    \"\"\" Preparing Datasets, args: \n        dataset (required): MNIST, cifar10/100, ImageNet, coco\n        dataset_root: str, default='./datasets'\n        num_workers: int\n        batch_size: int\n        test_batch_size: int\n        val_fraction: float, default=0\n        \n    \"\"\"\n    dataset_name=str.lower(args.dataset)\n    dataset_root=getattr(args,'dataset_root','./datasets') \n    num_workers=args.num_workers if hasattr(args,'num_workers') else 4\n    batch_size=args.batch_size if hasattr(args,'batch_size') else 64\n    test_batch_size=args.test_batch_size if hasattr(args,'test_batch_size') else batch_size\n    val_fraction=args.val_fraction if hasattr(args,\"val_fraction\") else 0\n    if \"cifar\" in dataset_name:\n        # Data loading code\n        g=CIFARLoaderGenerator(dataset_root,args.dataset,batch_size,test_batch_size,num_workers)\n    elif \"coco\" in dataset_name:\n        g=COCOLoaderGenerator(dataset_root,args.dataset,batch_size,test_batch_size,num_workers)\n    elif \"debug\" in dataset_name:\n        g=DebugLoaderGenerator(dataset_root,args.dataset,batch_size,test_batch_size,num_workers)\n    elif args.dataset=='ImageNet':\n        g=ImageNetLoaderGenerator(dataset_root,args.dataset,batch_size,test_batch_size,num_workers)\n    else:\n        raise NotImplementedError\n    return g.train_loader(),g.test_loader()\n    \n\nimport timm\nfrom timm.models.vision_transformer import VisionTransformer\nfrom timm.data import resolve_data_config\nfrom timm.data.transforms_factory import create_transform\n\nclass ViTImageNetLoaderGenerator(ImageNetLoaderGenerator):\n    \"\"\"\n    DataLoader for Vision Transformer. \n    To comply with timm's framework, we use the model's corresponding transform.\n    \"\"\"\n    def __init__(self, root, dataset_name, train_batch_size, test_batch_size, num_workers, kwargs={}):\n        kwargs.update({\"pin_memory\":False})\n        super().__init__(root, dataset_name, train_batch_size=train_batch_size, test_batch_size=test_batch_size, num_workers=num_workers, kwargs=kwargs)\n\n    def load(self):\n        model = self.kwargs.get(\"model\", None)\n        assert model != None, f\"No model in ViTImageNetLoaderGenerator!\"\n\n        config = resolve_data_config({}, model=model)\n        self.train_transform = create_transform(**config, is_training=True)\n        self.test_transform = create_transform(**config)\n\n"
  },
  {
    "path": "utils/integer.py",
    "content": "from numpy import dtype\nfrom quant_layers.matmul import MinMaxQuantMatMul, PTQSLBatchingQuantMatMul, PTQSLQuantMatMul, SoSPTQSLBatchingQuantMatMul, SoSPTQSLQuantMatMul\nfrom quant_layers.linear import MinMaxQuantLinear, PTQSLBatchingQuantLinear, PostGeluPTQSLBatchingQuantLinear, PostGeluPTQSLQuantLinear\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\ndef quantize_int_weight(module):\n    \"\"\"\n    get weight of type 'uint8' of a quantized module.\n    Bias are not quantized and you can use raw bias.\n    \"\"\"\n    assert hasattr(module, 'weight'), f\"module {module} does not have weight\"\n    assert module.w_bit == 8, f\"module {module}'s weight is quantized with {module.w_bit} bits\"\n\n    w_int = (module.weight/module.w_interval).round_().clamp_(-module.w_qmax, module.w_qmax-1)\n    w_int = w_int.cpu().detach().to(torch.int8)\n    return w_int\n\ndef dequantize_int_weight(module, w_int):\n    \"\"\"\n    Make sure it's the same module that generates w_int\n    \"\"\"\n    w_sim = module.w_interval.cpu() * w_int.float()\n    return w_sim\n\ndef quantize_matmul_input(input, interval, qmax, n_G, n_V, n_H, crb_groups, crb_rows, crb_cols):\n    \"\"\"\n    quantize input matrix of matmul operation, with respect to sublayerwise padding settings\n    \"\"\"\n    pad_groups = crb_groups*n_G - input.shape[1]\n    pad_rows = crb_rows*n_V - input.shape[2]\n    pad_cols = crb_cols*n_H - input.shape[3]\n\n    x = F.pad(input, [0,pad_cols,0,pad_rows,0,pad_groups])\n    x = x.view(-1,n_G,crb_groups,n_V,crb_rows,n_H,crb_cols)\n    x = (x/interval).round_().clamp(-qmax,qmax-1)\n    x = x.view(-1,n_G*crb_groups,n_V*crb_rows,n_H*crb_cols)\n    x = x[:,:x.shape[1]-pad_groups,:x.shape[2]-pad_rows,:x.shape[3]-pad_cols]\n\n    return x\n\n\ndef quantize_int_activation(module, input):\n    \"\"\"\n    Quantize current inputs into uint8 and store them as an attribute of the module.\n\n    The function is a pre-forward hook that need to be manually added to the calibrated model.\n    You need to manipulate the cached data before feeding another batch of pictures.\n    Currently only support int8. (For twin quantization, we use uint8)\n\n    For twin quantization:\n    - For softmax, the MSB being 1 means using large interval, while MSB being 0 means using small interval.\n    - For post-GELU, the MSB serves as sign bit. We use 1 for positive values and 0 for negative values.\n    \"\"\"\n    if isinstance(module, PostGeluPTQSLQuantLinear) or isinstance(module, PostGeluPTQSLBatchingQuantLinear):\n        assert module.a_bit == 8, f\"module {module}'s activation is quantized with {module.a_bit} bits\"\n        \n        x = input[0]\n        \n        int_input_pos = (x/module.a_interval).round_().clamp_(0, module.a_qmax-1)\n        int_input_pos = int_input_pos.detach().to(torch.uint8) + 128\n\n        int_input_neg = (x/module.a_neg_interval).round_().clamp_(-module.a_qmax+1, 0).abs()\n        int_input_neg = int_input_neg.detach().to(torch.uint8)\n\n        int_input = (int_input_pos + int_input_neg).cpu()\n        module.int_input = [int_input]\n    \n    elif isinstance(module, MinMaxQuantLinear):\n        assert module.a_bit == 8, f\"module {module}'s activation is quantized with {module.a_bit} bits\"\n\n        x = input[0]\n        int_input = (x/module.a_interval).round_().clamp_(-module.a_qmax, module.a_qmax-1)\n        int_input = int_input.cpu().detach().to(torch.int8)\n\n        module.int_input = [int_input]\n    \n    elif isinstance(module, SoSPTQSLQuantMatMul) or isinstance(module, SoSPTQSLBatchingQuantMatMul):\n        assert module.A_bit == 8, f\"module {module}'s matrix A is quantized with {module.A_bit} bits\"\n        assert module.B_bit == 8, f\"module {module}'s matrix B is quantized with {module.B_bit} bits\"\n\n        A, B = input[0], input[1]\n\n        A_high = (A.clamp(module.split, 1)*(module.A_qmax-1)).round_().clamp_(0,module.A_qmax-1)\n        A_high = A_high.detach().to(torch.uint8) + 128\n\n        A_low = (A.clamp(0, module.split)/module.A_interval).round_().clamp_(0,module.A_qmax-1)\n        A_low = A_low.detach().to(torch.uint8)\n        \n        A_int = (A_high + A_low).cpu()\n\n        B_int = quantize_matmul_input(B,module.B_interval,module.B_qmax,module.n_G_B,module.n_V_B,module.n_H_B,module.crb_groups_B,module.crb_rows_B,module.crb_cols_B)\n        B_int = B_int.cpu().detach().to(torch.int8)\n\n        module.int_input = [A_int, B_int]\n\n    elif isinstance(module, PTQSLQuantMatMul) or isinstance(module, PTQSLBatchingQuantMatMul):\n        assert module.A_bit == 8, f\"module {module}'s matrix A is quantized with {module.A_bit} bits\"\n        assert module.B_bit == 8, f\"module {module}'s matrix B is quantized with {module.B_bit} bits\"\n\n        A, B = input[0], input[1]\n\n        A_int = quantize_matmul_input(A,module.A_interval,module.A_qmax,module.n_G_A,module.n_V_A,module.n_H_A,module.crb_groups_A,module.crb_rows_A,module.crb_cols_A)\n        A_int = A_int.cpu().detach().to(torch.int8)\n\n        B_int = quantize_matmul_input(B,module.B_interval,module.B_qmax,module.n_G_B,module.n_V_B,module.n_H_B,module.crb_groups_B,module.crb_rows_B,module.crb_cols_B)\n        B_int = B_int.cpu().detach().to(torch.int8)\n\n        module.int_input = [A_int, B_int]\n\n\ndef get_model_int_weight(wrapped_modules):\n    \"\"\"\n    Get quantized weights (in int8) of a model.\n\n    Return:\n        A dict, with modules' names as keys, and int weights as values.\n    \"\"\"\n\n    int_weights = {}\n\n    for name, m in wrapped_modules.items():\n        try:\n            int_weights[name] = quantize_int_weight(m)\n        except:\n            pass\n    \n    return int_weights\n"
  },
  {
    "path": "utils/models.py",
    "content": "from types import MethodType\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport timm\nfrom timm.models import vision_transformer\nfrom timm.models.vision_transformer import Attention\nfrom timm.models.swin_transformer import WindowAttention\n\ndef attention_forward(self, x):\n    B, N, C = x.shape\n    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n    q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)\n\n    # attn = (q @ k.transpose(-2, -1)) * self.scale\n    attn = self.matmul1(q, k.transpose(-2, -1)) * self.scale\n    attn = attn.softmax(dim=-1)\n    attn = self.attn_drop(attn)\n    del q, k\n\n    # x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n    x = self.matmul2(attn, v).transpose(1, 2).reshape(B, N, C)\n    del attn, v\n    x = self.proj(x)\n    x = self.proj_drop(x)\n    return x\n\ndef window_attention_forward(self, x, mask = None):\n    B_, N, C = x.shape\n    qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n    q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)\n\n    q = q * self.scale\n    # attn = (q @ k.transpose(-2, -1))\n    attn = self.matmul1(q, k.transpose(-2,-1))\n\n    relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n        self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH\n    relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n    attn = attn + relative_position_bias.unsqueeze(0)\n\n    if mask is not None:\n        nW = mask.shape[0]\n        attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n        attn = attn.view(-1, self.num_heads, N, N)\n        attn = self.softmax(attn)\n    else:\n        attn = self.softmax(attn)\n\n    attn = self.attn_drop(attn)\n\n    # x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n    x = self.matmul2(attn, v).transpose(1, 2).reshape(B_, N, C)\n    x = self.proj(x)\n    x = self.proj_drop(x)\n    return x\n\nclass MatMul(nn.Module):\n    def forward(self, A, B):\n        return A @ B\n\ndef get_net(name):\n    \"\"\"\n    Get a vision transformer model.\n    This will replace matrix multiplication operations with matmul modules in the model.\n\n    Currently support almost all models in timm.models.transformers, including:\n    - vit_tiny/small/base/large_patch16/patch32_224/384,\n    - deit_tiny/small/base(_distilled)_patch16_224,\n    - deit_base(_distilled)_patch16_384,\n    - swin_tiny/small/base/large_patch4_window7_224,\n    - swin_base/large_patch4_window12_384\n\n    These models are finetuned on imagenet-1k and should use ViTImageNetLoaderGenerator\n    for calibration and testing.\n    \"\"\"\n    net = timm.create_model(name, pretrained=True)\n\n    for name, module in net.named_modules():\n        if isinstance(module, Attention):\n            setattr(module, \"matmul1\", MatMul())\n            setattr(module, \"matmul2\", MatMul())\n            module.forward = MethodType(attention_forward, module)\n        if isinstance(module, WindowAttention):\n            setattr(module, \"matmul1\", MatMul())\n            setattr(module, \"matmul2\", MatMul())\n            module.forward = MethodType(window_attention_forward, module)\n\n    net.cuda()\n    net.eval()\n    return net\n"
  },
  {
    "path": "utils/net_wrap.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom utils.models import MatMul\nimport re\n\n\ndef _fold_bn(conv_module, bn_module):\n    w = conv_module.weight.data\n    y_mean = bn_module.running_mean\n    y_var = bn_module.running_var\n    safe_std = torch.sqrt(y_var + bn_module.eps)\n    w_view = (conv_module.out_channels, 1, 1, 1)\n    if bn_module.affine:\n        weight = w * (bn_module.weight / safe_std).view(w_view)\n        beta = bn_module.bias - bn_module.weight * y_mean / safe_std\n        if conv_module.bias is not None:\n            bias = bn_module.weight * conv_module.bias / safe_std + beta\n        else:\n            bias = beta\n    else:\n        weight = w / safe_std.view(w_view)\n        beta = -y_mean / safe_std\n        if conv_module.bias is not None:\n            bias = conv_module.bias / safe_std + beta\n        else:\n            bias = beta\n    return weight, bias\n\ndef fold_bn_into_conv(conv_module, bn_module):\n    w, b = _fold_bn(conv_module, bn_module)\n    if conv_module.bias is None:\n        conv_module.bias = nn.Parameter(b.data)\n    else:\n        conv_module.bias.data = b.data\n    conv_module.weight.data = w.data\n\n\ndef wrap_modules_in_net(net,cfg):\n    wrapped_modules={}\n    module_dict={}\n    module_types = {\"qkv\":\"qlinear_qkv\", \"proj\":'qlinear_proj', 'fc1':'qlinear_MLP_1', 'fc2':\"qlinear_MLP_2\", 'head':'qlinear_classifier','matmul1':\"qmatmul_qk\", 'matmul2':\"qmatmul_scorev\", \"reduction\": \"qlinear_reduction\"}\n    \n    it=[(name,m) for name,m in net.named_modules()]\n    for name,m in it:\n        module_dict[name]=m\n        idx=name.rfind('.')\n        if idx==-1:\n            idx=0\n        father_name=name[:idx]\n        if father_name in module_dict:\n            father_module=module_dict[father_name]\n        else:\n            raise RuntimeError(f\"father module {father_name} not found\")\n        if isinstance(m,nn.Conv2d):\n            # Embedding Layer\n            idx = idx+1 if idx != 0 else idx\n            new_m=cfg.get_module(\"qconv\",m.in_channels,m.out_channels,m.kernel_size,m.stride,m.padding,m.dilation,m.groups,m.bias is not None,m.padding_mode)\n            new_m.weight.data=m.weight.data\n            new_m.bias=m.bias\n            replace_m=new_m\n            wrapped_modules[name] = new_m\n            setattr(father_module,name[idx:],replace_m)\n        elif isinstance(m,nn.Linear):\n            # Linear Layer\n            idx = idx+1 if idx != 0 else idx\n            new_m = cfg.get_module(module_types[name[idx:]],m.in_features,m.out_features)\n            new_m.weight.data=m.weight.data\n            new_m.bias=m.bias\n            replace_m=new_m\n            wrapped_modules[name] = new_m\n            setattr(father_module,name[idx:],replace_m)\n        elif isinstance(m,MatMul):\n            # Matmul Layer\n            idx = idx+1 if idx != 0 else idx\n            new_m = cfg.get_module(module_types[name[idx:]])\n            replace_m=new_m\n            wrapped_modules[name] = new_m\n            setattr(father_module,name[idx:],replace_m)\n    print(\"Completed net wrap.\")\n    return wrapped_modules\n\ndef wrap_certain_modules_in_net(net,cfg,layers,modules_to_wrap,wrap_embedding=False):\n    \"\"\"\n    wrap specific module inside transformer block of specific layer\n    layers: list of integers, indicating layers to wrap\n    modules_to_wrap: list of modules to wrap\n    \"\"\"\n    wrapped_modules={}\n    module_dict={}\n    module_types = {\"qkv\":\"qlinear_qkv\", \"proj\":'qlinear_proj', 'fc1':'qlinear_MLP_1', 'fc2':\"qlinear_MLP_2\", 'head':'qlinear_classifier','matmul1':\"qmatmul_qk\", 'matmul2':\"qmatmul_scorev\"}\n    \n    it=[(name,m) for name,m in net.named_modules()]\n    for name,m in it:\n        module_dict[name]=m\n        idx=name.rfind('.')\n        if idx==-1:\n            idx=0\n        father_name=name[:idx]\n        if father_name in module_dict:\n            father_module=module_dict[father_name]\n        else:\n            raise RuntimeError(f\"father module {father_name} not found\")\n        layer = re.search('\\d+', name)\n        if layer is not None: # inside a transformer block\n            layer = int(name[layer.span()[0]:layer.span()[1]])\n            if layer not in layers: continue\n        if isinstance(m,nn.Conv2d):\n            # Embedding Layer\n            idx = idx+1 if idx != 0 else idx\n            if not wrap_embedding:\n                continue  # timm patch_embed use proj as well...\n            # if name[idx:] not in modules_to_wrap: continue\n            new_m=cfg.get_module(\"qconv\",m.in_channels,m.out_channels,m.kernel_size,m.stride,m.padding,m.dilation,m.groups,m.bias is not None,m.padding_mode)\n            new_m.weight.data=m.weight.data\n            new_m.bias=m.bias\n            replace_m=new_m\n            wrapped_modules[name] = new_m\n            setattr(father_module,name[idx:],replace_m)\n        elif isinstance(m,nn.Linear):\n            # Linear Layer\n            idx = idx+1 if idx != 0 else idx\n            if name[idx:] not in modules_to_wrap: continue\n            new_m = cfg.get_module(module_types[name[idx:]],m.in_features,m.out_features)\n            new_m.weight.data=m.weight.data\n            new_m.bias=m.bias\n            replace_m=new_m\n            wrapped_modules[name] = new_m\n            setattr(father_module,name[idx:],replace_m)\n        elif isinstance(m,MatMul):\n            # Matmul Layer\n            idx = idx+1 if idx != 0 else idx\n            if name[idx:] not in modules_to_wrap: continue\n            new_m = cfg.get_module(module_types[name[idx:]])\n            replace_m=new_m\n            wrapped_modules[name] = new_m\n            setattr(father_module,name[idx:],replace_m)\n    print(\"Completed net wrap.\")\n    return wrapped_modules\n"
  },
  {
    "path": "utils/quant_calib.py",
    "content": "from numpy import isin\nimport torch\nfrom quant_layers.conv import MinMaxQuantConv2d\nfrom quant_layers.linear import MinMaxQuantLinear, PTQSLQuantLinear\nfrom quant_layers.matmul import MinMaxQuantMatMul, PTQSLQuantMatMul\nimport torch.nn.functional as F\nfrom tqdm import tqdm\n\nclass QuantCalibrator():\n    \"\"\"\n    Modularization of quant calib.\n\n    Notice: \n    all quant modules has method \"calibration_step1\" that should only store raw inputs and outputs\n    all quant modules has method \"calibration_step2\" that should only quantize its intervals\n    and we assume we could feed in all calibration data in one batch, without backward propagations\n\n    sequential calibration is memory-friendly, while parallel calibration may consume \n    hundreds of GB of memory.\n    \"\"\"\n    def __init__(self, net, wrapped_modules, calib_loader, sequential=True):\n        self.net = net\n        self.wrapped_modules = wrapped_modules\n        self.calib_loader = calib_loader\n        self.sequential = sequential\n        self.calibrated = False\n    \n    def sequential_quant_calib(self):\n        \"\"\"\n        A quick implementation of calibration.\n        Assume calibration dataset could be fed at once.\n        \"\"\"\n        # run calibration\n        n_calibration_steps=2\n        for step in range(n_calibration_steps):\n            print(f\"Start calibration step={step+1}\")\n            for name,module in self.wrapped_modules.items():\n                # corner cases for calibrated modules\n                if hasattr(module, \"calibrated\"):\n                    if step == 1:\n                        module.mode = \"raw\"\n                    elif step == 2:\n                        module.mode = \"quant_forward\"\n                else:\n                    module.mode=f'calibration_step{step+1}'\n            with torch.no_grad():\n                for inp,target in self.calib_loader:\n                    inp=inp.cuda()\n                    self.net(inp)\n        \n        # finish calibration\n        for name,module in self.wrapped_modules.items():\n            module.mode='quant_forward'\n        torch.cuda.empty_cache() # memory footprint cleanup\n        print(\"sequential calibration finished\")\n    \n    def parallel_quant_calib(self):\n        \"\"\"\n        A quick implementation of parallel quant calib\n        Assume calibration dataset could be fed at once, and memory could hold all raw inputs/outs\n        \"\"\"\n        # calibration step1: collect raw data\n        print(f\"Start calibration step=1\")\n        for name,module in self.wrapped_modules.items():\n            # corner cases for calibrated modules\n            if hasattr(module, \"calibrated\"):\n                module.mode = \"raw\"\n            else:\n                module.mode=f'calibration_step1'\n        with torch.no_grad():\n            for inp,target in self.calib_loader:\n                inp=inp.cuda()\n                self.net(inp)\n        # calibration step2: each module run calibration with collected raw data\n        for name,module in self.wrapped_modules.items():\n            if hasattr(module, \"calibrated\"):\n                continue\n            else:\n                module.mode=f\"calibration_step2\"\n                with torch.no_grad():\n                    if isinstance(module, MinMaxQuantLinear):\n                        module.forward(module.raw_input.cuda())\n                    elif isinstance(module, MinMaxQuantConv2d):\n                        module.forward(module.raw_input.cuda())\n                    elif isinstance(module, MinMaxQuantMatMul):\n                        module.forward(module.raw_input[0].cuda(), module.raw_input[1].cuda())\n                    torch.cuda.empty_cache()\n                \n        # finish calibration\n        for name,module in self.wrapped_modules.items():\n            module.mode='quant_forward'\n        torch.cuda.empty_cache() # memory footprint cleanup\n        print(\"calibration finished\")\n    \n    def quant_calib(self):\n        calib_layers=[]\n        for name,module in self.wrapped_modules.items():\n            calib_layers.append(name)\n        print(f\"prepare parallel calibration for {calib_layers}\")\n        if self.sequential:\n            self.sequential_quant_calib()\n        else:\n            self.parallel_quant_calib()\n        self.calibrated = True\n\n    def batching_quant_calib(self):\n        calib_layers=[]\n        for name,module in self.wrapped_modules.items():\n            calib_layers.append(name)\n        print(f\"prepare parallel calibration for {calib_layers}\")\n\n        print(\"start calibration\")\n\n        # assume wrapped modules are in order (true for dict in python>=3.5)\n        q = tqdm(self.wrapped_modules.items(), desc=\"Brecq\")\n        for name, module in q:\n            q.set_postfix_str(name)\n\n            # add fp and bp hooks to current modules, which bypass calibration step 1\n            # precedent modules are using quant forward\n            hooks = []\n            if isinstance(module, MinMaxQuantLinear):\n                hooks.append(module.register_forward_hook(linear_forward_hook))\n            if isinstance(module, MinMaxQuantConv2d):\n                hooks.append(module.register_forward_hook(conv2d_forward_hook))\n            if isinstance(module, MinMaxQuantMatMul):\n                hooks.append(module.register_forward_hook(matmul_forward_hook))\n            \n            # feed in calibration data, and store the data\n            for inp, target in self.calib_loader:\n                for batch_st in range(0,self.calib_loader.batch_size,self.batch_size):\n                    self.net.zero_grad()\n                    inp_ = inp[batch_st:batch_st+self.batch_size].cuda()\n                    self.net(inp_)\n                del inp, target\n                torch.cuda.empty_cache()\n            \n            # replace cached raw_inputs, raw_outs\n            if isinstance(module, MinMaxQuantLinear):\n                module.raw_input = torch.cat(module.raw_input, dim=0)\n                module.raw_out = torch.cat(module.raw_out, dim=0)\n            if isinstance(module, MinMaxQuantConv2d):\n                module.raw_input = torch.cat(module.raw_input, dim=0)\n                module.raw_out = torch.cat(module.raw_out, dim=0)\n            if isinstance(module, MinMaxQuantMatMul):\n                module.raw_input = [torch.cat(_, dim=0) for _ in module.raw_input]\n                module.raw_out = torch.cat(module.raw_out, dim=0)\n            for hook in hooks:\n                hook.remove()\n\n            # run calibration step2\n            with torch.no_grad():\n                if isinstance(module, MinMaxQuantLinear):\n                    module.calibration_step2()\n                if isinstance(module, MinMaxQuantConv2d):\n                    module.calibration_step2()\n                if isinstance(module, MinMaxQuantMatMul):\n                    module.calibration_step2()\n                torch.cuda.empty_cache()\n            \n            # finishing up current module calibration\n            if self.sequential:\n                module.mode = \"quant_forward\"\n            else:\n                module.mode = \"raw\"\n\n        # finish calibration\n        for name, module in self.wrapped_modules.items():\n            module.mode = \"quant_forward\"\n        \n        print(\"calibration finished\")\n\ndef grad_hook(module, grad_input, grad_output):\n    if module.raw_grad is None:\n        module.raw_grad = []\n    module.raw_grad.append(grad_output[0].cpu().detach())   # that's a tuple!\n\ndef linear_forward_hook(module, input, output):\n    if module.raw_input is None:\n        module.raw_input = []\n    if module.raw_out is None:\n        module.raw_out = []\n    module.raw_input.append(input[0].cpu().detach())\n    module.raw_out.append(output.cpu().detach())\n\ndef conv2d_forward_hook(module, input, output):\n    if module.raw_input is None:\n        module.raw_input = []\n    if module.raw_out is None:\n        module.raw_out = []\n    module.raw_input.append(input[0].cpu().detach())\n    module.raw_out.append(output.cpu().detach())\n\ndef matmul_forward_hook(module, input, output):\n    if module.raw_input is None:\n        module.raw_input = [[],[]]\n    if module.raw_out is None:\n        module.raw_out = []\n    module.raw_input[0].append(input[0].cpu().detach())\n    module.raw_input[1].append(input[1].cpu().detach())\n    module.raw_out.append(output.cpu().detach())\n\nclass HessianQuantCalibrator(QuantCalibrator):\n    \"\"\"\n    Modularization of hessian_quant_calib\n\n    Hessian metric needs gradients of layer outputs to weigh the loss,\n    which calls for back propagation in calibration, both sequentially\n    and parallelly. Despite the complexity of bp, hessian quant calibrator\n    is compatible with other non-gradient quantization metrics.\n    \"\"\"\n    def __init__(self, net, wrapped_modules, calib_loader, sequential=False, batch_size=1):\n        super().__init__(net, wrapped_modules, calib_loader, sequential=sequential)\n        self.batch_size = batch_size\n\n    def quant_calib(self):\n        \"\"\"\n        An implementation of original hessian calibration.\n        \"\"\"\n\n        calib_layers=[]\n        for name,module in self.wrapped_modules.items():\n            calib_layers.append(name)\n        print(f\"prepare parallel calibration for {calib_layers}\")\n\n        print(\"start hessian calibration\")\n\n        # get raw_pred as target distribution \n        with torch.no_grad():\n            for inp, _ in self.calib_loader:\n                raw_pred = self.net(inp.cuda())\n                raw_pred_softmax = F.softmax(raw_pred, dim=-1).detach()\n            torch.cuda.empty_cache()\n\n        # assume wrapped modules are in order (true for dict in python>=3.5)\n        q = tqdm(self.wrapped_modules.items(), desc=\"Brecq\")\n        for name, module in q:\n            q.set_postfix_str(name)\n\n            # add fp and bp hooks to current modules, which bypass calibration step 1\n            # precedent modules are using quant forward\n            hooks = []\n            if isinstance(module, MinMaxQuantLinear):\n                hooks.append(module.register_forward_hook(linear_forward_hook))\n            if isinstance(module, MinMaxQuantConv2d):\n                hooks.append(module.register_forward_hook(conv2d_forward_hook))\n            if isinstance(module, MinMaxQuantMatMul):\n                hooks.append(module.register_forward_hook(matmul_forward_hook))\n            if hasattr(module, \"metric\") and module.metric == \"hessian\":\n                hooks.append(module.register_backward_hook(grad_hook))\n            \n            # feed in calibration data, and store the data\n            for inp, target in self.calib_loader:\n                for batch_st in range(0,self.calib_loader.batch_size,self.batch_size):\n                    self.net.zero_grad()\n                    inp_ = inp[batch_st:batch_st+self.batch_size].cuda()\n                    pred = self.net(inp_)\n                    loss = F.kl_div(F.log_softmax(pred, dim=-1), raw_pred_softmax[batch_st:batch_st+self.batch_size], reduction=\"batchmean\")\n                    loss.backward()\n                del inp, target, pred, loss\n                torch.cuda.empty_cache()\n            \n            # replace cached raw_inputs, raw_outs\n            if isinstance(module, MinMaxQuantLinear):\n                module.raw_input = torch.cat(module.raw_input, dim=0)\n                module.raw_out = torch.cat(module.raw_out, dim=0)\n            if isinstance(module, MinMaxQuantConv2d):\n                module.raw_input = torch.cat(module.raw_input, dim=0)\n                module.raw_out = torch.cat(module.raw_out, dim=0)\n            if isinstance(module, MinMaxQuantMatMul):\n                module.raw_input = [torch.cat(_, dim=0) for _ in module.raw_input]\n                module.raw_out = torch.cat(module.raw_out, dim=0)\n            if hasattr(module, \"metric\") and module.metric == \"hessian\":\n                module.raw_grad = torch.cat(module.raw_grad, dim=0)\n            for hook in hooks:\n                hook.remove()\n\n            # run calibration step2\n            with torch.no_grad():\n                if isinstance(module, MinMaxQuantLinear):\n                    module.calibration_step2(module.raw_input.cuda())\n                if isinstance(module, MinMaxQuantConv2d):\n                    module.calibration_step2(module.raw_input.cuda())\n                if isinstance(module, MinMaxQuantMatMul):\n                    module.calibration_step2(module.raw_input[0].cuda(), module.raw_input[1].cuda())\n                torch.cuda.empty_cache()\n            \n            # finishing up current module calibration\n            if self.sequential:\n                module.mode = \"quant_forward\"\n            else:\n                module.mode = \"raw\"\n\n        # finish calibration\n        for name, module in self.wrapped_modules.items():\n            module.mode = \"quant_forward\"\n        \n        print(\"hessian calibration finished\")\n\n    def batching_quant_calib(self):\n        calib_layers=[]\n        for name,module in self.wrapped_modules.items():\n            calib_layers.append(name)\n        print(f\"prepare parallel calibration for {calib_layers}\")\n\n        print(\"start hessian calibration\")\n\n        # get raw_pred as target distribution \n        with torch.no_grad():\n            for inp, _ in self.calib_loader:\n                raw_pred = self.net(inp.cuda())\n                raw_pred_softmax = F.softmax(raw_pred, dim=-1).detach()\n            torch.cuda.empty_cache()\n\n        # assume wrapped modules are in order (true for dict in python>=3.5)\n        q = tqdm(self.wrapped_modules.items(), desc=\"Hessian\")\n        for name, module in q:\n            q.set_postfix_str(name)\n\n            # add fp and bp hooks to current modules, which bypass calibration step 1\n            # precedent modules are using quant forward\n            hooks = []\n            if isinstance(module, MinMaxQuantLinear):\n                hooks.append(module.register_forward_hook(linear_forward_hook))\n            if isinstance(module, MinMaxQuantConv2d):\n                hooks.append(module.register_forward_hook(conv2d_forward_hook))\n            if isinstance(module, MinMaxQuantMatMul):\n                hooks.append(module.register_forward_hook(matmul_forward_hook))\n            if hasattr(module, \"metric\"):\n                hooks.append(module.register_backward_hook(grad_hook))\n            \n            # feed in calibration data, and store the data\n            for inp, target in self.calib_loader:\n                for batch_st in range(0,self.calib_loader.batch_size,self.batch_size):\n                    self.net.zero_grad()\n                    inp_ = inp[batch_st:batch_st+self.batch_size].cuda()\n                    pred = self.net(inp_)\n                    loss = F.kl_div(F.log_softmax(pred, dim=-1), raw_pred_softmax[batch_st:batch_st+self.batch_size], reduction=\"batchmean\")\n                    loss.backward()\n                del inp, target, pred, loss\n                torch.cuda.empty_cache()\n            \n            # replace cached raw_inputs, raw_outs\n            if isinstance(module, MinMaxQuantLinear):\n                module.raw_input = torch.cat(module.raw_input, dim=0)\n                module.raw_out = torch.cat(module.raw_out, dim=0)\n            if isinstance(module, MinMaxQuantConv2d):\n                module.raw_input = torch.cat(module.raw_input, dim=0)\n                module.raw_out = torch.cat(module.raw_out, dim=0)\n            if isinstance(module, MinMaxQuantMatMul):\n                module.raw_input = [torch.cat(_, dim=0) for _ in module.raw_input]\n                module.raw_out = torch.cat(module.raw_out, dim=0)\n            if hasattr(module, \"metric\"):\n                module.raw_grad = torch.cat(module.raw_grad, dim=0)\n            for hook in hooks:\n                hook.remove()\n\n            # run calibration step2\n            with torch.no_grad():\n                if isinstance(module, MinMaxQuantLinear):\n                    module.calibration_step2()\n                if isinstance(module, MinMaxQuantConv2d):\n                    module.calibration_step2()\n                if isinstance(module, MinMaxQuantMatMul):\n                    module.calibration_step2()\n                torch.cuda.empty_cache()\n            \n            # finishing up current module calibration\n            if self.sequential:\n                module.mode = \"quant_forward\"\n            else:\n                module.mode = \"raw\"\n\n        # finish calibration\n        for name, module in self.wrapped_modules.items():\n            module.mode = \"quant_forward\"\n        \n        print(\"hessian calibration finished\")"
  }
]