[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2021 xmu-xiaoma666\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "\n<img src=\"./FightingCVimg/LOGO.gif\" height=\"200\" width=\"400\"/>\n\n简体中文 | [English](./README_EN.md)\n\n# FightingCV 代码库， 包含 [***Attention***](#attention-series),[***Backbone***](#backbone-series), [***MLP***](#mlp-series), [***Re-parameter***](#re-parameter-series), [**Convolution**](#convolution-series)\n\n![](https://img.shields.io/badge/fightingcv-v0.0.1-brightgreen)\n![](https://img.shields.io/badge/python->=v3.0-blue)\n![](https://img.shields.io/badge/pytorch->=v1.4-red)\n\n<!--\n-------\n*If this project is helpful to you, welcome to give a ***star***.* \n\n*Don't forget to ***follow*** me to learn about project updates.*\n\n-->\n\n\n<!--\n\n\nHello，大家好，我是小马🚀🚀🚀\n\n***For 小白（Like Me）：***\n最近在读论文的时候会发现一个问题，有时候论文核心思想非常简单，核心代码可能也就十几行。但是打开作者release的源码时，却发现提出的模块嵌入到分类、检测、分割等任务框架中，导致代码比较冗余，对于特定任务框架不熟悉的我，**很难找到核心代码**，导致在论文和网络思想的理解上会有一定困难。\n\n***For 进阶者（Like You）：***\n如果把Conv、FC、RNN这些基本单元看做小的Lego积木，把Transformer、ResNet这些结构看成已经搭好的Lego城堡。那么本项目提供的模块就是一个个具有完整语义信息的Lego组件。**让科研工作者们避免反复造轮子**，只需思考如何利用这些“Lego组件”，搭建出更多绚烂多彩的作品。\n\n***For 大神（May Be Like You）：***\n能力有限，**不喜轻喷**！！！\n\n***For All：***\n本项目致力于实现一个既能**让深度学习小白也能搞懂**，又能**服务科研和工业社区**的代码库。\n\n-->\n\n<!--\n\n作为[**FightingCV公众号**](https://mp.weixin.qq.com/s/m9RiivbbDPdjABsTd6q8FA)和 **[FightingCV-Paper-Reading](https://github.com/xmu-xiaoma666/FightingCV-Paper-Reading)** 的补充，本项目的宗旨是从代码角度，实现🚀**让世界上没有难读的论文**🚀。\n\n\n（同时也非常欢迎各位科研工作者将自己的工作的核心代码整理到本项目中，推动科研社区的发展，会在readme中注明代码的作者~）\n\n\n\n\n## 技术交流 <img title=\"\" src=\"https://user-images.githubusercontent.com/48054808/157800467-2a9946ad-30d1-49a9-b9db-ba33413d9c90.png\" alt=\"\" width=\"20\">\n\n欢迎大家关注公众号：**FightingCV**\n\n\n\n| FightingCV公众号 | 小助手微信 （备注【**公司/学校+方向+ID**】）|\n:-------------------------:|:-------------------------:\n<img src='./FightingCVimg/FightingCV.jpg' width='200px'>  |  <img src='./FightingCVimg/xiaozhushou.jpg' width='200px'> \n\n- 公众号**每天**都会进行**论文、算法和代码的干货分享**哦~\n\n- **交流群每天分享一些最新的论文和解析**，欢迎大家一起**学习交流**哈~~~\n\n- 强烈推荐大家关注[**知乎**](https://www.zhihu.com/people/jason-14-58-38/posts)账号和[**FightingCV公众号**](https://mp.weixin.qq.com/s/m9RiivbbDPdjABsTd6q8FA)，可以快速了解到最新优质的干货资源。\n\n\n-------\n\n\n-->\n\n## 🌟 Star History\n\n\n[![Star History Chart](https://api.star-history.com/svg?repos=xmu-xiaoma666/External-Attention-pytorch&type=Date)](https://star-history.com/#xmu-xiaoma666/External-Attention-pytorch&Date)\n\n## 使用\n\n### 安装\n\n 直接通过 pip 安装\n\n  ```shell\n  pip install fightingcv-attention\n  ```\n\n\n或克隆该仓库\n\n  ```shell\n  git clone https://github.com/xmu-xiaoma666/External-Attention-pytorch.git\n\n  cd External-Attention-pytorch\n  ```\n\n### 演示\n\n#### 使用 pip 方式\n```python\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n# 使用 pip 方式\n\nfrom fightingcv_attention.attention.MobileViTv2Attention import *\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    sa = MobileViTv2Attention(d_model=512)\n    output=sa(input)\n    print(output.shape)\n```\n\n - pip包 内置模块使用参考: [fightingcv-attention 说明文档](./README_pip.md)\n\n#### 使用 git 方式\n```python\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n# 与 pip方式 区别在于 将 `fightingcv_attention` 替换 `model`\n\nfrom model.attention.MobileViTv2Attention import *\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    sa = MobileViTv2Attention(d_model=512)\n    output=sa(input)\n    print(output.shape)\n```\n\n-------\n\n\n\n# 目录\n\n- [Attention Series](#attention-series)\n    - [1. External Attention Usage](#1-external-attention-usage)\n\n    - [2. Self Attention Usage](#2-self-attention-usage)\n\n    - [3. Simplified Self Attention Usage](#3-simplified-self-attention-usage)\n\n    - [4. Squeeze-and-Excitation Attention Usage](#4-squeeze-and-excitation-attention-usage)\n\n    - [5. SK Attention Usage](#5-sk-attention-usage)\n\n    - [6. CBAM Attention Usage](#6-cbam-attention-usage)\n\n    - [7. BAM Attention Usage](#7-bam-attention-usage)\n    \n    - [8. ECA Attention Usage](#8-eca-attention-usage)\n\n    - [9. DANet Attention Usage](#9-danet-attention-usage)\n\n    - [10. Pyramid Split Attention (PSA) Usage](#10-Pyramid-Split-Attention-Usage)\n\n    - [11. Efficient Multi-Head Self-Attention(EMSA) Usage](#11-Efficient-Multi-Head-Self-Attention-Usage)\n\n    - [12. Shuffle Attention Usage](#12-Shuffle-Attention-Usage)\n    \n    - [13. MUSE Attention Usage](#13-MUSE-Attention-Usage)\n  \n    - [14. SGE Attention Usage](#14-SGE-Attention-Usage)\n\n    - [15. A2 Attention Usage](#15-A2-Attention-Usage)\n\n    - [16. AFT Attention Usage](#16-AFT-Attention-Usage)\n\n    - [17. Outlook Attention Usage](#17-Outlook-Attention-Usage)\n\n    - [18. ViP Attention Usage](#18-ViP-Attention-Usage)\n\n    - [19. CoAtNet Attention Usage](#19-CoAtNet-Attention-Usage)\n\n    - [20. HaloNet Attention Usage](#20-HaloNet-Attention-Usage)\n\n    - [21. Polarized Self-Attention Usage](#21-Polarized-Self-Attention-Usage)\n\n    - [22. CoTAttention Usage](#22-CoTAttention-Usage)\n\n    - [23. Residual Attention Usage](#23-Residual-Attention-Usage)\n  \n    - [24. S2 Attention Usage](#24-S2-Attention-Usage)\n\n    - [25. GFNet Attention Usage](#25-GFNet-Attention-Usage)\n\n    - [26. Triplet Attention Usage](#26-TripletAttention-Usage)\n\n    - [27. Coordinate Attention Usage](#27-Coordinate-Attention-Usage)\n\n    - [28. MobileViT Attention Usage](#28-MobileViT-Attention-Usage)\n\n    - [29. ParNet Attention Usage](#29-ParNet-Attention-Usage)\n\n    - [30. UFO Attention Usage](#30-UFO-Attention-Usage)\n\n    - [31. ACmix Attention Usage](#31-Acmix-Attention-Usage)\n  \n    - [32. MobileViTv2 Attention Usage](#32-MobileViTv2-Attention-Usage)\n\n    - [33. DAT Attention Usage](#33-DAT-Attention-Usage)\n\n    - [34. CrossFormer Attention Usage](#34-CrossFormer-Attention-Usage)\n\n    - [35. MOATransformer Attention Usage](#35-MOATransformer-Attention-Usage)\n\n    - [36. CrissCrossAttention Attention Usage](#36-CrissCrossAttention-Attention-Usage)\n\n    - [37. Axial_attention Attention Usage](#37-Axial_attention-Attention-Usage)\n\n- [Backbone Series](#Backbone-series)\n\n    - [1. ResNet Usage](#1-ResNet-Usage)\n\n    - [2. ResNeXt Usage](#2-ResNeXt-Usage)\n\n    - [3. MobileViT Usage](#3-MobileViT-Usage)\n\n    - [4. ConvMixer Usage](#4-ConvMixer-Usage)\n\n    - [5. ShuffleTransformer Usage](#5-ShuffleTransformer-Usage)\n\n    - [6. ConTNet Usage](#6-ConTNet-Usage)\n\n    - [7. HATNet Usage](#7-HATNet-Usage)\n\n    - [8. CoaT Usage](#8-CoaT-Usage)\n\n    - [9. PVT Usage](#9-PVT-Usage)\n\n    - [10. CPVT Usage](#10-CPVT-Usage)\n\n    - [11. PIT Usage](#11-PIT-Usage)\n\n    - [12. CrossViT Usage](#12-CrossViT-Usage)\n\n    - [13. TnT Usage](#13-TnT-Usage)\n\n    - [14. DViT Usage](#14-DViT-Usage)\n\n    - [15. CeiT Usage](#15-CeiT-Usage)\n\n    - [16. ConViT Usage](#16-ConViT-Usage)\n\n    - [17. CaiT Usage](#17-CaiT-Usage)\n\n    - [18. PatchConvnet Usage](#18-PatchConvnet-Usage)\n\n    - [19. DeiT Usage](#19-DeiT-Usage)\n\n    - [20. LeViT Usage](#20-LeViT-Usage)\n\n    - [21. VOLO Usage](#21-VOLO-Usage)\n    \n    - [22. Container Usage](#22-Container-Usage)\n\n    - [23. CMT Usage](#23-CMT-Usage)\n\n    - [24. EfficientFormer Usage](#24-EfficientFormer-Usage)\n\n    - [25. ConvNeXtV2 Usage](#25-ConvNeXtV2-Usage)\n\n\n\n- [MLP Series](#mlp-series)\n\n    - [1. RepMLP Usage](#1-RepMLP-Usage)\n\n    - [2. MLP-Mixer Usage](#2-MLP-Mixer-Usage)\n\n    - [3. ResMLP Usage](#3-ResMLP-Usage)\n\n    - [4. gMLP Usage](#4-gMLP-Usage)\n\n    - [5. sMLP Usage](#5-sMLP-Usage)\n\n    - [6. vip-mlp Usage](#6-vip-mlp-Usage)\n\n- [Re-Parameter(ReP) Series](#Re-Parameter-series)\n\n    - [1. RepVGG Usage](#1-RepVGG-Usage)\n\n    - [2. ACNet Usage](#2-ACNet-Usage)\n\n    - [3. Diverse Branch Block(DDB) Usage](#3-Diverse-Branch-Block-Usage)\n\n- [Convolution Series](#Convolution-series)\n\n    - [1. Depthwise Separable Convolution Usage](#1-Depthwise-Separable-Convolution-Usage)\n\n    - [2. MBConv Usage](#2-MBConv-Usage)\n\n    - [3. Involution Usage](#3-Involution-Usage)\n\n    - [4. DynamicConv Usage](#4-DynamicConv-Usage)\n\n    - [5. CondConv Usage](#5-CondConv-Usage)\n\n***\n\n\n\n# Attention Series\n\n- Pytorch implementation of [\"Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks---arXiv 2021.05.05\"](https://arxiv.org/abs/2105.02358)\n\n- Pytorch implementation of [\"Attention Is All You Need---NIPS2017\"](https://arxiv.org/pdf/1706.03762.pdf)\n\n- Pytorch implementation of [\"Squeeze-and-Excitation Networks---CVPR2018\"](https://arxiv.org/abs/1709.01507)\n\n- Pytorch implementation of [\"Selective Kernel Networks---CVPR2019\"](https://arxiv.org/pdf/1903.06586.pdf)\n\n- Pytorch implementation of [\"CBAM: Convolutional Block Attention Module---ECCV2018\"](https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf)\n\n- Pytorch implementation of [\"BAM: Bottleneck Attention Module---BMCV2018\"](https://arxiv.org/pdf/1807.06514.pdf)\n\n- Pytorch implementation of [\"ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks---CVPR2020\"](https://arxiv.org/pdf/1910.03151.pdf)\n\n- Pytorch implementation of [\"Dual Attention Network for Scene Segmentation---CVPR2019\"](https://arxiv.org/pdf/1809.02983.pdf)\n\n- Pytorch implementation of [\"EPSANet: An Efficient Pyramid Split Attention Block on Convolutional Neural Network---arXiv 2021.05.30\"](https://arxiv.org/pdf/2105.14447.pdf)\n\n- Pytorch implementation of [\"ResT: An Efficient Transformer for Visual Recognition---arXiv 2021.05.28\"](https://arxiv.org/abs/2105.13677)\n\n- Pytorch implementation of [\"SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS---ICASSP 2021\"](https://arxiv.org/pdf/2102.00240.pdf)\n\n- Pytorch implementation of [\"MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning---arXiv 2019.11.17\"](https://arxiv.org/abs/1911.09483)\n\n- Pytorch implementation of [\"Spatial Group-wise Enhance: Improving Semantic Feature Learning in Convolutional Networks---arXiv 2019.05.23\"](https://arxiv.org/pdf/1905.09646.pdf)\n\n- Pytorch implementation of [\"A2-Nets: Double Attention Networks---NIPS2018\"](https://arxiv.org/pdf/1810.11579.pdf)\n\n\n- Pytorch implementation of [\"An Attention Free Transformer---ICLR2021 (Apple New Work)\"](https://arxiv.org/pdf/2105.14103v1.pdf)\n\n\n- Pytorch implementation of [VOLO: Vision Outlooker for Visual Recognition---arXiv 2021.06.24\"](https://arxiv.org/abs/2106.13112) \n  [【论文解析】](https://zhuanlan.zhihu.com/p/385561050)\n\n\n- Pytorch implementation of [Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition---arXiv 2021.06.23](https://arxiv.org/abs/2106.12368) \n  [【论文解析】](https://mp.weixin.qq.com/s/5gonUQgBho_m2O54jyXF_Q)\n\n\n- Pytorch implementation of [CoAtNet: Marrying Convolution and Attention for All Data Sizes---arXiv 2021.06.09](https://arxiv.org/abs/2106.04803) \n  [【论文解析】](https://zhuanlan.zhihu.com/p/385578588)\n\n\n- Pytorch implementation of [Scaling Local Self-Attention for Parameter Efficient Visual Backbones---CVPR2021 Oral](https://arxiv.org/pdf/2103.12731.pdf)  [【论文解析】](https://zhuanlan.zhihu.com/p/388598744)\n\n\n\n- Pytorch implementation of [Polarized Self-Attention: Towards High-quality Pixel-wise Regression---arXiv 2021.07.02](https://arxiv.org/abs/2107.00782)  [【论文解析】](https://zhuanlan.zhihu.com/p/389770482) \n\n\n- Pytorch implementation of [Contextual Transformer Networks for Visual Recognition---arXiv 2021.07.26](https://arxiv.org/abs/2107.12292)  [【论文解析】](https://zhuanlan.zhihu.com/p/394795481) \n\n\n- Pytorch implementation of [Residual Attention: A Simple but Effective Method for Multi-Label Recognition---ICCV2021](https://arxiv.org/abs/2108.02456) \n\n\n- Pytorch implementation of [S²-MLPv2: Improved Spatial-Shift MLP Architecture for Vision---arXiv 2021.08.02](https://arxiv.org/abs/2108.01072) [【论文解析】](https://zhuanlan.zhihu.com/p/397003638) \n\n- Pytorch implementation of [Global Filter Networks for Image Classification---arXiv 2021.07.01](https://arxiv.org/abs/2107.00645) \n\n- Pytorch implementation of [Rotate to Attend: Convolutional Triplet Attention Module---WACV 2021](https://arxiv.org/abs/2010.03045) \n\n- Pytorch implementation of [Coordinate Attention for Efficient Mobile Network Design ---CVPR 2021](https://arxiv.org/abs/2103.02907)\n\n- Pytorch implementation of [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2021.10.05](https://arxiv.org/abs/2110.02178)\n\n- Pytorch implementation of [Non-deep Networks---ArXiv 2021.10.20](https://arxiv.org/abs/2110.07641)\n\n- Pytorch implementation of [UFO-ViT: High Performance Linear Vision Transformer without Softmax---ArXiv 2021.09.29](https://arxiv.org/abs/2109.14382)\n\n- Pytorch implementation of [Separable Self-attention for Mobile Vision Transformers---ArXiv 2022.06.06](https://arxiv.org/abs/2206.02680)\n\n- Pytorch implementation of [On the Integration of Self-Attention and Convolution---ArXiv 2022.03.14](https://arxiv.org/pdf/2111.14556.pdf)\n\n- Pytorch implementation of [CROSSFORMER: A VERSATILE VISION TRANSFORMER HINGING ON CROSS-SCALE ATTENTION---ICLR 2022](https://arxiv.org/pdf/2108.00154.pdf)\n\n- Pytorch implementation of [Aggregating Global Features into Local Vision Transformer](https://arxiv.org/abs/2201.12903)\n\n- Pytorch implementation of [CCNet: Criss-Cross Attention for Semantic Segmentation](https://arxiv.org/abs/1811.11721)\n\n- Pytorch implementation of [Axial Attention in Multidimensional Transformers](https://arxiv.org/abs/1912.12180)\n***\n\n\n### 1. External Attention Usage\n#### 1.1. Paper\n[\"Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks\"](https://arxiv.org/abs/2105.02358)\n\n#### 1.2. Overview\n![](./model/img/External_Attention.png)\n\n#### 1.3. Usage Code\n```python\nfrom model.attention.ExternalAttention import ExternalAttention\nimport torch\n\ninput=torch.randn(50,49,512)\nea = ExternalAttention(d_model=512,S=8)\noutput=ea(input)\nprint(output.shape)\n```\n\n***\n\n\n### 2. Self Attention Usage\n#### 2.1. Paper\n[\"Attention Is All You Need\"](https://arxiv.org/pdf/1706.03762.pdf)\n\n#### 1.2. Overview\n![](./model/img/SA.png)\n\n#### 1.3. Usage Code\n```python\nfrom model.attention.SelfAttention import ScaledDotProductAttention\nimport torch\n\ninput=torch.randn(50,49,512)\nsa = ScaledDotProductAttention(d_model=512, d_k=512, d_v=512, h=8)\noutput=sa(input,input,input)\nprint(output.shape)\n```\n\n***\n\n### 3. Simplified Self Attention Usage\n#### 3.1. Paper\n[None]()\n\n#### 3.2. Overview\n![](./model/img/SSA.png)\n\n#### 3.3. Usage Code\n```python\nfrom model.attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttention\nimport torch\n\ninput=torch.randn(50,49,512)\nssa = SimplifiedScaledDotProductAttention(d_model=512, h=8)\noutput=ssa(input,input,input)\nprint(output.shape)\n\n```\n\n***\n\n### 4. Squeeze-and-Excitation Attention Usage\n#### 4.1. Paper\n[\"Squeeze-and-Excitation Networks\"](https://arxiv.org/abs/1709.01507)\n\n#### 4.2. Overview\n![](./model/img/SE.png)\n\n#### 4.3. Usage Code\n```python\nfrom model.attention.SEAttention import SEAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\nse = SEAttention(channel=512,reduction=8)\noutput=se(input)\nprint(output.shape)\n\n```\n\n***\n\n### 5. SK Attention Usage\n#### 5.1. Paper\n[\"Selective Kernel Networks\"](https://arxiv.org/pdf/1903.06586.pdf)\n\n#### 5.2. Overview\n![](./model/img/SK.png)\n\n#### 5.3. Usage Code\n```python\nfrom model.attention.SKAttention import SKAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\nse = SKAttention(channel=512,reduction=8)\noutput=se(input)\nprint(output.shape)\n\n```\n***\n\n### 6. CBAM Attention Usage\n#### 6.1. Paper\n[\"CBAM: Convolutional Block Attention Module\"](https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf)\n\n#### 6.2. Overview\n![](./model/img/CBAM1.png)\n\n![](./model/img/CBAM2.png)\n\n#### 6.3. Usage Code\n```python\nfrom model.attention.CBAM import CBAMBlock\nimport torch\n\ninput=torch.randn(50,512,7,7)\nkernel_size=input.shape[2]\ncbam = CBAMBlock(channel=512,reduction=16,kernel_size=kernel_size)\noutput=cbam(input)\nprint(output.shape)\n\n```\n\n***\n\n### 7. BAM Attention Usage\n#### 7.1. Paper\n[\"BAM: Bottleneck Attention Module\"](https://arxiv.org/pdf/1807.06514.pdf)\n\n#### 7.2. Overview\n![](./model/img/BAM.png)\n\n#### 7.3. Usage Code\n```python\nfrom model.attention.BAM import BAMBlock\nimport torch\n\ninput=torch.randn(50,512,7,7)\nbam = BAMBlock(channel=512,reduction=16,dia_val=2)\noutput=bam(input)\nprint(output.shape)\n\n```\n\n***\n\n### 8. ECA Attention Usage\n#### 8.1. Paper\n[\"ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks\"](https://arxiv.org/pdf/1910.03151.pdf)\n\n#### 8.2. Overview\n![](./model/img/ECA.png)\n\n#### 8.3. Usage Code\n```python\nfrom model.attention.ECAAttention import ECAAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\neca = ECAAttention(kernel_size=3)\noutput=eca(input)\nprint(output.shape)\n\n```\n\n***\n\n### 9. DANet Attention Usage\n#### 9.1. Paper\n[\"Dual Attention Network for Scene Segmentation\"](https://arxiv.org/pdf/1809.02983.pdf)\n\n#### 9.2. Overview\n![](./model/img/danet.png)\n\n#### 9.3. Usage Code\n```python\nfrom model.attention.DANet import DAModule\nimport torch\n\ninput=torch.randn(50,512,7,7)\ndanet=DAModule(d_model=512,kernel_size=3,H=7,W=7)\nprint(danet(input).shape)\n\n```\n\n***\n\n### 10. Pyramid Split Attention Usage\n\n#### 10.1. Paper\n[\"EPSANet: An Efficient Pyramid Split Attention Block on Convolutional Neural Network\"](https://arxiv.org/pdf/2105.14447.pdf)\n\n#### 10.2. Overview\n![](./model/img/psa.png)\n\n#### 10.3. Usage Code\n```python\nfrom model.attention.PSA import PSA\nimport torch\n\ninput=torch.randn(50,512,7,7)\npsa = PSA(channel=512,reduction=8)\noutput=psa(input)\nprint(output.shape)\n\n```\n\n***\n\n\n### 11. Efficient Multi-Head Self-Attention Usage\n\n#### 11.1. Paper\n[\"ResT: An Efficient Transformer for Visual Recognition\"](https://arxiv.org/abs/2105.13677)\n\n#### 11.2. Overview\n![](./model/img/EMSA.png)\n\n#### 11.3. Usage Code\n```python\n\nfrom model.attention.EMSA import EMSA\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,64,512)\nemsa = EMSA(d_model=512, d_k=512, d_v=512, h=8,H=8,W=8,ratio=2,apply_transform=True)\noutput=emsa(input,input,input)\nprint(output.shape)\n    \n```\n\n***\n\n\n### 12. Shuffle Attention Usage\n\n#### 12.1. Paper\n[\"SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS\"](https://arxiv.org/pdf/2102.00240.pdf)\n\n#### 12.2. Overview\n![](./model/img/ShuffleAttention.jpg)\n\n#### 12.3. Usage Code\n```python\n\nfrom model.attention.ShuffleAttention import ShuffleAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\ninput=torch.randn(50,512,7,7)\nse = ShuffleAttention(channel=512,G=8)\noutput=se(input)\nprint(output.shape)\n\n    \n```\n\n\n***\n\n\n### 13. MUSE Attention Usage\n\n#### 13.1. Paper\n[\"MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning\"](https://arxiv.org/abs/1911.09483)\n\n#### 13.2. Overview\n![](./model/img/MUSE.png)\n\n#### 13.3. Usage Code\n```python\nfrom model.attention.MUSEAttention import MUSEAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\ninput=torch.randn(50,49,512)\nsa = MUSEAttention(d_model=512, d_k=512, d_v=512, h=8)\noutput=sa(input,input,input)\nprint(output.shape)\n\n```\n\n***\n\n\n### 14. SGE Attention Usage\n\n#### 14.1. Paper\n[Spatial Group-wise Enhance: Improving Semantic Feature Learning in Convolutional Networks](https://arxiv.org/pdf/1905.09646.pdf)\n\n#### 14.2. Overview\n![](./model/img/SGE.png)\n\n#### 14.3. Usage Code\n```python\nfrom model.attention.SGE import SpatialGroupEnhance\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\nsge = SpatialGroupEnhance(groups=8)\noutput=sge(input)\nprint(output.shape)\n\n```\n\n***\n\n\n### 15. A2 Attention Usage\n\n#### 15.1. Paper\n[A2-Nets: Double Attention Networks](https://arxiv.org/pdf/1810.11579.pdf)\n\n#### 15.2. Overview\n![](./model/img/A2.png)\n\n#### 15.3. Usage Code\n```python\nfrom model.attention.A2Atttention import DoubleAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\na2 = DoubleAttention(512,128,128,True)\noutput=a2(input)\nprint(output.shape)\n\n```\n\n\n\n### 16. AFT Attention Usage\n\n#### 16.1. Paper\n[An Attention Free Transformer](https://arxiv.org/pdf/2105.14103v1.pdf)\n\n#### 16.2. Overview\n![](./model/img/AFT.jpg)\n\n#### 16.3. Usage Code\n```python\nfrom model.attention.AFT import AFT_FULL\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,49,512)\naft_full = AFT_FULL(d_model=512, n=49)\noutput=aft_full(input)\nprint(output.shape)\n\n```\n\n\n\n\n\n\n### 17. Outlook Attention Usage\n\n#### 17.1. Paper\n\n\n[VOLO: Vision Outlooker for Visual Recognition\"](https://arxiv.org/abs/2106.13112)\n\n\n#### 17.2. Overview\n![](./model/img/OutlookAttention.png)\n\n#### 17.3. Usage Code\n```python\nfrom model.attention.OutlookAttention import OutlookAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,28,28,512)\noutlook = OutlookAttention(dim=512)\noutput=outlook(input)\nprint(output.shape)\n\n```\n\n\n***\n\n\n\n\n\n\n### 18. ViP Attention Usage\n\n#### 18.1. Paper\n\n\n[Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition\"](https://arxiv.org/abs/2106.12368)\n\n\n#### 18.2. Overview\n![](./model/img/ViP.png)\n\n#### 18.3. Usage Code\n```python\n\nfrom model.attention.ViP import WeightedPermuteMLP\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(64,8,8,512)\nseg_dim=8\nvip=WeightedPermuteMLP(512,seg_dim)\nout=vip(input)\nprint(out.shape)\n\n```\n\n\n***\n\n\n\n\n\n### 19. CoAtNet Attention Usage\n\n#### 19.1. Paper\n\n\n[CoAtNet: Marrying Convolution and Attention for All Data Sizes\"](https://arxiv.org/abs/2106.04803) \n\n\n#### 19.2. Overview\nNone\n\n\n#### 19.3. Usage Code\n```python\n\nfrom model.attention.CoAtNet import CoAtNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,3,224,224)\nmbconv=CoAtNet(in_ch=3,image_size=224)\nout=mbconv(input)\nprint(out.shape)\n\n```\n\n\n***\n\n\n\n\n\n\n### 20. HaloNet Attention Usage\n\n#### 20.1. Paper\n\n\n[Scaling Local Self-Attention for Parameter Efficient Visual Backbones\"](https://arxiv.org/pdf/2103.12731.pdf) \n\n\n#### 20.2. Overview\n\n![](./model/img/HaloNet.png)\n\n#### 20.3. Usage Code\n```python\n\nfrom model.attention.HaloAttention import HaloAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,512,8,8)\nhalo = HaloAttention(dim=512,\n    block_size=2,\n    halo_size=1,)\noutput=halo(input)\nprint(output.shape)\n\n```\n\n\n***\n\n### 21. Polarized Self-Attention Usage\n\n#### 21.1. Paper\n\n[Polarized Self-Attention: Towards High-quality Pixel-wise Regression\"](https://arxiv.org/abs/2107.00782)  \n\n\n#### 21.2. Overview\n\n![](./model/img/PoSA.png)\n\n#### 21.3. Usage Code\n```python\n\nfrom model.attention.PolarizedSelfAttention import ParallelPolarizedSelfAttention,SequentialPolarizedSelfAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,512,7,7)\npsa = SequentialPolarizedSelfAttention(channel=512)\noutput=psa(input)\nprint(output.shape)\n\n\n```\n\n\n***\n\n\n### 22. CoTAttention Usage\n\n#### 22.1. Paper\n\n[Contextual Transformer Networks for Visual Recognition---arXiv 2021.07.26](https://arxiv.org/abs/2107.12292) \n\n\n#### 22.2. Overview\n\n![](./model/img/CoT.png)\n\n#### 22.3. Usage Code\n```python\n\nfrom model.attention.CoTAttention import CoTAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\ncot = CoTAttention(dim=512,kernel_size=3)\noutput=cot(input)\nprint(output.shape)\n\n\n\n```\n\n***\n\n\n### 23. Residual Attention Usage\n\n#### 23.1. Paper\n\n[Residual Attention: A Simple but Effective Method for Multi-Label Recognition---ICCV2021](https://arxiv.org/abs/2108.02456) \n\n\n#### 23.2. Overview\n\n![](./model/img/ResAtt.png)\n\n#### 23.3. Usage Code\n```python\n\nfrom model.attention.ResidualAttention import ResidualAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\nresatt = ResidualAttention(channel=512,num_class=1000,la=0.2)\noutput=resatt(input)\nprint(output.shape)\n\n\n\n```\n\n***\n\n\n\n### 24. S2 Attention Usage\n\n#### 24.1. Paper\n\n[S²-MLPv2: Improved Spatial-Shift MLP Architecture for Vision---arXiv 2021.08.02](https://arxiv.org/abs/2108.01072) \n\n\n#### 24.2. Overview\n\n![](./model/img/S2Attention.png)\n\n#### 24.3. Usage Code\n```python\nfrom model.attention.S2Attention import S2Attention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\ns2att = S2Attention(channels=512)\noutput=s2att(input)\nprint(output.shape)\n\n```\n\n***\n\n\n\n### 25. GFNet Attention Usage\n\n#### 25.1. Paper\n\n[Global Filter Networks for Image Classification---arXiv 2021.07.01](https://arxiv.org/abs/2107.00645) \n\n\n#### 25.2. Overview\n\n![](./model/img/GFNet.jpg)\n\n#### 25.3. Usage Code - Implemented by [Wenliang Zhao (Author)](https://scholar.google.com/citations?user=lyPWvuEAAAAJ&hl=en)\n\n```python\nfrom model.attention.gfnet import GFNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nx = torch.randn(1, 3, 224, 224)\ngfnet = GFNet(embed_dim=384, img_size=224, patch_size=16, num_classes=1000)\nout = gfnet(x)\nprint(out.shape)\n\n```\n\n***\n\n\n### 26. TripletAttention Usage\n\n#### 26.1. Paper\n\n[Rotate to Attend: Convolutional Triplet Attention Module---CVPR 2021](https://arxiv.org/abs/2010.03045) \n\n#### 26.2. Overview\n\n![](./model/img/triplet.png)\n\n#### 26.3. Usage Code - Implemented by [digantamisra98](https://github.com/digantamisra98)\n\n```python\nfrom model.attention.TripletAttention import TripletAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\ninput=torch.randn(50,512,7,7)\ntriplet = TripletAttention()\noutput=triplet(input)\nprint(output.shape)\n```\n\n\n***\n\n\n### 27. Coordinate Attention Usage\n\n#### 27.1. Paper\n\n[Coordinate Attention for Efficient Mobile Network Design---CVPR 2021](https://arxiv.org/abs/2103.02907)\n\n\n#### 27.2. Overview\n\n![](./model/img/CoordAttention.png)\n\n#### 27.3. Usage Code - Implemented by [Andrew-Qibin](https://github.com/Andrew-Qibin)\n\n```python\nfrom model.attention.CoordAttention import CoordAtt\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninp=torch.rand([2, 96, 56, 56])\ninp_dim, oup_dim = 96, 96\nreduction=32\n\ncoord_attention = CoordAtt(inp_dim, oup_dim, reduction=reduction)\noutput=coord_attention(inp)\nprint(output.shape)\n```\n\n***\n\n\n### 28. MobileViT Attention Usage\n\n#### 28.1. Paper\n\n[MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2021.10.05](https://arxiv.org/abs/2103.02907)\n\n\n#### 28.2. Overview\n\n![](./model/img/MobileViTAttention.png)\n\n#### 28.3. Usage Code\n\n```python\nfrom model.attention.MobileViTAttention import MobileViTAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    m=MobileViTAttention()\n    input=torch.randn(1,3,49,49)\n    output=m(input)\n    print(output.shape)  #output:(1,3,49,49)\n    \n```\n\n***\n\n\n### 29. ParNet Attention Usage\n\n#### 29.1. Paper\n\n[Non-deep Networks---ArXiv 2021.10.20](https://arxiv.org/abs/2110.07641)\n\n\n#### 29.2. Overview\n\n![](./model/img/ParNet.png)\n\n#### 29.3. Usage Code\n\n```python\nfrom model.attention.ParNetAttention import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    pna = ParNetAttention(channel=512)\n    output=pna(input)\n    print(output.shape) #50,512,7,7\n    \n```\n\n***\n\n\n### 30. UFO Attention Usage\n\n#### 30.1. Paper\n\n[UFO-ViT: High Performance Linear Vision Transformer without Softmax---ArXiv 2021.09.29](https://arxiv.org/abs/2110.07641)\n\n\n#### 30.2. Overview\n\n![](./model/img/UFO.png)\n\n#### 30.3. Usage Code\n\n```python\nfrom model.attention.UFOAttention import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    ufo = UFOAttention(d_model=512, d_k=512, d_v=512, h=8)\n    output=ufo(input,input,input)\n    print(output.shape) #[50, 49, 512]\n    \n```\n\n-\n\n### 31. ACmix Attention Usage\n\n#### 31.1. Paper\n\n[On the Integration of Self-Attention and Convolution](https://arxiv.org/pdf/2111.14556.pdf)\n\n#### 31.2. Usage Code\n\n```python\nfrom model.attention.ACmix import ACmix\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(50,256,7,7)\n    acmix = ACmix(in_planes=256, out_planes=256)\n    output=acmix(input)\n    print(output.shape)\n    \n```\n\n### 32. MobileViTv2 Attention Usage\n\n#### 32.1. Paper\n\n[Separable Self-attention for Mobile Vision Transformers---ArXiv 2022.06.06](https://arxiv.org/abs/2206.02680)\n\n\n#### 32.2. Overview\n\n![](./model/img/MobileViTv2.png)\n\n#### 32.3. Usage Code\n\n```python\nfrom model.attention.MobileViTv2Attention import MobileViTv2Attention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    sa = MobileViTv2Attention(d_model=512)\n    output=sa(input)\n    print(output.shape)\n    \n```\n\n### 33. DAT Attention Usage\n\n#### 33.1. Paper\n\n[Vision Transformer with Deformable Attention---CVPR2022](https://arxiv.org/abs/2201.00520)\n\n#### 33.2. Usage Code\n\n```python\nfrom model.attention.DAT import DAT\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = DAT(\n        img_size=224,\n        patch_size=4,\n        num_classes=1000,\n        expansion=4,\n        dim_stem=96,\n        dims=[96, 192, 384, 768],\n        depths=[2, 2, 6, 2],\n        stage_spec=[['L', 'S'], ['L', 'S'], ['L', 'D', 'L', 'D', 'L', 'D'], ['L', 'D']],\n        heads=[3, 6, 12, 24],\n        window_sizes=[7, 7, 7, 7] ,\n        groups=[-1, -1, 3, 6],\n        use_pes=[False, False, True, True],\n        dwc_pes=[False, False, False, False],\n        strides=[-1, -1, 1, 1],\n        sr_ratios=[-1, -1, -1, -1],\n        offset_range_factor=[-1, -1, 2, 2],\n        no_offs=[False, False, False, False],\n        fixed_pes=[False, False, False, False],\n        use_dwc_mlps=[False, False, False, False],\n        use_conv_patches=False,\n        drop_rate=0.0,\n        attn_drop_rate=0.0,\n        drop_path_rate=0.2,\n    )\n    output=model(input)\n    print(output[0].shape)\n    \n```\n\n### 34. CrossFormer Attention Usage\n\n#### 34.1. Paper\n\n[CROSSFORMER: A VERSATILE VISION TRANSFORMER HINGING ON CROSS-SCALE ATTENTION---ICLR 2022](https://arxiv.org/pdf/2108.00154.pdf)\n\n#### 34.2. Usage Code\n\n```python\nfrom model.attention.Crossformer import CrossFormer\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CrossFormer(img_size=224,\n        patch_size=[4, 8, 16, 32],\n        in_chans= 3,\n        num_classes=1000,\n        embed_dim=48,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        group_size=[7, 7, 7, 7],\n        mlp_ratio=4.,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.0,\n        drop_path_rate=0.1,\n        ape=False,\n        patch_norm=True,\n        use_checkpoint=False,\n        merge_size=[[2, 4], [2,4], [2, 4]]\n    )\n    output=model(input)\n    print(output.shape)\n    \n```\n\n### 35. MOATransformer Attention Usage\n\n#### 35.1. Paper\n\n[Aggregating Global Features into Local Vision Transformer](https://arxiv.org/abs/2201.12903)\n\n#### 35.2. Usage Code\n\n```python\nfrom model.attention.MOATransformer import MOATransformer\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = MOATransformer(\n        img_size=224,\n        patch_size=4,\n        in_chans=3,\n        num_classes=1000,\n        embed_dim=96,\n        depths=[2, 2, 6],\n        num_heads=[3, 6, 12],\n        window_size=14,\n        mlp_ratio=4.,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.0,\n        drop_path_rate=0.1,\n        ape=False,\n        patch_norm=True,\n        use_checkpoint=False\n    )\n    output=model(input)\n    print(output.shape)\n    \n```\n\n### 36. CrissCrossAttention Attention Usage\n\n#### 36.1. Paper\n\n[CCNet: Criss-Cross Attention for Semantic Segmentation](https://arxiv.org/abs/1811.11721)\n\n#### 36.2. Usage Code\n\n```python\nfrom model.attention.CrissCrossAttention import CrissCrossAttention\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(3, 64, 7, 7)\n    model = CrissCrossAttention(64)\n    outputs = model(input)\n    print(outputs.shape)\n    \n```\n\n### 37. Axial_attention Attention Usage\n\n#### 37.1. Paper\n\n[Axial Attention in Multidimensional Transformers](https://arxiv.org/abs/1912.12180)\n\n#### 37.2. Usage Code\n\n```python\nfrom model.attention.Axial_attention import AxialImageTransformer\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(3, 128, 7, 7)\n    model = AxialImageTransformer(\n        dim = 128,\n        depth = 12,\n        reversible = True\n    )\n    outputs = model(input)\n    print(outputs.shape)\n    \n```\n\n***\n\n\n# Backbone Series\n\n- Pytorch implementation of [\"Deep Residual Learning for Image Recognition---CVPR2016 Best Paper\"](https://arxiv.org/pdf/1512.03385.pdf)\n\n- Pytorch implementation of [\"Aggregated Residual Transformations for Deep Neural Networks---CVPR2017\"](https://arxiv.org/abs/1611.05431v2)\n\n- Pytorch implementation of [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2020.10.05](https://arxiv.org/abs/2103.02907)\n\n- Pytorch implementation of [Patches Are All You Need?---ICLR2022 (Under Review)](https://openreview.net/forum?id=TVHS5Y4dNvM)\n\n- Pytorch implementation of [Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer---ArXiv 2021.06.07](https://arxiv.org/abs/2106.03650)\n\n- Pytorch implementation of [ConTNet: Why not use convolution and transformer at the same time?---ArXiv 2021.04.27](https://arxiv.org/abs/2104.13497)\n\n- Pytorch implementation of [Vision Transformers with Hierarchical Attention---ArXiv 2022.06.15](https://arxiv.org/abs/2106.03180)\n\n- Pytorch implementation of [Co-Scale Conv-Attentional Image Transformers---ArXiv 2021.08.26](https://arxiv.org/abs/2104.06399)\n\n- Pytorch implementation of [Conditional Positional Encodings for Vision Transformers](https://arxiv.org/abs/2102.10882)\n\n- Pytorch implementation of [Rethinking Spatial Dimensions of Vision Transformers---ICCV 2021](https://arxiv.org/abs/2103.16302)\n\n- Pytorch implementation of [CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification---ICCV 2021](https://arxiv.org/abs/2103.14899)\n\n- Pytorch implementation of [Transformer in Transformer---NeurIPS 2021](https://arxiv.org/abs/2103.00112)\n\n- Pytorch implementation of [DeepViT: Towards Deeper Vision Transformer](https://arxiv.org/abs/2103.11886)\n\n- Pytorch implementation of [Incorporating Convolution Designs into Visual Transformers](https://arxiv.org/abs/2103.11816)\n***\n\n- Pytorch implementation of [ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases](https://arxiv.org/abs/2103.10697)\n\n- Pytorch implementation of [Augmenting Convolutional networks with attention-based aggregation](https://arxiv.org/abs/2112.13692)\n\n- Pytorch implementation of [Going deeper with Image Transformers---ICCV 2021 (Oral)](https://arxiv.org/abs/2103.17239)\n\n- Pytorch implementation of [Training data-efficient image transformers & distillation through attention---ICML 2021](https://arxiv.org/abs/2012.12877)\n\n- Pytorch implementation of [LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference](https://arxiv.org/abs/2104.01136)\n\n- Pytorch implementation of [VOLO: Vision Outlooker for Visual Recognition](https://arxiv.org/abs/2106.13112)\n\n- Pytorch implementation of [Container: Context Aggregation Network---NeuIPS 2021](https://arxiv.org/abs/2106.01401)\n\n- Pytorch implementation of [CMT: Convolutional Neural Networks Meet Vision Transformers---CVPR 2022](https://arxiv.org/abs/2107.06263)\n\n- Pytorch implementation of [Vision Transformer with Deformable Attention---CVPR 2022](https://arxiv.org/abs/2201.00520)\n\n- Pytorch implementation of [EfficientFormer: Vision Transformers at MobileNet Speed](https://arxiv.org/abs/2206.01191)\n\n- Pytorch implementation of [ConvNeXtV2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808)\n\n\n### 1. ResNet Usage\n#### 1.1. Paper\n[\"Deep Residual Learning for Image Recognition---CVPR2016 Best Paper\"](https://arxiv.org/pdf/1512.03385.pdf)\n\n#### 1.2. Overview\n![](./model/img/resnet.png)\n![](./model/img/resnet2.jpg)\n\n#### 1.3. Usage Code\n```python\n\nfrom model.backbone.resnet import ResNet50,ResNet101,ResNet152\nimport torch\nif __name__ == '__main__':\n    input=torch.randn(50,3,224,224)\n    resnet50=ResNet50(1000)\n    # resnet101=ResNet101(1000)\n    # resnet152=ResNet152(1000)\n    out=resnet50(input)\n    print(out.shape)\n\n```\n\n\n### 2. ResNeXt Usage\n#### 2.1. Paper\n\n[\"Aggregated Residual Transformations for Deep Neural Networks---CVPR2017\"](https://arxiv.org/abs/1611.05431v2)\n\n#### 2.2. Overview\n![](./model/img/resnext.png)\n\n#### 2.3. Usage Code\n```python\n\nfrom model.backbone.resnext import ResNeXt50,ResNeXt101,ResNeXt152\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(50,3,224,224)\n    resnext50=ResNeXt50(1000)\n    # resnext101=ResNeXt101(1000)\n    # resnext152=ResNeXt152(1000)\n    out=resnext50(input)\n    print(out.shape)\n\n\n```\n\n\n\n### 3. MobileViT Usage\n#### 3.1. Paper\n\n[MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2020.10.05](https://arxiv.org/abs/2103.02907)\n\n#### 3.2. Overview\n![](./model/img/mobileViT.jpg)\n\n#### 3.3. Usage Code\n```python\n\nfrom model.backbone.MobileViT import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n\n    ### mobilevit_xxs\n    mvit_xxs=mobilevit_xxs()\n    out=mvit_xxs(input)\n    print(out.shape)\n\n    ### mobilevit_xs\n    mvit_xs=mobilevit_xs()\n    out=mvit_xs(input)\n    print(out.shape)\n\n\n    ### mobilevit_s\n    mvit_s=mobilevit_s()\n    out=mvit_s(input)\n    print(out.shape)\n\n```\n\n\n\n\n\n### 4. ConvMixer Usage\n#### 4.1. Paper\n[Patches Are All You Need?---ICLR2022 (Under Review)](https://openreview.net/forum?id=TVHS5Y4dNvM)\n#### 4.2. Overview\n![](./model/img/ConvMixer.png)\n\n#### 4.3. Usage Code\n```python\n\nfrom model.backbone.ConvMixer import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    x=torch.randn(1,3,224,224)\n    convmixer=ConvMixer(dim=512,depth=12)\n    out=convmixer(x)\n    print(out.shape)  #[1, 1000]\n\n\n```\n\n### 5. ShuffleTransformer Usage\n#### 5.1. Paper\n[Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer](https://arxiv.org/pdf/2106.03650.pdf)\n\n#### 5.2. Usage Code\n```python\n\nfrom model.backbone.ShuffleTransformer import ShuffleTransformer\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    sft = ShuffleTransformer()\n    output=sft(input)\n    print(output.shape)\n\n\n```\n\n### 6. ConTNet Usage\n#### 6.1. Paper\n[ConTNet: Why not use convolution and transformer at the same time?](https://arxiv.org/abs/2104.13497)\n\n#### 6.2. Usage Code\n```python\n\nfrom model.backbone.ConTNet import ConTNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == \"__main__\":\n    model = build_model(use_avgdown=True, relative=True, qkv_bias=True, pre_norm=True)\n    input = torch.randn(1, 3, 224, 224)\n    out = model(input)\n    print(out.shape)\n\n\n```\n\n### 7 HATNet Usage\n#### 7.1. Paper\n[Vision Transformers with Hierarchical Attention](https://arxiv.org/abs/2106.03180)\n\n#### 7.2. Usage Code\n```python\n\nfrom model.backbone.HATNet import HATNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    hat = HATNet(dims=[48, 96, 240, 384], head_dim=48, expansions=[8, 8, 4, 4],\n        grid_sizes=[8, 7, 7, 1], ds_ratios=[8, 4, 2, 1], depths=[2, 2, 6, 3])\n    output=hat(input)\n    print(output.shape)\n\n\n```\n\n### 8 CoaT Usage\n#### 8.1. Paper\n[Co-Scale Conv-Attentional Image Transformers](https://arxiv.org/abs/2104.06399)\n\n#### 8.2. Usage Code\n```python\n\nfrom model.backbone.CoaT import CoaT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CoaT(patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4])\n    output=model(input)\n    print(output.shape) # torch.Size([1, 1000])\n\n```\n\n### 9 PVT Usage\n#### 9.1. Paper\n[PVT v2: Improved Baselines with Pyramid Vision Transformer](https://arxiv.org/pdf/2106.13797.pdf)\n\n#### 9.2. Usage Code\n```python\n\nfrom model.backbone.PVT import PyramidVisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = PyramidVisionTransformer(\n        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1])\n    output=model(input)\n    print(output.shape)\n\n```\n\n\n### 10 CPVT Usage\n#### 10.1. Paper\n[Conditional Positional Encodings for Vision Transformers](https://arxiv.org/abs/2102.10882)\n\n#### 10.2. Usage Code\n```python\n\nfrom model.backbone.CPVT import CPVTV2\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CPVTV2(\n        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1])\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 11 PIT Usage\n#### 11.1. Paper\n[Rethinking Spatial Dimensions of Vision Transformers](https://arxiv.org/abs/2103.16302)\n\n#### 11.2. Usage Code\n```python\n\nfrom model.backbone.PIT import PoolingTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = PoolingTransformer(\n        image_size=224,\n        patch_size=14,\n        stride=7,\n        base_dims=[64, 64, 64],\n        depth=[3, 6, 4],\n        heads=[4, 8, 16],\n        mlp_ratio=4\n    )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 12 CrossViT Usage\n#### 12.1. Paper\n[CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification](https://arxiv.org/abs/2103.14899)\n\n#### 12.2. Usage Code\n```python\n\nfrom model.backbone.CrossViT import VisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == \"__main__\":\n    input=torch.randn(1,3,224,224)\n    model = VisionTransformer(\n        img_size=[240, 224],\n        patch_size=[12, 16], \n        embed_dim=[192, 384], \n        depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],\n        num_heads=[6, 6], \n        mlp_ratio=[4, 4, 1], \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n    )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 13 TnT Usage\n#### 13.1. Paper\n[Transformer in Transformer](https://arxiv.org/abs/2103.00112)\n\n#### 13.2. Usage Code\n```python\n\nfrom model.backbone.TnT import TNT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = TNT(\n        img_size=224, \n        patch_size=16, \n        outer_dim=384, \n        inner_dim=24, \n        depth=12,\n        outer_num_heads=6, \n        inner_num_heads=4, \n        qkv_bias=False,\n        inner_stride=4)\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 14 DViT Usage\n#### 14.1. Paper\n[DeepViT: Towards Deeper Vision Transformer](https://arxiv.org/abs/2103.11886)\n\n#### 14.2. Usage Code\n```python\n\nfrom model.backbone.DViT import DeepVisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = DeepVisionTransformer(\n        patch_size=16, embed_dim=384, \n        depth=[False] * 16, \n        apply_transform=[False] * 0 + [True] * 32, \n        num_heads=12, \n        mlp_ratio=3, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 15 CeiT Usage\n#### 15.1. Paper\n[Incorporating Convolution Designs into Visual Transformers](https://arxiv.org/abs/2103.11816)\n\n#### 15.2. Usage Code\n```python\n\nfrom model.backbone.CeiT import CeIT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CeIT(\n        hybrid_backbone=Image2Tokens(),\n        patch_size=4, \n        embed_dim=192, \n        depth=12, \n        num_heads=3, \n        mlp_ratio=4, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 16 ConViT Usage\n#### 16.1. Paper\n[ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases](https://arxiv.org/abs/2103.10697)\n\n#### 16.2. Usage Code\n```python\n\nfrom model.backbone.ConViT import VisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VisionTransformer(\n        num_heads=16,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 17 CaiT Usage\n#### 17.1. Paper\n[Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239)\n\n#### 17.2. Usage Code\n```python\n\nfrom model.backbone.CaiT import CaiT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CaiT(\n        img_size= 224,\n        patch_size=16, \n        embed_dim=192, \n        depth=24, \n        num_heads=4, \n        mlp_ratio=4, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_scale=1e-5,\n        depth_token_only=2\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 18 PatchConvnet Usage\n#### 18.1. Paper\n[Augmenting Convolutional networks with attention-based aggregation](https://arxiv.org/abs/2112.13692)\n\n#### 18.2. Usage Code\n```python\n\nfrom model.backbone.PatchConvnet import PatchConvnet\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = PatchConvnet(\n        patch_size=16,\n        embed_dim=384,\n        depth=60,\n        num_heads=1,\n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        Patch_layer=ConvStem,\n        Attention_block=Conv_blocks_se,\n        depth_token_only=1,\n        mlp_ratio_clstk=3.0,\n    )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 19 DeiT Usage\n#### 19.1. Paper\n[Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877)\n\n#### 19.2. Usage Code\n```python\n\nfrom model.backbone.DeiT import DistilledVisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = DistilledVisionTransformer(\n        patch_size=16, \n        embed_dim=384, \n        depth=12, \n        num_heads=6, \n        mlp_ratio=4, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n        )\n    output=model(input)\n    print(output[0].shape)\n\n```\n\n### 20 LeViT Usage\n#### 20.1. Paper\n[LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference](https://arxiv.org/abs/2104.01136)\n\n#### 20.2. Usage Code\n```python\n\nfrom model.backbone.LeViT import *\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    for name in specification:\n        input=torch.randn(1,3,224,224)\n        model = globals()[name](fuse=True, pretrained=False)\n        model.eval()\n        output = model(input)\n        print(output.shape)\n\n```\n\n### 21 VOLO Usage\n#### 21.1. Paper\n[VOLO: Vision Outlooker for Visual Recognition](https://arxiv.org/abs/2106.13112)\n\n#### 21.2. Usage Code\n```python\n\nfrom model.backbone.VOLO import VOLO\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VOLO([4, 4, 8, 2],\n                 embed_dims=[192, 384, 384, 384],\n                 num_heads=[6, 12, 12, 12],\n                 mlp_ratios=[3, 3, 3, 3],\n                 downsamples=[True, False, False, False],\n                 outlook_attention=[True, False, False, False ],\n                 post_layers=['ca', 'ca'],\n                 )\n    output=model(input)\n    print(output[0].shape)\n\n```\n\n### 22 Container Usage\n#### 22.1. Paper\n[Container: Context Aggregation Network](https://arxiv.org/abs/2106.01401)\n\n#### 22.2. Usage Code\n```python\n\nfrom model.backbone.Container import VisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VisionTransformer(\n        img_size=[224, 56, 28, 14], \n        patch_size=[4, 2, 2, 2], \n        embed_dim=[64, 128, 320, 512], \n        depth=[3, 4, 8, 3], \n        num_heads=16, \n        mlp_ratio=[8, 8, 4, 4], \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6))\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 23 CMT Usage\n#### 23.1. Paper\n[CMT: Convolutional Neural Networks Meet Vision Transformers](https://arxiv.org/abs/2107.06263)\n\n#### 23.2. Usage Code\n```python\n\nfrom model.backbone.CMT import CMT_Tiny\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CMT_Tiny()\n    output=model(input)\n    print(output[0].shape)\n\n```\n\n### 24 EfficientFormer Usage\n#### 24.1. Paper\n[EfficientFormer: Vision Transformers at MobileNet Speed](https://arxiv.org/abs/2206.01191)\n\n#### 24.2. Usage Code\n```python\n\nfrom model.backbone.EfficientFormer import EfficientFormer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = EfficientFormer(\n        layers=EfficientFormer_depth['l1'],\n        embed_dims=EfficientFormer_width['l1'],\n        downsamples=[True, True, True, True],\n        vit_num=1,\n    )\n    output=model(input)\n    print(output[0].shape)\n\n```\n\n### 25 ConvNeXtV2 Usage\n#### 25.1. Paper\n[ConvNeXtV2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808)\n\n#### 25.2. Usage Code\n```python\n\nfrom model.backbone.convnextv2 import convnextv2_atto\nimport torch\nfrom torch import nn\n\nif __name__ == \"__main__\":\n    model = convnextv2_atto()\n    input = torch.randn(1, 3, 224, 224)\n    out = model(input)\n    print(out.shape)\n\n```\n\n\n\n\n\n# MLP Series\n\n- Pytorch implementation of [\"RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition---arXiv 2021.05.05\"](https://arxiv.org/pdf/2105.01883v1.pdf)\n\n- Pytorch implementation of [\"MLP-Mixer: An all-MLP Architecture for Vision---arXiv 2021.05.17\"](https://arxiv.org/pdf/2105.01601.pdf)\n\n- Pytorch implementation of [\"ResMLP: Feedforward networks for image classification with data-efficient training---arXiv 2021.05.07\"](https://arxiv.org/pdf/2105.03404.pdf)\n\n- Pytorch implementation of [\"Pay Attention to MLPs---arXiv 2021.05.17\"](https://arxiv.org/abs/2105.08050)\n\n\n- Pytorch implementation of [\"Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?---arXiv 2021.09.12\"](https://arxiv.org/abs/2109.05422)\n\n### 1. RepMLP Usage\n#### 1.1. Paper\n[\"RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition\"](https://arxiv.org/pdf/2105.01883v1.pdf)\n\n#### 1.2. Overview\n![](./model/img/repmlp.png)\n\n#### 1.3. Usage Code\n```python\nfrom model.mlp.repmlp import RepMLP\nimport torch\nfrom torch import nn\n\nN=4 #batch size\nC=512 #input dim\nO=1024 #output dim\nH=14 #image height\nW=14 #image width\nh=7 #patch height\nw=7 #patch width\nfc1_fc2_reduction=1 #reduction ratio\nfc3_groups=8 # groups\nrepconv_kernels=[1,3,5,7] #kernel list\nrepmlp=RepMLP(C,O,H,W,h,w,fc1_fc2_reduction,fc3_groups,repconv_kernels=repconv_kernels)\nx=torch.randn(N,C,H,W)\nrepmlp.eval()\nfor module in repmlp.modules():\n    if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):\n        nn.init.uniform_(module.running_mean, 0, 0.1)\n        nn.init.uniform_(module.running_var, 0, 0.1)\n        nn.init.uniform_(module.weight, 0, 0.1)\n        nn.init.uniform_(module.bias, 0, 0.1)\n\n#training result\nout=repmlp(x)\n#inference result\nrepmlp.switch_to_deploy()\ndeployout = repmlp(x)\n\nprint(((deployout-out)**2).sum())\n```\n\n### 2. MLP-Mixer Usage\n#### 2.1. Paper\n[\"MLP-Mixer: An all-MLP Architecture for Vision\"](https://arxiv.org/pdf/2105.01601.pdf)\n\n#### 2.2. Overview\n![](./model/img/mlpmixer.png)\n\n#### 2.3. Usage Code\n```python\nfrom model.mlp.mlp_mixer import MlpMixer\nimport torch\nmlp_mixer=MlpMixer(num_classes=1000,num_blocks=10,patch_size=10,tokens_hidden_dim=32,channels_hidden_dim=1024,tokens_mlp_dim=16,channels_mlp_dim=1024)\ninput=torch.randn(50,3,40,40)\noutput=mlp_mixer(input)\nprint(output.shape)\n```\n\n***\n\n### 3. ResMLP Usage\n#### 3.1. Paper\n[\"ResMLP: Feedforward networks for image classification with data-efficient training\"](https://arxiv.org/pdf/2105.03404.pdf)\n\n#### 3.2. Overview\n![](./model/img/resmlp.png)\n\n#### 3.3. Usage Code\n```python\nfrom model.mlp.resmlp import ResMLP\nimport torch\n\ninput=torch.randn(50,3,14,14)\nresmlp=ResMLP(dim=128,image_size=14,patch_size=7,class_num=1000)\nout=resmlp(input)\nprint(out.shape) #the last dimention is class_num\n```\n\n***\n\n### 4. gMLP Usage\n#### 4.1. Paper\n[\"Pay Attention to MLPs\"](https://arxiv.org/abs/2105.08050)\n\n#### 4.2. Overview\n![](./model/img/gMLP.jpg)\n\n#### 4.3. Usage Code\n```python\nfrom model.mlp.g_mlp import gMLP\nimport torch\n\nnum_tokens=10000\nbs=50\nlen_sen=49\nnum_layers=6\ninput=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen\ngmlp = gMLP(num_tokens=num_tokens,len_sen=len_sen,dim=512,d_ff=1024)\noutput=gmlp(input)\nprint(output.shape)\n```\n\n***\n\n### 5. sMLP Usage\n#### 5.1. Paper\n[\"Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?\"](https://arxiv.org/abs/2109.05422)\n\n#### 5.2. Overview\n![](./model/img/sMLP.jpg)\n\n#### 5.3. Usage Code\n```python\nfrom model.mlp.sMLP_block import sMLPBlock\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,3,224,224)\n    smlp=sMLPBlock(h=224,w=224)\n    out=smlp(input)\n    print(out.shape)\n```\n\n### 6. vip-mlp Usage\n#### 6.1. Paper\n[\"Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition\"](https://arxiv.org/abs/2106.12368)\n\n#### 6.2. Usage Code\n```python\nfrom model.mlp.vip-mlp import VisionPermutator\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VisionPermutator(\n        layers=[4, 3, 8, 3], \n        embed_dims=[384, 384, 384, 384], \n        patch_size=14, \n        transitions=[False, False, False, False],\n        segment_dim=[16, 16, 16, 16], \n        mlp_ratios=[3, 3, 3, 3], \n        mlp_fn=WeightedPermuteMLP\n    )\n    output=model(input)\n    print(output.shape)\n```\n\n\n# Re-Parameter Series\n\n- Pytorch implementation of [\"RepVGG: Making VGG-style ConvNets Great Again---CVPR2021\"](https://arxiv.org/abs/2101.03697)\n\n- Pytorch implementation of [\"ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks---ICCV2019\"](https://arxiv.org/abs/1908.03930)\n\n- Pytorch implementation of [\"Diverse Branch Block: Building a Convolution as an Inception-like Unit---CVPR2021\"](https://arxiv.org/abs/2103.13425)\n\n\n***\n\n### 1. RepVGG Usage\n#### 1.1. Paper\n[\"RepVGG: Making VGG-style ConvNets Great Again\"](https://arxiv.org/abs/2101.03697)\n\n#### 1.2. Overview\n![](./model/img/repvgg.png)\n\n#### 1.3. Usage Code\n```python\n\nfrom model.rep.repvgg import RepBlock\nimport torch\n\n\ninput=torch.randn(50,512,49,49)\nrepblock=RepBlock(512,512)\nrepblock.eval()\nout=repblock(input)\nrepblock._switch_to_deploy()\nout2=repblock(input)\nprint('difference between vgg and repvgg')\nprint(((out2-out)**2).sum())\n```\n\n\n\n***\n\n### 2. ACNet Usage\n#### 2.1. Paper\n[\"ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks\"](https://arxiv.org/abs/1908.03930)\n\n#### 2.2. Overview\n![](./model/img/acnet.png)\n\n#### 2.3. Usage Code\n```python\nfrom model.rep.acnet import ACNet\nimport torch\nfrom torch import nn\n\ninput=torch.randn(50,512,49,49)\nacnet=ACNet(512,512)\nacnet.eval()\nout=acnet(input)\nacnet._switch_to_deploy()\nout2=acnet(input)\nprint('difference:')\nprint(((out2-out)**2).sum())\n\n```\n\n\n\n***\n\n### 2. Diverse Branch Block Usage\n#### 2.1. Paper\n[\"Diverse Branch Block: Building a Convolution as an Inception-like Unit\"](https://arxiv.org/abs/2103.13425)\n\n#### 2.2. Overview\n![](./model/img/ddb.png)\n\n#### 2.3. Usage Code\n##### 2.3.1 Transform I\n```python\nfrom model.rep.ddb import transI_conv_bn\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n#conv+bn\nconv1=nn.Conv2d(64,64,3,padding=1)\nbn1=nn.BatchNorm2d(64)\nbn1.eval()\nout1=bn1(conv1(input))\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transI_conv_bn(conv1,bn1)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.2 Transform II\n```python\nfrom model.rep.ddb import transII_conv_branch\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1=nn.Conv2d(64,64,3,padding=1)\nconv2=nn.Conv2d(64,64,3,padding=1)\nout1=conv1(input)+conv2(input)\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transII_conv_branch(conv1,conv2)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.3 Transform III\n```python\nfrom model.rep.ddb import transIII_conv_sequential\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1=nn.Conv2d(64,64,1,padding=0,bias=False)\nconv2=nn.Conv2d(64,64,3,padding=1,bias=False)\nout1=conv2(conv1(input))\n\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1,bias=False)\nconv_fuse.weight.data=transIII_conv_sequential(conv1,conv2)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.4 Transform IV\n```python\nfrom model.rep.ddb import transIV_conv_concat\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1=nn.Conv2d(64,32,3,padding=1)\nconv2=nn.Conv2d(64,32,3,padding=1)\nout1=torch.cat([conv1(input),conv2(input)],dim=1)\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transIV_conv_concat(conv1,conv2)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.5 Transform V\n```python\nfrom model.rep.ddb import transV_avg\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\navg=nn.AvgPool2d(kernel_size=3,stride=1)\nout1=avg(input)\n\nconv=transV_avg(64,3)\nout2=conv(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n\n##### 2.3.6 Transform VI\n```python\nfrom model.rep.ddb import transVI_conv_scale\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1x1=nn.Conv2d(64,64,1)\nconv1x3=nn.Conv2d(64,64,(1,3),padding=(0,1))\nconv3x1=nn.Conv2d(64,64,(3,1),padding=(1,0))\nout1=conv1x1(input)+conv1x3(input)+conv3x1(input)\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transVI_conv_scale(conv1x1,conv1x3,conv3x1)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n\n\n\n\n# Convolution Series\n\n- Pytorch implementation of [\"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications---CVPR2017\"](https://arxiv.org/abs/1704.04861)\n\n- Pytorch implementation of [\"Efficientnet: Rethinking model scaling for convolutional neural networks---PMLR2019\"](http://proceedings.mlr.press/v97/tan19a.html)\n\n- Pytorch implementation of [\"Involution: Inverting the Inherence of Convolution for Visual Recognition---CVPR2021\"](https://arxiv.org/abs/2103.06255)\n\n- Pytorch implementation of [\"Dynamic Convolution: Attention over Convolution Kernels---CVPR2020 Oral\"](https://arxiv.org/abs/1912.03458)\n\n- Pytorch implementation of [\"CondConv: Conditionally Parameterized Convolutions for Efficient Inference---NeurIPS2019\"](https://arxiv.org/abs/1904.04971)\n\n***\n\n### 1. Depthwise Separable Convolution Usage\n#### 1.1. Paper\n[\"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications\"](https://arxiv.org/abs/1704.04861)\n\n#### 1.2. Overview\n![](./model/img/DepthwiseSeparableConv.png)\n\n#### 1.3. Usage Code\n```python\nfrom model.conv.DepthwiseSeparableConvolution import DepthwiseSeparableConvolution\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,3,224,224)\ndsconv=DepthwiseSeparableConvolution(3,64)\nout=dsconv(input)\nprint(out.shape)\n```\n\n***\n\n\n### 2. MBConv Usage\n#### 2.1. Paper\n[\"Efficientnet: Rethinking model scaling for convolutional neural networks\"](http://proceedings.mlr.press/v97/tan19a.html)\n\n#### 2.2. Overview\n![](./model/img/MBConv.jpg)\n\n#### 2.3. Usage Code\n```python\nfrom model.conv.MBConv import MBConvBlock\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,3,224,224)\nmbconv=MBConvBlock(ksize=3,input_filters=3,output_filters=512,image_size=224)\nout=mbconv(input)\nprint(out.shape)\n\n\n```\n\n***\n\n\n### 3. Involution Usage\n#### 3.1. Paper\n[\"Involution: Inverting the Inherence of Convolution for Visual Recognition\"](https://arxiv.org/abs/2103.06255)\n\n#### 3.2. Overview\n![](./model/img/Involution.png)\n\n#### 3.3. Usage Code\n```python\nfrom model.conv.Involution import Involution\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,4,64,64)\ninvolution=Involution(kernel_size=3,in_channel=4,stride=2)\nout=involution(input)\nprint(out.shape)\n```\n\n***\n\n\n### 4. DynamicConv Usage\n#### 4.1. Paper\n[\"Dynamic Convolution: Attention over Convolution Kernels\"](https://arxiv.org/abs/1912.03458)\n\n#### 4.2. Overview\n![](./model/img/DynamicConv.png)\n\n#### 4.3. Usage Code\n```python\nfrom model.conv.DynamicConv import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(2,32,64,64)\n    m=DynamicConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)\n    out=m(input)\n    print(out.shape) # 2,32,64,64\n\n```\n\n***\n\n\n### 5. CondConv Usage\n#### 5.1. Paper\n[\"CondConv: Conditionally Parameterized Convolutions for Efficient Inference\"](https://arxiv.org/abs/1904.04971)\n\n#### 5.2. Overview\n![](./model/img/CondConv.png)\n\n#### 5.3. Usage Code\n```python\nfrom model.conv.CondConv import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\n\n\n\nif __name__ == '__main__':\n    input=torch.randn(2,32,64,64)\n    m=CondConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)\n    out=m(input)\n    print(out.shape)\n\n```\n\n\n\n## 其他项目推荐\n\n-------\n\n🔥🔥🔥 重磅！！！作为项目补充，更多论文层面的解析，可以关注新开源的项目 **[FightingCV-Paper-Reading](https://github.com/xmu-xiaoma666/FightingCV-Paper-Reading)** ，里面汇集和整理了各大顶会顶刊的论文解析\n\n\n\n🔥🔥🔥重磅！！！ 最近为大家整理了网上的各种AI相关的视频教程和必读论文 **[FightingCV-Course\n](https://github.com/xmu-xiaoma666/FightingCV-Course)**\n\n\n🔥🔥🔥 重磅！！！最近全新开源了一个 **[YOLOAir](https://github.com/iscyy/yoloair)** 目标检测代码库 ，里面集成了多种YOLO模型，包括YOLOv5, YOLOv7,YOLOR, YOLOX,YOLOv4, YOLOv3以及其他YOLO模型，还包括多种现有Attention机制。\n\n\n🔥🔥🔥 **ECCV2022论文汇总：[ECCV2022-Paper-List](https://github.com/xmu-xiaoma666/ECCV2022-Paper-List/blob/master/README.md)**\n\n\n<!-- ![image](https://user-images.githubusercontent.com/33897496/184842902-9acff374-b3e7-401a-80fd-9d484e40c637.png) -->\n"
  },
  {
    "path": "README_EN.md",
    "content": "\n<img src=\"./FightingCVimg/LOGO.gif\" height=\"200\" width=\"400\"/>\n\nEnglish | [简体中文](./README.md)\n\n# FightingCV Codebase For [***Attention***](#attention-series),[***Backbone***](#backbone-series), [***MLP***](#mlp-series), [***Re-parameter***](#re-parameter-series), [**Convolution**](#convolution-series)\n\n![](https://img.shields.io/badge/fightingcv-v0.0.1-brightgreen)\n![](https://img.shields.io/badge/python->=v3.0-blue)\n![](https://img.shields.io/badge/pytorch->=v1.4-red)\n\n<!--\n-------\n*If this project is helpful to you, welcome to give a*star***.* \n\n*Don't forget to*follow*me to learn about project updates.*\n\n-->\n\n-------\n\n\n🔥🔥🔥As a supplement to the project, a object detection codebase [YOLOAir](https://github.com/iscyy/yoloair) has recently been newly opened, which integrates various attention mechanisms in the object detection algorithm. The code is simple and easy to read. Welcome to play and star🌟!**\n\n\n<!-- ![image](https://user-images.githubusercontent.com/33897496/184842902-9acff374-b3e7-401a-80fd-9d484e40c637.png) -->\n\n\n\n-------\n\nHello, everyone, I'm Xiaoma 🚀🚀🚀\n\n***For beginners (like me):***\nRecently, I found a problem when reading the paper. Sometimes the core idea of the paper is very simple, and the core code may be just a dozen lines. However, when I open the source code of the author's release, I find that the proposed module is embedded in the task framework such as classification, detection and segmentation, resulting in redundant code. For me who is not familiar with the specific task framework, it is difficult to find the core code, resulting in some difficulties in understanding the paper and network ideas.\n\n***For advanced (like you):***\nIf the basic units conv, FC and RNN are regarded as small Lego blocks, and the structures transformer and RESNET are regarded as LEGO castles that have been built. The modules provided by this project are LEGO components with complete semantic informationLet scientific researchers avoid repeatedly building wheels, just think about how to use these \"LEGO components\" to build more colorful works.\n\n***For proficient (may be like you):***\nLimited capacity, do not like light spraying!!!\n\n***For All：***\nThis project aims to realize a code base that can make beginners of deep learning understand and serve scientific research and industrial communities. As [fightingcv wechat official account]( https://mp.weixin.qq.com/s/m9RiivbbDPdjABsTd6q8FA )The purpose of this project is to achieve 🚀Let there be no hard to read papers in the world🚀。\n(at the same time, we also welcome all scientific researchers to sort out the core code of their work into this project, promote the development of the scientific research community, and indicate the author of the code in readme ~)\n\n\n<!--\n\n\n## Wechat Official account &  communication group\n\n\n\nWelcome to pay attention to wechat official account: **fightingcv**\n\n\n\nThe official account shares papers, algorithms and codes every day Oh~\n\n\n\n\n**Share some recent papers and analysis in the group every day. Welcome to study and exchange ha~~~\n\n(if you can't add it, you can add wechat: **775629340**, remember the remarks **[company / school + direction + ID])**\n\n![](./FightingCVimg/wechat.jpg)\n\nWe strongly recommend that you pay attention to [Zhihu]( https://www.zhihu.com/people/jason-14-58-38/posts )Account number and **[fightingcv Wechat official account**]( https://mp.weixin.qq.com/s/m9RiivbbDPdjABsTd6q8FA )** to quickly learn about the latest high-quality dry goods resources.\n\n-->\n\n***\n\n# Contents\n\n- [Attention Series](#attention-series)\n    - [1. External Attention Usage](#1-external-attention-usage)\n\n    - [2. Self Attention Usage](#2-self-attention-usage)\n\n    - [3. Simplified Self Attention Usage](#3-simplified-self-attention-usage)\n\n    - [4. Squeeze-and-Excitation Attention Usage](#4-squeeze-and-excitation-attention-usage)\n\n    - [5. SK Attention Usage](#5-sk-attention-usage)\n\n    - [6. CBAM Attention Usage](#6-cbam-attention-usage)\n\n    - [7. BAM Attention Usage](#7-bam-attention-usage)\n    \n    - [8. ECA Attention Usage](#8-eca-attention-usage)\n\n    - [9. DANet Attention Usage](#9-danet-attention-usage)\n\n    - [10. Pyramid Split Attention (PSA) Usage](#10-Pyramid-Split-Attention-Usage)\n\n    - [11. Efficient Multi-Head Self-Attention(EMSA) Usage](#11-Efficient-Multi-Head-Self-Attention-Usage)\n\n    - [12. Shuffle Attention Usage](#12-Shuffle-Attention-Usage)\n    \n    - [13. MUSE Attention Usage](#13-MUSE-Attention-Usage)\n  \n    - [14. SGE Attention Usage](#14-SGE-Attention-Usage)\n\n    - [15. A2 Attention Usage](#15-A2-Attention-Usage)\n\n    - [16. AFT Attention Usage](#16-AFT-Attention-Usage)\n\n    - [17. Outlook Attention Usage](#17-Outlook-Attention-Usage)\n\n    - [18. ViP Attention Usage](#18-ViP-Attention-Usage)\n\n    - [19. CoAtNet Attention Usage](#19-CoAtNet-Attention-Usage)\n\n    - [20. HaloNet Attention Usage](#20-HaloNet-Attention-Usage)\n\n    - [21. Polarized Self-Attention Usage](#21-Polarized-Self-Attention-Usage)\n\n    - [22. CoTAttention Usage](#22-CoTAttention-Usage)\n\n    - [23. Residual Attention Usage](#23-Residual-Attention-Usage)\n  \n    - [24. S2 Attention Usage](#24-S2-Attention-Usage)\n\n    - [25. GFNet Attention Usage](#25-GFNet-Attention-Usage)\n\n    - [26. Triplet Attention Usage](#26-TripletAttention-Usage)\n\n    - [27. Coordinate Attention Usage](#27-Coordinate-Attention-Usage)\n\n    - [28. MobileViT Attention Usage](#28-MobileViT-Attention-Usage)\n\n    - [29. ParNet Attention Usage](#29-ParNet-Attention-Usage)\n\n    - [30. UFO Attention Usage](#30-UFO-Attention-Usage)\n\n    - [31. ACmix Attention Usage](#31-Acmix-Attention-Usage)\n  \n    - [32. MobileViTv2 Attention Usage](#32-MobileViTv2-Attention-Usage)\n\n    - [33. DAT Attention Usage](#33-DAT-Attention-Usage)\n\n    - [34. CrossFormer Attention Usage](#34-CrossFormer-Attention-Usage)\n\n    - [35. MOATransformer Attention Usage](#35-MOATransformer-Attention-Usage)\n\n    - [36. CrissCrossAttention Attention Usage](#36-CrissCrossAttention-Attention-Usage)\n\n    - [37. Axial_attention Attention Usage](#37-Axial_attention-Attention-Usage)\n\n- [Backbone Series](#Backbone-series)\n\n    - [1. ResNet Usage](#1-ResNet-Usage)\n\n    - [2. ResNeXt Usage](#2-ResNeXt-Usage)\n\n    - [3. MobileViT Usage](#3-MobileViT-Usage)\n\n    - [4. ConvMixer Usage](#4-ConvMixer-Usage)\n\n    - [5. ShuffleTransformer Usage](#5-ShuffleTransformer-Usage)\n\n    - [6. ConTNet Usage](#6-ConTNet-Usage)\n\n    - [7. HATNet Usage](#7-HATNet-Usage)\n\n    - [8. CoaT Usage](#8-CoaT-Usage)\n\n    - [9. PVT Usage](#9-PVT-Usage)\n\n    - [10. CPVT Usage](#10-CPVT-Usage)\n\n    - [11. PIT Usage](#11-PIT-Usage)\n\n    - [12. CrossViT Usage](#12-CrossViT-Usage)\n\n    - [13. TnT Usage](#13-TnT-Usage)\n\n    - [14. DViT Usage](#14-DViT-Usage)\n\n    - [15. CeiT Usage](#15-CeiT-Usage)\n\n    - [16. ConViT Usage](#16-ConViT-Usage)\n\n    - [17. CaiT Usage](#17-CaiT-Usage)\n\n    - [18. PatchConvnet Usage](#18-PatchConvnet-Usage)\n\n    - [19. DeiT Usage](#19-DeiT-Usage)\n\n    - [20. LeViT Usage](#20-LeViT-Usage)\n\n    - [21. VOLO Usage](#21-VOLO-Usage)\n    \n    - [22. Container Usage](#22-Container-Usage)\n\n    - [23. CMT Usage](#23-CMT-Usage)\n\n\n- [MLP Series](#mlp-series)\n\n    - [1. RepMLP Usage](#1-RepMLP-Usage)\n\n    - [2. MLP-Mixer Usage](#2-MLP-Mixer-Usage)\n\n    - [3. ResMLP Usage](#3-ResMLP-Usage)\n\n    - [4. gMLP Usage](#4-gMLP-Usage)\n\n    - [5. sMLP Usage](#5-sMLP-Usage)\n\n    - [6. vip-mlp Usage](#6-vip-mlp-Usage)\n\n- [Re-Parameter(ReP) Series](#Re-Parameter-series)\n\n    - [1. RepVGG Usage](#1-RepVGG-Usage)\n\n    - [2. ACNet Usage](#2-ACNet-Usage)\n\n    - [3. Diverse Branch Block(DDB) Usage](#3-Diverse-Branch-Block-Usage)\n\n- [Convolution Series](#Convolution-series)\n\n    - [1. Depthwise Separable Convolution Usage](#1-Depthwise-Separable-Convolution-Usage)\n\n    - [2. MBConv Usage](#2-MBConv-Usage)\n\n    - [3. Involution Usage](#3-Involution-Usage)\n\n    - [4. DynamicConv Usage](#4-DynamicConv-Usage)\n\n    - [5. CondConv Usage](#5-CondConv-Usage)\n\n***\n\n\n# Attention Series\n\n- Pytorch implementation of [\"Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks---arXiv 2021.05.05\"](https://arxiv.org/abs/2105.02358)\n\n- Pytorch implementation of [\"Attention Is All You Need---NIPS2017\"](https://arxiv.org/pdf/1706.03762.pdf)\n\n- Pytorch implementation of [\"Squeeze-and-Excitation Networks---CVPR2018\"](https://arxiv.org/abs/1709.01507)\n\n- Pytorch implementation of [\"Selective Kernel Networks---CVPR2019\"](https://arxiv.org/pdf/1903.06586.pdf)\n\n- Pytorch implementation of [\"CBAM: Convolutional Block Attention Module---ECCV2018\"](https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf)\n\n- Pytorch implementation of [\"BAM: Bottleneck Attention Module---BMCV2018\"](https://arxiv.org/pdf/1807.06514.pdf)\n\n- Pytorch implementation of [\"ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks---CVPR2020\"](https://arxiv.org/pdf/1910.03151.pdf)\n\n- Pytorch implementation of [\"Dual Attention Network for Scene Segmentation---CVPR2019\"](https://arxiv.org/pdf/1809.02983.pdf)\n\n- Pytorch implementation of [\"EPSANet: An Efficient Pyramid Split Attention Block on Convolutional Neural Network---arXiv 2021.05.30\"](https://arxiv.org/pdf/2105.14447.pdf)\n\n- Pytorch implementation of [\"ResT: An Efficient Transformer for Visual Recognition---arXiv 2021.05.28\"](https://arxiv.org/abs/2105.13677)\n\n- Pytorch implementation of [\"SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS---ICASSP 2021\"](https://arxiv.org/pdf/2102.00240.pdf)\n\n- Pytorch implementation of [\"MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning---arXiv 2019.11.17\"](https://arxiv.org/abs/1911.09483)\n\n- Pytorch implementation of [\"Spatial Group-wise Enhance: Improving Semantic Feature Learning in Convolutional Networks---arXiv 2019.05.23\"](https://arxiv.org/pdf/1905.09646.pdf)\n\n- Pytorch implementation of [\"A2-Nets: Double Attention Networks---NIPS2018\"](https://arxiv.org/pdf/1810.11579.pdf)\n\n\n- Pytorch implementation of [\"An Attention Free Transformer---ICLR2021 (Apple New Work)\"](https://arxiv.org/pdf/2105.14103v1.pdf)\n\n\n- Pytorch implementation of [VOLO: Vision Outlooker for Visual Recognition---arXiv 2021.06.24\"](https://arxiv.org/abs/2106.13112) \n  [【论文解析】](https://zhuanlan.zhihu.com/p/385561050)\n\n\n- Pytorch implementation of [Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition---arXiv 2021.06.23](https://arxiv.org/abs/2106.12368) \n  [【论文解析】](https://mp.weixin.qq.com/s/5gonUQgBho_m2O54jyXF_Q)\n\n\n- Pytorch implementation of [CoAtNet: Marrying Convolution and Attention for All Data Sizes---arXiv 2021.06.09](https://arxiv.org/abs/2106.04803) \n  [【论文解析】](https://zhuanlan.zhihu.com/p/385578588)\n\n\n- Pytorch implementation of [Scaling Local Self-Attention for Parameter Efficient Visual Backbones---CVPR2021 Oral](https://arxiv.org/pdf/2103.12731.pdf)  [【论文解析】](https://zhuanlan.zhihu.com/p/388598744)\n\n\n\n- Pytorch implementation of [Polarized Self-Attention: Towards High-quality Pixel-wise Regression---arXiv 2021.07.02](https://arxiv.org/abs/2107.00782)  [【论文解析】](https://zhuanlan.zhihu.com/p/389770482) \n\n\n- Pytorch implementation of [Contextual Transformer Networks for Visual Recognition---arXiv 2021.07.26](https://arxiv.org/abs/2107.12292)  [【论文解析】](https://zhuanlan.zhihu.com/p/394795481) \n\n\n- Pytorch implementation of [Residual Attention: A Simple but Effective Method for Multi-Label Recognition---ICCV2021](https://arxiv.org/abs/2108.02456) \n\n\n- Pytorch implementation of [S²-MLPv2: Improved Spatial-Shift MLP Architecture for Vision---arXiv 2021.08.02](https://arxiv.org/abs/2108.01072) [【论文解析】](https://zhuanlan.zhihu.com/p/397003638) \n\n- Pytorch implementation of [Global Filter Networks for Image Classification---arXiv 2021.07.01](https://arxiv.org/abs/2107.00645) \n\n- Pytorch implementation of [Rotate to Attend: Convolutional Triplet Attention Module---WACV 2021](https://arxiv.org/abs/2010.03045) \n\n- Pytorch implementation of [Coordinate Attention for Efficient Mobile Network Design ---CVPR 2021](https://arxiv.org/abs/2103.02907)\n\n- Pytorch implementation of [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2021.10.05](https://arxiv.org/abs/2110.02178)\n\n- Pytorch implementation of [Non-deep Networks---ArXiv 2021.10.20](https://arxiv.org/abs/2110.07641)\n\n- Pytorch implementation of [UFO-ViT: High Performance Linear Vision Transformer without Softmax---ArXiv 2021.09.29](https://arxiv.org/abs/2109.14382)\n\n- Pytorch implementation of [Separable Self-attention for Mobile Vision Transformers---ArXiv 2022.06.06](https://arxiv.org/abs/2206.02680)\n\n- Pytorch implementation of [On the Integration of Self-Attention and Convolution---ArXiv 2022.03.14](https://arxiv.org/pdf/2111.14556.pdf)\n\n- Pytorch implementation of [CROSSFORMER: A VERSATILE VISION TRANSFORMER HINGING ON CROSS-SCALE ATTENTION---ICLR 2022](https://arxiv.org/pdf/2108.00154.pdf)\n\n- Pytorch implementation of [Aggregating Global Features into Local Vision Transformer](https://arxiv.org/abs/2201.12903)\n\n- Pytorch implementation of [CCNet: Criss-Cross Attention for Semantic Segmentation](https://arxiv.org/abs/1811.11721)\n\n- Pytorch implementation of [Axial Attention in Multidimensional Transformers](https://arxiv.org/abs/1912.12180)\n***\n\n\n### 1. External Attention Usage\n#### 1.1. Paper\n[\"Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks\"](https://arxiv.org/abs/2105.02358)\n\n#### 1.2. Overview\n![](./model/img/External_Attention.png)\n\n#### 1.3. Usage Code\n```python\nfrom model.attention.ExternalAttention import ExternalAttention\nimport torch\n\ninput=torch.randn(50,49,512)\nea = ExternalAttention(d_model=512,S=8)\noutput=ea(input)\nprint(output.shape)\n```\n\n***\n\n\n### 2. Self Attention Usage\n#### 2.1. Paper\n[\"Attention Is All You Need\"](https://arxiv.org/pdf/1706.03762.pdf)\n\n#### 1.2. Overview\n![](./model/img/SA.png)\n\n#### 1.3. Usage Code\n```python\nfrom model.attention.SelfAttention import ScaledDotProductAttention\nimport torch\n\ninput=torch.randn(50,49,512)\nsa = ScaledDotProductAttention(d_model=512, d_k=512, d_v=512, h=8)\noutput=sa(input,input,input)\nprint(output.shape)\n```\n\n***\n\n### 3. Simplified Self Attention Usage\n#### 3.1. Paper\n[None]()\n\n#### 3.2. Overview\n![](./model/img/SSA.png)\n\n#### 3.3. Usage Code\n```python\nfrom model.attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttention\nimport torch\n\ninput=torch.randn(50,49,512)\nssa = SimplifiedScaledDotProductAttention(d_model=512, h=8)\noutput=ssa(input,input,input)\nprint(output.shape)\n\n```\n\n***\n\n### 4. Squeeze-and-Excitation Attention Usage\n#### 4.1. Paper\n[\"Squeeze-and-Excitation Networks\"](https://arxiv.org/abs/1709.01507)\n\n#### 4.2. Overview\n![](./model/img/SE.png)\n\n#### 4.3. Usage Code\n```python\nfrom model.attention.SEAttention import SEAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\nse = SEAttention(channel=512,reduction=8)\noutput=se(input)\nprint(output.shape)\n\n```\n\n***\n\n### 5. SK Attention Usage\n#### 5.1. Paper\n[\"Selective Kernel Networks\"](https://arxiv.org/pdf/1903.06586.pdf)\n\n#### 5.2. Overview\n![](./model/img/SK.png)\n\n#### 5.3. Usage Code\n```python\nfrom model.attention.SKAttention import SKAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\nse = SKAttention(channel=512,reduction=8)\noutput=se(input)\nprint(output.shape)\n\n```\n***\n\n### 6. CBAM Attention Usage\n#### 6.1. Paper\n[\"CBAM: Convolutional Block Attention Module\"](https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf)\n\n#### 6.2. Overview\n![](./model/img/CBAM1.png)\n\n![](./model/img/CBAM2.png)\n\n#### 6.3. Usage Code\n```python\nfrom model.attention.CBAM import CBAMBlock\nimport torch\n\ninput=torch.randn(50,512,7,7)\nkernel_size=input.shape[2]\ncbam = CBAMBlock(channel=512,reduction=16,kernel_size=kernel_size)\noutput=cbam(input)\nprint(output.shape)\n\n```\n\n***\n\n### 7. BAM Attention Usage\n#### 7.1. Paper\n[\"BAM: Bottleneck Attention Module\"](https://arxiv.org/pdf/1807.06514.pdf)\n\n#### 7.2. Overview\n![](./model/img/BAM.png)\n\n#### 7.3. Usage Code\n```python\nfrom model.attention.BAM import BAMBlock\nimport torch\n\ninput=torch.randn(50,512,7,7)\nbam = BAMBlock(channel=512,reduction=16,dia_val=2)\noutput=bam(input)\nprint(output.shape)\n\n```\n\n***\n\n### 8. ECA Attention Usage\n#### 8.1. Paper\n[\"ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks\"](https://arxiv.org/pdf/1910.03151.pdf)\n\n#### 8.2. Overview\n![](./model/img/ECA.png)\n\n#### 8.3. Usage Code\n```python\nfrom model.attention.ECAAttention import ECAAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\neca = ECAAttention(kernel_size=3)\noutput=eca(input)\nprint(output.shape)\n\n```\n\n***\n\n### 9. DANet Attention Usage\n#### 9.1. Paper\n[\"Dual Attention Network for Scene Segmentation\"](https://arxiv.org/pdf/1809.02983.pdf)\n\n#### 9.2. Overview\n![](./model/img/danet.png)\n\n#### 9.3. Usage Code\n```python\nfrom model.attention.DANet import DAModule\nimport torch\n\ninput=torch.randn(50,512,7,7)\ndanet=DAModule(d_model=512,kernel_size=3,H=7,W=7)\nprint(danet(input).shape)\n\n```\n\n***\n\n### 10. Pyramid Split Attention Usage\n\n#### 10.1. Paper\n[\"EPSANet: An Efficient Pyramid Split Attention Block on Convolutional Neural Network\"](https://arxiv.org/pdf/2105.14447.pdf)\n\n#### 10.2. Overview\n![](./model/img/psa.png)\n\n#### 10.3. Usage Code\n```python\nfrom model.attention.PSA import PSA\nimport torch\n\ninput=torch.randn(50,512,7,7)\npsa = PSA(channel=512,reduction=8)\noutput=psa(input)\nprint(output.shape)\n\n```\n\n***\n\n\n### 11. Efficient Multi-Head Self-Attention Usage\n\n#### 11.1. Paper\n[\"ResT: An Efficient Transformer for Visual Recognition\"](https://arxiv.org/abs/2105.13677)\n\n#### 11.2. Overview\n![](./model/img/EMSA.png)\n\n#### 11.3. Usage Code\n```python\n\nfrom model.attention.EMSA import EMSA\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,64,512)\nemsa = EMSA(d_model=512, d_k=512, d_v=512, h=8,H=8,W=8,ratio=2,apply_transform=True)\noutput=emsa(input,input,input)\nprint(output.shape)\n    \n```\n\n***\n\n\n### 12. Shuffle Attention Usage\n\n#### 12.1. Paper\n[\"SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS\"](https://arxiv.org/pdf/2102.00240.pdf)\n\n#### 12.2. Overview\n![](./model/img/ShuffleAttention.jpg)\n\n#### 12.3. Usage Code\n```python\n\nfrom model.attention.ShuffleAttention import ShuffleAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\ninput=torch.randn(50,512,7,7)\nse = ShuffleAttention(channel=512,G=8)\noutput=se(input)\nprint(output.shape)\n\n    \n```\n\n\n***\n\n\n### 13. MUSE Attention Usage\n\n#### 13.1. Paper\n[\"MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning\"](https://arxiv.org/abs/1911.09483)\n\n#### 13.2. Overview\n![](./model/img/MUSE.png)\n\n#### 13.3. Usage Code\n```python\nfrom model.attention.MUSEAttention import MUSEAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\ninput=torch.randn(50,49,512)\nsa = MUSEAttention(d_model=512, d_k=512, d_v=512, h=8)\noutput=sa(input,input,input)\nprint(output.shape)\n\n```\n\n***\n\n\n### 14. SGE Attention Usage\n\n#### 14.1. Paper\n[Spatial Group-wise Enhance: Improving Semantic Feature Learning in Convolutional Networks](https://arxiv.org/pdf/1905.09646.pdf)\n\n#### 14.2. Overview\n![](./model/img/SGE.png)\n\n#### 14.3. Usage Code\n```python\nfrom model.attention.SGE import SpatialGroupEnhance\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\nsge = SpatialGroupEnhance(groups=8)\noutput=sge(input)\nprint(output.shape)\n\n```\n\n***\n\n\n### 15. A2 Attention Usage\n\n#### 15.1. Paper\n[A2-Nets: Double Attention Networks](https://arxiv.org/pdf/1810.11579.pdf)\n\n#### 15.2. Overview\n![](./model/img/A2.png)\n\n#### 15.3. Usage Code\n```python\nfrom model.attention.A2Atttention import DoubleAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\na2 = DoubleAttention(512,128,128,True)\noutput=a2(input)\nprint(output.shape)\n\n```\n\n\n\n### 16. AFT Attention Usage\n\n#### 16.1. Paper\n[An Attention Free Transformer](https://arxiv.org/pdf/2105.14103v1.pdf)\n\n#### 16.2. Overview\n![](./model/img/AFT.jpg)\n\n#### 16.3. Usage Code\n```python\nfrom model.attention.AFT import AFT_FULL\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,49,512)\naft_full = AFT_FULL(d_model=512, n=49)\noutput=aft_full(input)\nprint(output.shape)\n\n```\n\n\n\n\n\n\n### 17. Outlook Attention Usage\n\n#### 17.1. Paper\n\n\n[VOLO: Vision Outlooker for Visual Recognition\"](https://arxiv.org/abs/2106.13112)\n\n\n#### 17.2. Overview\n![](./model/img/OutlookAttention.png)\n\n#### 17.3. Usage Code\n```python\nfrom model.attention.OutlookAttention import OutlookAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,28,28,512)\noutlook = OutlookAttention(dim=512)\noutput=outlook(input)\nprint(output.shape)\n\n```\n\n\n***\n\n\n\n\n\n\n### 18. ViP Attention Usage\n\n#### 18.1. Paper\n\n\n[Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition\"](https://arxiv.org/abs/2106.12368)\n\n\n#### 18.2. Overview\n![](./model/img/ViP.png)\n\n#### 18.3. Usage Code\n```python\n\nfrom model.attention.ViP import WeightedPermuteMLP\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(64,8,8,512)\nseg_dim=8\nvip=WeightedPermuteMLP(512,seg_dim)\nout=vip(input)\nprint(out.shape)\n\n```\n\n\n***\n\n\n\n\n\n### 19. CoAtNet Attention Usage\n\n#### 19.1. Paper\n\n\n[CoAtNet: Marrying Convolution and Attention for All Data Sizes\"](https://arxiv.org/abs/2106.04803) \n\n\n#### 19.2. Overview\nNone\n\n\n#### 19.3. Usage Code\n```python\n\nfrom model.attention.CoAtNet import CoAtNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,3,224,224)\nmbconv=CoAtNet(in_ch=3,image_size=224)\nout=mbconv(input)\nprint(out.shape)\n\n```\n\n\n***\n\n\n\n\n\n\n### 20. HaloNet Attention Usage\n\n#### 20.1. Paper\n\n\n[Scaling Local Self-Attention for Parameter Efficient Visual Backbones\"](https://arxiv.org/pdf/2103.12731.pdf) \n\n\n#### 20.2. Overview\n\n![](./model/img/HaloNet.png)\n\n#### 20.3. Usage Code\n```python\n\nfrom model.attention.HaloAttention import HaloAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,512,8,8)\nhalo = HaloAttention(dim=512,\n    block_size=2,\n    halo_size=1,)\noutput=halo(input)\nprint(output.shape)\n\n```\n\n\n***\n\n### 21. Polarized Self-Attention Usage\n\n#### 21.1. Paper\n\n[Polarized Self-Attention: Towards High-quality Pixel-wise Regression\"](https://arxiv.org/abs/2107.00782)  \n\n\n#### 21.2. Overview\n\n![](./model/img/PoSA.png)\n\n#### 21.3. Usage Code\n```python\n\nfrom model.attention.PolarizedSelfAttention import ParallelPolarizedSelfAttention,SequentialPolarizedSelfAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,512,7,7)\npsa = SequentialPolarizedSelfAttention(channel=512)\noutput=psa(input)\nprint(output.shape)\n\n\n```\n\n\n***\n\n\n### 22. CoTAttention Usage\n\n#### 22.1. Paper\n\n[Contextual Transformer Networks for Visual Recognition---arXiv 2021.07.26](https://arxiv.org/abs/2107.12292) \n\n\n#### 22.2. Overview\n\n![](./model/img/CoT.png)\n\n#### 22.3. Usage Code\n```python\n\nfrom model.attention.CoTAttention import CoTAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\ncot = CoTAttention(dim=512,kernel_size=3)\noutput=cot(input)\nprint(output.shape)\n\n\n\n```\n\n***\n\n\n### 23. Residual Attention Usage\n\n#### 23.1. Paper\n\n[Residual Attention: A Simple but Effective Method for Multi-Label Recognition---ICCV2021](https://arxiv.org/abs/2108.02456) \n\n\n#### 23.2. Overview\n\n![](./model/img/ResAtt.png)\n\n#### 23.3. Usage Code\n```python\n\nfrom model.attention.ResidualAttention import ResidualAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\nresatt = ResidualAttention(channel=512,num_class=1000,la=0.2)\noutput=resatt(input)\nprint(output.shape)\n\n\n\n```\n\n***\n\n\n\n### 24. S2 Attention Usage\n\n#### 24.1. Paper\n\n[S²-MLPv2: Improved Spatial-Shift MLP Architecture for Vision---arXiv 2021.08.02](https://arxiv.org/abs/2108.01072) \n\n\n#### 24.2. Overview\n\n![](./model/img/S2Attention.png)\n\n#### 24.3. Usage Code\n```python\nfrom model.attention.S2Attention import S2Attention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\ns2att = S2Attention(channels=512)\noutput=s2att(input)\nprint(output.shape)\n\n```\n\n***\n\n\n\n### 25. GFNet Attention Usage\n\n#### 25.1. Paper\n\n[Global Filter Networks for Image Classification---arXiv 2021.07.01](https://arxiv.org/abs/2107.00645) \n\n\n#### 25.2. Overview\n\n![](./model/img/GFNet.jpg)\n\n#### 25.3. Usage Code - Implemented by [Wenliang Zhao (Author)](https://scholar.google.com/citations?user=lyPWvuEAAAAJ&hl=en)\n\n```python\nfrom model.attention.gfnet import GFNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nx = torch.randn(1, 3, 224, 224)\ngfnet = GFNet(embed_dim=384, img_size=224, patch_size=16, num_classes=1000)\nout = gfnet(x)\nprint(out.shape)\n\n```\n\n***\n\n\n### 26. TripletAttention Usage\n\n#### 26.1. Paper\n\n[Rotate to Attend: Convolutional Triplet Attention Module---CVPR 2021](https://arxiv.org/abs/2010.03045) \n\n#### 26.2. Overview\n\n![](./model/img/triplet.png)\n\n#### 26.3. Usage Code - Implemented by [digantamisra98](https://github.com/digantamisra98)\n\n```python\nfrom model.attention.TripletAttention import TripletAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\ninput=torch.randn(50,512,7,7)\ntriplet = TripletAttention()\noutput=triplet(input)\nprint(output.shape)\n```\n\n\n***\n\n\n### 27. Coordinate Attention Usage\n\n#### 27.1. Paper\n\n[Coordinate Attention for Efficient Mobile Network Design---CVPR 2021](https://arxiv.org/abs/2103.02907)\n\n\n#### 27.2. Overview\n\n![](./model/img/CoordAttention.png)\n\n#### 27.3. Usage Code - Implemented by [Andrew-Qibin](https://github.com/Andrew-Qibin)\n\n```python\nfrom model.attention.CoordAttention import CoordAtt\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninp=torch.rand([2, 96, 56, 56])\ninp_dim, oup_dim = 96, 96\nreduction=32\n\ncoord_attention = CoordAtt(inp_dim, oup_dim, reduction=reduction)\noutput=coord_attention(inp)\nprint(output.shape)\n```\n\n***\n\n\n### 28. MobileViT Attention Usage\n\n#### 28.1. Paper\n\n[MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2021.10.05](https://arxiv.org/abs/2103.02907)\n\n\n#### 28.2. Overview\n\n![](./model/img/MobileViTAttention.png)\n\n#### 28.3. Usage Code\n\n```python\nfrom model.attention.MobileViTAttention import MobileViTAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    m=MobileViTAttention()\n    input=torch.randn(1,3,49,49)\n    output=m(input)\n    print(output.shape)  #output:(1,3,49,49)\n    \n```\n\n***\n\n\n### 29. ParNet Attention Usage\n\n#### 29.1. Paper\n\n[Non-deep Networks---ArXiv 2021.10.20](https://arxiv.org/abs/2110.07641)\n\n\n#### 29.2. Overview\n\n![](./model/img/ParNet.png)\n\n#### 29.3. Usage Code\n\n```python\nfrom model.attention.ParNetAttention import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    pna = ParNetAttention(channel=512)\n    output=pna(input)\n    print(output.shape) #50,512,7,7\n    \n```\n\n***\n\n\n### 30. UFO Attention Usage\n\n#### 30.1. Paper\n\n[UFO-ViT: High Performance Linear Vision Transformer without Softmax---ArXiv 2021.09.29](https://arxiv.org/abs/2110.07641)\n\n\n#### 30.2. Overview\n\n![](./model/img/UFO.png)\n\n#### 30.3. Usage Code\n\n```python\nfrom model.attention.UFOAttention import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    ufo = UFOAttention(d_model=512, d_k=512, d_v=512, h=8)\n    output=ufo(input,input,input)\n    print(output.shape) #[50, 49, 512]\n    \n```\n\n-\n\n### 31. ACmix Attention Usage\n\n#### 31.1. Paper\n\n[On the Integration of Self-Attention and Convolution](https://arxiv.org/pdf/2111.14556.pdf)\n\n#### 31.2. Usage Code\n\n```python\nfrom model.attention.ACmix import ACmix\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(50,256,7,7)\n    acmix = ACmix(in_planes=256, out_planes=256)\n    output=acmix(input)\n    print(output.shape)\n    \n```\n\n### 32. MobileViTv2 Attention Usage\n\n#### 32.1. Paper\n\n[Separable Self-attention for Mobile Vision Transformers---ArXiv 2022.06.06](https://arxiv.org/abs/2206.02680)\n\n\n#### 32.2. Overview\n\n![](./model/img/MobileViTv2.png)\n\n#### 32.3. Usage Code\n\n```python\nfrom model.attention.MobileViTv2Attention import MobileViTv2Attention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    sa = MobileViTv2Attention(d_model=512)\n    output=sa(input)\n    print(output.shape)\n    \n```\n\n### 33. DAT Attention Usage\n\n#### 33.1. Paper\n\n[Vision Transformer with Deformable Attention---CVPR2022](https://arxiv.org/abs/2201.00520)\n\n#### 33.2. Usage Code\n\n```python\nfrom model.attention.DAT import DAT\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = DAT(\n        img_size=224,\n        patch_size=4,\n        num_classes=1000,\n        expansion=4,\n        dim_stem=96,\n        dims=[96, 192, 384, 768],\n        depths=[2, 2, 6, 2],\n        stage_spec=[['L', 'S'], ['L', 'S'], ['L', 'D', 'L', 'D', 'L', 'D'], ['L', 'D']],\n        heads=[3, 6, 12, 24],\n        window_sizes=[7, 7, 7, 7] ,\n        groups=[-1, -1, 3, 6],\n        use_pes=[False, False, True, True],\n        dwc_pes=[False, False, False, False],\n        strides=[-1, -1, 1, 1],\n        sr_ratios=[-1, -1, -1, -1],\n        offset_range_factor=[-1, -1, 2, 2],\n        no_offs=[False, False, False, False],\n        fixed_pes=[False, False, False, False],\n        use_dwc_mlps=[False, False, False, False],\n        use_conv_patches=False,\n        drop_rate=0.0,\n        attn_drop_rate=0.0,\n        drop_path_rate=0.2,\n    )\n    output=model(input)\n    print(output[0].shape)\n    \n```\n\n### 34. CrossFormer Attention Usage\n\n#### 34.1. Paper\n\n[CROSSFORMER: A VERSATILE VISION TRANSFORMER HINGING ON CROSS-SCALE ATTENTION---ICLR 2022](https://arxiv.org/pdf/2108.00154.pdf)\n\n#### 34.2. Usage Code\n\n```python\nfrom model.attention.Crossformer import CrossFormer\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CrossFormer(img_size=224,\n        patch_size=[4, 8, 16, 32],\n        in_chans= 3,\n        num_classes=1000,\n        embed_dim=48,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        group_size=[7, 7, 7, 7],\n        mlp_ratio=4.,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.0,\n        drop_path_rate=0.1,\n        ape=False,\n        patch_norm=True,\n        use_checkpoint=False,\n        merge_size=[[2, 4], [2,4], [2, 4]]\n    )\n    output=model(input)\n    print(output.shape)\n    \n```\n\n### 35. MOATransformer Attention Usage\n\n#### 35.1. Paper\n\n[Aggregating Global Features into Local Vision Transformer](https://arxiv.org/abs/2201.12903)\n\n#### 35.2. Usage Code\n\n```python\nfrom model.attention.MOATransformer import MOATransformer\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = MOATransformer(\n        img_size=224,\n        patch_size=4,\n        in_chans=3,\n        num_classes=1000,\n        embed_dim=96,\n        depths=[2, 2, 6],\n        num_heads=[3, 6, 12],\n        window_size=14,\n        mlp_ratio=4.,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.0,\n        drop_path_rate=0.1,\n        ape=False,\n        patch_norm=True,\n        use_checkpoint=False\n    )\n    output=model(input)\n    print(output.shape)\n    \n```\n\n### 36. CrissCrossAttention Attention Usage\n\n#### 36.1. Paper\n\n[CCNet: Criss-Cross Attention for Semantic Segmentation](https://arxiv.org/abs/1811.11721)\n\n#### 36.2. Usage Code\n\n```python\nfrom model.attention.CrissCrossAttention import CrissCrossAttention\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(3, 64, 7, 7)\n    model = CrissCrossAttention(64)\n    outputs = model(input)\n    print(outputs.shape)\n    \n```\n\n### 37. Axial_attention Attention Usage\n\n#### 37.1. Paper\n\n[Axial Attention in Multidimensional Transformers](https://arxiv.org/abs/1912.12180)\n\n#### 37.2. Usage Code\n\n```python\nfrom model.attention.Axial_attention import AxialImageTransformer\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(3, 128, 7, 7)\n    model = AxialImageTransformer(\n        dim = 128,\n        depth = 12,\n        reversible = True\n    )\n    outputs = model(input)\n    print(outputs.shape)\n    \n```\n\n***\n\n\n# Backbone Series\n\n- Pytorch implementation of [\"Deep Residual Learning for Image Recognition---CVPR2016 Best Paper\"](https://arxiv.org/pdf/1512.03385.pdf)\n\n- Pytorch implementation of [\"Aggregated Residual Transformations for Deep Neural Networks---CVPR2017\"](https://arxiv.org/abs/1611.05431v2)\n\n- Pytorch implementation of [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2020.10.05](https://arxiv.org/abs/2103.02907)\n\n- Pytorch implementation of [Patches Are All You Need?---ICLR2022 (Under Review)](https://openreview.net/forum?id=TVHS5Y4dNvM)\n\n- Pytorch implementation of [Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer---ArXiv 2021.06.07](https://arxiv.org/abs/2106.03650)\n\n- Pytorch implementation of [ConTNet: Why not use convolution and transformer at the same time?---ArXiv 2021.04.27](https://arxiv.org/abs/2104.13497)\n\n- Pytorch implementation of [Vision Transformers with Hierarchical Attention---ArXiv 2022.06.15](https://arxiv.org/abs/2106.03180)\n\n- Pytorch implementation of [Co-Scale Conv-Attentional Image Transformers---ArXiv 2021.08.26](https://arxiv.org/abs/2104.06399)\n\n- Pytorch implementation of [Conditional Positional Encodings for Vision Transformers](https://arxiv.org/abs/2102.10882)\n\n- Pytorch implementation of [Rethinking Spatial Dimensions of Vision Transformers---ICCV 2021](https://arxiv.org/abs/2103.16302)\n\n- Pytorch implementation of [CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification---ICCV 2021](https://arxiv.org/abs/2103.14899)\n\n- Pytorch implementation of [Transformer in Transformer---NeurIPS 2021](https://arxiv.org/abs/2103.00112)\n\n- Pytorch implementation of [DeepViT: Towards Deeper Vision Transformer](https://arxiv.org/abs/2103.11886)\n\n- Pytorch implementation of [Incorporating Convolution Designs into Visual Transformers](https://arxiv.org/abs/2103.11816)\n***\n\n- Pytorch implementation of [ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases](https://arxiv.org/abs/2103.10697)\n\n- Pytorch implementation of [Augmenting Convolutional networks with attention-based aggregation](https://arxiv.org/abs/2112.13692)\n\n- Pytorch implementation of [Going deeper with Image Transformers---ICCV 2021 (Oral)](https://arxiv.org/abs/2103.17239)\n\n- Pytorch implementation of [Training data-efficient image transformers & distillation through attention---ICML 2021](https://arxiv.org/abs/2012.12877)\n\n- Pytorch implementation of [LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference](https://arxiv.org/abs/2104.01136)\n\n- Pytorch implementation of [VOLO: Vision Outlooker for Visual Recognition](https://arxiv.org/abs/2106.13112)\n\n- Pytorch implementation of [Container: Context Aggregation Network---NeuIPS 2021](https://arxiv.org/abs/2106.01401)\n\n- Pytorch implementation of [CMT: Convolutional Neural Networks Meet Vision Transformers---CVPR 2022](https://arxiv.org/abs/2107.06263)\n\n- Pytorch implementation of [Vision Transformer with Deformable Attention---CVPR 2022](https://arxiv.org/abs/2201.00520)\n\n\n### 1. ResNet Usage\n#### 1.1. Paper\n[\"Deep Residual Learning for Image Recognition---CVPR2016 Best Paper\"](https://arxiv.org/pdf/1512.03385.pdf)\n\n#### 1.2. Overview\n![](./model/img/resnet.png)\n![](./model/img/resnet2.jpg)\n\n#### 1.3. Usage Code\n```python\n\nfrom model.backbone.resnet import ResNet50,ResNet101,ResNet152\nimport torch\nif __name__ == '__main__':\n    input=torch.randn(50,3,224,224)\n    resnet50=ResNet50(1000)\n    # resnet101=ResNet101(1000)\n    # resnet152=ResNet152(1000)\n    out=resnet50(input)\n    print(out.shape)\n\n```\n\n\n### 2. ResNeXt Usage\n#### 2.1. Paper\n\n[\"Aggregated Residual Transformations for Deep Neural Networks---CVPR2017\"](https://arxiv.org/abs/1611.05431v2)\n\n#### 2.2. Overview\n![](./model/img/resnext.png)\n\n#### 2.3. Usage Code\n```python\n\nfrom model.backbone.resnext import ResNeXt50,ResNeXt101,ResNeXt152\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(50,3,224,224)\n    resnext50=ResNeXt50(1000)\n    # resnext101=ResNeXt101(1000)\n    # resnext152=ResNeXt152(1000)\n    out=resnext50(input)\n    print(out.shape)\n\n\n```\n\n\n\n### 3. MobileViT Usage\n#### 3.1. Paper\n\n[MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2020.10.05](https://arxiv.org/abs/2103.02907)\n\n#### 3.2. Overview\n![](./model/img/mobileViT.jpg)\n\n#### 3.3. Usage Code\n```python\n\nfrom model.backbone.MobileViT import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n\n    ### mobilevit_xxs\n    mvit_xxs=mobilevit_xxs()\n    out=mvit_xxs(input)\n    print(out.shape)\n\n    ### mobilevit_xs\n    mvit_xs=mobilevit_xs()\n    out=mvit_xs(input)\n    print(out.shape)\n\n\n    ### mobilevit_s\n    mvit_s=mobilevit_s()\n    out=mvit_s(input)\n    print(out.shape)\n\n```\n\n\n\n\n\n### 4. ConvMixer Usage\n#### 4.1. Paper\n[Patches Are All You Need?---ICLR2022 (Under Review)](https://openreview.net/forum?id=TVHS5Y4dNvM)\n#### 4.2. Overview\n![](./model/img/ConvMixer.png)\n\n#### 4.3. Usage Code\n```python\n\nfrom model.backbone.ConvMixer import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    x=torch.randn(1,3,224,224)\n    convmixer=ConvMixer(dim=512,depth=12)\n    out=convmixer(x)\n    print(out.shape)  #[1, 1000]\n\n\n```\n\n### 5. ShuffleTransformer Usage\n#### 5.1. Paper\n[Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer](https://arxiv.org/pdf/2106.03650.pdf)\n\n#### 5.2. Usage Code\n```python\n\nfrom model.backbone.ShuffleTransformer import ShuffleTransformer\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    sft = ShuffleTransformer()\n    output=sft(input)\n    print(output.shape)\n\n\n```\n\n### 6. ConTNet Usage\n#### 6.1. Paper\n[ConTNet: Why not use convolution and transformer at the same time?](https://arxiv.org/abs/2104.13497)\n\n#### 6.2. Usage Code\n```python\n\nfrom model.backbone.ConTNet import ConTNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == \"__main__\":\n    model = build_model(use_avgdown=True, relative=True, qkv_bias=True, pre_norm=True)\n    input = torch.randn(1, 3, 224, 224)\n    out = model(input)\n    print(out.shape)\n\n\n```\n\n### 7 HATNet Usage\n#### 7.1. Paper\n[Vision Transformers with Hierarchical Attention](https://arxiv.org/abs/2106.03180)\n\n#### 7.2. Usage Code\n```python\n\nfrom model.backbone.HATNet import HATNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    hat = HATNet(dims=[48, 96, 240, 384], head_dim=48, expansions=[8, 8, 4, 4],\n        grid_sizes=[8, 7, 7, 1], ds_ratios=[8, 4, 2, 1], depths=[2, 2, 6, 3])\n    output=hat(input)\n    print(output.shape)\n\n\n```\n\n### 8 CoaT Usage\n#### 8.1. Paper\n[Co-Scale Conv-Attentional Image Transformers](https://arxiv.org/abs/2104.06399)\n\n#### 8.2. Usage Code\n```python\n\nfrom model.backbone.CoaT import CoaT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CoaT(patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4])\n    output=model(input)\n    print(output.shape) # torch.Size([1, 1000])\n\n```\n\n### 9 PVT Usage\n#### 9.1. Paper\n[PVT v2: Improved Baselines with Pyramid Vision Transformer](https://arxiv.org/pdf/2106.13797.pdf)\n\n#### 9.2. Usage Code\n```python\n\nfrom model.backbone.PVT import PyramidVisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = PyramidVisionTransformer(\n        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1])\n    output=model(input)\n    print(output.shape)\n\n```\n\n\n### 10 CPVT Usage\n#### 10.1. Paper\n[Conditional Positional Encodings for Vision Transformers](https://arxiv.org/abs/2102.10882)\n\n#### 10.2. Usage Code\n```python\n\nfrom model.backbone.CPVT import CPVTV2\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CPVTV2(\n        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1])\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 11 PIT Usage\n#### 11.1. Paper\n[Rethinking Spatial Dimensions of Vision Transformers](https://arxiv.org/abs/2103.16302)\n\n#### 11.2. Usage Code\n```python\n\nfrom model.backbone.PIT import PoolingTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = PoolingTransformer(\n        image_size=224,\n        patch_size=14,\n        stride=7,\n        base_dims=[64, 64, 64],\n        depth=[3, 6, 4],\n        heads=[4, 8, 16],\n        mlp_ratio=4\n    )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 12 CrossViT Usage\n#### 12.1. Paper\n[CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification](https://arxiv.org/abs/2103.14899)\n\n#### 12.2. Usage Code\n```python\n\nfrom model.backbone.CrossViT import VisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == \"__main__\":\n    input=torch.randn(1,3,224,224)\n    model = VisionTransformer(\n        img_size=[240, 224],\n        patch_size=[12, 16], \n        embed_dim=[192, 384], \n        depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],\n        num_heads=[6, 6], \n        mlp_ratio=[4, 4, 1], \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n    )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 13 TnT Usage\n#### 13.1. Paper\n[Transformer in Transformer](https://arxiv.org/abs/2103.00112)\n\n#### 13.2. Usage Code\n```python\n\nfrom model.backbone.TnT import TNT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = TNT(\n        img_size=224, \n        patch_size=16, \n        outer_dim=384, \n        inner_dim=24, \n        depth=12,\n        outer_num_heads=6, \n        inner_num_heads=4, \n        qkv_bias=False,\n        inner_stride=4)\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 14 DViT Usage\n#### 14.1. Paper\n[DeepViT: Towards Deeper Vision Transformer](https://arxiv.org/abs/2103.11886)\n\n#### 14.2. Usage Code\n```python\n\nfrom model.backbone.DViT import DeepVisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = DeepVisionTransformer(\n        patch_size=16, embed_dim=384, \n        depth=[False] * 16, \n        apply_transform=[False] * 0 + [True] * 32, \n        num_heads=12, \n        mlp_ratio=3, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 15 CeiT Usage\n#### 15.1. Paper\n[Incorporating Convolution Designs into Visual Transformers](https://arxiv.org/abs/2103.11816)\n\n#### 15.2. Usage Code\n```python\n\nfrom model.backbone.CeiT import CeIT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CeIT(\n        hybrid_backbone=Image2Tokens(),\n        patch_size=4, \n        embed_dim=192, \n        depth=12, \n        num_heads=3, \n        mlp_ratio=4, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 16 ConViT Usage\n#### 16.1. Paper\n[ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases](https://arxiv.org/abs/2103.10697)\n\n#### 16.2. Usage Code\n```python\n\nfrom model.backbone.ConViT import VisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VisionTransformer(\n        num_heads=16,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 17 CaiT Usage\n#### 17.1. Paper\n[Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239)\n\n#### 17.2. Usage Code\n```python\n\nfrom model.backbone.CaiT import CaiT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CaiT(\n        img_size= 224,\n        patch_size=16, \n        embed_dim=192, \n        depth=24, \n        num_heads=4, \n        mlp_ratio=4, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_scale=1e-5,\n        depth_token_only=2\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 18 PatchConvnet Usage\n#### 18.1. Paper\n[Augmenting Convolutional networks with attention-based aggregation](https://arxiv.org/abs/2112.13692)\n\n#### 18.2. Usage Code\n```python\n\nfrom model.backbone.PatchConvnet import PatchConvnet\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = PatchConvnet(\n        patch_size=16,\n        embed_dim=384,\n        depth=60,\n        num_heads=1,\n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        Patch_layer=ConvStem,\n        Attention_block=Conv_blocks_se,\n        depth_token_only=1,\n        mlp_ratio_clstk=3.0,\n    )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 19 DeiT Usage\n#### 19.1. Paper\n[Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877)\n\n#### 19.2. Usage Code\n```python\n\nfrom model.backbone.DeiT import DistilledVisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = DistilledVisionTransformer(\n        patch_size=16, \n        embed_dim=384, \n        depth=12, \n        num_heads=6, \n        mlp_ratio=4, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n        )\n    output=model(input)\n    print(output[0].shape)\n\n```\n\n### 20 LeViT Usage\n#### 20.1. Paper\n[LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference](https://arxiv.org/abs/2104.01136)\n\n#### 20.2. Usage Code\n```python\n\nfrom model.backbone.LeViT import *\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    for name in specification:\n        input=torch.randn(1,3,224,224)\n        model = globals()[name](fuse=True, pretrained=False)\n        model.eval()\n        output = model(input)\n        print(output.shape)\n\n```\n\n### 21 VOLO Usage\n#### 21.1. Paper\n[VOLO: Vision Outlooker for Visual Recognition](https://arxiv.org/abs/2106.13112)\n\n#### 21.2. Usage Code\n```python\n\nfrom model.backbone.VOLO import VOLO\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VOLO([4, 4, 8, 2],\n                 embed_dims=[192, 384, 384, 384],\n                 num_heads=[6, 12, 12, 12],\n                 mlp_ratios=[3, 3, 3, 3],\n                 downsamples=[True, False, False, False],\n                 outlook_attention=[True, False, False, False ],\n                 post_layers=['ca', 'ca'],\n                 )\n    output=model(input)\n    print(output[0].shape)\n\n```\n\n### 22 Container Usage\n#### 22.1. Paper\n[Container: Context Aggregation Network](https://arxiv.org/abs/2106.01401)\n\n#### 22.2. Usage Code\n```python\n\nfrom model.backbone.Container import VisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VisionTransformer(\n        img_size=[224, 56, 28, 14], \n        patch_size=[4, 2, 2, 2], \n        embed_dim=[64, 128, 320, 512], \n        depth=[3, 4, 8, 3], \n        num_heads=16, \n        mlp_ratio=[8, 8, 4, 4], \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6))\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 23 CMT Usage\n#### 23.1. Paper\n[CMT: Convolutional Neural Networks Meet Vision Transformers](https://arxiv.org/abs/2107.06263)\n\n#### 23.2. Usage Code\n```python\n\nfrom model.backbone.CMT import CMT_Tiny\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CMT_Tiny()\n    output=model(input)\n    print(output[0].shape)\n\n```\n\n\n\n\n\n\n# MLP Series\n\n- Pytorch implementation of [\"RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition---arXiv 2021.05.05\"](https://arxiv.org/pdf/2105.01883v1.pdf)\n\n- Pytorch implementation of [\"MLP-Mixer: An all-MLP Architecture for Vision---arXiv 2021.05.17\"](https://arxiv.org/pdf/2105.01601.pdf)\n\n- Pytorch implementation of [\"ResMLP: Feedforward networks for image classification with data-efficient training---arXiv 2021.05.07\"](https://arxiv.org/pdf/2105.03404.pdf)\n\n- Pytorch implementation of [\"Pay Attention to MLPs---arXiv 2021.05.17\"](https://arxiv.org/abs/2105.08050)\n\n\n- Pytorch implementation of [\"Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?---arXiv 2021.09.12\"](https://arxiv.org/abs/2109.05422)\n\n### 1. RepMLP Usage\n#### 1.1. Paper\n[\"RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition\"](https://arxiv.org/pdf/2105.01883v1.pdf)\n\n#### 1.2. Overview\n![](./model/img/repmlp.png)\n\n#### 1.3. Usage Code\n```python\nfrom model.mlp.repmlp import RepMLP\nimport torch\nfrom torch import nn\n\nN=4 #batch size\nC=512 #input dim\nO=1024 #output dim\nH=14 #image height\nW=14 #image width\nh=7 #patch height\nw=7 #patch width\nfc1_fc2_reduction=1 #reduction ratio\nfc3_groups=8 # groups\nrepconv_kernels=[1,3,5,7] #kernel list\nrepmlp=RepMLP(C,O,H,W,h,w,fc1_fc2_reduction,fc3_groups,repconv_kernels=repconv_kernels)\nx=torch.randn(N,C,H,W)\nrepmlp.eval()\nfor module in repmlp.modules():\n    if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):\n        nn.init.uniform_(module.running_mean, 0, 0.1)\n        nn.init.uniform_(module.running_var, 0, 0.1)\n        nn.init.uniform_(module.weight, 0, 0.1)\n        nn.init.uniform_(module.bias, 0, 0.1)\n\n#training result\nout=repmlp(x)\n#inference result\nrepmlp.switch_to_deploy()\ndeployout = repmlp(x)\n\nprint(((deployout-out)**2).sum())\n```\n\n### 2. MLP-Mixer Usage\n#### 2.1. Paper\n[\"MLP-Mixer: An all-MLP Architecture for Vision\"](https://arxiv.org/pdf/2105.01601.pdf)\n\n#### 2.2. Overview\n![](./model/img/mlpmixer.png)\n\n#### 2.3. Usage Code\n```python\nfrom model.mlp.mlp_mixer import MlpMixer\nimport torch\nmlp_mixer=MlpMixer(num_classes=1000,num_blocks=10,patch_size=10,tokens_hidden_dim=32,channels_hidden_dim=1024,tokens_mlp_dim=16,channels_mlp_dim=1024)\ninput=torch.randn(50,3,40,40)\noutput=mlp_mixer(input)\nprint(output.shape)\n```\n\n***\n\n### 3. ResMLP Usage\n#### 3.1. Paper\n[\"ResMLP: Feedforward networks for image classification with data-efficient training\"](https://arxiv.org/pdf/2105.03404.pdf)\n\n#### 3.2. Overview\n![](./model/img/resmlp.png)\n\n#### 3.3. Usage Code\n```python\nfrom model.mlp.resmlp import ResMLP\nimport torch\n\ninput=torch.randn(50,3,14,14)\nresmlp=ResMLP(dim=128,image_size=14,patch_size=7,class_num=1000)\nout=resmlp(input)\nprint(out.shape) #the last dimention is class_num\n```\n\n***\n\n### 4. gMLP Usage\n#### 4.1. Paper\n[\"Pay Attention to MLPs\"](https://arxiv.org/abs/2105.08050)\n\n#### 4.2. Overview\n![](./model/img/gMLP.jpg)\n\n#### 4.3. Usage Code\n```python\nfrom model.mlp.g_mlp import gMLP\nimport torch\n\nnum_tokens=10000\nbs=50\nlen_sen=49\nnum_layers=6\ninput=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen\ngmlp = gMLP(num_tokens=num_tokens,len_sen=len_sen,dim=512,d_ff=1024)\noutput=gmlp(input)\nprint(output.shape)\n```\n\n***\n\n### 5. sMLP Usage\n#### 5.1. Paper\n[\"Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?\"](https://arxiv.org/abs/2109.05422)\n\n#### 5.2. Overview\n![](./model/img/sMLP.jpg)\n\n#### 5.3. Usage Code\n```python\nfrom model.mlp.sMLP_block import sMLPBlock\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,3,224,224)\n    smlp=sMLPBlock(h=224,w=224)\n    out=smlp(input)\n    print(out.shape)\n```\n\n### 6. vip-mlp Usage\n#### 6.1. Paper\n[\"Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition\"](https://arxiv.org/abs/2106.12368)\n\n#### 6.2. Usage Code\n```python\nfrom model.mlp.vip-mlp import VisionPermutator\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VisionPermutator(\n        layers=[4, 3, 8, 3], \n        embed_dims=[384, 384, 384, 384], \n        patch_size=14, \n        transitions=[False, False, False, False],\n        segment_dim=[16, 16, 16, 16], \n        mlp_ratios=[3, 3, 3, 3], \n        mlp_fn=WeightedPermuteMLP\n    )\n    output=model(input)\n    print(output.shape)\n```\n\n\n# Re-Parameter Series\n\n- Pytorch implementation of [\"RepVGG: Making VGG-style ConvNets Great Again---CVPR2021\"](https://arxiv.org/abs/2101.03697)\n\n- Pytorch implementation of [\"ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks---ICCV2019\"](https://arxiv.org/abs/1908.03930)\n\n- Pytorch implementation of [\"Diverse Branch Block: Building a Convolution as an Inception-like Unit---CVPR2021\"](https://arxiv.org/abs/2103.13425)\n\n\n***\n\n### 1. RepVGG Usage\n#### 1.1. Paper\n[\"RepVGG: Making VGG-style ConvNets Great Again\"](https://arxiv.org/abs/2101.03697)\n\n#### 1.2. Overview\n![](./model/img/repvgg.png)\n\n#### 1.3. Usage Code\n```python\n\nfrom model.rep.repvgg import RepBlock\nimport torch\n\n\ninput=torch.randn(50,512,49,49)\nrepblock=RepBlock(512,512)\nrepblock.eval()\nout=repblock(input)\nrepblock._switch_to_deploy()\nout2=repblock(input)\nprint('difference between vgg and repvgg')\nprint(((out2-out)**2).sum())\n```\n\n\n\n***\n\n### 2. ACNet Usage\n#### 2.1. Paper\n[\"ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks\"](https://arxiv.org/abs/1908.03930)\n\n#### 2.2. Overview\n![](./model/img/acnet.png)\n\n#### 2.3. Usage Code\n```python\nfrom model.rep.acnet import ACNet\nimport torch\nfrom torch import nn\n\ninput=torch.randn(50,512,49,49)\nacnet=ACNet(512,512)\nacnet.eval()\nout=acnet(input)\nacnet._switch_to_deploy()\nout2=acnet(input)\nprint('difference:')\nprint(((out2-out)**2).sum())\n\n```\n\n\n\n***\n\n### 2. Diverse Branch Block Usage\n#### 2.1. Paper\n[\"Diverse Branch Block: Building a Convolution as an Inception-like Unit\"](https://arxiv.org/abs/2103.13425)\n\n#### 2.2. Overview\n![](./model/img/ddb.png)\n\n#### 2.3. Usage Code\n##### 2.3.1 Transform I\n```python\nfrom model.rep.ddb import transI_conv_bn\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n#conv+bn\nconv1=nn.Conv2d(64,64,3,padding=1)\nbn1=nn.BatchNorm2d(64)\nbn1.eval()\nout1=bn1(conv1(input))\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transI_conv_bn(conv1,bn1)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.2 Transform II\n```python\nfrom model.rep.ddb import transII_conv_branch\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1=nn.Conv2d(64,64,3,padding=1)\nconv2=nn.Conv2d(64,64,3,padding=1)\nout1=conv1(input)+conv2(input)\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transII_conv_branch(conv1,conv2)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.3 Transform III\n```python\nfrom model.rep.ddb import transIII_conv_sequential\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1=nn.Conv2d(64,64,1,padding=0,bias=False)\nconv2=nn.Conv2d(64,64,3,padding=1,bias=False)\nout1=conv2(conv1(input))\n\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1,bias=False)\nconv_fuse.weight.data=transIII_conv_sequential(conv1,conv2)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.4 Transform IV\n```python\nfrom model.rep.ddb import transIV_conv_concat\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1=nn.Conv2d(64,32,3,padding=1)\nconv2=nn.Conv2d(64,32,3,padding=1)\nout1=torch.cat([conv1(input),conv2(input)],dim=1)\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transIV_conv_concat(conv1,conv2)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.5 Transform V\n```python\nfrom model.rep.ddb import transV_avg\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\navg=nn.AvgPool2d(kernel_size=3,stride=1)\nout1=avg(input)\n\nconv=transV_avg(64,3)\nout2=conv(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n\n##### 2.3.6 Transform VI\n```python\nfrom model.rep.ddb import transVI_conv_scale\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1x1=nn.Conv2d(64,64,1)\nconv1x3=nn.Conv2d(64,64,(1,3),padding=(0,1))\nconv3x1=nn.Conv2d(64,64,(3,1),padding=(1,0))\nout1=conv1x1(input)+conv1x3(input)+conv3x1(input)\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transVI_conv_scale(conv1x1,conv1x3,conv3x1)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n\n\n\n\n# Convolution Series\n\n- Pytorch implementation of [\"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications---CVPR2017\"](https://arxiv.org/abs/1704.04861)\n\n- Pytorch implementation of [\"Efficientnet: Rethinking model scaling for convolutional neural networks---PMLR2019\"](http://proceedings.mlr.press/v97/tan19a.html)\n\n- Pytorch implementation of [\"Involution: Inverting the Inherence of Convolution for Visual Recognition---CVPR2021\"](https://arxiv.org/abs/2103.06255)\n\n- Pytorch implementation of [\"Dynamic Convolution: Attention over Convolution Kernels---CVPR2020 Oral\"](https://arxiv.org/abs/1912.03458)\n\n- Pytorch implementation of [\"CondConv: Conditionally Parameterized Convolutions for Efficient Inference---NeurIPS2019\"](https://arxiv.org/abs/1904.04971)\n\n***\n\n### 1. Depthwise Separable Convolution Usage\n#### 1.1. Paper\n[\"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications\"](https://arxiv.org/abs/1704.04861)\n\n#### 1.2. Overview\n![](./model/img/DepthwiseSeparableConv.png)\n\n#### 1.3. Usage Code\n```python\nfrom model.conv.DepthwiseSeparableConvolution import DepthwiseSeparableConvolution\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,3,224,224)\ndsconv=DepthwiseSeparableConvolution(3,64)\nout=dsconv(input)\nprint(out.shape)\n```\n\n***\n\n\n### 2. MBConv Usage\n#### 2.1. Paper\n[\"Efficientnet: Rethinking model scaling for convolutional neural networks\"](http://proceedings.mlr.press/v97/tan19a.html)\n\n#### 2.2. Overview\n![](./model/img/MBConv.jpg)\n\n#### 2.3. Usage Code\n```python\nfrom model.conv.MBConv import MBConvBlock\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,3,224,224)\nmbconv=MBConvBlock(ksize=3,input_filters=3,output_filters=512,image_size=224)\nout=mbconv(input)\nprint(out.shape)\n\n\n```\n\n***\n\n\n### 3. Involution Usage\n#### 3.1. Paper\n[\"Involution: Inverting the Inherence of Convolution for Visual Recognition\"](https://arxiv.org/abs/2103.06255)\n\n#### 3.2. Overview\n![](./model/img/Involution.png)\n\n#### 3.3. Usage Code\n```python\nfrom model.conv.Involution import Involution\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,4,64,64)\ninvolution=Involution(kernel_size=3,in_channel=4,stride=2)\nout=involution(input)\nprint(out.shape)\n```\n\n***\n\n\n### 4. DynamicConv Usage\n#### 4.1. Paper\n[\"Dynamic Convolution: Attention over Convolution Kernels\"](https://arxiv.org/abs/1912.03458)\n\n#### 4.2. Overview\n![](./model/img/DynamicConv.png)\n\n#### 4.3. Usage Code\n```python\nfrom model.conv.DynamicConv import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(2,32,64,64)\n    m=DynamicConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)\n    out=m(input)\n    print(out.shape) # 2,32,64,64\n\n```\n\n***\n\n\n### 5. CondConv Usage\n#### 5.1. Paper\n[\"CondConv: Conditionally Parameterized Convolutions for Efficient Inference\"](https://arxiv.org/abs/1904.04971)\n\n#### 5.2. Overview\n![](./model/img/CondConv.png)\n\n#### 5.3. Usage Code\n```python\nfrom model.conv.CondConv import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(2,32,64,64)\n    m=CondConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)\n    out=m(input)\n    print(out.shape)\n\n```\n\n***\n"
  },
  {
    "path": "README_pip.md",
    "content": "## pip使用文档\n\n### 安装\n\n 直接通过 pip 安装，可直接在其他任务中使用\n\n  ```shell\n  pip install fightingcv-attention\n  ```\n\n### 演示\n\n#### 使用 pip 方式\n```python\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n# 使用 pip 方式\n\nfrom fightingcv_attention.attention.MobileViTv2Attention import *\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    sa = MobileViTv2Attention(d_model=512)\n    output=sa(input)\n    print(output.shape)\n```\n\n## pip包 fightingcv-attention 包含以下模块\n\n# 目录\n\n- [Attention Series](#attention-series)\n    - [1. External Attention Usage](#1-external-attention-usage)\n\n    - [2. Self Attention Usage](#2-self-attention-usage)\n\n    - [3. Simplified Self Attention Usage](#3-simplified-self-attention-usage)\n\n    - [4. Squeeze-and-Excitation Attention Usage](#4-squeeze-and-excitation-attention-usage)\n\n    - [5. SK Attention Usage](#5-sk-attention-usage)\n\n    - [6. CBAM Attention Usage](#6-cbam-attention-usage)\n\n    - [7. BAM Attention Usage](#7-bam-attention-usage)\n    \n    - [8. ECA Attention Usage](#8-eca-attention-usage)\n\n    - [9. DANet Attention Usage](#9-danet-attention-usage)\n\n    - [10. Pyramid Split Attention (PSA) Usage](#10-Pyramid-Split-Attention-Usage)\n\n    - [11. Efficient Multi-Head Self-Attention(EMSA) Usage](#11-Efficient-Multi-Head-Self-Attention-Usage)\n\n    - [12. Shuffle Attention Usage](#12-Shuffle-Attention-Usage)\n    \n    - [13. MUSE Attention Usage](#13-MUSE-Attention-Usage)\n  \n    - [14. SGE Attention Usage](#14-SGE-Attention-Usage)\n\n    - [15. A2 Attention Usage](#15-A2-Attention-Usage)\n\n    - [16. AFT Attention Usage](#16-AFT-Attention-Usage)\n\n    - [17. Outlook Attention Usage](#17-Outlook-Attention-Usage)\n\n    - [18. ViP Attention Usage](#18-ViP-Attention-Usage)\n\n    - [19. CoAtNet Attention Usage](#19-CoAtNet-Attention-Usage)\n\n    - [20. HaloNet Attention Usage](#20-HaloNet-Attention-Usage)\n\n    - [21. Polarized Self-Attention Usage](#21-Polarized-Self-Attention-Usage)\n\n    - [22. CoTAttention Usage](#22-CoTAttention-Usage)\n\n    - [23. Residual Attention Usage](#23-Residual-Attention-Usage)\n  \n    - [24. S2 Attention Usage](#24-S2-Attention-Usage)\n\n    - [25. GFNet Attention Usage](#25-GFNet-Attention-Usage)\n\n    - [26. Triplet Attention Usage](#26-TripletAttention-Usage)\n\n    - [27. Coordinate Attention Usage](#27-Coordinate-Attention-Usage)\n\n    - [28. MobileViT Attention Usage](#28-MobileViT-Attention-Usage)\n\n    - [29. ParNet Attention Usage](#29-ParNet-Attention-Usage)\n\n    - [30. UFO Attention Usage](#30-UFO-Attention-Usage)\n\n    - [31. ACmix Attention Usage](#31-Acmix-Attention-Usage)\n  \n    - [32. MobileViTv2 Attention Usage](#32-MobileViTv2-Attention-Usage)\n\n    - [33. DAT Attention Usage](#33-DAT-Attention-Usage)\n\n    - [34. CrossFormer Attention Usage](#34-CrossFormer-Attention-Usage)\n\n    - [35. MOATransformer Attention Usage](#35-MOATransformer-Attention-Usage)\n\n    - [36. CrissCrossAttention Attention Usage](#36-CrissCrossAttention-Attention-Usage)\n\n    - [37. Axial_attention Attention Usage](#37-Axial_attention-Attention-Usage)\n\n- [Backbone Series](#Backbone-series)\n\n    - [1. ResNet Usage](#1-ResNet-Usage)\n\n    - [2. ResNeXt Usage](#2-ResNeXt-Usage)\n\n    - [3. MobileViT Usage](#3-MobileViT-Usage)\n\n    - [4. ConvMixer Usage](#4-ConvMixer-Usage)\n\n    - [5. ShuffleTransformer Usage](#5-ShuffleTransformer-Usage)\n\n    - [6. ConTNet Usage](#6-ConTNet-Usage)\n\n    - [7. HATNet Usage](#7-HATNet-Usage)\n\n    - [8. CoaT Usage](#8-CoaT-Usage)\n\n    - [9. PVT Usage](#9-PVT-Usage)\n\n    - [10. CPVT Usage](#10-CPVT-Usage)\n\n    - [11. PIT Usage](#11-PIT-Usage)\n\n    - [12. CrossViT Usage](#12-CrossViT-Usage)\n\n    - [13. TnT Usage](#13-TnT-Usage)\n\n    - [14. DViT Usage](#14-DViT-Usage)\n\n    - [15. CeiT Usage](#15-CeiT-Usage)\n\n    - [16. ConViT Usage](#16-ConViT-Usage)\n\n    - [17. CaiT Usage](#17-CaiT-Usage)\n\n    - [18. PatchConvnet Usage](#18-PatchConvnet-Usage)\n\n    - [19. DeiT Usage](#19-DeiT-Usage)\n\n    - [20. LeViT Usage](#20-LeViT-Usage)\n\n    - [21. VOLO Usage](#21-VOLO-Usage)\n    \n    - [22. Container Usage](#22-Container-Usage)\n\n    - [23. CMT Usage](#23-CMT-Usage)\n\n    - [24. EfficientFormer Usage](#24-EfficientFormer-Usage)\n\n\n- [MLP Series](#mlp-series)\n\n    - [1. RepMLP Usage](#1-RepMLP-Usage)\n\n    - [2. MLP-Mixer Usage](#2-MLP-Mixer-Usage)\n\n    - [3. ResMLP Usage](#3-ResMLP-Usage)\n\n    - [4. gMLP Usage](#4-gMLP-Usage)\n\n    - [5. sMLP Usage](#5-sMLP-Usage)\n\n    - [6. vip-mlp Usage](#6-vip-mlp-Usage)\n\n- [Re-Parameter(ReP) Series](#Re-Parameter-series)\n\n    - [1. RepVGG Usage](#1-RepVGG-Usage)\n\n    - [2. ACNet Usage](#2-ACNet-Usage)\n\n    - [3. Diverse Branch Block(DDB) Usage](#3-Diverse-Branch-Block-Usage)\n\n- [Convolution Series](#Convolution-series)\n\n    - [1. Depthwise Separable Convolution Usage](#1-Depthwise-Separable-Convolution-Usage)\n\n    - [2. MBConv Usage](#2-MBConv-Usage)\n\n    - [3. Involution Usage](#3-Involution-Usage)\n\n    - [4. DynamicConv Usage](#4-DynamicConv-Usage)\n\n    - [5. CondConv Usage](#5-CondConv-Usage)\n\n***\n\n\n\n# Attention Series\n\n- Pytorch implementation of [\"Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks---arXiv 2021.05.05\"](https://arxiv.org/abs/2105.02358)\n\n- Pytorch implementation of [\"Attention Is All You Need---NIPS2017\"](https://arxiv.org/pdf/1706.03762.pdf)\n\n- Pytorch implementation of [\"Squeeze-and-Excitation Networks---CVPR2018\"](https://arxiv.org/abs/1709.01507)\n\n- Pytorch implementation of [\"Selective Kernel Networks---CVPR2019\"](https://arxiv.org/pdf/1903.06586.pdf)\n\n- Pytorch implementation of [\"CBAM: Convolutional Block Attention Module---ECCV2018\"](https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf)\n\n- Pytorch implementation of [\"BAM: Bottleneck Attention Module---BMCV2018\"](https://arxiv.org/pdf/1807.06514.pdf)\n\n- Pytorch implementation of [\"ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks---CVPR2020\"](https://arxiv.org/pdf/1910.03151.pdf)\n\n- Pytorch implementation of [\"Dual Attention Network for Scene Segmentation---CVPR2019\"](https://arxiv.org/pdf/1809.02983.pdf)\n\n- Pytorch implementation of [\"EPSANet: An Efficient Pyramid Split Attention Block on Convolutional Neural Network---arXiv 2021.05.30\"](https://arxiv.org/pdf/2105.14447.pdf)\n\n- Pytorch implementation of [\"ResT: An Efficient Transformer for Visual Recognition---arXiv 2021.05.28\"](https://arxiv.org/abs/2105.13677)\n\n- Pytorch implementation of [\"SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS---ICASSP 2021\"](https://arxiv.org/pdf/2102.00240.pdf)\n\n- Pytorch implementation of [\"MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning---arXiv 2019.11.17\"](https://arxiv.org/abs/1911.09483)\n\n- Pytorch implementation of [\"Spatial Group-wise Enhance: Improving Semantic Feature Learning in Convolutional Networks---arXiv 2019.05.23\"](https://arxiv.org/pdf/1905.09646.pdf)\n\n- Pytorch implementation of [\"A2-Nets: Double Attention Networks---NIPS2018\"](https://arxiv.org/pdf/1810.11579.pdf)\n\n\n- Pytorch implementation of [\"An Attention Free Transformer---ICLR2021 (Apple New Work)\"](https://arxiv.org/pdf/2105.14103v1.pdf)\n\n\n- Pytorch implementation of [VOLO: Vision Outlooker for Visual Recognition---arXiv 2021.06.24\"](https://arxiv.org/abs/2106.13112) \n  [【论文解析】](https://zhuanlan.zhihu.com/p/385561050)\n\n\n- Pytorch implementation of [Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition---arXiv 2021.06.23](https://arxiv.org/abs/2106.12368) \n  [【论文解析】](https://mp.weixin.qq.com/s/5gonUQgBho_m2O54jyXF_Q)\n\n\n- Pytorch implementation of [CoAtNet: Marrying Convolution and Attention for All Data Sizes---arXiv 2021.06.09](https://arxiv.org/abs/2106.04803) \n  [【论文解析】](https://zhuanlan.zhihu.com/p/385578588)\n\n\n- Pytorch implementation of [Scaling Local Self-Attention for Parameter Efficient Visual Backbones---CVPR2021 Oral](https://arxiv.org/pdf/2103.12731.pdf)  [【论文解析】](https://zhuanlan.zhihu.com/p/388598744)\n\n\n\n- Pytorch implementation of [Polarized Self-Attention: Towards High-quality Pixel-wise Regression---arXiv 2021.07.02](https://arxiv.org/abs/2107.00782)  [【论文解析】](https://zhuanlan.zhihu.com/p/389770482) \n\n\n- Pytorch implementation of [Contextual Transformer Networks for Visual Recognition---arXiv 2021.07.26](https://arxiv.org/abs/2107.12292)  [【论文解析】](https://zhuanlan.zhihu.com/p/394795481) \n\n\n- Pytorch implementation of [Residual Attention: A Simple but Effective Method for Multi-Label Recognition---ICCV2021](https://arxiv.org/abs/2108.02456) \n\n\n- Pytorch implementation of [S²-MLPv2: Improved Spatial-Shift MLP Architecture for Vision---arXiv 2021.08.02](https://arxiv.org/abs/2108.01072) [【论文解析】](https://zhuanlan.zhihu.com/p/397003638) \n\n- Pytorch implementation of [Global Filter Networks for Image Classification---arXiv 2021.07.01](https://arxiv.org/abs/2107.00645) \n\n- Pytorch implementation of [Rotate to Attend: Convolutional Triplet Attention Module---WACV 2021](https://arxiv.org/abs/2010.03045) \n\n- Pytorch implementation of [Coordinate Attention for Efficient Mobile Network Design ---CVPR 2021](https://arxiv.org/abs/2103.02907)\n\n- Pytorch implementation of [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2021.10.05](https://arxiv.org/abs/2110.02178)\n\n- Pytorch implementation of [Non-deep Networks---ArXiv 2021.10.20](https://arxiv.org/abs/2110.07641)\n\n- Pytorch implementation of [UFO-ViT: High Performance Linear Vision Transformer without Softmax---ArXiv 2021.09.29](https://arxiv.org/abs/2109.14382)\n\n- Pytorch implementation of [Separable Self-attention for Mobile Vision Transformers---ArXiv 2022.06.06](https://arxiv.org/abs/2206.02680)\n\n- Pytorch implementation of [On the Integration of Self-Attention and Convolution---ArXiv 2022.03.14](https://arxiv.org/pdf/2111.14556.pdf)\n\n- Pytorch implementation of [CROSSFORMER: A VERSATILE VISION TRANSFORMER HINGING ON CROSS-SCALE ATTENTION---ICLR 2022](https://arxiv.org/pdf/2108.00154.pdf)\n\n- Pytorch implementation of [Aggregating Global Features into Local Vision Transformer](https://arxiv.org/abs/2201.12903)\n\n- Pytorch implementation of [CCNet: Criss-Cross Attention for Semantic Segmentation](https://arxiv.org/abs/1811.11721)\n\n- Pytorch implementation of [Axial Attention in Multidimensional Transformers](https://arxiv.org/abs/1912.12180)\n***\n\n\n### 1. External Attention Usage\n#### 1.1. Paper\n[\"Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks\"](https://arxiv.org/abs/2105.02358)\n\n#### 1.2. Overview\n![](./model/img/External_Attention.png)\n\n#### 1.3. Usage Code\n```python\nfrom fightingcv_attention.attention.ExternalAttention import ExternalAttention\nimport torch\n\ninput=torch.randn(50,49,512)\nea = ExternalAttention(d_model=512,S=8)\noutput=ea(input)\nprint(output.shape)\n```\n\n***\n\n\n### 2. Self Attention Usage\n#### 2.1. Paper\n[\"Attention Is All You Need\"](https://arxiv.org/pdf/1706.03762.pdf)\n\n#### 1.2. Overview\n![](./model/img/SA.png)\n\n#### 1.3. Usage Code\n```python\nfrom fightingcv_attention.attention.SelfAttention import ScaledDotProductAttention\nimport torch\n\ninput=torch.randn(50,49,512)\nsa = ScaledDotProductAttention(d_model=512, d_k=512, d_v=512, h=8)\noutput=sa(input,input,input)\nprint(output.shape)\n```\n\n***\n\n### 3. Simplified Self Attention Usage\n#### 3.1. Paper\n[None]()\n\n#### 3.2. Overview\n![](./model/img/SSA.png)\n\n#### 3.3. Usage Code\n```python\nfrom fightingcv_attention.attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttention\nimport torch\n\ninput=torch.randn(50,49,512)\nssa = SimplifiedScaledDotProductAttention(d_model=512, h=8)\noutput=ssa(input,input,input)\nprint(output.shape)\n\n```\n\n***\n\n### 4. Squeeze-and-Excitation Attention Usage\n#### 4.1. Paper\n[\"Squeeze-and-Excitation Networks\"](https://arxiv.org/abs/1709.01507)\n\n#### 4.2. Overview\n![](./model/img/SE.png)\n\n#### 4.3. Usage Code\n```python\nfrom fightingcv_attention.attention.SEAttention import SEAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\nse = SEAttention(channel=512,reduction=8)\noutput=se(input)\nprint(output.shape)\n\n```\n\n***\n\n### 5. SK Attention Usage\n#### 5.1. Paper\n[\"Selective Kernel Networks\"](https://arxiv.org/pdf/1903.06586.pdf)\n\n#### 5.2. Overview\n![](./model/img/SK.png)\n\n#### 5.3. Usage Code\n```python\nfrom fightingcv_attention.attention.SKAttention import SKAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\nse = SKAttention(channel=512,reduction=8)\noutput=se(input)\nprint(output.shape)\n\n```\n***\n\n### 6. CBAM Attention Usage\n#### 6.1. Paper\n[\"CBAM: Convolutional Block Attention Module\"](https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf)\n\n#### 6.2. Overview\n![](./model/img/CBAM1.png)\n\n![](./model/img/CBAM2.png)\n\n#### 6.3. Usage Code\n```python\nfrom fightingcv_attention.attention.CBAM import CBAMBlock\nimport torch\n\ninput=torch.randn(50,512,7,7)\nkernel_size=input.shape[2]\ncbam = CBAMBlock(channel=512,reduction=16,kernel_size=kernel_size)\noutput=cbam(input)\nprint(output.shape)\n\n```\n\n***\n\n### 7. BAM Attention Usage\n#### 7.1. Paper\n[\"BAM: Bottleneck Attention Module\"](https://arxiv.org/pdf/1807.06514.pdf)\n\n#### 7.2. Overview\n![](./model/img/BAM.png)\n\n#### 7.3. Usage Code\n```python\nfrom fightingcv_attention.attention.BAM import BAMBlock\nimport torch\n\ninput=torch.randn(50,512,7,7)\nbam = BAMBlock(channel=512,reduction=16,dia_val=2)\noutput=bam(input)\nprint(output.shape)\n\n```\n\n***\n\n### 8. ECA Attention Usage\n#### 8.1. Paper\n[\"ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks\"](https://arxiv.org/pdf/1910.03151.pdf)\n\n#### 8.2. Overview\n![](./model/img/ECA.png)\n\n#### 8.3. Usage Code\n```python\nfrom fightingcv_attention.attention.ECAAttention import ECAAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\neca = ECAAttention(kernel_size=3)\noutput=eca(input)\nprint(output.shape)\n\n```\n\n***\n\n### 9. DANet Attention Usage\n#### 9.1. Paper\n[\"Dual Attention Network for Scene Segmentation\"](https://arxiv.org/pdf/1809.02983.pdf)\n\n#### 9.2. Overview\n![](./model/img/danet.png)\n\n#### 9.3. Usage Code\n```python\nfrom fightingcv_attention.attention.DANet import DAModule\nimport torch\n\ninput=torch.randn(50,512,7,7)\ndanet=DAModule(d_model=512,kernel_size=3,H=7,W=7)\nprint(danet(input).shape)\n\n```\n\n***\n\n### 10. Pyramid Split Attention Usage\n\n#### 10.1. Paper\n[\"EPSANet: An Efficient Pyramid Split Attention Block on Convolutional Neural Network\"](https://arxiv.org/pdf/2105.14447.pdf)\n\n#### 10.2. Overview\n![](./model/img/psa.png)\n\n#### 10.3. Usage Code\n```python\nfrom fightingcv_attention.attention.PSA import PSA\nimport torch\n\ninput=torch.randn(50,512,7,7)\npsa = PSA(channel=512,reduction=8)\noutput=psa(input)\nprint(output.shape)\n\n```\n\n***\n\n\n### 11. Efficient Multi-Head Self-Attention Usage\n\n#### 11.1. Paper\n[\"ResT: An Efficient Transformer for Visual Recognition\"](https://arxiv.org/abs/2105.13677)\n\n#### 11.2. Overview\n![](./model/img/EMSA.png)\n\n#### 11.3. Usage Code\n```python\n\nfrom fightingcv_attention.attention.EMSA import EMSA\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,64,512)\nemsa = EMSA(d_model=512, d_k=512, d_v=512, h=8,H=8,W=8,ratio=2,apply_transform=True)\noutput=emsa(input,input,input)\nprint(output.shape)\n    \n```\n\n***\n\n\n### 12. Shuffle Attention Usage\n\n#### 12.1. Paper\n[\"SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS\"](https://arxiv.org/pdf/2102.00240.pdf)\n\n#### 12.2. Overview\n![](./model/img/ShuffleAttention.jpg)\n\n#### 12.3. Usage Code\n```python\n\nfrom fightingcv_attention.attention.ShuffleAttention import ShuffleAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\ninput=torch.randn(50,512,7,7)\nse = ShuffleAttention(channel=512,G=8)\noutput=se(input)\nprint(output.shape)\n\n    \n```\n\n\n***\n\n\n### 13. MUSE Attention Usage\n\n#### 13.1. Paper\n[\"MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning\"](https://arxiv.org/abs/1911.09483)\n\n#### 13.2. Overview\n![](./model/img/MUSE.png)\n\n#### 13.3. Usage Code\n```python\nfrom fightingcv_attention.attention.MUSEAttention import MUSEAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\ninput=torch.randn(50,49,512)\nsa = MUSEAttention(d_model=512, d_k=512, d_v=512, h=8)\noutput=sa(input,input,input)\nprint(output.shape)\n\n```\n\n***\n\n\n### 14. SGE Attention Usage\n\n#### 14.1. Paper\n[Spatial Group-wise Enhance: Improving Semantic Feature Learning in Convolutional Networks](https://arxiv.org/pdf/1905.09646.pdf)\n\n#### 14.2. Overview\n![](./model/img/SGE.png)\n\n#### 14.3. Usage Code\n```python\nfrom fightingcv_attention.attention.SGE import SpatialGroupEnhance\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\nsge = SpatialGroupEnhance(groups=8)\noutput=sge(input)\nprint(output.shape)\n\n```\n\n***\n\n\n### 15. A2 Attention Usage\n\n#### 15.1. Paper\n[A2-Nets: Double Attention Networks](https://arxiv.org/pdf/1810.11579.pdf)\n\n#### 15.2. Overview\n![](./model/img/A2.png)\n\n#### 15.3. Usage Code\n```python\nfrom fightingcv_attention.attention.A2Atttention import DoubleAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\na2 = DoubleAttention(512,128,128,True)\noutput=a2(input)\nprint(output.shape)\n\n```\n\n\n\n### 16. AFT Attention Usage\n\n#### 16.1. Paper\n[An Attention Free Transformer](https://arxiv.org/pdf/2105.14103v1.pdf)\n\n#### 16.2. Overview\n![](./model/img/AFT.jpg)\n\n#### 16.3. Usage Code\n```python\nfrom fightingcv_attention.attention.AFT import AFT_FULL\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,49,512)\naft_full = AFT_FULL(d_model=512, n=49)\noutput=aft_full(input)\nprint(output.shape)\n\n```\n\n\n\n\n\n\n### 17. Outlook Attention Usage\n\n#### 17.1. Paper\n\n\n[VOLO: Vision Outlooker for Visual Recognition\"](https://arxiv.org/abs/2106.13112)\n\n\n#### 17.2. Overview\n![](./model/img/OutlookAttention.png)\n\n#### 17.3. Usage Code\n```python\nfrom fightingcv_attention.attention.OutlookAttention import OutlookAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,28,28,512)\noutlook = OutlookAttention(dim=512)\noutput=outlook(input)\nprint(output.shape)\n\n```\n\n\n***\n\n\n\n\n\n\n### 18. ViP Attention Usage\n\n#### 18.1. Paper\n\n\n[Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition\"](https://arxiv.org/abs/2106.12368)\n\n\n#### 18.2. Overview\n![](./model/img/ViP.png)\n\n#### 18.3. Usage Code\n```python\n\nfrom fightingcv_attention.attention.ViP import WeightedPermuteMLP\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(64,8,8,512)\nseg_dim=8\nvip=WeightedPermuteMLP(512,seg_dim)\nout=vip(input)\nprint(out.shape)\n\n```\n\n\n***\n\n\n\n\n\n### 19. CoAtNet Attention Usage\n\n#### 19.1. Paper\n\n\n[CoAtNet: Marrying Convolution and Attention for All Data Sizes\"](https://arxiv.org/abs/2106.04803) \n\n\n#### 19.2. Overview\nNone\n\n\n#### 19.3. Usage Code\n```python\n\nfrom fightingcv_attention.attention.CoAtNet import CoAtNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,3,224,224)\nmbconv=CoAtNet(in_ch=3,image_size=224)\nout=mbconv(input)\nprint(out.shape)\n\n```\n\n\n***\n\n\n\n\n\n\n### 20. HaloNet Attention Usage\n\n#### 20.1. Paper\n\n\n[Scaling Local Self-Attention for Parameter Efficient Visual Backbones\"](https://arxiv.org/pdf/2103.12731.pdf) \n\n\n#### 20.2. Overview\n\n![](./model/img/HaloNet.png)\n\n#### 20.3. Usage Code\n```python\n\nfrom fightingcv_attention.attention.HaloAttention import HaloAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,512,8,8)\nhalo = HaloAttention(dim=512,\n    block_size=2,\n    halo_size=1,)\noutput=halo(input)\nprint(output.shape)\n\n```\n\n\n***\n\n### 21. Polarized Self-Attention Usage\n\n#### 21.1. Paper\n\n[Polarized Self-Attention: Towards High-quality Pixel-wise Regression\"](https://arxiv.org/abs/2107.00782)  \n\n\n#### 21.2. Overview\n\n![](./model/img/PoSA.png)\n\n#### 21.3. Usage Code\n```python\n\nfrom fightingcv_attention.attention.PolarizedSelfAttention import ParallelPolarizedSelfAttention,SequentialPolarizedSelfAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,512,7,7)\npsa = SequentialPolarizedSelfAttention(channel=512)\noutput=psa(input)\nprint(output.shape)\n\n\n```\n\n\n***\n\n\n### 22. CoTAttention Usage\n\n#### 22.1. Paper\n\n[Contextual Transformer Networks for Visual Recognition---arXiv 2021.07.26](https://arxiv.org/abs/2107.12292) \n\n\n#### 22.2. Overview\n\n![](./model/img/CoT.png)\n\n#### 22.3. Usage Code\n```python\n\nfrom fightingcv_attention.attention.CoTAttention import CoTAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\ncot = CoTAttention(dim=512,kernel_size=3)\noutput=cot(input)\nprint(output.shape)\n\n\n\n```\n\n***\n\n\n### 23. Residual Attention Usage\n\n#### 23.1. Paper\n\n[Residual Attention: A Simple but Effective Method for Multi-Label Recognition---ICCV2021](https://arxiv.org/abs/2108.02456) \n\n\n#### 23.2. Overview\n\n![](./model/img/ResAtt.png)\n\n#### 23.3. Usage Code\n```python\n\nfrom fightingcv_attention.attention.ResidualAttention import ResidualAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\nresatt = ResidualAttention(channel=512,num_class=1000,la=0.2)\noutput=resatt(input)\nprint(output.shape)\n\n\n\n```\n\n***\n\n\n\n### 24. S2 Attention Usage\n\n#### 24.1. Paper\n\n[S²-MLPv2: Improved Spatial-Shift MLP Architecture for Vision---arXiv 2021.08.02](https://arxiv.org/abs/2108.01072) \n\n\n#### 24.2. Overview\n\n![](./model/img/S2Attention.png)\n\n#### 24.3. Usage Code\n```python\nfrom fightingcv_attention.attention.S2Attention import S2Attention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\ns2att = S2Attention(channels=512)\noutput=s2att(input)\nprint(output.shape)\n\n```\n\n***\n\n\n\n### 25. GFNet Attention Usage\n\n#### 25.1. Paper\n\n[Global Filter Networks for Image Classification---arXiv 2021.07.01](https://arxiv.org/abs/2107.00645) \n\n\n#### 25.2. Overview\n\n![](./model/img/GFNet.jpg)\n\n#### 25.3. Usage Code - Implemented by [Wenliang Zhao (Author)](https://scholar.google.com/citations?user=lyPWvuEAAAAJ&hl=en)\n\n```python\nfrom fightingcv_attention.attention.gfnet import GFNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nx = torch.randn(1, 3, 224, 224)\ngfnet = GFNet(embed_dim=384, img_size=224, patch_size=16, num_classes=1000)\nout = gfnet(x)\nprint(out.shape)\n\n```\n\n***\n\n\n### 26. TripletAttention Usage\n\n#### 26.1. Paper\n\n[Rotate to Attend: Convolutional Triplet Attention Module---CVPR 2021](https://arxiv.org/abs/2010.03045) \n\n#### 26.2. Overview\n\n![](./model/img/triplet.png)\n\n#### 26.3. Usage Code - Implemented by [digantamisra98](https://github.com/digantamisra98)\n\n```python\nfrom fightingcv_attention.attention.TripletAttention import TripletAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\ninput=torch.randn(50,512,7,7)\ntriplet = TripletAttention()\noutput=triplet(input)\nprint(output.shape)\n```\n\n\n***\n\n\n### 27. Coordinate Attention Usage\n\n#### 27.1. Paper\n\n[Coordinate Attention for Efficient Mobile Network Design---CVPR 2021](https://arxiv.org/abs/2103.02907)\n\n\n#### 27.2. Overview\n\n![](./model/img/CoordAttention.png)\n\n#### 27.3. Usage Code - Implemented by [Andrew-Qibin](https://github.com/Andrew-Qibin)\n\n```python\nfrom fightingcv_attention.attention.CoordAttention import CoordAtt\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninp=torch.rand([2, 96, 56, 56])\ninp_dim, oup_dim = 96, 96\nreduction=32\n\ncoord_attention = CoordAtt(inp_dim, oup_dim, reduction=reduction)\noutput=coord_attention(inp)\nprint(output.shape)\n```\n\n***\n\n\n### 28. MobileViT Attention Usage\n\n#### 28.1. Paper\n\n[MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2021.10.05](https://arxiv.org/abs/2103.02907)\n\n\n#### 28.2. Overview\n\n![](./model/img/MobileViTAttention.png)\n\n#### 28.3. Usage Code\n\n```python\nfrom fightingcv_attention.attention.MobileViTAttention import MobileViTAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    m=MobileViTAttention()\n    input=torch.randn(1,3,49,49)\n    output=m(input)\n    print(output.shape)  #output:(1,3,49,49)\n    \n```\n\n***\n\n\n### 29. ParNet Attention Usage\n\n#### 29.1. Paper\n\n[Non-deep Networks---ArXiv 2021.10.20](https://arxiv.org/abs/2110.07641)\n\n\n#### 29.2. Overview\n\n![](./model/img/ParNet.png)\n\n#### 29.3. Usage Code\n\n```python\nfrom fightingcv_attention.attention.ParNetAttention import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    pna = ParNetAttention(channel=512)\n    output=pna(input)\n    print(output.shape) #50,512,7,7\n    \n```\n\n***\n\n\n### 30. UFO Attention Usage\n\n#### 30.1. Paper\n\n[UFO-ViT: High Performance Linear Vision Transformer without Softmax---ArXiv 2021.09.29](https://arxiv.org/abs/2110.07641)\n\n\n#### 30.2. Overview\n\n![](./model/img/UFO.png)\n\n#### 30.3. Usage Code\n\n```python\nfrom fightingcv_attention.attention.UFOAttention import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    ufo = UFOAttention(d_model=512, d_k=512, d_v=512, h=8)\n    output=ufo(input,input,input)\n    print(output.shape) #[50, 49, 512]\n    \n```\n\n-\n\n### 31. ACmix Attention Usage\n\n#### 31.1. Paper\n\n[On the Integration of Self-Attention and Convolution](https://arxiv.org/pdf/2111.14556.pdf)\n\n#### 31.2. Usage Code\n\n```python\nfrom fightingcv_attention.attention.ACmix import ACmix\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(50,256,7,7)\n    acmix = ACmix(in_planes=256, out_planes=256)\n    output=acmix(input)\n    print(output.shape)\n    \n```\n\n### 32. MobileViTv2 Attention Usage\n\n#### 32.1. Paper\n\n[Separable Self-attention for Mobile Vision Transformers---ArXiv 2022.06.06](https://arxiv.org/abs/2206.02680)\n\n\n#### 32.2. Overview\n\n![](./model/img/MobileViTv2.png)\n\n#### 32.3. Usage Code\n\n```python\nfrom fightingcv_attention.attention.MobileViTv2Attention import MobileViTv2Attention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    sa = MobileViTv2Attention(d_model=512)\n    output=sa(input)\n    print(output.shape)\n    \n```\n\n### 33. DAT Attention Usage\n\n#### 33.1. Paper\n\n[Vision Transformer with Deformable Attention---CVPR2022](https://arxiv.org/abs/2201.00520)\n\n#### 33.2. Usage Code\n\n```python\nfrom fightingcv_attention.attention.DAT import DAT\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = DAT(\n        img_size=224,\n        patch_size=4,\n        num_classes=1000,\n        expansion=4,\n        dim_stem=96,\n        dims=[96, 192, 384, 768],\n        depths=[2, 2, 6, 2],\n        stage_spec=[['L', 'S'], ['L', 'S'], ['L', 'D', 'L', 'D', 'L', 'D'], ['L', 'D']],\n        heads=[3, 6, 12, 24],\n        window_sizes=[7, 7, 7, 7] ,\n        groups=[-1, -1, 3, 6],\n        use_pes=[False, False, True, True],\n        dwc_pes=[False, False, False, False],\n        strides=[-1, -1, 1, 1],\n        sr_ratios=[-1, -1, -1, -1],\n        offset_range_factor=[-1, -1, 2, 2],\n        no_offs=[False, False, False, False],\n        fixed_pes=[False, False, False, False],\n        use_dwc_mlps=[False, False, False, False],\n        use_conv_patches=False,\n        drop_rate=0.0,\n        attn_drop_rate=0.0,\n        drop_path_rate=0.2,\n    )\n    output=model(input)\n    print(output[0].shape)\n    \n```\n\n### 34. CrossFormer Attention Usage\n\n#### 34.1. Paper\n\n[CROSSFORMER: A VERSATILE VISION TRANSFORMER HINGING ON CROSS-SCALE ATTENTION---ICLR 2022](https://arxiv.org/pdf/2108.00154.pdf)\n\n#### 34.2. Usage Code\n\n```python\nfrom fightingcv_attention.attention.Crossformer import CrossFormer\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CrossFormer(img_size=224,\n        patch_size=[4, 8, 16, 32],\n        in_chans= 3,\n        num_classes=1000,\n        embed_dim=48,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        group_size=[7, 7, 7, 7],\n        mlp_ratio=4.,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.0,\n        drop_path_rate=0.1,\n        ape=False,\n        patch_norm=True,\n        use_checkpoint=False,\n        merge_size=[[2, 4], [2,4], [2, 4]]\n    )\n    output=model(input)\n    print(output.shape)\n    \n```\n\n### 35. MOATransformer Attention Usage\n\n#### 35.1. Paper\n\n[Aggregating Global Features into Local Vision Transformer](https://arxiv.org/abs/2201.12903)\n\n#### 35.2. Usage Code\n\n```python\nfrom fightingcv_attention.attention.MOATransformer import MOATransformer\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = MOATransformer(\n        img_size=224,\n        patch_size=4,\n        in_chans=3,\n        num_classes=1000,\n        embed_dim=96,\n        depths=[2, 2, 6],\n        num_heads=[3, 6, 12],\n        window_size=14,\n        mlp_ratio=4.,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.0,\n        drop_path_rate=0.1,\n        ape=False,\n        patch_norm=True,\n        use_checkpoint=False\n    )\n    output=model(input)\n    print(output.shape)\n    \n```\n\n### 36. CrissCrossAttention Attention Usage\n\n#### 36.1. Paper\n\n[CCNet: Criss-Cross Attention for Semantic Segmentation](https://arxiv.org/abs/1811.11721)\n\n#### 36.2. Usage Code\n\n```python\nfrom fightingcv_attention.attention.CrissCrossAttention import CrissCrossAttention\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(3, 64, 7, 7)\n    model = CrissCrossAttention(64)\n    outputs = model(input)\n    print(outputs.shape)\n    \n```\n\n### 37. Axial_attention Attention Usage\n\n#### 37.1. Paper\n\n[Axial Attention in Multidimensional Transformers](https://arxiv.org/abs/1912.12180)\n\n#### 37.2. Usage Code\n\n```python\nfrom fightingcv_attention.attention.Axial_attention import AxialImageTransformer\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(3, 128, 7, 7)\n    model = AxialImageTransformer(\n        dim = 128,\n        depth = 12,\n        reversible = True\n    )\n    outputs = model(input)\n    print(outputs.shape)\n    \n```\n\n***\n\n\n# Backbone Series\n\n- Pytorch implementation of [\"Deep Residual Learning for Image Recognition---CVPR2016 Best Paper\"](https://arxiv.org/pdf/1512.03385.pdf)\n\n- Pytorch implementation of [\"Aggregated Residual Transformations for Deep Neural Networks---CVPR2017\"](https://arxiv.org/abs/1611.05431v2)\n\n- Pytorch implementation of [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2020.10.05](https://arxiv.org/abs/2103.02907)\n\n- Pytorch implementation of [Patches Are All You Need?---ICLR2022 (Under Review)](https://openreview.net/forum?id=TVHS5Y4dNvM)\n\n- Pytorch implementation of [Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer---ArXiv 2021.06.07](https://arxiv.org/abs/2106.03650)\n\n- Pytorch implementation of [ConTNet: Why not use convolution and transformer at the same time?---ArXiv 2021.04.27](https://arxiv.org/abs/2104.13497)\n\n- Pytorch implementation of [Vision Transformers with Hierarchical Attention---ArXiv 2022.06.15](https://arxiv.org/abs/2106.03180)\n\n- Pytorch implementation of [Co-Scale Conv-Attentional Image Transformers---ArXiv 2021.08.26](https://arxiv.org/abs/2104.06399)\n\n- Pytorch implementation of [Conditional Positional Encodings for Vision Transformers](https://arxiv.org/abs/2102.10882)\n\n- Pytorch implementation of [Rethinking Spatial Dimensions of Vision Transformers---ICCV 2021](https://arxiv.org/abs/2103.16302)\n\n- Pytorch implementation of [CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification---ICCV 2021](https://arxiv.org/abs/2103.14899)\n\n- Pytorch implementation of [Transformer in Transformer---NeurIPS 2021](https://arxiv.org/abs/2103.00112)\n\n- Pytorch implementation of [DeepViT: Towards Deeper Vision Transformer](https://arxiv.org/abs/2103.11886)\n\n- Pytorch implementation of [Incorporating Convolution Designs into Visual Transformers](https://arxiv.org/abs/2103.11816)\n***\n\n- Pytorch implementation of [ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases](https://arxiv.org/abs/2103.10697)\n\n- Pytorch implementation of [Augmenting Convolutional networks with attention-based aggregation](https://arxiv.org/abs/2112.13692)\n\n- Pytorch implementation of [Going deeper with Image Transformers---ICCV 2021 (Oral)](https://arxiv.org/abs/2103.17239)\n\n- Pytorch implementation of [Training data-efficient image transformers & distillation through attention---ICML 2021](https://arxiv.org/abs/2012.12877)\n\n- Pytorch implementation of [LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference](https://arxiv.org/abs/2104.01136)\n\n- Pytorch implementation of [VOLO: Vision Outlooker for Visual Recognition](https://arxiv.org/abs/2106.13112)\n\n- Pytorch implementation of [Container: Context Aggregation Network---NeuIPS 2021](https://arxiv.org/abs/2106.01401)\n\n- Pytorch implementation of [CMT: Convolutional Neural Networks Meet Vision Transformers---CVPR 2022](https://arxiv.org/abs/2107.06263)\n\n- Pytorch implementation of [Vision Transformer with Deformable Attention---CVPR 2022](https://arxiv.org/abs/2201.00520)\n\n- Pytorch implementation of [EfficientFormer: Vision Transformers at MobileNet Speed](https://arxiv.org/abs/2206.01191)\n\n\n### 1. ResNet Usage\n#### 1.1. Paper\n[\"Deep Residual Learning for Image Recognition---CVPR2016 Best Paper\"](https://arxiv.org/pdf/1512.03385.pdf)\n\n#### 1.2. Overview\n![](./model/img/resnet.png)\n![](./model/img/resnet2.jpg)\n\n#### 1.3. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.resnet import ResNet50,ResNet101,ResNet152\nimport torch\nif __name__ == '__main__':\n    input=torch.randn(50,3,224,224)\n    resnet50=ResNet50(1000)\n    # resnet101=ResNet101(1000)\n    # resnet152=ResNet152(1000)\n    out=resnet50(input)\n    print(out.shape)\n\n```\n\n\n### 2. ResNeXt Usage\n#### 2.1. Paper\n\n[\"Aggregated Residual Transformations for Deep Neural Networks---CVPR2017\"](https://arxiv.org/abs/1611.05431v2)\n\n#### 2.2. Overview\n![](./model/img/resnext.png)\n\n#### 2.3. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.resnext import ResNeXt50,ResNeXt101,ResNeXt152\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(50,3,224,224)\n    resnext50=ResNeXt50(1000)\n    # resnext101=ResNeXt101(1000)\n    # resnext152=ResNeXt152(1000)\n    out=resnext50(input)\n    print(out.shape)\n\n\n```\n\n\n\n### 3. MobileViT Usage\n#### 3.1. Paper\n\n[MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2020.10.05](https://arxiv.org/abs/2103.02907)\n\n#### 3.2. Overview\n![](./model/img/mobileViT.jpg)\n\n#### 3.3. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.MobileViT import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n\n    ### mobilevit_xxs\n    mvit_xxs=mobilevit_xxs()\n    out=mvit_xxs(input)\n    print(out.shape)\n\n    ### mobilevit_xs\n    mvit_xs=mobilevit_xs()\n    out=mvit_xs(input)\n    print(out.shape)\n\n\n    ### mobilevit_s\n    mvit_s=mobilevit_s()\n    out=mvit_s(input)\n    print(out.shape)\n\n```\n\n\n\n\n\n### 4. ConvMixer Usage\n#### 4.1. Paper\n[Patches Are All You Need?---ICLR2022 (Under Review)](https://openreview.net/forum?id=TVHS5Y4dNvM)\n#### 4.2. Overview\n![](./model/img/ConvMixer.png)\n\n#### 4.3. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.ConvMixer import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    x=torch.randn(1,3,224,224)\n    convmixer=ConvMixer(dim=512,depth=12)\n    out=convmixer(x)\n    print(out.shape)  #[1, 1000]\n\n\n```\n\n### 5. ShuffleTransformer Usage\n#### 5.1. Paper\n[Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer](https://arxiv.org/pdf/2106.03650.pdf)\n\n#### 5.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.ShuffleTransformer import ShuffleTransformer\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    sft = ShuffleTransformer()\n    output=sft(input)\n    print(output.shape)\n\n\n```\n\n### 6. ConTNet Usage\n#### 6.1. Paper\n[ConTNet: Why not use convolution and transformer at the same time?](https://arxiv.org/abs/2104.13497)\n\n#### 6.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.ConTNet import ConTNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == \"__main__\":\n    model = build_model(use_avgdown=True, relative=True, qkv_bias=True, pre_norm=True)\n    input = torch.randn(1, 3, 224, 224)\n    out = model(input)\n    print(out.shape)\n\n\n```\n\n### 7 HATNet Usage\n#### 7.1. Paper\n[Vision Transformers with Hierarchical Attention](https://arxiv.org/abs/2106.03180)\n\n#### 7.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.HATNet import HATNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    hat = HATNet(dims=[48, 96, 240, 384], head_dim=48, expansions=[8, 8, 4, 4],\n        grid_sizes=[8, 7, 7, 1], ds_ratios=[8, 4, 2, 1], depths=[2, 2, 6, 3])\n    output=hat(input)\n    print(output.shape)\n\n\n```\n\n### 8 CoaT Usage\n#### 8.1. Paper\n[Co-Scale Conv-Attentional Image Transformers](https://arxiv.org/abs/2104.06399)\n\n#### 8.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.CoaT import CoaT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CoaT(patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4])\n    output=model(input)\n    print(output.shape) # torch.Size([1, 1000])\n\n```\n\n### 9 PVT Usage\n#### 9.1. Paper\n[PVT v2: Improved Baselines with Pyramid Vision Transformer](https://arxiv.org/pdf/2106.13797.pdf)\n\n#### 9.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.PVT import PyramidVisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = PyramidVisionTransformer(\n        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1])\n    output=model(input)\n    print(output.shape)\n\n```\n\n\n### 10 CPVT Usage\n#### 10.1. Paper\n[Conditional Positional Encodings for Vision Transformers](https://arxiv.org/abs/2102.10882)\n\n#### 10.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.CPVT import CPVTV2\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CPVTV2(\n        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1])\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 11 PIT Usage\n#### 11.1. Paper\n[Rethinking Spatial Dimensions of Vision Transformers](https://arxiv.org/abs/2103.16302)\n\n#### 11.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.PIT import PoolingTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = PoolingTransformer(\n        image_size=224,\n        patch_size=14,\n        stride=7,\n        base_dims=[64, 64, 64],\n        depth=[3, 6, 4],\n        heads=[4, 8, 16],\n        mlp_ratio=4\n    )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 12 CrossViT Usage\n#### 12.1. Paper\n[CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification](https://arxiv.org/abs/2103.14899)\n\n#### 12.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.CrossViT import VisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == \"__main__\":\n    input=torch.randn(1,3,224,224)\n    model = VisionTransformer(\n        img_size=[240, 224],\n        patch_size=[12, 16], \n        embed_dim=[192, 384], \n        depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],\n        num_heads=[6, 6], \n        mlp_ratio=[4, 4, 1], \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n    )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 13 TnT Usage\n#### 13.1. Paper\n[Transformer in Transformer](https://arxiv.org/abs/2103.00112)\n\n#### 13.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.TnT import TNT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = TNT(\n        img_size=224, \n        patch_size=16, \n        outer_dim=384, \n        inner_dim=24, \n        depth=12,\n        outer_num_heads=6, \n        inner_num_heads=4, \n        qkv_bias=False,\n        inner_stride=4)\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 14 DViT Usage\n#### 14.1. Paper\n[DeepViT: Towards Deeper Vision Transformer](https://arxiv.org/abs/2103.11886)\n\n#### 14.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.DViT import DeepVisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = DeepVisionTransformer(\n        patch_size=16, embed_dim=384, \n        depth=[False] * 16, \n        apply_transform=[False] * 0 + [True] * 32, \n        num_heads=12, \n        mlp_ratio=3, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 15 CeiT Usage\n#### 15.1. Paper\n[Incorporating Convolution Designs into Visual Transformers](https://arxiv.org/abs/2103.11816)\n\n#### 15.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.CeiT import CeIT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CeIT(\n        hybrid_backbone=Image2Tokens(),\n        patch_size=4, \n        embed_dim=192, \n        depth=12, \n        num_heads=3, \n        mlp_ratio=4, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 16 ConViT Usage\n#### 16.1. Paper\n[ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases](https://arxiv.org/abs/2103.10697)\n\n#### 16.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.ConViT import VisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VisionTransformer(\n        num_heads=16,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 17 CaiT Usage\n#### 17.1. Paper\n[Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239)\n\n#### 17.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.CaiT import CaiT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CaiT(\n        img_size= 224,\n        patch_size=16, \n        embed_dim=192, \n        depth=24, \n        num_heads=4, \n        mlp_ratio=4, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_scale=1e-5,\n        depth_token_only=2\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 18 PatchConvnet Usage\n#### 18.1. Paper\n[Augmenting Convolutional networks with attention-based aggregation](https://arxiv.org/abs/2112.13692)\n\n#### 18.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.PatchConvnet import PatchConvnet\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = PatchConvnet(\n        patch_size=16,\n        embed_dim=384,\n        depth=60,\n        num_heads=1,\n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        Patch_layer=ConvStem,\n        Attention_block=Conv_blocks_se,\n        depth_token_only=1,\n        mlp_ratio_clstk=3.0,\n    )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 19 DeiT Usage\n#### 19.1. Paper\n[Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877)\n\n#### 19.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.DeiT import DistilledVisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = DistilledVisionTransformer(\n        patch_size=16, \n        embed_dim=384, \n        depth=12, \n        num_heads=6, \n        mlp_ratio=4, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n        )\n    output=model(input)\n    print(output[0].shape)\n\n```\n\n### 20 LeViT Usage\n#### 20.1. Paper\n[LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference](https://arxiv.org/abs/2104.01136)\n\n#### 20.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.LeViT import *\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    for name in specification:\n        input=torch.randn(1,3,224,224)\n        model = globals()[name](fuse=True, pretrained=False)\n        model.eval()\n        output = model(input)\n        print(output.shape)\n\n```\n\n### 21 VOLO Usage\n#### 21.1. Paper\n[VOLO: Vision Outlooker for Visual Recognition](https://arxiv.org/abs/2106.13112)\n\n#### 21.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.VOLO import VOLO\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VOLO([4, 4, 8, 2],\n                 embed_dims=[192, 384, 384, 384],\n                 num_heads=[6, 12, 12, 12],\n                 mlp_ratios=[3, 3, 3, 3],\n                 downsamples=[True, False, False, False],\n                 outlook_attention=[True, False, False, False ],\n                 post_layers=['ca', 'ca'],\n                 )\n    output=model(input)\n    print(output[0].shape)\n\n```\n\n### 22 Container Usage\n#### 22.1. Paper\n[Container: Context Aggregation Network](https://arxiv.org/abs/2106.01401)\n\n#### 22.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.Container import VisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VisionTransformer(\n        img_size=[224, 56, 28, 14], \n        patch_size=[4, 2, 2, 2], \n        embed_dim=[64, 128, 320, 512], \n        depth=[3, 4, 8, 3], \n        num_heads=16, \n        mlp_ratio=[8, 8, 4, 4], \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6))\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 23 CMT Usage\n#### 23.1. Paper\n[CMT: Convolutional Neural Networks Meet Vision Transformers](https://arxiv.org/abs/2107.06263)\n\n#### 23.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.CMT import CMT_Tiny\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CMT_Tiny()\n    output=model(input)\n    print(output[0].shape)\n\n```\n\n### 24 EfficientFormer Usage\n#### 24.1. Paper\n[EfficientFormer: Vision Transformers at MobileNet Speed](https://arxiv.org/abs/2206.01191)\n\n#### 24.2. Usage Code\n```python\n\nfrom fightingcv_attention.backbone.EfficientFormer import EfficientFormer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = EfficientFormer(\n        layers=EfficientFormer_depth['l1'],\n        embed_dims=EfficientFormer_width['l1'],\n        downsamples=[True, True, True, True],\n        vit_num=1,\n    )\n    output=model(input)\n    print(output[0].shape)\n\n```\n\n\n\n\n\n\n# MLP Series\n\n- Pytorch implementation of [\"RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition---arXiv 2021.05.05\"](https://arxiv.org/pdf/2105.01883v1.pdf)\n\n- Pytorch implementation of [\"MLP-Mixer: An all-MLP Architecture for Vision---arXiv 2021.05.17\"](https://arxiv.org/pdf/2105.01601.pdf)\n\n- Pytorch implementation of [\"ResMLP: Feedforward networks for image classification with data-efficient training---arXiv 2021.05.07\"](https://arxiv.org/pdf/2105.03404.pdf)\n\n- Pytorch implementation of [\"Pay Attention to MLPs---arXiv 2021.05.17\"](https://arxiv.org/abs/2105.08050)\n\n\n- Pytorch implementation of [\"Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?---arXiv 2021.09.12\"](https://arxiv.org/abs/2109.05422)\n\n### 1. RepMLP Usage\n#### 1.1. Paper\n[\"RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition\"](https://arxiv.org/pdf/2105.01883v1.pdf)\n\n#### 1.2. Overview\n![](./model/img/repmlp.png)\n\n#### 1.3. Usage Code\n```python\nfrom fightingcv_attention.mlp.repmlp import RepMLP\nimport torch\nfrom torch import nn\n\nN=4 #batch size\nC=512 #input dim\nO=1024 #output dim\nH=14 #image height\nW=14 #image width\nh=7 #patch height\nw=7 #patch width\nfc1_fc2_reduction=1 #reduction ratio\nfc3_groups=8 # groups\nrepconv_kernels=[1,3,5,7] #kernel list\nrepmlp=RepMLP(C,O,H,W,h,w,fc1_fc2_reduction,fc3_groups,repconv_kernels=repconv_kernels)\nx=torch.randn(N,C,H,W)\nrepmlp.eval()\nfor module in repmlp.modules():\n    if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):\n        nn.init.uniform_(module.running_mean, 0, 0.1)\n        nn.init.uniform_(module.running_var, 0, 0.1)\n        nn.init.uniform_(module.weight, 0, 0.1)\n        nn.init.uniform_(module.bias, 0, 0.1)\n\n#training result\nout=repmlp(x)\n#inference result\nrepmlp.switch_to_deploy()\ndeployout = repmlp(x)\n\nprint(((deployout-out)**2).sum())\n```\n\n### 2. MLP-Mixer Usage\n#### 2.1. Paper\n[\"MLP-Mixer: An all-MLP Architecture for Vision\"](https://arxiv.org/pdf/2105.01601.pdf)\n\n#### 2.2. Overview\n![](./model/img/mlpmixer.png)\n\n#### 2.3. Usage Code\n```python\nfrom fightingcv_attention.mlp.mlp_mixer import MlpMixer\nimport torch\nmlp_mixer=MlpMixer(num_classes=1000,num_blocks=10,patch_size=10,tokens_hidden_dim=32,channels_hidden_dim=1024,tokens_mlp_dim=16,channels_mlp_dim=1024)\ninput=torch.randn(50,3,40,40)\noutput=mlp_mixer(input)\nprint(output.shape)\n```\n\n***\n\n### 3. ResMLP Usage\n#### 3.1. Paper\n[\"ResMLP: Feedforward networks for image classification with data-efficient training\"](https://arxiv.org/pdf/2105.03404.pdf)\n\n#### 3.2. Overview\n![](./model/img/resmlp.png)\n\n#### 3.3. Usage Code\n```python\nfrom fightingcv_attention.mlp.resmlp import ResMLP\nimport torch\n\ninput=torch.randn(50,3,14,14)\nresmlp=ResMLP(dim=128,image_size=14,patch_size=7,class_num=1000)\nout=resmlp(input)\nprint(out.shape) #the last dimention is class_num\n```\n\n***\n\n### 4. gMLP Usage\n#### 4.1. Paper\n[\"Pay Attention to MLPs\"](https://arxiv.org/abs/2105.08050)\n\n#### 4.2. Overview\n![](./model/img/gMLP.jpg)\n\n#### 4.3. Usage Code\n```python\nfrom fightingcv_attention.mlp.g_mlp import gMLP\nimport torch\n\nnum_tokens=10000\nbs=50\nlen_sen=49\nnum_layers=6\ninput=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen\ngmlp = gMLP(num_tokens=num_tokens,len_sen=len_sen,dim=512,d_ff=1024)\noutput=gmlp(input)\nprint(output.shape)\n```\n\n***\n\n### 5. sMLP Usage\n#### 5.1. Paper\n[\"Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?\"](https://arxiv.org/abs/2109.05422)\n\n#### 5.2. Overview\n![](./model/img/sMLP.jpg)\n\n#### 5.3. Usage Code\n```python\nfrom fightingcv_attention.mlp.sMLP_block import sMLPBlock\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,3,224,224)\n    smlp=sMLPBlock(h=224,w=224)\n    out=smlp(input)\n    print(out.shape)\n```\n\n### 6. vip-mlp Usage\n#### 6.1. Paper\n[\"Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition\"](https://arxiv.org/abs/2106.12368)\n\n#### 6.2. Usage Code\n```python\nfrom fightingcv_attention.mlp.vip-mlp import VisionPermutator\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VisionPermutator(\n        layers=[4, 3, 8, 3], \n        embed_dims=[384, 384, 384, 384], \n        patch_size=14, \n        transitions=[False, False, False, False],\n        segment_dim=[16, 16, 16, 16], \n        mlp_ratios=[3, 3, 3, 3], \n        mlp_fn=WeightedPermuteMLP\n    )\n    output=model(input)\n    print(output.shape)\n```\n\n\n# Re-Parameter Series\n\n- Pytorch implementation of [\"RepVGG: Making VGG-style ConvNets Great Again---CVPR2021\"](https://arxiv.org/abs/2101.03697)\n\n- Pytorch implementation of [\"ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks---ICCV2019\"](https://arxiv.org/abs/1908.03930)\n\n- Pytorch implementation of [\"Diverse Branch Block: Building a Convolution as an Inception-like Unit---CVPR2021\"](https://arxiv.org/abs/2103.13425)\n\n\n***\n\n### 1. RepVGG Usage\n#### 1.1. Paper\n[\"RepVGG: Making VGG-style ConvNets Great Again\"](https://arxiv.org/abs/2101.03697)\n\n#### 1.2. Overview\n![](./model/img/repvgg.png)\n\n#### 1.3. Usage Code\n```python\n\nfrom fightingcv_attention.rep.repvgg import RepBlock\nimport torch\n\n\ninput=torch.randn(50,512,49,49)\nrepblock=RepBlock(512,512)\nrepblock.eval()\nout=repblock(input)\nrepblock._switch_to_deploy()\nout2=repblock(input)\nprint('difference between vgg and repvgg')\nprint(((out2-out)**2).sum())\n```\n\n\n\n***\n\n### 2. ACNet Usage\n#### 2.1. Paper\n[\"ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks\"](https://arxiv.org/abs/1908.03930)\n\n#### 2.2. Overview\n![](./model/img/acnet.png)\n\n#### 2.3. Usage Code\n```python\nfrom fightingcv_attention.rep.acnet import ACNet\nimport torch\nfrom torch import nn\n\ninput=torch.randn(50,512,49,49)\nacnet=ACNet(512,512)\nacnet.eval()\nout=acnet(input)\nacnet._switch_to_deploy()\nout2=acnet(input)\nprint('difference:')\nprint(((out2-out)**2).sum())\n\n```\n\n\n\n***\n\n### 2. Diverse Branch Block Usage\n#### 2.1. Paper\n[\"Diverse Branch Block: Building a Convolution as an Inception-like Unit\"](https://arxiv.org/abs/2103.13425)\n\n#### 2.2. Overview\n![](./model/img/ddb.png)\n\n#### 2.3. Usage Code\n##### 2.3.1 Transform I\n```python\nfrom fightingcv_attention.rep.ddb import transI_conv_bn\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n#conv+bn\nconv1=nn.Conv2d(64,64,3,padding=1)\nbn1=nn.BatchNorm2d(64)\nbn1.eval()\nout1=bn1(conv1(input))\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transI_conv_bn(conv1,bn1)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.2 Transform II\n```python\nfrom fightingcv_attention.rep.ddb import transII_conv_branch\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1=nn.Conv2d(64,64,3,padding=1)\nconv2=nn.Conv2d(64,64,3,padding=1)\nout1=conv1(input)+conv2(input)\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transII_conv_branch(conv1,conv2)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.3 Transform III\n```python\nfrom fightingcv_attention.rep.ddb import transIII_conv_sequential\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1=nn.Conv2d(64,64,1,padding=0,bias=False)\nconv2=nn.Conv2d(64,64,3,padding=1,bias=False)\nout1=conv2(conv1(input))\n\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1,bias=False)\nconv_fuse.weight.data=transIII_conv_sequential(conv1,conv2)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.4 Transform IV\n```python\nfrom fightingcv_attention.rep.ddb import transIV_conv_concat\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1=nn.Conv2d(64,32,3,padding=1)\nconv2=nn.Conv2d(64,32,3,padding=1)\nout1=torch.cat([conv1(input),conv2(input)],dim=1)\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transIV_conv_concat(conv1,conv2)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.5 Transform V\n```python\nfrom fightingcv_attention.rep.ddb import transV_avg\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\navg=nn.AvgPool2d(kernel_size=3,stride=1)\nout1=avg(input)\n\nconv=transV_avg(64,3)\nout2=conv(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n\n##### 2.3.6 Transform VI\n```python\nfrom fightingcv_attention.rep.ddb import transVI_conv_scale\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1x1=nn.Conv2d(64,64,1)\nconv1x3=nn.Conv2d(64,64,(1,3),padding=(0,1))\nconv3x1=nn.Conv2d(64,64,(3,1),padding=(1,0))\nout1=conv1x1(input)+conv1x3(input)+conv3x1(input)\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transVI_conv_scale(conv1x1,conv1x3,conv3x1)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n\n\n\n\n# Convolution Series\n\n- Pytorch implementation of [\"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications---CVPR2017\"](https://arxiv.org/abs/1704.04861)\n\n- Pytorch implementation of [\"Efficientnet: Rethinking model scaling for convolutional neural networks---PMLR2019\"](http://proceedings.mlr.press/v97/tan19a.html)\n\n- Pytorch implementation of [\"Involution: Inverting the Inherence of Convolution for Visual Recognition---CVPR2021\"](https://arxiv.org/abs/2103.06255)\n\n- Pytorch implementation of [\"Dynamic Convolution: Attention over Convolution Kernels---CVPR2020 Oral\"](https://arxiv.org/abs/1912.03458)\n\n- Pytorch implementation of [\"CondConv: Conditionally Parameterized Convolutions for Efficient Inference---NeurIPS2019\"](https://arxiv.org/abs/1904.04971)\n\n***\n\n### 1. Depthwise Separable Convolution Usage\n#### 1.1. Paper\n[\"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications\"](https://arxiv.org/abs/1704.04861)\n\n#### 1.2. Overview\n![](./model/img/DepthwiseSeparableConv.png)\n\n#### 1.3. Usage Code\n```python\nfrom fightingcv_attention.conv.DepthwiseSeparableConvolution import DepthwiseSeparableConvolution\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,3,224,224)\ndsconv=DepthwiseSeparableConvolution(3,64)\nout=dsconv(input)\nprint(out.shape)\n```\n\n***\n\n\n### 2. MBConv Usage\n#### 2.1. Paper\n[\"Efficientnet: Rethinking model scaling for convolutional neural networks\"](http://proceedings.mlr.press/v97/tan19a.html)\n\n#### 2.2. Overview\n![](./model/img/MBConv.jpg)\n\n#### 2.3. Usage Code\n```python\nfrom fightingcv_attention.conv.MBConv import MBConvBlock\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,3,224,224)\nmbconv=MBConvBlock(ksize=3,input_filters=3,output_filters=512,image_size=224)\nout=mbconv(input)\nprint(out.shape)\n\n\n```\n\n***\n\n\n### 3. Involution Usage\n#### 3.1. Paper\n[\"Involution: Inverting the Inherence of Convolution for Visual Recognition\"](https://arxiv.org/abs/2103.06255)\n\n#### 3.2. Overview\n![](./model/img/Involution.png)\n\n#### 3.3. Usage Code\n```python\nfrom fightingcv_attention.conv.Involution import Involution\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,4,64,64)\ninvolution=Involution(kernel_size=3,in_channel=4,stride=2)\nout=involution(input)\nprint(out.shape)\n```\n\n***\n\n\n### 4. DynamicConv Usage\n#### 4.1. Paper\n[\"Dynamic Convolution: Attention over Convolution Kernels\"](https://arxiv.org/abs/1912.03458)\n\n#### 4.2. Overview\n![](./model/img/DynamicConv.png)\n\n#### 4.3. Usage Code\n```python\nfrom fightingcv_attention.conv.DynamicConv import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(2,32,64,64)\n    m=DynamicConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)\n    out=m(input)\n    print(out.shape) # 2,32,64,64\n\n```\n\n***\n\n\n### 5. CondConv Usage\n#### 5.1. Paper\n[\"CondConv: Conditionally Parameterized Convolutions for Efficient Inference\"](https://arxiv.org/abs/1904.04971)\n\n#### 5.2. Overview\n![](./model/img/CondConv.png)\n\n#### 5.3. Usage Code\n```python\nfrom fightingcv_attention.conv.CondConv import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\n\n\n\nif __name__ == '__main__':\n    input=torch.randn(2,32,64,64)\n    m=CondConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)\n    out=m(input)\n    print(out.shape)\n\n```\n"
  },
  {
    "path": "main.py",
    "content": "from model.attention.MobileViTv2Attention import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    sa = MobileViTv2Attention(d_model=512)\n    output=sa(input)\n    print(output.shape)\n "
  },
  {
    "path": "model/.vscode/settings.json",
    "content": "{\n    \"python.pythonPath\": \"D:\\\\Anaconda\\\\python.exe\"\n}"
  },
  {
    "path": "model/__init__.py",
    "content": "\ndef test():\n    print (\"hello world\")\n\nif __name__ == '__main__':\n    test()"
  },
  {
    "path": "model/analysis/Attention.md",
    "content": "## Content\n  \n- [1. External Attention](#1-external-attention)\n\n- [2. Self Attention](#2-self-attention)\n\n- [3. Squeeze-and-Excitation(SE) Attention](#3-squeeze-and-excitationse-attention)\n\n- [4. Selective Kernel(SK) Attention](#4-selective-kernelsk-attention)\n\n- [5. CBAM Attention](#5-cbam-attention)\n\n- [6. BAM Attention](#6-bam-attention)\n\n- [7. ECA Attention](#7-eca-attention)\n\n- [8. DANet Attention](#8-danet-attention)\n\n- [9. Pyramid Split Attention(PSA)](#9-pyramid-split-attentionpsa)\n\n- [10. Efficient Multi-Head Self-Attention(EMSA)](#10-efficient-multi-head-self-attentionemsa)\n\n- [Write at the end](#Write_at_the_end)\n\n\n\n## 1. External Attention\n\n### 1.1. Citation\n\nBeyond Self-attention: External Attention using Two Linear Layers for Visual Tasks.---arXiv 2021.05.05\n\nAddress：[https://arxiv.org/abs/2105.02358](https://arxiv.org/abs/2105.02358)\n\n### 1.2. Model Structure\n\n![](./img/External_Attention.png)\n\n### 1.3. Brief\nThis is an article on arXiv in May. It mainly solves two pain points of Self-Attention (SA): (1) O(n^2) computational complexity; (2) SA is in the same sample The above calculates Attention based on different positions, ignoring the relationship between different samples. Therefore, this paper uses two serial MLP structures as memory units, which reduces the computational complexity to O(n); in addition, these two memory units are learned based on all training data, so they also implicitly consider the differences. The connection between the samples.\n\n### 1.4. Usage\n\n```python\nfrom attention.ExternalAttention import ExternalAttention\nimport torch\n\n\ninput=torch.randn(50,49,512)\nea = ExternalAttention(d_model=512,S=8)\noutput=ea(input)\nprint(output.shape)\n```\n\n\n\n\n\n## 2. Self Attention\n\n### 2.1. Citation\n\nAttention Is All You Need---NeurIPS2017\n\nAddress：[https://arxiv.org/abs/1706.03762](https://arxiv.org/abs/1706.03762)\n\n### 2.2. Model Structure\n\n![](./img/SA.png)\n\n### 2.3. Brief\nThis is an article published by Google in NeurIPS2017. It has a great influence in various fields such as CV, NLP, and multi-modality. The current citation volume has been 2.2w+. The Self-Attention proposed in Transformer is a kind of Attention, which is used to calculate the weight between different positions in the feature, so as to achieve the effect of updating the feature. First, the input feature is mapped into three features of Q, K, and V through FC, and then Q and K are dot-multiplied to obtain the attention map, and the attention map and V are dot-multiplied to obtain the weighted feature. Finally, the feature is mapped through FC, and a new feature is obtained. (There are many very good explanations about Transformer and Self-Attention on the Internet, so I won’t give a detailed introduction here)\n\n### 2.4. Usage\n\n```python\nfrom attention.SelfAttention import ScaledDotProductAttention\nimport torch\n\ninput=torch.randn(50,49,512)\nsa = ScaledDotProductAttention(d_model=512, d_k=512, d_v=512, h=8)\noutput=sa(input,input,input)\nprint(output.shape)\n```\n\n\n\n\n\n## 3. Squeeze-and-Excitation(SE) Attention\n\n### 3.1. Citation\n\nSqueeze-and-Excitation Networks---CVPR2018\n\nAddress：[https://arxiv.org/abs/1709.01507](https://arxiv.org/abs/1709.01507)\n\n### 3.2. Model Structure\n\n![](./img/SE.png)\n\n### 3.3. Brief\nThis is an article of CVPR2018, which is also very influential. The current citation volume is 7k+. This article is for channel attention. Because of its simple structure and effectiveness, it has set off a wave of channel attention. From the avenue to the simple, the idea of this article can be said to be very simple. First, the spatial dimension is applied to AdaptiveAvgPool, and then the channel attention is learned through two FCs, and the Sigmoid is used for normalization to obtain the Channel Attention Map, and finally the Channel Attention Map is combined with the original Multiply the features to get the weighted features.\n\n### 3.4. Usage\n\n```python\nfrom attention.SEAttention import SEAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\nse = SEAttention(channel=512,reduction=8)\noutput=se(input)\nprint(output.shape)\n```\n\n\n\n \n\n## 4. Selective Kernel(SK) Attention\n\n### 4.1. Citation\n\nSelective Kernel Networks---CVPR2019\n\nAddress：[https://arxiv.org/pdf/1903.06586.pdf](https://arxiv.org/pdf/1903.06586.pdf)\n\n### 4.2. Model Structure\n\n![](./img/SK.png)\n\n### 4.3. Brief\nThis is an article from CVPR2019, which pays tribute to SENet's thoughts. In traditional CNN, each convolutional layer uses the same size convolution kernel, which limits the expressive ability of the model; and the \"wider\" model structure of Inception is also verified, using multiple different convolution kernels. Learning can indeed improve the expressive ability of the model. The author draws on the idea of ​​SENet, obtains the weight of the channel by dynamically calculating each convolution kernel, and dynamically merges the results of each convolution kernel.\n\nI personally think that the reason why this article can also be called lightweight is that when channel attention is performed on the features of different kernels, the parameters are shared (ie because before Attention, the features are first fused, so different The result of the convolution kernel shares a parameter of the SE module).\n\nThe method in this article is divided into three parts: Split, Fuse, and Select. Split is a multi-branch operation, convolution with different convolution kernels to get different features; the Fuse part is to use the SE structure to obtain the channel attention matrix (N convolution kernels can get N attention matrices , This step is shared with all the feature parameters), so that the features of different kernels after SE can be obtained; the Select operation is to add these features.\n\n### 4.4. Usage\n\n```python\nfrom attention.SKAttention import SKAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\nse = SKAttention(channel=512,reduction=8)\noutput=se(input)\nprint(output.shape)\n```\n\n\n\n \n\n## 5. CBAM Attention\n\n### 5.1. Citation\n\nCBAM: Convolutional Block Attention Module---ECCV2018\n\nAddress：[https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf](https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf)\n\n### 5.2. Model Structure\n\n![](./img/CBAM1.png)\n\n![](./img/CBAM2.png)\n\n### 5.3. Brief\nThis is an ECCV2018 paper. This article uses Channel Attention and Spatial Attention at the same time and connects the two in series (the article also does ablation experiments in parallel and two series).\n\nIn terms of Channel Attention, the general structure is still similar to SE, but the author proposes that AvgPool and MaxPool have different representation effects, so the author performs AvgPool and MaxPool on the original features in the Spatial dimension, and then uses the SE structure to extract channel attention. Note here The parameters are shared, and then the two features are added and normalized to obtain the attention matrix.\n\nSpatial Attention is similar to Channel Attention. After performing two pools in the channel dimension, the two features are spliced, and then a 7x7 convolution is used to extract the Spatial Attention (the reason for using 7x7 is because the spatial attention is extracted, so use The convolution kernel must be large enough). Then do a normalization to get the spatial attention matrix.\n\n### 5.4. Usage\n\n```python\nfrom attention.CBAM import CBAMBlock\nimport torch\n\ninput=torch.randn(50,512,7,7)\nkernel_size=input.shape[2]\ncbam = CBAMBlock(channel=512,reduction=16,kernel_size=kernel_size)\noutput=cbam(input)\nprint(output.shape)\n```\n\n\n\n \n\n## 6. BAM Attention\n\n### 6.1. Citation\n\nBAM: Bottleneck Attention Module---BMCV2018\n\nAddress：[https://arxiv.org/pdf/1807.06514.pdf](https://arxiv.org/pdf/1807.06514.pdf)\n\n### 6.2. Model Structure\n\n![](./img/BAM.png)\n\n### 6.3. Brief\nThis is the work of CBAM and the author at the same time. The work is very similar to CBAM, and it is also dual attention. The difference is that CBAM connects the results of two attention in series; while BAM directly adds two attention matrices.\n\nIn terms of Channel Attention, the structure is basically the same as SE. In terms of Spatial Attention, the pool is still performed in the channel dimension, and then a 3x3 hole convolution is used twice, and finally a 1x1 convolution will be used to obtain the Spatial Attention matrix.\n\nFinally, the Channel Attention and Spatial Attention matrices are added (the broadcast mechanism is used here) and normalized. In this way, the attention matrix that combines space and channel is obtained.\n\n### 6.4. Usage\n\n```python\nfrom attention.BAM import BAMBlock\nimport torch\n\ninput=torch.randn(50,512,7,7)\nbam = BAMBlock(channel=512,reduction=16,dia_val=2)\noutput=bam(input)\nprint(output.shape)\n```\n\n\n\n\n\n## 7. ECA Attention\n\n### 7.1. Citation\n\nECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks---CVPR2020\n\nAddress：[https://arxiv.org/pdf/1910.03151.pdf](https://arxiv.org/pdf/1910.03151.pdf)\n\n### 7.2. Model Structure\n\n![](./img/ECA.png)\n\n### 7.3. Brief\nThis is an article of CVPR2020.\n\nAs shown in the figure above, SE uses two fully connected layers to achieve channel attention, while ECA requires one convolution. The reason why the author did this is that it is not necessary to calculate the attention between all channels. On the other hand, the use of two fully connected layers does introduce too many parameters and calculations.\n\nTherefore, after the author performed AvgPool, he only used a one-dimensional convolution with a receptive field of k (equivalent to only calculating the attention of the adjacent k channels), which greatly reduced the parameters and calculation amount. (i.e. is equivalent to SE being a global attention, while ECA is a local attention).\n### 7.4. Usage\n\n```python\nfrom attention.ECAAttention import ECAAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\neca = ECAAttention(kernel_size=3)\noutput=eca(input)\nprint(output.shape)\n```\n\n\n\n \n\n## 8. DANet Attention\n\n### 8.1. Citation\n\nDual Attention Network for Scene Segmentation---CVPR2019\n\nAddress：[https://arxiv.org/pdf/1809.02983.pdf](https://arxiv.org/pdf/1809.02983.pdf)\n\n### 8.2. Model Structure\n\n![](./img/danet.png)\n\n\n### 8.3. Brief\nThis is an article by CVPR2019. The idea is very simple, that is, self-attention is used in the task of scene segmentation. The difference is that self-attention is to pay attention to the attention between each position, and this article will make a self-attention. To expand, we also made a branch of channel attention. The operation is the same as self-attention. The three Linears that generate Q, K, and V are removed from different channel attention. Finally, the features after the two attentions are summed element-wise.\n\n### 8.4. Usage\n\n```python\nfrom attention.DANet import DAModule\nimport torch\n\ninput=torch.randn(50,512,7,7)\ndanet=DAModule(d_model=512,kernel_size=3,H=7,W=7)\nprint(danet(input).shape)\n```\n\n\n\n \n\n## 9. Pyramid Split Attention(PSA)\n\n### 9.1. Citation\n\nEPSANet: An Efficient Pyramid Split Attention Block on Convolutional Neural Network---arXiv 2021.05.30\n\nAddress：[https://arxiv.org/pdf/2105.14447.pdf](https://arxiv.org/pdf/2105.14447.pdf)\n\n### 9.2. Model Structure\n\n![](./img/psa.png)\n\n\n### 9.3. Brief\nThis is an article uploaded by Shenzhen University on arXiv on May 30. The purpose of this article is how to obtain and explore spatial information of different scales to enrich the feature space. The network structure is relatively simple, mainly divided into four steps. In the first part, the original feature is divided into n groups according to the channel, and then the different groups are convolved with different scales to obtain the new feature W1; the second part is SE performs SE on the original features to obtain different Attention Map; the third part is to perform softmax on different groups; the fourth part is to multiply the obtained attention with the original feature W1.\n### 9.4. Usage\n\n```python\nfrom attention.PSA import PSA\nimport torch\n\ninput=torch.randn(50,512,7,7)\npsa = PSA(channel=512,reduction=8)\noutput=psa(input)\nprint(output.shape)\n```\n\n\n\n \n\n## 10. Efficient Multi-Head Self-Attention(EMSA)\n\n### 10.1. Citation\n\nResT: An Efficient Transformer for Visual Recognition---arXiv 2021.05.28\n\nAddress：[https://arxiv.org/abs/2105.13677](https://arxiv.org/abs/2105.13677)\n\n### 10.2. Model Structure\n\n![](./img/EMSA.png)\n\n### 10.3. Brief\n\nThis is an article uploaded by Nanjing University on arXiv on May 28. This article mainly solves two pain points of SA: (1) The computational complexity of Self-Attention is squared with n; (2) Each head has only partial information of q, k, v, if q, k, v If the dimension of is too small, continuous information will not be obtained, resulting in performance loss. The idea given in this article is also very simple. In SA, before FC, a convolution is used to reduce the spatial dimension, thereby obtaining smaller K and V in spatial dimension.\n\n### 10.4. Usage\n\n```python\nfrom attention.EMSA import EMSA\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,64,512)\nemsa = EMSA(d_model=512, d_k=512, d_v=512, h=8,H=8,W=8,ratio=2,apply_transform=True)\noutput=emsa(input,input,input)\nprint(output.shape)\n```\n\n\n\n \n\n## Write at the end\nAt present, the Attention work organized by this project is indeed not comprehensive enough. As the amount of reading increases, we will continue to improve this project. Welcome everyone star to support. If there are incorrect statements or incorrect code implementations in the article, you are welcome to point out~"
  },
  {
    "path": "model/analysis/注意力机制.md",
    "content": "## 目录\n  \n- [1. External Attention](#1-external-attention)\n\n- [2. Self Attention](#2-self-attention)\n\n- [3. Squeeze-and-Excitation(SE) Attention](#3-squeeze-and-excitationse-attention)\n\n- [4. Selective Kernel(SK) Attention](#4-selective-kernelsk-attention)\n\n- [5. CBAM Attention](#5-cbam-attention)\n\n- [6. BAM Attention](#6-bam-attention)\n\n- [7. ECA Attention](#7-eca-attention)\n\n- [8. DANet Attention](#8-danet-attention)\n\n- [9. Pyramid Split Attention(PSA)](#9-pyramid-split-attentionpsa)\n\n- [10. Efficient Multi-Head Self-Attention(EMSA)](#10-efficient-multi-head-self-attentionemsa)\n\n- [【写在最后】](#写在最后)\n\n\n\n## 1. External Attention\n\n### 1.1. 引用\n\nBeyond Self-attention: External Attention using Two Linear Layers for Visual Tasks.---arXiv 2021.05.05\n\n论文地址：[https://arxiv.org/abs/2105.02358](https://arxiv.org/abs/2105.02358)\n\n### 1.2. 模型结构\n\n![](./img/External_Attention.png)\n\n### 1.3. 简介\n\n这是五月份在arXiv上的一篇文章，主要解决的Self-Attention(SA)的两个痛点问题：（1）O(n^2)的计算复杂度；(2)SA是在同一个样本上根据不同位置计算Attention，忽略了不同样本之间的联系。因此，本文采用了两个串联的MLP结构作为memory units，使得计算复杂度降低到了O(n)；此外，这两个memory units是基于全部的训练数据学习的，因此也隐式的考虑了不同样本之间的联系。\n\n### 1.4. 使用方法\n\n```python\nfrom attention.ExternalAttention import ExternalAttention\nimport torch\n\n\ninput=torch.randn(50,49,512)\nea = ExternalAttention(d_model=512,S=8)\noutput=ea(input)\nprint(output.shape)\n```\n\n\n\n\n\n## 2. Self Attention\n\n### 2.1. 引用\n\nAttention Is All You Need---NeurIPS2017\n\n论文地址：[https://arxiv.org/abs/1706.03762](https://arxiv.org/abs/1706.03762)\n\n### 2.2. 模型结构\n\n![](./img/SA.png)\n\n### 2.3. 简介\n\n这是Google在NeurIPS2017发表的一篇文章，在CV、NLP、多模态等各个领域都有很大的影响力，目前引用量已经2.2w+。Transformer中提出的Self-Attention是Attention的一种，用于计算特征中不同位置之间的权重，从而达到更新特征的效果。首先将input feature通过FC映射成Q、K、V三个特征，然后将Q和K进行点乘的得到attention map，在将attention map与V做点乘得到加权后的特征。最后通过FC进行特征的映射，得到一个新的特征。（关于Transformer和Self-Attention目前网上有许多非常好的讲解，这里就不做详细的介绍了）\n\n### 2.4. 使用方法\n\n```python\nfrom attention.SelfAttention import ScaledDotProductAttention\nimport torch\n\ninput=torch.randn(50,49,512)\nsa = ScaledDotProductAttention(d_model=512, d_k=512, d_v=512, h=8)\noutput=sa(input,input,input)\nprint(output.shape)\n```\n\n\n\n\n\n## 3. Squeeze-and-Excitation(SE) Attention\n\n### 3.1. 引用\n\nSqueeze-and-Excitation Networks---CVPR2018\n\n论文地址：[https://arxiv.org/abs/1709.01507](https://arxiv.org/abs/1709.01507)\n\n### 3.2. 模型结构\n\n![](./img/SE.png)\n\n### 3.3. 简介\n\n这是CVPR2018的一篇文章，同样非常具有影响力，目前引用量7k+。本文是做通道注意力的，因其简单的结构和有效性，将通道注意力掀起了一波小高潮。大道至简，这篇文章的思想可以说非常简单，首先将spatial维度进行AdaptiveAvgPool，然后通过两个FC学习到通道注意力，并用Sigmoid进行归一化得到Channel Attention Map,最后将Channel Attention Map与原特征相乘，就得到了加权后的特征。\n\n### 3.4. 使用方法\n\n```python\nfrom attention.SEAttention import SEAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\nse = SEAttention(channel=512,reduction=8)\noutput=se(input)\nprint(output.shape)\n```\n\n\n\n \n\n## 4. Selective Kernel(SK) Attention\n\n### 4.1. 引用\n\nSelective Kernel Networks---CVPR2019\n\n论文地址：[https://arxiv.org/pdf/1903.06586.pdf](https://arxiv.org/pdf/1903.06586.pdf)\n\n### 4.2. 模型结构\n\n![](./img/SK.png)\n\n### 4.3. 简介\n\n这是CVPR2019的一篇文章，致敬了SENet的思想。在传统的CNN中每一个卷积层都是用相同大小的卷积核，限制了模型的表达能力；而Inception这种“更宽”的模型结构也验证了，用多个不同的卷积核进行学习确实可以提升模型的表达能力。作者借鉴了SENet的思想，通过动态计算每个卷积核得到通道的权重，动态的将各个卷积核的结果进行融合。\n\n个人认为，之所以所这篇文章也能够称之为lightweight，是因为对不同kernel的特征进行通道注意力的时候是参数共享的（i.e. 因为在做Attention之前，首先将特征进行了融合，所以不同卷积核的结果共享一个SE模块的参数）。\n\n本文的方法分为三个部分：Split,Fuse,Select。Split就是一个multi-branch的操作，用不同的卷积核进行卷积得到不同的特征；Fuse部分就是用SE的结构获取通道注意力的矩阵(N个卷积核就可以得到N个注意力矩阵，这步操作对所有的特征参数共享)，这样就可以得到不同kernel经过SE之后的特征；Select操作就是将这几个特征进行相加。\n\n### 4.4. 使用方法\n\n```python\nfrom attention.SKAttention import SKAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\nse = SKAttention(channel=512,reduction=8)\noutput=se(input)\nprint(output.shape)\n```\n\n\n\n \n\n## 5. CBAM Attention\n\n### 5.1. 引用\n\nCBAM: Convolutional Block Attention Module---ECCV2018\n\n论文地址：[https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf](https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf)\n\n### 5.2. 模型结构\n\n![](./img/CBAM1.png)\n\n![](./img/CBAM2.png)\n\n### 5.3. 简介\n\n这是ECCV2018的一篇论文，这篇文章同时使用了Channel Attention和Spatial Attention，将两者进行了串联（文章也做了并联和两种串联方式的消融实验）。\n\nChannel Attention方面，大致结构还是和SE相似，不过作者提出AvgPool和MaxPool有不同的表示效果，所以作者对原来的特征在Spatial维度分别进行了AvgPool和MaxPool，然后用SE的结构提取channel attention，注意这里是参数共享的，然后将两个特征相加后做归一化，就得到了注意力矩阵。\n\nSpatial Attention和Channel Attention类似，先在channel维度进行两种pool后，将两个特征进行拼接，然后用7x7的卷积来提取Spatial Attention（之所以用7x7是因为提取的是空间注意力，所以用的卷积核必须足够大）。然后做一次归一化，就得到了空间的注意力矩阵。\n\n### 5.4. 使用方法\n\n```python\nfrom attention.CBAM import CBAMBlock\nimport torch\n\ninput=torch.randn(50,512,7,7)\nkernel_size=input.shape[2]\ncbam = CBAMBlock(channel=512,reduction=16,kernel_size=kernel_size)\noutput=cbam(input)\nprint(output.shape)\n```\n\n\n\n \n\n## 6. BAM Attention\n\n### 6.1. 引用\n\nBAM: Bottleneck Attention Module---BMCV2018\n\n论文地址：[https://arxiv.org/pdf/1807.06514.pdf](https://arxiv.org/pdf/1807.06514.pdf)\n\n### 6.2. 模型结构\n\n![](./img/BAM.png)\n\n### 6.3. 简介\n\n这是CBAM同作者同时期的工作，工作与CBAM非常相似，也是双重Attention，不同的是CBAM是将两个attention的结果串联；而BAM是直接将两个attention矩阵进行相加。\n\nChannel Attention方面，与SE的结构基本一样。Spatial Attention方面，还是在通道维度进行pool，然后用了两次3x3的空洞卷积，最后将用一次1x1的卷积得到Spatial Attention的矩阵。\n\n最后Channel Attention和Spatial Attention矩阵进行相加（这里用到了广播机制），并进行归一化，这样一来，就得到了空间和通道结合的attention矩阵。\n\n### 6.4.使用方法\n\n```python\nfrom attention.BAM import BAMBlock\nimport torch\n\ninput=torch.randn(50,512,7,7)\nbam = BAMBlock(channel=512,reduction=16,dia_val=2)\noutput=bam(input)\nprint(output.shape)\n```\n\n\n\n\n\n## 7. ECA Attention\n\n### 7.1. 引用\n\nECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks---CVPR2020\n\n论文地址：[https://arxiv.org/pdf/1910.03151.pdf](https://arxiv.org/pdf/1910.03151.pdf)\n\n### 7.2. 模型结构\n\n![](./img/ECA.png)\n\n### 7.3. 简介\n\n这是CVPR2020的一篇文章。\n\n如上图所示，SE实现通道注意力是使用两个全连接层，而ECA是需要一个的卷积。作者这么做的原因一方面是认为计算所有通道两两之间的注意力是没有必要的，另一方面是用两个全连接层确实引入了太多的参数和计算量。\n\n因此作者进行了AvgPool之后，只是使用了一个感受野为k的一维卷积（相当于只计算与相邻k个通道的注意力），这样做就大大的减少的参数和计算量。(i.e.相当于SE是一个global的注意力，而ECA是一个local的注意力)。\n\n### 7.4. 使用方法：\n\n```python\nfrom attention.ECAAttention import ECAAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\neca = ECAAttention(kernel_size=3)\noutput=eca(input)\nprint(output.shape)\n```\n\n\n\n \n\n## 8. DANet Attention\n\n### 8.1. 引用\n\nDual Attention Network for Scene Segmentation---CVPR2019\n\n论文地址：[https://arxiv.org/pdf/1809.02983.pdf](https://arxiv.org/pdf/1809.02983.pdf)\n\n### 8.2. 模型结构\n\n![](./img/danet.png)\n\n\n### 8.3. 简介\n\n这是CVPR2019的文章，思想上非常简单，就是将self-attention用到场景分割的任务中，不同的是self-attention是关注每个position之间的注意力，而本文将self-attention做了一个拓展，还做了一个通道注意力的分支，操作上和self-attention一样，不同的通道attention中把生成Q，K，V的三个Linear去掉了。最后将两个attention之后的特征进行element-wise sum。\n\n### 8.4. 使用方法\n\n```python\nfrom attention.DANet import DAModule\nimport torch\n\ninput=torch.randn(50,512,7,7)\ndanet=DAModule(d_model=512,kernel_size=3,H=7,W=7)\nprint(danet(input).shape)\n```\n\n\n\n \n\n## 9. Pyramid Split Attention(PSA)\n\n### 9.1. 引用\n\nEPSANet: An Efficient Pyramid Split Attention Block on Convolutional Neural Network---arXiv 2021.05.30\n\n论文地址：[https://arxiv.org/pdf/2105.14447.pdf](https://arxiv.org/pdf/2105.14447.pdf)\n\n### 9.2. 模型结构\n\n![](./img/psa.png)\n\n\n\n### 9.3. 简介\n\n这是深大5月30日在arXiv上上传的一篇文章，本文的目的是如何获取并探索不同尺度的空间信息来丰富特征空间。网络结构相对来说也比较简单，主要分成四步，第一步，将原来的feature根据通道分成n组然后对不同的组进行不同尺度的卷积，得到新的特征W1；第二步，用SE在原来的特征上进行SE，从而获得不同的Attention Map；第三步，对不同组进行SOFTMAX；第四步，将获得attention与原来的特征W1相乘。\n\n### 9.4. 使用方法\n\n```python\nfrom attention.PSA import PSA\nimport torch\n\ninput=torch.randn(50,512,7,7)\npsa = PSA(channel=512,reduction=8)\noutput=psa(input)\nprint(output.shape)\n```\n\n\n\n \n\n## 10. Efficient Multi-Head Self-Attention(EMSA)\n\n### 10.1. 引用\n\nResT: An Efficient Transformer for Visual Recognition---arXiv 2021.05.28\n\n论文地址：[https://arxiv.org/abs/2105.13677](https://arxiv.org/abs/2105.13677)\n\n### 10.2. 模型结构\n\n![](./img/EMSA.png)\n\n### 10.3. 简介\n\n这是南大5月28日在arXiv上上传的一篇文章。本文解决的主要是SA的两个痛点问题：（1）Self-Attention的计算复杂度和n呈平方关系；（2）每个head只有q,k,v的部分信息，如果q,k,v的维度太小，那么就会导致获取不到连续的信息，从而导致性能损失。这篇文章给出的思路也非常简单，在SA中，在FC之前，用了一个卷积来降低了空间的维度，从而得到空间维度上更小的K和V。\n\n### 10.4. 使用方法\n\n```python\nfrom attention.EMSA import EMSA\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,64,512)\nemsa = EMSA(d_model=512, d_k=512, d_v=512, h=8,H=8,W=8,ratio=2,apply_transform=True)\noutput=emsa(input,input,input)\nprint(output.shape)\n```\n\n\n\n \n\n## 【写在最后】\n\n目前该项目整理的Attention的工作确实还不够全面，后面随着阅读量的提高，会不断对本项目进行完善，欢迎大家star支持。若在文章中有表述不恰、代码实现有误的地方，欢迎大家指出~"
  },
  {
    "path": "model/analysis/重参数机制.md",
    "content": "[toc]\n\n## 【写在前面】\n\n最近拜读了丁霄汉大神的一系列重参数的论文，觉得这个思想真的很妙。能够在将所有的cost都放在训练过程中，在测试的时候能够在所有的网络参数和计算量都进行缩减。目前网上也有部分对这些论文进行了解析，为了能够让更多读者进一步、深层的理解重参数的思想，本文将会结合代码，近几年重参数的论文进行详细的解析。\n\n个人理解，重参数其实就是在测试的时候对训练的网络结构进行压缩。比如三个并联的卷积（kernel size相同）结果的和，其实就等于用求和之后的卷积核进行一次卷积的结果。所以，在训练的时候可以用三个卷积来提高模型的学习能力，但是在测试部署的时候，可以无损压缩为一次卷积，从而减少参数量和计算量。\n\n\n\n## 【复现框架】 \n\n[https://github.com/xmu-xiaoma666/External-Attention-pytorch](https://github.com/xmu-xiaoma666/External-Attention-pytorch)\n\n（欢迎大家***star***、***fork***该工作；如果有任何问题，也欢迎大家在***issue***中提出）\n\n\n\n## 【先验知识】\n\n首先向各位读者介绍一下卷积的一些基本性质，这几篇论文所提出的重参数操作，都是基于卷积的这几个性质。\n\n一个普通的卷积操作可以被定义成下面的公式：\n$$\nO=I*F+REP(b)\n$$\n其中，$*$为卷积操作，$O \\in \\mathbb{R}^{D \\times H' \\times W'}$为输出特征，$I \\in \\mathbb{R}^{C \\times H \\times W}$为输入特征，$F \\in \\mathbb{R}^{D \\times C \\times K \\times K}$为卷积核，$b \\in \\mathbb{R}^{D}$为偏置项，$Rep(b) \\in \\mathbb{R}^{D \\times H' \\times W'}$表示广播后的偏置项。\n\n卷积操作具有以下两个性质：\n\n### 1）同质性（homogeneity）\n\n$$\nI*(pF)=p(I*F), \\forall p \\in \\mathbb{R}\n$$\n\n这个性质的意思是，一个常数与卷积核相乘之后的结果与特征进行卷积=一个常数乘上卷积之后的结果。\n\n### 2）可加性（additivity）\n\n$$\nI*F^{(1)}+I*F^{(2)}=I*(F^{(1)}+F^{(2)})\n$$\n\n这个性质的意思是，两个并联的卷积结果相加，等于将这两个卷积核相加之后之后在进行卷积\n\n\n\n# 1.ICCV2019-ACNet\n\n## 1.1. 论文地址\n\nACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks\n\n论文地址：[https://arxiv.org/abs/1908.03930](https://arxiv.org/abs/1908.03930)\n\n## 1.2. 网络框架\n\n![](https://pic4.zhimg.com/80/v2-0fd24ba4c9648ddb4865da704e71def4_720w.png)\n\n![](https://pic1.zhimg.com/80/v2-3c064c20caa7c36e9bef1c3ba61dc62c_720w.png)\n\n## 1.3. 原理解释\n\n这篇文章做的主要工作如“网络框架”中的展示的那样：将并联的三个卷积（1x3、3x1、3x3）转换成一个卷积（3x3）。\n\n### 首先，考虑不带BatchNorm的情况：\n\n将1x3和3x1的卷积核转换到3x3的卷积核，只需要在水平和竖直方向分别用0padding成3x3的大小即可，然后将三个卷积核（1x3、3x1、3x3）相加进行卷积，就相当于是用三个卷积核分别对feature map卷积后再相加。\n\n### 然后，考虑如何将BatchNorm融合到卷积核中：\n\n卷积操作如下：\n$$\nO=I*F+REP(b)\n$$\nBatchNorm操作如下：\n$$\nBN(x)=\\gamma\\frac{(x-mean)}{\\sqrt{var}}+\\beta\n$$\n将卷积带入到BatchNorm就如：\n$$\nBN(Conv(x))=\\gamma\\frac{(I*F+REP(b)-mean)}{\\sqrt{var}}+\\beta\n$$\n  化简得到：\n$$\nBN(Conv(x))=I*(\\frac{F\\gamma}{\\sqrt{var}})+\\frac{\\gamma(REP(b)-mean)}{\\sqrt{var}}+\\beta\n$$\n因此新卷积核的weight和bias为：\n$$\nF_{new}=\\frac{F\\gamma}{\\sqrt{var}} \\\\\nREP(b)=\\frac{\\gamma(REP(b)-mean)}{\\sqrt{var}}+\\beta\n$$\n\n### 最后，考虑多分支的带BN的结构融合：\n\n第一步，我们将BN层的参数融合到卷积核中\n\n第二步，将BN层的参数融合到卷积核之后，原来带BN层的结构就变成了不带BN层的结构，我们将三个新卷积核相加之后，就得到了融合的卷积核。\n\n## 1.4. 代码调用\n\n```python\nfrom rep.acnet import ACNet\nimport torch\nfrom torch import nn\n\ninput=torch.randn(50,512,49,49)\nacnet=ACNet(512,512)\nacnet.eval()\nout=acnet(input)\nacnet._switch_to_deploy()\nout2=acnet(input)\nprint('difference:')\nprint(((out2-out)**2).sum())\n```\n\n\n\n# 2. CVPR2021-RepVGG\n\n## 2.1. 论文地址\n\nRepVGG: Making VGG-style ConvNets Great Again\n\n论文地址：[https://arxiv.org/abs/2101.03697](https://arxiv.org/abs/2101.03697)\n\n## 2.2. 网络框架\n\n![](https://pic3.zhimg.com/80/v2-7b404f792d4f7ac4d2d104a38b827989_720w.png)\n\n## 2.3. 原理解释\n\n这篇论文是的核心是将并联的带BN的3x3卷积核,1x1卷积核和残差结构转换为一个3x3的卷积核。\n\n首先，带BN的1x1卷积核和带BN的3x3卷积核融合成一个3x3的卷积核，这个操作和第一篇文章ACNet的转换方式非常相似，就是将1x1 的卷积核padding成3x3后，在进行和ACNet相同的操作。\n\n现在的问题是，怎么把残差结构也变成3x3的卷积。残差结构可以其实就是一个value为1的1x1的Depthwise卷积。如果能把Depthwise卷积转换成正常的卷积，那么这个问题也就迎刃而解了。下面这张图形象的展示了如果把Depthwise卷积转换成正常卷积：\n\n![img](https://pic3.zhimg.com/v2-7deaae69ee14aff87a0210f6c9965a66_b.jpg)\n\n（来自：https://zhuanlan.zhihu.com/p/352239591）\n\n其实就是将对应需要操作的通道赋值为1，其他赋值为0。\n\n输入通道为c，输出通道为c，把这里的参数矩阵比作cxc的矩阵，那么深度可分离矩阵就是一个单位矩阵（对角位置全部为1，其他全部为0）\n\n![img](https://pic1.zhimg.com/80/v2-48a4eb12a20d1a499d0fb7c2110caae0_720w.jpg)\n\n（来自：https://zhuanlan.zhihu.com/p/352239591）\n\n这样一来，残差结构也能转换成1x1的卷积了，然后与3x3的卷积用上面的方式进行合并，就得到了RepVGG的重参数结构\n\n## 2.4. 代码调用\n\n```python\nfrom rep.repvgg import RepBlock\nimport torch\n\ninput=torch.randn(50,512,49,49)\nrepblock=RepBlock(512,512)\nrepblock.eval()\nout=repblock(input)\nrepblock._switch_to_deploy()\nout2=repblock(input)\nprint('difference between vgg and repvgg')\nprint(((out2-out)**2).sum())\n```\n\n\n\n\n\n# 3. CVPR2021-Diverse Branch Block\n\n## 3.1. 论文地址\n\nDiverse Branch Block: Building a Convolution as an Inception-like Unit\n\n论文地址：[https://arxiv.org/abs/2103.13425](https://arxiv.org/abs/2103.13425)\n\n## 3.2. 网络框架\n\n![](https://pic4.zhimg.com/80/v2-8b304977c8b933f6ac9f31850d248f37_720w.png)\n\n## 3.3. 原理解释\n\n像Inception一样的多分支结构可以增加模型的表达能力，提高性能，但是也会带来额外的参数和显存使用。因此，本文提出了一个方法，在训练时采用多分支的结构，在测试和部署的时候将多分支的结构模型转换成一个单一分支的模型，从而模型在测试的时候就能够“免费”享用多分支结构带来的性能提升。但是怎么把多分支结构无损压缩成一个单分支的结构呢，这就是这篇文章的贡献点所在。\n\n### 3.3.1. Transform I:Conv+BN->BN\n\n这部分的原理在ACNet的解析中已经详细解释了，代码实现如下：\n\n```python\ndef transI_conv_bn(conv, bn):\n\n    std = (bn.running_var + bn.eps).sqrt()\n    gamma=bn.weight\n\n    weight=conv.weight*((gamma/std).reshape(-1, 1, 1, 1))\n    if(conv.bias is not None):\n        bias=gamma/std*conv.bias-gamma/std*bn.running_mean+bn.bias\n    else:\n        bias=bn.bias-gamma/std*bn.running_mean\n    return weight,bias\n```\n\n### 3.3.2. Transform II:并联Conv->Conv\n\n这部分的原理就是【先验知识】部分的可加性，代码实现如下：\n\n```python\ndef transII_conv_branch(conv1, conv2):\n    weight=conv1.weight.data+conv2.weight.data\n    bias=conv1.bias.data+conv2.bias.data\n    return weight,bias\n```\n\n\n\n### 3.3.3 Transform III:1x1Conv + 3x3Conv->3x3Conv\n\n1x1的卷积其实并没有在空间上对feature map进行交互操作（或者说都只是乘了相同的数），所以1x1的Conv其实就是一个全连接层。所以，本质上，我们可以直接后面接着的3x3的卷积进行这个1x1的卷积操作，得到的新的卷积核，就可以是融合之后的卷积核。（可能讲的不是非常清楚，详细解释可以参考这篇文章：https://zhuanlan.zhihu.com/p/360939086）\n\n代码实现如下：\n\n```python\ndef transIII_conv_sequential(conv1, conv2):\n    weight=F.conv2d(conv2.weight.data,conv1.weight.data.permute(1,0,2,3))\n    return weight\n```\n\n\n\n### 3.3.4 Transform IV:Concat Conv->Conv\n\n将多个卷积之后的结果进行concat，其实就是将多个卷积核权重在输出通道维度上进行拼接即可，代码实现如下：\n\n```python\ndef transIV_conv_concat(conv1, conv2):\n    print(conv1.bias.data.shape)\n    print(conv2.bias.data.shape)\n    weight=torch.cat([conv1.weight.data,conv2.weight.data],0)\n    bias=torch.cat([conv1.bias.data,conv2.bias.data],0)\n    return weight,bias\n```\n\n\n\n### 3.3.5 Transform V:AvgPooling->Conv\n\nAvgPool就是将感受野的值求平均，那么转换成卷积就是卷积核中每个值的value都等于1/卷积核大小，代码实现如下：\n\n```python\ndef transV_avg(channel,kernel):\n    conv=nn.Conv2d(channel,channel,kernel,bias=False)\n    conv.weight.data[:]=0\n    for i in range(channel):\n        conv.weight.data[i,i,:,:]=1/(kernel*kernel)\n    return conv\n```\n\n\n\n## 3.3.6 Transform VI:1x1Conv+1x3Conv+3x1Conv(并联)->Conv\n\n这个操作其实就是ACNet的思想，详情可以见上面ACNet的解析，代码实现如下：\n\n```python\ndef transVI_conv_scale(conv1, conv2, conv3):\n    weight=F.pad(conv1.weight.data,(1,1,1,1))+F.pad(conv2.weight.data,(0,0,1,1))+F.pad(conv3.weight.data,(1,1,0,0))\n    bias=conv1.bias.data+conv2.bias.data+conv3.bias.data\n    return weight,bias\n```\n\n\n\n\n\n## 3.4. 代码调用\n\n### 3.4.1. Transform I:Conv+BN->BN\n\n```python\nfrom rep.ddb import transI_conv_bn\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n#conv+bn\nconv1=nn.Conv2d(64,64,3,padding=1)\nbn1=nn.BatchNorm2d(64)\nbn1.eval()\nout1=bn1(conv1(input))\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transI_conv_bn(conv1,bn1)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n\n```\n\n\n\n### 3.4.2. Transform II:并联Conv->Conv\n\n```python\nfrom rep.ddb import transII_conv_branch\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1=nn.Conv2d(64,64,3,padding=1)\nconv2=nn.Conv2d(64,64,3,padding=1)\nout1=conv1(input)+conv2(input)\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transII_conv_branch(conv1,conv2)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n\n\n### 3.4.3 Transform III:1x1Conv + 3x3Conv->3x3Conv\n\n```python\nfrom rep.ddb import transIII_conv_sequential\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1=nn.Conv2d(64,64,1,padding=0,bias=False)\nconv2=nn.Conv2d(64,64,3,padding=1,bias=False)\nout1=conv2(conv1(input))\n\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1,bias=False)\nconv_fuse.weight.data=transIII_conv_sequential(conv1,conv2)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n\n\n### 3.4.4 Transform IV:Concat Conv->Conv\n\n```python\nfrom rep.ddb import transIV_conv_concat\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1=nn.Conv2d(64,32,3,padding=1)\nconv2=nn.Conv2d(64,32,3,padding=1)\nout1=torch.cat([conv1(input),conv2(input)],dim=1)\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transIV_conv_concat(conv1,conv2)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n\n\n### 3.4.5 Transform V:AvgPooling->Conv\n\n```python\nfrom rep.ddb import transV_avg\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\navg=nn.AvgPool2d(kernel_size=3,stride=1)\nout1=avg(input)\n\nconv=transV_avg(64,3)\nout2=conv(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n\n\n### 3.4.6 Transform VI:1x1Conv+1x3Conv+3x1Conv(并联)->Conv\n\n```python\nfrom rep.ddb import transVI_conv_scale\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1x1=nn.Conv2d(64,64,1)\nconv1x3=nn.Conv2d(64,64,(1,3),padding=(0,1))\nconv3x1=nn.Conv2d(64,64,(3,1),padding=(1,0))\nout1=conv1x1(input)+conv1x3(input)+conv3x1(input)\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transVI_conv_scale(conv1x1,conv1x3,conv3x1)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n"
  },
  {
    "path": "model/attention/A2Atttention.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\nfrom torch.nn import functional as F\n\n\n\nclass DoubleAttention(nn.Module):\n\n    def __init__(self, in_channels,c_m,c_n,reconstruct = True):\n        super().__init__()\n        self.in_channels=in_channels\n        self.reconstruct = reconstruct\n        self.c_m=c_m\n        self.c_n=c_n\n        self.convA=nn.Conv2d(in_channels,c_m,1)\n        self.convB=nn.Conv2d(in_channels,c_n,1)\n        self.convV=nn.Conv2d(in_channels,c_n,1)\n        if self.reconstruct:\n            self.conv_reconstruct = nn.Conv2d(c_m, in_channels, kernel_size = 1)\n        self.init_weights()\n\n\n    def init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n    def forward(self, x):\n        b, c, h,w=x.shape\n        assert c==self.in_channels\n        A=self.convA(x) #b,c_m,h,w\n        B=self.convB(x) #b,c_n,h,w\n        V=self.convV(x) #b,c_n,h,w\n        tmpA=A.view(b,self.c_m,-1)\n        attention_maps=F.softmax(B.view(b,self.c_n,-1))\n        attention_vectors=F.softmax(V.view(b,self.c_n,-1))\n        # step 1: feature gating\n        global_descriptors=torch.bmm(tmpA,attention_maps.permute(0,2,1)) #b.c_m,c_n\n        # step 2: feature distribution\n        tmpZ = global_descriptors.matmul(attention_vectors) #b,c_m,h*w\n        tmpZ=tmpZ.view(b,self.c_m,h,w) #b,c_m,h,w\n        if self.reconstruct:\n            tmpZ=self.conv_reconstruct(tmpZ)\n\n        return tmpZ \n\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    a2 = DoubleAttention(512,128,128,True)\n    output=a2(input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/ACmixAttention.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\ndef position(H, W, is_cuda=True):\n    if is_cuda:\n        loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1)\n        loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W)\n    else:\n        loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1)\n        loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W)\n    loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0)\n    return loc\n\n\ndef stride(x, stride):\n    b, c, h, w = x.shape\n    return x[:, :, ::stride, ::stride]\n\ndef init_rate_half(tensor):\n    if tensor is not None:\n        tensor.data.fill_(0.5)\n\ndef init_rate_0(tensor):\n    if tensor is not None:\n        tensor.data.fill_(0.)\n\n\nclass ACmix(nn.Module):\n    def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, stride=1, dilation=1):\n        super(ACmix, self).__init__()\n        self.in_planes = in_planes\n        self.out_planes = out_planes\n        self.head = head\n        self.kernel_att = kernel_att\n        self.kernel_conv = kernel_conv\n        self.stride = stride\n        self.dilation = dilation\n        self.rate1 = torch.nn.Parameter(torch.Tensor(1))\n        self.rate2 = torch.nn.Parameter(torch.Tensor(1))\n        self.head_dim = self.out_planes // self.head\n\n        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)\n        self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)\n        self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)\n        self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)\n\n        self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2\n        self.pad_att = torch.nn.ReflectionPad2d(self.padding_att)\n        self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride)\n        self.softmax = torch.nn.Softmax(dim=1)\n\n        self.fc = nn.Conv2d(3*self.head, self.kernel_conv * self.kernel_conv, kernel_size=1, bias=False)\n        self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim, out_planes, kernel_size=self.kernel_conv, bias=True, groups=self.head_dim, padding=1, stride=stride)\n\n        self.reset_parameters()\n    \n    def reset_parameters(self):\n        init_rate_half(self.rate1)\n        init_rate_half(self.rate2)\n        kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv)\n        for i in range(self.kernel_conv * self.kernel_conv):\n            kernel[i, i//self.kernel_conv, i%self.kernel_conv] = 1.\n        kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1)\n        self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True)\n        self.dep_conv.bias = init_rate_0(self.dep_conv.bias)\n\n    def forward(self, x):\n        q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)\n        scaling = float(self.head_dim) ** -0.5\n        b, c, h, w = q.shape\n        h_out, w_out = h//self.stride, w//self.stride\n\n        pe = self.conv_p(position(h, w, x.is_cuda))\n\n        q_att = q.view(b*self.head, self.head_dim, h, w) * scaling\n        k_att = k.view(b*self.head, self.head_dim, h, w)\n        v_att = v.view(b*self.head, self.head_dim, h, w)\n\n        if self.stride > 1:\n            q_att = stride(q_att, self.stride)\n            q_pe = stride(pe, self.stride)\n        else:\n            q_pe = pe\n\n        unfold_k = self.unfold(self.pad_att(k_att)).view(b*self.head, self.head_dim, self.kernel_att*self.kernel_att, h_out, w_out) # b*head, head_dim, k_att^2, h_out, w_out\n        unfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att*self.kernel_att, h_out, w_out) # 1, head_dim, k_att^2, h_out, w_out\n        \n        att = (q_att.unsqueeze(2)*(unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1) # (b*head, head_dim, 1, h_out, w_out) * (b*head, head_dim, k_att^2, h_out, w_out) -> (b*head, k_att^2, h_out, w_out)\n        att = self.softmax(att)\n\n        out_att = self.unfold(self.pad_att(v_att)).view(b*self.head, self.head_dim, self.kernel_att*self.kernel_att, h_out, w_out)\n        out_att = (att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out)\n\n        f_all = self.fc(torch.cat([q.view(b, self.head, self.head_dim, h*w), k.view(b, self.head, self.head_dim, h*w), v.view(b, self.head, self.head_dim, h*w)], 1))\n        f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])\n        \n        out_conv = self.dep_conv(f_conv)\n\n        return self.rate1 * out_att + self.rate2 * out_conv\n\nif __name__ == '__main__':\n    input=torch.randn(50,256,7,7)\n    acmix = ACmix(in_planes=256, out_planes=256)\n    output=acmix(input)\n    print(output.shape)\n"
  },
  {
    "path": "model/attention/AFT.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\n\n\n\nclass AFT_FULL(nn.Module):\n\n    def __init__(self, d_model,n=49,simple=False):\n\n        super(AFT_FULL, self).__init__()\n        self.fc_q = nn.Linear(d_model, d_model)\n        self.fc_k = nn.Linear(d_model, d_model)\n        self.fc_v = nn.Linear(d_model,d_model)\n        if(simple):\n            self.position_biases=torch.zeros((n,n))\n        else:\n            self.position_biases=nn.Parameter(torch.ones((n,n)))\n        self.d_model = d_model\n        self.n=n\n        self.sigmoid=nn.Sigmoid()\n\n        self.init_weights()\n\n\n    def init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n    def forward(self, input):\n\n        bs, n,dim = input.shape\n\n        q = self.fc_q(input) #bs,n,dim\n        k = self.fc_k(input).view(1,bs,n,dim) #1,bs,n,dim\n        v = self.fc_v(input).view(1,bs,n,dim) #1,bs,n,dim\n        \n        numerator=torch.sum(torch.exp(k+self.position_biases.view(n,1,-1,1))*v,dim=2) #n,bs,dim\n        denominator=torch.sum(torch.exp(k+self.position_biases.view(n,1,-1,1)),dim=2) #n,bs,dim\n\n        out=(numerator/denominator) #n,bs,dim\n        out=self.sigmoid(q)*(out.permute(1,0,2)) #bs,n,dim\n\n        return out\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    aft_full = AFT_FULL(d_model=512, n=49)\n    output=aft_full(input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/Axial_attention.py",
    "content": "import torch\nfrom torch import nn\nfrom operator import itemgetter\n# from axial_attention.reversible import ReversibleSequence\nfrom torch.autograd.function import Function\nfrom torch.utils.checkpoint import get_device_states, set_device_states\n\n# following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html\nclass Deterministic(nn.Module):\n    def __init__(self, net):\n        super().__init__()\n        self.net = net\n        self.cpu_state = None\n        self.cuda_in_fwd = None\n        self.gpu_devices = None\n        self.gpu_states = None\n\n    def record_rng(self, *args):\n        self.cpu_state = torch.get_rng_state()\n        if torch.cuda._initialized:\n            self.cuda_in_fwd = True\n            self.gpu_devices, self.gpu_states = get_device_states(*args)\n\n    def forward(self, *args, record_rng = False, set_rng = False, **kwargs):\n        if record_rng:\n            self.record_rng(*args)\n\n        if not set_rng:\n            return self.net(*args, **kwargs)\n\n        rng_devices = []\n        if self.cuda_in_fwd:\n            rng_devices = self.gpu_devices\n\n        with torch.random.fork_rng(devices=rng_devices, enabled=True):\n            torch.set_rng_state(self.cpu_state)\n            if self.cuda_in_fwd:\n                set_device_states(self.gpu_devices, self.gpu_states)\n            return self.net(*args, **kwargs)\n\n# heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py\n# once multi-GPU is confirmed working, refactor and send PR back to source\nclass ReversibleBlock(nn.Module):\n    def __init__(self, f, g):\n        super().__init__()\n        self.f = Deterministic(f)\n        self.g = Deterministic(g)\n\n    def forward(self, x, f_args = {}, g_args = {}):\n        x1, x2 = torch.chunk(x, 2, dim = 1)\n        y1, y2 = None, None\n\n        with torch.no_grad():\n            y1 = x1 + self.f(x2, record_rng=self.training, **f_args)\n            y2 = x2 + self.g(y1, record_rng=self.training, **g_args)\n\n        return torch.cat([y1, y2], dim = 1)\n\n    def backward_pass(self, y, dy, f_args = {}, g_args = {}):\n        y1, y2 = torch.chunk(y, 2, dim = 1)\n        del y\n\n        dy1, dy2 = torch.chunk(dy, 2, dim = 1)\n        del dy\n\n        with torch.enable_grad():\n            y1.requires_grad = True\n            gy1 = self.g(y1, set_rng=True, **g_args)\n            torch.autograd.backward(gy1, dy2)\n\n        with torch.no_grad():\n            x2 = y2 - gy1\n            del y2, gy1\n\n            dx1 = dy1 + y1.grad\n            del dy1\n            y1.grad = None\n\n        with torch.enable_grad():\n            x2.requires_grad = True\n            fx2 = self.f(x2, set_rng=True, **f_args)\n            torch.autograd.backward(fx2, dx1, retain_graph=True)\n\n        with torch.no_grad():\n            x1 = y1 - fx2\n            del y1, fx2\n\n            dx2 = dy2 + x2.grad\n            del dy2\n            x2.grad = None\n\n            x = torch.cat([x1, x2.detach()], dim = 1)\n            dx = torch.cat([dx1, dx2], dim = 1)\n\n        return x, dx\n\nclass IrreversibleBlock(nn.Module):\n    def __init__(self, f, g):\n        super().__init__()\n        self.f = f\n        self.g = g\n\n    def forward(self, x, f_args, g_args):\n        x1, x2 = torch.chunk(x, 2, dim = 1)\n        y1 = x1 + self.f(x2, **f_args)\n        y2 = x2 + self.g(y1, **g_args)\n        return torch.cat([y1, y2], dim = 1)\n\nclass _ReversibleFunction(Function):\n    @staticmethod\n    def forward(ctx, x, blocks, kwargs):\n        ctx.kwargs = kwargs\n        for block in blocks:\n            x = block(x, **kwargs)\n        ctx.y = x.detach()\n        ctx.blocks = blocks\n        return x\n\n    @staticmethod\n    def backward(ctx, dy):\n        y = ctx.y\n        kwargs = ctx.kwargs\n        for block in ctx.blocks[::-1]:\n            y, dy = block.backward_pass(y, dy, **kwargs)\n        return dy, None, None\n\nclass ReversibleSequence(nn.Module):\n    def __init__(self, blocks, ):\n        super().__init__()\n        self.blocks = nn.ModuleList([ReversibleBlock(f, g) for (f, g) in blocks])\n\n    def forward(self, x, arg_route = (True, True), **kwargs):\n        f_args, g_args = map(lambda route: kwargs if route else {}, arg_route)\n        block_kwargs = {'f_args': f_args, 'g_args': g_args}\n        x = torch.cat((x, x), dim = 1)\n        x = _ReversibleFunction.apply(x, self.blocks, block_kwargs)\n        return torch.stack(x.chunk(2, dim = 1)).mean(dim = 0)\n\n\n# helper functions\n\ndef exists(val):\n    return val is not None\n\ndef map_el_ind(arr, ind):\n    return list(map(itemgetter(ind), arr))\n\ndef sort_and_return_indices(arr):\n    indices = [ind for ind in range(len(arr))]\n    arr = zip(arr, indices)\n    arr = sorted(arr)\n    return map_el_ind(arr, 0), map_el_ind(arr, 1)\n\n# calculates the permutation to bring the input tensor to something attend-able\n# also calculates the inverse permutation to bring the tensor back to its original shape\n\ndef calculate_permutations(num_dimensions, emb_dim):\n    total_dimensions = num_dimensions + 2\n    emb_dim = emb_dim if emb_dim > 0 else (emb_dim + total_dimensions)\n    axial_dims = [ind for ind in range(1, total_dimensions) if ind != emb_dim]\n\n    permutations = []\n\n    for axial_dim in axial_dims:\n        last_two_dims = [axial_dim, emb_dim]\n        dims_rest = set(range(0, total_dimensions)) - set(last_two_dims)\n        permutation = [*dims_rest, *last_two_dims]\n        permutations.append(permutation)\n      \n    return permutations\n\n# helper classes\n\nclass ChanLayerNorm(nn.Module):\n    def __init__(self, dim, eps = 1e-5):\n        super().__init__()\n        self.eps = eps\n        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))\n        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))\n\n    def forward(self, x):\n        std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()\n        mean = torch.mean(x, dim = 1, keepdim = True)\n        return (x - mean) / (std + self.eps) * self.g + self.b\n\nclass PreNorm(nn.Module):\n    def __init__(self, dim, fn):\n        super().__init__()\n        self.fn = fn\n        self.norm = nn.LayerNorm(dim)\n\n    def forward(self, x):\n        x = self.norm(x)\n        return self.fn(x)\n\nclass Sequential(nn.Module):\n    def __init__(self, blocks):\n        super().__init__()\n        self.blocks = blocks\n\n    def forward(self, x):\n        for f, g in self.blocks:\n            x = x + f(x)\n            x = x + g(x)\n        return x\n\nclass PermuteToFrom(nn.Module):\n    def __init__(self, permutation, fn):\n        super().__init__()\n        self.fn = fn\n        _, inv_permutation = sort_and_return_indices(permutation)\n        self.permutation = permutation\n        self.inv_permutation = inv_permutation\n\n    def forward(self, x, **kwargs):\n        axial = x.permute(*self.permutation).contiguous()\n\n        shape = axial.shape\n        *_, t, d = shape\n\n        # merge all but axial dimension\n        axial = axial.reshape(-1, t, d)\n\n        # attention\n        axial = self.fn(axial, **kwargs)\n\n        # restore to original shape and permutation\n        axial = axial.reshape(*shape)\n        axial = axial.permute(*self.inv_permutation).contiguous()\n        return axial\n\n# axial pos emb\n\nclass AxialPositionalEmbedding(nn.Module):\n    def __init__(self, dim, shape, emb_dim_index = 1):\n        super().__init__()\n        parameters = []\n        total_dimensions = len(shape) + 2\n        ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]\n\n        self.num_axials = len(shape)\n\n        for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):\n            shape = [1] * total_dimensions\n            shape[emb_dim_index] = dim\n            shape[axial_dim_index] = axial_dim\n            parameter = nn.Parameter(torch.randn(*shape))\n            setattr(self, f'param_{i}', parameter)\n\n    def forward(self, x):\n        for i in range(self.num_axials):\n            x = x + getattr(self, f'param_{i}')\n        return x\n\n# attention\n\nclass SelfAttention(nn.Module):\n    def __init__(self, dim, heads, dim_heads = None):\n        super().__init__()\n        self.dim_heads = (dim // heads) if dim_heads is None else dim_heads\n        dim_hidden = self.dim_heads * heads\n\n        self.heads = heads\n        self.to_q = nn.Linear(dim, dim_hidden, bias = False)\n        self.to_kv = nn.Linear(dim, 2 * dim_hidden, bias = False)\n        self.to_out = nn.Linear(dim_hidden, dim)\n\n    def forward(self, x, kv = None):\n        kv = x if kv is None else kv\n        q, k, v = (self.to_q(x), *self.to_kv(kv).chunk(2, dim=-1))\n\n        b, t, d, h, e = *q.shape, self.heads, self.dim_heads\n\n        merge_heads = lambda x: x.reshape(b, -1, h, e).transpose(1, 2).reshape(b * h, -1, e)\n        q, k, v = map(merge_heads, (q, k, v))\n\n        dots = torch.einsum('bie,bje->bij', q, k) * (e ** -0.5)\n        dots = dots.softmax(dim=-1)\n        out = torch.einsum('bij,bje->bie', dots, v)\n\n        out = out.reshape(b, h, -1, e).transpose(1, 2).reshape(b, -1, d)\n        out = self.to_out(out)\n        return out\n\n# axial attention class\n\nclass AxialAttention(nn.Module):\n    def __init__(self, dim, num_dimensions = 2, heads = 8, dim_heads = None, dim_index = -1, sum_axial_out = True):\n        assert (dim % heads) == 0, 'hidden dimension must be divisible by number of heads'\n        super().__init__()\n        self.dim = dim\n        self.total_dimensions = num_dimensions + 2\n        self.dim_index = dim_index if dim_index > 0 else (dim_index + self.total_dimensions)\n\n        attentions = []\n        for permutation in calculate_permutations(num_dimensions, dim_index):\n            attentions.append(PermuteToFrom(permutation, SelfAttention(dim, heads, dim_heads)))\n\n        self.axial_attentions = nn.ModuleList(attentions)\n        self.sum_axial_out = sum_axial_out\n\n    def forward(self, x):\n        assert len(x.shape) == self.total_dimensions, 'input tensor does not have the correct number of dimensions'\n        assert x.shape[self.dim_index] == self.dim, 'input tensor does not have the correct input dimension'\n\n        if self.sum_axial_out:\n            return sum(map(lambda axial_attn: axial_attn(x), self.axial_attentions))\n\n        out = x\n        for axial_attn in self.axial_attentions:\n            out = axial_attn(out)\n        return out\n\n# axial image transformer\n\nclass AxialImageTransformer(nn.Module):\n    def __init__(self, dim, depth, heads = 8, dim_heads = None, dim_index = 1, reversible = True, axial_pos_emb_shape = None):\n        super().__init__()\n        permutations = calculate_permutations(2, dim_index)\n\n        get_ff = lambda: nn.Sequential(\n            ChanLayerNorm(dim),\n            nn.Conv2d(dim, dim * 4, 3, padding = 1),\n            nn.LeakyReLU(inplace=True),\n            nn.Conv2d(dim * 4, dim, 3, padding = 1)\n        )\n\n        self.pos_emb = AxialPositionalEmbedding(dim, axial_pos_emb_shape, dim_index) if exists(axial_pos_emb_shape) else nn.Identity()\n\n        layers = nn.ModuleList([])\n        for _ in range(depth):\n            attn_functions = nn.ModuleList([PermuteToFrom(permutation, PreNorm(dim, SelfAttention(dim, heads, dim_heads))) for permutation in permutations])\n            conv_functions = nn.ModuleList([get_ff(), get_ff()])\n            layers.append(attn_functions)\n            layers.append(conv_functions)            \n\n        execute_type = ReversibleSequence if reversible else Sequential\n        self.layers = execute_type(layers)\n\n    def forward(self, x):\n        x = self.pos_emb(x)\n        return self.layers(x)\n\n# input=torch.randn(3, 128, 7, 7)\n\n# attn = AxialAttention(\n#     dim = 3,               # embedding dimension\n#     dim_index = 1,         # where is the embedding dimension\n#     dim_heads = 32,        # dimension of each head. defaults to dim // heads if not supplied\n#     heads = 1,             # number of heads for multi-head attention\n#     num_dimensions = 2,    # number of axial dimensions (images is 2, video is 3, or more)\n#     sum_axial_out = True   # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true\n# )\n# print(attn(input).shape) # (1, 3, 256, 256)\n\nif __name__ == '__main__':\n    input=torch.randn(3, 128, 7, 7)\n    model = AxialImageTransformer(\n        dim = 128,\n        depth = 12,\n        reversible = True\n    )\n    outputs = model(input)\n    print(outputs.shape)"
  },
  {
    "path": "model/attention/BAM.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\n\nclass Flatten(nn.Module):\n    def forward(self,x):\n        return x.view(x.shape[0],-1)\n\nclass ChannelAttention(nn.Module):\n    def __init__(self,channel,reduction=16,num_layers=3):\n        super().__init__()\n        self.avgpool=nn.AdaptiveAvgPool2d(1)\n        gate_channels=[channel]\n        gate_channels+=[channel//reduction]*num_layers\n        gate_channels+=[channel]\n\n\n        self.ca=nn.Sequential()\n        self.ca.add_module('flatten',Flatten())\n        for i in range(len(gate_channels)-2):\n            self.ca.add_module('fc%d'%i,nn.Linear(gate_channels[i],gate_channels[i+1]))\n            self.ca.add_module('bn%d'%i,nn.BatchNorm1d(gate_channels[i+1]))\n            self.ca.add_module('relu%d'%i,nn.ReLU())\n        self.ca.add_module('last_fc',nn.Linear(gate_channels[-2],gate_channels[-1]))\n        \n\n    def forward(self, x) :\n        res=self.avgpool(x)\n        res=self.ca(res)\n        res=res.unsqueeze(-1).unsqueeze(-1).expand_as(x)\n        return res\n\nclass SpatialAttention(nn.Module):\n    def __init__(self,channel,reduction=16,num_layers=3,dia_val=2):\n        super().__init__()\n        self.sa=nn.Sequential()\n        self.sa.add_module('conv_reduce1',nn.Conv2d(kernel_size=1,in_channels=channel,out_channels=channel//reduction))\n        self.sa.add_module('bn_reduce1',nn.BatchNorm2d(channel//reduction))\n        self.sa.add_module('relu_reduce1',nn.ReLU())\n        for i in range(num_layers):\n            self.sa.add_module('conv_%d'%i,nn.Conv2d(kernel_size=3,in_channels=channel//reduction,out_channels=channel//reduction,padding=1,dilation=dia_val))\n            self.sa.add_module('bn_%d'%i,nn.BatchNorm2d(channel//reduction))\n            self.sa.add_module('relu_%d'%i,nn.ReLU())\n        self.sa.add_module('last_conv',nn.Conv2d(channel//reduction,1,kernel_size=1))\n\n    def forward(self, x) :\n        res=self.sa(x)\n        res=res.expand_as(x)\n        return res\n\n\n\n\nclass BAMBlock(nn.Module):\n\n    def __init__(self, channel=512,reduction=16,dia_val=2):\n        super().__init__()\n        self.ca=ChannelAttention(channel=channel,reduction=reduction)\n        self.sa=SpatialAttention(channel=channel,reduction=reduction,dia_val=dia_val)\n        self.sigmoid=nn.Sigmoid()\n\n\n    def init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n    def forward(self, x):\n        b, c, _, _ = x.size()\n        sa_out=self.sa(x)\n        ca_out=self.ca(x)\n        weight=self.sigmoid(sa_out+ca_out)\n        out=(1+weight)*x\n        return out\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    bam = BAMBlock(channel=512,reduction=16,dia_val=2)\n    output=bam(input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/CBAM.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\n\n\n\nclass ChannelAttention(nn.Module):\n    def __init__(self,channel,reduction=16):\n        super().__init__()\n        self.maxpool=nn.AdaptiveMaxPool2d(1)\n        self.avgpool=nn.AdaptiveAvgPool2d(1)\n        self.se=nn.Sequential(\n            nn.Conv2d(channel,channel//reduction,1,bias=False),\n            nn.ReLU(),\n            nn.Conv2d(channel//reduction,channel,1,bias=False)\n        )\n        self.sigmoid=nn.Sigmoid()\n    \n    def forward(self, x) :\n        max_result=self.maxpool(x)\n        avg_result=self.avgpool(x)\n        max_out=self.se(max_result)\n        avg_out=self.se(avg_result)\n        output=self.sigmoid(max_out+avg_out)\n        return output\n\nclass SpatialAttention(nn.Module):\n    def __init__(self,kernel_size=7):\n        super().__init__()\n        self.conv=nn.Conv2d(2,1,kernel_size=kernel_size,padding=kernel_size//2)\n        self.sigmoid=nn.Sigmoid()\n    \n    def forward(self, x) :\n        max_result,_=torch.max(x,dim=1,keepdim=True)\n        avg_result=torch.mean(x,dim=1,keepdim=True)\n        result=torch.cat([max_result,avg_result],1)\n        output=self.conv(result)\n        output=self.sigmoid(output)\n        return output\n\n\n\nclass CBAMBlock(nn.Module):\n\n    def __init__(self, channel=512,reduction=16,kernel_size=49):\n        super().__init__()\n        self.ca=ChannelAttention(channel=channel,reduction=reduction)\n        self.sa=SpatialAttention(kernel_size=kernel_size)\n\n\n    def init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n    def forward(self, x):\n        b, c, _, _ = x.size()\n        residual=x\n        out=x*self.ca(x)\n        out=out*self.sa(out)\n        return out+residual\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    kernel_size=input.shape[2]\n    cbam = CBAMBlock(channel=512,reduction=16,kernel_size=kernel_size)\n    output=cbam(input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/CoAtNet.py",
    "content": "from torch import nn, sqrt\nimport torch\nimport sys\nfrom math import sqrt\nsys.path.append('.')\nfrom model.conv.MBConv import MBConvBlock\nfrom model.attention.SelfAttention import ScaledDotProductAttention\n\nclass CoAtNet(nn.Module):\n    def __init__(self,in_ch,image_size,out_chs=[64,96,192,384,768]):\n        super().__init__()\n        self.out_chs=out_chs\n        self.maxpool2d=nn.MaxPool2d(kernel_size=2,stride=2)\n        self.maxpool1d = nn.MaxPool1d(kernel_size=2, stride=2)\n\n        self.s0=nn.Sequential(\n            nn.Conv2d(in_ch,in_ch,kernel_size=3,padding=1),\n            nn.ReLU(),\n            nn.Conv2d(in_ch,in_ch,kernel_size=3,padding=1)\n        )\n        self.mlp0=nn.Sequential(\n            nn.Conv2d(in_ch,out_chs[0],kernel_size=1),\n            nn.ReLU(),\n            nn.Conv2d(out_chs[0],out_chs[0],kernel_size=1)\n        )\n        \n        self.s1=MBConvBlock(ksize=3,input_filters=out_chs[0],output_filters=out_chs[0],image_size=image_size//2)\n        self.mlp1=nn.Sequential(\n            nn.Conv2d(out_chs[0],out_chs[1],kernel_size=1),\n            nn.ReLU(),\n            nn.Conv2d(out_chs[1],out_chs[1],kernel_size=1)\n        )\n\n        self.s2=MBConvBlock(ksize=3,input_filters=out_chs[1],output_filters=out_chs[1],image_size=image_size//4)\n        self.mlp2=nn.Sequential(\n            nn.Conv2d(out_chs[1],out_chs[2],kernel_size=1),\n            nn.ReLU(),\n            nn.Conv2d(out_chs[2],out_chs[2],kernel_size=1)\n        )\n\n        self.s3=ScaledDotProductAttention(out_chs[2],out_chs[2]//8,out_chs[2]//8,8)\n        self.mlp3=nn.Sequential(\n            nn.Linear(out_chs[2],out_chs[3]),\n            nn.ReLU(),\n            nn.Linear(out_chs[3],out_chs[3])\n        )\n\n        self.s4=ScaledDotProductAttention(out_chs[3],out_chs[3]//8,out_chs[3]//8,8)\n        self.mlp4=nn.Sequential(\n            nn.Linear(out_chs[3],out_chs[4]),\n            nn.ReLU(),\n            nn.Linear(out_chs[4],out_chs[4])\n        )\n\n\n    def forward(self, x) :\n        B,C,H,W=x.shape\n        #stage0\n        y=self.mlp0(self.s0(x))\n        y=self.maxpool2d(y)\n        #stage1\n        y=self.mlp1(self.s1(y))\n        y=self.maxpool2d(y)\n        #stage2\n        y=self.mlp2(self.s2(y))\n        y=self.maxpool2d(y)\n        #stage3\n        y=y.reshape(B,self.out_chs[2],-1).permute(0,2,1) #B,N,C\n        y=self.mlp3(self.s3(y,y,y))\n        y=self.maxpool1d(y.permute(0,2,1)).permute(0,2,1)\n        #stage4\n        y=self.mlp4(self.s4(y,y,y))\n        y=self.maxpool1d(y.permute(0,2,1))\n        N=y.shape[-1]\n        y=y.reshape(B,self.out_chs[4],int(sqrt(N)),int(sqrt(N)))\n\n        return y\n\nif __name__ == '__main__':\n    x=torch.randn(1,3,224,224)\n    coatnet=CoAtNet(3,224)\n    y=coatnet(x)\n    print(y.shape)\n    "
  },
  {
    "path": "model/attention/CoTAttention.py",
    "content": "import numpy as np\nimport torch\nfrom torch import flatten, nn\nfrom torch.nn import init\nfrom torch.nn.modules.activation import ReLU\nfrom torch.nn.modules.batchnorm import BatchNorm2d\nfrom torch.nn import functional as F\n\n\n\nclass CoTAttention(nn.Module):\n\n    def __init__(self, dim=512,kernel_size=3):\n        super().__init__()\n        self.dim=dim\n        self.kernel_size=kernel_size\n\n        self.key_embed=nn.Sequential(\n            nn.Conv2d(dim,dim,kernel_size=kernel_size,padding=kernel_size//2,groups=4,bias=False),\n            nn.BatchNorm2d(dim),\n            nn.ReLU()\n        )\n        self.value_embed=nn.Sequential(\n            nn.Conv2d(dim,dim,1,bias=False),\n            nn.BatchNorm2d(dim)\n        )\n\n        factor=4\n        self.attention_embed=nn.Sequential(\n            nn.Conv2d(2*dim,2*dim//factor,1,bias=False),\n            nn.BatchNorm2d(2*dim//factor),\n            nn.ReLU(),\n            nn.Conv2d(2*dim//factor,kernel_size*kernel_size*dim,1)\n        )\n\n\n    def forward(self, x):\n        bs,c,h,w=x.shape\n        k1=self.key_embed(x) #bs,c,h,w\n        v=self.value_embed(x).view(bs,c,-1) #bs,c,h,w\n\n        y=torch.cat([k1,x],dim=1) #bs,2c,h,w\n        att=self.attention_embed(y) #bs,c*k*k,h,w\n        att=att.reshape(bs,c,self.kernel_size*self.kernel_size,h,w)\n        att=att.mean(2,keepdim=False).view(bs,c,-1) #bs,c,h*w\n        k2=F.softmax(att,dim=-1)*v\n        k2=k2.view(bs,c,h,w)\n\n\n        return k1+k2\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    cot = CoTAttention(dim=512,kernel_size=3)\n    output=cot(input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/CoordAttention.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass h_sigmoid(nn.Module):\n    def __init__(self, inplace=True):\n        super(h_sigmoid, self).__init__()\n        self.relu = nn.ReLU6(inplace=inplace)\n\n    def forward(self, x):\n        return self.relu(x + 3) / 6\n\nclass h_swish(nn.Module):\n    def __init__(self, inplace=True):\n        super(h_swish, self).__init__()\n        self.sigmoid = h_sigmoid(inplace=inplace)\n\n    def forward(self, x):\n        return x * self.sigmoid(x)\n\nclass CoordAtt(nn.Module):\n    def __init__(self, inp, oup, reduction=32):\n        super(CoordAtt, self).__init__()\n        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))\n        self.pool_w = nn.AdaptiveAvgPool2d((1, None))\n\n        mip = max(8, inp // reduction)\n\n        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)\n        self.bn1 = nn.BatchNorm2d(mip)\n        self.act = h_swish()\n        \n        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)\n        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)\n        \n\n    def forward(self, x):\n        identity = x\n        \n        n,c,h,w = x.size()\n        x_h = self.pool_h(x)\n        x_w = self.pool_w(x).permute(0, 1, 3, 2)\n\n        y = torch.cat([x_h, x_w], dim=2)\n        y = self.conv1(y)\n        y = self.bn1(y)\n        y = self.act(y) \n        \n        x_h, x_w = torch.split(y, [h, w], dim=2)\n        x_w = x_w.permute(0, 1, 3, 2)\n\n        a_h = self.conv_h(x_h).sigmoid()\n        a_w = self.conv_w(x_w).sigmoid()\n\n        out = identity * a_w * a_h\n\n        return out"
  },
  {
    "path": "model/attention/CrissCrossAttention.py",
    "content": "'''\nThis code is borrowed from Serge-weihao/CCNet-Pure-Pytorch\n'''\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import Softmax\n\n\ndef INF(B,H,W):\n     return -torch.diag(torch.tensor(float(\"inf\")).repeat(H),0).unsqueeze(0).repeat(B*W,1,1)\n\n\nclass CrissCrossAttention(nn.Module):\n    \"\"\" Criss-Cross Attention Module\"\"\"\n    def __init__(self, in_dim):\n        super(CrissCrossAttention,self).__init__()\n        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)\n        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)\n        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)\n        self.softmax = Softmax(dim=3)\n        self.INF = INF\n        self.gamma = nn.Parameter(torch.zeros(1))\n\n\n    def forward(self, x):\n        m_batchsize, _, height, width = x.size()\n        proj_query = self.query_conv(x)\n        proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)\n        proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)\n        proj_key = self.key_conv(x)\n        proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)\n        proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)\n        proj_value = self.value_conv(x)\n        proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)\n        proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)\n        energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)\n        energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)\n        concate = self.softmax(torch.cat([energy_H, energy_W], 3))\n\n        att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)\n        #print(concate)\n        #print(att_H) \n        att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)\n        out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)\n        out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)\n        #print(out_H.size(),out_W.size())\n        return self.gamma*(out_H + out_W) + x\n\nif __name__ == '__main__':\n    input=torch.randn(3, 64, 7, 7)\n    model = CrissCrossAttention(64)\n    outputs = model(input)\n    print(outputs.shape)"
  },
  {
    "path": "model/attention/Crossformer.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\nclass DynamicPosBias(nn.Module):\n    def __init__(self, dim, num_heads, residual):\n        super().__init__()\n        self.residual = residual\n        self.num_heads = num_heads\n        self.pos_dim = dim // 4\n        self.pos_proj = nn.Linear(2, self.pos_dim)\n        self.pos1 = nn.Sequential(\n            nn.LayerNorm(self.pos_dim),\n            nn.ReLU(inplace=True),\n            nn.Linear(self.pos_dim, self.pos_dim),\n        )\n        self.pos2 = nn.Sequential(\n            nn.LayerNorm(self.pos_dim),\n            nn.ReLU(inplace=True),\n            nn.Linear(self.pos_dim, self.pos_dim)\n        )\n        self.pos3 = nn.Sequential(\n            nn.LayerNorm(self.pos_dim),\n            nn.ReLU(inplace=True),\n            nn.Linear(self.pos_dim, self.num_heads)\n        )\n    def forward(self, biases):\n        if self.residual:\n            pos = self.pos_proj(biases) # 2Wh-1 * 2Ww-1, heads\n            pos = pos + self.pos1(pos)\n            pos = pos + self.pos2(pos)\n            pos = self.pos3(pos)\n        else:\n            pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))\n        return pos\n\n    def flops(self, N):\n        flops = N * 2 * self.pos_dim\n        flops += N * self.pos_dim * self.pos_dim\n        flops += N * self.pos_dim * self.pos_dim\n        flops += N * self.pos_dim * self.num_heads\n        return flops\n\nclass Attention(nn.Module):\n    r\"\"\" Multi-head self attention module with dynamic position bias.\n\n    Args:\n        dim (int): Number of input channels.\n        group_size (tuple[int]): The height and width of the group.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(self, dim, group_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,\n                 position_bias=True):\n\n        super().__init__()\n        self.dim = dim\n        self.group_size = group_size  # Wh, Ww\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n        self.position_bias = position_bias\n\n        if position_bias:\n            self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)\n            \n            # generate mother-set\n            position_bias_h = torch.arange(1 - self.group_size[0], self.group_size[0])\n            position_bias_w = torch.arange(1 - self.group_size[1], self.group_size[1])\n            biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))  # 2, 2Wh-1, 2W2-1\n            biases = biases.flatten(1).transpose(0, 1).float()\n            self.register_buffer(\"biases\", biases)\n\n            # get pair-wise relative position index for each token inside the group\n            coords_h = torch.arange(self.group_size[0])\n            coords_w = torch.arange(self.group_size[1])\n            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n            relative_coords[:, :, 0] += self.group_size[0] - 1  # shift to start from 0\n            relative_coords[:, :, 1] += self.group_size[1] - 1\n            relative_coords[:, :, 0] *= 2 * self.group_size[1] - 1\n            relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n            self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask=None):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_groups*B, N, C)\n            mask: (0/-inf) mask with shape of (num_groups, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        if self.position_bias:\n            pos = self.pos(self.biases) # 2Wh-1 * 2Ww-1, heads\n            # select position bias\n            relative_position_bias = pos[self.relative_position_index.view(-1)].view(\n                self.group_size[0] * self.group_size[1], self.group_size[0] * self.group_size[1], -1)  # Wh*Ww,Wh*Ww,nH\n            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n            attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f'dim={self.dim}, group_size={self.group_size}, num_heads={self.num_heads}'\n\n    def flops(self, N):\n        # calculate flops for 1 group with token length of N\n        flops = 0\n        # qkv = self.qkv(x)\n        flops += N * self.dim * 3 * self.dim\n        # attn = (q @ k.transpose(-2, -1))\n        flops += self.num_heads * N * (self.dim // self.num_heads) * N\n        #  x = (attn @ v)\n        flops += self.num_heads * N * N * (self.dim // self.num_heads)\n        # x = self.proj(x)\n        flops += N * self.dim * self.dim\n        if self.position_bias:\n            flops += self.pos.flops(N)\n        return flops\n\n\nclass CrossFormerBlock(nn.Module):\n    r\"\"\" CrossFormer Block.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        num_heads (int): Number of attention heads.\n        group_size (int): Group size.\n        lsda_flag (int): use SDA or LDA, 0 for SDA and 1 for LDA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, num_heads, group_size=7, lsda_flag=0,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, num_patch_size=1):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.group_size = group_size\n        self.lsda_flag = lsda_flag\n        self.mlp_ratio = mlp_ratio\n        self.num_patch_size = num_patch_size\n        if min(self.input_resolution) <= self.group_size:\n            # if group size is larger than input resolution, we don't partition groups\n            self.lsda_flag = 0\n            self.group_size = min(self.input_resolution)\n\n        self.norm1 = norm_layer(dim)\n\n        self.attn = Attention(\n            dim, group_size=to_2tuple(self.group_size), num_heads=num_heads,\n            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,\n            position_bias=True)\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        attn_mask = None\n        self.register_buffer(\"attn_mask\", attn_mask)\n\n    def forward(self, x):\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size %d, %d, %d\" % (L, H, W)\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # group embeddings\n        G = self.group_size\n        if self.lsda_flag == 0: # 0 for SDA\n            x = x.reshape(B, H // G, G, W // G, G, C).permute(0, 1, 3, 2, 4, 5)\n        else: # 1 for LDA\n            x = x.reshape(B, G, H // G, G, W // G, C).permute(0, 2, 4, 1, 3, 5)\n        x = x.reshape(B * H * W // G**2, G**2, C)\n\n        # multi-head self-attention\n        x = self.attn(x, mask=self.attn_mask)  # nW*B, G*G, C\n\n        # ungroup embeddings\n        x = x.reshape(B, H // G, W // G, G, G, C)\n        if self.lsda_flag == 0:\n            x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, H, W, C)\n        else:\n            x = x.permute(0, 3, 1, 4, 2, 5).reshape(B, H, W, C)\n        x = x.view(B, H * W, C)\n\n        # FFN\n        x = shortcut + self.drop_path(x)\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n               f\"group_size={self.group_size}, lsda_flag={self.lsda_flag}, mlp_ratio={self.mlp_ratio}\"\n\n    def flops(self):\n        flops = 0\n        H, W = self.input_resolution\n        # norm1\n        flops += self.dim * H * W\n        # LSDA\n        nW = H * W / self.group_size / self.group_size\n        flops += nW * self.attn.flops(self.group_size * self.group_size)\n        # mlp\n        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n        # norm2\n        flops += self.dim * H * W\n        return flops\n\nclass PatchMerging(nn.Module):\n    r\"\"\" Patch Merging Layer.\n\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, patch_size=[2], num_input_patch_size=1):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reductions = nn.ModuleList()\n        self.patch_size = patch_size\n        self.norm = norm_layer(dim)\n\n        for i, ps in enumerate(patch_size):\n            if i == len(patch_size) - 1:\n                out_dim = 2 * dim // 2 ** i\n            else:\n                out_dim = 2 * dim // 2 ** (i + 1)\n            stride = 2\n            padding = (ps - stride) // 2\n            self.reductions.append(nn.Conv2d(dim, out_dim, kernel_size=ps, \n                                                stride=stride, padding=padding))\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n        assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n\n        x = self.norm(x)\n        x = x.view(B, H, W, C).permute(0, 3, 1, 2)\n\n        xs = []\n        for i in range(len(self.reductions)):\n            tmp_x = self.reductions[i](x).flatten(2).transpose(1, 2)\n            xs.append(tmp_x)\n        x = torch.cat(xs, dim=2)\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n\n    def flops(self):\n        H, W = self.input_resolution\n        flops = H * W * self.dim\n        for i, ps in enumerate(self.patch_size):\n            if i == len(self.patch_size) - 1:\n                out_dim = 2 * self.dim // 2 ** i\n            else:\n                out_dim = 2 * self.dim // 2 ** (i + 1)\n            flops += (H // 2) * (W // 2) * ps * ps * out_dim * self.dim\n        return flops\n\n\nclass Stage(nn.Module):\n    \"\"\" CrossFormer blocks for one stage.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        group_size (int): variable G in the paper, one group has GxG embeddings\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, group_size,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,\n                 patch_size_end=[4], num_patch_size=None):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList()\n        for i in range(depth):\n            lsda_flag = 0 if (i % 2 == 0) else 1\n            self.blocks.append(CrossFormerBlock(dim=dim, input_resolution=input_resolution,\n                                 num_heads=num_heads, group_size=group_size,\n                                 lsda_flag=lsda_flag,\n                                 mlp_ratio=mlp_ratio,\n                                 qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                 drop=drop, attn_drop=attn_drop,\n                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                                 norm_layer=norm_layer,\n                                 num_patch_size=num_patch_size))\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer, \n                                         patch_size=patch_size_end, num_input_patch_size=num_patch_size)\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops\n\n\nclass PatchEmbed(nn.Module):\n    r\"\"\" Image to Patch Embedding\n\n    Args:\n        img_size (int): Image size.  Default: 224.\n        patch_size (int): Patch token size. Default: [4].\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=[4], in_chans=3, embed_dim=96, norm_layer=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        # patch_size = to_2tuple(patch_size)\n        patches_resolution = [img_size[0] // patch_size[0], img_size[0] // patch_size[0]]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.projs = nn.ModuleList()\n        for i, ps in enumerate(patch_size):\n            if i == len(patch_size) - 1:\n                dim = embed_dim // 2 ** i\n            else:\n                dim = embed_dim // 2 ** (i + 1)\n            stride = patch_size[0]\n            padding = (ps - patch_size[0]) // 2\n            self.projs.append(nn.Conv2d(in_chans, dim, kernel_size=ps, stride=stride, padding=padding))\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        xs = []\n        for i in range(len(self.projs)):\n            tx = self.projs[i](x).flatten(2).transpose(1, 2)\n            xs.append(tx)  # B Ph*Pw C\n        x = torch.cat(xs, dim=2)\n        if self.norm is not None:\n            x = self.norm(x)\n        return x\n\n    def flops(self):\n        Ho, Wo = self.patches_resolution\n        flops = 0\n        for i, ps in enumerate(self.patch_size):\n            if i == len(self.patch_size) - 1:\n                dim = self.embed_dim // 2 ** i\n            else:\n                dim = self.embed_dim // 2 ** (i + 1)\n            flops += Ho * Wo * dim * self.in_chans * (self.patch_size[i] * self.patch_size[i])\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops\n\n\nclass CrossFormer(nn.Module):\n    r\"\"\" CrossFormer\n        A PyTorch impl of : `CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention`  -\n\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 224\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each stage.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        group_size (int): Group size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None\n        drop_rate (float): Dropout rate. Default: 0\n        attn_drop_rate (float): Attention dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=[4], in_chans=3, num_classes=1000,\n                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],\n                 group_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n                 use_checkpoint=False, merge_size=[[2], [2], [2]], **kwargs):\n        super().__init__()\n\n        self.num_classes = num_classes\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n        self.mlp_ratio = mlp_ratio\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        # absolute position embedding\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n            trunc_normal_(self.absolute_pos_embed, std=.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n\n        num_patch_sizes = [len(patch_size)] + [len(m) for m in merge_size]\n        for i_layer in range(self.num_layers):\n            patch_size_end = merge_size[i_layer] if i_layer < self.num_layers - 1 else None\n            num_patch_size = num_patch_sizes[i_layer]\n            layer = Stage(dim=int(embed_dim * 2 ** i_layer),\n                               input_resolution=(patches_resolution[0] // (2 ** i_layer),\n                                                 patches_resolution[1] // (2 ** i_layer)),\n                               depth=depths[i_layer],\n                               num_heads=num_heads[i_layer],\n                               group_size=group_size[i_layer],\n                               mlp_ratio=self.mlp_ratio,\n                               qkv_bias=qkv_bias, qk_scale=qk_scale,\n                               drop=drop_rate, attn_drop=attn_drop_rate,\n                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n                               norm_layer=norm_layer,\n                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n                               use_checkpoint=use_checkpoint,\n                               patch_size_end=patch_size_end,\n                               num_patch_size=num_patch_size)\n            self.layers.append(layer)\n\n        self.norm = norm_layer(self.num_features)\n        self.avgpool = nn.AdaptiveAvgPool1d(1)\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {'relative_position_bias_table'}\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        for layer in self.layers:\n            x = layer(x)\n\n        x = self.norm(x)  # B L C\n        x = self.avgpool(x.transpose(1, 2))  # B C 1\n        x = torch.flatten(x, 1)\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n    def flops(self):\n        flops = 0\n        flops += self.patch_embed.flops()\n        for i, layer in enumerate(self.layers):\n            flops += layer.flops()\n        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)\n        flops += self.num_features * self.num_classes\n        return flops\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CrossFormer(img_size=224,\n        patch_size=[4, 8, 16, 32],\n        in_chans= 3,\n        num_classes=1000,\n        embed_dim=48,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        group_size=[7, 7, 7, 7],\n        mlp_ratio=4.,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.0,\n        drop_path_rate=0.1,\n        ape=False,\n        patch_norm=True,\n        use_checkpoint=False,\n        merge_size=[[2, 4], [2,4], [2, 4]]\n    )\n    output=model(input)\n    print(output.shape)"
  },
  {
    "path": "model/attention/DANet.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\nfrom model.attention.SelfAttention import ScaledDotProductAttention\nfrom model.attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttention\n\nclass PositionAttentionModule(nn.Module):\n\n    def __init__(self,d_model=512,kernel_size=3,H=7,W=7):\n        super().__init__()\n        self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)\n        self.pa=ScaledDotProductAttention(d_model,d_k=d_model,d_v=d_model,h=1)\n    \n    def forward(self,x):\n        bs,c,h,w=x.shape\n        y=self.cnn(x)\n        y=y.view(bs,c,-1).permute(0,2,1) #bs,h*w,c\n        y=self.pa(y,y,y) #bs,h*w,c\n        return y\n\n\nclass ChannelAttentionModule(nn.Module):\n    \n    def __init__(self,d_model=512,kernel_size=3,H=7,W=7):\n        super().__init__()\n        self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)\n        self.pa=SimplifiedScaledDotProductAttention(H*W,h=1)\n    \n    def forward(self,x):\n        bs,c,h,w=x.shape\n        y=self.cnn(x)\n        y=y.view(bs,c,-1) #bs,c,h*w\n        y=self.pa(y,y,y) #bs,c,h*w\n        return y\n\n\n\n\nclass DAModule(nn.Module):\n\n    def __init__(self,d_model=512,kernel_size=3,H=7,W=7):\n        super().__init__()\n        self.position_attention_module=PositionAttentionModule(d_model=512,kernel_size=3,H=7,W=7)\n        self.channel_attention_module=ChannelAttentionModule(d_model=512,kernel_size=3,H=7,W=7)\n    \n    def forward(self,input):\n        bs,c,h,w=input.shape\n        p_out=self.position_attention_module(input)\n        c_out=self.channel_attention_module(input)\n        p_out=p_out.permute(0,2,1).view(bs,c,h,w)\n        c_out=c_out.view(bs,c,h,w)\n        return p_out+c_out\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    danet=DAModule(d_model=512,kernel_size=3,H=7,W=7)\n    print(danet(input).shape)\n"
  },
  {
    "path": "model/attention/DAT.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n# Vision Transformer with Deformable Attention\n# Modified by Zhuofan Xia \n# --------------------------------------------------------\n\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport einops\nfrom timm.models.layers import to_2tuple, trunc_normal_\nfrom timm.models.layers import DropPath, to_2tuple\n\nclass LocalAttention(nn.Module):\n\n    def __init__(self, dim, heads, window_size, attn_drop, proj_drop):\n        \n        super().__init__()\n\n        window_size = to_2tuple(window_size)\n\n        self.proj_qkv = nn.Linear(dim, 3 * dim)\n        self.heads = heads\n        assert dim % heads == 0\n        head_dim = dim // heads\n        self.scale = head_dim ** -0.5\n        self.proj_out = nn.Linear(dim, dim)\n        self.window_size = window_size\n        self.proj_drop = nn.Dropout(proj_drop, inplace=True)\n        self.attn_drop = nn.Dropout(attn_drop, inplace=True)\n\n        Wh, Ww = self.window_size\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * Wh - 1) * (2 * Ww - 1), heads)\n        )\n        trunc_normal_(self.relative_position_bias_table, std=0.01)\n\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n    def forward(self, x, mask=None):\n\n        B, C, H, W = x.size()\n        r1, r2 = H // self.window_size[0], W // self.window_size[1]\n        \n        x_total = einops.rearrange(x, 'b c (r1 h1) (r2 w1) -> b (r1 r2) (h1 w1) c', h1=self.window_size[0], w1=self.window_size[1]) # B x Nr x Ws x C\n        \n        x_total = einops.rearrange(x_total, 'b m n c -> (b m) n c')\n\n        qkv = self.proj_qkv(x_total) # B' x N x 3C\n        q, k, v = torch.chunk(qkv, 3, dim=2)\n\n        q = q * self.scale\n        q, k, v = [einops.rearrange(t, 'b n (h c1) -> b h n c1', h=self.heads) for t in [q, k, v]]\n        attn = torch.einsum('b h m c, b h n c -> b h m n', q, k)\n\n        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n        attn_bias = relative_position_bias\n        attn = attn + attn_bias.unsqueeze(0)\n\n        if mask is not None:\n            # attn =(b * nW) h w w\n            # mask =nW ww ww\n            nW, ww, _ = mask.size()\n            attn = einops.rearrange(attn, '(b n) h w1 w2 -> b n h w1 w2', n=nW, h=self.heads, w1=ww, w2=ww) + mask.reshape(1, nW, 1, ww, ww)\n            attn = einops.rearrange(attn, 'b n h w1 w2 -> (b n) h w1 w2')\n        attn = self.attn_drop(attn.softmax(dim=3))\n\n        x = torch.einsum('b h m n, b h n c -> b h m c', attn, v)\n        x = einops.rearrange(x, 'b h n c1 -> b n (h c1)')\n        x = self.proj_drop(self.proj_out(x)) # B' x N x C\n        x = einops.rearrange(x, '(b r1 r2) (h1 w1) c -> b c (r1 h1) (r2 w1)', r1=r1, r2=r2, h1=self.window_size[0], w1=self.window_size[1]) # B x C x H x W\n        \n        return x, None, None\n\nclass ShiftWindowAttention(LocalAttention):\n\n    def __init__(self, dim, heads, window_size, attn_drop, proj_drop, shift_size, fmap_size):\n        \n        super().__init__(dim, heads, window_size, attn_drop, proj_drop)\n\n        self.fmap_size = to_2tuple(fmap_size)\n        self.shift_size = shift_size\n\n        assert 0 < self.shift_size < min(self.window_size), \"wrong shift size.\"\n\n        img_mask = torch.zeros(*self.fmap_size)  # H W\n        h_slices = (slice(0, -self.window_size[0]),\n                    slice(-self.window_size[0], -self.shift_size),\n                    slice(-self.shift_size, None))\n        w_slices = (slice(0, -self.window_size[1]),\n                    slice(-self.window_size[1], -self.shift_size),\n                    slice(-self.shift_size, None))\n        cnt = 0\n        for h in h_slices:\n            for w in w_slices:\n                img_mask[h, w] = cnt\n                cnt += 1\n        mask_windows = einops.rearrange(img_mask, '(r1 h1) (r2 w1) -> (r1 r2) (h1 w1)', h1=self.window_size[0],w1=self.window_size[1])\n        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW ww ww\n        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n        self.register_buffer(\"attn_mask\", attn_mask)\n      \n    def forward(self, x):\n\n        shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3))\n        sw_x, _, _ = super().forward(shifted_x, self.attn_mask)\n        x = torch.roll(sw_x, shifts=(self.shift_size, self.shift_size), dims=(2, 3))\n\n        return x, None, None\n    \n\nclass DAttentionBaseline(nn.Module):\n\n    def __init__(\n        self, q_size, kv_size, n_heads, n_head_channels, n_groups,\n        attn_drop, proj_drop, stride, \n        offset_range_factor, use_pe, dwc_pe,\n        no_off, fixed_pe, stage_idx\n    ):\n\n        super().__init__()\n        self.dwc_pe = dwc_pe\n        self.n_head_channels = n_head_channels\n        self.scale = self.n_head_channels ** -0.5\n        self.n_heads = n_heads\n        self.q_h, self.q_w = q_size\n        self.kv_h, self.kv_w = kv_size\n        self.nc = n_head_channels * n_heads\n        self.n_groups = n_groups\n        self.n_group_channels = self.nc // self.n_groups\n        self.n_group_heads = self.n_heads // self.n_groups\n        self.use_pe = use_pe\n        self.fixed_pe = fixed_pe\n        self.no_off = no_off\n        self.offset_range_factor = offset_range_factor\n        \n        ksizes = [9, 7, 5, 3]\n        kk = ksizes[stage_idx]\n\n        self.conv_offset = nn.Sequential(\n            nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, kk//2, groups=self.n_group_channels),\n            LayerNormProxy(self.n_group_channels),\n            nn.GELU(),\n            nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False)\n        )\n\n        self.proj_q = nn.Conv2d(\n            self.nc, self.nc,\n            kernel_size=1, stride=1, padding=0\n        )\n\n        self.proj_k = nn.Conv2d(\n            self.nc, self.nc,\n            kernel_size=1, stride=1, padding=0\n        )\n\n        self.proj_v = nn.Conv2d(\n            self.nc, self.nc,\n            kernel_size=1, stride=1, padding=0\n        )\n\n        self.proj_out = nn.Conv2d(\n            self.nc, self.nc,\n            kernel_size=1, stride=1, padding=0\n        )\n\n        self.proj_drop = nn.Dropout(proj_drop, inplace=True)\n        self.attn_drop = nn.Dropout(attn_drop, inplace=True)\n\n        if self.use_pe:\n            if self.dwc_pe:\n                self.rpe_table = nn.Conv2d(self.nc, self.nc, \n                                           kernel_size=3, stride=1, padding=1, groups=self.nc)\n            elif self.fixed_pe:\n                self.rpe_table = nn.Parameter(\n                    torch.zeros(self.n_heads, self.q_h * self.q_w, self.kv_h * self.kv_w)\n                )\n                trunc_normal_(self.rpe_table, std=0.01)\n            else:\n                self.rpe_table = nn.Parameter(\n                    torch.zeros(self.n_heads, self.kv_h * 2 - 1, self.kv_w * 2 - 1)\n                )\n                trunc_normal_(self.rpe_table, std=0.01)\n        else:\n            self.rpe_table = None\n    \n    @torch.no_grad()\n    def _get_ref_points(self, H_key, W_key, B, dtype, device):\n        \n        ref_y, ref_x = torch.meshgrid(\n            torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device), \n            torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device)\n        )\n        ref = torch.stack((ref_y, ref_x), -1)\n        ref[..., 1].div_(W_key).mul_(2).sub_(1)\n        ref[..., 0].div_(H_key).mul_(2).sub_(1)\n        ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1) # B * g H W 2\n        \n        return ref\n\n    def forward(self, x):\n\n        B, C, H, W = x.size()\n        dtype, device = x.dtype, x.device\n        \n        q = self.proj_q(x)\n        q_off = einops.rearrange(q, 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=self.n_group_channels)\n        offset = self.conv_offset(q_off) # B * g 2 Hg Wg\n        Hk, Wk = offset.size(2), offset.size(3)\n        n_sample = Hk * Wk\n        \n        if self.offset_range_factor > 0:\n            offset_range = torch.tensor([1.0 / Hk, 1.0 / Wk], device=device).reshape(1, 2, 1, 1)\n            offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)\n            \n        offset = einops.rearrange(offset, 'b p h w -> b h w p')\n        reference = self._get_ref_points(Hk, Wk, B, dtype, device)\n            \n        if self.no_off:\n            offset = offset.fill(0.0)\n            \n        if self.offset_range_factor >= 0:\n            pos = offset + reference\n        else:\n            pos = (offset + reference).tanh()\n        \n        x_sampled = F.grid_sample(\n            input=x.reshape(B * self.n_groups, self.n_group_channels, H, W), \n            grid=pos[..., (1, 0)], # y, x -> x, y\n            mode='bilinear', align_corners=True) # B * g, Cg, Hg, Wg\n            \n        x_sampled = x_sampled.reshape(B, C, 1, n_sample)\n\n        q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)\n        k = self.proj_k(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)\n        v = self.proj_v(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)\n        \n        attn = torch.einsum('b c m, b c n -> b m n', q, k) # B * h, HW, Ns\n        attn = attn.mul(self.scale)\n        \n        if self.use_pe:\n            \n            if self.dwc_pe:\n                residual_lepe = self.rpe_table(q.reshape(B, C, H, W)).reshape(B * self.n_heads, self.n_head_channels, H * W)\n            elif self.fixed_pe:\n                rpe_table = self.rpe_table\n                attn_bias = rpe_table[None, ...].expand(B, -1, -1, -1)\n                attn = attn + attn_bias.reshape(B * self.n_heads, H * W, self.n_sample)\n            else:\n                rpe_table = self.rpe_table\n                rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)\n                \n                q_grid = self._get_ref_points(H, W, B, dtype, device)\n                \n                displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(0.5)\n                \n                attn_bias = F.grid_sample(\n                    input=rpe_bias.reshape(B * self.n_groups, self.n_group_heads, 2 * H - 1, 2 * W - 1),\n                    grid=displacement[..., (1, 0)],\n                    mode='bilinear', align_corners=True\n                ) # B * g, h_g, HW, Ns\n                \n                attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)\n                \n                attn = attn + attn_bias\n\n        attn = F.softmax(attn, dim=2)\n        attn = self.attn_drop(attn)\n        \n        out = torch.einsum('b m n, b c n -> b c m', attn, v)\n        \n        if self.use_pe and self.dwc_pe:\n            out = out + residual_lepe\n        out = out.reshape(B, C, H, W)\n        \n        y = self.proj_drop(self.proj_out(out))\n        \n        return y, pos.reshape(B, self.n_groups, Hk, Wk, 2), reference.reshape(B, self.n_groups, Hk, Wk, 2)\n\nclass TransformerMLP(nn.Module):\n\n    def __init__(self, channels, expansion, drop):\n        \n        super().__init__()\n        \n        self.dim1 = channels\n        self.dim2 = channels * expansion\n        self.chunk = nn.Sequential()\n        self.chunk.add_module('linear1', nn.Linear(self.dim1, self.dim2))\n        self.chunk.add_module('act', nn.GELU())\n        self.chunk.add_module('drop1', nn.Dropout(drop, inplace=True))\n        self.chunk.add_module('linear2', nn.Linear(self.dim2, self.dim1))\n        self.chunk.add_module('drop2', nn.Dropout(drop, inplace=True))\n    \n    def forward(self, x):\n\n        _, _, H, W = x.size()\n        x = einops.rearrange(x, 'b c h w -> b (h w) c')\n        x = self.chunk(x)\n        x = einops.rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)\n        return x\n\nclass LayerNormProxy(nn.Module):\n    \n    def __init__(self, dim):\n        \n        super().__init__()\n        self.norm = nn.LayerNorm(dim)\n\n    def forward(self, x):\n\n        x = einops.rearrange(x, 'b c h w -> b h w c')\n        x = self.norm(x)\n        return einops.rearrange(x, 'b h w c -> b c h w')\n\nclass TransformerMLPWithConv(nn.Module):\n\n    def __init__(self, channels, expansion, drop):\n        \n        super().__init__()\n        \n        self.dim1 = channels\n        self.dim2 = channels * expansion\n        self.linear1 = nn.Conv2d(self.dim1, self.dim2, 1, 1, 0)\n        self.drop1 = nn.Dropout(drop, inplace=True)\n        self.act = nn.GELU()\n        self.linear2 = nn.Conv2d(self.dim2, self.dim1, 1, 1, 0) \n        self.drop2 = nn.Dropout(drop, inplace=True)\n        self.dwc = nn.Conv2d(self.dim2, self.dim2, 3, 1, 1, groups=self.dim2)\n    \n    def forward(self, x):\n        \n        x = self.drop1(self.act(self.dwc(self.linear1(x))))\n        x = self.drop2(self.linear2(x))\n        \n        return x\n\nclass TransformerStage(nn.Module):\n\n    def __init__(self, fmap_size, window_size, ns_per_pt,\n                 dim_in, dim_embed, depths, stage_spec, n_groups, \n                 use_pe, sr_ratio, \n                 heads, stride, offset_range_factor, stage_idx,\n                 dwc_pe, no_off, fixed_pe,\n                 attn_drop, proj_drop, expansion, drop, drop_path_rate, use_dwc_mlp):\n\n        super().__init__()\n        fmap_size = to_2tuple(fmap_size)\n        self.depths = depths\n        hc = dim_embed // heads\n        assert dim_embed == heads * hc\n        self.proj = nn.Conv2d(dim_in, dim_embed, 1, 1, 0) if dim_in != dim_embed else nn.Identity()\n\n        self.layer_norms = nn.ModuleList(\n            [LayerNormProxy(dim_embed) for _ in range(2 * depths)]\n        )\n        self.mlps = nn.ModuleList(\n            [\n                TransformerMLPWithConv(dim_embed, expansion, drop) \n                if use_dwc_mlp else TransformerMLP(dim_embed, expansion, drop)\n                for _ in range(depths)\n            ]\n        )\n        self.attns = nn.ModuleList()\n        self.drop_path = nn.ModuleList()\n        for i in range(depths):\n            if stage_spec[i] == 'L':\n                self.attns.append(\n                    LocalAttention(dim_embed, heads, window_size, attn_drop, proj_drop)\n                )\n            elif stage_spec[i] == 'D':\n                self.attns.append(\n                    DAttentionBaseline(fmap_size, fmap_size, heads, \n                    hc, n_groups, attn_drop, proj_drop, \n                    stride, offset_range_factor, use_pe, dwc_pe, \n                    no_off, fixed_pe, stage_idx)\n                )\n            elif stage_spec[i] == 'S':\n                shift_size = math.ceil(window_size / 2)\n                self.attns.append(\n                    ShiftWindowAttention(dim_embed, heads, window_size, attn_drop, proj_drop, shift_size, fmap_size)\n                )\n            else:\n                raise NotImplementedError(f'Spec={stage_spec[i]} is not supported.')\n            \n            self.drop_path.append(DropPath(drop_path_rate[i]) if drop_path_rate[i] > 0.0 else nn.Identity())\n        \n    def forward(self, x):\n        \n        x = self.proj(x)\n        \n        positions = []\n        references = []\n        for d in range(self.depths):\n\n            x0 = x\n            x, pos, ref = self.attns[d](self.layer_norms[2 * d](x))\n            x = self.drop_path[d](x) + x0\n            x0 = x\n            x = self.mlps[d](self.layer_norms[2 * d + 1](x))\n            x = self.drop_path[d](x) + x0\n            positions.append(pos)\n            references.append(ref)\n\n        return x, positions, references\n\nclass DAT(nn.Module):\n\n    def __init__(self, img_size=224, patch_size=4, num_classes=1000, expansion=4,\n                 dim_stem=96, dims=[96, 192, 384, 768], depths=[2, 2, 6, 2], \n                 heads=[3, 6, 12, 24], \n                 window_sizes=[7, 7, 7, 7],\n                 drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, \n                 strides=[-1,-1,-1,-1], offset_range_factor=[1, 2, 3, 4], \n                 stage_spec=[['L', 'D'], ['L', 'D'], ['L', 'D', 'L', 'D', 'L', 'D'], ['L', 'D']], \n                 groups=[-1, -1, 3, 6],\n                 use_pes=[False, False, False, False], \n                 dwc_pes=[False, False, False, False],\n                 sr_ratios=[8, 4, 2, 1], \n                 fixed_pes=[False, False, False, False],\n                 no_offs=[False, False, False, False],\n                 ns_per_pts=[4, 4, 4, 4],\n                 use_dwc_mlps=[False, False, False, False],\n                 use_conv_patches=False,\n                 **kwargs):\n        super().__init__()\n\n        self.patch_proj = nn.Sequential(\n            nn.Conv2d(3, dim_stem, 7, patch_size, 3),\n            LayerNormProxy(dim_stem)\n        ) if use_conv_patches else nn.Sequential(\n            nn.Conv2d(3, dim_stem, patch_size, patch_size, 0),\n            LayerNormProxy(dim_stem)\n        ) \n\n        img_size = img_size // patch_size\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]\n        \n        self.stages = nn.ModuleList()\n        for i in range(4):\n            dim1 = dim_stem if i == 0 else dims[i - 1] * 2\n            dim2 = dims[i]\n            self.stages.append(\n                TransformerStage(img_size, window_sizes[i], ns_per_pts[i],\n                dim1, dim2, depths[i], stage_spec[i], groups[i], use_pes[i], \n                sr_ratios[i], heads[i], strides[i], \n                offset_range_factor[i], i,\n                dwc_pes[i], no_offs[i], fixed_pes[i],\n                attn_drop_rate, drop_rate, expansion, drop_rate, \n                dpr[sum(depths[:i]):sum(depths[:i + 1])],\n                use_dwc_mlps[i])\n            )\n            img_size = img_size // 2\n\n        self.down_projs = nn.ModuleList()\n        for i in range(3):\n            self.down_projs.append(\n                nn.Sequential(\n                    nn.Conv2d(dims[i], dims[i + 1], 3, 2, 1, bias=False),\n                    LayerNormProxy(dims[i + 1])\n                ) if use_conv_patches else nn.Sequential(\n                    nn.Conv2d(dims[i], dims[i + 1], 2, 2, 0, bias=False),\n                    LayerNormProxy(dims[i + 1])\n                )\n            )\n           \n        self.cls_norm = LayerNormProxy(dims[-1]) \n        self.cls_head = nn.Linear(dims[-1], num_classes)\n        \n        self.reset_parameters()\n    \n    def reset_parameters(self):\n\n        for m in self.parameters():\n            if isinstance(m, (nn.Linear, nn.Conv2d)):\n                nn.init.kaiming_normal_(m.weight)\n                nn.init.zeros_(m.bias)\n                \n    @torch.no_grad()\n    def load_pretrained(self, state_dict):\n        \n        new_state_dict = {}\n        for state_key, state_value in state_dict.items():\n            keys = state_key.split('.')\n            m = self\n            for key in keys:\n                if key.isdigit():\n                    m = m[int(key)]\n                else:\n                    m = getattr(m, key)\n            if m.shape == state_value.shape:\n                new_state_dict[state_key] = state_value\n            else:\n                # Ignore different shapes\n                if 'relative_position_index' in keys:\n                    new_state_dict[state_key] = m.data\n                if 'q_grid' in keys:\n                    new_state_dict[state_key] = m.data\n                if 'reference' in keys:\n                    new_state_dict[state_key] = m.data\n                # Bicubic Interpolation\n                if 'relative_position_bias_table' in keys:\n                    n, c = state_value.size()\n                    l = int(math.sqrt(n))\n                    assert n == l ** 2\n                    L = int(math.sqrt(m.shape[0]))\n                    pre_interp = state_value.reshape(1, l, l, c).permute(0, 3, 1, 2)\n                    post_interp = F.interpolate(pre_interp, (L, L), mode='bicubic')\n                    new_state_dict[state_key] = post_interp.reshape(c, L ** 2).permute(1, 0)\n                if 'rpe_table' in keys:\n                    c, h, w = state_value.size()\n                    C, H, W = m.data.size()\n                    pre_interp = state_value.unsqueeze(0)\n                    post_interp = F.interpolate(pre_interp, (H, W), mode='bicubic')\n                    new_state_dict[state_key] = post_interp.squeeze(0)\n        \n        self.load_state_dict(new_state_dict, strict=False)\n    \n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {'relative_position_bias_table', 'rpe_table'}\n    \n    def forward(self, x):\n        \n        x = self.patch_proj(x)\n        positions = []\n        references = []\n        for i in range(4):\n            x, pos, ref = self.stages[i](x)\n            if i < 3:\n                x = self.down_projs[i](x)\n            positions.append(pos)\n            references.append(ref)\n        x = self.cls_norm(x)\n        x = F.adaptive_avg_pool2d(x, 1)\n        x = torch.flatten(x, 1)\n        x = self.cls_head(x)\n        \n        return x, positions, references\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = DAT(\n        img_size=224,\n        patch_size=4,\n        num_classes=1000,\n        expansion=4,\n        dim_stem=96,\n        dims=[96, 192, 384, 768],\n        depths=[2, 2, 6, 2],\n        stage_spec=[['L', 'S'], ['L', 'S'], ['L', 'D', 'L', 'D', 'L', 'D'], ['L', 'D']],\n        heads=[3, 6, 12, 24],\n        window_sizes=[7, 7, 7, 7] ,\n        groups=[-1, -1, 3, 6],\n        use_pes=[False, False, True, True],\n        dwc_pes=[False, False, False, False],\n        strides=[-1, -1, 1, 1],\n        sr_ratios=[-1, -1, -1, -1],\n        offset_range_factor=[-1, -1, 2, 2],\n        no_offs=[False, False, False, False],\n        fixed_pes=[False, False, False, False],\n        use_dwc_mlps=[False, False, False, False],\n        use_conv_patches=False,\n        drop_rate=0.0,\n        attn_drop_rate=0.0,\n        drop_path_rate=0.2,\n    )\n    output=model(input)\n    print(output[0].shape)"
  },
  {
    "path": "model/attention/ECAAttention.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\nfrom collections import OrderedDict\n\n\n\nclass ECAAttention(nn.Module):\n\n    def __init__(self, kernel_size=3):\n        super().__init__()\n        self.gap=nn.AdaptiveAvgPool2d(1)\n        self.conv=nn.Conv1d(1,1,kernel_size=kernel_size,padding=(kernel_size-1)//2)\n        self.sigmoid=nn.Sigmoid()\n\n    def init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n    def forward(self, x):\n        y=self.gap(x) #bs,c,1,1\n        y=y.squeeze(-1).permute(0,2,1) #bs,1,c\n        y=self.conv(y) #bs,1,c\n        y=self.sigmoid(y) #bs,1,c\n        y=y.permute(0,2,1).unsqueeze(-1) #bs,c,1,1\n        return x*y.expand_as(x)\n\n        \n\n\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    eca = ECAAttention(kernel_size=3)\n    output=eca(input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/EMSA.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\n\n\n\nclass EMSA(nn.Module):\n\n    def __init__(self, d_model, d_k, d_v, h,dropout=.1,H=7,W=7,ratio=3,apply_transform=True):\n\n        super(EMSA, self).__init__()\n        self.H=H\n        self.W=W\n        self.fc_q = nn.Linear(d_model, h * d_k)\n        self.fc_k = nn.Linear(d_model, h * d_k)\n        self.fc_v = nn.Linear(d_model, h * d_v)\n        self.fc_o = nn.Linear(h * d_v, d_model)\n        self.dropout=nn.Dropout(dropout)\n\n        self.ratio=ratio\n        if(self.ratio>1):\n            self.sr=nn.Sequential()\n            self.sr_conv=nn.Conv2d(d_model,d_model,kernel_size=ratio+1,stride=ratio,padding=ratio//2,groups=d_model)\n            self.sr_ln=nn.LayerNorm(d_model)\n\n        self.apply_transform=apply_transform and h>1\n        if(self.apply_transform):\n            self.transform=nn.Sequential()\n            self.transform.add_module('conv',nn.Conv2d(h,h,kernel_size=1,stride=1))\n            self.transform.add_module('softmax',nn.Softmax(-1))\n            self.transform.add_module('in',nn.InstanceNorm2d(h))\n\n        self.d_model = d_model\n        self.d_k = d_k\n        self.d_v = d_v\n        self.h = h\n\n        self.init_weights()\n\n\n    def init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):\n\n        b_s, nq ,c = queries.shape\n        nk = keys.shape[1]\n\n        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)\n\n        if(self.ratio>1):\n            x=queries.permute(0,2,1).view(b_s,c,self.H,self.W) #bs,c,H,W\n            x=self.sr_conv(x) #bs,c,h,w\n            x=x.contiguous().view(b_s,c,-1).permute(0,2,1) #bs,n',c\n            x=self.sr_ln(x)\n            k = self.fc_k(x).view(b_s, -1, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, n')\n            v = self.fc_v(x).view(b_s, -1, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, n', d_v)\n        else:\n            k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)\n            v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)\n\n        if(self.apply_transform):\n            att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, n')\n            att = self.transform(att) # (b_s, h, nq, n')\n        else:\n            att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, n')\n            att = torch.softmax(att, -1) # (b_s, h, nq, n')\n\n\n        if attention_weights is not None:\n            att = att * attention_weights\n        if attention_mask is not None:\n            att = att.masked_fill(attention_mask, -np.inf)\n        \n        att=self.dropout(att)\n\n        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)\n        out = self.fc_o(out)  # (b_s, nq, d_model)\n        return out\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,64,512)\n    emsa = EMSA(d_model=512, d_k=512, d_v=512, h=8,H=8,W=8,ratio=2,apply_transform=True)\n    output=emsa(input,input,input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/ExternalAttention.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\n\n\n\nclass ExternalAttention(nn.Module):\n\n    def __init__(self, d_model,S=64):\n        super().__init__()\n        self.mk=nn.Linear(d_model,S,bias=False)\n        self.mv=nn.Linear(S,d_model,bias=False)\n        self.softmax=nn.Softmax(dim=1)\n        self.init_weights()\n\n\n    def init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n    def forward(self, queries):\n        attn=self.mk(queries) #bs,n,S\n        attn=self.softmax(attn) #bs,n,S\n        attn=attn/torch.sum(attn,dim=2,keepdim=True) #bs,n,S\n        out=self.mv(attn) #bs,n,d_model\n\n        return out\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    ea = ExternalAttention(d_model=512,S=8)\n    output=ea(input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/HaloAttention.py",
    "content": "import torch\nfrom torch import nn, einsum\nimport torch.nn.functional as F\n\nfrom einops import rearrange, repeat\n\n# relative positional embedding\n\ndef to(x):\n    return {'device': x.device, 'dtype': x.dtype}\n\ndef pair(x):\n    return (x, x) if not isinstance(x, tuple) else x\n\ndef expand_dim(t, dim, k):\n    t = t.unsqueeze(dim = dim)\n    expand_shape = [-1] * len(t.shape)\n    expand_shape[dim] = k\n    return t.expand(*expand_shape)\n\ndef rel_to_abs(x):\n    b, l, m = x.shape\n    r = (m + 1) // 2\n\n    col_pad = torch.zeros((b, l, 1), **to(x))\n    x = torch.cat((x, col_pad), dim = 2)\n    flat_x = rearrange(x, 'b l c -> b (l c)')\n    flat_pad = torch.zeros((b, m - l), **to(x))\n    flat_x_padded = torch.cat((flat_x, flat_pad), dim = 1)\n    final_x = flat_x_padded.reshape(b, l + 1, m)\n    final_x = final_x[:, :l, -r:]\n    return final_x\n\ndef relative_logits_1d(q, rel_k):\n    b, h, w, _ = q.shape\n    r = (rel_k.shape[0] + 1) // 2\n\n    logits = einsum('b x y d, r d -> b x y r', q, rel_k)\n    logits = rearrange(logits, 'b x y r -> (b x) y r')\n    logits = rel_to_abs(logits)\n\n    logits = logits.reshape(b, h, w, r)\n    logits = expand_dim(logits, dim = 2, k = r)\n    return logits\n\nclass RelPosEmb(nn.Module):\n    def __init__(\n        self,\n        block_size,\n        rel_size,\n        dim_head\n    ):\n        super().__init__()\n        height = width = rel_size\n        scale = dim_head ** -0.5\n\n        self.block_size = block_size\n        self.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)\n        self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)\n\n    def forward(self, q):\n        block = self.block_size\n\n        q = rearrange(q, 'b (x y) c -> b x y c', x = block)\n        rel_logits_w = relative_logits_1d(q, self.rel_width)\n        rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)')\n\n        q = rearrange(q, 'b x y d -> b y x d')\n        rel_logits_h = relative_logits_1d(q, self.rel_height)\n        rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)')\n        return rel_logits_w + rel_logits_h\n\n# classes\n\nclass HaloAttention(nn.Module):\n    def __init__(\n        self,\n        *,\n        dim,\n        block_size,\n        halo_size,\n        dim_head = 64,\n        heads = 8\n    ):\n        super().__init__()\n        assert halo_size > 0, 'halo size must be greater than 0'\n\n        self.dim = dim\n        self.heads = heads\n        self.scale = dim_head ** -0.5\n\n        self.block_size = block_size\n        self.halo_size = halo_size\n\n        inner_dim = dim_head * heads\n\n        self.rel_pos_emb = RelPosEmb(\n            block_size = block_size,\n            rel_size = block_size + (halo_size * 2),\n            dim_head = dim_head\n        )\n\n        self.to_q  = nn.Linear(dim, inner_dim, bias = False)\n        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)\n        self.to_out = nn.Linear(inner_dim, dim)\n\n    def forward(self, x):\n        b, c, h, w, block, halo, heads, device = *x.shape, self.block_size, self.halo_size, self.heads, x.device\n        assert h % block == 0 and w % block == 0, 'fmap dimensions must be divisible by the block size'\n        assert c == self.dim, f'channels for input ({c}) does not equal to the correct dimension ({self.dim})'\n\n        # get block neighborhoods, and prepare a halo-ed version (blocks with padding) for deriving key values\n\n        q_inp = rearrange(x, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1 = block, p2 = block)\n\n        kv_inp = F.unfold(x, kernel_size = block + halo * 2, stride = block, padding = halo)\n        kv_inp = rearrange(kv_inp, 'b (c j) i -> (b i) j c', c = c)\n\n        # derive queries, keys, values\n\n        q = self.to_q(q_inp)\n        k, v = self.to_kv(kv_inp).chunk(2, dim = -1)\n\n        # split heads\n\n        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = heads), (q, k, v))\n\n        # scale\n\n        q *= self.scale\n\n        # attention\n\n        sim = einsum('b i d, b j d -> b i j', q, k)\n\n        # add relative positional bias\n\n        sim += self.rel_pos_emb(q)\n\n        # mask out padding (in the paper, they claim to not need masks, but what about padding?)\n\n        mask = torch.ones(1, 1, h, w, device = device)\n        mask = F.unfold(mask, kernel_size = block + (halo * 2), stride = block, padding = halo)\n        mask = repeat(mask, '() j i -> (b i h) () j', b = b, h = heads)\n        mask = mask.bool()\n\n        max_neg_value = -torch.finfo(sim.dtype).max\n        sim.masked_fill_(mask, max_neg_value)\n\n        # attention\n\n        attn = sim.softmax(dim = -1)\n\n        # aggregate\n\n        out = einsum('b i j, b j d -> b i d', attn, v)\n\n        # merge and combine heads\n\n        out = rearrange(out, '(b h) n d -> b n (h d)', h = heads)\n        out = self.to_out(out)\n\n        # merge blocks back to original feature map\n\n        out = rearrange(out, '(b h w) (p1 p2) c -> b c (h p1) (w p2)', b = b, h = (h // block), w = (w // block), p1 = block, p2 = block)\n        return out\n\nif __name__ == '__main__':\n    input=torch.randn(1,512,8,8)\n    halo = HaloAttention(dim=512,\n        block_size=2,\n        halo_size=1,)\n    output=halo(input)\n    print(output.shape)"
  },
  {
    "path": "model/attention/MOATransformer.py",
    "content": "\n# --------------------------------------------------------\n# Adopted from Swin Transformer\n# Modified by Krushi Patel\n# --------------------------------------------------------\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nfrom einops.layers.torch import Rearrange, Reduce\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\n\n\n\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowAttention(nn.Module):\n    r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.query_size = self.window_size\n        self.key_size = self.window_size[0] * 2\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n        \n        # define a parameter table of relative position bias\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n        \n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        \n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        \n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        trunc_normal_(self.relative_position_bias_table, std=.02)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n\n        \n        attn = attn + relative_position_bias.unsqueeze(0)\n\n      \n        attn = self.softmax(attn)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n    def extra_repr(self) -> str:\n        #return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'\n        return f'dim={self.dim}, num_heads={self.num_heads}'\n    def flops(self, N):\n        # calculate flops for 1 window with token length of N\n        flops = 0\n        # qkv = self.qkv(x)\n        flops += N * self.dim * 3 * self.dim\n        # attn = (q @ k.transpose(-2, -1))\n        flops += self.num_heads * N * (self.dim // self.num_heads) * N\n        #  x = (attn @ v)\n        flops += self.num_heads * N * N * (self.dim // self.num_heads)\n        # x = self.proj(x)\n        flops += N * self.dim * self.dim\n        return flops\n\nclass GlobalAttention(nn.Module):\n    r\"\"\" MOA - multi-head self attention (W-MSA) module with relative position bias.\n\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(self, dim, window_size, input_resolution,num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.query_size = self.window_size[0]\n       \n        self.key_size = self.window_size[0] + 2\n        h,w = input_resolution\n        self.seq_len = h//self.query_size\n    \n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n        self.reduction = 32\n        self.pre_conv = nn.Conv2d(dim, int(dim//self.reduction), 1)\n     \n      \n        # define a parameter table of relative position bias\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * self.seq_len - 1) * (2 * self.seq_len - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n        #print(self.relative_position_bias_table.shape)\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.seq_len)\n        coords_w = torch.arange(self.seq_len)\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n     \n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        \n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        \n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n\n        relative_coords[:, :, 0] += self.seq_len - 1  # shift to start from 0\n\n        relative_coords[:, :, 1] += self.seq_len - 1\n        relative_coords[:, :, 0] *= 2 * self.seq_len - 1\n       \n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n      \n        self.register_buffer(\"relative_position_index\", relative_position_index)\n        \n       \n\n        self.queryembedding = Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = self.query_size, p2 = self. query_size)\n\n        self.keyembedding = nn.Unfold(kernel_size=(self.key_size, self.key_size), stride = 14, padding=1)\n   \n        self.query_dim = int(dim//self.reduction) * self.query_size * self.query_size\n        self.key_dim = int(dim//self.reduction) * self.key_size * self.key_size\n                \n        self.q = nn.Linear(self.query_dim, self.dim,bias=qkv_bias)\n        self.kv = nn.Linear(self.key_dim, 2*self.dim,bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim,dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        #trunc_normal_(self.relative_position_bias_table, std=.02)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, H, W):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n\n        #B, H, W, C = x.shape\n        B,_, C = x.shape  \n          \n        x = x.reshape(-1, C, H, W)    \n        x = self.pre_conv(x)\n        query = self.queryembedding(x).view(B,-1,self.query_dim)\n        query = self.q(query)\n        B,N,C = query.size()\n        \n        q = query.reshape(B,N,self.num_heads, C//self.num_heads).permute(0,2,1,3)\n        key = self.keyembedding(x).view(B,-1,self.key_dim)\n        kv = self.kv(key).reshape(B,N,2,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)\n        k = kv[0]\n        v = kv[1]\n        \n        \n        q = q * self.scale\n        \n        attn = (q @ k.transpose(-2, -1))\n\n        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n            self.seq_len * self.seq_len, self.seq_len * self.seq_len, -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n\n       \n        attn = attn + relative_position_bias.unsqueeze(0)\n      \n        attn = self.softmax(attn)\n        \n        attn = self.attn_drop(attn)\n     \n        x = (attn @ v).transpose(1, 2).reshape(B, N, C) \n     \n        x = self.proj(x)\n      \n        x = self.proj_drop(x)\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'\n\n    def flops(self, N):\n        # calculate flops for 1 window with token length of N\n        flops = 0\n        # qkv = self.qkv(x)\n        flops += N * self.dim * 3 * self.dim\n        # attn = (q @ k.transpose(-2, -1))\n        flops += self.num_heads * N * (self.dim // self.num_heads) * N\n        #  x = (attn @ v)\n        flops += self.num_heads * N * N * (self.dim // self.num_heads)\n        # x = self.proj(x)\n        flops += N * self.dim * self.dim\n        return flops\n\n\nclass LocalTransformerBlock(nn.Module):\n    r\"\"\" Local Transformer Block.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, num_heads, window_size=7,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.mlp_ratio = mlp_ratio\n   \n        if min(self.input_resolution) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n           \n            self.window_size = min(self.input_resolution)\n   \n       \n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,\n            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n\n\n    def forward(self, x):\n        H, W = self.input_resolution\n     \n        B, L, C = x.shape\n      \n        assert L == H * W, \"input feature has wrong size\"\n       \n        shortcut = x\n        x = self.norm1(x)\n        \n        x = x.view(B, H, W, C)\n     \n\n    \n        x_windows = window_partition(x, self.window_size)  # nW*B, window_size, window_size, C \n        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C     \n        attn_windows = self.attn(x_windows)  # nW*B, window_size*window_size, C    \n        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n        x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C\n        x = x.view(B, H * W, C)\n\n\n       \n\n \n        x = shortcut + self.drop_path(x)\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n               f\"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}\"\n\n    def flops(self):\n        flops = 0\n        H, W = self.input_resolution\n        # norm1\n        flops += self.dim * H * W\n        # W-MSA/SW-MSA\n        nW = H * W / self.window_size / self.window_size\n        flops += nW * self.attn.flops(self.window_size * self.window_size)\n        # mlp\n        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n        # norm2\n        flops += self.dim * H * W\n        return flops\n\n\nclass PatchMerging(nn.Module):\n    \"\"\" Patch Merging Layer.\n\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n        assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n\n        x = x.view(B, H, W, C)\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.norm(x)\n        x = self.reduction(x)\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n\n    def flops(self):\n        H, W = self.input_resolution\n        flops = H * W * self.dim\n        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim\n        return flops\n\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic Swin Transformer layer for one stage.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, drop_path_global=0., use_checkpoint=False):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n        self.window_size = window_size\n       \n        self.drop_path_gl = DropPath(drop_path_global) if drop_path_global > 0. else nn.Identity()\n        # build blocks\n        self.blocks = nn.ModuleList([\n            LocalTransformerBlock(dim=dim, input_resolution=input_resolution,\n                                 num_heads=num_heads, window_size=window_size,\n                                 mlp_ratio=mlp_ratio,\n                                 qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                 drop=drop, attn_drop=attn_drop,\n                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                                 norm_layer=norm_layer)\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n           if min(self.input_resolution) >= self.window_size:\n                 self.glb_attn = GlobalAttention(dim, to_2tuple(window_size), self.input_resolution, num_heads = num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n                 self.post_conv = nn.Conv2d(dim, dim, 3, padding=1)\n                 self.norm1 = norm_layer(dim)\n                 self.norm2 = norm_layer(dim)\n                \n           else:\n                 self.post_conv = None\n                 self.glb_attn = None\n                 self.norm1 = None\n                 self.norm2 = None\n                \n           self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n           \n            \n        else:\n            self.downsample = None\n            \n    def forward(self, x):\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n                \n        if self.downsample is not None:\n        \n           if min(self.input_resolution) >= self.window_size:\n                shortcut = x\n                x = self.norm1(x)\n                H, W = self.input_resolution\n                B,_,C = x.size()\n         \n                no_window = int(H*W/self.window_size**2)   \n                local_attn = x.view(B,no_window,self.window_size, self.window_size,C)\n             \n                glb_attn = self.glb_attn(x, H, W)\n                glb_attn = glb_attn.view(B,no_window,1,1,C)\n                x = torch.add(local_attn, glb_attn).view(B,C,H,W)\n              \n\n                x = shortcut.view(B,C,H,W) + self.drop_path_gl(x)\n                x = self.norm2(x.view(B,H*W,C))\n                post_conv = self.drop_path_gl(self.post_conv(x.view(B,C,H,W))).view(B, H*W, C)\n                x = x + post_conv\n                \n           x = self.downsample(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops\n\n\nclass PatchEmbed(nn.Module):\n    r\"\"\" Image to Patch Embedding\n\n    Args:\n        img_size (int): Image size.  Default: 224.\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C\n        if self.norm is not None:\n            x = self.norm(x)\n        return x\n\n    def flops(self):\n        Ho, Wo = self.patches_resolution\n        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops\n\n\nclass MOATransformer(nn.Module):\n    r\"\"\" Swin Transformer\n        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -\n          https://arxiv.org/pdf/2103.14030\n\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 224\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each Swin Transformer layer.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        window_size (int): Window size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None\n        drop_rate (float): Dropout rate. Default: 0\n        attn_drop_rate (float): Attention dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,\n                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],\n                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n                 use_checkpoint=False, **kwargs):\n        super().__init__()\n\n        self.num_classes = num_classes\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n        self.mlp_ratio = mlp_ratio\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        # absolute position embedding\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n            trunc_normal_(self.absolute_pos_embed, std=.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n        dpr_global = [x.item() for x in torch.linspace(0, 0.2, len(depths)-1)]\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),\n                               input_resolution=(patches_resolution[0] // (2 ** i_layer),\n                                                 patches_resolution[1] // (2 ** i_layer)),\n                               depth=depths[i_layer],\n                               num_heads=num_heads[i_layer],\n                               window_size=window_size,\n                               mlp_ratio=self.mlp_ratio,\n                               qkv_bias=qkv_bias, qk_scale=qk_scale,\n                               drop=drop_rate, attn_drop=attn_drop_rate,\n                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n                               norm_layer=norm_layer,\n                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n                               drop_path_global = (dpr_global[i_layer]) if (i_layer < self.num_layers -1) else 0,\n                               use_checkpoint=use_checkpoint)\n            self.layers.append(layer)\n\n        self.norm = norm_layer(self.num_features)\n        self.avgpool = nn.AdaptiveAvgPool1d(1)\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {'relative_position_bias_table'}\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n \n        for layer in self.layers:\n            x = layer(x)\n           \n        x = self.norm(x)  # B L C\n        x = self.avgpool(x.transpose(1, 2))  # B C 1\n        x = torch.flatten(x, 1)\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n    def flops(self):\n        flops = 0\n        flops += self.patch_embed.flops()\n        for i, layer in enumerate(self.layers):\n            flops += layer.flops()\n        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)\n        flops += self.num_features * self.num_classes\n        return flops\n\n# --------------------------------------------------------\n# Adopted from Swin Transformer\n# Modified by Krushi Patel\n# print(sum(p.numel() for p in model.parameters() if p.requires_grad), 'parameters')\n# --------------------------------------------------------\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = MOATransformer(\n        img_size=224,\n        patch_size=4,\n        in_chans=3,\n        num_classes=1000,\n        embed_dim=96,\n        depths=[2, 2, 6],\n        num_heads=[3, 6, 12],\n        window_size=14,\n        mlp_ratio=4.,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.0,\n        drop_path_rate=0.1,\n        ape=False,\n        patch_norm=True,\n        use_checkpoint=False\n    )\n    output=model(input)\n    print(output.shape)"
  },
  {
    "path": "model/attention/MUSEAttention.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\n\n\n\nclass Depth_Pointwise_Conv1d(nn.Module):\n    def __init__(self,in_ch,out_ch,k):\n        super().__init__()\n        if(k==1):\n            self.depth_conv=nn.Identity()\n        else:\n            self.depth_conv=nn.Conv1d(\n                in_channels=in_ch,\n                out_channels=in_ch,\n                kernel_size=k,\n                groups=in_ch,\n                padding=k//2\n                )\n        self.pointwise_conv=nn.Conv1d(\n            in_channels=in_ch,\n            out_channels=out_ch,\n            kernel_size=1,\n            groups=1\n        )\n    def forward(self,x):\n        out=self.pointwise_conv(self.depth_conv(x))\n        return out\n    \n\n\nclass MUSEAttention(nn.Module):\n\n    def __init__(self, d_model, d_k, d_v, h,dropout=.1):\n\n\n        super(MUSEAttention, self).__init__()\n        self.fc_q = nn.Linear(d_model, h * d_k)\n        self.fc_k = nn.Linear(d_model, h * d_k)\n        self.fc_v = nn.Linear(d_model, h * d_v)\n        self.fc_o = nn.Linear(h * d_v, d_model)\n        self.dropout=nn.Dropout(dropout)\n\n        self.conv1=Depth_Pointwise_Conv1d(h * d_v, d_model,1)\n        self.conv3=Depth_Pointwise_Conv1d(h * d_v, d_model,3)\n        self.conv5=Depth_Pointwise_Conv1d(h * d_v, d_model,5)\n        self.dy_paras=nn.Parameter(torch.ones(3))\n        self.softmax=nn.Softmax(-1)\n\n        self.d_model = d_model\n        self.d_k = d_k\n        self.d_v = d_v\n        self.h = h\n\n        self.init_weights()\n\n\n    def init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):\n\n        #Self Attention\n        b_s, nq = queries.shape[:2]\n        nk = keys.shape[1]\n\n        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)\n        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)\n        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)\n\n        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)\n        if attention_weights is not None:\n            att = att * attention_weights\n        if attention_mask is not None:\n            att = att.masked_fill(attention_mask, -np.inf)\n        att = torch.softmax(att, -1)\n        att=self.dropout(att)\n\n        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)\n        out = self.fc_o(out)  # (b_s, nq, d_model)\n\n        v2=v.permute(0,1,3,2).contiguous().view(b_s,-1,nk) #bs,dim,n\n        self.dy_paras=nn.Parameter(self.softmax(self.dy_paras))\n        out2=self.dy_paras[0]*self.conv1(v2)+self.dy_paras[1]*self.conv3(v2)+self.dy_paras[2]*self.conv5(v2)\n        out2=out2.permute(0,2,1) #bs.n.dim\n\n        out=out+out2\n        return out\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    sa = MUSEAttention(d_model=512, d_k=512, d_v=512, h=8)\n    output=sa(input,input,input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/MobileViTAttention.py",
    "content": "from torch import nn\nimport torch\nfrom einops import rearrange\n\n\nclass PreNorm(nn.Module):\n    def __init__(self,dim,fn):\n        super().__init__()\n        self.ln=nn.LayerNorm(dim)\n        self.fn=fn\n    def forward(self,x,**kwargs):\n        return self.fn(self.ln(x),**kwargs)\n\nclass FeedForward(nn.Module):\n    def __init__(self,dim,mlp_dim,dropout) :\n        super().__init__()\n        self.net=nn.Sequential(\n            nn.Linear(dim,mlp_dim),\n            nn.SiLU(),\n            nn.Dropout(dropout),\n            nn.Linear(mlp_dim,dim),\n            nn.Dropout(dropout)\n        )\n    def forward(self,x):\n        return self.net(x)\n\nclass Attention(nn.Module):\n    def __init__(self,dim,heads,head_dim,dropout):\n        super().__init__()\n        inner_dim=heads*head_dim\n        project_out=not(heads==1 and head_dim==dim)\n\n        self.heads=heads\n        self.scale=head_dim**-0.5\n\n        self.attend=nn.Softmax(dim=-1)\n        self.to_qkv=nn.Linear(dim,inner_dim*3,bias=False)\n        \n        self.to_out=nn.Sequential(\n            nn.Linear(inner_dim,dim),\n            nn.Dropout(dropout)\n        ) if project_out else nn.Identity()\n\n    def forward(self,x):\n        qkv=self.to_qkv(x).chunk(3,dim=-1)\n        q,k,v=map(lambda t:rearrange(t,'b p n (h d) -> b p h n d',h=self.heads),qkv)\n        dots=torch.matmul(q,k.transpose(-1,-2))*self.scale\n        attn=self.attend(dots)\n        out=torch.matmul(attn,v)\n        out=rearrange(out,'b p h n d -> b p n (h d)')\n        return self.to_out(out)\n\n\n\n\n\nclass Transformer(nn.Module):\n    def __init__(self,dim,depth,heads,head_dim,mlp_dim,dropout=0.):\n        super().__init__()\n        self.layers=nn.ModuleList([])\n        for _ in range(depth):\n            self.layers.append(nn.ModuleList([\n                PreNorm(dim,Attention(dim,heads,head_dim,dropout)),\n                PreNorm(dim,FeedForward(dim,mlp_dim,dropout))\n            ]))\n\n\n    def forward(self,x):\n        out=x\n        for att,ffn in self.layers:\n            out=out+att(out)\n            out=out+ffn(out)\n        return out\n\nclass MobileViTAttention(nn.Module):\n    def __init__(self,in_channel=3,dim=512,kernel_size=3,patch_size=7):\n        super().__init__()\n        self.ph,self.pw=patch_size,patch_size\n        self.conv1=nn.Conv2d(in_channel,in_channel,kernel_size=kernel_size,padding=kernel_size//2)\n        self.conv2=nn.Conv2d(in_channel,dim,kernel_size=1)\n\n        self.trans=Transformer(dim=dim,depth=3,heads=8,head_dim=64,mlp_dim=1024)\n\n        self.conv3=nn.Conv2d(dim,in_channel,kernel_size=1)\n        self.conv4=nn.Conv2d(2*in_channel,in_channel,kernel_size=kernel_size,padding=kernel_size//2)\n\n    def forward(self,x):\n        y=x.clone() #bs,c,h,w\n\n        ## Local Representation\n        y=self.conv2(self.conv1(x)) #bs,dim,h,w\n\n        ## Global Representation\n        _,_,h,w=y.shape\n        y=rearrange(y,'bs dim (nh ph) (nw pw) -> bs (ph pw) (nh nw) dim',ph=self.ph,pw=self.pw) #bs,h,w,dim\n        y=self.trans(y)\n        y=rearrange(y,'bs (ph pw) (nh nw) dim -> bs dim (nh ph) (nw pw)',ph=self.ph,pw=self.pw,nh=h//self.ph,nw=w//self.pw) #bs,dim,h,w\n\n        ## Fusion\n        y=self.conv3(y) #bs,dim,h,w\n        y=torch.cat([x,y],1) #bs,2*dim,h,w\n        y=self.conv4(y) #bs,c,h,w\n\n        return y\n\n\nif __name__ == '__main__':\n    m=MobileViTAttention()\n    input=torch.randn(1,3,49,49)\n    output=m(input)\n    print(output.shape)\n    "
  },
  {
    "path": "model/attention/MobileViTv2Attention.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\n\n\n\nclass MobileViTv2Attention(nn.Module):\n    '''\n    Scaled dot-product attention\n    '''\n\n    def __init__(self, d_model):\n        '''\n        :param d_model: Output dimensionality of the model\n        :param d_k: Dimensionality of queries and keys\n        :param d_v: Dimensionality of values\n        :param h: Number of heads\n        '''\n        super(MobileViTv2Attention, self).__init__()\n        self.fc_i = nn.Linear(d_model,1)\n        self.fc_k = nn.Linear(d_model, d_model)\n        self.fc_v = nn.Linear(d_model, d_model)\n        self.fc_o = nn.Linear(d_model, d_model)\n\n        self.d_model = d_model\n        self.init_weights()\n\n\n    def init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n    def forward(self, input):\n        '''\n        Computes\n        :param queries: Queries (b_s, nq, d_model)\n        :return:\n        '''\n        i = self.fc_i(input) #(bs,nq,1)\n        weight_i = torch.softmax(i, dim=1) #bs,nq,1\n        context_score = weight_i * self.fc_k(input) #bs,nq,d_model\n        context_vector = torch.sum(context_score,dim=1,keepdim=True) #bs,1,d_model\n        v = self.fc_v(input) * context_vector #bs,nq,d_model\n        out = self.fc_o(v) #bs,nq,d_model\n\n        return out\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    sa = MobileViTv2Attention(d_model=512)\n    output=sa(input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/OutlookAttention.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\nimport math\nfrom torch.nn import functional as F\n\nclass OutlookAttention(nn.Module):\n\n    def __init__(self,dim,num_heads=1,kernel_size=3,padding=1,stride=1,qkv_bias=False,\n                    attn_drop=0.1):\n        super().__init__()\n        self.dim=dim\n        self.num_heads=num_heads\n        self.head_dim=dim//num_heads\n        self.kernel_size=kernel_size\n        self.padding=padding\n        self.stride=stride\n        self.scale=self.head_dim**(-0.5)\n\n        self.v_pj=nn.Linear(dim,dim,bias=qkv_bias)\n        self.attn=nn.Linear(dim,kernel_size**4*num_heads)\n\n        self.attn_drop=nn.Dropout(attn_drop)\n        self.proj=nn.Linear(dim,dim)\n        self.proj_drop=nn.Dropout(attn_drop)\n\n        self.unflod=nn.Unfold(kernel_size,padding,stride) #手动卷积\n        self.pool=nn.AvgPool2d(kernel_size=stride,stride=stride,ceil_mode=True) \n\n    def forward(self, x) :\n        B,H,W,C=x.shape\n\n        #映射到新的特征v\n        v=self.v_pj(x).permute(0,3,1,2) #B,C,H,W\n        h,w=math.ceil(H/self.stride),math.ceil(W/self.stride)\n        v=self.unflod(v).reshape(B,self.num_heads,self.head_dim,self.kernel_size*self.kernel_size,h*w).permute(0,1,4,3,2) #B,num_head,H*W,kxk,head_dim\n\n        #生成Attention Map\n        attn=self.pool(x.permute(0,3,1,2)).permute(0,2,3,1) #B,H,W,C\n        attn=self.attn(attn).reshape(B,h*w,self.num_heads,self.kernel_size*self.kernel_size \\\n                    ,self.kernel_size*self.kernel_size).permute(0,2,1,3,4) #B，num_head，H*W,kxk,kxk\n        attn=self.scale*attn\n        attn=attn.softmax(-1)\n        attn=self.attn_drop(attn)\n\n        #获取weighted特征\n        out=(attn @ v).permute(0,1,4,3,2).reshape(B,C*self.kernel_size*self.kernel_size,h*w) #B,dimxkxk,H*W\n        out=F.fold(out,output_size=(H,W),kernel_size=self.kernel_size,\n                    padding=self.padding,stride=self.stride) #B,C,H,W\n        out=self.proj(out.permute(0,2,3,1)) #B,H,W,C\n        out=self.proj_drop(out)\n\n        return out\n\n        \n\nif __name__ == '__main__':\n    input=torch.randn(50,28,28,512)\n    outlook = OutlookAttention(dim=512)\n    output=outlook(input)\n    print(output.shape)\n\n\n\n\n"
  },
  {
    "path": "model/attention/PSA.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\n\n\n\nclass PSA(nn.Module):\n\n    def __init__(self, channel=512,reduction=4,S=4):\n        super().__init__()\n        self.S=S\n\n        self.convs=[]\n        for i in range(S):\n            self.convs.append(nn.Conv2d(channel//S,channel//S,kernel_size=2*(i+1)+1,padding=i+1))\n\n        self.se_blocks=[]\n        for i in range(S):\n            self.se_blocks.append(nn.Sequential(\n                nn.AdaptiveAvgPool2d(1),\n                nn.Conv2d(channel//S, channel // (S*reduction),kernel_size=1, bias=False),\n                nn.ReLU(inplace=True),\n                nn.Conv2d(channel // (S*reduction), channel//S,kernel_size=1, bias=False),\n                nn.Sigmoid()\n            ))\n        \n        self.softmax=nn.Softmax(dim=1)\n\n\n    def init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n    def forward(self, x):\n        b, c, h, w = x.size()\n\n        #Step1:SPC module\n        SPC_out=x.view(b,self.S,c//self.S,h,w) #bs,s,ci,h,w\n        for idx,conv in enumerate(self.convs):\n            SPC_out[:,idx,:,:,:]=conv(SPC_out[:,idx,:,:,:])\n\n        #Step2:SE weight\n        se_out=[]\n        for idx,se in enumerate(self.se_blocks):\n            se_out.append(se(SPC_out[:,idx,:,:,:]))\n        SE_out=torch.stack(se_out,dim=1)\n        SE_out=SE_out.expand_as(SPC_out)\n\n        #Step3:Softmax\n        softmax_out=self.softmax(SE_out)\n\n        #Step4:SPA\n        PSA_out=SPC_out*softmax_out\n        PSA_out=PSA_out.view(b,-1,h,w)\n\n        return PSA_out\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    psa = PSA(channel=512,reduction=8)\n    output=psa(input)\n    a=output.view(-1).sum()\n    a.backward()\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/ParNetAttention.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\n\n\n\nclass ParNetAttention(nn.Module):\n\n    def __init__(self, channel=512):\n        super().__init__()\n        self.sse = nn.Sequential(\n            nn.AdaptiveAvgPool2d(1),\n            nn.Conv2d(channel,channel,kernel_size=1),\n            nn.Sigmoid()\n        )\n\n        self.conv1x1=nn.Sequential(\n            nn.Conv2d(channel,channel,kernel_size=1),\n            nn.BatchNorm2d(channel)\n        )\n        self.conv3x3=nn.Sequential(\n            nn.Conv2d(channel,channel,kernel_size=3,padding=1),\n            nn.BatchNorm2d(channel)\n        )\n        self.silu=nn.SiLU()\n        \n\n    def forward(self, x):\n        b, c, _, _ = x.size()\n        x1=self.conv1x1(x)\n        x2=self.conv3x3(x)\n        x3=self.sse(x)*x\n        y=self.silu(x1+x2+x3)\n        return y\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    pna = ParNetAttention(channel=512)\n    output=pna(input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/PolarizedSelfAttention.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\n\n\n\nclass ParallelPolarizedSelfAttention(nn.Module):\n\n    def __init__(self, channel=512):\n        super().__init__()\n        self.ch_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1))\n        self.ch_wq=nn.Conv2d(channel,1,kernel_size=(1,1))\n        self.softmax_channel=nn.Softmax(1)\n        self.softmax_spatial=nn.Softmax(-1)\n        self.ch_wz=nn.Conv2d(channel//2,channel,kernel_size=(1,1))\n        self.ln=nn.LayerNorm(channel)\n        self.sigmoid=nn.Sigmoid()\n        self.sp_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1))\n        self.sp_wq=nn.Conv2d(channel,channel//2,kernel_size=(1,1))\n        self.agp=nn.AdaptiveAvgPool2d((1,1))\n\n    def forward(self, x):\n        b, c, h, w = x.size()\n\n        #Channel-only Self-Attention\n        channel_wv=self.ch_wv(x) #bs,c//2,h,w\n        channel_wq=self.ch_wq(x) #bs,1,h,w\n        channel_wv=channel_wv.reshape(b,c//2,-1) #bs,c//2,h*w\n        channel_wq=channel_wq.reshape(b,-1,1) #bs,h*w,1\n        channel_wq=self.softmax_channel(channel_wq)\n        channel_wz=torch.matmul(channel_wv,channel_wq).unsqueeze(-1) #bs,c//2,1,1\n        channel_weight=self.sigmoid(self.ln(self.ch_wz(channel_wz).reshape(b,c,1).permute(0,2,1))).permute(0,2,1).reshape(b,c,1,1) #bs,c,1,1\n        channel_out=channel_weight*x\n\n        #Spatial-only Self-Attention\n        spatial_wv=self.sp_wv(x) #bs,c//2,h,w\n        spatial_wq=self.sp_wq(x) #bs,c//2,h,w\n        spatial_wq=self.agp(spatial_wq) #bs,c//2,1,1\n        spatial_wv=spatial_wv.reshape(b,c//2,-1) #bs,c//2,h*w\n        spatial_wq=spatial_wq.permute(0,2,3,1).reshape(b,1,c//2) #bs,1,c//2\n        spatial_wq=self.softmax_spatial(spatial_wq)\n        spatial_wz=torch.matmul(spatial_wq,spatial_wv) #bs,1,h*w\n        spatial_weight=self.sigmoid(spatial_wz.reshape(b,1,h,w)) #bs,1,h,w\n        spatial_out=spatial_weight*x\n        out=spatial_out+channel_out\n        return out\n\n\n\n\n\n\nclass SequentialPolarizedSelfAttention(nn.Module):\n\n    def __init__(self, channel=512):\n        super().__init__()\n        self.ch_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1))\n        self.ch_wq=nn.Conv2d(channel,1,kernel_size=(1,1))\n        self.softmax_channel=nn.Softmax(1)\n        self.softmax_spatial=nn.Softmax(-1)\n        self.ch_wz=nn.Conv2d(channel//2,channel,kernel_size=(1,1))\n        self.ln=nn.LayerNorm(channel)\n        self.sigmoid=nn.Sigmoid()\n        self.sp_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1))\n        self.sp_wq=nn.Conv2d(channel,channel//2,kernel_size=(1,1))\n        self.agp=nn.AdaptiveAvgPool2d((1,1))\n\n    def forward(self, x):\n        b, c, h, w = x.size()\n\n        #Channel-only Self-Attention\n        channel_wv=self.ch_wv(x) #bs,c//2,h,w\n        channel_wq=self.ch_wq(x) #bs,1,h,w\n        channel_wv=channel_wv.reshape(b,c//2,-1) #bs,c//2,h*w\n        channel_wq=channel_wq.reshape(b,-1,1) #bs,h*w,1\n        channel_wq=self.softmax_channel(channel_wq)\n        channel_wz=torch.matmul(channel_wv,channel_wq).unsqueeze(-1) #bs,c//2,1,1\n        channel_weight=self.sigmoid(self.ln(self.ch_wz(channel_wz).reshape(b,c,1).permute(0,2,1))).permute(0,2,1).reshape(b,c,1,1) #bs,c,1,1\n        channel_out=channel_weight*x\n\n        #Spatial-only Self-Attention\n        spatial_wv=self.sp_wv(channel_out) #bs,c//2,h,w\n        spatial_wq=self.sp_wq(channel_out) #bs,c//2,h,w\n        spatial_wq=self.agp(spatial_wq) #bs,c//2,1,1\n        spatial_wv=spatial_wv.reshape(b,c//2,-1) #bs,c//2,h*w\n        spatial_wq=spatial_wq.permute(0,2,3,1).reshape(b,1,c//2) #bs,1,c//2\n        spatial_wq=self.softmax_spatial(spatial_wq)\n        spatial_wz=torch.matmul(spatial_wq,spatial_wv) #bs,1,h*w\n        spatial_weight=self.sigmoid(spatial_wz.reshape(b,1,h,w)) #bs,1,h,w\n        spatial_out=spatial_weight*channel_out\n        return spatial_out\n\n\n\n\nif __name__ == '__main__':\n    input=torch.randn(1,512,7,7)\n    psa = SequentialPolarizedSelfAttention(channel=512)\n    output=psa(input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/ResidualAttention.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\n\n\n\nclass ResidualAttention(nn.Module):\n\n    def __init__(self, channel=512 , num_class=1000,la=0.2):\n        super().__init__()\n        self.la=la\n        self.fc=nn.Conv2d(in_channels=channel,out_channels=num_class,kernel_size=1,stride=1,bias=False)\n\n    def forward(self, x):\n        b,c,h,w=x.shape\n        y_raw=self.fc(x).flatten(2) #b,num_class,hxw\n        y_avg=torch.mean(y_raw,dim=2) #b,num_class\n        y_max=torch.max(y_raw,dim=2)[0] #b,num_class\n        score=y_avg+self.la*y_max\n        return score\n\n        \n\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    resatt = ResidualAttention(channel=512,num_class=1000,la=0.2)\n    output=resatt(input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/S2Attention.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\n\n\ndef spatial_shift1(x):\n    b,w,h,c = x.size()\n    x[:,1:,:,:c//4] = x[:,:w-1,:,:c//4]\n    x[:,:w-1,:,c//4:c//2] = x[:,1:,:,c//4:c//2]\n    x[:,:,1:,c//2:c*3//4] = x[:,:,:h-1,c//2:c*3//4]\n    x[:,:,:h-1,3*c//4:] = x[:,:,1:,3*c//4:]\n    return x\n\n\ndef spatial_shift2(x):\n    b,w,h,c = x.size()\n    x[:,:,1:,:c//4] = x[:,:,:h-1,:c//4]\n    x[:,:,:h-1,c//4:c//2] = x[:,:,1:,c//4:c//2]\n    x[:,1:,:,c//2:c*3//4] = x[:,:w-1,:,c//2:c*3//4]\n    x[:,:w-1,:,3*c//4:] = x[:,1:,:,3*c//4:]\n    return x\n\n\nclass SplitAttention(nn.Module):\n    def __init__(self,channel=512,k=3):\n        super().__init__()\n        self.channel=channel\n        self.k=k\n        self.mlp1=nn.Linear(channel,channel,bias=False)\n        self.gelu=nn.GELU()\n        self.mlp2=nn.Linear(channel,channel*k,bias=False)\n        self.softmax=nn.Softmax(1)\n    \n    def forward(self,x_all):\n        b,k,h,w,c=x_all.shape\n        x_all=x_all.reshape(b,k,-1,c) #bs,k,n,c\n        a=torch.sum(torch.sum(x_all,1),1) #bs,c\n        hat_a=self.mlp2(self.gelu(self.mlp1(a))) #bs,kc\n        hat_a=hat_a.reshape(b,self.k,c) #bs,k,c\n        bar_a=self.softmax(hat_a) #bs,k,c\n        attention=bar_a.unsqueeze(-2) # #bs,k,1,c\n        out=attention*x_all # #bs,k,n,c\n        out=torch.sum(out,1).reshape(b,h,w,c)\n        return out\n\n\nclass S2Attention(nn.Module):\n\n    def __init__(self, channels=512 ):\n        super().__init__()\n        self.mlp1 = nn.Linear(channels,channels*3)\n        self.mlp2 = nn.Linear(channels,channels)\n        self.split_attention = SplitAttention()\n\n    def forward(self, x):\n        b,c,w,h = x.size()\n        x=x.permute(0,2,3,1)\n        x = self.mlp1(x)\n        x1 = spatial_shift1(x[:,:,:,:c])\n        x2 = spatial_shift2(x[:,:,:,c:c*2])\n        x3 = x[:,:,:,c*2:]\n        x_all=torch.stack([x1,x2,x3],1)\n        a = self.split_attention(x_all)\n        x = self.mlp2(a)\n        x=x.permute(0,3,1,2)\n        return x\n\n        \n\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    s2att = S2Attention(channels=512)\n    output=s2att(input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/SEAttention.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\n\n\n\nclass SEAttention(nn.Module):\n\n    def __init__(self, channel=512,reduction=16):\n        super().__init__()\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Sequential(\n            nn.Linear(channel, channel // reduction, bias=False),\n            nn.ReLU(inplace=True),\n            nn.Linear(channel // reduction, channel, bias=False),\n            nn.Sigmoid()\n        )\n\n\n    def init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n    def forward(self, x):\n        b, c, _, _ = x.size()\n        y = self.avg_pool(x).view(b, c)\n        y = self.fc(y).view(b, c, 1, 1)\n        return x * y.expand_as(x)\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    se = SEAttention(channel=512,reduction=8)\n    output=se(input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/SGE.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\n\n\n\nclass SpatialGroupEnhance(nn.Module):\n\n    def __init__(self, groups):\n        super().__init__()\n        self.groups=groups\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.weight=nn.Parameter(torch.zeros(1,groups,1,1))\n        self.bias=nn.Parameter(torch.zeros(1,groups,1,1))\n        self.sig=nn.Sigmoid()\n        self.init_weights()\n\n\n    def init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n    def forward(self, x):\n        b, c, h,w=x.shape\n        x=x.view(b*self.groups,-1,h,w) #bs*g,dim//g,h,w\n        xn=x*self.avg_pool(x) #bs*g,dim//g,h,w\n        xn=xn.sum(dim=1,keepdim=True) #bs*g,1,h,w\n        t=xn.view(b*self.groups,-1) #bs*g,h*w\n\n        t=t-t.mean(dim=1,keepdim=True) #bs*g,h*w\n        std=t.std(dim=1,keepdim=True)+1e-5\n        t=t/std #bs*g,h*w\n        t=t.view(b,self.groups,h,w) #bs,g,h*w\n        \n        t=t*self.weight+self.bias #bs,g,h*w\n        t=t.view(b*self.groups,1,h,w) #bs*g,1,h*w\n        x=x*self.sig(t)\n        x=x.view(b,c,h,w)\n\n        return x \n\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    sge = SpatialGroupEnhance(groups=8)\n    output=sge(input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/SKAttention.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\nfrom collections import OrderedDict\n\n\n\nclass SKAttention(nn.Module):\n\n    def __init__(self, channel=512,kernels=[1,3,5,7],reduction=16,group=1,L=32):\n        super().__init__()\n        self.d=max(L,channel//reduction)\n        self.convs=nn.ModuleList([])\n        for k in kernels:\n            self.convs.append(\n                nn.Sequential(OrderedDict([\n                    ('conv',nn.Conv2d(channel,channel,kernel_size=k,padding=k//2,groups=group)),\n                    ('bn',nn.BatchNorm2d(channel)),\n                    ('relu',nn.ReLU())\n                ]))\n            )\n        self.fc=nn.Linear(channel,self.d)\n        self.fcs=nn.ModuleList([])\n        for i in range(len(kernels)):\n            self.fcs.append(nn.Linear(self.d,channel))\n        self.softmax=nn.Softmax(dim=0)\n\n\n\n    def forward(self, x):\n        bs, c, _, _ = x.size()\n        conv_outs=[]\n        ### split\n        for conv in self.convs:\n            conv_outs.append(conv(x))\n        feats=torch.stack(conv_outs,0)#k,bs,channel,h,w\n\n        ### fuse\n        U=sum(conv_outs) #bs,c,h,w\n\n        ### reduction channel\n        S=U.mean(-1).mean(-1) #bs,c\n        Z=self.fc(S) #bs,d\n\n        ### calculate attention weight\n        weights=[]\n        for fc in self.fcs:\n            weight=fc(Z)\n            weights.append(weight.view(bs,c,1,1)) #bs,channel\n        attention_weughts=torch.stack(weights,0)#k,bs,channel,1,1\n        attention_weughts=self.softmax(attention_weughts)#k,bs,channel,1,1\n\n        ### fuse\n        V=(attention_weughts*feats).sum(0)\n        return V\n\n        \n\n\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    se = SKAttention(channel=512,reduction=8)\n    output=se(input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/SelfAttention.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\n\n\n\nclass ScaledDotProductAttention(nn.Module):\n    '''\n    Scaled dot-product attention\n    '''\n\n    def __init__(self, d_model, d_k, d_v, h,dropout=.1):\n        '''\n        :param d_model: Output dimensionality of the model\n        :param d_k: Dimensionality of queries and keys\n        :param d_v: Dimensionality of values\n        :param h: Number of heads\n        '''\n        super(ScaledDotProductAttention, self).__init__()\n        self.fc_q = nn.Linear(d_model, h * d_k)\n        self.fc_k = nn.Linear(d_model, h * d_k)\n        self.fc_v = nn.Linear(d_model, h * d_v)\n        self.fc_o = nn.Linear(h * d_v, d_model)\n        self.dropout=nn.Dropout(dropout)\n\n        self.d_model = d_model\n        self.d_k = d_k\n        self.d_v = d_v\n        self.h = h\n\n        self.init_weights()\n\n\n    def init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):\n        '''\n        Computes\n        :param queries: Queries (b_s, nq, d_model)\n        :param keys: Keys (b_s, nk, d_model)\n        :param values: Values (b_s, nk, d_model)\n        :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.\n        :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).\n        :return:\n        '''\n        b_s, nq = queries.shape[:2]\n        nk = keys.shape[1]\n\n        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)\n        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)\n        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)\n\n        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)\n        if attention_weights is not None:\n            att = att * attention_weights\n        if attention_mask is not None:\n            att = att.masked_fill(attention_mask, -np.inf)\n        att = torch.softmax(att, -1)\n        att=self.dropout(att)\n\n        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)\n        out = self.fc_o(out)  # (b_s, nq, d_model)\n        return out\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    sa = ScaledDotProductAttention(d_model=512, d_k=512, d_v=512, h=8)\n    output=sa(input,input,input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/ShuffleAttention.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\nfrom torch.nn.parameter import Parameter\n\n\nclass ShuffleAttention(nn.Module):\n\n    def __init__(self, channel=512,reduction=16,G=8):\n        super().__init__()\n        self.G=G\n        self.channel=channel\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))\n        self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))\n        self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))\n        self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))\n        self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))\n        self.sigmoid=nn.Sigmoid()\n\n\n    def init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n\n    @staticmethod\n    def channel_shuffle(x, groups):\n        b, c, h, w = x.shape\n        x = x.reshape(b, groups, -1, h, w)\n        x = x.permute(0, 2, 1, 3, 4)\n\n        # flatten\n        x = x.reshape(b, -1, h, w)\n\n        return x\n\n    def forward(self, x):\n        b, c, h, w = x.size()\n        #group into subfeatures\n        x=x.view(b*self.G,-1,h,w) #bs*G,c//G,h,w\n\n        #channel_split\n        x_0,x_1=x.chunk(2,dim=1) #bs*G,c//(2*G),h,w\n\n        #channel attention\n        x_channel=self.avg_pool(x_0) #bs*G,c//(2*G),1,1\n        x_channel=self.cweight*x_channel+self.cbias #bs*G,c//(2*G),1,1\n        x_channel=x_0*self.sigmoid(x_channel)\n\n        #spatial attention\n        x_spatial=self.gn(x_1) #bs*G,c//(2*G),h,w\n        x_spatial=self.sweight*x_spatial+self.sbias #bs*G,c//(2*G),h,w\n        x_spatial=x_1*self.sigmoid(x_spatial) #bs*G,c//(2*G),h,w\n\n        # concatenate along channel axis\n        out=torch.cat([x_channel,x_spatial],dim=1)  #bs*G,c//G,h,w\n        out=out.contiguous().view(b,-1,h,w)\n\n        # channel shuffle\n        out = self.channel_shuffle(out, 2)\n        return out\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    se = ShuffleAttention(channel=512,G=8)\n    output=se(input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/SimAM.py",
    "content": "import torch\nimport torch.nn as nn\n\n\nclass SimAM(torch.nn.Module):\n    def __init__(self, channels = None, e_lambda = 1e-4):\n        super(SimAM, self).__init__()\n\n        self.activaton = nn.Sigmoid()\n        self.e_lambda = e_lambda\n\n    def __repr__(self):\n        s = self.__class__.__name__ + '('\n        s += ('lambda=%f)' % self.e_lambda)\n        return s\n\n    @staticmethod\n    def get_module_name():\n        return \"simam\"\n\n    def forward(self, x):\n\n        b, c, h, w = x.size()\n        \n        n = w * h - 1\n\n        x_minus_mu_square = (x - x.mean(dim=[2,3], keepdim=True)).pow(2)\n        y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2,3], keepdim=True) / n + self.e_lambda)) + 0.5\n\n        return x * self.activaton(y)\n\nif __name__ == '__main__':\n    input=torch.randn(3, 64, 7, 7)\n    model = SimAM(64)\n    outputs = model(input)\n    print(outputs.shape)"
  },
  {
    "path": "model/attention/SimplifiedSelfAttention.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.nn import init\n\n\n\nclass SimplifiedScaledDotProductAttention(nn.Module):\n    '''\n    Scaled dot-product attention\n    '''\n\n    def __init__(self, d_model, h,dropout=.1):\n        '''\n        :param d_model: Output dimensionality of the model\n        :param d_k: Dimensionality of queries and keys\n        :param d_v: Dimensionality of values\n        :param h: Number of heads\n        '''\n        super(SimplifiedScaledDotProductAttention, self).__init__()\n\n        self.d_model = d_model\n        self.d_k = d_model//h\n        self.d_v = d_model//h\n        self.h = h\n\n        self.fc_o = nn.Linear(h * self.d_v, d_model)\n        self.dropout=nn.Dropout(dropout)\n\n\n\n        self.init_weights()\n\n\n    def init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):\n        '''\n        Computes\n        :param queries: Queries (b_s, nq, d_model)\n        :param keys: Keys (b_s, nk, d_model)\n        :param values: Values (b_s, nk, d_model)\n        :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.\n        :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).\n        :return:\n        '''\n        b_s, nq = queries.shape[:2]\n        nk = keys.shape[1]\n\n        q = queries.view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)\n        k = keys.view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)\n        v = values.view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)\n\n        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)\n        if attention_weights is not None:\n            att = att * attention_weights\n        if attention_mask is not None:\n            att = att.masked_fill(attention_mask, -np.inf)\n        att = torch.softmax(att, -1)\n        att=self.dropout(att)\n\n        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)\n        out = self.fc_o(out)  # (b_s, nq, d_model)\n        return out\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    ssa = SimplifiedScaledDotProductAttention(d_model=512, h=8)\n    output=ssa(input,input,input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/TripletAttention.py",
    "content": "import torch\nimport torch.nn as nn\n\nclass BasicConv(nn.Module):\n    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):\n        super(BasicConv, self).__init__()\n        self.out_channels = out_planes\n        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)\n        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None\n        self.relu = nn.ReLU() if relu else None\n\n    def forward(self, x):\n        x = self.conv(x)\n        if self.bn is not None:\n            x = self.bn(x)\n        if self.relu is not None:\n            x = self.relu(x)\n        return x\n\nclass ZPool(nn.Module):\n    def forward(self, x):\n        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1)\n\nclass AttentionGate(nn.Module):\n    def __init__(self):\n        super(AttentionGate, self).__init__()\n        kernel_size = 7\n        self.compress = ZPool()\n        self.conv = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)\n    def forward(self, x):\n        x_compress = self.compress(x)\n        x_out = self.conv(x_compress)\n        scale = torch.sigmoid_(x_out) \n        return x * scale\n\nclass TripletAttention(nn.Module):\n    def __init__(self, no_spatial=False):\n        super(TripletAttention, self).__init__()\n        self.cw = AttentionGate()\n        self.hc = AttentionGate()\n        self.no_spatial=no_spatial\n        if not no_spatial:\n            self.hw = AttentionGate()\n    def forward(self, x):\n        x_perm1 = x.permute(0,2,1,3).contiguous()\n        x_out1 = self.cw(x_perm1)\n        x_out11 = x_out1.permute(0,2,1,3).contiguous()\n        x_perm2 = x.permute(0,3,2,1).contiguous()\n        x_out2 = self.hc(x_perm2)\n        x_out21 = x_out2.permute(0,3,2,1).contiguous()\n        if not self.no_spatial:\n            x_out = self.hw(x)\n            x_out = 1/3 * (x_out + x_out11 + x_out21)\n        else:\n            x_out = 1/2 * (x_out11 + x_out21)\n        return x_out\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    triplet = TripletAttention()\n    output=triplet(input)\n    print(output.shape)\n    "
  },
  {
    "path": "model/attention/UFOAttention.py",
    "content": "import numpy as np\nimport torch\nfrom torch import nn\nfrom torch.functional import norm\nfrom torch.nn import init\n\n\ndef XNorm(x,gamma):\n    norm_tensor=torch.norm(x,2,-1,True)\n    return x*gamma/norm_tensor\n\n\nclass UFOAttention(nn.Module):\n    '''\n    Scaled dot-product attention\n    '''\n\n    def __init__(self, d_model, d_k, d_v, h,dropout=.1):\n        '''\n        :param d_model: Output dimensionality of the model\n        :param d_k: Dimensionality of queries and keys\n        :param d_v: Dimensionality of values\n        :param h: Number of heads\n        '''\n        super(UFOAttention, self).__init__()\n        self.fc_q = nn.Linear(d_model, h * d_k)\n        self.fc_k = nn.Linear(d_model, h * d_k)\n        self.fc_v = nn.Linear(d_model, h * d_v)\n        self.fc_o = nn.Linear(h * d_v, d_model)\n        self.dropout=nn.Dropout(dropout)\n        self.gamma=nn.Parameter(torch.randn((1,h,1,1)))\n\n        self.d_model = d_model\n        self.d_k = d_k\n        self.d_v = d_v\n        self.h = h\n\n        self.init_weights()\n\n\n    def init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n    def forward(self, queries, keys, values):\n        b_s, nq = queries.shape[:2]\n        nk = keys.shape[1]\n\n        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)\n        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)\n        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)\n\n        kv=torch.matmul(k, v) #bs,h,c,c\n        kv_norm=XNorm(kv,self.gamma) #bs,h,c,c\n        q_norm=XNorm(q,self.gamma) #bs,h,n,c\n        out=torch.matmul(q_norm,kv_norm).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)\n        out = self.fc_o(out)  # (b_s, nq, d_model)\n\n        \n        return out\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    ufo = UFOAttention(d_model=512, d_k=512, d_v=512, h=8)\n    output=ufo(input,input,input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/attention/ViP.py",
    "content": "import torch\nfrom torch import nn\n\n\nclass MLP(nn.Module):\n    def __init__(self,in_features,hidden_features,out_features,act_layer=nn.GELU,drop=0.1):\n        super().__init__()\n        self.fc1=nn.Linear(in_features,hidden_features)\n        self.act=act_layer()\n        self.fc2=nn.Linear(hidden_features,out_features)\n        self.drop=nn.Dropout(drop)\n\n    def forward(self, x) :\n        return self.drop(self.fc2(self.drop(self.act(self.fc1(x)))))\n\nclass WeightedPermuteMLP(nn.Module):\n    def __init__(self,dim,seg_dim=8, qkv_bias=False, proj_drop=0.):\n        super().__init__()\n        self.seg_dim=seg_dim\n\n        self.mlp_c=nn.Linear(dim,dim,bias=qkv_bias)\n        self.mlp_h=nn.Linear(dim,dim,bias=qkv_bias)\n        self.mlp_w=nn.Linear(dim,dim,bias=qkv_bias)\n\n        self.reweighting=MLP(dim,dim//4,dim*3)\n\n        self.proj=nn.Linear(dim,dim)\n        self.proj_drop=nn.Dropout(proj_drop)\n    \n    def forward(self,x) :\n        B,H,W,C=x.shape\n\n        c_embed=self.mlp_c(x)\n\n        S=C//self.seg_dim\n        h_embed=x.reshape(B,H,W,self.seg_dim,S).permute(0,3,2,1,4).reshape(B,self.seg_dim,W,H*S)\n        h_embed=self.mlp_h(h_embed).reshape(B,self.seg_dim,W,H,S).permute(0,3,2,1,4).reshape(B,H,W,C)\n\n        w_embed=x.reshape(B,H,W,self.seg_dim,S).permute(0,3,1,2,4).reshape(B,self.seg_dim,H,W*S)\n        w_embed=self.mlp_w(w_embed).reshape(B,self.seg_dim,H,W,S).permute(0,2,3,1,4).reshape(B,H,W,C)\n\n        weight=(c_embed+h_embed+w_embed).permute(0,3,1,2).flatten(2).mean(2)\n        weight=self.reweighting(weight).reshape(B,C,3).permute(2,0,1).softmax(0).unsqueeze(2).unsqueeze(2)\n\n        x=c_embed*weight[0]+w_embed*weight[1]+h_embed*weight[2]\n\n        x=self.proj_drop(self.proj(x))\n\n        return x\n\n\n\nif __name__ == '__main__':\n    input=torch.randn(64,8,8,512)\n    seg_dim=8\n    vip=WeightedPermuteMLP(512,seg_dim)\n    out=vip(input)\n    print(out.shape)\n    "
  },
  {
    "path": "model/attention/gfnet.py",
    "content": "import torch\nfrom torch import nn\nimport math\nfrom timm.models.layers import DropPath, to_2tuple\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)\n        return x\n\nclass GlobalFilter(nn.Module):\n    def __init__(self, dim, h=14, w=8):\n        super().__init__()\n        self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)\n        self.w = w\n        self.h = h\n\n    def forward(self, x, spatial_size=None):\n        B, N, C = x.shape\n        if spatial_size is None:\n            a = b = int(math.sqrt(N))\n        else:\n            a, b = spatial_size\n\n        x = x.view(B, a, b, C)\n\n        x = x.to(torch.float32)\n\n        x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')\n        weight = torch.view_as_complex(self.complex_weight)\n        x = x * weight\n        x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm='ortho')\n\n        x = x.reshape(B, N, C)\n        return x\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\nclass Block(nn.Module):\n    def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, h=14, w=8):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.filter = GlobalFilter(dim, h=h, w=w)\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n    def forward(self, x):\n        x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x)))))\n        return x\n\n\nclass GFNet(nn.Module):\n    def __init__(self, embed_dim=384, img_size=224, patch_size=16, mlp_ratio=4, depth=4, num_classes=1000):\n        super().__init__()\n        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)\n        self.embedding = nn.Linear((patch_size ** 2) * 3, embed_dim)\n\n        h = img_size // patch_size\n        w = h // 2 + 1\n\n\n        self.blocks = nn.ModuleList([\n            Block(dim=embed_dim, mlp_ratio=mlp_ratio, h=h, w=w)\n            for i in range(depth)\n        ])\n\n        self.head = nn.Linear(embed_dim, num_classes)\n        self.softmax = nn.Softmax(1)\n\n    def forward(self, x):\n        x = self.patch_embed(x)\n        for blk in self.blocks:\n            x = blk(x)\n        x = x.mean(dim=1)\n        x = self.softmax(self.head(x))\n        return x\n\nif __name__ == '__main__':\n    x = torch.randn(1, 3, 224, 224)\n    gfnet = GFNet(embed_dim=384, img_size=224, patch_size=16, num_classes=1000)\n    out = gfnet(x)\n    print(out.shape)\n\n    \n"
  },
  {
    "path": "model/backbone/CMT.py",
    "content": "## Author: Jianyuan Guo (jyguo@pku.edu.cn)\n\nimport math\nimport logging\nfrom functools import partial\nfrom collections import OrderedDict\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\nfrom timm.models.helpers import load_pretrained\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nfrom timm.models.resnet import resnet26d, resnet50d\nfrom timm.models.registry import register_model\n\n_logger = logging.getLogger(__name__)\n\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n        'crop_pct': .9, 'interpolation': 'bicubic',\n        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,\n        'first_conv': 'patch_embed.proj', 'classifier': 'head',\n        **kwargs\n    }\n\n\n# A memory-efficient implementation of Swish function\nclass SwishImplementation(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, i):\n        result = i * torch.sigmoid(i)\n        ctx.save_for_backward(i)\n        return result\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        i = ctx.saved_tensors[0]\n        sigmoid_i = torch.sigmoid(i)\n        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))\n\n\nclass MemoryEfficientSwish(nn.Module):\n    def forward(self, x):\n        return SwishImplementation.apply(x)\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.conv1 = nn.Sequential(\n            nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True),\n            nn.GELU(),\n            nn.BatchNorm2d(hidden_features, eps=1e-5),\n        )\n        self.proj = nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, groups=hidden_features)\n        self.proj_act = nn.GELU()\n        self.proj_bn = nn.BatchNorm2d(hidden_features, eps=1e-5)\n        self.conv2 = nn.Sequential(\n            nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True),\n            nn.BatchNorm2d(out_features, eps=1e-5),\n        )\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x, H, W):\n        B, N, C = x.shape\n        x = x.permute(0, 2, 1).reshape(B, C, H, W)\n        x = self.conv1(x)\n        x = self.drop(x)\n        x = self.proj(x) + x\n        x = self.proj_act(x)\n        x = self.proj_bn(x)\n        x = self.conv2(x)\n        x = x.flatten(2).permute(0, 2, 1)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, \n                 attn_drop=0., proj_drop=0., qk_ratio=1, sr_ratio=1):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n        self.qk_dim = dim // qk_ratio\n\n        self.q = nn.Linear(dim, self.qk_dim, bias=qkv_bias)\n        self.k = nn.Linear(dim, self.qk_dim, bias=qkv_bias)\n        self.v = nn.Linear(dim, dim, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n        \n        self.sr_ratio = sr_ratio\n        # Exactly same as PVTv1\n        if self.sr_ratio > 1:\n            self.sr = nn.Sequential(\n                nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio, groups=dim, bias=True),\n                nn.BatchNorm2d(dim, eps=1e-5),\n            )\n\n    def forward(self, x, H, W, relative_pos):\n        B, N, C = x.shape\n        q = self.q(x).reshape(B, N, self.num_heads, self.qk_dim // self.num_heads).permute(0, 2, 1, 3)\n        \n        if self.sr_ratio > 1:\n            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)\n            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)\n            k = self.k(x_).reshape(B, -1, self.num_heads, self.qk_dim // self.num_heads).permute(0, 2, 1, 3)\n            v = self.v(x_).reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)\n        else:\n            k = self.k(x).reshape(B, N, self.num_heads, self.qk_dim // self.num_heads).permute(0, 2, 1, 3)\n            v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale + relative_pos\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, qk_ratio=1, sr_ratio=1):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, \n            attn_drop=attn_drop, proj_drop=drop, qk_ratio=qk_ratio, sr_ratio=sr_ratio)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n        self.proj = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)\n        \n    def forward(self, x, H, W, relative_pos):\n        B, N, C = x.shape\n        cnn_feat = x.permute(0, 2, 1).reshape(B, C, H, W)\n        x = self.proj(cnn_feat) + cnn_feat\n        x = x.flatten(2).permute(0, 2, 1)\n        x = x + self.drop_path(self.attn(self.norm1(x), H, W, relative_pos))\n        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        \n        assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \\\n            f\"img_size {img_size} should be divided by patch_size {patch_size}.\"\n        \n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        self.norm = nn.LayerNorm(embed_dim)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)\n        x = self.norm(x)\n        \n        H, W = H // self.patch_size[0], W // self.patch_size[1]\n        return x, (H, W)\n\n\nclass CMT(nn.Module):\n    def __init__(self, img_size=224, in_chans=3, num_classes=1000, embed_dims=[46,92,184,368], stem_channel=16, fc_dim=1280,\n                 num_heads=[1,2,4,8], mlp_ratios=[3.6,3.6,3.6,3.6], qkv_bias=True, qk_scale=None, representation_size=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None,\n                 depths=[2,2,10,2], qk_ratio=1, sr_ratios=[8,4,2,1], dp=0.1):\n        super().__init__()\n        self.num_classes = num_classes\n        self.num_features = self.embed_dim = embed_dims[-1]\n        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)\n        \n        self.stem_conv1 = nn.Conv2d(3, stem_channel, kernel_size=3, stride=2, padding=1, bias=True)\n        self.stem_relu1 = nn.GELU()\n        self.stem_norm1 = nn.BatchNorm2d(stem_channel, eps=1e-5)\n        \n        self.stem_conv2 = nn.Conv2d(stem_channel, stem_channel, kernel_size=3, stride=1, padding=1, bias=True)\n        self.stem_relu2 = nn.GELU()\n        self.stem_norm2 = nn.BatchNorm2d(stem_channel, eps=1e-5)\n        \n        self.stem_conv3 = nn.Conv2d(stem_channel, stem_channel, kernel_size=3, stride=1, padding=1, bias=True)\n        self.stem_relu3 = nn.GELU()\n        self.stem_norm3 = nn.BatchNorm2d(stem_channel, eps=1e-5)\n\n        self.patch_embed_a = PatchEmbed(\n            img_size=img_size//2, patch_size=2, in_chans=stem_channel, embed_dim=embed_dims[0])\n        self.patch_embed_b = PatchEmbed(\n            img_size=img_size//4, patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])\n        self.patch_embed_c = PatchEmbed(\n            img_size=img_size//8, patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])\n        self.patch_embed_d = PatchEmbed(\n            img_size=img_size//16, patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3])\n\n        self.relative_pos_a = nn.Parameter(torch.randn(\n            num_heads[0], self.patch_embed_a.num_patches, self.patch_embed_a.num_patches//sr_ratios[0]//sr_ratios[0]))\n        self.relative_pos_b = nn.Parameter(torch.randn(\n            num_heads[1], self.patch_embed_b.num_patches, self.patch_embed_b.num_patches//sr_ratios[1]//sr_ratios[1]))\n        self.relative_pos_c = nn.Parameter(torch.randn(\n            num_heads[2], self.patch_embed_c.num_patches, self.patch_embed_c.num_patches//sr_ratios[2]//sr_ratios[2]))\n        self.relative_pos_d = nn.Parameter(torch.randn(\n            num_heads[3], self.patch_embed_d.num_patches, self.patch_embed_d.num_patches//sr_ratios[3]//sr_ratios[3]))\n        \n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n        cur = 0\n        self.blocks_a = nn.ModuleList([\n            Block(\n                dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias,\n                qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur+i],\n                norm_layer=norm_layer, qk_ratio=qk_ratio, sr_ratio=sr_ratios[0])\n            for i in range(depths[0])])\n        cur += depths[0]\n        self.blocks_b = nn.ModuleList([\n            Block(\n                dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias,\n                qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur+i],\n                norm_layer=norm_layer, qk_ratio=qk_ratio, sr_ratio=sr_ratios[1])\n            for i in range(depths[1])])\n        cur += depths[1]\n        self.blocks_c = nn.ModuleList([\n            Block(\n                dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias,\n                qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur+i],\n                norm_layer=norm_layer, qk_ratio=qk_ratio, sr_ratio=sr_ratios[2])\n            for i in range(depths[2])])\n        cur += depths[2]\n        self.blocks_d = nn.ModuleList([\n            Block(\n                dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias,\n                qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur+i],\n                norm_layer=norm_layer, qk_ratio=qk_ratio, sr_ratio=sr_ratios[3])\n            for i in range(depths[3])])\n\n        # Representation layer\n        if representation_size:\n            self.num_features = representation_size\n            self.pre_logits = nn.Sequential(OrderedDict([\n                ('fc', nn.Linear(embed_dim, representation_size)),\n                ('act', nn.Tanh())\n            ]))\n        else:\n            self.pre_logits = nn.Identity()\n\n        # Classifier head\n        self._fc = nn.Conv2d(embed_dims[-1], fc_dim, kernel_size=1)\n        self._bn = nn.BatchNorm2d(fc_dim, eps=1e-5)\n        self._swish = MemoryEfficientSwish()\n        self._avg_pooling = nn.AdaptiveAvgPool2d(1)\n        self._drop = nn.Dropout(dp)\n        self.head = nn.Linear(fc_dim, num_classes) if num_classes > 0 else nn.Identity()\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.Conv2d):\n            nn.init.kaiming_normal_(m.weight, mode='fan_out')\n            if isinstance(m, nn.Conv2d) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.BatchNorm2d):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n            \n    def update_temperature(self):\n        for m in self.modules():\n            if isinstance(m, Attention):\n                m.update_temperature()\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token'}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x):\n        B = x.shape[0]\n        x = self.stem_conv1(x)\n        x = self.stem_relu1(x)\n        x = self.stem_norm1(x)\n        \n        x = self.stem_conv2(x)\n        x = self.stem_relu2(x)\n        x = self.stem_norm2(x)\n        \n        x = self.stem_conv3(x)\n        x = self.stem_relu3(x)\n        x = self.stem_norm3(x)\n        \n        x, (H, W) = self.patch_embed_a(x)\n        for i, blk in enumerate(self.blocks_a):\n            x = blk(x, H, W, self.relative_pos_a)\n            \n        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n        x, (H, W) = self.patch_embed_b(x)\n        for i, blk in enumerate(self.blocks_b):\n            x = blk(x, H, W, self.relative_pos_b)\n            \n        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n        x, (H, W) = self.patch_embed_c(x)\n        for i, blk in enumerate(self.blocks_c):\n            x = blk(x, H, W, self.relative_pos_c)\n            \n        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n        x, (H, W) = self.patch_embed_d(x)\n        for i, blk in enumerate(self.blocks_d):\n            x = blk(x, H, W, self.relative_pos_d)\n\n        B, N, C = x.shape\n        x = self._fc(x.permute(0, 2, 1).reshape(B, C, H, W))\n        x = self._bn(x)\n        x = self._swish(x)\n        x = self._avg_pooling(x).flatten(start_dim=1)\n        x = self._drop(x)\n        x = self.pre_logits(x)\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n\ndef resize_pos_embed(posemb, posemb_new):\n    # Rescale the grid of position embeddings when loading from state_dict. Adapted from\n    # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224\n    _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)\n    ntok_new = posemb_new.shape[1]\n    if True:\n        posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]\n        ntok_new -= 1\n    else:\n        posemb_tok, posemb_grid = posemb[:, :0], posemb[0]\n    gs_old = int(math.sqrt(len(posemb_grid)))\n    gs_new = int(math.sqrt(ntok_new))\n    _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)\n    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)\n    posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear')\n    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)\n    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)\n    return posemb\n\n\ndef checkpoint_filter_fn(state_dict, model):\n    \"\"\" convert patch embedding weight from manual patchify + linear proj to conv\"\"\"\n    out_dict = {}\n    if 'model' in state_dict:\n        # For deit models\n        state_dict = state_dict['model']\n    for k, v in state_dict.items():\n        if 'patch_embed.proj.weight' in k and len(v.shape) < 4:\n            # For old models that I trained prior to conv based patchification\n            O, I, H, W = model.patch_embed.proj.weight.shape\n            v = v.reshape(O, -1, H, W)\n        elif k == 'pos_embed' and v.shape != model.pos_embed.shape:\n            # To resize pos embedding when using model at different size from pretrained weights\n            v = resize_pos_embed(v, model.pos_embed)\n        out_dict[k] = v\n    return out_dict\n\n\ndef _create_cmt_model(pretrained=False, distilled=False, **kwargs):\n    default_cfg = _cfg()\n    default_num_classes = default_cfg['num_classes']\n    default_img_size = default_cfg['input_size'][-1]\n\n    num_classes = kwargs.pop('num_classes', default_num_classes)\n    img_size = kwargs.pop('img_size', default_img_size)\n    repr_size = kwargs.pop('representation_size', None)\n    if repr_size is not None and num_classes != default_num_classes:\n        # Remove representation layer if fine-tuning. This may not always be the desired action,\n        # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?\n        _logger.warning(\"Removing representation layer for fine-tuning.\")\n        repr_size = None\n\n    model = CMT(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs)\n    model.default_cfg = default_cfg\n\n    if pretrained:\n        load_pretrained(\n            model, num_classes=num_classes, in_chans=kwargs.get('in_chans', 3),\n            filter_fn=partial(checkpoint_filter_fn, model=model))\n    return model\n\n\n@register_model\ndef cmt_ti(pretrained=False, **kwargs):\n    \"\"\" \n    CMT-Tiny\n    \"\"\"\n    model_kwargs = dict(qkv_bias=True, **kwargs)\n    model = _create_cmt_model(pretrained=pretrained, **model_kwargs)\n    return model\n\n@register_model\ndef cmt_xs(pretrained=False, **kwargs):\n    \"\"\" \n    CMT-XS: dim x 0.9, depth x 0.8, input 192\n    \"\"\"\n    model_kwargs = dict(\n        qkv_bias=True, embed_dims=[52,104,208,416], stem_channel=16, num_heads=[1,2,4,8],\n        depths=[3,3,12,3], mlp_ratios=[3.77,3.77,3.77,3.77], qk_ratio=1, sr_ratios=[8,4,2,1], **kwargs)\n    model = _create_cmt_model(pretrained=pretrained, **model_kwargs)\n    return model\n\n@register_model\ndef cmt_s(pretrained=False, **kwargs):\n    \"\"\" \n    CMT-Small\n    \"\"\"\n    model_kwargs = dict(\n        qkv_bias=True, embed_dims=[64,128,256,512], stem_channel=32, num_heads=[1,2,4,8],\n        depths=[3,3,16,3], mlp_ratios=[4,4,4,4], qk_ratio=1, sr_ratios=[8,4,2,1], **kwargs)\n    model = _create_cmt_model(pretrained=pretrained, **model_kwargs)\n    return model\n\n@register_model\ndef cmt_b(pretrained=False, **kwargs):\n    \"\"\" \n    CMT-Base\n    \"\"\"\n    model_kwargs = dict(\n        qkv_bias=True, embed_dims=[76,152,304,608], stem_channel=38, num_heads=[1,2,4,8],\n        depths=[4,4,20,4], mlp_ratios=[4,4,4,4], qk_ratio=1, sr_ratios=[8,4,2,1], dp=0.3, **kwargs)\n    model = _create_cmt_model(pretrained=pretrained, **model_kwargs)\n    return model\n\n@register_model\ndef CMT_Tiny(pretrained=False, **kwargs):\n    \"\"\" \n    CMT-Tiny\n    \"\"\"\n    model_kwargs = dict(qkv_bias=True, **kwargs)\n    model = _create_cmt_model(pretrained=pretrained, **model_kwargs)\n    return model\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CMT_Tiny()\n    output=model(input)\n    print(output[0].shape)"
  },
  {
    "path": "model/backbone/CPVT.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom functools import partial\n\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nfrom timm.models.registry import register_model\nfrom timm.models.vision_transformer import _cfg\nfrom timm.models.vision_transformer import Block as TimmBlock\nfrom timm.models.vision_transformer import Attention as TimmAttention\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass GroupAttention(nn.Module):\n    \"\"\"\n    LSA: self attention within a group\n    \"\"\"\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., ws=1):\n        assert ws != 1\n        super(GroupAttention, self).__init__()\n        assert dim % num_heads == 0, f\"dim {dim} should be divided by num_heads {num_heads}.\"\n\n        self.dim = dim\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n        self.ws = ws\n\n    def forward(self, x, H, W):\n        B, N, C = x.shape\n        h_group, w_group = H // self.ws, W // self.ws\n\n        total_groups = h_group * w_group\n\n        x = x.reshape(B, h_group, self.ws, w_group, self.ws, C).transpose(2, 3)\n\n        qkv = self.qkv(x).reshape(B, total_groups, -1, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)\n        # B, hw, ws*ws, 3, n_head, head_dim -> 3, B, hw, n_head, ws*ws, head_dim\n        q, k, v = qkv[0], qkv[1], qkv[2]  # B, hw, n_head, ws*ws, head_dim\n        attn = (q @ k.transpose(-2, -1)) * self.scale  # B, hw, n_head, ws*ws, ws*ws\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(\n            attn)  # attn @ v-> B, hw, n_head, ws*ws, head_dim -> (t(2,3)) B, hw, ws*ws, n_head,  head_dim\n        attn = (attn @ v).transpose(2, 3).reshape(B, h_group, w_group, self.ws, self.ws, C)\n        x = attn.transpose(2, 3).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    \"\"\"\n    GSA: using a  key to summarize the information for a group to be efficient.\n    \"\"\"\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):\n        super().__init__()\n        assert dim % num_heads == 0, f\"dim {dim} should be divided by num_heads {num_heads}.\"\n\n        self.dim = dim\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.q = nn.Linear(dim, dim, bias=qkv_bias)\n        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        self.sr_ratio = sr_ratio\n        if sr_ratio > 1:\n            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)\n            self.norm = nn.LayerNorm(dim)\n\n    def forward(self, x, H, W):\n        B, N, C = x.shape\n        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)\n\n        if self.sr_ratio > 1:\n            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)\n            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)\n            x_ = self.norm(x_)\n            kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        else:\n            kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        k, v = kv[0], kv[1]\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n\n        return x\n\n\nclass Block(nn.Module):\n\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim,\n            num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n    def forward(self, x, H, W):\n        x = x + self.drop_path(self.attn(self.norm1(x), H, W))\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n\nclass SBlock(TimmBlock):\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):\n        super(SBlock, self).__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop,\n                                     drop_path, act_layer, norm_layer)\n\n    def forward(self, x, H, W):\n        return super(SBlock, self).forward(x)\n\n\nclass GroupBlock(TimmBlock):\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=1):\n        super(GroupBlock, self).__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop,\n                                         drop_path, act_layer, norm_layer)\n        del self.attn\n        if ws == 1:\n            self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop, drop, sr_ratio)\n        else:\n            self.attn = GroupAttention(dim, num_heads, qkv_bias, qk_scale, attn_drop, drop, ws)\n\n    def forward(self, x, H, W):\n        x = x + self.drop_path(self.attn(self.norm1(x), H, W))\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n\n        self.img_size = img_size\n        self.patch_size = patch_size\n        assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \\\n            f\"img_size {img_size} should be divided by patch_size {patch_size}.\"\n        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]\n        self.num_patches = self.H * self.W\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        self.norm = nn.LayerNorm(embed_dim)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n\n        x = self.proj(x).flatten(2).transpose(1, 2)\n        x = self.norm(x)\n        H, W = H // self.patch_size[0], W // self.patch_size[1]\n\n        return x, (H, W)\n\n\n# borrow from PVT https://github.com/whai362/PVT.git\nclass PyramidVisionTransformer(nn.Module):\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],\n                 num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,\n                 attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,\n                 depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], block_cls=Block):\n        super().__init__()\n        self.num_classes = num_classes\n        self.depths = depths\n\n        # patch_embed\n        self.patch_embeds = nn.ModuleList()\n        self.pos_embeds = nn.ParameterList()\n        self.pos_drops = nn.ModuleList()\n        self.blocks = nn.ModuleList()\n\n        for i in range(len(depths)):\n            if i == 0:\n                self.patch_embeds.append(PatchEmbed(img_size, patch_size, in_chans, embed_dims[i]))\n            else:\n                self.patch_embeds.append(\n                    PatchEmbed(img_size // patch_size // 2 ** (i - 1), 2, embed_dims[i - 1], embed_dims[i]))\n            patch_num = self.patch_embeds[-1].num_patches + 1 if i == len(embed_dims) - 1 else self.patch_embeds[\n                -1].num_patches\n            self.pos_embeds.append(nn.Parameter(torch.zeros(1, patch_num, embed_dims[i])))\n            self.pos_drops.append(nn.Dropout(p=drop_rate))\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n        cur = 0\n        for k in range(len(depths)):\n            _block = nn.ModuleList([block_cls(\n                dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], qkv_bias=qkv_bias,\n                qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,\n                sr_ratio=sr_ratios[k])\n                for i in range(depths[k])])\n            self.blocks.append(_block)\n            cur += depths[k]\n\n        self.norm = norm_layer(embed_dims[-1])\n\n        # cls_token\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1]))\n\n        # classification head\n        self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()\n\n        # init weights\n        for pos_emb in self.pos_embeds:\n            trunc_normal_(pos_emb, std=.02)\n        self.apply(self._init_weights)\n\n    def reset_drop_path(self, drop_path_rate):\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]\n        cur = 0\n        for k in range(len(self.depths)):\n            for i in range(self.depths[k]):\n                self.blocks[k][i].drop_path.drop_prob = dpr[cur + i]\n            cur += self.depths[k]\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'cls_token'}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x):\n        B = x.shape[0]\n        for i in range(len(self.depths)):\n            x, (H, W) = self.patch_embeds[i](x)\n            if i == len(self.depths) - 1:\n                cls_tokens = self.cls_token.expand(B, -1, -1)\n                x = torch.cat((cls_tokens, x), dim=1)\n            x = x + self.pos_embeds[i]\n            x = self.pos_drops[i](x)\n            for blk in self.blocks[i]:\n                x = blk(x, H, W)\n            if i < len(self.depths) - 1:\n                x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n\n        x = self.norm(x)\n\n        return x[:, 0]\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n\n        return x\n\n\n# PEG  from https://arxiv.org/abs/2102.10882\nclass PosCNN(nn.Module):\n    def __init__(self, in_chans, embed_dim=768, s=1):\n        super(PosCNN, self).__init__()\n        self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, s, 1, bias=True, groups=embed_dim), )\n        self.s = s\n\n    def forward(self, x, H, W):\n        B, N, C = x.shape\n        feat_token = x\n        cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)\n        if self.s == 1:\n            x = self.proj(cnn_feat) + cnn_feat\n        else:\n            x = self.proj(cnn_feat)\n        x = x.flatten(2).transpose(1, 2)\n        return x\n\n    def no_weight_decay(self):\n        return ['proj.%d.weight' % i for i in range(4)]\n\n\nclass CPVTV2(PyramidVisionTransformer):\n    \"\"\"\n    Use useful results from CPVT. PEG and GAP.\n    Therefore, cls token is no longer required.\n    PEG is used to encode the absolute position on the fly, which greatly affects the performance when input resolution\n    changes during the training (such as segmentation, detection)\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],\n                 num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,\n                 attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,\n                 depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], block_cls=Block):\n        super(CPVTV2, self).__init__(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads, mlp_ratios,\n                                     qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, depths,\n                                     sr_ratios, block_cls)\n        del self.pos_embeds\n        del self.cls_token\n        self.pos_block = nn.ModuleList(\n            [PosCNN(embed_dim, embed_dim) for embed_dim in embed_dims]\n        )\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        import math\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            fan_out //= m.groups\n            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n            if m.bias is not None:\n                m.bias.data.zero_()\n        elif isinstance(m, nn.BatchNorm2d):\n            m.weight.data.fill_(1.0)\n            m.bias.data.zero_()\n\n    def no_weight_decay(self):\n        return set(['cls_token'] + ['pos_block.' + n for n, p in self.pos_block.named_parameters()])\n\n    def forward_features(self, x):\n        B = x.shape[0]\n\n        for i in range(len(self.depths)):\n            x, (H, W) = self.patch_embeds[i](x)\n            x = self.pos_drops[i](x)\n            for j, blk in enumerate(self.blocks[i]):\n                x = blk(x, H, W)\n                if j == 0:\n                    x = self.pos_block[i](x, H, W)  # PEG here\n            if i < len(self.depths) - 1:\n                x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n\n        x = self.norm(x)\n\n        return x.mean(dim=1)  # GAP here\n\n\nclass PCPVT(CPVTV2):\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256],\n                 num_heads=[1, 2, 4], mlp_ratios=[4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,\n                 attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,\n                 depths=[4, 4, 4], sr_ratios=[4, 2, 1], block_cls=SBlock):\n        super(PCPVT, self).__init__(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads,\n                                    mlp_ratios, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate,\n                                    norm_layer, depths, sr_ratios, block_cls)\n\n\nclass ALTGVT(PCPVT):\n    \"\"\"\n    alias Twins-SVT\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256],\n                 num_heads=[1, 2, 4], mlp_ratios=[4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,\n                 attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,\n                 depths=[4, 4, 4], sr_ratios=[4, 2, 1], block_cls=GroupBlock, wss=[7, 7, 7]):\n        super(ALTGVT, self).__init__(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads,\n                                     mlp_ratios, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate,\n                                     norm_layer, depths, sr_ratios, block_cls)\n        del self.blocks\n        self.wss = wss\n        # transformer encoder\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n        cur = 0\n        self.blocks = nn.ModuleList()\n        for k in range(len(depths)):\n            _block = nn.ModuleList([block_cls(\n                dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], qkv_bias=qkv_bias,\n                qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,\n                sr_ratio=sr_ratios[k], ws=1 if i % 2 == 1 else wss[k]) for i in range(depths[k])])\n            self.blocks.append(_block)\n            cur += depths[k]\n        self.apply(self._init_weights)\n\n\ndef _conv_filter(state_dict, patch_size=16):\n    \"\"\" convert patch embedding weight from manual patchify + linear proj to conv\"\"\"\n    out_dict = {}\n    for k, v in state_dict.items():\n        if 'patch_embed.proj.weight' in k:\n            v = v.reshape((v.shape[0], 3, patch_size, patch_size))\n        out_dict[k] = v\n\n    return out_dict\n\n\n@register_model\ndef pcpvt_small_v0(pretrained=False, **kwargs):\n    model = CPVTV2(\n        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],\n        **kwargs)\n    model.default_cfg = _cfg()\n    return model\n\n\n@register_model\ndef pcpvt_base_v0(pretrained=False, **kwargs):\n    model = CPVTV2(\n        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],\n        **kwargs)\n    model.default_cfg = _cfg()\n    return model\n\n\n@register_model\ndef pcpvt_large_v0(pretrained=False, **kwargs):\n    model = CPVTV2(\n        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],\n        **kwargs)\n    model.default_cfg = _cfg()\n    return model\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CPVTV2(\n        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1])\n    output=model(input)\n    print(output.shape)"
  },
  {
    "path": "model/backbone/CaiT.py",
    "content": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\n\nimport torch\nimport torch.nn as nn\nfrom functools import partial\n\nfrom timm.models.vision_transformer import Mlp, PatchEmbed , _cfg\nfrom timm.models.registry import register_model\nfrom timm.models.layers import trunc_normal_, DropPath\n\n\n__all__ = [\n    'cait_M48', 'cait_M36',\n    'cait_S36', 'cait_S24','cait_S24_224',\n    'cait_XS24','cait_XXS24','cait_XXS24_224',\n    'cait_XXS36','cait_XXS36_224'\n]\n\nclass Class_Attention(nn.Module):\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n    # with slight modifications to do CA \n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.q = nn.Linear(dim, dim, bias=qkv_bias)\n        self.k = nn.Linear(dim, dim, bias=qkv_bias)\n        self.v = nn.Linear(dim, dim, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    \n    def forward(self, x ):\n        \n        B, N, C = x.shape\n        q = self.q(x[:,0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)\n        k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)\n\n        q = q * self.scale\n        v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)\n\n        attn = (q @ k.transpose(-2, -1)) \n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C)\n        x_cls = self.proj(x_cls)\n        x_cls = self.proj_drop(x_cls)\n        \n        return x_cls     \n        \nclass LayerScale_Block_CA(nn.Module):\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n    # with slight modifications to add CA and LayerScale\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, Attention_block = Class_Attention,\n                 Mlp_block=Mlp,init_values=1e-4):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention_block(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n        self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n        self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n\n    \n    def forward(self, x, x_cls):\n        \n        u = torch.cat((x_cls,x),dim=1)\n        \n        \n        x_cls = x_cls + self.drop_path(self.gamma_1 * self.attn(self.norm1(u)))\n        \n        x_cls = x_cls + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x_cls)))\n        \n        return x_cls \n        \n        \nclass Attention_talking_head(nn.Module):\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n    # with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf)\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n        super().__init__()\n        \n        self.num_heads = num_heads\n        \n        head_dim = dim // num_heads\n        \n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        \n        self.proj = nn.Linear(dim, dim)\n        \n        self.proj_l = nn.Linear(num_heads, num_heads)\n        self.proj_w = nn.Linear(num_heads, num_heads)\n        \n        self.proj_drop = nn.Dropout(proj_drop)\n\n\n    \n    def forward(self, x):\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0] * self.scale , qkv[1], qkv[2] \n    \n        attn = (q @ k.transpose(-2, -1)) \n        \n        attn = self.proj_l(attn.permute(0,2,3,1)).permute(0,3,1,2)\n                \n        attn = attn.softmax(dim=-1)\n  \n        attn = self.proj_w(attn.permute(0,2,3,1)).permute(0,3,1,2)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n    \nclass LayerScale_Block(nn.Module):\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n    # with slight modifications to add layerScale\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,Attention_block = Attention_talking_head,\n                 Mlp_block=Mlp,init_values=1e-4):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention_block(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n        self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n        self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)\n\n    def forward(self, x):        \n        x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))\n        x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))\n        return x \n    \n    \n    \n    \nclass CaiT(nn.Module):\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n    # with slight modifications to adapt to our cait models\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,\n                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,\n                 drop_path_rate=0., norm_layer=nn.LayerNorm, global_pool=None,\n                 block_layers = LayerScale_Block,\n                 block_layers_token = LayerScale_Block_CA,\n                 Patch_layer=PatchEmbed,act_layer=nn.GELU,\n                 Attention_block = Attention_talking_head,Mlp_block=Mlp,\n                init_scale=1e-4,\n                Attention_block_token_only=Class_Attention,\n                Mlp_block_token_only= Mlp, \n                depth_token_only=2,\n                mlp_ratio_clstk = 4.0):\n        super().__init__()\n        \n\n            \n        self.num_classes = num_classes\n        self.num_features = self.embed_dim = embed_dim  \n\n        self.patch_embed = Patch_layer(\n                img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        \n        num_patches = self.patch_embed.num_patches\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        dpr = [drop_path_rate for i in range(depth)] \n        self.blocks = nn.ModuleList([\n            block_layers(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,\n                act_layer=act_layer,Attention_block=Attention_block,Mlp_block=Mlp_block,init_values=init_scale)\n            for i in range(depth)])\n        \n\n        self.blocks_token_only = nn.ModuleList([\n            block_layers_token(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer,\n                act_layer=act_layer,Attention_block=Attention_block_token_only,\n                Mlp_block=Mlp_block_token_only,init_values=init_scale)\n            for i in range(depth_token_only)])\n            \n        self.norm = norm_layer(embed_dim)\n\n\n        self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]\n        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n        trunc_normal_(self.pos_embed, std=.02)\n        trunc_normal_(self.cls_token, std=.02)\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token'}\n\n\n    def forward_features(self, x):\n        B = x.shape[0]\n        x = self.patch_embed(x)\n\n        cls_tokens = self.cls_token.expand(B, -1, -1)  \n        \n        x = x + self.pos_embed\n        x = self.pos_drop(x)\n\n        for i , blk in enumerate(self.blocks):\n            x = blk(x)\n            \n        for i , blk in enumerate(self.blocks_token_only):\n            cls_tokens = blk(x,cls_tokens)\n\n        x = torch.cat((cls_tokens, x), dim=1)\n            \n                \n        x = self.norm(x)\n        return x[:, 0]\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        \n        x = self.head(x)\n\n        return x \n        \n@register_model\ndef cait_XXS24_224(pretrained=False, **kwargs):\n    model = CaiT(\n        img_size= 224,patch_size=16, embed_dim=192, depth=24, num_heads=4, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_scale=1e-5,\n        depth_token_only=2,**kwargs)\n    \n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/deit/XXS24_224.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        checkpoint_no_module = {}\n        for k in model.state_dict().keys():\n            checkpoint_no_module[k] = checkpoint[\"model\"]['module.'+k]\n            \n        model.load_state_dict(checkpoint_no_module)\n        \n    return model \n\n@register_model\ndef cait_XXS24(pretrained=False, **kwargs):\n    model = CaiT(\n        img_size= 384,patch_size=16, embed_dim=192, depth=24, num_heads=4, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_scale=1e-5,\n        depth_token_only=2,**kwargs)\n    \n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/deit/XXS24_384.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        checkpoint_no_module = {}\n        for k in model.state_dict().keys():\n            checkpoint_no_module[k] = checkpoint[\"model\"]['module.'+k]\n            \n        model.load_state_dict(checkpoint_no_module)\n        \n    return model \n@register_model\ndef cait_XXS36_224(pretrained=False, **kwargs):\n    model = CaiT(\n        img_size= 224,patch_size=16, embed_dim=192, depth=36, num_heads=4, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_scale=1e-5,\n        depth_token_only=2,**kwargs)\n    \n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/deit/XXS36_224.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        checkpoint_no_module = {}\n        for k in model.state_dict().keys():\n            checkpoint_no_module[k] = checkpoint[\"model\"]['module.'+k]\n            \n        model.load_state_dict(checkpoint_no_module)\n        \n    return model \n\n@register_model\ndef cait_XXS36(pretrained=False, **kwargs):\n    model = CaiT(\n        img_size= 384,patch_size=16, embed_dim=192, depth=36, num_heads=4, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_scale=1e-5,\n        depth_token_only=2,**kwargs)\n    \n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/deit/XXS36_384.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        checkpoint_no_module = {}\n        for k in model.state_dict().keys():\n            checkpoint_no_module[k] = checkpoint[\"model\"]['module.'+k]\n            \n        model.load_state_dict(checkpoint_no_module)\n        \n    return model \n\n@register_model\ndef cait_XS24(pretrained=False, **kwargs):\n    model = CaiT(\n        img_size= 384,patch_size=16, embed_dim=288, depth=24, num_heads=6, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_scale=1e-5,\n        depth_token_only=2,**kwargs)\n    \n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/deit/XS24_384.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        checkpoint_no_module = {}\n        for k in model.state_dict().keys():\n            checkpoint_no_module[k] = checkpoint[\"model\"]['module.'+k]\n            \n        model.load_state_dict(checkpoint_no_module)\n        \n    return model \n\n\n\n\n@register_model\ndef cait_S24_224(pretrained=False, **kwargs):\n    model = CaiT(\n        img_size= 224,patch_size=16, embed_dim=384, depth=24, num_heads=8, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_scale=1e-5,\n        depth_token_only=2,**kwargs)\n    \n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/deit/S24_224.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        checkpoint_no_module = {}\n        for k in model.state_dict().keys():\n            checkpoint_no_module[k] = checkpoint[\"model\"]['module.'+k]\n            \n        model.load_state_dict(checkpoint_no_module)\n        \n    return model \n\n@register_model\ndef cait_S24(pretrained=False, **kwargs):\n    model = CaiT(\n        img_size= 384,patch_size=16, embed_dim=384, depth=24, num_heads=8, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_scale=1e-5,\n        depth_token_only=2,**kwargs)\n    \n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/deit/S24_384.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        checkpoint_no_module = {}\n        for k in model.state_dict().keys():\n            checkpoint_no_module[k] = checkpoint[\"model\"]['module.'+k]\n            \n        model.load_state_dict(checkpoint_no_module)\n        \n    return model \n\n@register_model\ndef cait_S36(pretrained=False, **kwargs):\n    model = CaiT(\n        img_size= 384,patch_size=16, embed_dim=384, depth=36, num_heads=8, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_scale=1e-6,\n        depth_token_only=2,**kwargs)\n    \n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/deit/S36_384.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        checkpoint_no_module = {}\n        for k in model.state_dict().keys():\n            checkpoint_no_module[k] = checkpoint[\"model\"]['module.'+k]\n            \n        model.load_state_dict(checkpoint_no_module)\n\n    return model \n\n@register_model\ndef cait_M36(pretrained=False, **kwargs):\n    model = CaiT(\n        img_size= 384, patch_size=16, embed_dim=768, depth=36, num_heads=16, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_scale=1e-6,\n        depth_token_only=2,**kwargs)\n    \n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/deit/M36_384.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        checkpoint_no_module = {}\n        for k in model.state_dict().keys():\n            checkpoint_no_module[k] = checkpoint[\"model\"]['module.'+k]\n            \n        model.load_state_dict(checkpoint_no_module)\n\n    return model \n\n\n@register_model\ndef cait_M48(pretrained=False, **kwargs):\n    model = CaiT(\n        img_size= 448 , patch_size=16, embed_dim=768, depth=48, num_heads=16, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_scale=1e-6,\n        depth_token_only=2,**kwargs)\n    \n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/deit/M48_448.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        checkpoint_no_module = {}\n        for k in model.state_dict().keys():\n            checkpoint_no_module[k] = checkpoint[\"model\"]['module.'+k]\n            \n        model.load_state_dict(checkpoint_no_module)\n        \n    return model         \n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CaiT(\n        img_size= 224,\n        patch_size=16, \n        embed_dim=192, \n        depth=24, \n        num_heads=4, \n        mlp_ratio=4, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_scale=1e-5,\n        depth_token_only=2\n        )\n    output=model(input)\n    print(output.shape)"
  },
  {
    "path": "model/backbone/CeiT.py",
    "content": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom functools import partial\n\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nfrom timm.models.registry import register_model\nfrom timm.models.vision_transformer import default_cfgs, _cfg\n\n\n__all__ = [\n    'ceit_tiny_patch16_224', 'ceit_small_patch16_224', 'ceit_base_patch16_224',\n    'ceit_tiny_patch16_384', 'ceit_small_patch16_384',\n]\n\n\nclass Image2Tokens(nn.Module):\n    def __init__(self, in_chans=3, out_chans=64, kernel_size=7, stride=2):\n        super(Image2Tokens, self).__init__()\n        self.conv = nn.Conv2d(in_chans, out_chans, kernel_size=kernel_size, stride=stride,\n                              padding=kernel_size // 2, bias=False)\n        self.bn = nn.BatchNorm2d(out_chans)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n\n    def forward(self, x):\n        x = self.conv(x)\n        x = self.bn(x)\n        x = self.maxpool(x)\n        return x\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass LocallyEnhancedFeedForward(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,\n                 kernel_size=3, with_bn=True):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        # pointwise\n        self.conv1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1, padding=0)\n        # depthwise\n        self.conv2 = nn.Conv2d(\n            hidden_features, hidden_features, kernel_size=kernel_size, stride=1,\n            padding=(kernel_size - 1) // 2, groups=hidden_features\n        )\n        # pointwise\n        self.conv3 = nn.Conv2d(hidden_features, out_features, kernel_size=1, stride=1, padding=0)\n        self.act = act_layer()\n        # self.drop = nn.Dropout(drop)\n\n        self.with_bn = with_bn\n        if self.with_bn:\n            self.bn1 = nn.BatchNorm2d(hidden_features)\n            self.bn2 = nn.BatchNorm2d(hidden_features)\n            self.bn3 = nn.BatchNorm2d(out_features)\n\n    def forward(self, x):\n        b, n, k = x.size()\n        cls_token, tokens = torch.split(x, [1, n - 1], dim=1)\n        x = tokens.reshape(b, int(math.sqrt(n - 1)), int(math.sqrt(n - 1)), k).permute(0, 3, 1, 2)\n        if self.with_bn:\n            x = self.conv1(x)\n            x = self.bn1(x)\n            x = self.act(x)\n            x = self.conv2(x)\n            x = self.bn2(x)\n            x = self.act(x)\n            x = self.conv3(x)\n            x = self.bn3(x)\n        else:\n            x = self.conv1(x)\n            x = self.act(x)\n            x = self.conv2(x)\n            x = self.act(x)\n            x = self.conv3(x)\n\n        tokens = x.flatten(2).permute(0, 2, 1)\n        out = torch.cat((cls_token, tokens), dim=1)\n        return out\n\n\nclass Attention(nn.Module):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n        self.attention_map = None\n\n    def forward(self, x):\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        # self.attention_map = attn\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass AttentionLCA(Attention):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n        super(AttentionLCA, self).__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop)\n        self.dim = dim\n        self.qkv_bias = qkv_bias\n        \n    def forward(self, x):\n\n        q_weight = self.qkv.weight[:self.dim, :]\n        q_bias = None if not self.qkv_bias else self.qkv.bias[:self.dim]\n        kv_weight = self.qkv.weight[self.dim:, :]\n        kv_bias = None if not self.qkv_bias else self.qkv.bias[self.dim:]\n        \n        B, N, C = x.shape\n        _, last_token = torch.split(x, [N-1, 1], dim=1)\n        \n        q = F.linear(last_token, q_weight, q_bias)\\\n             .reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)\n        kv = F.linear(x, kv_weight, kv_bias)\\\n              .reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        k, v = kv[0], kv[1]\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        # self.attention_map = attn\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, 1, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, kernel_size=3, with_bn=True, \n                 feedforward_type='leff'):\n        super().__init__()\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.norm1 = norm_layer(dim)\n        self.feedforward_type = feedforward_type\n\n        if feedforward_type == 'leff':\n            self.attn = Attention(\n                dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n            self.leff = LocallyEnhancedFeedForward(\n                in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,\n                kernel_size=kernel_size, with_bn=with_bn,\n            )\n        else:  # LCA\n            self.attn = AttentionLCA(\n                dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n            self.feedforward = Mlp(\n                in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop\n            )\n\n    def forward(self, x):\n        if self.feedforward_type == 'leff':\n            x = x + self.drop_path(self.attn(self.norm1(x)))\n            x = x + self.drop_path(self.leff(self.norm2(x)))\n            return x, x[:, 0]\n        else:  # LCA\n            _, last_token = torch.split(x, [x.size(1)-1, 1], dim=1)\n            x = last_token + self.drop_path(self.attn(self.norm1(x)))\n            x = x + self.drop_path(self.feedforward(self.norm2(x)))\n            return x\n\n\nclass HybridEmbed(nn.Module):\n    \"\"\" CNN Feature Map Embedding\n    Extract feature map from CNN, flatten, project to embedding dim.\n    \"\"\"\n    def __init__(self, backbone, img_size=224, patch_size=16, feature_size=None, in_chans=3, embed_dim=768):\n        super().__init__()\n        assert isinstance(backbone, nn.Module)\n        img_size = to_2tuple(img_size)\n        self.img_size = img_size\n        self.backbone = backbone\n        if feature_size is None:\n            with torch.no_grad():\n                # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature\n                # map for all networks, the feature metadata has reliable channel and stride info, but using\n                # stride to calc feature dim requires info about padding of each stage that isn't captured.\n                training = backbone.training\n                if training:\n                    backbone.eval()\n                o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))\n                if isinstance(o, (list, tuple)):\n                    o = o[-1]  # last feature if backbone outputs list/tuple of features\n                feature_size = o.shape[-2:]\n                feature_dim = o.shape[1]\n                backbone.train(training)\n        else:\n            feature_size = to_2tuple(feature_size)\n            feature_dim = self.backbone.feature_info.channels()[-1]\n        print('feature_size is {}, feature_dim is {}, patch_size is {}'.format(\n            feature_size, feature_dim, patch_size\n        ))\n        self.num_patches = (feature_size[0] // patch_size) * (feature_size[1] // patch_size)\n        self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n    def forward(self, x):\n        x = self.backbone(x)\n        if isinstance(x, (list, tuple)):\n            x = x[-1]  # last feature if backbone outputs list/tuple of features\n        x = self.proj(x).flatten(2).transpose(1, 2)\n        return x\n\n\nclass CeIT(nn.Module):\n    def __init__(self,\n                 img_size=224,\n                 patch_size=16,\n                 in_chans=3,\n                 num_classes=1000,\n                 embed_dim=768,\n                 depth=12,\n                 num_heads=12,\n                 mlp_ratio=4.,\n                 qkv_bias=False,\n                 qk_scale=None,\n                 drop_rate=0.,\n                 attn_drop_rate=0.,\n                 drop_path_rate=0.,\n                 hybrid_backbone=None,\n                 norm_layer=nn.LayerNorm,\n                 leff_local_size=3,\n                 leff_with_bn=True):\n        \"\"\"\n        args:\n            - img_size (:obj:`int`): input image size\n            - patch_size (:obj:`int`): patch size\n            - in_chans (:obj:`int`): input channels\n            - num_classes (:obj:`int`): number of classes\n            - embed_dim (:obj:`int`): embedding dimensions for tokens\n            - depth (:obj:`int`): depth of encoder\n            - num_heads (:obj:`int`): number of heads in multi-head self-attention\n            - mlp_ratio (:obj:`float`): expand ratio in feedforward\n            - qkv_bias (:obj:`bool`): whether to add bias for mlp of qkv\n            - qk_scale (:obj:`float`): scale ratio for qk, default is head_dim ** -0.5\n            - drop_rate (:obj:`float`): dropout rate in feedforward module after linear operation\n                and projection drop rate in attention\n            - attn_drop_rate (:obj:`float`): dropout rate for attention\n            - drop_path_rate (:obj:`float`): drop_path rate after attention\n            - hybrid_backbone (:obj:`nn.Module`): backbone e.g. resnet\n            - norm_layer (:obj:`nn.Module`): normalization type\n            - leff_local_size (:obj:`int`): kernel size in LocallyEnhancedFeedForward\n            - leff_with_bn (:obj:`bool`): whether add bn in LocallyEnhancedFeedForward\n        \"\"\"\n        super().__init__()\n        self.num_classes = num_classes\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n\n        self.i2t = HybridEmbed(\n            hybrid_backbone, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        num_patches = self.i2t.num_patches\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\n        self.blocks = nn.ModuleList([\n            Block(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,\n                kernel_size=leff_local_size, with_bn=leff_with_bn)\n            for i in range(depth)])\n\n        # without droppath\n        self.lca = Block(\n            dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=0., norm_layer=norm_layer,\n            feedforward_type = 'lca'\n        )\n        self.pos_layer_embed = nn.Parameter(torch.zeros(1, depth, embed_dim))\n\n        self.norm = norm_layer(embed_dim)\n\n        # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here\n        # self.repr = nn.Linear(embed_dim, representation_size)\n        # self.repr_act = nn.Tanh()\n\n        # Classifier head\n        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n        trunc_normal_(self.pos_embed, std=.02)\n        trunc_normal_(self.cls_token, std=.02)\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token'}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x):\n        B = x.shape[0]\n        x = self.i2t(x)\n\n        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks\n        x = torch.cat((cls_tokens, x), dim=1)\n        x = x + self.pos_embed\n        x = self.pos_drop(x)\n\n        cls_token_list = []\n        for blk in self.blocks:\n            x, curr_cls_token = blk(x)\n            cls_token_list.append(curr_cls_token)\n\n        all_cls_token = torch.stack(cls_token_list, dim=1)  # B*D*K\n        all_cls_token = all_cls_token + self.pos_layer_embed\n        # attention over cls tokens\n        last_cls_token = self.lca(all_cls_token)\n        last_cls_token = self.norm(last_cls_token)\n\n        return last_cls_token.view(B, -1)\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n\n@register_model\ndef ceit_tiny_patch16_224(pretrained=False, **kwargs):\n    \"\"\"\n    convolutional + pooling stem\n    local enhanced feedforward\n    attention over cls_tokens\n    \"\"\"\n    i2t = Image2Tokens()\n    model = CeIT(\n        hybrid_backbone=i2t,\n        patch_size=4, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    return model\n\n\n@register_model\ndef ceit_small_patch16_224(pretrained=False, **kwargs):\n    \"\"\"\n    convolutional + pooling stem\n    local enhanced feedforward\n    attention over cls_tokens\n    \"\"\"\n    i2t = Image2Tokens()\n    model = CeIT(\n        hybrid_backbone=i2t,\n        patch_size=4, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    return model\n\n\n@register_model\ndef ceit_base_patch16_224(pretrained=False, **kwargs):\n    \"\"\"\n    convolutional + pooling stem\n    local enhanced feedforward\n    attention over cls_tokens\n    \"\"\"\n    i2t = Image2Tokens()\n    model = CeIT(\n        hybrid_backbone=i2t,\n        patch_size=4, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    return model\n\n\n@register_model\ndef ceit_tiny_patch16_384(pretrained=False, **kwargs):\n    \"\"\"\n    convolutional + pooling stem\n    local enhanced feedforward\n    attention over cls_tokens\n    \"\"\"\n    i2t = Image2Tokens()\n    model = CeIT(\n        hybrid_backbone=i2t, img_size=384,\n        patch_size=4, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    return model\n\n\n@register_model\ndef ceit_small_patch16_384(pretrained=False, **kwargs):\n    \"\"\"\n    convolutional + pooling stem\n    local enhanced feedforward\n    attention over cls_tokens\n    \"\"\"\n    i2t = Image2Tokens()\n    model = CeIT(\n        hybrid_backbone=i2t, img_size=384,\n        patch_size=4, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    return model\n\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CeIT(\n        hybrid_backbone=Image2Tokens(),\n        patch_size=4, \n        embed_dim=192, \n        depth=12, \n        num_heads=3, \n        mlp_ratio=4, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n        )\n    output=model(input)\n    print(output.shape)"
  },
  {
    "path": "model/backbone/CoaT.py",
    "content": "\"\"\" \nCoaT architecture.\n\nModified from timm/models/vision_transformer.py\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nfrom timm.models.registry import register_model\n\nfrom einops import rearrange\nfrom functools import partial\nfrom torch import nn, einsum\n\n__all__ = [\n    \"coat_tiny\",\n    \"coat_mini\",\n    \"coat_small\",\n    \"coat_lite_tiny\",\n    \"coat_lite_mini\",\n    \"coat_lite_small\"\n]\n\n\ndef _cfg_coat(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n        'crop_pct': .9, 'interpolation': 'bicubic',\n        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,\n        'first_conv': 'patch_embed.proj', 'classifier': 'head',\n        **kwargs\n    }\n\n\nclass Mlp(nn.Module):\n    \"\"\" Feed-forward network (FFN, a.k.a. MLP) class. \"\"\"\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass ConvRelPosEnc(nn.Module):\n    \"\"\" Convolutional relative position encoding. \"\"\"\n    def __init__(self, Ch, h, window):\n        \"\"\"\n        Initialization.\n            Ch: Channels per head.\n            h: Number of heads.\n            window: Window size(s) in convolutional relative positional encoding. It can have two forms:\n                    1. An integer of window size, which assigns all attention heads with the same window size in ConvRelPosEnc.\n                    2. A dict mapping window size to #attention head splits (e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2})\n                       It will apply different window size to the attention head splits.\n        \"\"\"\n        super().__init__()\n\n        if isinstance(window, int):\n            window = {window: h}                                                         # Set the same window size for all attention heads.\n            self.window = window\n        elif isinstance(window, dict):\n            self.window = window\n        else:\n            raise ValueError()            \n        \n        self.conv_list = nn.ModuleList()\n        self.head_splits = []\n        for cur_window, cur_head_split in window.items():\n            dilation = 1                                                                 # Use dilation=1 at default.\n            padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2         # Determine padding size. Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338\n            cur_conv = nn.Conv2d(cur_head_split*Ch, cur_head_split*Ch,\n                kernel_size=(cur_window, cur_window), \n                padding=(padding_size, padding_size),\n                dilation=(dilation, dilation),                          \n                groups=cur_head_split*Ch,\n            )\n            self.conv_list.append(cur_conv)\n            self.head_splits.append(cur_head_split)\n        self.channel_splits = [x*Ch for x in self.head_splits]\n\n    def forward(self, q, v, size):\n        B, h, N, Ch = q.shape\n        H, W = size\n        assert N == 1 + H * W\n\n        # Convolutional relative position encoding.\n        q_img = q[:,:,1:,:]                                                              # Shape: [B, h, H*W, Ch].\n        v_img = v[:,:,1:,:]                                                              # Shape: [B, h, H*W, Ch].\n        \n        v_img = rearrange(v_img, 'B h (H W) Ch -> B (h Ch) H W', H=H, W=W)               # Shape: [B, h, H*W, Ch] -> [B, h*Ch, H, W].\n        v_img_list = torch.split(v_img, self.channel_splits, dim=1)                      # Split according to channels.\n        conv_v_img_list = [conv(x) for conv, x in zip(self.conv_list, v_img_list)]\n        conv_v_img = torch.cat(conv_v_img_list, dim=1)\n        conv_v_img = rearrange(conv_v_img, 'B (h Ch) H W -> B h (H W) Ch', h=h)          # Shape: [B, h*Ch, H, W] -> [B, h, H*W, Ch].\n\n        EV_hat_img = q_img * conv_v_img\n        zero = torch.zeros((B, h, 1, Ch), dtype=q.dtype, layout=q.layout, device=q.device)\n        EV_hat = torch.cat((zero, EV_hat_img), dim=2)                                # Shape: [B, h, N, Ch].\n\n        return EV_hat\n\n\nclass FactorAtt_ConvRelPosEnc(nn.Module):\n    \"\"\" Factorized attention with convolutional relative position encoding class. \"\"\"\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., shared_crpe=None):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)                                           # Note: attn_drop is actually not used.\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        # Shared convolutional relative position encoding.\n        self.crpe = shared_crpe\n\n    def forward(self, x, size):\n        B, N, C = x.shape\n\n        # Generate Q, K, V.\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)  # Shape: [3, B, h, N, Ch].\n        q, k, v = qkv[0], qkv[1], qkv[2]                                                 # Shape: [B, h, N, Ch].\n\n        # Factorized attention.\n        k_softmax = k.softmax(dim=2)                                                     # Softmax on dim N.\n        k_softmax_T_dot_v = einsum('b h n k, b h n v -> b h k v', k_softmax, v)          # Shape: [B, h, Ch, Ch].\n        factor_att        = einsum('b h n k, b h k v -> b h n v', q, k_softmax_T_dot_v)  # Shape: [B, h, N, Ch].\n\n        # Convolutional relative position encoding.\n        crpe = self.crpe(q, v, size=size)                                                # Shape: [B, h, N, Ch].\n\n        # Merge and reshape.\n        x = self.scale * factor_att + crpe\n        x = x.transpose(1, 2).reshape(B, N, C)                                           # Shape: [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C].\n\n        # Output projection.\n        x = self.proj(x)\n        x = self.proj_drop(x)\n\n        return x                                                                         # Shape: [B, N, C].\n\n\nclass ConvPosEnc(nn.Module):\n    \"\"\" Convolutional Position Encoding. \n        Note: This module is similar to the conditional position encoding in CPVT.\n    \"\"\"\n    def __init__(self, dim, k=3):\n        super(ConvPosEnc, self).__init__()\n        self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim) \n    \n    def forward(self, x, size):\n        B, N, C = x.shape\n        H, W = size\n        assert N == 1 + H * W\n\n        # Extract CLS token and image tokens.\n        cls_token, img_tokens = x[:, :1], x[:, 1:]                                       # Shape: [B, 1, C], [B, H*W, C].\n        \n        # Depthwise convolution.\n        feat = img_tokens.transpose(1, 2).view(B, C, H, W)\n        x = self.proj(feat) + feat\n        x = x.flatten(2).transpose(1, 2)\n\n        # Combine with CLS token.\n        x = torch.cat((cls_token, x), dim=1)\n\n        return x\n\n\nclass SerialBlock(nn.Module):\n    \"\"\" Serial block class.\n        Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. \"\"\"\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 shared_cpe=None, shared_crpe=None):\n        super().__init__()\n\n        # Conv-Attention.\n        self.cpe = shared_cpe\n\n        self.norm1 = norm_layer(dim)\n        self.factoratt_crpe = FactorAtt_ConvRelPosEnc(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, \n            shared_crpe=shared_crpe)\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n        # MLP.\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n    def forward(self, x, size):\n        # Conv-Attention.\n        x = self.cpe(x, size)                  # Apply convolutional position encoding.\n        cur = self.norm1(x)\n        cur = self.factoratt_crpe(cur, size)   # Apply factorized attention and convolutional relative position encoding.\n        x = x + self.drop_path(cur) \n\n        # MLP. \n        cur = self.norm2(x)\n        cur = self.mlp(cur)\n        x = x + self.drop_path(cur)\n\n        return x\n\n\nclass ParallelBlock(nn.Module):\n    \"\"\" Parallel block class. \"\"\"\n    def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 shared_cpes=None, shared_crpes=None):\n        super().__init__()\n\n        # Conv-Attention.\n        self.cpes = shared_cpes\n\n        self.norm12 = norm_layer(dims[1])\n        self.norm13 = norm_layer(dims[2])\n        self.norm14 = norm_layer(dims[3])\n        self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc(\n            dims[1], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, \n            shared_crpe=shared_crpes[1]\n        )\n        self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc(\n            dims[2], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, \n            shared_crpe=shared_crpes[2]\n        )\n        self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc(\n            dims[3], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, \n            shared_crpe=shared_crpes[3]\n        )\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n        # MLP.\n        self.norm22 = norm_layer(dims[1])\n        self.norm23 = norm_layer(dims[2])\n        self.norm24 = norm_layer(dims[3])\n        assert dims[1] == dims[2] == dims[3]                              # In parallel block, we assume dimensions are the same and share the linear transformation.\n        assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3]\n        mlp_hidden_dim = int(dims[1] * mlp_ratios[1])\n        self.mlp2 = self.mlp3 = self.mlp4 = Mlp(in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n    def upsample(self, x, output_size, size):\n        \"\"\" Feature map up-sampling. \"\"\"\n        return self.interpolate(x, output_size=output_size, size=size)\n\n    def downsample(self, x, output_size, size):\n        \"\"\" Feature map down-sampling. \"\"\"\n        return self.interpolate(x, output_size=output_size, size=size)\n\n    def interpolate(self, x, output_size, size):\n        \"\"\" Feature map interpolation. \"\"\"\n        B, N, C = x.shape\n        H, W = size\n        assert N == 1 + H * W\n\n        cls_token  = x[:, :1, :]\n        img_tokens = x[:, 1:, :]\n        \n        img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W)\n        img_tokens = F.interpolate(img_tokens, size=output_size, mode='bilinear')  # FIXME: May have alignment issue.\n        img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2)\n        \n        out = torch.cat((cls_token, img_tokens), dim=1)\n\n        return out\n\n    def forward(self, x1, x2, x3, x4, sizes):\n        _, (H2, W2), (H3, W3), (H4, W4) = sizes\n        \n        # Conv-Attention.\n        x2 = self.cpes[1](x2, size=(H2, W2))  # Note: x1 is ignored.\n        x3 = self.cpes[2](x3, size=(H3, W3))\n        x4 = self.cpes[3](x4, size=(H4, W4))\n        \n        cur2 = self.norm12(x2)\n        cur3 = self.norm13(x3)\n        cur4 = self.norm14(x4)\n        cur2 = self.factoratt_crpe2(cur2, size=(H2,W2))\n        cur3 = self.factoratt_crpe3(cur3, size=(H3,W3))\n        cur4 = self.factoratt_crpe4(cur4, size=(H4,W4))\n        upsample3_2 = self.upsample(cur3, output_size=(H2,W2), size=(H3,W3))\n        upsample4_3 = self.upsample(cur4, output_size=(H3,W3), size=(H4,W4))\n        upsample4_2 = self.upsample(cur4, output_size=(H2,W2), size=(H4,W4))\n        downsample2_3 = self.downsample(cur2, output_size=(H3,W3), size=(H2,W2))\n        downsample3_4 = self.downsample(cur3, output_size=(H4,W4), size=(H3,W3))\n        downsample2_4 = self.downsample(cur2, output_size=(H4,W4), size=(H2,W2))\n        cur2 = cur2  + upsample3_2   + upsample4_2\n        cur3 = cur3  + upsample4_3   + downsample2_3\n        cur4 = cur4  + downsample3_4 + downsample2_4\n        x2 = x2 + self.drop_path(cur2) \n        x3 = x3 + self.drop_path(cur3) \n        x4 = x4 + self.drop_path(cur4) \n\n        # MLP. \n        cur2 = self.norm22(x2)\n        cur3 = self.norm23(x3)\n        cur4 = self.norm24(x4)\n        cur2 = self.mlp2(cur2)\n        cur3 = self.mlp3(cur3)\n        cur4 = self.mlp4(cur4)\n        x2 = x2 + self.drop_path(cur2)\n        x3 = x3 + self.drop_path(cur3)\n        x4 = x4 + self.drop_path(cur4) \n\n        return x1, x2, x3, x4\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding \"\"\"\n    def __init__(self, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        patch_size = to_2tuple(patch_size)\n\n        self.patch_size = patch_size\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        self.norm = nn.LayerNorm(embed_dim)\n\n    def forward(self, x):\n        _, _, H, W = x.shape\n        out_H, out_W = H // self.patch_size[0], W // self.patch_size[1]\n\n        x = self.proj(x).flatten(2).transpose(1, 2)\n        out = self.norm(x)\n        \n        return out, (out_H, out_W)\n\n\nclass CoaT(nn.Module):\n    \"\"\" CoaT class. \"\"\"\n    def __init__(self, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[0, 0, 0, 0], \n                 serial_depths=[0, 0, 0, 0], parallel_depth=0,\n                 num_heads=0, mlp_ratios=[0, 0, 0, 0], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,\n                 drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),\n                 return_interm_layers=False, out_features=None, crpe_window={3:2, 5:3, 7:3},\n                 **kwargs):\n        super().__init__()\n        self.return_interm_layers = return_interm_layers\n        self.out_features = out_features\n        self.num_classes = num_classes\n\n        # Patch embeddings.\n        self.patch_embed1 = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0])\n        self.patch_embed2 = PatchEmbed(patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])\n        self.patch_embed3 = PatchEmbed(patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])\n        self.patch_embed4 = PatchEmbed(patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3])\n\n        # Class tokens.\n        self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0]))\n        self.cls_token2 = nn.Parameter(torch.zeros(1, 1, embed_dims[1]))\n        self.cls_token3 = nn.Parameter(torch.zeros(1, 1, embed_dims[2]))\n        self.cls_token4 = nn.Parameter(torch.zeros(1, 1, embed_dims[3]))\n\n        # Convolutional position encodings.\n        self.cpe1 = ConvPosEnc(dim=embed_dims[0], k=3)\n        self.cpe2 = ConvPosEnc(dim=embed_dims[1], k=3)\n        self.cpe3 = ConvPosEnc(dim=embed_dims[2], k=3)\n        self.cpe4 = ConvPosEnc(dim=embed_dims[3], k=3)\n\n        # Convolutional relative position encodings.\n        self.crpe1 = ConvRelPosEnc(Ch=embed_dims[0] // num_heads, h=num_heads, window=crpe_window)\n        self.crpe2 = ConvRelPosEnc(Ch=embed_dims[1] // num_heads, h=num_heads, window=crpe_window)\n        self.crpe3 = ConvRelPosEnc(Ch=embed_dims[2] // num_heads, h=num_heads, window=crpe_window)\n        self.crpe4 = ConvRelPosEnc(Ch=embed_dims[3] // num_heads, h=num_heads, window=crpe_window)\n\n        # Enable stochastic depth.\n        dpr = drop_path_rate\n        \n        # Serial blocks 1.\n        self.serial_blocks1 = nn.ModuleList([\n            SerialBlock(\n                dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, \n                shared_cpe=self.cpe1, shared_crpe=self.crpe1\n            )\n            for _ in range(serial_depths[0])]\n        )\n\n        # Serial blocks 2.\n        self.serial_blocks2 = nn.ModuleList([\n            SerialBlock(\n                dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, \n                shared_cpe=self.cpe2, shared_crpe=self.crpe2\n            )\n            for _ in range(serial_depths[1])]\n        )\n\n        # Serial blocks 3.\n        self.serial_blocks3 = nn.ModuleList([\n            SerialBlock(\n                dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, \n                shared_cpe=self.cpe3, shared_crpe=self.crpe3\n            )\n            for _ in range(serial_depths[2])]\n        )\n\n        # Serial blocks 4.\n        self.serial_blocks4 = nn.ModuleList([\n            SerialBlock(\n                dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, \n                shared_cpe=self.cpe4, shared_crpe=self.crpe4\n            )\n            for _ in range(serial_depths[3])]\n        )\n\n        # Parallel blocks.\n        self.parallel_depth = parallel_depth\n        if self.parallel_depth > 0:\n            self.parallel_blocks = nn.ModuleList([\n                ParallelBlock(\n                    dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                    drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, \n                    shared_cpes=[self.cpe1, self.cpe2, self.cpe3, self.cpe4],\n                    shared_crpes=[self.crpe1, self.crpe2, self.crpe3, self.crpe4]\n                )\n                for _ in range(parallel_depth)]\n            )\n\n        # Classification head(s).\n        if not self.return_interm_layers:\n            self.norm1 = norm_layer(embed_dims[0])\n            self.norm2 = norm_layer(embed_dims[1])\n            self.norm3 = norm_layer(embed_dims[2])\n            self.norm4 = norm_layer(embed_dims[3])\n\n            if self.parallel_depth > 0:                                  # CoaT series: Aggregate features of last three scales for classification.\n                assert embed_dims[1] == embed_dims[2] == embed_dims[3]\n                self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1)\n                self.head = nn.Linear(embed_dims[3], num_classes)\n            else:\n                self.head = nn.Linear(embed_dims[3], num_classes)        # CoaT-Lite series: Use feature of last scale for classification.\n\n        # Initialize weights.\n        trunc_normal_(self.cls_token1, std=.02)\n        trunc_normal_(self.cls_token2, std=.02)\n        trunc_normal_(self.cls_token3, std=.02)\n        trunc_normal_(self.cls_token4, std=.02)\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'cls_token1', 'cls_token2', 'cls_token3', 'cls_token4'}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n    def insert_cls(self, x, cls_token):\n        \"\"\" Insert CLS token. \"\"\"\n        cls_tokens = cls_token.expand(x.shape[0], -1, -1)\n        x = torch.cat((cls_tokens, x), dim=1)\n        return x\n\n    def remove_cls(self, x):\n        \"\"\" Remove CLS token. \"\"\"\n        return x[:, 1:, :]\n\n    def forward_features(self, x0):\n        B = x0.shape[0]\n\n        # Serial blocks 1.\n        x1, (H1, W1) = self.patch_embed1(x0)\n        x1 = self.insert_cls(x1, self.cls_token1)\n        for blk in self.serial_blocks1:\n            x1 = blk(x1, size=(H1, W1))\n        x1_nocls = self.remove_cls(x1)\n        x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()\n        \n        # Serial blocks 2.\n        x2, (H2, W2) = self.patch_embed2(x1_nocls)\n        x2 = self.insert_cls(x2, self.cls_token2)\n        for blk in self.serial_blocks2:\n            x2 = blk(x2, size=(H2, W2))\n        x2_nocls = self.remove_cls(x2)\n        x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()\n\n        # Serial blocks 3.\n        x3, (H3, W3) = self.patch_embed3(x2_nocls)\n        x3 = self.insert_cls(x3, self.cls_token3)\n        for blk in self.serial_blocks3:\n            x3 = blk(x3, size=(H3, W3))\n        x3_nocls = self.remove_cls(x3)\n        x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()\n\n        # Serial blocks 4.\n        x4, (H4, W4) = self.patch_embed4(x3_nocls)\n        x4 = self.insert_cls(x4, self.cls_token4)\n        for blk in self.serial_blocks4:\n            x4 = blk(x4, size=(H4, W4))\n        x4_nocls = self.remove_cls(x4)\n        x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()\n\n        # Only serial blocks: Early return.\n        if self.parallel_depth == 0:\n            if self.return_interm_layers:   # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).\n                feat_out = {}   \n                if 'x1_nocls' in self.out_features:\n                    feat_out['x1_nocls'] = x1_nocls\n                if 'x2_nocls' in self.out_features:\n                    feat_out['x2_nocls'] = x2_nocls\n                if 'x3_nocls' in self.out_features:\n                    feat_out['x3_nocls'] = x3_nocls\n                if 'x4_nocls' in self.out_features:\n                    feat_out['x4_nocls'] = x4_nocls\n                return feat_out\n            else:                           # Return features for classification.\n                x4 = self.norm4(x4)\n                x4_cls = x4[:, 0]\n                return x4_cls\n\n        # Parallel blocks.\n        for blk in self.parallel_blocks:\n            x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)])\n\n        if self.return_interm_layers:       # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).\n            feat_out = {}   \n            if 'x1_nocls' in self.out_features:\n                x1_nocls = self.remove_cls(x1)\n                x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()\n                feat_out['x1_nocls'] = x1_nocls\n            if 'x2_nocls' in self.out_features:\n                x2_nocls = self.remove_cls(x2)\n                x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()\n                feat_out['x2_nocls'] = x2_nocls\n            if 'x3_nocls' in self.out_features:\n                x3_nocls = self.remove_cls(x3)\n                x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()\n                feat_out['x3_nocls'] = x3_nocls\n            if 'x4_nocls' in self.out_features:\n                x4_nocls = self.remove_cls(x4)\n                x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()\n                feat_out['x4_nocls'] = x4_nocls\n            return feat_out\n        else:\n            x2 = self.norm2(x2)\n            x3 = self.norm3(x3)\n            x4 = self.norm4(x4)\n            x2_cls = x2[:, :1]              # Shape: [B, 1, C].\n            x3_cls = x3[:, :1]\n            x4_cls = x4[:, :1]\n            merged_cls = torch.cat((x2_cls, x3_cls, x4_cls), dim=1)       # Shape: [B, 3, C].\n            merged_cls = self.aggregate(merged_cls).squeeze(dim=1)        # Shape: [B, C].\n            return merged_cls\n\n    def forward(self, x):\n        if self.return_interm_layers:       # Return intermediate features (for down-stream tasks).\n            return self.forward_features(x)\n        else:                               # Return features for classification.\n            x = self.forward_features(x) \n            x = self.head(x)\n            return x\n\n\n# CoaT.\n@register_model\ndef coat_tiny(**kwargs):\n    model = CoaT(patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)\n    model.default_cfg = _cfg_coat()\n    return model\n\n@register_model\ndef coat_mini(**kwargs):\n    model = CoaT(patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)\n    model.default_cfg = _cfg_coat()\n    return model\n\n@register_model\ndef coat_small(**kwargs):\n    model = CoaT(patch_size=4, embed_dims=[152, 320, 320, 320], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)\n    model.default_cfg = _cfg_coat()\n    return model\n\n# CoaT-Lite.\n@register_model\ndef coat_lite_tiny(**kwargs):\n    model = CoaT(patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)\n    model.default_cfg = _cfg_coat()\n    return model\n\n@register_model\ndef coat_lite_mini(**kwargs):\n    model = CoaT(patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)\n    model.default_cfg = _cfg_coat()\n    return model\n\n@register_model\ndef coat_lite_small(**kwargs):\n    model = CoaT(patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)\n    model.default_cfg = _cfg_coat()\n    return model\n\n@register_model\ndef coat_lite_medium(**kwargs):\n    model = CoaT(patch_size=4, embed_dims=[128, 256, 320, 512], serial_depths=[3, 6, 10, 8], parallel_depth=0, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)\n    model.default_cfg = _cfg_coat()\n    return model\n\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CoaT(patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4])\n    output=model(input)\n    print(output.shape) # torch.Size([1, 1000])"
  },
  {
    "path": "model/backbone/ConTNet.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as F\nimport torch\n\nfrom einops.layers.torch import Rearrange\nfrom einops import rearrange\n\nimport numpy as np\n\nfrom typing import Any, List\nimport math\nimport warnings\nfrom collections import OrderedDict\n\n__all__ = ['ConTBlock', 'ConTNet']\n\n\nr\"\"\" The following trunc_normal method is pasted from timm https://github.com/rwightman/pytorch-image-models/tree/master/timm      \n\"\"\"\ndef _no_grad_trunc_normal_(tensor, mean, std, a, b):\n    \n    # Cut & paste from PyTorch official master until it's in a few official releases - RW\n    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf\n    def norm_cdf(x):\n        # Computes standard normal cumulative distribution function\n        return (1. + math.erf(x / math.sqrt(2.))) / 2.\n\n    if (mean < a - 2 * std) or (mean > b + 2 * std):\n        warnings.warn(\"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. \"\n                      \"The distribution of values may be incorrect.\",\n                      stacklevel=2)\n\n    with torch.no_grad():\n        # Values are generated by using a truncated uniform distribution and\n        # then using the inverse CDF for the normal distribution.\n        # Get upper and lower cdf values\n        l = norm_cdf((a - mean) / std)\n        u = norm_cdf((b - mean) / std)\n\n        # Uniformly fill tensor with values from [l, u], then translate to\n        # [2l-1, 2u-1].\n        tensor.uniform_(2 * l - 1, 2 * u - 1)\n\n        # Use inverse cdf transform for normal distribution to get truncated\n        # standard normal\n        tensor.erfinv_()\n\n        # Transform to proper mean, std\n        tensor.mul_(std * math.sqrt(2.))\n        tensor.add_(mean)\n\n        # Clamp to ensure it's in the proper range\n        tensor.clamp_(min=a, max=b)\n        return tensor\n\ndef trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):\n    # type: (Tensor, float, float, float, float) -> Tensor\n    r\"\"\"Fills the input Tensor with values drawn from a truncated\n    normal distribution. The values are effectively drawn from the\n    normal distribution :math:`\\mathcal{N}(\\text{mean}, \\text{std}^2)`\n    with values outside :math:`[a, b]` redrawn until they are within\n    the bounds. The method used for generating the random values works\n    best when :math:`a \\leq \\text{mean} \\leq b`.\n    Args:\n        tensor: an n-dimensional `torch.Tensor`\n        mean: the mean of the normal distribution\n        std: the standard deviation of the normal distribution\n        a: the minimum cutoff value\n        b: the maximum cutoff value\n    Examples:\n        >>> w = torch.empty(3, 5)\n        >>> nn.init.trunc_normal_(w)\n    \"\"\"\n    return _no_grad_trunc_normal_(tensor, mean, std, a, b)\n\ndef fixed_padding(inputs, kernel_size, dilation):\n    kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)\n    pad_total = kernel_size_effective - 1\n    pad_beg = pad_total // 2\n    pad_end = pad_total - pad_beg\n    padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))\n    return padded_inputs\n\nclass ConvBN(nn.Sequential):\n    def __init__(self, in_planes, out_planes, kernel_size, stride=1, groups=1, bn=True):\n        padding = (kernel_size - 1) // 2\n        if bn:\n            super(ConvBN, self).__init__(OrderedDict([\n                ('conv', nn.Conv2d(in_planes, out_planes, kernel_size, stride,\n                                padding=padding, groups=groups, bias=False)),\n                ('bn', nn.BatchNorm2d(out_planes))\n            ]))\n        else:\n            super(ConvBN, self).__init__(OrderedDict([\n                ('conv', nn.Conv2d(in_planes, out_planes, kernel_size, stride,\n                                padding=padding, groups=groups, bias=False)),\n            ]))\n\nclass MHSA(nn.Module):\n    r\"\"\"\n    Build a Multi-Head Self-Attention:\n        - https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n    \"\"\"\n    def __init__(self,\n                 planes,\n                 head_num,\n                 dropout,\n                 patch_size,\n                 qkv_bias,\n                 relative):\n        super(MHSA, self).__init__()\n        self.head_num = head_num \n        head_dim = planes // head_num\n        self.qkv = nn.Linear(planes, 3*planes, bias=qkv_bias)\n        self.relative = relative\n        self.patch_size = patch_size\n        self.scale = head_dim ** -0.5\n        \n        if self.relative:\n            # print('### relative position embedding ###')\n            self.relative_position_bias_table = nn.Parameter(\n                torch.zeros((2 * patch_size - 1) * (2 * patch_size - 1), head_num))  \n            coords_w = coords_h = torch.arange(patch_size)\n            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  \n            coords_flatten = torch.flatten(coords, 1)  \n            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  \n            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  \n            relative_coords[:, :, 0] += patch_size - 1 \n            relative_coords[:, :, 1] += patch_size - 1\n            relative_coords[:, :, 0] *= 2 * patch_size - 1\n            relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n            self.register_buffer(\"relative_position_index\", relative_position_index)\n            trunc_normal_(self.relative_position_bias_table, std=.02)\n\n        self.attn_drop = nn.Dropout(p=dropout)\n        self.proj = nn.Linear(planes, planes)\n        self.proj_drop = nn.Dropout(p=dropout)\n\n    def forward(self, x):\n        B, N, C, H = *x.shape, self.head_num\n        # print(x.shape)\n        qkv = self.qkv(x).reshape(B, N, 3, H, C // H).permute(2, 0, 3, 1, 4) # x: (3, B, H, N, C//H)\n        q, k, v = qkv[0], qkv[1], qkv[2]  # x: (B, H, N, C//N) \n\n        q = q * self.scale \n        attn = (q @ k.transpose(-2, -1)) # attn: (B, H, N, N)\n\n        if self.relative:\n            relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n                self.patch_size ** 2, self.patch_size ** 2, -1)  \n            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() \n            attn = attn + relative_position_bias.unsqueeze(0)\n\n        attn = attn.softmax(dim=-1) \n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C) \n        x = self.proj(x)\n        x = self.proj_drop(x)\n\n        return x\n\nclass MLP(nn.Module):\n    r\"\"\"\n    Build a Multi-Layer Perceptron\n        - https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n    \"\"\"\n    def __init__(self,\n                 planes,\n                 mlp_dim,\n                 dropout):\n        super(MLP, self).__init__()\n\n        self.fc1 = nn.Linear(planes, mlp_dim)\n        self.act = nn.GELU()\n        self.fc2 = nn.Linear(mlp_dim, planes)\n        self.drop = nn.Dropout(dropout)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n\n        return x\n        \n\nclass STE(nn.Module):\n    r\"\"\"\n    Build a Standard Transformer Encoder(STE)\n    input: Tensor (b, c, h, w)\n    output: Tensor (b, c, h, w)\n    \"\"\"\n    def __init__(self,\n                 planes: int,\n                 mlp_dim: int,\n                 head_num: int,\n                 dropout: float,\n                 patch_size: int,\n                 relative: bool,\n                 qkv_bias: bool,\n                 pre_norm: bool,\n                 **kwargs):\n        super(STE, self).__init__()\n        self.patch_size = patch_size\n        self.pre_norm = pre_norm\n        self.relative = relative\n\n        self.flatten = nn.Sequential(\n                Rearrange('b c pnh pnw psh psw -> (b pnh pnw) psh psw c'),\n            )\n        if not relative:\n            self.pe = nn.ParameterList(\n                    [nn.Parameter(torch.zeros(1, patch_size, 1, planes//2)), nn.Parameter(torch.zeros(1, 1, patch_size, planes//2))]\n                )\n        self.attn = MHSA(planes, head_num, dropout, patch_size, qkv_bias=qkv_bias, relative=relative)\n        self.mlp = MLP(planes, mlp_dim, dropout=dropout)\n        self.norm1 = nn.LayerNorm(planes)\n        self.norm2 = nn.LayerNorm(planes)\n\n    def forward(self, x):\n        bs, c, h, w = x.shape \n        patch_size = self.patch_size\n        patch_num_h, patch_num_w = h // patch_size, w // patch_size\n    \n        x = (\n            x.unfold(2, self.patch_size, self.patch_size)\n                .unfold(3, self.patch_size, self.patch_size)\n        ) # x: (b, c, patch_num, patch_num, patch_size, patch_size)\n        x = self.flatten(x) # x: (b, patch_size, patch_size, c)\n        ### add 2d position embedding ###\n        if not self.relative:\n            x_h, x_w = x.split(c // 2, dim=3)\n            x = torch.cat((x_h + self.pe[0], x_w + self.pe[1]), dim=3) # x: (b, patch_size, patch_size, c)\n        \n        x = rearrange(x, 'b psh psw c -> b (psh psw) c')\n\n        if self.pre_norm:\n            x = x + self.attn(self.norm1(x))\n            x = x + self.mlp(self.norm2(x))\n        else:\n            x = self.norm1(x + self.attn(x))\n            x = self.norm2(x + self.mlp(x))\n\n        x = rearrange(x, '(b pnh pnw) (psh psw) c -> b c (pnh psh) (pnw psw)', pnh=patch_num_h, pnw=patch_num_w, psh=patch_size, psw=patch_size) \n\n        return x\n\nclass ConTBlock(nn.Module):\n    r\"\"\"\n    Build a ConTBlock\n    \"\"\"\n    def __init__(self,\n                 planes: int,\n                 out_planes: int,\n                 mlp_dim: int,\n                 head_num: int,\n                 dropout: float,\n                 patch_size: List[int],\n                 downsample: nn.Module = None, \n                 stride: int=1,\n                 last_dropout: float=0.3,\n                 **kwargs):\n        super(ConTBlock, self).__init__()\n        self.downsample = downsample\n        self.identity = nn.Identity()\n        self.dropout = nn.Identity()\n        \n        self.bn = nn.BatchNorm2d(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.ste1 = STE(planes=planes, mlp_dim=mlp_dim, head_num=head_num, dropout=dropout, patch_size=patch_size[0], **kwargs)\n        self.ste2 = STE(planes=planes, mlp_dim=mlp_dim, head_num=head_num, dropout=dropout, patch_size=patch_size[1], **kwargs)\n        \n        if stride == 1 and downsample is not None: \n            self.dropout = nn.Dropout(p=last_dropout)\n            kernel_size = 1\n        else:\n            kernel_size = 3\n\n        self.out_conv = ConvBN(planes, out_planes, kernel_size, stride, bn=False)\n\n    def forward(self, x):\n        x_preact = self.relu(self.bn(x))\n        identity = self.identity(x)\n\n        if self.downsample is not None:\n            identity = self.downsample(x_preact)\n\n        residual = self.ste1(x_preact)\n        residual = self.ste2(residual)\n        residual = self.out_conv(residual)\n        out = self.dropout(residual+identity)\n\n        return out\n\nclass ConTNet(nn.Module):\n    r\"\"\"\n    Build a ConTNet backbone\n    \"\"\"\n    def __init__(self, \n                 block,\n                 layers: List[int],\n                 mlp_dim: List[int],\n                 head_num: List[int],\n                 dropout: List[float],\n                 in_channels: int=3,\n                 inplanes: int=64,\n                 num_classes: int=1000,\n                 init_weights: bool=True,\n                 first_embedding: bool=False,\n                 tweak_C: bool=False,\n                 **kwargs):\n        r\"\"\"\n        Args:\n            block: ConT Block\n            layers: number of blocks at each layer\n            mlp_dim: dimension of mlp in each stage\n            head_num: number of head in each stage\n            dropout: dropout in the last two stage\n            relative: if True, relative Position Embedding is used\n            groups: nunmber of group at each conv layer in the Network\n            depthwise: if True, depthwise convolution is adopted\n            in_channels: number of channels of input image\n            inplanes: channel of the first convolution layer\n            num_classes: number of classes for classification task\n                         only useful when `with_classifier` is True\n            with_avgpool: if True, an average pooling is added at the end of resnet stage5\n            with_classifier: if True, FC layer is registered for classification task\n            first_embedding: if True, a conv layer with both stride and kernel of 7 is placed at the top\n            tweakC: if true, the first layer of ResNet-C replace the ori layer\n        \"\"\"\n    \n        super(ConTNet, self).__init__()\n        self.inplanes = inplanes\n        self.block = block\n\n        # build the top layer\n        if tweak_C:\n            self.layer0 = nn.Sequential(OrderedDict([\n                ('conv_bn1', ConvBN(in_channels, inplanes//2, kernel_size=3, stride=2)),\n                ('relu1', nn.ReLU(inplace=True)),\n                ('conv_bn2', ConvBN(inplanes//2, inplanes//2, kernel_size=3, stride=1)),\n                ('relu2', nn.ReLU(inplace=True)),\n                ('conv_bn3', ConvBN(inplanes//2, inplanes, kernel_size=3, stride=1)),\n                ('relu3', nn.ReLU(inplace=True)),\n                ('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n            ]))\n        elif first_embedding:\n            self.layer0 = nn.Sequential(OrderedDict([\n                ('conv', nn.Conv2d(in_channels, inplanes, kernel_size=4, stride=4)),\n                ('norm', nn.LayerNorm(inplanes))\n            ]))\n        else:\n            self.layer0 = nn.Sequential(OrderedDict([\n                ('conv', ConvBN(in_channels, inplanes, kernel_size=7, stride=2, bn=False)),\n                # ('relu', nn.ReLU(inplace=True)),\n                ('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n            ]))\n\n        # build cont layers\n        self.cont_layers = []\n        self.out_channels = OrderedDict()\n\n        for i in range(len(layers)):\n            stride = 2,\n            patch_size = [7,14]\n            if i == len(layers)-1:\n                stride, patch_size[1] = 1, 7 # the last stage does not conduct downsampling\n            cont_layer = self._make_layer(inplanes * 2**i, layers[i], stride=stride, mlp_dim=mlp_dim[i], head_num=head_num[i], dropout=dropout[i], patch_size=patch_size, **kwargs)\n            layer_name = 'layer{}'.format(i + 1)\n            self.add_module(layer_name, cont_layer)\n            self.cont_layers.append(layer_name)\n            self.out_channels[layer_name] = 2 * inplanes * 2**i\n\n        self.last_out_channels = next(reversed(self.out_channels.values()))\n        self.fc = nn.Linear(self.last_out_channels, num_classes) \n\n        if init_weights:\n            self._initialize_weights()\n\n    def _make_layer(self,\n                    planes: int,\n                    blocks: int,\n                    stride: int,\n                    mlp_dim: int,\n                    head_num: int,\n                    dropout: float,\n                    patch_size: List[int],\n                    use_avgdown: bool=False,\n                    **kwargs):\n        \n        layers = OrderedDict()\n        for i in range(0, blocks-1):\n            layers[f'{self.block.__name__}{i}'] = self.block(\n                planes, planes, mlp_dim, head_num, dropout, patch_size, **kwargs)\n\n        downsample = None\n        if stride != 1:\n            if use_avgdown:\n                downsample = nn.Sequential(OrderedDict([\n                    ('avgpool', nn.AvgPool2d(kernel_size=2, stride=2)),\n                    ('conv', ConvBN(planes, planes * 2, kernel_size=1, stride=1, bn=False))]))\n            else:\n                downsample = ConvBN(planes, planes * 2, kernel_size=1, \n                                        stride=2, bn=False)\n        else:\n            downsample = ConvBN(planes, planes * 2, kernel_size=1, stride=1, bn=False)\n\n        layers[f'{self.block.__name__}{blocks-1}'] = self.block(\n            planes, planes*2, mlp_dim, head_num, dropout, patch_size, downsample, stride, **kwargs) \n\n        return nn.Sequential(layers)\n\n    def _initialize_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, nn.Linear):\n                trunc_normal_(m.weight, std=.02)\n                if isinstance(m, nn.Linear) and m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.LayerNorm):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n    \n    def forward(self, x):   \n        x = self.layer0(x)  \n\n        for _, layer_name in enumerate(self.cont_layers):\n            cont_layer = getattr(self, layer_name)\n            x = cont_layer(x)\n        \n        x = x.mean([2, 3])  \n        x = self.fc(x)\n\n        return x\n\ndef create_ConTNet_Ti(kwargs):\n    return ConTNet(block=ConTBlock, \n                   mlp_dim=[196, 392, 768, 768], \n                   head_num=[1, 2, 4, 8],\n                   dropout=[0,0,0,0], \n                   inplanes=48, \n                   layers=[1,1,1,1], \n                   last_dropout=0, \n                   **kwargs)\n\ndef create_ConTNet_S(kwargs):\n    return ConTNet(block=ConTBlock, \n                   mlp_dim=[256, 512, 1024, 1024], \n                   head_num=[1, 2, 4, 8],\n                   dropout=[0,0,0,0], \n                   inplanes=64, \n                   layers=[1,1,1,1], \n                   last_dropout=0, \n                   **kwargs)\n\ndef create_ConTNet_M(kwargs):\n    return ConTNet(block=ConTBlock, \n                   mlp_dim=[256, 512, 1024, 1024], \n                   head_num=[1, 2, 4, 8],\n                   dropout=[0,0,0,0], \n                   inplanes=64, \n                   layers=[2,2,2,2], \n                   last_dropout=0, \n                   **kwargs)\n\ndef create_ConTNet_B(kwargs):\n    return ConTNet(block=ConTBlock, \n                   mlp_dim=[256, 512, 1024, 1024], \n                   head_num=[1, 2, 4, 8],\n                   dropout=[0,0,0.1,0.1], \n                   inplanes=64, \n                   layers=[3,4,6,3], \n                   last_dropout=0.2, \n                   **kwargs)\n\ndef build_model(use_avgdown, relative, qkv_bias, pre_norm):\n    kwargs = dict(use_avgdown=use_avgdown, relative=relative, qkv_bias=qkv_bias, pre_norm=pre_norm)\n    return create_ConTNet_Ti(kwargs)\n\nif __name__ == \"__main__\":\n    model = build_model(use_avgdown=True, relative=True, qkv_bias=True, pre_norm=True)\n    input = torch.randn(1, 3, 224, 224)\n    out = model(input)\n    print(out.shape)\n"
  },
  {
    "path": "model/backbone/ConViT.py",
    "content": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\n#\n# This source code is licensed under the CC-by-NC license found in the\n# LICENSE file in the root directory of this source tree.\n#\n\n'''These modules are adapted from those of timm, see\nhttps://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n'''\n\nimport torch\nimport torch.nn as nn\nfrom functools import partial\nimport torch.nn.functional as F\nfrom timm.models.helpers import load_pretrained\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nfrom timm.models.registry import register_model\nfrom timm.models.vision_transformer import _cfg\n\nimport torch\nimport torch.nn as nn\nimport matplotlib.pyplot as plt\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n        self.apply(self._init_weights)\n        \n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n            \n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass GPSA(nn.Module):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,\n                 locality_strength=1., use_local_init=True):\n        super().__init__()\n        self.num_heads = num_heads\n        self.dim = dim\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)       \n        self.v = nn.Linear(dim, dim, bias=qkv_bias)       \n        \n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.pos_proj = nn.Linear(3, num_heads)\n        self.proj_drop = nn.Dropout(proj_drop)\n        self.locality_strength = locality_strength\n        self.gating_param = nn.Parameter(torch.ones(self.num_heads))\n        self.apply(self._init_weights)\n        if use_local_init:\n            self.local_init(locality_strength=locality_strength)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        \n    def forward(self, x):\n        B, N, C = x.shape\n        if not hasattr(self, 'rel_indices') or self.rel_indices.size(1)!=N:\n            self.get_rel_indices(N)\n\n        attn = self.get_attention(x)\n        v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n    def get_attention(self, x):\n        B, N, C = x.shape        \n        qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k = qk[0], qk[1]\n        pos_score = self.rel_indices.expand(B, -1, -1,-1)\n        pos_score = self.pos_proj(pos_score).permute(0,3,1,2) \n        patch_score = (q @ k.transpose(-2, -1)) * self.scale\n        patch_score = patch_score.softmax(dim=-1)\n        pos_score = pos_score.softmax(dim=-1)\n\n        gating = self.gating_param.view(1,-1,1,1)\n        attn = (1.-torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score\n        attn /= attn.sum(dim=-1).unsqueeze(-1)\n        attn = self.attn_drop(attn)\n        return attn\n\n    def get_attention_map(self, x, return_map = False):\n\n        attn_map = self.get_attention(x).mean(0) # average over batch\n        distances = self.rel_indices.squeeze()[:,:,-1]**.5\n        dist = torch.einsum('nm,hnm->h', (distances, attn_map))\n        dist /= distances.size(0)\n        if return_map:\n            return dist, attn_map\n        else:\n            return dist\n    \n    def local_init(self, locality_strength=1.):\n        \n        self.v.weight.data.copy_(torch.eye(self.dim))\n        locality_distance = 1 #max(1,1/locality_strength**.5)\n        \n        kernel_size = int(self.num_heads**.5)\n        center = (kernel_size-1)/2 if kernel_size%2==0 else kernel_size//2\n        for h1 in range(kernel_size):\n            for h2 in range(kernel_size):\n                position = h1+kernel_size*h2\n                self.pos_proj.weight.data[position,2] = -1\n                self.pos_proj.weight.data[position,1] = 2*(h1-center)*locality_distance\n                self.pos_proj.weight.data[position,0] = 2*(h2-center)*locality_distance\n        self.pos_proj.weight.data *= locality_strength\n\n    def get_rel_indices(self, num_patches):\n        img_size = int(num_patches**.5)\n        rel_indices   = torch.zeros(1, num_patches, num_patches, 3)\n        ind = torch.arange(img_size).view(1,-1) - torch.arange(img_size).view(-1, 1)\n        indx = ind.repeat(img_size,img_size)\n        indy = ind.repeat_interleave(img_size,dim=0).repeat_interleave(img_size,dim=1)\n        indd = indx**2 + indy**2\n        rel_indices[:,:,:,2] = indd.unsqueeze(0)\n        rel_indices[:,:,:,1] = indy.unsqueeze(0)\n        rel_indices[:,:,:,0] = indx.unsqueeze(0)\n        device = self.qk.weight.device\n        self.rel_indices = rel_indices.to(device)\n\n \nclass MHSA(nn.Module):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n        self.apply(self._init_weights)\n        \n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def get_attention_map(self, x, return_map = False):\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]\n        attn_map = (q @ k.transpose(-2, -1)) * self.scale\n        attn_map = attn_map.softmax(dim=-1).mean(0)\n\n        img_size = int(N**.5)\n        ind = torch.arange(img_size).view(1,-1) - torch.arange(img_size).view(-1, 1)\n        indx = ind.repeat(img_size,img_size)\n        indy = ind.repeat_interleave(img_size,dim=0).repeat_interleave(img_size,dim=1)\n        indd = indx**2 + indy**2\n        distances = indd**.5\n        distances = distances.to('cuda')\n\n        dist = torch.einsum('nm,hnm->h', (distances, attn_map))\n        dist /= N\n        \n        if return_map:\n            return dist, attn_map\n        else:\n            return dist\n\n            \n    def forward(self, x):\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n    \nclass Block(nn.Module):\n\n    def __init__(self, dim, num_heads,  mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.use_gpsa = use_gpsa\n        if self.use_gpsa:\n            self.attn = GPSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, **kwargs)\n        else:\n            self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, **kwargs)\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n    def forward(self, x):\n        x = x + self.drop_path(self.attn(self.norm1(x)))\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n    \n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding, from timm\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        self.apply(self._init_weights)\n    def forward(self, x):\n        B, C, H, W = x.shape\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)\n        return x\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n\nclass HybridEmbed(nn.Module):\n    \"\"\" CNN Feature Map Embedding, from timm\n    \"\"\"\n    def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):\n        super().__init__()\n        assert isinstance(backbone, nn.Module)\n        img_size = to_2tuple(img_size)\n        self.img_size = img_size\n        self.backbone = backbone\n        if feature_size is None:\n            with torch.no_grad():\n                training = backbone.training\n                if training:\n                    backbone.eval()\n                o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]\n                feature_size = o.shape[-2:]\n                feature_dim = o.shape[1]\n                backbone.train(training)\n        else:\n            feature_size = to_2tuple(feature_size)\n            feature_dim = self.backbone.feature_info.channels()[-1]\n        self.num_patches = feature_size[0] * feature_size[1]\n        self.proj = nn.Linear(feature_dim, embed_dim)\n        self.apply(self._init_weights)\n\n    def forward(self, x):\n        x = self.backbone(x)[-1]\n        x = x.flatten(2).transpose(1, 2)\n        x = self.proj(x)\n        return x\n\n\nclass VisionTransformer(nn.Module):\n    \"\"\" Vision Transformer with support for patch or hybrid CNN input stage\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=48, depth=12,\n                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,\n                 drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None,\n                 local_up_to_layer=10, locality_strength=1., use_pos_embed=True):\n        super().__init__()\n        self.num_classes = num_classes\n        self.local_up_to_layer = local_up_to_layer\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n        self.locality_strength = locality_strength\n        self.use_pos_embed = use_pos_embed\n\n        if hybrid_backbone is not None:\n            self.patch_embed = HybridEmbed(\n                hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)\n        else:\n            self.patch_embed = PatchEmbed(\n                img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        num_patches = self.patch_embed.num_patches\n        self.num_patches = num_patches\n        \n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        if self.use_pos_embed:\n            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n            trunc_normal_(self.pos_embed, std=.02)\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\n        self.blocks = nn.ModuleList([\n            Block(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,\n                use_gpsa=True,\n                locality_strength=locality_strength)\n            if i<local_up_to_layer else\n            Block(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,\n                use_gpsa=False)\n            for i in range(depth)])\n        self.norm = norm_layer(embed_dim)\n\n        # Classifier head\n        self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]\n        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n        trunc_normal_(self.cls_token, std=.02)\n        self.head.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token'}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x):\n        B = x.shape[0]\n        x = self.patch_embed(x)\n\n        cls_tokens = self.cls_token.expand(B, -1, -1)\n\n        if self.use_pos_embed:\n            x = x + self.pos_embed\n        x = self.pos_drop(x)\n\n        for u,blk in enumerate(self.blocks):\n            if u == self.local_up_to_layer :\n                x = torch.cat((cls_tokens, x), dim=1)\n            x = blk(x)\n\n        x = self.norm(x)\n        return x[:, 0]\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n    \n    \n@register_model\ndef convit_tiny(pretrained=False, **kwargs):\n    num_heads = 4\n    kwargs['embed_dim'] *= num_heads\n    model = VisionTransformer(\n        num_heads=num_heads,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/convit/convit_tiny.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        model.load_state_dict(checkpoint)\n    return model\n\n@register_model\ndef convit_small(pretrained=False, **kwargs):\n    num_heads = 9\n    kwargs['embed_dim'] *= num_heads\n    model = VisionTransformer(\n        num_heads=num_heads,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/convit/convit_small.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        model.load_state_dict(checkpoint)\n    return model\n\n@register_model\ndef convit_base(pretrained=False, **kwargs):\n    num_heads = 16\n    kwargs['embed_dim'] *= num_heads\n    model = VisionTransformer(\n        num_heads=num_heads,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/convit/convit_base.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        model.load_state_dict(checkpoint)\n    return model\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VisionTransformer(\n        num_heads=16,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n        )\n    output=model(input)\n    print(output.shape)\n\n"
  },
  {
    "path": "model/backbone/Container.py",
    "content": "import torch\nimport torch.nn as nn\nfrom functools import partial\nimport math\nfrom timm.models.vision_transformer import VisionTransformer, _cfg\nfrom timm.models.registry import register_model\nfrom timm.models.layers import trunc_normal_, DropPath, to_2tuple\nimport pdb\n\n__all__ = [\n    'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',\n    'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',\n    'deit_base_distilled_patch16_224', 'deit_base_patch16_384',\n    'deit_base_distilled_patch16_384', 'container_light'\n]\n\nclass Mlp(nn.Module):\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\nclass CMlp(nn.Module):\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Conv2d(in_features, hidden_features, 1)\n        self.act = act_layer()\n        self.fc2 = nn.Conv2d(hidden_features, out_features, 1)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n    \n    \nclass Attention(nn.Module):\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x):\n        pdb.set_trace()\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n    \nclass Attention_pure(nn.Module):\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x):\n        B, N, C = x.shape\n        C = int(C // 3)\n        qkv = x.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj_drop(x)\n        return x\n    \n\nclass MixBlock(nn.Module):\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.dim = dim\n        self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)\n        self.norm1 = nn.BatchNorm2d(dim)\n        self.conv1 = nn.Conv2d(dim, 3 * dim, 1)\n        self.conv2 = nn.Conv2d(dim, dim, 1)\n        self.conv = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)\n        self.attn = Attention_pure(\n            dim,\n            num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            attn_drop=attn_drop, proj_drop=drop)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = nn.BatchNorm2d(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n        self.sa_weight = nn.Parameter(torch.Tensor([0.0]))\n\n\n    def forward(self, x):\n        x = x + self.pos_embed(x)\n        B, _, H, W = x.shape\n        residual = x\n        x = self.norm1(x)\n        qkv = self.conv1(x)\n        \n        conv = qkv[:, 2 * self.dim:, :, :]\n        conv = self.conv(conv)\n        \n        sa = qkv.flatten(2).transpose(1, 2)\n        sa = self.attn(sa)\n        sa = sa.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n        \n        x = residual + self.drop_path(self.conv2(torch.sigmoid(self.sa_weight) * sa + (1 - torch.sigmoid(self.sa_weight)) * conv))\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n    \n\nclass CBlock(nn.Module):\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)\n        self.norm1 = nn.BatchNorm2d(dim)\n        self.conv1 = nn.Conv2d(dim, dim, 1)\n        self.conv2 = nn.Conv2d(dim, dim, 1)\n        self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = nn.BatchNorm2d(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n    def forward(self, x):\n        x = x + self.pos_embed(x)\n        x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x)))))\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n    \nclass Block(nn.Module):\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim,\n            num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            attn_drop=attn_drop, proj_drop=drop)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n    def forward(self, x):\n        x = x + self.drop_path(self.attn(self.norm1(x)))\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n    \n\n    \nclass PatchEmbed(nn.Module):\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n        self.norm = nn.LayerNorm(embed_dim)\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x)\n        B, C, H, W = x.shape\n        x = x.flatten(2).transpose(1, 2)\n        x = self.norm(x)\n        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n        return x\n    \nclass HybridEmbed(nn.Module):\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n    def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):\n        super().__init__()\n        assert isinstance(backbone, nn.Module)\n        img_size = to_2tuple(img_size)\n        self.img_size = img_size\n        self.backbone = backbone\n        if feature_size is None:\n            with torch.no_grad():\n                # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature\n                # map for all networks, the feature metadata has reliable channel and stride info, but using\n                # stride to calc feature dim requires info about padding of each stage that isn't captured.\n                training = backbone.training\n                if training:\n                    backbone.eval()\n                o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))\n                if isinstance(o, (list, tuple)):\n                    o = o[-1]  # last feature if backbone outputs list/tuple of features\n                feature_size = o.shape[-2:]\n                feature_dim = o.shape[1]\n                backbone.train(training)\n        else:\n            feature_size = to_2tuple(feature_size)\n            if hasattr(self.backbone, 'feature_info'):\n                feature_dim = self.backbone.feature_info.channels()[-1]\n            else:\n                feature_dim = self.backbone.num_features\n        self.num_patches = feature_size[0] * feature_size[1]\n        self.proj = nn.Conv2d(feature_dim, embed_dim, 1)\n\n    def forward(self, x):\n        x = self.backbone(x)\n        if isinstance(x, (list, tuple)):\n            x = x[-1]  # last feature if backbone outputs list/tuple of features\n        x = self.proj(x).flatten(2).transpose(1, 2)\n        return x\n\n        \nclass VisionTransformer(nn.Module):\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n    \"\"\" Vision Transformer\n    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`  -\n        https://arxiv.org/abs/2010.11929\n    \"\"\"\n    def __init__(self, img_size=[224, 56, 28, 14], patch_size=[4, 2, 2, 2], in_chans=3, num_classes=1000, embed_dim=[64, 128, 320, 512], depth=[3, 4, 8, 3],\n                 num_heads=12, mlp_ratio=[8, 8, 4, 4], qkv_bias=True, qk_scale=None, representation_size=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None):\n        \"\"\"\n        Args:\n            img_size (int, tuple): input image size\n            patch_size (int, tuple): patch size\n            in_chans (int): number of input channels\n            num_classes (int): number of classes for classification head\n            embed_dim (int): embedding dimension\n            depth (int): depth of transformer\n            num_heads (int): number of attention heads\n            mlp_ratio (int): ratio of mlp hidden dim to embedding dim\n            qkv_bias (bool): enable bias for qkv if True\n            qk_scale (float): override default qk scale of head_dim ** -0.5 if set\n            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set\n            drop_rate (float): dropout rate\n            attn_drop_rate (float): attention dropout rate\n            drop_path_rate (float): stochastic depth rate\n            hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module\n            norm_layer: (nn.Module): normalization layer\n        \"\"\"\n        super().__init__()\n        self.num_classes = num_classes\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) \n        self.embed_dim = embed_dim\n        self.depth = depth\n        if hybrid_backbone is not None:\n            self.patch_embed = HybridEmbed(\n                hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)\n        else:\n            self.patch_embed1 = PatchEmbed(\n                img_size=img_size[0], patch_size=patch_size[0], in_chans=in_chans, embed_dim=embed_dim[0])\n            self.patch_embed2 = PatchEmbed(\n                img_size=img_size[1], patch_size=patch_size[1], in_chans=embed_dim[0], embed_dim=embed_dim[1])\n            self.patch_embed3 = PatchEmbed(\n                img_size=img_size[2], patch_size=patch_size[2], in_chans=embed_dim[1], embed_dim=embed_dim[2])\n            self.patch_embed4 = PatchEmbed(\n                img_size=img_size[3], patch_size=patch_size[3], in_chans=embed_dim[2], embed_dim=embed_dim[3])\n        num_patches1 = self.patch_embed1.num_patches\n        num_patches2 = self.patch_embed2.num_patches\n        num_patches3 = self.patch_embed3.num_patches\n        num_patches4 = self.patch_embed4.num_patches\n    \n        self.pos_drop = nn.Dropout(p=drop_rate)\n        self.mixture =True\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))]  # stochastic depth decay rule\n        self.blocks1 = nn.ModuleList([\n            CBlock(\n                dim=embed_dim[0], num_heads=num_heads, mlp_ratio=mlp_ratio[0], qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)\n            for i in range(depth[0])])\n        self.blocks2 = nn.ModuleList([\n            CBlock(\n                dim=embed_dim[1], num_heads=num_heads, mlp_ratio=mlp_ratio[1], qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]], norm_layer=norm_layer)\n            for i in range(depth[1])])\n        self.blocks3 = nn.ModuleList([\n            CBlock(\n                dim=embed_dim[2], num_heads=num_heads, mlp_ratio=mlp_ratio[2], qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]], norm_layer=norm_layer)\n            for i in range(depth[2])])\n        self.blocks4 = nn.ModuleList([\n            MixBlock(\n                dim=embed_dim[3], num_heads=num_heads, mlp_ratio=mlp_ratio[3], qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]+depth[2]], norm_layer=norm_layer)\n            for i in range(depth[3])])\n        self.norm = nn.BatchNorm2d(embed_dim[-1])\n\n        # Representation layer\n        if representation_size:\n            self.num_features = representation_size\n            self.pre_logits = nn.Sequential(OrderedDict([\n                ('fc', nn.Linear(embed_dim, representation_size)),\n                ('act', nn.Tanh())\n            ]))\n        else:\n            self.pre_logits = nn.Identity()\n\n        # Classifier head\n        self.head = nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()\n        \n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token'}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x):\n        B = x.shape[0]\n        x = self.patch_embed1(x)\n        x = self.pos_drop(x)\n        for blk in self.blocks1:\n            x = blk(x)\n        x = self.patch_embed2(x)\n        for blk in self.blocks2:\n            x = blk(x)\n        x = self.patch_embed3(x)\n        for blk in self.blocks3:\n            x = blk(x)\n        x = self.patch_embed4(x)\n        for blk in self.blocks4:\n            x = blk(x)\n        x = self.norm(x)\n        x = self.pre_logits(x)\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = x.flatten(2).mean(-1)\n        x = self.head(x)\n        return x\n\n\n\n@register_model\ndef container_v1_light(pretrained=False, **kwargs):\n    model = VisionTransformer(\n        img_size=[224, 56, 28, 14], patch_size=[4, 2, 2, 2], embed_dim=[64, 128, 320, 512], depth=[3, 4, 8, 3], num_heads=16, mlp_ratio=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        model.load_state_dict(checkpoint[\"model\"])\n    return model\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VisionTransformer(\n        img_size=[224, 56, 28, 14], \n        patch_size=[4, 2, 2, 2], \n        embed_dim=[64, 128, 320, 512], \n        depth=[3, 4, 8, 3], \n        num_heads=16, \n        mlp_ratio=[8, 8, 4, 4], \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6))\n    output=model(input)\n    print(output.shape)\n"
  },
  {
    "path": "model/backbone/ConvMixer.py",
    "content": "import torch.nn as nn\nfrom torch.nn.modules.activation import GELU\nimport torch\nfrom torch.nn.modules.pooling import AdaptiveAvgPool2d\n\nclass Residual(nn.Module):\n    def __init__(self,fn):\n        super().__init__()\n        self.fn=fn\n    def forward(self,x):\n        return x+self.fn(x)\n\ndef ConvMixer(dim,depth,kernel_size=9,patch_size=7,num_classes=1000):\n    return nn.Sequential(\n        nn.Conv2d(3,dim,kernel_size=patch_size,stride=patch_size),\n        nn.GELU(),\n        nn.BatchNorm2d(dim),\n        *[nn.Sequential(\n            Residual(nn.Sequential(\n                nn.Conv2d(dim,dim,kernel_size=kernel_size,groups=dim,padding=kernel_size//2),\n                nn.GELU(),\n                nn.BatchNorm2d(dim)\n            )),\n            nn.Conv2d(dim,dim,kernel_size=1),\n            nn.GELU(),\n            nn.BatchNorm2d(dim)\n        ) for _ in range(depth)],\n        nn.AdaptiveAvgPool2d(1),\n        nn.Flatten(),\n        nn.Linear(dim,num_classes)\n    )\n\nif __name__ == '__main__':\n    x=torch.randn(1,3,224,224)\n    convmixer=ConvMixer(dim=512,depth=12)\n    out=convmixer(x)\n    print(out.shape)  #[1, 1000]\n\n    \n"
  },
  {
    "path": "model/backbone/CrossViT.py",
    "content": "# Copyright IBM All Rights Reserved.\n# SPDX-License-Identifier: Apache-2.0\n\n\n\"\"\"\nModifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.hub\nfrom functools import partial\n\n\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nfrom timm.models.registry import register_model\nfrom timm.models.vision_transformer import _cfg, Mlp, Block\n\n_model_urls = {\n    'crossvit_15_224': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_224.pth',\n    'crossvit_15_dagger_224': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_224.pth',\n    'crossvit_15_dagger_384': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth',\n    'crossvit_18_224': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_224.pth',\n    'crossvit_18_dagger_224': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_224.pth',\n    'crossvit_18_dagger_384': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_384.pth',\n    'crossvit_9_224': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_224.pth',\n    'crossvit_9_dagger_224': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_dagger_224.pth',\n    'crossvit_base_224': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_base_224.pth',\n    'crossvit_small_224': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_small_224.pth',\n    'crossvit_tiny_224': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_tiny_224.pth',\n}\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi_conv=False):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n        if multi_conv:\n            if patch_size[0] == 12:\n                self.proj = nn.Sequential(\n                    nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3),\n                    nn.ReLU(inplace=True),\n                    nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=3, padding=0),\n                    nn.ReLU(inplace=True),\n                    nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1),\n                )\n            elif patch_size[0] == 16:\n                self.proj = nn.Sequential(\n                    nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3),\n                    nn.ReLU(inplace=True),\n                    nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1),\n                    nn.ReLU(inplace=True),\n                    nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1),\n                )\n        else:\n            self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)\n        return x\n\n\nclass CrossAttention(nn.Module):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.wq = nn.Linear(dim, dim, bias=qkv_bias)\n        self.wk = nn.Linear(dim, dim, bias=qkv_bias)\n        self.wv = nn.Linear(dim, dim, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x):\n\n        B, N, C = x.shape\n        q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  # B1C -> B1H(C/H) -> BH1(C/H)\n        k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  # BNC -> BNH(C/H) -> BHN(C/H)\n        v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  # BNC -> BNH(C/H) -> BHN(C/H)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale  # BH1(C/H) @ BH(C/H)N -> BH1N\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, 1, C)   # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass CrossAttentionBlock(nn.Module):\n\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, has_mlp=True):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = CrossAttention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.has_mlp = has_mlp\n        if has_mlp:\n            self.norm2 = norm_layer(dim)\n            mlp_hidden_dim = int(dim * mlp_ratio)\n            self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n    def forward(self, x):\n        x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x)))\n        if self.has_mlp:\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n\nclass MultiScaleBlock(nn.Module):\n\n    def __init__(self, dim, patches, depth, num_heads, mlp_ratio, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n\n        num_branches = len(dim)\n        self.num_branches = num_branches\n        # different branch could have different embedding size, the first one is the base\n        self.blocks = nn.ModuleList()\n        for d in range(num_branches):\n            tmp = []\n            for i in range(depth[d]):\n                tmp.append(\n                    Block(dim=dim[d], num_heads=num_heads[d], mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, \n                          drop=drop, attn_drop=attn_drop, drop_path=drop_path[i], norm_layer=norm_layer))\n            if len(tmp) != 0:\n                self.blocks.append(nn.Sequential(*tmp))\n\n        if len(self.blocks) == 0:\n            self.blocks = None\n\n        self.projs = nn.ModuleList()\n        for d in range(num_branches):\n            if dim[d] == dim[(d+1) % num_branches] and False:\n                tmp = [nn.Identity()]\n            else:\n                tmp = [norm_layer(dim[d]), act_layer(), nn.Linear(dim[d], dim[(d+1) % num_branches])]\n            self.projs.append(nn.Sequential(*tmp))\n\n        self.fusion = nn.ModuleList()\n        for d in range(num_branches):\n            d_ = (d+1) % num_branches\n            nh = num_heads[d_]\n            if depth[-1] == 0:  # backward capability:\n                self.fusion.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                                       drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer,\n                                                       has_mlp=False))\n            else:\n                tmp = []\n                for _ in range(depth[-1]):\n                    tmp.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                                   drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer,\n                                                   has_mlp=False))\n                self.fusion.append(nn.Sequential(*tmp))\n\n        self.revert_projs = nn.ModuleList()\n        for d in range(num_branches):\n            if dim[(d+1) % num_branches] == dim[d] and False:\n                tmp = [nn.Identity()]\n            else:\n                tmp = [norm_layer(dim[(d+1) % num_branches]), act_layer(), nn.Linear(dim[(d+1) % num_branches], dim[d])]\n            self.revert_projs.append(nn.Sequential(*tmp))\n\n    def forward(self, x):\n        outs_b = [block(x_) for x_, block in zip(x, self.blocks)]\n        # only take the cls token out\n        proj_cls_token = [proj(x[:, 0:1]) for x, proj in zip(outs_b, self.projs)]\n        # cross attention\n        outs = []\n        for i in range(self.num_branches):\n            tmp = torch.cat((proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1)\n            tmp = self.fusion[i](tmp)\n            reverted_proj_cls_token = self.revert_projs[i](tmp[:, 0:1, ...])\n            tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1)\n            outs.append(tmp)\n        return outs\n\n\ndef _compute_num_patches(img_size, patches):\n    return [i // p * i // p for i, p in zip(img_size,patches)]\n\n\nclass VisionTransformer(nn.Module):\n    \"\"\" Vision Transformer with support for patch or hybrid CNN input stage\n    \"\"\"\n    def __init__(self, img_size=(224, 224), patch_size=(8, 16), in_chans=3, num_classes=1000, embed_dim=(192, 384), depth=([1, 3, 1], [1, 3, 1], [1, 3, 1]),\n                 num_heads=(6, 12), mlp_ratio=(2., 2., 4.), qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,\n                 drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, multi_conv=False):\n        super().__init__()\n\n        self.num_classes = num_classes\n        if not isinstance(img_size, list):\n            img_size = to_2tuple(img_size)\n        self.img_size = img_size\n\n        num_patches = _compute_num_patches(img_size, patch_size)\n        self.num_branches = len(patch_size)\n\n        self.patch_embed = nn.ModuleList()\n        if hybrid_backbone is None:\n            self.pos_embed = nn.ParameterList([nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])) for i in range(self.num_branches)])\n            for im_s, p, d in zip(img_size, patch_size, embed_dim):\n                self.patch_embed.append(PatchEmbed(img_size=im_s, patch_size=p, in_chans=in_chans, embed_dim=d, multi_conv=multi_conv))\n        else:\n            self.pos_embed = nn.ParameterList()\n            from .t2t import T2T, get_sinusoid_encoding\n            tokens_type = 'transformer' if hybrid_backbone == 't2t' else 'performer'\n            for idx, (im_s, p, d) in enumerate(zip(img_size, patch_size, embed_dim)):\n                self.patch_embed.append(T2T(im_s, tokens_type=tokens_type, patch_size=p, embed_dim=d))\n                self.pos_embed.append(nn.Parameter(data=get_sinusoid_encoding(n_position=1 + num_patches[idx], d_hid=embed_dim[idx]), requires_grad=False))\n\n            del self.pos_embed\n            self.pos_embed = nn.ParameterList([nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])) for i in range(self.num_branches)])\n\n        self.cls_token = nn.ParameterList([nn.Parameter(torch.zeros(1, 1, embed_dim[i])) for i in range(self.num_branches)])\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        total_depth = sum([sum(x[-2:]) for x in depth])\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)]  # stochastic depth decay rule\n        dpr_ptr = 0\n        self.blocks = nn.ModuleList()\n        for idx, block_cfg in enumerate(depth):\n            curr_depth = max(block_cfg[:-1]) + block_cfg[-1]\n            dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth]\n            blk = MultiScaleBlock(embed_dim, num_patches, block_cfg, num_heads=num_heads, mlp_ratio=mlp_ratio,\n                                  qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr_,\n                                  norm_layer=norm_layer)\n            dpr_ptr += curr_depth\n            self.blocks.append(blk)\n\n        self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)])\n        self.head = nn.ModuleList([nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)])\n\n        for i in range(self.num_branches):\n            if self.pos_embed[i].requires_grad:\n                trunc_normal_(self.pos_embed[i], std=.02)\n            trunc_normal_(self.cls_token[i], std=.02)\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        out = {'cls_token'}\n        if self.pos_embed[0].requires_grad:\n            out.add('pos_embed')\n        return out\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x):\n        B, C, H, W = x.shape\n        xs = []\n        for i in range(self.num_branches):\n            x_ = torch.nn.functional.interpolate(x, size=(self.img_size[i], self.img_size[i]), mode='bicubic') if H != self.img_size[i] else x\n            tmp = self.patch_embed[i](x_)\n            cls_tokens = self.cls_token[i].expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks\n            tmp = torch.cat((cls_tokens, tmp), dim=1)\n            tmp = tmp + self.pos_embed[i]\n            tmp = self.pos_drop(tmp)\n            xs.append(tmp)\n\n        for blk in self.blocks:\n            xs = blk(xs)\n\n        # NOTE: was before branch token section, move to here to assure all branch token are before layer norm\n        xs = [self.norm[i](x) for i, x in enumerate(xs)]\n        out = [x[:, 0] for x in xs]\n\n        return out\n\n    def forward(self, x):\n        xs = self.forward_features(x)\n        ce_logits = [self.head[i](x) for i, x in enumerate(xs)]\n        ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0)\n        return ce_logits\n\n\n\n\n@register_model\ndef crossvit_tiny_224(pretrained=False, **kwargs):\n    model = VisionTransformer(img_size=[240, 224],\n                              patch_size=[12, 16], embed_dim=[96, 192], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],\n                              num_heads=[3, 3], mlp_ratio=[4, 4, 1], qkv_bias=True,\n                              norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_tiny_224'], map_location='cpu')\n        model.load_state_dict(state_dict)\n    return model\n\n\n@register_model\ndef crossvit_small_224(pretrained=False, **kwargs):\n    model = VisionTransformer(img_size=[240, 224],\n                              patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],\n                              num_heads=[6, 6], mlp_ratio=[4, 4, 1], qkv_bias=True,\n                              norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_small_224'], map_location='cpu')\n        model.load_state_dict(state_dict)\n    return model\n\n\n@register_model\ndef crossvit_base_224(pretrained=False, **kwargs):\n    model = VisionTransformer(img_size=[240, 224],\n                              patch_size=[12, 16], embed_dim=[384, 768], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],\n                              num_heads=[12, 12], mlp_ratio=[4, 4, 1], qkv_bias=True,\n                              norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_base_224'], map_location='cpu')\n        model.load_state_dict(state_dict)\n    return model\n\n\n@register_model\ndef crossvit_9_224(pretrained=False, **kwargs):\n    model = VisionTransformer(img_size=[240, 224],\n                              patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],\n                              num_heads=[4, 4], mlp_ratio=[3, 3, 1], qkv_bias=True,\n                              norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_9_224'], map_location='cpu')\n        model.load_state_dict(state_dict)\n    return model\n\n\n@register_model\ndef crossvit_15_224(pretrained=False, **kwargs):\n    model = VisionTransformer(img_size=[240, 224],\n                              patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],\n                              num_heads=[6, 6], mlp_ratio=[3, 3, 1], qkv_bias=True,\n                              norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_15_224'], map_location='cpu')\n        model.load_state_dict(state_dict)\n    return model\n\n\n@register_model\ndef crossvit_18_224(pretrained=False, **kwargs):\n    model = VisionTransformer(img_size=[240, 224],\n                              patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],\n                              num_heads=[7, 7], mlp_ratio=[3, 3, 1], qkv_bias=True,\n                              norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_18_224'], map_location='cpu')\n        model.load_state_dict(state_dict)\n    return model\n\n\n@register_model\ndef crossvit_9_dagger_224(pretrained=False, **kwargs):\n    model = VisionTransformer(img_size=[240, 224],\n                              patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],\n                              num_heads=[4, 4], mlp_ratio=[3, 3, 1], qkv_bias=True,\n                              norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_9_dagger_224'], map_location='cpu')\n        model.load_state_dict(state_dict)\n    return model\n\n@register_model\ndef crossvit_15_dagger_224(pretrained=False, **kwargs):\n    model = VisionTransformer(img_size=[240, 224],\n                              patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],\n                              num_heads=[6, 6], mlp_ratio=[3, 3, 1], qkv_bias=True,\n                              norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_15_dagger_224'], map_location='cpu')\n        model.load_state_dict(state_dict)\n    return model\n\n@register_model\ndef crossvit_15_dagger_384(pretrained=False, **kwargs):\n    model = VisionTransformer(img_size=[408, 384],\n                              patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],\n                              num_heads=[6, 6], mlp_ratio=[3, 3, 1], qkv_bias=True,\n                              norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_15_dagger_384'], map_location='cpu')\n        model.load_state_dict(state_dict)\n    return model\n\n@register_model\ndef crossvit_18_dagger_224(pretrained=False, **kwargs):\n    model = VisionTransformer(img_size=[240, 224],\n                              patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],\n                              num_heads=[7, 7], mlp_ratio=[3, 3, 1], qkv_bias=True,\n                              norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_18_dagger_224'], map_location='cpu')\n        model.load_state_dict(state_dict)\n    return model\n\n@register_model\ndef crossvit_18_dagger_384(pretrained=False, **kwargs):\n    model = VisionTransformer(img_size=[408, 384],\n                              patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],\n                              num_heads=[7, 7], mlp_ratio=[3, 3, 1], qkv_bias=True,\n                              norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_18_dagger_384'], map_location='cpu')\n        model.load_state_dict(state_dict)\n    return model\n\nif __name__ == \"__main__\":\n    input=torch.randn(1,3,224,224)\n    model = VisionTransformer(\n        img_size=[240, 224],\n        patch_size=[12, 16], \n        embed_dim=[192, 384], \n        depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],\n        num_heads=[6, 6], \n        mlp_ratio=[4, 4, 1], \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n    )\n    output=model(input)\n    print(output.shape)\n\n    "
  },
  {
    "path": "model/backbone/DViT.py",
    "content": "\"\"\" \nCode for DeepViT. The implementation has heavy reference to timm.\n\"\"\"\nimport torch\nimport torch.nn as nn\nfrom functools import partial\nimport pickle\nfrom torch.nn.parameter import Parameter\n\nfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\nfrom timm.models.helpers import load_pretrained\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nfrom timm.models.resnet import resnet26d, resnet50d\nfrom timm.models.registry import register_model\n\nimport torch.nn.init as init\nimport torch.nn.functional as F\n\nfrom torch.nn import functional as F\n\nimport numpy as np\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., expansion_ratio=3):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.act = act_layer()\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\nclass Attention(nn.Module):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., expansion_ratio=3):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.expansion = expansion_ratio\n        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, dim * self.expansion, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x, atten=None):\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x, attn\nclass ReAttention(nn.Module):\n    \"\"\"\n    It is observed that similarity along same batch of data is extremely large. \n    Thus can reduce the bs dimension when calculating the attention map.\n    \"\"\"\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,expansion_ratio = 3, apply_transform=True, transform_scale=False):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.apply_transform = apply_transform\n        \n        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n        self.scale = qk_scale or head_dim ** -0.5\n        if apply_transform:\n            self.reatten_matrix = nn.Conv2d(self.num_heads,self.num_heads, 1, 1)\n            self.var_norm = nn.BatchNorm2d(self.num_heads)\n            self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)\n            self.reatten_scale = self.scale if transform_scale else 1.0\n        else:\n            self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)\n        \n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n    def forward(self, x, atten=None):\n        B, N, C = x.shape\n        # x = self.fc(x)\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n        if self.apply_transform:\n            attn = self.var_norm(self.reatten_matrix(attn)) * self.reatten_scale\n        attn_next = attn\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x, attn_next\nclass Block(nn.Module):\n\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, expansion=3, \n                 group = False, share = False, re_atten=False, bs=False, apply_transform=False,\n                 scale_adjustment=1.0, transform_scale=False):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.re_atten = re_atten\n\n        self.adjust_ratio = scale_adjustment\n        self.dim = dim\n        if  self.re_atten:\n            self.attn = ReAttention(\n                dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, \n                expansion_ratio = expansion, apply_transform=apply_transform, transform_scale=transform_scale)\n        else:\n            self.attn = Attention(\n                dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, \n                expansion_ratio = expansion)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n    def forward(self, x, atten=None):\n        if self.re_atten:\n            x_new, atten = self.attn(self.norm1(x * self.adjust_ratio), atten)\n            x = x + self.drop_path(x_new/self.adjust_ratio)\n            x = x + self.drop_path(self.mlp(self.norm2(x * self.adjust_ratio))) / self.adjust_ratio\n            return x, atten\n        else:\n            x_new, atten = self.attn(self.norm1(x), atten)\n            x= x + self.drop_path(x_new)\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n            return x, atten\n\nclass PatchEmbed_CNN(nn.Module):\n    \"\"\" \n        Following T2T, we use 3 layers of CNN for comparison with other methods.\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,spp=32):\n        super().__init__()\n\n        new_patch_size = to_2tuple(patch_size // 2)\n\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n\n        self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False)  # 112x112\n        self.bn1 = nn.BatchNorm2d(64)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)  # 112x112\n        self.bn2 = nn.BatchNorm2d(64)\n\n        self.proj = nn.Conv2d(64, embed_dim, kernel_size=new_patch_size, stride=new_patch_size)\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n\n        x = self.conv2(x)\n        x = self.bn2(x)\n        x = self.relu(x)\n\n        x = self.proj(x).flatten(2).transpose(1, 2)  # [B, C, W, H]\n\n        return x\nclass PatchEmbed(nn.Module):\n    \"\"\" \n        Same embedding as timm lib.\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)\n        return x\n\n\nclass HybridEmbed(nn.Module):\n    \"\"\" \n        Same embedding as timm lib.\n    \"\"\"\n    def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):\n        super().__init__()\n        assert isinstance(backbone, nn.Module)\n        img_size = to_2tuple(img_size)\n        self.img_size = img_size\n        self.backbone = backbone\n        if feature_size is None:\n            with torch.no_grad():\n                training = backbone.training\n                if training:\n                    backbone.eval()\n                o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]\n                feature_size = o.shape[-2:]\n                feature_dim = o.shape[1]\n                backbone.train(training)\n        else:\n            feature_size = to_2tuple(feature_size)\n            feature_dim = self.backbone.feature_info.channels()[-1]\n        self.num_patches = feature_size[0] * feature_size[1]\n        self.proj = nn.Linear(feature_dim, embed_dim)\n\n    def forward(self, x):\n        x = self.backbone(x)[-1]\n        x = x.flatten(2).transpose(1, 2)\n        x = self.proj(x)\n        return x\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n        'crop_pct': .9, 'interpolation': 'bicubic',\n        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,\n        'first_conv': 'patch_embed.proj', 'classifier': 'head',\n        **kwargs\n    }\n\n\ndefault_cfgs = {\n    # patch models\n    'Deepvit_base_patch16_224_16B': _cfg(\n        url='',\n        mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),\n    ),\n    'Deepvit_base_patch16_224_24B': _cfg(\n        url='',\n        mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),\n    ),\n    'Deepvit_base_patch16_224_32B': _cfg(\n        url='',\n        mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),\n    ),\n    'Deepvit_L_384': _cfg(\n        url='',\n        input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),\n}\n\n\n\nclass DeepVisionTransformer(nn.Module):\n    \"\"\" Vision Transformer with support for patch or hybrid CNN input stage\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,\n                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,\n                 drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, group = False, re_atten=True, cos_reg = False,\n                 use_cnn_embed=False, apply_transform=None, transform_scale=False, scale_adjustment=1.):\n        super().__init__()\n        self.num_classes = num_classes\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n        # use cosine similarity as a regularization term\n        self.cos_reg = cos_reg\n\n        if hybrid_backbone is not None:\n            self.patch_embed = HybridEmbed(\n                hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)\n        else:\n            if use_cnn_embed:\n                self.patch_embed = PatchEmbed_CNN(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n            else:\n                self.patch_embed = PatchEmbed(\n                    img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        num_patches = self.patch_embed.num_patches\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))\n        self.pos_drop = nn.Dropout(p=drop_rate)\n        d = depth if isinstance(depth, int) else len(depth)\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, d)]  # stochastic depth decay rule\n\n        self.blocks = nn.ModuleList([\n            Block(\n                dim=embed_dim, share=depth[i], num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, group = group, \n                re_atten=re_atten, apply_transform=apply_transform[i], transform_scale=transform_scale, scale_adjustment=scale_adjustment)\n            for i in range(len(depth))])\n        self.norm = norm_layer(embed_dim)\n\n        # Classifier head\n        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n        trunc_normal_(self.pos_embed, std=.02)\n        trunc_normal_(self.cls_token, std=.02)\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token'}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x):\n        if self.cos_reg:\n            atten_list = []\n        B = x.shape[0]\n        x = self.patch_embed(x)\n\n        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks\n        x = torch.cat((cls_tokens, x), dim=1)\n        x = x + self.pos_embed\n        x = self.pos_drop(x)\n        attn = None\n        for blk in self.blocks:\n            x, attn = blk(x, attn)\n            if self.cos_reg:\n                atten_list.append(attn)\n\n        x = self.norm(x)\n        if self.cos_reg and self.training:\n            return x[:, 0], atten_list\n        else:\n            return x[:, 0]\n\n    def forward(self, x):\n        if self.cos_reg and self.training:\n            x, atten = self.forward_features(x)\n            x = self.head(x)\n            return x, atten\n        else:\n            x = self.forward_features(x)\n            x = self.head(x)\n            return x\n\n\n@register_model\ndef deepvit_patch16_224_re_attn_16b(pretrained=False, **kwargs):\n    apply_transform = [False] * 0 + [True] * 16\n    model = DeepVisionTransformer(\n        patch_size=16, embed_dim=384, depth=[False] * 16, apply_transform=apply_transform, num_heads=12, mlp_ratio=3, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),  **kwargs)\n    # We following the same settings for original ViT\n    model.default_cfg = default_cfgs['Deepvit_base_patch16_224_16B']\n    if pretrained:\n        load_pretrained(\n            model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)\n    return model\n\n@register_model\ndef deepvit_patch16_224_re_attn_24b(pretrained=False, **kwargs):\n    apply_transform = [False] * 0 + [True] * 24\n    model = DeepVisionTransformer(\n        patch_size=16, embed_dim=384, depth=[False] * 24, apply_transform=apply_transform, num_heads=12, mlp_ratio=3, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),  **kwargs)\n    # We following the same settings for original ViT\n    model.default_cfg = default_cfgs['Deepvit_base_patch16_224_24B']\n    if pretrained:\n        load_pretrained(\n            model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)\n    return model\n \n@register_model\ndef deepvit_patch16_224_re_attn_32b(pretrained=False, **kwargs):\n    apply_transform = [False] * 0 + [True] * 32\n    model = DeepVisionTransformer(\n        patch_size=16, embed_dim=384, depth=[False] * 32, apply_transform=apply_transform, num_heads=12, mlp_ratio=3, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),  **kwargs)\n    # We following the same settings for original ViT\n    model.default_cfg = default_cfgs['Deepvit_base_patch16_224_32B']\n    if pretrained:\n        load_pretrained(\n            model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)\n    return model\n@register_model\ndef deepvit_S(pretrained=False, **kwargs):\n    apply_transform = [False] * 11 + [True] * 5\n    model = DeepVisionTransformer(\n        patch_size=16, embed_dim=396, depth=[False] * 16, apply_transform=apply_transform, num_heads=12, mlp_ratio=3, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),  transform_scale=True, use_cnn_embed = True, scale_adjustment=0.5, **kwargs)\n    # We following the same settings for original ViT\n    model.default_cfg = default_cfgs['Deepvit_base_patch16_224_32B']\n    if pretrained:\n        load_pretrained(\n            model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)\n    return model\n@register_model\ndef deepvit_L(pretrained=False, **kwargs):\n    apply_transform = [False] * 20 + [True] * 12\n    model = DeepVisionTransformer(\n        patch_size=16, embed_dim=420, depth=[False] * 32, apply_transform=apply_transform, num_heads=12, mlp_ratio=3, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), use_cnn_embed = True, scale_adjustment=0.5, **kwargs)\n    # We following the same settings for original ViT\n    model.default_cfg = default_cfgs['Deepvit_base_patch16_224_32B']\n    if pretrained:\n        load_pretrained(\n            model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)\n    return model\n\n@register_model\ndef deepvit_L_384(pretrained=False, **kwargs):\n    apply_transform = [False] * 20 + [True] * 12\n    model = DeepVisionTransformer(\n        img_size=384, patch_size=16, embed_dim=420, depth=[False] * 32, apply_transform=apply_transform, num_heads=12, mlp_ratio=3, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), use_cnn_embed = True, scale_adjustment=0.5, **kwargs)\n    # We following the same settings for original ViT\n    model.default_cfg = default_cfgs['Deepvit_L_384']\n    if pretrained:\n        load_pretrained(\n            model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)\n    return model\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = DeepVisionTransformer(\n        patch_size=16, embed_dim=384, \n        depth=[False] * 16, \n        apply_transform=[False] * 0 + [True] * 32, \n        num_heads=12, \n        mlp_ratio=3, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        )\n    output=model(input)\n    print(output.shape)"
  },
  {
    "path": "model/backbone/DeiT.py",
    "content": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom functools import partial\n\nfrom timm.models.vision_transformer import VisionTransformer, _cfg\nfrom timm.models.registry import register_model\nfrom timm.models.layers import trunc_normal_\n\n\n__all__ = [\n    'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',\n    'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',\n    'deit_base_distilled_patch16_224', 'deit_base_patch16_384',\n    'deit_base_distilled_patch16_384',\n]\n\n\nclass DistilledVisionTransformer(VisionTransformer):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))\n        num_patches = self.patch_embed.num_patches\n        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))\n        self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()\n\n        trunc_normal_(self.dist_token, std=.02)\n        trunc_normal_(self.pos_embed, std=.02)\n        self.head_dist.apply(self._init_weights)\n\n    def forward_features(self, x):\n        # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n        # with slight modifications to add the dist_token\n        B = x.shape[0]\n        x = self.patch_embed(x)\n\n        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks\n        dist_token = self.dist_token.expand(B, -1, -1)\n        x = torch.cat((cls_tokens, dist_token, x), dim=1)\n\n        x = x + self.pos_embed\n        x = self.pos_drop(x)\n\n        for blk in self.blocks:\n            x = blk(x)\n\n        x = self.norm(x)\n        return x[:, 0], x[:, 1]\n\n    def forward(self, x):\n        x, x_dist = self.forward_features(x)\n        x = self.head(x)\n        x_dist = self.head_dist(x_dist)\n        if self.training:\n            return x, x_dist\n        else:\n            # during inference, return the average of both classifier predictions\n            return (x + x_dist) / 2\n\n\n@register_model\ndef deit_tiny_patch16_224(pretrained=False, **kwargs):\n    model = VisionTransformer(\n        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        model.load_state_dict(checkpoint[\"model\"])\n    return model\n\n\n@register_model\ndef deit_small_patch16_224(pretrained=False, **kwargs):\n    model = VisionTransformer(\n        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        model.load_state_dict(checkpoint[\"model\"])\n    return model\n\n\n@register_model\ndef deit_base_patch16_224(pretrained=False, **kwargs):\n    model = VisionTransformer(\n        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        model.load_state_dict(checkpoint[\"model\"])\n    return model\n\n\n@register_model\ndef deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):\n    model = DistilledVisionTransformer(\n        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        model.load_state_dict(checkpoint[\"model\"])\n    return model\n\n\n@register_model\ndef deit_small_distilled_patch16_224(pretrained=False, **kwargs):\n    model = DistilledVisionTransformer(\n        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        model.load_state_dict(checkpoint[\"model\"])\n    return model\n\n\n@register_model\ndef deit_base_distilled_patch16_224(pretrained=False, **kwargs):\n    model = DistilledVisionTransformer(\n        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        model.load_state_dict(checkpoint[\"model\"])\n    return model\n\n\n@register_model\ndef deit_base_patch16_384(pretrained=False, **kwargs):\n    model = VisionTransformer(\n        img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        model.load_state_dict(checkpoint[\"model\"])\n    return model\n\n\n@register_model\ndef deit_base_distilled_patch16_384(pretrained=False, **kwargs):\n    model = DistilledVisionTransformer(\n        img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n    model.default_cfg = _cfg()\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            url=\"https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth\",\n            map_location=\"cpu\", check_hash=True\n        )\n        model.load_state_dict(checkpoint[\"model\"])\n    return model\n\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = DistilledVisionTransformer(\n        patch_size=16, \n        embed_dim=384, \n        depth=12, \n        num_heads=6, \n        mlp_ratio=4, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n        )\n    output=model(input)\n    print(output[0].shape)"
  },
  {
    "path": "model/backbone/EfficientFormer.py",
    "content": "\"\"\"\nEfficientFormer\n\"\"\"\nimport os\nimport copy\nimport torch\nimport torch.nn as nn\n\nfrom typing import Dict\nimport itertools\n\nfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\nfrom timm.models.layers import DropPath, trunc_normal_\nfrom timm.models.registry import register_model\nfrom timm.models.layers.helpers import to_2tuple\n\nEfficientFormer_width = {\n    'l1': [48, 96, 224, 448],\n    'l3': [64, 128, 320, 512],\n    'l7': [96, 192, 384, 768],\n}\n\nEfficientFormer_depth = {\n    'l1': [3, 2, 6, 4],\n    'l3': [4, 4, 12, 6],\n    'l7': [6, 6, 18, 8],\n}\n\n\nclass Attention(torch.nn.Module):\n    def __init__(self, dim=384, key_dim=32, num_heads=8,\n                 attn_ratio=4,\n                 resolution=7):\n        super().__init__()\n        self.num_heads = num_heads\n        self.scale = key_dim ** -0.5\n        self.key_dim = key_dim\n        self.nh_kd = nh_kd = key_dim * num_heads\n        self.d = int(attn_ratio * key_dim)\n        self.dh = int(attn_ratio * key_dim) * num_heads\n        self.attn_ratio = attn_ratio\n        h = self.dh + nh_kd * 2\n        self.qkv = nn.Linear(dim, h)\n        self.proj = nn.Linear(self.dh, dim)\n\n        points = list(itertools.product(range(resolution), range(resolution)))\n        N = len(points)\n        attention_offsets = {}\n        idxs = []\n        for p1 in points:\n            for p2 in points:\n                offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))\n                if offset not in attention_offsets:\n                    attention_offsets[offset] = len(attention_offsets)\n                idxs.append(attention_offsets[offset])\n        self.attention_biases = torch.nn.Parameter(\n            torch.zeros(num_heads, len(attention_offsets)))\n        self.register_buffer('attention_bias_idxs',\n                             torch.LongTensor(idxs).view(N, N))\n\n    @torch.no_grad()\n    def train(self, mode=True):\n        super().train(mode)\n        if mode and hasattr(self, 'ab'):\n            del self.ab\n        else:\n            self.ab = self.attention_biases[:, self.attention_bias_idxs]\n\n    def forward(self, x):  # x (B,N,C)\n        B, N, C = x.shape\n        qkv = self.qkv(x)\n        q, k, v = qkv.reshape(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)\n        q = q.permute(0, 2, 1, 3)\n        k = k.permute(0, 2, 1, 3)\n        v = v.permute(0, 2, 1, 3)\n\n        attn = (\n                (q @ k.transpose(-2, -1)) * self.scale\n                +\n                (self.attention_biases[:, self.attention_bias_idxs]\n                 if self.training else self.ab)\n        )\n        attn = attn.softmax(dim=-1)\n        x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)\n        x = self.proj(x)\n        return x\n\n\ndef stem(in_chs, out_chs):\n    return nn.Sequential(\n        nn.Conv2d(in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1),\n        nn.BatchNorm2d(out_chs // 2),\n        nn.ReLU(),\n        nn.Conv2d(out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1),\n        nn.BatchNorm2d(out_chs),\n        nn.ReLU(), )\n\n\nclass Embedding(nn.Module):\n    \"\"\"\n    Patch Embedding that is implemented by a layer of conv.\n    Input: tensor in shape [B, C, H, W]\n    Output: tensor in shape [B, C, H/stride, W/stride]\n    \"\"\"\n\n    def __init__(self, patch_size=16, stride=16, padding=0,\n                 in_chans=3, embed_dim=768, norm_layer=nn.BatchNorm2d):\n        super().__init__()\n        patch_size = to_2tuple(patch_size)\n        stride = to_2tuple(stride)\n        padding = to_2tuple(padding)\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size,\n                              stride=stride, padding=padding)\n        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n\n    def forward(self, x):\n        x = self.proj(x)\n        x = self.norm(x)\n        return x\n\n\nclass Flat(nn.Module):\n\n    def __init__(self, ):\n        super().__init__()\n\n    def forward(self, x):\n        x = x.flatten(2).transpose(1, 2)\n        return x\n\n\nclass Pooling(nn.Module):\n    \"\"\"\n    Implementation of pooling for PoolFormer\n    --pool_size: pooling size\n    \"\"\"\n\n    def __init__(self, pool_size=3):\n        super().__init__()\n        self.pool = nn.AvgPool2d(\n            pool_size, stride=1, padding=pool_size // 2, count_include_pad=False)\n\n    def forward(self, x):\n        return self.pool(x) - x\n\n\nclass LinearMlp(nn.Module):\n    \"\"\" MLP as used in Vision Transformer, MLP-Mixer and related networks\n    \"\"\"\n\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.drop1 = nn.Dropout(drop)\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop2 = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop1(x)\n        x = self.fc2(x)\n        x = self.drop2(x)\n        return x\n\n\nclass Mlp(nn.Module):\n    \"\"\"\n    Implementation of MLP with 1*1 convolutions.\n    Input: tensor with shape [B, C, H, W]\n    \"\"\"\n\n    def __init__(self, in_features, hidden_features=None,\n                 out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Conv2d(in_features, hidden_features, 1)\n        self.act = act_layer()\n        self.fc2 = nn.Conv2d(hidden_features, out_features, 1)\n        self.drop = nn.Dropout(drop)\n        self.apply(self._init_weights)\n\n        self.norm1 = nn.BatchNorm2d(hidden_features)\n        self.norm2 = nn.BatchNorm2d(out_features)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Conv2d):\n            trunc_normal_(m.weight, std=.02)\n            if m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n\n    def forward(self, x):\n        x = self.fc1(x)\n\n        x = self.norm1(x)\n\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n\n        x = self.norm2(x)\n\n        x = self.drop(x)\n        return x\n\n\nclass Meta3D(nn.Module):\n\n    def __init__(self, dim, mlp_ratio=4.,\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 drop=0., drop_path=0.,\n                 use_layer_scale=True, layer_scale_init_value=1e-5):\n\n        super().__init__()\n\n        self.norm1 = norm_layer(dim)\n        self.token_mixer = Attention(dim)\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = LinearMlp(in_features=dim, hidden_features=mlp_hidden_dim,\n                             act_layer=act_layer, drop=drop)\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. \\\n            else nn.Identity()\n        self.use_layer_scale = use_layer_scale\n        if use_layer_scale:\n            self.layer_scale_1 = nn.Parameter(\n                layer_scale_init_value * torch.ones((dim)), requires_grad=True)\n            self.layer_scale_2 = nn.Parameter(\n                layer_scale_init_value * torch.ones((dim)), requires_grad=True)\n\n    def forward(self, x):\n        if self.use_layer_scale:\n            x = x + self.drop_path(\n                self.layer_scale_1.unsqueeze(0).unsqueeze(0)\n                * self.token_mixer(self.norm1(x)))\n            x = x + self.drop_path(\n                self.layer_scale_2.unsqueeze(0).unsqueeze(0)\n                * self.mlp(self.norm2(x)))\n\n        else:\n            x = x + self.drop_path(self.token_mixer(self.norm1(x)))\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n\n\nclass Meta4D(nn.Module):\n\n    def __init__(self, dim, pool_size=3, mlp_ratio=4.,\n                 act_layer=nn.GELU,\n                 drop=0., drop_path=0.,\n                 use_layer_scale=True, layer_scale_init_value=1e-5):\n        super().__init__()\n\n        self.token_mixer = Pooling(pool_size=pool_size)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,\n                       act_layer=act_layer, drop=drop)\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. \\\n            else nn.Identity()\n        self.use_layer_scale = use_layer_scale\n        if use_layer_scale:\n            self.layer_scale_1 = nn.Parameter(\n                layer_scale_init_value * torch.ones((dim)), requires_grad=True)\n            self.layer_scale_2 = nn.Parameter(\n                layer_scale_init_value * torch.ones((dim)), requires_grad=True)\n\n    def forward(self, x):\n        if self.use_layer_scale:\n\n            x = x + self.drop_path(\n                self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)\n                * self.token_mixer(x))\n            x = x + self.drop_path(\n                self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)\n                * self.mlp(x))\n        else:\n            x = x + self.drop_path(self.token_mixer(x))\n            x = x + self.drop_path(self.mlp(x))\n        return x\n\n\ndef meta_blocks(dim, index, layers,\n                pool_size=3, mlp_ratio=4.,\n                act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                drop_rate=.0, drop_path_rate=0.,\n                use_layer_scale=True, layer_scale_init_value=1e-5, vit_num=1):\n    blocks = []\n    if index == 3 and vit_num == layers[index]:\n        blocks.append(Flat())\n    for block_idx in range(layers[index]):\n        block_dpr = drop_path_rate * (\n                block_idx + sum(layers[:index])) / (sum(layers) - 1)\n        if index == 3 and layers[index] - block_idx <= vit_num:\n            blocks.append(Meta3D(\n                dim, mlp_ratio=mlp_ratio,\n                act_layer=act_layer, norm_layer=norm_layer,\n                drop=drop_rate, drop_path=block_dpr,\n                use_layer_scale=use_layer_scale,\n                layer_scale_init_value=layer_scale_init_value,\n            ))\n        else:\n            blocks.append(Meta4D(\n                dim, pool_size=pool_size, mlp_ratio=mlp_ratio,\n                act_layer=act_layer,\n                drop=drop_rate, drop_path=block_dpr,\n                use_layer_scale=use_layer_scale,\n                layer_scale_init_value=layer_scale_init_value,\n            ))\n            if index == 3 and layers[index] - block_idx - 1 == vit_num:\n                blocks.append(Flat())\n\n    blocks = nn.Sequential(*blocks)\n    return blocks\n\n\nclass EfficientFormer(nn.Module):\n\n    def __init__(self, layers, embed_dims=None,\n                 mlp_ratios=4, downsamples=None,\n                 pool_size=3,\n                 norm_layer=nn.LayerNorm, act_layer=nn.GELU,\n                 num_classes=1000,\n                 down_patch_size=3, down_stride=2, down_pad=1,\n                 drop_rate=0., drop_path_rate=0.,\n                 use_layer_scale=True, layer_scale_init_value=1e-5,\n                 fork_feat=False,\n                 init_cfg=None,\n                 pretrained=None,\n                 vit_num=0,\n                 distillation=True,\n                 **kwargs):\n        super().__init__()\n\n        if not fork_feat:\n            self.num_classes = num_classes\n        self.fork_feat = fork_feat\n\n        self.patch_embed = stem(3, embed_dims[0])\n\n        network = []\n        for i in range(len(layers)):\n            stage = meta_blocks(embed_dims[i], i, layers,\n                                pool_size=pool_size, mlp_ratio=mlp_ratios,\n                                act_layer=act_layer, norm_layer=norm_layer,\n                                drop_rate=drop_rate,\n                                drop_path_rate=drop_path_rate,\n                                use_layer_scale=use_layer_scale,\n                                layer_scale_init_value=layer_scale_init_value,\n                                vit_num=vit_num)\n            network.append(stage)\n            if i >= len(layers) - 1:\n                break\n            if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:\n                # downsampling between two stages\n                network.append(\n                    Embedding(\n                        patch_size=down_patch_size, stride=down_stride,\n                        padding=down_pad,\n                        in_chans=embed_dims[i], embed_dim=embed_dims[i + 1]\n                    )\n                )\n\n        self.network = nn.ModuleList(network)\n\n        if self.fork_feat:\n            # add a norm layer for each output\n            self.out_indices = [0, 2, 4, 6]\n            for i_emb, i_layer in enumerate(self.out_indices):\n                if i_emb == 0 and os.environ.get('FORK_LAST3', None):\n                    layer = nn.Identity()\n                else:\n                    layer = norm_layer(embed_dims[i_emb])\n                layer_name = f'norm{i_layer}'\n                self.add_module(layer_name, layer)\n        else:\n            # Classifier head\n            self.norm = norm_layer(embed_dims[-1])\n            self.head = nn.Linear(\n                embed_dims[-1], num_classes) if num_classes > 0 \\\n                else nn.Identity()\n            self.dist = distillation\n            if self.dist:\n                self.dist_head = nn.Linear(\n                    embed_dims[-1], num_classes) if num_classes > 0 \\\n                    else nn.Identity()\n\n        self.apply(self.cls_init_weights)\n\n        self.init_cfg = copy.deepcopy(init_cfg)\n        # load pre-trained model\n        if self.fork_feat and (\n                self.init_cfg is not None or pretrained is not None):\n            self.init_weights()\n\n    # init for classification\n    def cls_init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n\n    # init for mmdetection or mmsegmentation by loading\n    # imagenet pre-trained weights\n    def init_weights(self, pretrained=None):\n        logger = get_root_logger()\n        if self.init_cfg is None and pretrained is None:\n            logger.warn(f'No pre-trained weights for '\n                        f'{self.__class__.__name__}, '\n                        f'training start from scratch')\n            pass\n        else:\n            assert 'checkpoint' in self.init_cfg, f'Only support ' \\\n                                                  f'specify `Pretrained` in ' \\\n                                                  f'`init_cfg` in ' \\\n                                                  f'{self.__class__.__name__} '\n            if self.init_cfg is not None:\n                ckpt_path = self.init_cfg['checkpoint']\n            elif pretrained is not None:\n                ckpt_path = pretrained\n\n            ckpt = _load_checkpoint(\n                ckpt_path, logger=logger, map_location='cpu')\n            if 'state_dict' in ckpt:\n                _state_dict = ckpt['state_dict']\n            elif 'model' in ckpt:\n                _state_dict = ckpt['model']\n            else:\n                _state_dict = ckpt\n\n            state_dict = _state_dict\n            missing_keys, unexpected_keys = \\\n                self.load_state_dict(state_dict, False)\n\n    def forward_tokens(self, x):\n        outs = []\n        for idx, block in enumerate(self.network):\n            x = block(x)\n            if self.fork_feat and idx in self.out_indices:\n                norm_layer = getattr(self, f'norm{idx}')\n                x_out = norm_layer(x)\n                outs.append(x_out)\n        if self.fork_feat:\n            return outs\n        return x\n\n    def forward(self, x):\n        x = self.patch_embed(x)\n        x = self.forward_tokens(x)\n        if self.fork_feat:\n            # otuput features of four stages for dense prediction\n            return x\n        x = self.norm(x)\n        if self.dist:\n            cls_out = self.head(x.mean(-2)), self.dist_head(x.mean(-2))\n            if not self.training:\n                cls_out = (cls_out[0] + cls_out[1]) / 2\n        else:\n            cls_out = self.head(x.mean(-2))\n        # for image classification\n        return cls_out\n\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n        'crop_pct': .95, 'interpolation': 'bicubic',\n        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,\n        'classifier': 'head',\n        **kwargs\n    }\n\n\n@register_model\ndef efficientformer_l1(pretrained=False, **kwargs):\n    model = EfficientFormer(\n        layers=EfficientFormer_depth['l1'],\n        embed_dims=EfficientFormer_width['l1'],\n        downsamples=[True, True, True, True],\n        vit_num=1,\n        **kwargs)\n    model.default_cfg = _cfg(crop_pct=0.9)\n    return model\n\n\n@register_model\ndef efficientformer_l3(pretrained=False, **kwargs):\n    model = EfficientFormer(\n        layers=EfficientFormer_depth['l3'],\n        embed_dims=EfficientFormer_width['l3'],\n        downsamples=[True, True, True, True],\n        vit_num=4,\n        **kwargs)\n    model.default_cfg = _cfg(crop_pct=0.9)\n    return model\n\n\n@register_model\ndef efficientformer_l7(pretrained=False, **kwargs):\n    model = EfficientFormer(\n        layers=EfficientFormer_depth['l7'],\n        embed_dims=EfficientFormer_width['l7'],\n        downsamples=[True, True, True, True],\n        vit_num=8,\n        **kwargs)\n    model.default_cfg = _cfg(crop_pct=0.9)\n    return model\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = EfficientFormer(\n        layers=EfficientFormer_depth['l1'],\n        embed_dims=EfficientFormer_width['l1'],\n        downsamples=[True, True, True, True],\n        vit_num=1,\n    )\n    output=model(input)\n    print(output[0].shape)"
  },
  {
    "path": "model/backbone/HATNet.py",
    "content": "from pyexpat import model\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom timm.models.layers import DropPath, trunc_normal_\nfrom timm.models.registry import register_model\nfrom timm.models.vision_transformer import _cfg\n\n\nclass InvertedResidual(nn.Module):\n    def __init__(self, in_dim, hidden_dim=None, out_dim=None, kernel_size=3,\n                 drop=0., act_layer=nn.SiLU):\n        super().__init__()\n        hidden_dim = hidden_dim or in_dim\n        out_dim = out_dim or in_dim\n        pad = (kernel_size - 1) // 2\n        self.conv1 = nn.Sequential(\n            nn.GroupNorm(1, in_dim, eps=1e-6),\n            nn.Conv2d(in_dim, hidden_dim, 1, bias=False),\n            act_layer(inplace=True)\n        )\n        self.conv2 = nn.Sequential(\n            nn.Conv2d(hidden_dim, hidden_dim, kernel_size, padding=pad, groups=hidden_dim, bias=False),\n            act_layer(inplace=True)\n        )\n        self.conv3 = nn.Sequential(\n            nn.Conv2d(hidden_dim, out_dim, 1, bias=False),\n            nn.GroupNorm(1, out_dim, eps=1e-6)\n        )\n        self.drop = nn.Dropout2d(drop, inplace=True)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.conv2(x)\n        x = self.drop(x)\n        x = self.conv3(x)\n        x = self.drop(x)\n\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(self, dim, head_dim, grid_size=1, ds_ratio=1, drop=0.):\n        super().__init__()\n        assert dim % head_dim == 0\n        self.num_heads = dim // head_dim\n        self.head_dim = head_dim\n        self.scale = self.head_dim ** -0.5\n        self.grid_size = grid_size\n\n        self.norm = nn.GroupNorm(1, dim, eps=1e-6)\n        self.qkv = nn.Conv2d(dim, dim * 3, 1)\n        self.proj = nn.Conv2d(dim, dim, 1)\n        self.proj_norm = nn.GroupNorm(1, dim, eps=1e-6)\n        self.drop = nn.Dropout2d(drop, inplace=True)\n\n        if grid_size > 1:\n            self.grid_norm = nn.GroupNorm(1, dim, eps=1e-6)\n            self.avg_pool = nn.AvgPool2d(ds_ratio, stride=ds_ratio)\n            self.ds_norm = nn.GroupNorm(1, dim, eps=1e-6)\n            self.q = nn.Conv2d(dim, dim, 1)\n            self.kv = nn.Conv2d(dim, dim * 2, 1)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        qkv = self.qkv(self.norm(x))\n\n        if self.grid_size > 1:\n            grid_h, grid_w = H // self.grid_size, W // self.grid_size\n            qkv = qkv.reshape(B, 3, self.num_heads, self.head_dim, grid_h,\n                self.grid_size, grid_w, self.grid_size)\n            qkv = qkv.permute(1, 0, 2, 4, 6, 5, 7, 3)\n            qkv = qkv.reshape(3, -1, self.grid_size * self.grid_size, self.head_dim)\n            q, k, v = qkv[0], qkv[1], qkv[2]\n\n            attn = (q * self.scale) @ k.transpose(-2, -1)\n            attn = attn.softmax(dim=-1)\n            grid_x = (attn @ v).reshape(B, self.num_heads, grid_h, grid_w,\n                self.grid_size, self.grid_size, self.head_dim)\n            grid_x = grid_x.permute(0, 1, 6, 2, 4, 3, 5).reshape(B, C, H, W)\n            grid_x = self.grid_norm(x + grid_x)\n\n            q = self.q(grid_x).reshape(B, self.num_heads, self.head_dim, -1)\n            q = q.transpose(-2, -1)\n            kv = self.kv(self.ds_norm(self.avg_pool(grid_x)))\n            kv = kv.reshape(B, 2, self.num_heads, self.head_dim, -1)\n            kv = kv.permute(1, 0, 2, 4, 3)\n            k, v = kv[0], kv[1]\n        else:\n            qkv = qkv.reshape(B, 3, self.num_heads, self.head_dim, -1)\n            qkv = qkv.permute(1, 0, 2, 4, 3)\n            q, k, v = qkv[0], qkv[1], qkv[2]\n\n        attn = (q * self.scale) @ k.transpose(-2, -1)\n        attn = attn.softmax(dim=-1)\n        global_x = (attn @ v).transpose(-2, -1).reshape(B, C, H, W)\n        if self.grid_size > 1:\n            global_x = global_x + grid_x\n        x = self.drop(self.proj(global_x))\n\n        return x\n\n\nclass Block(nn.Module):\n    def __init__(self, dim, head_dim, grid_size=1, ds_ratio=1, expansion=4,\n                 drop=0., drop_path=0., kernel_size=3, act_layer=nn.SiLU):\n        super().__init__()\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.attn = Attention(dim, head_dim, grid_size=grid_size, ds_ratio=ds_ratio, drop=drop)\n        self.conv = InvertedResidual(dim, hidden_dim=dim * expansion, out_dim=dim,\n            kernel_size=kernel_size, drop=drop, act_layer=act_layer)\n\n    def forward(self, x):\n        x = x + self.drop_path(self.attn(x))\n        x = x + self.drop_path(self.conv(x))\n        return x\n\n\nclass Downsample(nn.Module):\n    def __init__(self, in_dim, out_dim, kernel_size=3):\n        super().__init__()\n        self.conv = nn.Conv2d(in_dim, out_dim, kernel_size, padding=1, stride=2)\n        self.norm = nn.GroupNorm(1, out_dim, eps=1e-6)\n\n    def forward(self, x):\n        x = self.norm(self.conv(x))\n        return x\n\n\nclass HATNet(nn.Module):\n    def __init__(self, img_size=224, in_chans=3, num_classes=1000, dims=[64, 128, 256, 512],\n                 head_dim=64, expansions=[4, 4, 6, 6], grid_sizes=[1, 1, 1, 1],\n                 ds_ratios=[8, 4, 2, 1], depths=[3, 4, 8, 3], drop_rate=0.,\n                 drop_path_rate=0., act_layer=nn.SiLU, kernel_sizes=[3, 3, 3, 3]):\n        super().__init__()\n        self.depths = depths\n        self.patch_embed = nn.Sequential(\n            nn.Conv2d(3, 16, 3, padding=1, stride=2),\n            nn.GroupNorm(1, 16, eps=1e-6),\n            act_layer(inplace=True),\n            nn.Conv2d(16, dims[0], 3, padding=1, stride=2),\n        )\n\n        self.blocks = []\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]\n        for stage in range(len(dims)):\n            self.blocks.append(nn.ModuleList([Block(\n                dims[stage], head_dim, grid_size=grid_sizes[stage], ds_ratio=ds_ratios[stage],\n                expansion=expansions[stage], drop=drop_rate, drop_path=dpr[sum(depths[:stage]) + i],\n                kernel_size=kernel_sizes[stage], act_layer=act_layer)\n                for i in range(depths[stage])]))\n        self.blocks = nn.ModuleList(self.blocks)\n\n        self.ds2 = Downsample(dims[0], dims[1])\n        self.ds3 = Downsample(dims[1], dims[2])\n        self.ds4 = Downsample(dims[2], dims[3])\n        self.classifier = nn.Sequential(\n            nn.Dropout(0.2, inplace=True),\n            nn.Linear(dims[-1], num_classes),\n        )\n\n        # init weights\n        self.apply(self._init_weights)\n\n    def reset_drop_path(self, drop_path_rate):\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]\n        cur = 0\n        for stage in range(len(self.blocks)):\n            for idx in range(self.depths[stage]):\n                self.blocks[stage][idx].drop_path.drop_prob = dpr[cur + idx]\n            cur += self.depths[stage]\n\n    def _init_weights(self, m):\n        if isinstance(m, (nn.Linear, nn.Conv2d)):\n            trunc_normal_(m.weight, std=.02)\n            if m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def forward(self, x):\n        x = self.patch_embed(x)\n        for block in self.blocks[0]:\n            x = block(x)\n        x = self.ds2(x)\n        for block in self.blocks[1]:\n            x = block(x)\n        x = self.ds3(x)\n        for block in self.blocks[2]:\n            x = block(x)\n        x = self.ds4(x)\n        for block in self.blocks[3]:\n            x = block(x)\n        x = F.adaptive_avg_pool2d(x, (1, 1)).flatten(1)\n        x = self.classifier(x)\n\n        return x\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    hat = HATNet(dims=[48, 96, 240, 384], head_dim=48, expansions=[8, 8, 4, 4],\n        grid_sizes=[8, 7, 7, 1], ds_ratios=[8, 4, 2, 1], depths=[2, 2, 6, 3])\n    output=hat(input)\n    print(output.shape)"
  },
  {
    "path": "model/backbone/LeViT.py",
    "content": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\n\n# Modified from\n# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n# Copyright 2020 Ross Wightman, Apache-2.0 License\n\nimport torch\nimport itertools\n# import utils\n\nfrom timm.models.vision_transformer import trunc_normal_\nfrom timm.models.registry import register_model\n\ndef replace_batchnorm(net):\n    for child_name, child in net.named_children():\n        if hasattr(child, 'fuse'):\n            setattr(net, child_name, child.fuse())\n        elif isinstance(child, torch.nn.Conv2d):\n            child.bias = torch.nn.Parameter(torch.zeros(child.weight.size(0)))\n        elif isinstance(child, torch.nn.BatchNorm2d):\n            setattr(net, child_name, torch.nn.Identity())\n        else:\n            replace_batchnorm(child)\n\nspecification = {\n    'LeViT_128S': {\n        'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0,\n        'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'},\n    'LeViT_128': {\n        'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0,\n        'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'},\n    'LeViT_192': {\n        'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0,\n        'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'},\n    'LeViT_256': {\n        'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0,\n        'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'},\n    'LeViT_384': {\n        'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1,\n        'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'},\n}\n\n__all__ = [specification.keys()]\n\n\n@register_model\ndef LeViT_128S(num_classes=1000, distillation=True,\n               pretrained=False, fuse=False):\n    return model_factory(**specification['LeViT_128S'], num_classes=num_classes,\n                         distillation=distillation, pretrained=pretrained, fuse=fuse)\n\n\n@register_model\ndef LeViT_128(num_classes=1000, distillation=True,\n              pretrained=False, fuse=False):\n    return model_factory(**specification['LeViT_128'], num_classes=num_classes,\n                         distillation=distillation, pretrained=pretrained, fuse=fuse)\n\n\n@register_model\ndef LeViT_192(num_classes=1000, distillation=True,\n              pretrained=False, fuse=False):\n    return model_factory(**specification['LeViT_192'], num_classes=num_classes,\n                         distillation=distillation, pretrained=pretrained, fuse=fuse)\n\n\n@register_model\ndef LeViT_256(num_classes=1000, distillation=True,\n              pretrained=False, fuse=False):\n    return model_factory(**specification['LeViT_256'], num_classes=num_classes,\n                         distillation=distillation, pretrained=pretrained, fuse=fuse)\n\n\n@register_model\ndef LeViT_384(num_classes=1000, distillation=True,\n              pretrained=False, fuse=False):\n    return model_factory(**specification['LeViT_384'], num_classes=num_classes,\n                         distillation=distillation, pretrained=pretrained, fuse=fuse)\n\n\nFLOPS_COUNTER = 0\n\n\nclass Conv2d_BN(torch.nn.Sequential):\n    def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,\n                 groups=1, bn_weight_init=1, resolution=-10000):\n        super().__init__()\n        self.add_module('c', torch.nn.Conv2d(\n            a, b, ks, stride, pad, dilation, groups, bias=False))\n        bn = torch.nn.BatchNorm2d(b)\n        torch.nn.init.constant_(bn.weight, bn_weight_init)\n        torch.nn.init.constant_(bn.bias, 0)\n        self.add_module('bn', bn)\n\n        global FLOPS_COUNTER\n        output_points = ((resolution + 2 * pad - dilation *\n                          (ks - 1) - 1) // stride + 1)**2\n        FLOPS_COUNTER += a * b * output_points * (ks**2) // groups\n\n    @torch.no_grad()\n    def fuse(self):\n        c, bn = self._modules.values()\n        w = bn.weight / (bn.running_var + bn.eps)**0.5\n        w = c.weight * w[:, None, None, None]\n        b = bn.bias - bn.running_mean * bn.weight / \\\n            (bn.running_var + bn.eps)**0.5\n        m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(\n            0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)\n        m.weight.data.copy_(w)\n        m.bias.data.copy_(b)\n        return m\n\n\nclass Linear_BN(torch.nn.Sequential):\n    def __init__(self, a, b, bn_weight_init=1, resolution=-100000):\n        super().__init__()\n        self.add_module('c', torch.nn.Linear(a, b, bias=False))\n        bn = torch.nn.BatchNorm1d(b)\n        torch.nn.init.constant_(bn.weight, bn_weight_init)\n        torch.nn.init.constant_(bn.bias, 0)\n        self.add_module('bn', bn)\n\n        global FLOPS_COUNTER\n        output_points = resolution**2\n        FLOPS_COUNTER += a * b * output_points\n\n    @torch.no_grad()\n    def fuse(self):\n        l, bn = self._modules.values()\n        w = bn.weight / (bn.running_var + bn.eps)**0.5\n        w = l.weight * w[:, None]\n        b = bn.bias - bn.running_mean * bn.weight / \\\n            (bn.running_var + bn.eps)**0.5\n        m = torch.nn.Linear(w.size(1), w.size(0))\n        m.weight.data.copy_(w)\n        m.bias.data.copy_(b)\n        return m\n\n    def forward(self, x):\n        l, bn = self._modules.values()\n        x = l(x)\n        return bn(x.flatten(0, 1)).reshape_as(x)\n\n\nclass BN_Linear(torch.nn.Sequential):\n    def __init__(self, a, b, bias=True, std=0.02):\n        super().__init__()\n        self.add_module('bn', torch.nn.BatchNorm1d(a))\n        l = torch.nn.Linear(a, b, bias=bias)\n        trunc_normal_(l.weight, std=std)\n        if bias:\n            torch.nn.init.constant_(l.bias, 0)\n        self.add_module('l', l)\n        global FLOPS_COUNTER\n        FLOPS_COUNTER += a * b\n\n    @torch.no_grad()\n    def fuse(self):\n        bn, l = self._modules.values()\n        w = bn.weight / (bn.running_var + bn.eps)**0.5\n        b = bn.bias - self.bn.running_mean * \\\n            self.bn.weight / (bn.running_var + bn.eps)**0.5\n        w = l.weight * w[None, :]\n        if l.bias is None:\n            b = b @ self.l.weight.T\n        else:\n            b = (l.weight @ b[:, None]).view(-1) + self.l.bias\n        m = torch.nn.Linear(w.size(1), w.size(0))\n        m.weight.data.copy_(w)\n        m.bias.data.copy_(b)\n        return m\n\n\ndef b16(n, activation, resolution=224):\n    return torch.nn.Sequential(\n        Conv2d_BN(3, n // 8, 3, 2, 1, resolution=resolution),\n        activation(),\n        Conv2d_BN(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2),\n        activation(),\n        Conv2d_BN(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4),\n        activation(),\n        Conv2d_BN(n // 2, n, 3, 2, 1, resolution=resolution // 8))\n\n\nclass Residual(torch.nn.Module):\n    def __init__(self, m, drop):\n        super().__init__()\n        self.m = m\n        self.drop = drop\n\n    def forward(self, x):\n        if self.training and self.drop > 0:\n            return x + self.m(x) * torch.rand(x.size(0), 1, 1,\n                                              device=x.device).ge_(self.drop).div(1 - self.drop).detach()\n        else:\n            return x + self.m(x)\n\n\nclass Attention(torch.nn.Module):\n    def __init__(self, dim, key_dim, num_heads=8,\n                 attn_ratio=4,\n                 activation=None,\n                 resolution=14):\n        super().__init__()\n        self.num_heads = num_heads\n        self.scale = key_dim ** -0.5\n        self.key_dim = key_dim\n        self.nh_kd = nh_kd = key_dim * num_heads\n        self.d = int(attn_ratio * key_dim)\n        self.dh = int(attn_ratio * key_dim) * num_heads\n        self.attn_ratio = attn_ratio\n        h = self.dh + nh_kd * 2\n        self.qkv = Linear_BN(dim, h, resolution=resolution)\n        self.proj = torch.nn.Sequential(activation(), Linear_BN(\n            self.dh, dim, bn_weight_init=0, resolution=resolution))\n\n        points = list(itertools.product(range(resolution), range(resolution)))\n        N = len(points)\n        attention_offsets = {}\n        idxs = []\n        for p1 in points:\n            for p2 in points:\n                offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))\n                if offset not in attention_offsets:\n                    attention_offsets[offset] = len(attention_offsets)\n                idxs.append(attention_offsets[offset])\n        self.attention_biases = torch.nn.Parameter(\n            torch.zeros(num_heads, len(attention_offsets)))\n        self.register_buffer('attention_bias_idxs',\n                             torch.LongTensor(idxs).view(N, N))\n\n        global FLOPS_COUNTER\n        #queries * keys\n        FLOPS_COUNTER += num_heads * (resolution**4) * key_dim\n        # softmax\n        FLOPS_COUNTER += num_heads * (resolution**4)\n        #attention * v\n        FLOPS_COUNTER += num_heads * self.d * (resolution**4)\n\n    @torch.no_grad()\n    def train(self, mode=True):\n        super().train(mode)\n        if mode and hasattr(self, 'ab'):\n            del self.ab\n        else:\n            self.ab = self.attention_biases[:, self.attention_bias_idxs]\n\n    def forward(self, x):  # x (B,N,C)\n        B, N, C = x.shape\n        qkv = self.qkv(x)\n        q, k, v = qkv.view(B, N, self.num_heads, -\n                           1).split([self.key_dim, self.key_dim, self.d], dim=3)\n        q = q.permute(0, 2, 1, 3)\n        k = k.permute(0, 2, 1, 3)\n        v = v.permute(0, 2, 1, 3)\n\n        attn = (\n            (q @ k.transpose(-2, -1)) * self.scale\n            +\n            (self.attention_biases[:, self.attention_bias_idxs]\n             if self.training else self.ab)\n        )\n        attn = attn.softmax(dim=-1)\n        x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)\n        x = self.proj(x)\n        return x\n\n\nclass Subsample(torch.nn.Module):\n    def __init__(self, stride, resolution):\n        super().__init__()\n        self.stride = stride\n        self.resolution = resolution\n\n    def forward(self, x):\n        B, N, C = x.shape\n        x = x.view(B, self.resolution, self.resolution, C)[\n            :, ::self.stride, ::self.stride].reshape(B, -1, C)\n        return x\n\n\nclass AttentionSubsample(torch.nn.Module):\n    def __init__(self, in_dim, out_dim, key_dim, num_heads=8,\n                 attn_ratio=2,\n                 activation=None,\n                 stride=2,\n                 resolution=14, resolution_=7):\n        super().__init__()\n        self.num_heads = num_heads\n        self.scale = key_dim ** -0.5\n        self.key_dim = key_dim\n        self.nh_kd = nh_kd = key_dim * num_heads\n        self.d = int(attn_ratio * key_dim)\n        self.dh = int(attn_ratio * key_dim) * self.num_heads\n        self.attn_ratio = attn_ratio\n        self.resolution_ = resolution_\n        self.resolution_2 = resolution_**2\n        h = self.dh + nh_kd\n        self.kv = Linear_BN(in_dim, h, resolution=resolution)\n\n        self.q = torch.nn.Sequential(\n            Subsample(stride, resolution),\n            Linear_BN(in_dim, nh_kd, resolution=resolution_))\n        self.proj = torch.nn.Sequential(activation(), Linear_BN(\n            self.dh, out_dim, resolution=resolution_))\n\n        self.stride = stride\n        self.resolution = resolution\n        points = list(itertools.product(range(resolution), range(resolution)))\n        points_ = list(itertools.product(\n            range(resolution_), range(resolution_)))\n        N = len(points)\n        N_ = len(points_)\n        attention_offsets = {}\n        idxs = []\n        for p1 in points_:\n            for p2 in points:\n                size = 1\n                offset = (\n                    abs(p1[0] * stride - p2[0] + (size - 1) / 2),\n                    abs(p1[1] * stride - p2[1] + (size - 1) / 2))\n                if offset not in attention_offsets:\n                    attention_offsets[offset] = len(attention_offsets)\n                idxs.append(attention_offsets[offset])\n        self.attention_biases = torch.nn.Parameter(\n            torch.zeros(num_heads, len(attention_offsets)))\n        self.register_buffer('attention_bias_idxs',\n                             torch.LongTensor(idxs).view(N_, N))\n\n        global FLOPS_COUNTER\n        #queries * keys\n        FLOPS_COUNTER += num_heads * \\\n            (resolution**2) * (resolution_**2) * key_dim\n        # softmax\n        FLOPS_COUNTER += num_heads * (resolution**2) * (resolution_**2)\n        #attention * v\n        FLOPS_COUNTER += num_heads * \\\n            (resolution**2) * (resolution_**2) * self.d\n\n    @torch.no_grad()\n    def train(self, mode=True):\n        super().train(mode)\n        if mode and hasattr(self, 'ab'):\n            del self.ab\n        else:\n            self.ab = self.attention_biases[:, self.attention_bias_idxs]\n\n    def forward(self, x):\n        B, N, C = x.shape\n        k, v = self.kv(x).view(B, N, self.num_heads, -\n                               1).split([self.key_dim, self.d], dim=3)\n        k = k.permute(0, 2, 1, 3)  # BHNC\n        v = v.permute(0, 2, 1, 3)  # BHNC\n        q = self.q(x).view(B, self.resolution_2, self.num_heads,\n                           self.key_dim).permute(0, 2, 1, 3)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale + \\\n            (self.attention_biases[:, self.attention_bias_idxs]\n             if self.training else self.ab)\n        attn = attn.softmax(dim=-1)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)\n        x = self.proj(x)\n        return x\n\n\nclass LeViT(torch.nn.Module):\n    \"\"\" Vision Transformer with support for patch or hybrid CNN input stage\n    \"\"\"\n\n    def __init__(self, img_size=224,\n                 patch_size=16,\n                 in_chans=3,\n                 num_classes=1000,\n                 embed_dim=[192],\n                 key_dim=[64],\n                 depth=[12],\n                 num_heads=[3],\n                 attn_ratio=[2],\n                 mlp_ratio=[2],\n                 hybrid_backbone=None,\n                 down_ops=[],\n                 attention_activation=torch.nn.Hardswish,\n                 mlp_activation=torch.nn.Hardswish,\n                 distillation=True,\n                 drop_path=0):\n        super().__init__()\n        global FLOPS_COUNTER\n\n        self.num_classes = num_classes\n        self.num_features = embed_dim[-1]\n        self.embed_dim = embed_dim\n        self.distillation = distillation\n\n        self.patch_embed = hybrid_backbone\n\n        self.blocks = []\n        down_ops.append([''])\n        resolution = img_size // patch_size\n        for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(\n                zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)):\n            for _ in range(dpth):\n                self.blocks.append(\n                    Residual(Attention(\n                        ed, kd, nh,\n                        attn_ratio=ar,\n                        activation=attention_activation,\n                        resolution=resolution,\n                    ), drop_path))\n                if mr > 0:\n                    h = int(ed * mr)\n                    self.blocks.append(\n                        Residual(torch.nn.Sequential(\n                            Linear_BN(ed, h, resolution=resolution),\n                            mlp_activation(),\n                            Linear_BN(h, ed, bn_weight_init=0,\n                                      resolution=resolution),\n                        ), drop_path))\n            if do[0] == 'Subsample':\n                #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)\n                resolution_ = (resolution - 1) // do[5] + 1\n                self.blocks.append(\n                    AttentionSubsample(\n                        *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2],\n                        attn_ratio=do[3],\n                        activation=attention_activation,\n                        stride=do[5],\n                        resolution=resolution,\n                        resolution_=resolution_))\n                resolution = resolution_\n                if do[4] > 0:  # mlp_ratio\n                    h = int(embed_dim[i + 1] * do[4])\n                    self.blocks.append(\n                        Residual(torch.nn.Sequential(\n                            Linear_BN(embed_dim[i + 1], h,\n                                      resolution=resolution),\n                            mlp_activation(),\n                            Linear_BN(\n                                h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),\n                        ), drop_path))\n        self.blocks = torch.nn.Sequential(*self.blocks)\n\n        # Classifier head\n        self.head = BN_Linear(\n            embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()\n        if distillation:\n            self.head_dist = BN_Linear(\n                embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()\n\n        self.FLOPS = FLOPS_COUNTER\n        FLOPS_COUNTER = 0\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {x for x in self.state_dict().keys() if 'attention_biases' in x}\n\n    def forward(self, x):\n        x = self.patch_embed(x)\n        x = x.flatten(2).transpose(1, 2)\n        x = self.blocks(x)\n        x = x.mean(1)\n        if self.distillation:\n            x = self.head(x), self.head_dist(x)\n            if not self.training:\n                x = (x[0] + x[1]) / 2\n        else:\n            x = self.head(x)\n        return x\n\n\ndef model_factory(C, D, X, N, drop_path, weights,\n                  num_classes, distillation, pretrained, fuse):\n    embed_dim = [int(x) for x in C.split('_')]\n    num_heads = [int(x) for x in N.split('_')]\n    depth = [int(x) for x in X.split('_')]\n    act = torch.nn.Hardswish\n    model = LeViT(\n        patch_size=16,\n        embed_dim=embed_dim,\n        num_heads=num_heads,\n        key_dim=[D] * 3,\n        depth=depth,\n        attn_ratio=[2, 2, 2],\n        mlp_ratio=[2, 2, 2],\n        down_ops=[\n            #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)\n            ['Subsample', D, embed_dim[0] // D, 4, 2, 2],\n            ['Subsample', D, embed_dim[1] // D, 4, 2, 2],\n        ],\n        attention_activation=act,\n        mlp_activation=act,\n        hybrid_backbone=b16(embed_dim[0], activation=act),\n        num_classes=num_classes,\n        drop_path=drop_path,\n        distillation=distillation\n    )\n    if pretrained:\n        checkpoint = torch.hub.load_state_dict_from_url(\n            weights, map_location='cpu')\n        model.load_state_dict(checkpoint['model'])\n    if fuse:\n        replace_batchnorm(model)\n\n    return model\n\n\nif __name__ == '__main__':\n    for name in specification:\n        model = globals()[name](fuse=True, pretrained=False)\n        input=torch.randn(1,3,224,224)\n        model.eval()\n        output = model(input)\n        # print(name,\n        #       model.FLOPS, 'FLOPs',\n        #       sum(p.numel() for p in model.parameters() if p.requires_grad), 'parameters')\n        print(output.shape)"
  },
  {
    "path": "model/backbone/MobileNetV3.py",
    "content": "\"\"\" MobileNet V3\nA PyTorch impl of MobileNet-V3, compatible with TF weights from official impl.\nPaper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244\nHacked together by / Copyright 2019, Ross Wightman\n\"\"\"\nfrom functools import partial\nfrom typing import List\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD\nfrom timm.layers import SelectAdaptivePool2d, Linear, create_conv2d, get_norm_act_layer\nfrom ._builder import build_model_with_cfg, pretrained_cfg_for_features\nfrom ._efficientnet_blocks import SqueezeExcite\nfrom ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \\\n    round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT\nfrom ._features import FeatureInfo, FeatureHooks\nfrom ._manipulate import checkpoint_seq\nfrom ._registry import register_model\n\n__all__ = ['MobileNetV3', 'MobileNetV3Features']\n\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),\n        'crop_pct': 0.875, 'interpolation': 'bilinear',\n        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,\n        'first_conv': 'conv_stem', 'classifier': 'classifier',\n        **kwargs\n    }\n\n\ndefault_cfgs = {\n    'mobilenetv3_large_075': _cfg(url=''),\n    'mobilenetv3_large_100': _cfg(\n        interpolation='bicubic',\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'),\n    'mobilenetv3_large_100_miil': _cfg(\n        interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.),\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_1k_miil_78_0-66471c13.pth'),\n    'mobilenetv3_large_100_miil_in21k': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_in21k_miil-d71cc17b.pth',\n        interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221),\n\n    'mobilenetv3_small_050': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth',\n        interpolation='bicubic'),\n    'mobilenetv3_small_075': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_075_lambc-384766db.pth',\n        interpolation='bicubic'),\n    'mobilenetv3_small_100': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_100_lamb-266a294c.pth',\n        interpolation='bicubic'),\n\n    'mobilenetv3_rw': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',\n        interpolation='bicubic'),\n\n    'tf_mobilenetv3_large_075': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',\n        mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),\n    'tf_mobilenetv3_large_100': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',\n        mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),\n    'tf_mobilenetv3_large_minimal_100': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',\n        mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),\n    'tf_mobilenetv3_small_075': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',\n        mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),\n    'tf_mobilenetv3_small_100': _cfg(\n        url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',\n        mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),\n    'tf_mobilenetv3_small_minimal_100': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',\n        mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),\n\n    'fbnetv3_b': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_b_224-ead5d2a1.pth',\n        test_input_size=(3, 256, 256), crop_pct=0.95),\n    'fbnetv3_d': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_d_224-c98bce42.pth',\n        test_input_size=(3, 256, 256), crop_pct=0.95),\n    'fbnetv3_g': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_g_240-0b1df83b.pth',\n        input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95, pool_size=(8, 8)),\n\n    \"lcnet_035\": _cfg(),\n    \"lcnet_050\": _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_050-f447553b.pth',\n        interpolation='bicubic',\n    ),\n    \"lcnet_075\": _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_075-318cad2c.pth',\n        interpolation='bicubic',\n    ),\n    \"lcnet_100\": _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_100-a929038c.pth',\n        interpolation='bicubic',\n    ),\n    \"lcnet_150\": _cfg(),\n}\n\nclass MobileNetV3(nn.Module):\n    \"\"\" MobiletNet-V3\n    Based on my EfficientNet implementation and building blocks, this model utilizes the MobileNet-v3 specific\n    'efficient head', where global pooling is done before the head convolution without a final batch-norm\n    layer before the classifier.\n    Paper: `Searching for MobileNetV3` - https://arxiv.org/abs/1905.02244\n    Other architectures utilizing MobileNet-V3 efficient head that are supported by this impl include:\n      * HardCoRe-NAS - https://arxiv.org/abs/2102.11646 (defn in hardcorenas.py uses this class)\n      * FBNet-V3 - https://arxiv.org/abs/2006.02049\n      * LCNet - https://arxiv.org/abs/2109.15099\n    \"\"\"\n    \n    def __init__(\n            self, block_args, num_classes=1000, in_chans=3, stem_size=16, fix_stem=False, num_features=1280,\n            head_bias=True, pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,\n            round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):\n        super(MobileNetV3, self).__init__()\n        act_layer = act_layer or nn.ReLU\n        norm_layer = norm_layer or nn.BatchNorm2d\n        norm_act_layer = get_norm_act_layer(norm_layer, act_layer)\n        se_layer = se_layer or SqueezeExcite\n        self.num_classes = num_classes\n        self.num_features = num_features\n        self.drop_rate = drop_rate\n        self.grad_checkpointing = False\n\n        # Stem\n        if not fix_stem:\n            stem_size = round_chs_fn(stem_size)\n        self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)\n        self.bn1 = norm_act_layer(stem_size, inplace=True)\n\n        # Middle stages (IR/ER/DS Blocks)\n        builder = EfficientNetBuilder(\n            output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,\n            act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate)\n        self.blocks = nn.Sequential(*builder(stem_size, block_args))\n        self.feature_info = builder.features\n        head_chs = builder.in_chs\n\n        # Head + Pooling\n        self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)\n        num_pooled_chs = head_chs * self.global_pool.feat_mult()\n        self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)\n        self.act2 = act_layer(inplace=True)\n        self.flatten = nn.Flatten(1) if global_pool else nn.Identity()  # don't flatten if pooling disabled\n        self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n        efficientnet_init_weights(self)\n        \n    def as_sequential(self):\n        layers = [self.conv_stem, self.bn1]\n        layers.extend(self.blocks)\n        layers.extend([self.global_pool, self.conv_head, self.act2])\n        layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])\n        return nn.Sequential(*layers)\n    \n    @torch.jit.ignore\n    def group_matcher(self, coarse=False):\n        return dict(\n            stem=r'^conv_stem|bn1',\n            blocks=r'^blocks\\.(\\d+)' if coarse else r'^blocks\\.(\\d+)\\.(\\d+)'\n        )\n    \n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        self.grad_checkpointing = enable\n        \n    @torch.jit.ignore\n    def get_classifier(self):\n        return self.classifier\n    \n    def reset_classifier(self, num_classes, global_pool='avg'):\n        self.num_classes = num_classes\n        # cannot meaningfully change pooling of efficient head after creation\n        self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)\n        self.flatten = nn.Flatten(1) if global_pool else nn.Identity()  # don't flatten if pooling disabled\n        self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n        \n    def forward_features(self, x):\n        x = self.conv_stem(x)\n        x = self.bn1(x)\n        if self.grad_checkpointing and not torch.jit.is_scripting():\n            x = checkpoint_seq(self.blocks, x, flatten=True)\n        else:\n            x = self.blocks(x)\n        return x\n    \n    def forward_head(self, x, pre_logits: bool = False):\n        x = self.global_pool(x)\n        x = self.conv_head(x)\n        x = self.act2(x)\n        if pre_logits:\n            return x.flatten(1)\n        else:\n            x = self.flatten(x)\n            if self.drop_rate > 0.:\n                x = F.dropout(x, p=self.drop_rate, training=self.training)\n            return self.classifier(x)\n        \n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.forward_head(x)\n        return x\n    \nclass MobileNetV3Features(nn.Module):\n    \"\"\" MobileNetV3 Feature Extractor\n    A work-in-progress feature extraction module for MobileNet-V3 to use as a backbone for segmentation\n    and object detection models.\n    \"\"\"\n    \n    def __init__(\n            self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,\n            stem_size=16, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels,\n            se_from_exp=True, act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):\n        super(MobileNetV3Features, self).__init__()\n        act_layer = act_layer or nn.ReLU\n        norm_layer = norm_layer or nn.BatchNorm2d\n        se_layer = se_layer or SqueezeExcite\n        self.drop_rate = drop_rate\n\n        # Stem\n        if not fix_stem:\n            stem_size = round_chs_fn(stem_size)\n        self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)\n        self.bn1 = norm_layer(stem_size)\n        self.act1 = act_layer(inplace=True)\n\n        # Middle stages (IR/ER/DS Blocks)\n        builder = EfficientNetBuilder(\n            output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,\n            act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer,\n            drop_path_rate=drop_path_rate, feature_location=feature_location)\n        self.blocks = nn.Sequential(*builder(stem_size, block_args))\n        self.feature_info = FeatureInfo(builder.features, out_indices)\n        self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}\n\n        efficientnet_init_weights(self)\n\n        # Register feature extraction hooks with FeatureHooks helper\n        self.feature_hooks = None\n        if feature_location != 'bottleneck':\n            hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))\n            self.feature_hooks = FeatureHooks(hooks, self.named_modules())\n            \n    def forward(self, x) -> List[torch.Tensor]:\n        x = self.conv_stem(x)\n        x = self.bn1(x)\n        x = self.act1(x)\n        if self.feature_hooks is None:\n            features = []\n            if 0 in self._stage_out_idx:\n                features.append(x)  # add stem out\n            for i, b in enumerate(self.blocks):\n                x = b(x)\n                if i + 1 in self._stage_out_idx:\n                    features.append(x)\n            return features\n        else:\n            self.blocks(x)\n            out = self.feature_hooks.get_output(x.device)\n            return list(out.values())\n        \ndef _create_mnv3(variant, pretrained=False, **kwargs):\n    features_only = False\n    model_cls = MobileNetV3\n    kwargs_filter = None\n    if kwargs.pop('features_only', False):\n        features_only = True\n        kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')\n        model_cls = MobileNetV3Features\n    model = build_model_with_cfg(\n        model_cls, variant, pretrained,\n        pretrained_strict=not features_only,\n        kwargs_filter=kwargs_filter,\n        **kwargs)\n    if features_only:\n        model.default_cfg = pretrained_cfg_for_features(model.default_cfg)\n    return model\n\ndef _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):\n    \"\"\"Creates a MobileNet-V3 model.\n    Ref impl: ?\n    Paper: https://arxiv.org/abs/1905.02244\n    Args:\n      channel_multiplier: multiplier to number of channels per layer.\n    \"\"\"\n    arch_def = [\n        # stage 0, 112x112 in\n        ['ds_r1_k3_s1_e1_c16_nre_noskip'],  # relu\n        # stage 1, 112x112 in\n        ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'],  # relu\n        # stage 2, 56x56 in\n        ['ir_r3_k5_s2_e3_c40_se0.25_nre'],  # relu\n        # stage 3, 28x28 in\n        ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],  # hard-swish\n        # stage 4, 14x14in\n        ['ir_r2_k3_s1_e6_c112_se0.25'],  # hard-swish\n        # stage 5, 14x14in\n        ['ir_r3_k5_s2_e6_c160_se0.25'],  # hard-swish\n        # stage 6, 7x7 in\n        ['cn_r1_k1_s1_c960'],  # hard-swish\n    ]\n    model_kwargs = dict(\n        block_args=decode_arch_def(arch_def),\n        head_bias=False,\n        round_chs_fn=partial(round_channels, multiplier=channel_multiplier),\n        norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),\n        act_layer=resolve_act_layer(kwargs, 'hard_swish'),\n        se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid'),\n        **kwargs,\n    )\n    model = _create_mnv3(variant, pretrained, **model_kwargs)\n    return model\n\ndef _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):\n    \"\"\"Creates a MobileNet-V3 model.\n    Ref impl: ?\n    Paper: https://arxiv.org/abs/1905.02244\n    Args:\n      channel_multiplier: multiplier to number of channels per layer.\n    \"\"\"\n    if 'small' in variant:\n        num_features = 1024\n        if 'minimal' in variant:\n            act_layer = resolve_act_layer(kwargs, 'relu')\n            arch_def = [\n                # stage 0, 112x112 in\n                ['ds_r1_k3_s2_e1_c16'],\n                # stage 1, 56x56 in\n                ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'],\n                # stage 2, 28x28 in\n                ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'],\n                # stage 3, 14x14 in\n                ['ir_r2_k3_s1_e3_c48'],\n                # stage 4, 14x14in\n                ['ir_r3_k3_s2_e6_c96'],\n                # stage 6, 7x7 in\n                ['cn_r1_k1_s1_c576'],\n            ]\n        else:\n            act_layer = resolve_act_layer(kwargs, 'hard_swish')\n            arch_def = [\n                # stage 0, 112x112 in\n                ['ds_r1_k3_s2_e1_c16_se0.25_nre'],  # relu\n                # stage 1, 56x56 in\n                ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'],  # relu\n                # stage 2, 28x28 in\n                ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'],  # hard-swish\n                # stage 3, 14x14 in\n                ['ir_r2_k5_s1_e3_c48_se0.25'],  # hard-swish\n                # stage 4, 14x14in\n                ['ir_r3_k5_s2_e6_c96_se0.25'],  # hard-swish\n                # stage 6, 7x7 in\n                ['cn_r1_k1_s1_c576'],  # hard-swish\n            ]\n    else:\n        num_features = 1280\n        if 'minimal' in variant:\n            act_layer = resolve_act_layer(kwargs, 'relu')\n            arch_def = [\n                # stage 0, 112x112 in\n                ['ds_r1_k3_s1_e1_c16'],\n                # stage 1, 112x112 in\n                ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'],\n                # stage 2, 56x56 in\n                ['ir_r3_k3_s2_e3_c40'],\n                # stage 3, 28x28 in\n                ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],\n                # stage 4, 14x14in\n                ['ir_r2_k3_s1_e6_c112'],\n                # stage 5, 14x14in\n                ['ir_r3_k3_s2_e6_c160'],\n                # stage 6, 7x7 in\n                ['cn_r1_k1_s1_c960'],\n            ]\n        else:\n            act_layer = resolve_act_layer(kwargs, 'hard_swish')\n            arch_def = [\n                # stage 0, 112x112 in\n                ['ds_r1_k3_s1_e1_c16_nre'],  # relu\n                # stage 1, 112x112 in\n                ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'],  # relu\n                # stage 2, 56x56 in\n                ['ir_r3_k5_s2_e3_c40_se0.25_nre'],  # relu\n                # stage 3, 28x28 in\n                ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],  # hard-swish\n                # stage 4, 14x14in\n                ['ir_r2_k3_s1_e6_c112_se0.25'],  # hard-swish\n                # stage 5, 14x14in\n                ['ir_r3_k5_s2_e6_c160_se0.25'],  # hard-swish\n                # stage 6, 7x7 in\n                ['cn_r1_k1_s1_c960'],  # hard-swish\n            ]\n    se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)\n    model_kwargs = dict(\n        block_args=decode_arch_def(arch_def),\n        num_features=num_features,\n        stem_size=16,\n        fix_stem=channel_multiplier < 0.75,\n        round_chs_fn=partial(round_channels, multiplier=channel_multiplier),\n        norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),\n        act_layer=act_layer,\n        se_layer=se_layer,\n        **kwargs,\n    )\n    model = _create_mnv3(variant, pretrained, **model_kwargs)\n    return model\n\ndef _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):\n    \"\"\" FBNetV3\n    Paper: `FBNetV3: Joint Architecture-Recipe Search using Predictor Pretraining`\n        - https://arxiv.org/abs/2006.02049\n    FIXME untested, this is a preliminary impl of some FBNet-V3 variants.\n    \"\"\"\n    vl = variant.split('_')[-1]\n    if vl in ('a', 'b'):\n        stem_size = 16\n        arch_def = [\n            ['ds_r2_k3_s1_e1_c16'],\n            ['ir_r1_k5_s2_e4_c24', 'ir_r3_k5_s1_e2_c24'],\n            ['ir_r1_k5_s2_e5_c40_se0.25', 'ir_r4_k5_s1_e3_c40_se0.25'],\n            ['ir_r1_k5_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'],\n            ['ir_r1_k3_s1_e5_c120_se0.25', 'ir_r5_k5_s1_e3_c120_se0.25'],\n            ['ir_r1_k3_s2_e6_c184_se0.25', 'ir_r5_k5_s1_e4_c184_se0.25', 'ir_r1_k5_s1_e6_c224_se0.25'],\n            ['cn_r1_k1_s1_c1344'],\n        ]\n    elif vl == 'd':\n        stem_size = 24\n        arch_def = [\n            ['ds_r2_k3_s1_e1_c16'],\n            ['ir_r1_k3_s2_e5_c24', 'ir_r5_k3_s1_e2_c24'],\n            ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r4_k3_s1_e3_c40_se0.25'],\n            ['ir_r1_k3_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'],\n            ['ir_r1_k3_s1_e5_c128_se0.25', 'ir_r6_k5_s1_e3_c128_se0.25'],\n            ['ir_r1_k3_s2_e6_c208_se0.25', 'ir_r5_k5_s1_e5_c208_se0.25', 'ir_r1_k5_s1_e6_c240_se0.25'],\n            ['cn_r1_k1_s1_c1440'],\n        ]\n    elif vl == 'g':\n        stem_size = 32\n        arch_def = [\n            ['ds_r3_k3_s1_e1_c24'],\n            ['ir_r1_k5_s2_e4_c40', 'ir_r4_k5_s1_e2_c40'],\n            ['ir_r1_k5_s2_e4_c56_se0.25', 'ir_r4_k5_s1_e3_c56_se0.25'],\n            ['ir_r1_k5_s2_e5_c104', 'ir_r4_k3_s1_e3_c104'],\n            ['ir_r1_k3_s1_e5_c160_se0.25', 'ir_r8_k5_s1_e3_c160_se0.25'],\n            ['ir_r1_k3_s2_e6_c264_se0.25', 'ir_r6_k5_s1_e5_c264_se0.25', 'ir_r2_k5_s1_e6_c288_se0.25'],\n            ['cn_r1_k1_s1_c1728'],\n        ]\n    else:\n        raise NotImplemented\n    round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.95)\n    se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=round_chs_fn)\n    act_layer = resolve_act_layer(kwargs, 'hard_swish')\n    model_kwargs = dict(\n        block_args=decode_arch_def(arch_def),\n        num_features=1984,\n        head_bias=False,\n        stem_size=stem_size,\n        round_chs_fn=round_chs_fn,\n        se_from_exp=False,\n        norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),\n        act_layer=act_layer,\n        se_layer=se_layer,\n        **kwargs,\n    )\n    model = _create_mnv3(variant, pretrained, **model_kwargs)\n    return model\n\ndef _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):\n    \"\"\" LCNet\n    Essentially a MobileNet-V3 crossed with a MobileNet-V1\n    Paper: `PP-LCNet: A Lightweight CPU Convolutional Neural Network` - https://arxiv.org/abs/2109.15099\n    Args:\n      channel_multiplier: multiplier to number of channels per layer.\n    \"\"\"\n    arch_def = [\n        # stage 0, 112x112 in\n        ['dsa_r1_k3_s1_c32'],\n        # stage 1, 112x112 in\n        ['dsa_r2_k3_s2_c64'],\n        # stage 2, 56x56 in\n        ['dsa_r2_k3_s2_c128'],\n        # stage 3, 28x28 in\n        ['dsa_r1_k3_s2_c256', 'dsa_r1_k5_s1_c256'],\n        # stage 4, 14x14in\n        ['dsa_r4_k5_s1_c256'],\n        # stage 5, 14x14in\n        ['dsa_r2_k5_s2_c512_se0.25'],\n        # 7x7\n    ]\n    model_kwargs = dict(\n        block_args=decode_arch_def(arch_def),\n        stem_size=16,\n        round_chs_fn=partial(round_channels, multiplier=channel_multiplier),\n        norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),\n        act_layer=resolve_act_layer(kwargs, 'hard_swish'),\n        se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU),\n        num_features=1280,\n        **kwargs,\n    )\n    model = _create_mnv3(variant, pretrained, **model_kwargs)\n    return model\n\ndef _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):\n    \"\"\" LCNet\n    Essentially a MobileNet-V3 crossed with a MobileNet-V1\n    Paper: `PP-LCNet: A Lightweight CPU Convolutional Neural Network` - https://arxiv.org/abs/2109.15099\n    Args:\n      channel_multiplier: multiplier to number of channels per layer.\n    \"\"\"\n    arch_def = [\n        # stage 0, 112x112 in\n        ['dsa_r1_k3_s1_c32'],\n        # stage 1, 112x112 in\n        ['dsa_r2_k3_s2_c64'],\n        # stage 2, 56x56 in\n        ['dsa_r2_k3_s2_c128'],\n        # stage 3, 28x28 in\n        ['dsa_r1_k3_s2_c256', 'dsa_r1_k5_s1_c256'],\n        # stage 4, 14x14in\n        ['dsa_r4_k5_s1_c256'],\n        # stage 5, 14x14in\n        ['dsa_r2_k5_s2_c512_se0.25'],\n        # 7x7\n    ]\n    model_kwargs = dict(\n        block_args=decode_arch_def(arch_def),\n        stem_size=16,\n        round_chs_fn=partial(round_channels, multiplier=channel_multiplier),\n        norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),\n        act_layer=resolve_act_layer(kwargs, 'hard_swish'),\n        se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU),\n        num_features=1280,\n        **kwargs,\n    )\n    model = _create_mnv3(variant, pretrained, **model_kwargs)\n    return model\n\n\n@register_model\ndef mobilenetv3_large_075(pretrained=False, **kwargs):\n    \"\"\" MobileNet V3 \"\"\"\n    model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef mobilenetv3_large_100(pretrained=False, **kwargs):\n    \"\"\" MobileNet V3 \"\"\"\n    model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef mobilenetv3_large_100_miil(pretrained=False, **kwargs):\n    \"\"\" MobileNet V3\n    Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K\n    \"\"\"\n    model = _gen_mobilenet_v3('mobilenetv3_large_100_miil', 1.0, pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef mobilenetv3_large_100_miil_in21k(pretrained=False, **kwargs):\n    \"\"\" MobileNet V3, 21k pretraining\n    Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K\n    \"\"\"\n    model = _gen_mobilenet_v3('mobilenetv3_large_100_miil_in21k', 1.0, pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef mobilenetv3_small_050(pretrained=False, **kwargs):\n    \"\"\" MobileNet V3 \"\"\"\n    model = _gen_mobilenet_v3('mobilenetv3_small_050', 0.50, pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef mobilenetv3_small_075(pretrained=False, **kwargs):\n    \"\"\" MobileNet V3 \"\"\"\n    model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef mobilenetv3_small_100(pretrained=False, **kwargs):\n    \"\"\" MobileNet V3 \"\"\"\n    model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef mobilenetv3_rw(pretrained=False, **kwargs):\n    \"\"\" MobileNet V3 \"\"\"\n    if pretrained:\n        # pretrained model trained with non-default BN epsilon\n        kwargs['bn_eps'] = BN_EPS_TF_DEFAULT\n    model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef tf_mobilenetv3_large_075(pretrained=False, **kwargs):\n    \"\"\" MobileNet V3 \"\"\"\n    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT\n    kwargs['pad_type'] = 'same'\n    model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef tf_mobilenetv3_large_100(pretrained=False, **kwargs):\n    \"\"\" MobileNet V3 \"\"\"\n    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT\n    kwargs['pad_type'] = 'same'\n    model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs):\n    \"\"\" MobileNet V3 \"\"\"\n    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT\n    kwargs['pad_type'] = 'same'\n    model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef tf_mobilenetv3_small_075(pretrained=False, **kwargs):\n    \"\"\" MobileNet V3 \"\"\"\n    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT\n    kwargs['pad_type'] = 'same'\n    model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef tf_mobilenetv3_small_100(pretrained=False, **kwargs):\n    \"\"\" MobileNet V3 \"\"\"\n    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT\n    kwargs['pad_type'] = 'same'\n    model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs):\n    \"\"\" MobileNet V3 \"\"\"\n    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT\n    kwargs['pad_type'] = 'same'\n    model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef fbnetv3_b(pretrained=False, **kwargs):\n    \"\"\" FBNetV3-B \"\"\"\n    model = _gen_fbnetv3('fbnetv3_b', pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef fbnetv3_d(pretrained=False, **kwargs):\n    \"\"\" FBNetV3-D \"\"\"\n    model = _gen_fbnetv3('fbnetv3_d', pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef fbnetv3_g(pretrained=False, **kwargs):\n    \"\"\" FBNetV3-G \"\"\"\n    model = _gen_fbnetv3('fbnetv3_g', pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef lcnet_035(pretrained=False, **kwargs):\n    \"\"\" PP-LCNet 0.35\"\"\"\n    model = _gen_lcnet('lcnet_035', 0.35, pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef lcnet_050(pretrained=False, **kwargs):\n    \"\"\" PP-LCNet 0.5\"\"\"\n    model = _gen_lcnet('lcnet_050', 0.5, pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef lcnet_075(pretrained=False, **kwargs):\n    \"\"\" PP-LCNet 1.0\"\"\"\n    model = _gen_lcnet('lcnet_075', 0.75, pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef lcnet_100(pretrained=False, **kwargs):\n    \"\"\" PP-LCNet 1.0\"\"\"\n    model = _gen_lcnet('lcnet_100', 1.0, pretrained=pretrained, **kwargs)\n    return model\n\n\n@register_model\ndef lcnet_150(pretrained=False, **kwargs):\n    \"\"\" PP-LCNet 1.5\"\"\"\n    model = _gen_lcnet('lcnet_150', 1.5, pretrained=pretrained, **kwargs)\n    return model\n"
  },
  {
    "path": "model/backbone/MobileViT.py",
    "content": "from torch import nn\nimport torch\nfrom torch.nn.modules import conv\nfrom torch.nn.modules.conv import Conv2d\nfrom einops import rearrange\n\n\n\ndef conv_bn(inp,oup,kernel_size=3,stride=1):\n    return nn.Sequential(\n        nn.Conv2d(inp,oup,kernel_size=kernel_size,stride=stride,padding=kernel_size//2),\n        nn.BatchNorm2d(oup),\n        nn.SiLU()\n    )\n\nclass PreNorm(nn.Module):\n    def __init__(self,dim,fn):\n        super().__init__()\n        self.ln=nn.LayerNorm(dim)\n        self.fn=fn\n    def forward(self,x,**kwargs):\n        return self.fn(self.ln(x),**kwargs)\n\nclass FeedForward(nn.Module):\n    def __init__(self,dim,mlp_dim,dropout) :\n        super().__init__()\n        self.net=nn.Sequential(\n            nn.Linear(dim,mlp_dim),\n            nn.SiLU(),\n            nn.Dropout(dropout),\n            nn.Linear(mlp_dim,dim),\n            nn.Dropout(dropout)\n        )\n    def forward(self,x):\n        return self.net(x)\n\nclass Attention(nn.Module):\n    def __init__(self,dim,heads,head_dim,dropout):\n        super().__init__()\n        inner_dim=heads*head_dim\n        project_out=not(heads==1 and head_dim==dim)\n\n        self.heads=heads\n        self.scale=head_dim**-0.5\n\n        self.attend=nn.Softmax(dim=-1)\n        self.to_qkv=nn.Linear(dim,inner_dim*3,bias=False)\n        \n        self.to_out=nn.Sequential(\n            nn.Linear(inner_dim,dim),\n            nn.Dropout(dropout)\n        ) if project_out else nn.Identity()\n\n    def forward(self,x):\n        qkv=self.to_qkv(x).chunk(3,dim=-1)\n        q,k,v=map(lambda t:rearrange(t,'b p n (h d) -> b p h n d',h=self.heads),qkv)\n        dots=torch.matmul(q,k.transpose(-1,-2))*self.scale\n        attn=self.attend(dots)\n        out=torch.matmul(attn,v)\n        out=rearrange(out,'b p h n d -> b p n (h d)')\n        return self.to_out(out)\n\n\n\n\n\nclass Transformer(nn.Module):\n    def __init__(self,dim,depth,heads,head_dim,mlp_dim,dropout=0.):\n        super().__init__()\n        self.layers=nn.ModuleList([])\n        for _ in range(depth):\n            self.layers.append(nn.ModuleList([\n                PreNorm(dim,Attention(dim,heads,head_dim,dropout)),\n                PreNorm(dim,FeedForward(dim,mlp_dim,dropout))\n            ]))\n\n\n    def forward(self,x):\n        out=x\n        for att,ffn in self.layers:\n            out=out+att(out)\n            out=out+ffn(out)\n        return out\n\nclass MobileViTAttention(nn.Module):\n    def __init__(self,in_channel=3,dim=512,kernel_size=3,patch_size=7,depth=3,mlp_dim=1024):\n        super().__init__()\n        self.ph,self.pw=patch_size,patch_size\n        self.conv1=nn.Conv2d(in_channel,in_channel,kernel_size=kernel_size,padding=kernel_size//2)\n        self.conv2=nn.Conv2d(in_channel,dim,kernel_size=1)\n\n        self.trans=Transformer(dim=dim,depth=depth,heads=8,head_dim=64,mlp_dim=mlp_dim)\n\n        self.conv3=nn.Conv2d(dim,in_channel,kernel_size=1)\n        self.conv4=nn.Conv2d(2*in_channel,in_channel,kernel_size=kernel_size,padding=kernel_size//2)\n\n    def forward(self,x):\n        y=x.clone() #bs,c,h,w\n\n        ## Local Representation\n        y=self.conv2(self.conv1(x)) #bs,dim,h,w\n\n        ## Global Representation\n        _,_,h,w=y.shape\n        y=rearrange(y,'bs dim (nh ph) (nw pw) -> bs (ph pw) (nh nw) dim',ph=self.ph,pw=self.pw) #bs,h,w,dim\n        y=self.trans(y)\n        y=rearrange(y,'bs (ph pw) (nh nw) dim -> bs dim (nh ph) (nw pw)',ph=self.ph,pw=self.pw,nh=h//self.ph,nw=w//self.pw) #bs,dim,h,w\n\n        ## Fusion\n        y=self.conv3(y) #bs,dim,h,w\n        y=torch.cat([x,y],1) #bs,2*dim,h,w\n        y=self.conv4(y) #bs,c,h,w\n\n        return y\n\n\nclass MV2Block(nn.Module):\n    def __init__(self,inp,out,stride=1,expansion=4):\n        super().__init__()\n        self.stride=stride\n        hidden_dim=inp*expansion\n        self.use_res_connection=stride==1 and inp==out\n\n        if expansion==1:\n            self.conv=nn.Sequential(\n                nn.Conv2d(hidden_dim,hidden_dim,kernel_size=3,stride=self.stride,padding=1,groups=hidden_dim,bias=False),\n                nn.BatchNorm2d(hidden_dim),\n                nn.SiLU(),\n                nn.Conv2d(hidden_dim,out,kernel_size=1,stride=1,bias=False),\n                nn.BatchNorm2d(out)\n            )\n        else:\n            self.conv=nn.Sequential(\n                nn.Conv2d(inp,hidden_dim,kernel_size=1,stride=1,bias=False),\n                nn.BatchNorm2d(hidden_dim),\n                nn.SiLU(),\n                nn.Conv2d(hidden_dim,hidden_dim,kernel_size=3,stride=1,padding=1,groups=hidden_dim,bias=False),\n                nn.BatchNorm2d(hidden_dim),\n                nn.SiLU(),\n                nn.Conv2d(hidden_dim,out,kernel_size=1,stride=1,bias=False),\n                nn.SiLU(),\n                nn.BatchNorm2d(out)\n            )\n    def forward(self,x):\n        if(self.use_res_connection):\n            out=x+self.conv(x)\n        else:\n            out=self.conv(x)\n        return out\n\nclass MobileViT(nn.Module):\n    def __init__(self,image_size,dims,channels,num_classes,depths=[2,4,3],expansion=4,kernel_size=3,patch_size=2):\n        super().__init__()\n        ih,iw=image_size,image_size\n        ph,pw=patch_size,patch_size\n        assert iw%pw==0 and ih%ph==0\n\n        self.conv1=conv_bn(3,channels[0],kernel_size=3,stride=patch_size)\n        self.mv2=nn.ModuleList([])\n        self.m_vits=nn.ModuleList([])\n\n\n        self.mv2.append(MV2Block(channels[0],channels[1],1))\n        self.mv2.append(MV2Block(channels[1],channels[2],2))\n        self.mv2.append(MV2Block(channels[2],channels[3],1))\n        self.mv2.append(MV2Block(channels[2],channels[3],1)) # x2\n        self.mv2.append(MV2Block(channels[3],channels[4],2))\n        self.m_vits.append(MobileViTAttention(channels[4],dim=dims[0],kernel_size=kernel_size,patch_size=patch_size,depth=depths[0],mlp_dim=int(2*dims[0])))\n        self.mv2.append(MV2Block(channels[4],channels[5],2))\n        self.m_vits.append(MobileViTAttention(channels[5],dim=dims[1],kernel_size=kernel_size,patch_size=patch_size,depth=depths[1],mlp_dim=int(4*dims[1])))\n        self.mv2.append(MV2Block(channels[5],channels[6],2))\n        self.m_vits.append(MobileViTAttention(channels[6],dim=dims[2],kernel_size=kernel_size,patch_size=patch_size,depth=depths[2],mlp_dim=int(4*dims[2])))\n\n        \n        self.conv2=conv_bn(channels[-2],channels[-1],kernel_size=1)\n        self.pool=nn.AvgPool2d(image_size//32,1)\n        self.fc=nn.Linear(channels[-1],num_classes,bias=False)\n\n    def forward(self,x):\n        y=self.conv1(x) #\n        y=self.mv2[0](y)\n        y=self.mv2[1](y) #\n        y=self.mv2[2](y)\n        y=self.mv2[3](y)\n        y=self.mv2[4](y) #\n        y=self.m_vits[0](y)\n\n        y=self.mv2[5](y) #\n        y=self.m_vits[1](y)\n\n        y=self.mv2[6](y) #\n        y=self.m_vits[2](y)\n\n        y=self.conv2(y)\n        y=self.pool(y).view(y.shape[0],-1) \n        y=self.fc(y)\n        return y\n\ndef mobilevit_xxs():\n    dims=[60,80,96]\n    channels= [16, 16, 24, 24, 48, 64, 80, 320]\n    return MobileViT(224,dims,channels,num_classes=1000)\n\ndef mobilevit_xs():\n    dims = [96, 120, 144]\n    channels = [16, 32, 48, 48, 64, 80, 96, 384]\n    return MobileViT(224, dims, channels, num_classes=1000)\n\ndef mobilevit_s():\n    dims = [144, 192, 240]\n    channels = [16, 32, 64, 64, 96, 128, 160, 640]\n    return MobileViT(224, dims, channels, num_classes=1000)\n\n\ndef count_paratermeters(model):\n    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n\n    ### mobilevit_xxs\n    mvit_xxs=mobilevit_xxs()\n    out=mvit_xxs(input)\n    print(out.shape)\n\n    ### mobilevit_xs\n    mvit_xs=mobilevit_xs()\n    out=mvit_xs(input)\n    print(out.shape)\n\n\n    ### mobilevit_s\n    mvit_s=mobilevit_s()\n    out=mvit_s(input)\n    print(out.shape)\n\n    "
  },
  {
    "path": "model/backbone/PIT.py",
    "content": "# PiT\n# Copyright 2021-present NAVER Corp.\n# Apache License v2.0\n\nimport torch\nfrom einops import rearrange\nfrom torch import nn\nimport math\n\nfrom functools import partial\nfrom timm.models.layers import trunc_normal_\nfrom timm.models.vision_transformer import Block as transformer_block\nfrom timm.models.registry import register_model\n\nclass Transformer(nn.Module):\n    def __init__(self, base_dim, depth, heads, mlp_ratio,\n                 drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None):\n        super(Transformer, self).__init__()\n        self.layers = nn.ModuleList([])\n        embed_dim = base_dim * heads\n\n        if drop_path_prob is None:\n            drop_path_prob = [0.0 for _ in range(depth)]\n\n        self.blocks = nn.ModuleList([\n            transformer_block(\n                dim=embed_dim,\n                num_heads=heads,\n                mlp_ratio=mlp_ratio,\n                qkv_bias=True,\n                drop=drop_rate,\n                attn_drop=attn_drop_rate,\n                drop_path=drop_path_prob[i],\n                norm_layer=partial(nn.LayerNorm, eps=1e-6)\n            )\n            for i in range(depth)])\n\n    def forward(self, x, cls_tokens):\n        h, w = x.shape[2:4]\n        x = rearrange(x, 'b c h w -> b (h w) c')\n\n        token_length = cls_tokens.shape[1]\n        x = torch.cat((cls_tokens, x), dim=1)\n        for blk in self.blocks:\n            x = blk(x)\n\n        cls_tokens = x[:, :token_length]\n        x = x[:, token_length:]\n        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)\n\n        return x, cls_tokens\n\n\nclass conv_head_pooling(nn.Module):\n    def __init__(self, in_feature, out_feature, stride,\n                 padding_mode='zeros'):\n        super(conv_head_pooling, self).__init__()\n\n        self.conv = nn.Conv2d(in_feature, out_feature, kernel_size=stride + 1,\n                              padding=stride // 2, stride=stride,\n                              padding_mode=padding_mode, groups=in_feature)\n        self.fc = nn.Linear(in_feature, out_feature)\n\n    def forward(self, x, cls_token):\n\n        x = self.conv(x)\n        cls_token = self.fc(cls_token)\n\n        return x, cls_token\n\n\nclass conv_embedding(nn.Module):\n    def __init__(self, in_channels, out_channels, patch_size,\n                 stride, padding):\n        super(conv_embedding, self).__init__()\n        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size,\n                              stride=stride, padding=padding, bias=True)\n\n    def forward(self, x):\n        x = self.conv(x)\n        return x\n\n\nclass PoolingTransformer(nn.Module):\n    def __init__(self, image_size, patch_size, stride, base_dims, depth, heads,\n                 mlp_ratio, num_classes=1000, in_chans=3,\n                 attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):\n        super(PoolingTransformer, self).__init__()\n\n        total_block = sum(depth)\n        padding = 0\n        block_idx = 0\n\n        width = math.floor(\n            (image_size + 2 * padding - patch_size) / stride + 1)\n\n        self.base_dims = base_dims\n        self.heads = heads\n        self.num_classes = num_classes\n\n        self.patch_size = patch_size\n        self.pos_embed = nn.Parameter(\n            torch.randn(1, base_dims[0] * heads[0], width, width),\n            requires_grad=True\n        )\n        self.patch_embed = conv_embedding(in_chans, base_dims[0] * heads[0],\n                                          patch_size, stride, padding)\n\n        self.cls_token = nn.Parameter(\n            torch.randn(1, 1, base_dims[0] * heads[0]),\n            requires_grad=True\n        )\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        self.transformers = nn.ModuleList([])\n        self.pools = nn.ModuleList([])\n\n        for stage in range(len(depth)):\n            drop_path_prob = [drop_path_rate * i / total_block\n                              for i in range(block_idx, block_idx + depth[stage])]\n            block_idx += depth[stage]\n\n            self.transformers.append(\n                Transformer(base_dims[stage], depth[stage], heads[stage],\n                            mlp_ratio,\n                            drop_rate, attn_drop_rate, drop_path_prob)\n            )\n            if stage < len(heads) - 1:\n                self.pools.append(\n                    conv_head_pooling(base_dims[stage] * heads[stage],\n                                      base_dims[stage + 1] * heads[stage + 1],\n                                      stride=2\n                                      )\n                )\n\n        self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6)\n        self.embed_dim = base_dims[-1] * heads[-1]\n\n        # Classifier head\n        if num_classes > 0:\n            self.head = nn.Linear(base_dims[-1] * heads[-1], num_classes)\n        else:\n            self.head = nn.Identity()\n\n        trunc_normal_(self.pos_embed, std=.02)\n        trunc_normal_(self.cls_token, std=.02)\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token'}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        if num_classes > 0:\n            self.head = nn.Linear(self.embed_dim, num_classes)\n        else:\n            self.head = nn.Identity()\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n\n        pos_embed = self.pos_embed\n        x = self.pos_drop(x + pos_embed)\n        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)\n\n        for stage in range(len(self.pools)):\n            x, cls_tokens = self.transformers[stage](x, cls_tokens)\n            x, cls_tokens = self.pools[stage](x, cls_tokens)\n        x, cls_tokens = self.transformers[-1](x, cls_tokens)\n\n        cls_tokens = self.norm(cls_tokens)\n\n        return cls_tokens\n\n    def forward(self, x):\n        cls_token = self.forward_features(x)\n        cls_token = self.head(cls_token[:, 0])\n        return cls_token\n\n\nclass DistilledPoolingTransformer(PoolingTransformer):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.cls_token = nn.Parameter(\n            torch.randn(1, 2, self.base_dims[0] * self.heads[0]),\n            requires_grad=True)\n        if self.num_classes > 0:\n            self.head_dist = nn.Linear(self.base_dims[-1] * self.heads[-1],\n                                       self.num_classes)\n        else:\n            self.head_dist = nn.Identity()\n\n        trunc_normal_(self.cls_token, std=.02)\n        self.head_dist.apply(self._init_weights)\n\n    def forward(self, x):\n        cls_token = self.forward_features(x)\n        x_cls = self.head(cls_token[:, 0])\n        x_dist = self.head_dist(cls_token[:, 1])\n        if self.training:\n            return x_cls, x_dist\n        else:\n            return (x_cls + x_dist) / 2\n\n@register_model\ndef pit_b(pretrained, **kwargs):\n    model = PoolingTransformer(\n        image_size=224,\n        patch_size=14,\n        stride=7,\n        base_dims=[64, 64, 64],\n        depth=[3, 6, 4],\n        heads=[4, 8, 16],\n        mlp_ratio=4,\n        **kwargs\n    )\n    if pretrained:\n        state_dict = \\\n        torch.load('weights/pit_b_820.pth', map_location='cpu')\n        model.load_state_dict(state_dict)\n    return model\n\n@register_model\ndef pit_s(pretrained, **kwargs):\n    model = PoolingTransformer(\n        image_size=224,\n        patch_size=16,\n        stride=8,\n        base_dims=[48, 48, 48],\n        depth=[2, 6, 4],\n        heads=[3, 6, 12],\n        mlp_ratio=4,\n        **kwargs\n    )\n    if pretrained:\n        state_dict = \\\n        torch.load('weights/pit_s_809.pth', map_location='cpu')\n        model.load_state_dict(state_dict)\n    return model\n\n\n@register_model\ndef pit_xs(pretrained, **kwargs):\n    model = PoolingTransformer(\n        image_size=224,\n        patch_size=16,\n        stride=8,\n        base_dims=[48, 48, 48],\n        depth=[2, 6, 4],\n        heads=[2, 4, 8],\n        mlp_ratio=4,\n        **kwargs\n    )\n    if pretrained:\n        state_dict = \\\n        torch.load('weights/pit_xs_781.pth', map_location='cpu')\n        model.load_state_dict(state_dict)\n    return model\n\n@register_model\ndef pit_ti(pretrained, **kwargs):\n    model = PoolingTransformer(\n        image_size=224,\n        patch_size=16,\n        stride=8,\n        base_dims=[32, 32, 32],\n        depth=[2, 6, 4],\n        heads=[2, 4, 8],\n        mlp_ratio=4,\n        **kwargs\n    )\n    if pretrained:\n        state_dict = \\\n        torch.load('weights/pit_ti_730.pth', map_location='cpu')\n        model.load_state_dict(state_dict)\n    return model\n\n\n@register_model\ndef pit_b_distilled(pretrained, **kwargs):\n    model = DistilledPoolingTransformer(\n        image_size=224,\n        patch_size=14,\n        stride=7,\n        base_dims=[64, 64, 64],\n        depth=[3, 6, 4],\n        heads=[4, 8, 16],\n        mlp_ratio=4,\n        **kwargs\n    )\n    if pretrained:\n        state_dict = \\\n        torch.load('weights/pit_b_distill_840.pth', map_location='cpu')\n        model.load_state_dict(state_dict)\n    return model\n\n\n@register_model\ndef pit_s_distilled(pretrained, **kwargs):\n    model = DistilledPoolingTransformer(\n        image_size=224,\n        patch_size=16,\n        stride=8,\n        base_dims=[48, 48, 48],\n        depth=[2, 6, 4],\n        heads=[3, 6, 12],\n        mlp_ratio=4,\n        **kwargs\n    )\n    if pretrained:\n        state_dict = \\\n        torch.load('weights/pit_s_distill_819.pth', map_location='cpu')\n        model.load_state_dict(state_dict)\n    return model\n\n\n@register_model\ndef pit_xs_distilled(pretrained, **kwargs):\n    model = DistilledPoolingTransformer(\n        image_size=224,\n        patch_size=16,\n        stride=8,\n        base_dims=[48, 48, 48],\n        depth=[2, 6, 4],\n        heads=[2, 4, 8],\n        mlp_ratio=4,\n        **kwargs\n    )\n    if pretrained:\n        state_dict = \\\n        torch.load('weights/pit_xs_distill_791.pth', map_location='cpu')\n        model.load_state_dict(state_dict)\n    return model\n\n\n@register_model\ndef pit_ti_distilled(pretrained, **kwargs):\n    model = DistilledPoolingTransformer(\n        image_size=224,\n        patch_size=16,\n        stride=8,\n        base_dims=[32, 32, 32],\n        depth=[2, 6, 4],\n        heads=[2, 4, 8],\n        mlp_ratio=4,\n        **kwargs\n    )\n    if pretrained:\n        state_dict = \\\n        torch.load('weights/pit_ti_distill_746.pth', map_location='cpu')\n        model.load_state_dict(state_dict)\n    return model\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = PoolingTransformer(\n        image_size=224,\n        patch_size=14,\n        stride=7,\n        base_dims=[64, 64, 64],\n        depth=[3, 6, 4],\n        heads=[4, 8, 16],\n        mlp_ratio=4\n    )\n    output=model(input)\n    print(output.shape)"
  },
  {
    "path": "model/backbone/PVT.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom functools import partial\n\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nfrom timm.models.registry import register_model\nfrom timm.models.vision_transformer import _cfg\n\n__all__ = [\n    'pvt_tiny', 'pvt_small', 'pvt_medium', 'pvt_large'\n]\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):\n        super().__init__()\n        assert dim % num_heads == 0, f\"dim {dim} should be divided by num_heads {num_heads}.\"\n\n        self.dim = dim\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.q = nn.Linear(dim, dim, bias=qkv_bias)\n        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        self.sr_ratio = sr_ratio\n        if sr_ratio > 1:\n            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)\n            self.norm = nn.LayerNorm(dim)\n\n    def forward(self, x, H, W):\n        B, N, C = x.shape\n        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)\n\n        if self.sr_ratio > 1:\n            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)\n            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)\n            x_ = self.norm(x_)\n            kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        else:\n            kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        k, v = kv[0], kv[1]\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n\n        return x\n\n\nclass Block(nn.Module):\n\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim,\n            num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n    def forward(self, x, H, W):\n        x = x + self.drop_path(self.attn(self.norm1(x), H, W))\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n\n        self.img_size = img_size\n        self.patch_size = patch_size\n        # assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \\\n        #     f\"img_size {img_size} should be divided by patch_size {patch_size}.\"\n        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]\n        self.num_patches = self.H * self.W\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        self.norm = nn.LayerNorm(embed_dim)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n\n        x = self.proj(x).flatten(2).transpose(1, 2)\n        x = self.norm(x)\n        H, W = H // self.patch_size[0], W // self.patch_size[1]\n\n        return x, (H, W)\n\n\nclass PyramidVisionTransformer(nn.Module):\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],\n                 num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,\n                 attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,\n                 depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4):\n        super().__init__()\n        self.num_classes = num_classes\n        self.depths = depths\n        self.num_stages = num_stages\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n        cur = 0\n\n        for i in range(num_stages):\n            patch_embed = PatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),\n                                     patch_size=patch_size if i == 0 else 2,\n                                     in_chans=in_chans if i == 0 else embed_dims[i - 1],\n                                     embed_dim=embed_dims[i])\n            num_patches = patch_embed.num_patches if i != num_stages - 1 else patch_embed.num_patches + 1\n            pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dims[i]))\n            pos_drop = nn.Dropout(p=drop_rate)\n\n            block = nn.ModuleList([Block(\n                dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias,\n                qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j],\n                norm_layer=norm_layer, sr_ratio=sr_ratios[i])\n                for j in range(depths[i])])\n            cur += depths[i]\n\n            setattr(self, f\"patch_embed{i + 1}\", patch_embed)\n            setattr(self, f\"pos_embed{i + 1}\", pos_embed)\n            setattr(self, f\"pos_drop{i + 1}\", pos_drop)\n            setattr(self, f\"block{i + 1}\", block)\n\n        self.norm = norm_layer(embed_dims[3])\n\n        # cls_token\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[3]))\n\n        # classification head\n        self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()\n\n        # init weights\n        for i in range(num_stages):\n            pos_embed = getattr(self, f\"pos_embed{i + 1}\")\n            trunc_normal_(pos_embed, std=.02)\n        trunc_normal_(self.cls_token, std=.02)\n        self.apply(self._init_weights)\n\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        # return {'pos_embed', 'cls_token'} # has pos_embed may be better\n        return {'cls_token'}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n    def _get_pos_embed(self, pos_embed, patch_embed, H, W):\n        if H * W == self.patch_embed1.num_patches:\n            return pos_embed\n        else:\n            return F.interpolate(\n                pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),\n                size=(H, W), mode=\"bilinear\").reshape(1, -1, H * W).permute(0, 2, 1)\n\n    def forward_features(self, x):\n        B = x.shape[0]\n\n        for i in range(self.num_stages):\n            patch_embed = getattr(self, f\"patch_embed{i + 1}\")\n            pos_embed = getattr(self, f\"pos_embed{i + 1}\")\n            pos_drop = getattr(self, f\"pos_drop{i + 1}\")\n            block = getattr(self, f\"block{i + 1}\")\n            x, (H, W) = patch_embed(x)\n\n            if i == self.num_stages - 1:\n                cls_tokens = self.cls_token.expand(B, -1, -1)\n                x = torch.cat((cls_tokens, x), dim=1)\n                pos_embed_ = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W)\n                pos_embed = torch.cat((pos_embed[:, 0:1], pos_embed_), dim=1)\n            else:\n                pos_embed = self._get_pos_embed(pos_embed, patch_embed, H, W)\n\n            x = pos_drop(x + pos_embed)\n            for blk in block:\n                x = blk(x, H, W)\n            if i != self.num_stages - 1:\n                x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n\n        x = self.norm(x)\n\n        return x[:, 0]\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n\n        return x\n\n\ndef _conv_filter(state_dict, patch_size=16):\n    \"\"\" convert patch embedding weight from manual patchify + linear proj to conv\"\"\"\n    out_dict = {}\n    for k, v in state_dict.items():\n        if 'patch_embed.proj.weight' in k:\n            v = v.reshape((v.shape[0], 3, patch_size, patch_size))\n        out_dict[k] = v\n\n    return out_dict\n\n\n@register_model\ndef pvt_tiny(pretrained=False, **kwargs):\n    model = PyramidVisionTransformer(\n        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],\n        **kwargs)\n    model.default_cfg = _cfg()\n\n    return model\n\n\n@register_model\ndef pvt_small(pretrained=False, **kwargs):\n    model = PyramidVisionTransformer(\n        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs)\n    model.default_cfg = _cfg()\n\n    return model\n\n\n@register_model\ndef pvt_medium(pretrained=False, **kwargs):\n    model = PyramidVisionTransformer(\n        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],\n        **kwargs)\n    model.default_cfg = _cfg()\n\n    return model\n\n\n@register_model\ndef pvt_large(pretrained=False, **kwargs):\n    model = PyramidVisionTransformer(\n        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],\n        **kwargs)\n    model.default_cfg = _cfg()\n\n    return model\n\n\n@register_model\ndef pvt_huge_v2(pretrained=False, **kwargs):\n    model = PyramidVisionTransformer(\n        patch_size=4, embed_dims=[128, 256, 512, 768], num_heads=[2, 4, 8, 12], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 10, 60, 3], sr_ratios=[8, 4, 2, 1],\n        # drop_rate=0.0, drop_path_rate=0.02)\n        **kwargs)\n    model.default_cfg = _cfg()\n\n    return model\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = PyramidVisionTransformer(\n        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1])\n    output=model(input)\n    print(output.shape)"
  },
  {
    "path": "model/backbone/PatchConvnet.py",
    "content": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\n\n\nfrom functools import partial\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom timm.models.efficientnet_blocks import SqueezeExcite\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nfrom timm.models.registry import register_model\n\n__all__ = ['S60', 'S120', 'B60', 'B120', 'L60', 'L120', 'S60_multi']\n\n\nclass Mlp(nn.Module):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: Optional[int] = None,\n        out_features: Optional[int] = None,\n        act_layer: nn.Module = nn.GELU,\n        drop: float = 0.0,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Learned_Aggregation_Layer(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int = 1,\n        qkv_bias: bool = False,\n        qk_scale: Optional[float] = None,\n        attn_drop: float = 0.0,\n        proj_drop: float = 0.0,\n    ):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim: int = dim // num_heads\n        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n        self.scale = qk_scale or head_dim**-0.5\n\n        self.q = nn.Linear(dim, dim, bias=qkv_bias)\n        self.k = nn.Linear(dim, dim, bias=qkv_bias)\n        self.v = nn.Linear(dim, dim, bias=qkv_bias)\n        self.id = nn.Identity()\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        B, N, C = x.shape\n        q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)\n        k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)\n\n        q = q * self.scale\n        v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)\n\n        attn = q @ k.transpose(-2, -1)\n        attn = self.id(attn)\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C)\n        x_cls = self.proj(x_cls)\n        x_cls = self.proj_drop(x_cls)\n\n        return x_cls\n\n\nclass Learned_Aggregation_Layer_multi(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int = 8,\n        qkv_bias: bool = False,\n        qk_scale: Optional[float] = None,\n        attn_drop: float = 0.0,\n        proj_drop: float = 0.0,\n        num_classes: int = 1000,\n    ):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim: int = dim // num_heads\n        self.scale = qk_scale or head_dim**-0.5\n\n        self.q = nn.Linear(dim, dim, bias=qkv_bias)\n        self.k = nn.Linear(dim, dim, bias=qkv_bias)\n        self.v = nn.Linear(dim, dim, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n        self.num_classes = num_classes\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        B, N, C = x.shape\n        q = (\n            self.q(x[:, : self.num_classes])\n            .reshape(B, self.num_classes, self.num_heads, C // self.num_heads)\n            .permute(0, 2, 1, 3)\n        )\n        k = (\n            self.k(x[:, self.num_classes:])\n            .reshape(B, N - self.num_classes, self.num_heads, C // self.num_heads)\n            .permute(0, 2, 1, 3)\n        )\n\n        q = q * self.scale\n        v = (\n            self.v(x[:, self.num_classes:])\n            .reshape(B, N - self.num_classes, self.num_heads, C // self.num_heads)\n            .permute(0, 2, 1, 3)\n        )\n\n        attn = q @ k.transpose(-2, -1)\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x_cls = (attn @ v).transpose(1, 2).reshape(B, self.num_classes, C)\n        x_cls = self.proj(x_cls)\n        x_cls = self.proj_drop(x_cls)\n\n        return x_cls\n\n\nclass Layer_scale_init_Block_only_token(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        mlp_ratio: float = 4.0,\n        qkv_bias: bool = False,\n        qk_scale: Optional[float] = None,\n        drop: float = 0.0,\n        attn_drop: float = 0.0,\n        drop_path: float = 0.0,\n        act_layer: nn.Module = nn.GELU,\n        norm_layer=nn.LayerNorm,\n        Attention_block=Learned_Aggregation_Layer,\n        Mlp_block=Mlp,\n        init_values: float = 1e-4,\n    ):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention_block(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop\n        )\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n        self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)\n        self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)\n\n    def forward(self, x: torch.Tensor, x_cls: torch.Tensor) -> torch.Tensor:\n        u = torch.cat((x_cls, x), dim=1)\n        x_cls = x_cls + self.drop_path(self.gamma_1 * self.attn(self.norm1(u)))\n        x_cls = x_cls + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x_cls)))\n        return x_cls\n\n\nclass Conv_blocks_se(nn.Module):\n    def __init__(self, dim: int):\n        super().__init__()\n\n        self.qkv_pos = nn.Sequential(\n            nn.Conv2d(dim, dim, kernel_size=1),\n            nn.GELU(),\n            nn.Conv2d(dim, dim, groups=dim, kernel_size=3, padding=1, stride=1, bias=True),\n            nn.GELU(),\n            SqueezeExcite(dim, rd_ratio=0.25),\n            nn.Conv2d(dim, dim, kernel_size=1),\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        B, N, C = x.shape\n        H = W = int(N ** 0.5)\n        x = x.transpose(-1, -2)\n        x = x.reshape(B, C, H, W)\n        x = self.qkv_pos(x)\n        x = x.reshape(B, C, N)\n        x = x.transpose(-1, -2)\n        return x\n\n\nclass Layer_scale_init_Block(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        drop_path: float = 0.0,\n        act_layer: nn.Module = nn.GELU,\n        norm_layer=nn.LayerNorm,\n        Attention_block=None,\n        init_values: float = 1e-4,\n    ):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention_block(dim)\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n        self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))\n\n\ndef conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Sequential:\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Sequential(\n        nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False),\n    )\n\n\nclass ConvStem(nn.Module):\n    \"\"\"Image to Patch Embedding\"\"\"\n\n    def __init__(self, img_size: int = 224, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n\n        self.proj = nn.Sequential(\n            conv3x3(in_chans, embed_dim // 8, 2),\n            nn.GELU(),\n            conv3x3(embed_dim // 8, embed_dim // 4, 2),\n            nn.GELU(),\n            conv3x3(embed_dim // 4, embed_dim // 2, 2),\n            nn.GELU(),\n            conv3x3(embed_dim // 2, embed_dim, 2),\n        )\n\n    def forward(self, x: torch.Tensor, padding_size: Optional[int] = None) -> torch.Tensor:\n        B, C, H, W = x.shape\n        x = self.proj(x).flatten(2).transpose(1, 2)\n        return x\n\n\nclass PatchConvnet(nn.Module):\n    def __init__(\n        self,\n        img_size: int = 224,\n        patch_size: int = 16,\n        in_chans: int = 3,\n        num_classes: int = 1000,\n        embed_dim: int = 768,\n        depth: int = 12,\n        num_heads: int = 1,\n        qkv_bias: bool = False,\n        qk_scale: Optional[float] = None,\n        drop_rate: float = 0.0,\n        attn_drop_rate: float = 0.0,\n        drop_path_rate: float = 0.0,\n        norm_layer=nn.LayerNorm,\n        global_pool: Optional[str] = None,\n        block_layers=Layer_scale_init_Block,\n        block_layers_token=Layer_scale_init_Block_only_token,\n        Patch_layer=ConvStem,\n        act_layer: nn.Module = nn.GELU,\n        Attention_block=Conv_blocks_se,\n        dpr_constant: bool = True,\n        init_scale: float = 1e-4,\n        Attention_block_token_only=Learned_Aggregation_Layer,\n        Mlp_block_token_only=Mlp,\n        depth_token_only: int = 1,\n        mlp_ratio_clstk: float = 3.0,\n        multiclass: bool = False,\n    ):\n        super().__init__()\n\n        self.multiclass = multiclass\n        self.patch_size = patch_size\n        self.num_classes = num_classes\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n\n        self.patch_embed = Patch_layer(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim\n        )\n\n        if not self.multiclass:\n            self.cls_token = nn.Parameter(torch.zeros(1, 1, int(embed_dim)))\n        else:\n            self.cls_token = nn.Parameter(torch.zeros(1, num_classes, int(embed_dim)))\n\n        if not dpr_constant:\n            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]\n        else:\n            dpr = [drop_path_rate for i in range(depth)]\n\n        self.blocks = nn.ModuleList(\n            [\n                block_layers(\n                    dim=embed_dim,\n                    drop_path=dpr[i],\n                    norm_layer=norm_layer,\n                    act_layer=act_layer,\n                    Attention_block=Attention_block,\n                    init_values=init_scale,\n                )\n                for i in range(depth)\n            ]\n        )\n\n        self.blocks_token_only = nn.ModuleList(\n            [\n                block_layers_token(\n                    dim=int(embed_dim),\n                    num_heads=num_heads,\n                    mlp_ratio=mlp_ratio_clstk,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop_rate,\n                    attn_drop=attn_drop_rate,\n                    drop_path=0.0,\n                    norm_layer=norm_layer,\n                    act_layer=act_layer,\n                    Attention_block=Attention_block_token_only,\n                    Mlp_block=Mlp_block_token_only,\n                    init_values=init_scale,\n                )\n                for i in range(depth_token_only)\n            ]\n        )\n\n        self.norm = norm_layer(int(embed_dim))\n\n        self.total_len = depth_token_only + depth\n\n        self.feature_info = [dict(num_chs=int(embed_dim), reduction=0, module='head')]\n        if not self.multiclass:\n            self.head = nn.Linear(int(embed_dim), num_classes) if num_classes > 0 else nn.Identity()\n        else:\n            self.head = nn.ModuleList([nn.Linear(int(embed_dim), 1) for _ in range(num_classes)])\n\n        self.rescale: float = 0.02\n\n        trunc_normal_(self.cls_token, std=self.rescale)\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=self.rescale)\n            if m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'cls_token'}\n\n    def get_classifier(self):\n        return self.head\n\n    def get_num_layers(self):\n        return len(self.blocks)\n\n    def reset_classifier(self, num_classes: int, global_pool: str = ''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x: torch.Tensor) -> torch.Tensor:\n        B = x.shape[0]\n        x = self.patch_embed(x)\n        cls_tokens = self.cls_token.expand(B, -1, -1)\n\n        for i, blk in enumerate(self.blocks):\n            x = blk(x)\n\n        for i, blk in enumerate(self.blocks_token_only):\n            cls_tokens = blk(x, cls_tokens)\n        x = torch.cat((cls_tokens, x), dim=1)\n\n        x = self.norm(x)\n\n        if not self.multiclass:\n            return x[:, 0]\n        else:\n            return x[:, : self.num_classes].reshape(B, self.num_classes, -1)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        B = x.shape[0]\n        x = self.forward_features(x)\n        if not self.multiclass:\n            x = self.head(x)\n            return x\n        else:\n            all_results = []\n            for i in range(self.num_classes):\n                all_results.append(self.head[i](x[:, i]))\n            return torch.cat(all_results, dim=1).reshape(B, self.num_classes)\n\n\n@register_model\ndef S60(pretrained: bool = False, **kwargs):\n    model = PatchConvnet(\n        patch_size=16,\n        embed_dim=384,\n        depth=60,\n        num_heads=1,\n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        Patch_layer=ConvStem,\n        Attention_block=Conv_blocks_se,\n        depth_token_only=1,\n        mlp_ratio_clstk=3.0,\n        **kwargs\n    )\n\n    return model\n\n\n@register_model\ndef S120(pretrained: bool = False, **kwargs):\n    model = PatchConvnet(\n        patch_size=16,\n        embed_dim=384,\n        depth=120,\n        num_heads=1,\n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        Patch_layer=ConvStem,\n        Attention_block=Conv_blocks_se,\n        init_scale=1e-6,\n        mlp_ratio_clstk=3.0,\n        **kwargs\n    )\n\n    return model\n\n\n@register_model\ndef B60(pretrained: bool = False, **kwargs):\n    model = PatchConvnet(\n        patch_size=16,\n        embed_dim=768,\n        depth=60,\n        num_heads=1,\n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        Attention_block=Conv_blocks_se,\n        init_scale=1e-6,\n        **kwargs\n    )\n\n    return model\n\n\n@register_model\ndef B120(pretrained: bool = False, **kwargs):\n    model = PatchConvnet(\n        patch_size=16,\n        embed_dim=768,\n        depth=120,\n        num_heads=1,\n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        Patch_layer=ConvStem,\n        Attention_block=Conv_blocks_se,\n        init_scale=1e-6,\n        **kwargs\n    )\n\n    return model\n\n\n@register_model\ndef L60(pretrained: bool = False, **kwargs):\n    model = PatchConvnet(\n        patch_size=16,\n        embed_dim=1024,\n        depth=60,\n        num_heads=1,\n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        Patch_layer=ConvStem,\n        Attention_block=Conv_blocks_se,\n        init_scale=1e-6,\n        mlp_ratio_clstk=3.0,\n        **kwargs\n    )\n\n    return model\n\n\n@register_model\ndef L120(pretrained: bool = False, **kwargs):\n    model = PatchConvnet(\n        patch_size=16,\n        embed_dim=1024,\n        depth=120,\n        num_heads=1,\n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        Patch_layer=ConvStem,\n        Attention_block=Conv_blocks_se,\n        init_scale=1e-6,\n        mlp_ratio_clstk=3.0,\n        **kwargs\n    )\n\n    return model\n\n\n@register_model\ndef S60_multi(pretrained: bool = False, **kwargs):\n    model = PatchConvnet(\n        patch_size=16,\n        embed_dim=384,\n        depth=60,\n        num_heads=1,\n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        Patch_layer=ConvStem,\n        Attention_block=Conv_blocks_se,\n        Attention_block_token_only=Learned_Aggregation_Layer_multi,\n        depth_token_only=1,\n        mlp_ratio_clstk=3.0,\n        multiclass=True,\n        **kwargs\n    )\n\n    return model\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = PatchConvnet(\n        patch_size=16,\n        embed_dim=384,\n        depth=60,\n        num_heads=1,\n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        Patch_layer=ConvStem,\n        Attention_block=Conv_blocks_se,\n        depth_token_only=1,\n        mlp_ratio_clstk=3.0,\n    )\n    output=model(input)\n    print(output.shape)"
  },
  {
    "path": "model/backbone/ShuffleTransformer.py",
    "content": "import torch\nfrom torch import nn, einsum\nfrom einops import rearrange, repeat\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0., stride=False):\n        super().__init__()\n        self.stride = stride\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)\n        self.act = act_layer()\n        self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)\n        self.drop = nn.Dropout(drop, inplace=True)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\nclass Attention(nn.Module):\n    def __init__(self, dim, num_heads, window_size=1, shuffle=False, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., relative_pos_embedding=False):\n        super().__init__()\n        self.num_heads = num_heads\n        self.relative_pos_embedding = relative_pos_embedding\n        head_dim = dim // self.num_heads\n        self.ws = window_size\n        self.shuffle = shuffle\n\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.to_qkv = nn.Conv2d(dim, dim * 3, 1, bias=False)\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Conv2d(dim, dim, 1)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        if self.relative_pos_embedding:\n            # define a parameter table of relative position bias\n            self.relative_position_bias_table = nn.Parameter(\n                torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n\n            # get pair-wise relative position index for each token inside the window\n            coords_h = torch.arange(self.ws)\n            coords_w = torch.arange(self.ws)\n            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n            relative_coords[:, :, 0] += self.ws - 1  # shift to start from 0\n            relative_coords[:, :, 1] += self.ws - 1\n            relative_coords[:, :, 0] *= 2 * self.ws - 1\n            relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n            self.register_buffer(\"relative_position_index\", relative_position_index)\n\n            trunc_normal_(self.relative_position_bias_table, std=.02)\n            # print('The relative_pos_embedding is used')\n\n    def forward(self, x):\n        b, c, h, w = x.shape\n        qkv = self.to_qkv(x)\n\n        if self.shuffle:\n            q, k, v = rearrange(qkv, 'b (qkv h d) (ws1 hh) (ws2 ww) -> qkv (b hh ww) h (ws1 ws2) d', h=self.num_heads, qkv=3, ws1=self.ws, ws2=self.ws)\n        else:\n            q, k, v = rearrange(qkv, 'b (qkv h d) (hh ws1) (ww ws2) -> qkv (b hh ww) h (ws1 ws2) d', h=self.num_heads, qkv=3, ws1=self.ws, ws2=self.ws)\n\n        dots = (q @ k.transpose(-2, -1)) * self.scale\n\n        if self.relative_pos_embedding:\n            relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n            self.ws * self.ws, self.ws * self.ws, -1)  # Wh*Ww,Wh*Ww,nH\n            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n            dots += relative_position_bias.unsqueeze(0)\n\n        attn = dots.softmax(dim=-1)\n        out = attn @ v\n\n        if self.shuffle:\n            out = rearrange(out, '(b hh ww) h (ws1 ws2) d -> b (h d) (ws1 hh) (ws2 ww)', h=self.num_heads, b=b, hh=h//self.ws, ws1=self.ws, ws2=self.ws)\n        else:\n            out = rearrange(out, '(b hh ww) h (ws1 ws2) d -> b (h d) (hh ws1) (ww ws2)', h=self.num_heads, b=b, hh=h//self.ws, ws1=self.ws, ws2=self.ws)\n \n        out = self.proj(out)\n        out = self.proj_drop(out)\n\n        return out\n\nclass Block(nn.Module):\n    def __init__(self, dim, out_dim, num_heads, window_size=1, shuffle=False, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.ReLU6, norm_layer=nn.BatchNorm2d, stride=False, relative_pos_embedding=False):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim, num_heads=num_heads, window_size=window_size, shuffle=shuffle, qkv_bias=qkv_bias, qk_scale=qk_scale, \n            attn_drop=attn_drop, proj_drop=drop, relative_pos_embedding=relative_pos_embedding)\n        self.local = nn.Conv2d(dim, dim, window_size, 1, window_size//2, groups=dim, bias=qkv_bias)\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=out_dim, act_layer=act_layer, drop=drop, stride=stride)\n        self.norm3 = norm_layer(dim)\n        # print(\"input dim={}, output dim={}, stride={}, expand={}, num_heads={}\".format(dim, out_dim, stride, shuffle, num_heads))\n\n    def forward(self, x):\n        x = x + self.drop_path(self.attn(self.norm1(x)))\n        x = x + self.local(self.norm2(x)) # local connection\n        x = x + self.drop_path(self.mlp(self.norm3(x)))\n        return x\n\n\nclass PatchMerging(nn.Module):\n    def __init__(self, dim, out_dim, norm_layer=nn.BatchNorm2d):\n        super().__init__()\n        self.dim = dim\n        self.out_dim = out_dim\n        self.norm = norm_layer(dim)\n        self.reduction = nn.Conv2d(dim, out_dim, 2, 2, 0, bias=False)\n\n    def forward(self, x):\n        x = self.norm(x)\n        x = self.reduction(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"input dim={self.dim}, out dim={self.out_dim}\"\n\n\nclass StageModule(nn.Module):\n    def __init__(self, layers, dim, out_dim, num_heads, window_size=1, shuffle=True, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.ReLU6, norm_layer=nn.BatchNorm2d, relative_pos_embedding=False):\n        super().__init__()\n        assert layers % 2 == 0, 'Stage layers need to be divisible by 2 for regular and shifted block.'\n\n        if dim != out_dim:\n            self.patch_partition = PatchMerging(dim, out_dim)\n        else:\n            self.patch_partition = None\n\n        num = layers // 2\n        self.layers = nn.ModuleList([])\n        for idx in range(num):\n            the_last = (idx==num-1)\n            self.layers.append(nn.ModuleList([\n                Block(dim=out_dim, out_dim=out_dim, num_heads=num_heads, window_size=window_size, shuffle=False, mlp_ratio=mlp_ratio,\n                      qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path,\n                      relative_pos_embedding=relative_pos_embedding),\n                Block(dim=out_dim, out_dim=out_dim, num_heads=num_heads, window_size=window_size, shuffle=shuffle, mlp_ratio=mlp_ratio,\n                      qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path, \n                      relative_pos_embedding=relative_pos_embedding)\n            ]))\n\n    def forward(self, x):\n        if self.patch_partition:\n            x = self.patch_partition(x)\n            \n        for regular_block, shifted_block in self.layers:\n            x = regular_block(x)\n            x = shifted_block(x)\n        return x\n\n\nclass PatchEmbedding(nn.Module):\n    def __init__(self, inter_channel=32, out_channels=48):\n        super().__init__()\n        self.conv1 = nn.Sequential(\n            nn.Conv2d(3, inter_channel, kernel_size=3, stride=2, padding=1),\n            nn.BatchNorm2d(inter_channel),\n            nn.ReLU6(inplace=True)\n        )\n        self.conv2 = nn.Sequential(\n            nn.Conv2d(inter_channel, out_channels, kernel_size=3, stride=2, padding=1),\n            nn.BatchNorm2d(out_channels),\n            nn.ReLU6(inplace=True)\n        )\n        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)\n\n    def forward(self, x):\n        x = self.conv3(self.conv2(self.conv1(x)))\n        return x\n\n\nclass ShuffleTransformer(nn.Module):\n    def __init__(self, img_size=224, in_chans=3, num_classes=1000, token_dim=32, embed_dim=96, mlp_ratio=4., layers=[2,2,6,2], num_heads=[3,6,12,24], \n                relative_pos_embedding=True, shuffle=True, window_size=7, qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., \n                has_pos_embed=False, **kwargs):\n        super().__init__()\n        self.num_classes = num_classes\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n        self.has_pos_embed = has_pos_embed\n        dims = [i*32 for i in num_heads]\n\n        self.to_token = PatchEmbedding(inter_channel=token_dim, out_channels=embed_dim)\n\n        num_patches = (img_size*img_size) // 16\n\n        if self.has_pos_embed:\n            self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(n_position=num_patches, d_hid=embed_dim), requires_grad=False)\n            self.pos_drop = nn.Dropout(p=drop_rate)\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 4)]  # stochastic depth decay rule\n        self.stage1 = StageModule(layers[0], embed_dim, dims[0], num_heads[0], window_size=window_size, shuffle=shuffle,\n                                  mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0],\n                                  relative_pos_embedding=relative_pos_embedding)\n        self.stage2 = StageModule(layers[1], dims[0], dims[1], num_heads[1], window_size=window_size, shuffle=shuffle,\n                                  mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1],\n                                  relative_pos_embedding=relative_pos_embedding)\n        self.stage3 = StageModule(layers[2], dims[1], dims[2], num_heads[2], window_size=window_size, shuffle=shuffle,\n                                  mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[2],\n                                  relative_pos_embedding=relative_pos_embedding)\n        self.stage4 = StageModule(layers[3], dims[2], dims[3], num_heads[3], window_size=window_size, shuffle=shuffle, \n                                  mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[3],\n                                  relative_pos_embedding=relative_pos_embedding)\n\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        # Classifier head\n        self.head = nn.Linear(dims[3], num_classes) if num_classes > 0 else nn.Identity()\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)):\n            nn.init.constant_(m.weight, 1.0)\n            nn.init.constant_(m.bias, 0)\n        elif isinstance(m, (nn.Linear, nn.Conv2d)):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, (nn.Linear, nn.Conv2d)) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {'relative_position_bias_table'}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x):\n        x = self.to_token(x)\n        b, c, h, w = x.shape\n\n        if self.has_pos_embed:\n            x = x + self.pos_embed.view(1, h, w, c).permute(0, 3, 1, 2)\n            x = self.pos_drop(x)\n\n        x = self.stage1(x)\n        x = self.stage2(x)\n        x = self.stage3(x)\n        x = self.stage4(x)\n\n        x = self.avgpool(x)\n        x = torch.flatten(x, 1)\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    sft = ShuffleTransformer()\n    output=sft(input)\n    print(output.shape)"
  },
  {
    "path": "model/backbone/TnT.py",
    "content": "# 2021.06.15-Changed for implementation of TNT model\n#            Huawei Technologies Co., Ltd. <foss@huawei.com>\nimport torch\nimport torch.nn as nn\nfrom functools import partial\nimport math\n\nfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\nfrom timm.models.helpers import load_pretrained\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nfrom timm.models.resnet import resnet26d, resnet50d\nfrom timm.models.registry import register_model\n\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n        'crop_pct': .9, 'interpolation': 'bicubic',\n        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,\n        'first_conv': 'patch_embed.proj', 'classifier': 'head',\n        **kwargs\n    }\n\n\ndefault_cfgs = {\n    'tnt_s_patch16_224': _cfg(\n        mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),\n    ),\n    'tnt_b_patch16_224': _cfg(\n        mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),\n    ),\n}\n\n\ndef make_divisible(v, divisor=8, min_value=None):\n    min_value = min_value or divisor\n    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)\n    # Make sure that round down does not go down by more than 10%.\n    if new_v < 0.9 * v:\n        new_v += divisor\n    return new_v\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass SE(nn.Module):\n    def __init__(self, dim, hidden_ratio=None):\n        super().__init__()\n        hidden_ratio = hidden_ratio or 1\n        self.dim = dim\n        hidden_dim = int(dim * hidden_ratio)\n        self.fc = nn.Sequential(\n            nn.LayerNorm(dim),\n            nn.Linear(dim, hidden_dim),\n            nn.ReLU(inplace=True),\n            nn.Linear(hidden_dim, dim),\n            nn.Tanh()\n        )\n\n    def forward(self, x):\n        a = x.mean(dim=1, keepdim=True) # B, 1, C\n        a = self.fc(a)\n        x = a * x\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n        super().__init__()\n        self.hidden_dim = hidden_dim\n        self.num_heads = num_heads\n        head_dim = hidden_dim // num_heads\n        self.head_dim = head_dim\n        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qk = nn.Linear(dim, hidden_dim * 2, bias=qkv_bias)\n        self.v = nn.Linear(dim, dim, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop, inplace=True)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop, inplace=True)\n\n    def forward(self, x):\n        B, N, C = x.shape\n        qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)\n        q, k = qk[0], qk[1]   # make torchscript happy (cannot use tensor as tuple)\n        v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n    \"\"\" TNT Block\n    \"\"\"\n    def __init__(self, outer_dim, inner_dim, outer_num_heads, inner_num_heads, num_words, mlp_ratio=4.,\n                 qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,\n                 norm_layer=nn.LayerNorm, se=0):\n        super().__init__()\n        self.has_inner = inner_dim > 0\n        if self.has_inner:\n            # Inner\n            self.inner_norm1 = norm_layer(inner_dim)\n            self.inner_attn = Attention(\n                inner_dim, inner_dim, num_heads=inner_num_heads, qkv_bias=qkv_bias,\n                qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n            self.inner_norm2 = norm_layer(inner_dim)\n            self.inner_mlp = Mlp(in_features=inner_dim, hidden_features=int(inner_dim * mlp_ratio),\n                                 out_features=inner_dim, act_layer=act_layer, drop=drop)\n\n            self.proj_norm1 = norm_layer(num_words * inner_dim)\n            self.proj = nn.Linear(num_words * inner_dim, outer_dim, bias=False)\n            self.proj_norm2 = norm_layer(outer_dim)\n        # Outer\n        self.outer_norm1 = norm_layer(outer_dim)\n        self.outer_attn = Attention(\n            outer_dim, outer_dim, num_heads=outer_num_heads, qkv_bias=qkv_bias,\n            qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.outer_norm2 = norm_layer(outer_dim)\n        self.outer_mlp = Mlp(in_features=outer_dim, hidden_features=int(outer_dim * mlp_ratio),\n                             out_features=outer_dim, act_layer=act_layer, drop=drop)\n        # SE\n        self.se = se\n        self.se_layer = None\n        if self.se > 0:\n            self.se_layer = SE(outer_dim, 0.25)\n\n    def forward(self, inner_tokens, outer_tokens):\n        if self.has_inner:\n            inner_tokens = inner_tokens + self.drop_path(self.inner_attn(self.inner_norm1(inner_tokens))) # B*N, k*k, c\n            inner_tokens = inner_tokens + self.drop_path(self.inner_mlp(self.inner_norm2(inner_tokens))) # B*N, k*k, c\n            B, N, C = outer_tokens.size()\n            outer_tokens[:,1:] = outer_tokens[:,1:] + self.proj_norm2(self.proj(self.proj_norm1(inner_tokens.reshape(B, N-1, -1)))) # B, N, C\n        if self.se > 0:\n            outer_tokens = outer_tokens + self.drop_path(self.outer_attn(self.outer_norm1(outer_tokens)))\n            tmp_ = self.outer_mlp(self.outer_norm2(outer_tokens))\n            outer_tokens = outer_tokens + self.drop_path(tmp_ + self.se_layer(tmp_))\n        else:\n            outer_tokens = outer_tokens + self.drop_path(self.outer_attn(self.outer_norm1(outer_tokens)))\n            outer_tokens = outer_tokens + self.drop_path(self.outer_mlp(self.outer_norm2(outer_tokens)))\n        return inner_tokens, outer_tokens\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Visual Word Embedding\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, outer_dim=768, inner_dim=24, inner_stride=4):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n        self.inner_dim = inner_dim\n        self.num_words = math.ceil(patch_size[0] / inner_stride) * math.ceil(patch_size[1] / inner_stride)\n        \n        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)\n        self.proj = nn.Conv2d(in_chans, inner_dim, kernel_size=7, padding=3, stride=inner_stride)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.unfold(x) # B, Ck2, N\n        x = x.transpose(1, 2).reshape(B * self.num_patches, C, *self.patch_size) # B*N, C, 16, 16\n        x = self.proj(x) # B*N, C, 8, 8\n        x = x.reshape(B * self.num_patches, self.inner_dim, -1).transpose(1, 2) # B*N, 8*8, C\n        return x\n\n\nclass TNT(nn.Module):\n    \"\"\" TNT (Transformer in Transformer) for computer vision\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, outer_dim=768, inner_dim=48,\n                 depth=12, outer_num_heads=12, inner_num_heads=4, mlp_ratio=4., qkv_bias=False, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, inner_stride=4, se=0):\n        super().__init__()\n        self.num_classes = num_classes\n        self.num_features = self.outer_dim = outer_dim  # num_features for consistency with other models\n\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, outer_dim=outer_dim,\n            inner_dim=inner_dim, inner_stride=inner_stride)\n        self.num_patches = num_patches = self.patch_embed.num_patches\n        num_words = self.patch_embed.num_words\n        \n        self.proj_norm1 = norm_layer(num_words * inner_dim)\n        self.proj = nn.Linear(num_words * inner_dim, outer_dim)\n        self.proj_norm2 = norm_layer(outer_dim)\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, outer_dim))\n        self.outer_tokens = nn.Parameter(torch.zeros(1, num_patches, outer_dim), requires_grad=False)\n        self.outer_pos = nn.Parameter(torch.zeros(1, num_patches + 1, outer_dim))\n        self.inner_pos = nn.Parameter(torch.zeros(1, num_words, inner_dim))\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\n        vanilla_idxs = []\n        blocks = []\n        for i in range(depth):\n            if i in vanilla_idxs:\n                blocks.append(Block(\n                    outer_dim=outer_dim, inner_dim=-1, outer_num_heads=outer_num_heads, inner_num_heads=inner_num_heads,\n                    num_words=num_words, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,\n                    attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, se=se))\n            else:\n                blocks.append(Block(\n                    outer_dim=outer_dim, inner_dim=inner_dim, outer_num_heads=outer_num_heads, inner_num_heads=inner_num_heads,\n                    num_words=num_words, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,\n                    attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, se=se))\n        self.blocks = nn.ModuleList(blocks)\n        self.norm = norm_layer(outer_dim)\n\n        # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here\n        #self.repr = nn.Linear(outer_dim, representation_size)\n        #self.repr_act = nn.Tanh()\n\n        # Classifier head\n        self.head = nn.Linear(outer_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n        trunc_normal_(self.cls_token, std=.02)\n        trunc_normal_(self.outer_pos, std=.02)\n        trunc_normal_(self.inner_pos, std=.02)\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'outer_pos', 'inner_pos', 'cls_token'}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.outer_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x):\n        B = x.shape[0]\n        inner_tokens = self.patch_embed(x) + self.inner_pos # B*N, 8*8, C\n        \n        outer_tokens = self.proj_norm2(self.proj(self.proj_norm1(inner_tokens.reshape(B, self.num_patches, -1))))        \n        outer_tokens = torch.cat((self.cls_token.expand(B, -1, -1), outer_tokens), dim=1)\n        \n        outer_tokens = outer_tokens + self.outer_pos\n        outer_tokens = self.pos_drop(outer_tokens)\n\n        for blk in self.blocks:\n            inner_tokens, outer_tokens = blk(inner_tokens, outer_tokens)\n\n        outer_tokens = self.norm(outer_tokens)\n        return outer_tokens[:, 0]\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n\ndef _conv_filter(state_dict, patch_size=16):\n    \"\"\" convert patch embedding weight from manual patchify + linear proj to conv\"\"\"\n    out_dict = {}\n    for k, v in state_dict.items():\n        if 'patch_embed.proj.weight' in k:\n            v = v.reshape((v.shape[0], 3, patch_size, patch_size))\n        out_dict[k] = v\n    return out_dict\n\n\n@register_model\ndef tnt_s_patch16_224(pretrained=False, **kwargs):\n    patch_size = 16\n    inner_stride = 4\n    outer_dim = 384\n    inner_dim = 24\n    outer_num_heads = 6\n    inner_num_heads = 4\n    outer_dim = make_divisible(outer_dim, outer_num_heads)\n    inner_dim = make_divisible(inner_dim, inner_num_heads)\n    model = TNT(img_size=224, patch_size=patch_size, outer_dim=outer_dim, inner_dim=inner_dim, depth=12,\n                outer_num_heads=outer_num_heads, inner_num_heads=inner_num_heads, qkv_bias=False,\n                inner_stride=inner_stride, **kwargs)\n    model.default_cfg = default_cfgs['tnt_s_patch16_224']\n    if pretrained:\n        load_pretrained(\n            model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)\n    return model\n\n\n@register_model\ndef tnt_b_patch16_224(pretrained=False, **kwargs):\n    patch_size = 16\n    inner_stride = 4\n    outer_dim = 640\n    inner_dim = 40\n    outer_num_heads = 10\n    inner_num_heads = 4\n    outer_dim = make_divisible(outer_dim, outer_num_heads)\n    inner_dim = make_divisible(inner_dim, inner_num_heads)\n    model = TNT(img_size=224, patch_size=patch_size, outer_dim=outer_dim, inner_dim=inner_dim, depth=12,\n                outer_num_heads=outer_num_heads, inner_num_heads=inner_num_heads, qkv_bias=False,\n                inner_stride=inner_stride, **kwargs)\n    model.default_cfg = default_cfgs['tnt_b_patch16_224']\n    if pretrained:\n        load_pretrained(\n            model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)\n    return model\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = TNT(\n        img_size=224, \n        patch_size=16, \n        outer_dim=384, \n        inner_dim=24, \n        depth=12,\n        outer_num_heads=6, \n        inner_num_heads=4, \n        qkv_bias=False,\n        inner_stride=4)\n    output=model(input)\n    print(output.shape)"
  },
  {
    "path": "model/backbone/VOLO.py",
    "content": "# Copyright 2021 Sea Limited.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nVision OutLOoker (VOLO) implementation\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nfrom timm.models.registry import register_model\nimport math\nimport numpy as np\n\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n        'crop_pct': .96, 'interpolation': 'bicubic',\n        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,\n        'first_conv': 'patch_embed.proj', 'classifier': 'head',\n        **kwargs\n    }\n\n\ndefault_cfgs = {\n    'volo': _cfg(crop_pct=0.96),\n    'volo_large': _cfg(crop_pct=1.15),\n}\n\n\nclass OutlookAttention(nn.Module):\n    \"\"\"\n    Implementation of outlook attention\n    --dim: hidden dim\n    --num_heads: number of heads\n    --kernel_size: kernel size in each window for outlook attention\n    return: token features after outlook attention\n    \"\"\"\n\n    def __init__(self, dim, num_heads, kernel_size=3, padding=1, stride=1,\n                 qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n        super().__init__()\n        head_dim = dim // num_heads\n        self.num_heads = num_heads\n        self.kernel_size = kernel_size\n        self.padding = padding\n        self.stride = stride\n        self.scale = qk_scale or head_dim**-0.5\n\n        self.v = nn.Linear(dim, dim, bias=qkv_bias)\n        self.attn = nn.Linear(dim, kernel_size**4 * num_heads)\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride)\n        self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True)\n\n    def forward(self, x):\n        B, H, W, C = x.shape\n\n        v = self.v(x).permute(0, 3, 1, 2)  # B, C, H, W\n\n        h, w = math.ceil(H / self.stride), math.ceil(W / self.stride)\n        v = self.unfold(v).reshape(B, self.num_heads, C // self.num_heads,\n                                   self.kernel_size * self.kernel_size,\n                                   h * w).permute(0, 1, 4, 3, 2)  # B,H,N,kxk,C/H\n\n        attn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)\n        attn = self.attn(attn).reshape(\n            B, h * w, self.num_heads, self.kernel_size * self.kernel_size,\n            self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 4)  # B,H,N,kxk,kxk\n        attn = attn * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).permute(0, 1, 4, 3, 2).reshape(\n            B, C * self.kernel_size * self.kernel_size, h * w)\n        x = F.fold(x, output_size=(H, W), kernel_size=self.kernel_size,\n                   padding=self.padding, stride=self.stride)\n\n        x = self.proj(x.permute(0, 2, 3, 1))\n        x = self.proj_drop(x)\n\n        return x\n\n\nclass Outlooker(nn.Module):\n    \"\"\"\n    Implementation of outlooker layer: which includes outlook attention + MLP\n    Outlooker is the first stage in our VOLO\n    --dim: hidden dim\n    --num_heads: number of heads\n    --mlp_ratio: mlp ratio\n    --kernel_size: kernel size in each window for outlook attention\n    return: outlooker layer\n    \"\"\"\n    def __init__(self, dim, kernel_size, padding, stride=1,\n                 num_heads=1,mlp_ratio=3., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU,\n                 norm_layer=nn.LayerNorm, qkv_bias=False,\n                 qk_scale=None):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = OutlookAttention(dim, num_heads, kernel_size=kernel_size,\n                                     padding=padding, stride=stride,\n                                     qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                     attn_drop=attn_drop)\n\n        self.drop_path = DropPath(\n            drop_path) if drop_path > 0. else nn.Identity()\n\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim,\n                       hidden_features=mlp_hidden_dim,\n                       act_layer=act_layer)\n\n    def forward(self, x):\n        x = x + self.drop_path(self.attn(self.norm1(x)))\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n\n\nclass Mlp(nn.Module):\n    \"Implementation of MLP\"\n\n    def __init__(self, in_features, hidden_features=None,\n                 out_features=None, act_layer=nn.GELU,\n                 drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    \"Implementation of self-attention\"\n\n    def __init__(self, dim,  num_heads=8, qkv_bias=False,\n                 qk_scale=None, attn_drop=0., proj_drop=0.):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim**-0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x):\n        B, H, W, C = x.shape\n\n        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads,\n                                  C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[\n            2]  # make torchscript happy (cannot use tensor as tuple)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, H, W, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n\n        return x\n\n\nclass Transformer(nn.Module):\n    \"\"\"\n    Implementation of Transformer,\n    Transformer is the second stage in our VOLO\n    \"\"\"\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False,\n                 qk_scale=None, attn_drop=0., drop_path=0.,\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias,\n                              qk_scale=qk_scale, attn_drop=attn_drop)\n\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(\n            drop_path) if drop_path > 0. else nn.Identity()\n\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim,\n                       hidden_features=mlp_hidden_dim,\n                       act_layer=act_layer)\n\n    def forward(self, x):\n        x = x + self.drop_path(self.attn(self.norm1(x)))\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n\n\nclass ClassAttention(nn.Module):\n    \"\"\"\n    Class attention layer from CaiT, see details in CaiT\n    Class attention is the post stage in our VOLO, which is optional.\n    \"\"\"\n    def __init__(self, dim, num_heads=8, head_dim=None, qkv_bias=False,\n                 qk_scale=None, attn_drop=0., proj_drop=0.):\n        super().__init__()\n        self.num_heads = num_heads\n        if head_dim is not None:\n            self.head_dim = head_dim\n        else:\n            head_dim = dim // num_heads\n            self.head_dim = head_dim\n        self.scale = qk_scale or head_dim**-0.5\n\n        self.kv = nn.Linear(dim,\n                            self.head_dim * self.num_heads * 2,\n                            bias=qkv_bias)\n        self.q = nn.Linear(dim, self.head_dim * self.num_heads, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(self.head_dim * self.num_heads, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x):\n        B, N, C = x.shape\n\n        kv = self.kv(x).reshape(B, N, 2, self.num_heads,\n                                self.head_dim).permute(2, 0, 3, 1, 4)\n        k, v = kv[0], kv[\n            1]  # make torchscript happy (cannot use tensor as tuple)\n        q = self.q(x[:, :1, :]).reshape(B, self.num_heads, 1, self.head_dim)\n        attn = ((q * self.scale) @ k.transpose(-2, -1))\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        cls_embed = (attn @ v).transpose(1, 2).reshape(\n            B, 1, self.head_dim * self.num_heads)\n        cls_embed = self.proj(cls_embed)\n        cls_embed = self.proj_drop(cls_embed)\n        return cls_embed\n\n\nclass ClassBlock(nn.Module):\n    \"\"\"\n    Class attention block from CaiT, see details in CaiT\n    We use two-layers class attention in our VOLO, which is optional.\n    \"\"\"\n\n    def __init__(self, dim, num_heads, head_dim=None, mlp_ratio=4.,\n                 qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = ClassAttention(\n            dim, num_heads=num_heads, head_dim=head_dim, qkv_bias=qkv_bias,\n            qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n        # NOTE: drop path for stochastic depth\n        self.drop_path = DropPath(\n            drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim,\n                       hidden_features=mlp_hidden_dim,\n                       act_layer=act_layer,\n                       drop=drop)\n\n    def forward(self, x):\n        cls_embed = x[:, :1]\n        cls_embed = cls_embed + self.drop_path(self.attn(self.norm1(x)))\n        cls_embed = cls_embed + self.drop_path(self.mlp(self.norm2(cls_embed)))\n        return torch.cat([cls_embed, x[:, 1:]], dim=1)\n\n\ndef get_block(block_type, **kargs):\n    \"\"\"\n    get block by name, specifically for class attention block in here\n    \"\"\"\n    if block_type == 'ca':\n        return ClassBlock(**kargs)\n\n\ndef rand_bbox(size, lam, scale=1):\n    \"\"\"\n    get bounding box as token labeling (https://github.com/zihangJiang/TokenLabeling)\n    return: bounding box\n    \"\"\"\n    W = size[1] // scale\n    H = size[2] // scale\n    cut_rat = np.sqrt(1. - lam)\n    cut_w = np.int(W * cut_rat)\n    cut_h = np.int(H * cut_rat)\n\n    # uniform\n    cx = np.random.randint(W)\n    cy = np.random.randint(H)\n\n    bbx1 = np.clip(cx - cut_w // 2, 0, W)\n    bby1 = np.clip(cy - cut_h // 2, 0, H)\n    bbx2 = np.clip(cx + cut_w // 2, 0, W)\n    bby2 = np.clip(cy + cut_h // 2, 0, H)\n\n    return bbx1, bby1, bbx2, bby2\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"\n    Image to Patch Embedding.\n    Different with ViT use 1 conv layer, we use 4 conv layers to do patch embedding\n    \"\"\"\n\n    def __init__(self, img_size=224, stem_conv=False, stem_stride=1,\n                 patch_size=8, in_chans=3, hidden_dim=64, embed_dim=384):\n        super().__init__()\n        assert patch_size in [4, 8, 16]\n\n        self.stem_conv = stem_conv\n        if stem_conv:\n            self.conv = nn.Sequential(\n                nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride,\n                          padding=3, bias=False),  # 112x112\n                nn.BatchNorm2d(hidden_dim),\n                nn.ReLU(inplace=True),\n                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,\n                          padding=1, bias=False),  # 112x112\n                nn.BatchNorm2d(hidden_dim),\n                nn.ReLU(inplace=True),\n                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,\n                          padding=1, bias=False),  # 112x112\n                nn.BatchNorm2d(hidden_dim),\n                nn.ReLU(inplace=True),\n            )\n\n        self.proj = nn.Conv2d(hidden_dim,\n                              embed_dim,\n                              kernel_size=patch_size // stem_stride,\n                              stride=patch_size // stem_stride)\n        self.num_patches = (img_size // patch_size) * (img_size // patch_size)\n\n    def forward(self, x):\n        if self.stem_conv:\n            x = self.conv(x)\n        x = self.proj(x)  # B, C, H, W\n        return x\n\n\nclass Downsample(nn.Module):\n    \"\"\"\n    Image to Patch Embedding, downsampling between stage1 and stage2\n    \"\"\"\n    def __init__(self, in_embed_dim, out_embed_dim, patch_size):\n        super().__init__()\n        self.proj = nn.Conv2d(in_embed_dim, out_embed_dim,\n                              kernel_size=patch_size, stride=patch_size)\n\n    def forward(self, x):\n        x = x.permute(0, 3, 1, 2)\n        x = self.proj(x)  # B, C, H, W\n        x = x.permute(0, 2, 3, 1)\n        return x\n\n\ndef outlooker_blocks(block_fn, index, dim, layers, num_heads=1, kernel_size=3,\n                     padding=1,stride=1, mlp_ratio=3., qkv_bias=False, qk_scale=None,\n                     attn_drop=0, drop_path_rate=0., **kwargs):\n    \"\"\"\n    generate outlooker layer in stage1\n    return: outlooker layers\n    \"\"\"\n    blocks = []\n    for block_idx in range(layers[index]):\n        block_dpr = drop_path_rate * (block_idx +\n                                      sum(layers[:index])) / (sum(layers) - 1)\n        blocks.append(block_fn(dim, kernel_size=kernel_size, padding=padding,\n                               stride=stride, num_heads=num_heads, mlp_ratio=mlp_ratio,\n                               qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,\n                               drop_path=block_dpr))\n\n    blocks = nn.Sequential(*blocks)\n\n    return blocks\n\n\ndef transformer_blocks(block_fn, index, dim, layers, num_heads, mlp_ratio=3.,\n                       qkv_bias=False, qk_scale=None, attn_drop=0,\n                       drop_path_rate=0., **kwargs):\n    \"\"\"\n    generate transformer layers in stage2\n    return: transformer layers\n    \"\"\"\n    blocks = []\n    for block_idx in range(layers[index]):\n        block_dpr = drop_path_rate * (block_idx +\n                                      sum(layers[:index])) / (sum(layers) - 1)\n        blocks.append(\n            block_fn(dim, num_heads,\n                     mlp_ratio=mlp_ratio,\n                     qkv_bias=qkv_bias,\n                     qk_scale=qk_scale,\n                     attn_drop=attn_drop,\n                     drop_path=block_dpr))\n\n    blocks = nn.Sequential(*blocks)\n\n    return blocks\n\n\nclass VOLO(nn.Module):\n    \"\"\"\n    Vision Outlooker, the main class of our model\n    --layers: [x,x,x,x], four blocks in two stages, the first block is outlooker, the\n              other three are transformer, we set four blocks, which are easily\n              applied to downstream tasks\n    --img_size, --in_chans, --num_classes: these three are very easy to understand\n    --patch_size: patch_size in outlook attention\n    --stem_hidden_dim: hidden dim of patch embedding, d1-d4 is 64, d5 is 128\n    --embed_dims, --num_heads: embedding dim, number of heads in each block\n    --downsamples: flags to apply downsampling or not\n    --outlook_attention: flags to apply outlook attention or not\n    --mlp_ratios, --qkv_bias, --qk_scale, --drop_rate: easy to undertand\n    --attn_drop_rate, --drop_path_rate, --norm_layer: easy to undertand\n    --post_layers: post layers like two class attention layers using [ca, ca],\n                  if yes, return_mean=False\n    --return_mean: use mean of all feature tokens for classification, if yes, no class token\n    --return_dense: use token labeling, details are here:\n                    https://github.com/zihangJiang/TokenLabeling\n    --mix_token: mixing tokens as token labeling, details are here:\n                    https://github.com/zihangJiang/TokenLabeling\n    --pooling_scale: pooling_scale=2 means we downsample 2x\n    --out_kernel, --out_stride, --out_padding: kerner size,\n                                               stride, and padding for outlook attention\n    \"\"\"\n    def __init__(self, layers, img_size=224, in_chans=3, num_classes=1000, patch_size=8,\n                 stem_hidden_dim=64, embed_dims=None, num_heads=None, downsamples=None,\n                 outlook_attention=None, mlp_ratios=None, qkv_bias=False, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,\n                 post_layers=None, return_mean=False, return_dense=True, mix_token=True,\n                 pooling_scale=2, out_kernel=3, out_stride=2, out_padding=1):\n\n        super().__init__()\n        self.num_classes = num_classes\n        self.patch_embed = PatchEmbed(stem_conv=True, stem_stride=2, patch_size=patch_size,\n                                      in_chans=in_chans, hidden_dim=stem_hidden_dim,\n                                      embed_dim=embed_dims[0])\n\n        # inital positional encoding, we add positional encoding after outlooker blocks\n        self.pos_embed = nn.Parameter(\n            torch.zeros(1, img_size // patch_size // pooling_scale,\n                        img_size // patch_size // pooling_scale,\n                        embed_dims[-1]))\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # set the main block in network\n        network = []\n        for i in range(len(layers)):\n            if outlook_attention[i]:\n                # stage 1\n                stage = outlooker_blocks(Outlooker, i, embed_dims[i], layers,\n                                         downsample=downsamples[i], num_heads=num_heads[i],\n                                         kernel_size=out_kernel, stride=out_stride,\n                                         padding=out_padding, mlp_ratio=mlp_ratios[i],\n                                         qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                         attn_drop=attn_drop_rate, norm_layer=norm_layer)\n                network.append(stage)\n            else:\n                # stage 2\n                stage = transformer_blocks(Transformer, i, embed_dims[i], layers,\n                                           num_heads[i], mlp_ratio=mlp_ratios[i],\n                                           qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                           drop_path_rate=drop_path_rate,\n                                           attn_drop=attn_drop_rate,\n                                           norm_layer=norm_layer)\n                network.append(stage)\n\n            if downsamples[i]:\n                # downsampling between two stages\n                network.append(Downsample(embed_dims[i], embed_dims[i + 1], 2))\n\n        self.network = nn.ModuleList(network)\n\n        # set post block, for example, class attention layers\n        self.post_network = None\n        if post_layers is not None:\n            self.post_network = nn.ModuleList([\n                get_block(post_layers[i],\n                          dim=embed_dims[-1],\n                          num_heads=num_heads[-1],\n                          mlp_ratio=mlp_ratios[-1],\n                          qkv_bias=qkv_bias,\n                          qk_scale=qk_scale,\n                          attn_drop=attn_drop_rate,\n                          drop_path=0.,\n                          norm_layer=norm_layer)\n                for i in range(len(post_layers))\n            ])\n            self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1]))\n            trunc_normal_(self.cls_token, std=.02)\n\n        # set output type\n        self.return_mean = return_mean  # if yes, return mean, not use class token\n        self.return_dense = return_dense  # if yes, return class token and all feature tokens\n        if return_dense:\n            assert not return_mean, \"cannot return both mean and dense\"\n        self.mix_token = mix_token\n        self.pooling_scale = pooling_scale\n        if mix_token:  # enable token mixing, see token labeling for details.\n            self.beta = 1.0\n            assert return_dense, \"return all tokens if mix_token is enabled\"\n        if return_dense:\n            self.aux_head = nn.Linear(\n                embed_dims[-1],\n                num_classes) if num_classes > 0 else nn.Identity()\n        self.norm = norm_layer(embed_dims[-1])\n\n        # Classifier head\n        self.head = nn.Linear(\n            embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()\n\n        trunc_normal_(self.pos_embed, std=.02)\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token'}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes):\n        self.num_classes = num_classes\n        self.head = nn.Linear(\n            self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_embeddings(self, x):\n        # patch embedding\n        x = self.patch_embed(x)\n        # B,C,H,W-> B,H,W,C\n        x = x.permute(0, 2, 3, 1)\n        return x\n\n    def forward_tokens(self, x):\n        for idx, block in enumerate(self.network):\n            if idx == 2:  # add positional encoding after outlooker blocks\n                x = x + self.pos_embed\n                x = self.pos_drop(x)\n            x = block(x)\n\n        B, H, W, C = x.shape\n        x = x.reshape(B, -1, C)\n        return x\n\n    def forward_cls(self, x):\n        B, N, C = x.shape\n        cls_tokens = self.cls_token.expand(B, -1, -1)\n        x = torch.cat((cls_tokens, x), dim=1)\n        for block in self.post_network:\n            x = block(x)\n        return x\n\n    def forward(self, x):\n        # step1: patch embedding\n        x = self.forward_embeddings(x)\n\n        # mix token, see token labeling for details.\n        if self.mix_token and self.training:\n            lam = np.random.beta(self.beta, self.beta)\n            patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[\n                2] // self.pooling_scale\n            bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale)\n            temp_x = x.clone()\n            sbbx1,sbby1,sbbx2,sbby2=self.pooling_scale*bbx1,self.pooling_scale*bby1,\\\n                                    self.pooling_scale*bbx2,self.pooling_scale*bby2\n            temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :]\n            x = temp_x\n        else:\n            bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0\n\n        # step2: tokens learning in the two stages\n        x = self.forward_tokens(x)\n\n        # step3: post network, apply class attention or not\n        if self.post_network is not None:\n            x = self.forward_cls(x)\n        x = self.norm(x)\n\n        if self.return_mean:  # if no class token, return mean\n            return self.head(x.mean(1))\n\n        x_cls = self.head(x[:, 0])\n        if not self.return_dense:\n            return x_cls\n\n        x_aux = self.aux_head(\n            x[:, 1:]\n        )  # generate classes in all feature tokens, see token labeling\n\n        if not self.training:\n            return x_cls + 0.5 * x_aux.max(1)[0]\n\n        if self.mix_token and self.training:  # reverse \"mix token\", see token labeling for details.\n            x_aux = x_aux.reshape(x_aux.shape[0], patch_h, patch_w, x_aux.shape[-1])\n\n            temp_x = x_aux.clone()\n            temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :]\n            x_aux = temp_x\n\n            x_aux = x_aux.reshape(x_aux.shape[0], patch_h * patch_w, x_aux.shape[-1])\n\n        # return these: 1. class token, 2. classes from all feature tokens, 3. bounding box\n        return x_cls, x_aux, (bbx1, bby1, bbx2, bby2)\n\n\n@register_model\ndef volo_d1(pretrained=False, **kwargs):\n    \"\"\"\n    VOLO-D1 model, Params: 27M\n    --layers: [x,x,x,x], four blocks in two stages, the first stage(block) is outlooker,\n            the other three blocks are transformer, we set four blocks, which are easily\n             applied to downstream tasks\n    --embed_dims, --num_heads,: embedding dim, number of heads in each block\n    --downsamples: flags to apply downsampling or not in four blocks\n    --outlook_attention: flags to apply outlook attention or not\n    --mlp_ratios: mlp ratio in four blocks\n    --post_layers: post layers like two class attention layers using [ca, ca]\n    See detail for all args in the class VOLO()\n    \"\"\"\n    layers = [4, 4, 8, 2]  # num of layers in the four blocks\n    embed_dims = [192, 384, 384, 384]\n    num_heads = [6, 12, 12, 12]\n    mlp_ratios = [3, 3, 3, 3]\n    downsamples = [True, False, False, False] # do downsampling after first block\n    outlook_attention = [True, False, False, False ]\n    # first block is outlooker (stage1), the other three are transformer (stage2)\n    model = VOLO(layers,\n                 embed_dims=embed_dims,\n                 num_heads=num_heads,\n                 mlp_ratios=mlp_ratios,\n                 downsamples=downsamples,\n                 outlook_attention=outlook_attention,\n                 post_layers=['ca', 'ca'],\n                 **kwargs)\n    model.default_cfg = default_cfgs['volo']\n    return model\n\n\n@register_model\ndef volo_d2(pretrained=False, **kwargs):\n    \"\"\"\n    VOLO-D2 model, Params: 59M\n    \"\"\"\n    layers = [6, 4, 10, 4]\n    embed_dims = [256, 512, 512, 512]\n    num_heads = [8, 16, 16, 16]\n    mlp_ratios = [3, 3, 3, 3]\n    downsamples = [True, False, False, False]\n    outlook_attention = [True, False, False, False]\n    model = VOLO(layers,\n                 embed_dims=embed_dims,\n                 num_heads=num_heads,\n                 mlp_ratios=mlp_ratios,\n                 downsamples=downsamples,\n                 outlook_attention=outlook_attention,\n                 post_layers=['ca', 'ca'],\n                 **kwargs)\n    model.default_cfg = default_cfgs['volo']\n    return model\n\n\n@register_model\ndef volo_d3(pretrained=False, **kwargs):\n    \"\"\"\n    VOLO-D3 model, Params: 86M\n    \"\"\"\n    layers = [8, 8, 16, 4]\n    embed_dims = [256, 512, 512, 512]\n    num_heads = [8, 16, 16, 16]\n    mlp_ratios = [3, 3, 3, 3]\n    downsamples = [True, False, False, False]\n    outlook_attention = [True, False, False, False]\n    model = VOLO(layers,\n                 embed_dims=embed_dims,\n                 num_heads=num_heads,\n                 mlp_ratios=mlp_ratios,\n                 downsamples=downsamples,\n                 outlook_attention=outlook_attention,\n                 post_layers=['ca', 'ca'],\n                 **kwargs)\n    model.default_cfg = default_cfgs['volo']\n    return model\n\n\n@register_model\ndef volo_d4(pretrained=False, **kwargs):\n    \"\"\"\n    VOLO-D4 model, Params: 193M\n    \"\"\"\n    layers = [8, 8, 16, 4]\n    embed_dims = [384, 768, 768, 768]\n    num_heads = [12, 16, 16, 16]\n    mlp_ratios = [3, 3, 3, 3]\n    downsamples = [True, False, False, False]\n    outlook_attention = [True, False, False, False]\n    model = VOLO(layers,\n                 embed_dims=embed_dims,\n                 num_heads=num_heads,\n                 mlp_ratios=mlp_ratios,\n                 downsamples=downsamples,\n                 outlook_attention=outlook_attention,\n                 post_layers=['ca', 'ca'],\n                 **kwargs)\n    model.default_cfg = default_cfgs['volo_large']\n    return model\n\n\n@register_model\ndef volo_d5(pretrained=False, **kwargs):\n    \"\"\"\n    VOLO-D5 model, Params: 296M\n    stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5\n    \"\"\"\n    layers = [12, 12, 20, 4]\n    embed_dims = [384, 768, 768, 768]\n    num_heads = [12, 16, 16, 16]\n    mlp_ratios = [4, 4, 4, 4]\n    downsamples = [True, False, False, False]\n    outlook_attention = [True, False, False, False]\n    model = VOLO(layers,\n                 embed_dims=embed_dims,\n                 num_heads=num_heads,\n                 mlp_ratios=mlp_ratios,\n                 downsamples=downsamples,\n                 outlook_attention=outlook_attention,\n                 post_layers=['ca', 'ca'],\n                 stem_hidden_dim=128,\n                 **kwargs)\n    model.default_cfg = default_cfgs['volo_large']\n    return model\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VOLO([4, 4, 8, 2],\n                 embed_dims=[192, 384, 384, 384],\n                 num_heads=[6, 12, 12, 12],\n                 mlp_ratios=[3, 3, 3, 3],\n                 downsamples=[True, False, False, False],\n                 outlook_attention=[True, False, False, False ],\n                 post_layers=['ca', 'ca'],\n                 )\n    output=model(input)\n    print(output[0].shape)"
  },
  {
    "path": "model/backbone/convnextv2.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom timm.models.layers import trunc_normal_, DropPath\n\nclass Block(nn.Module):\n    \"\"\" ConvNeXtV2 Block.\n    \n    Args:\n        dim (int): Number of input channels.\n        drop_path (float): Stochastic depth rate. Default: 0.0\n    \"\"\"\n    def __init__(self, dim, drop_path=0.):\n        super().__init__()\n        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv\n        self.norm = LayerNorm(dim, eps=1e-6)\n        self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers\n        self.act = nn.GELU()\n        self.grn = GRN(4 * dim)\n        self.pwconv2 = nn.Linear(4 * dim, dim)\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n    def forward(self, x):\n        input = x\n        x = self.dwconv(x)\n        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)\n        x = self.norm(x)\n        x = self.pwconv1(x)\n        x = self.act(x)\n        x = self.grn(x)\n        x = self.pwconv2(x)\n        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)\n\n        x = input + self.drop_path(x)\n        return x\n\nclass LayerNorm(nn.Module):\n    \"\"\" LayerNorm that supports two data formats: channels_last (default) or channels_first. \n    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with \n    shape (batch_size, height, width, channels) while channels_first corresponds to inputs \n    with shape (batch_size, channels, height, width).\n    \"\"\"\n    def __init__(self, normalized_shape, eps=1e-6, data_format=\"channels_last\"):\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(normalized_shape))\n        self.bias = nn.Parameter(torch.zeros(normalized_shape))\n        self.eps = eps\n        self.data_format = data_format\n        if self.data_format not in [\"channels_last\", \"channels_first\"]:\n            raise NotImplementedError \n        self.normalized_shape = (normalized_shape, )\n    \n    def forward(self, x):\n        if self.data_format == \"channels_last\":\n            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)\n        elif self.data_format == \"channels_first\":\n            u = x.mean(1, keepdim=True)\n            s = (x - u).pow(2).mean(1, keepdim=True)\n            x = (x - u) / torch.sqrt(s + self.eps)\n            x = self.weight[:, None, None] * x + self.bias[:, None, None]\n            return x\n\nclass GRN(nn.Module):\n    \"\"\" GRN (Global Response Normalization) layer\n    \"\"\"\n    def __init__(self, dim):\n        super().__init__()\n        self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))\n        self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))\n\n    def forward(self, x):\n        Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True)\n        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)\n        return self.gamma * (x * Nx) + self.beta + x\n\n\nclass ConvNeXtV2(nn.Module):\n    \"\"\" ConvNeXt V2\n        \n    Args:\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]\n        dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]\n        drop_path_rate (float): Stochastic depth rate. Default: 0.\n        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.\n    \"\"\"\n    def __init__(self, in_chans=3, num_classes=1000, \n                 depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], \n                 drop_path_rate=0., head_init_scale=1.\n                 ):\n        super().__init__()\n        self.depths = depths\n        self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers\n        stem = nn.Sequential(\n            nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),\n            LayerNorm(dims[0], eps=1e-6, data_format=\"channels_first\")\n        )\n        self.downsample_layers.append(stem)\n        for i in range(3):\n            downsample_layer = nn.Sequential(\n                    LayerNorm(dims[i], eps=1e-6, data_format=\"channels_first\"),\n                    nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),\n            )\n            self.downsample_layers.append(downsample_layer)\n\n        self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks\n        dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] \n        cur = 0\n        for i in range(4):\n            stage = nn.Sequential(\n                *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]\n            )\n            self.stages.append(stage)\n            cur += depths[i]\n\n        self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer\n        self.head = nn.Linear(dims[-1], num_classes)\n\n        self.apply(self._init_weights)\n        self.head.weight.data.mul_(head_init_scale)\n        self.head.bias.data.mul_(head_init_scale)\n\n    def _init_weights(self, m):\n        if isinstance(m, (nn.Conv2d, nn.Linear)):\n            trunc_normal_(m.weight, std=.02)\n            nn.init.constant_(m.bias, 0)\n\n    def forward_features(self, x):\n        for i in range(4):\n            x = self.downsample_layers[i](x)\n            x = self.stages[i](x)\n        return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\ndef convnextv2_atto(**kwargs):\n    model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs)\n    return model\n\ndef convnextv2_femto(**kwargs):\n    model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs)\n    return model\n\ndef convnext_pico(**kwargs):\n    model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs)\n    return model\n\ndef convnextv2_nano(**kwargs):\n    model = ConvNeXtV2(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs)\n    return model\n\ndef convnextv2_tiny(**kwargs):\n    model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)\n    return model\n\ndef convnextv2_base(**kwargs):\n    model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)\n    return model\n\ndef convnextv2_large(**kwargs):\n    model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)\n    return model\n\ndef convnextv2_huge(**kwargs):\n    model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs)\n    return model\n\nif __name__ == \"__main__\":\n    model = convnextv2_atto()\n    input = torch.randn(1, 3, 224, 224)\n    out = model(input)\n    print(out.shape)"
  },
  {
    "path": "model/backbone/resnet.py",
    "content": "import torch\nfrom torch import nn\n\n\n\"\"\"\n    # in_channel:输入block之前的通道数\n    # channel:在block中间处理的时候的通道数（这个值是输出维度的1/4)\n    # channel * block.expansion:输出的维度\n\"\"\"\nclass BottleNeck(nn.Module):\n    expansion = 4\n    def __init__(self,in_channel,channel,stride=1,downsample=None):\n        super().__init__()\n\n        self.conv1=nn.Conv2d(in_channel,channel,kernel_size=1,stride=stride,bias=False)\n        self.bn1=nn.BatchNorm2d(channel)\n\n        self.conv2=nn.Conv2d(channel,channel,kernel_size=3,padding=1,bias=False,stride=1)\n        self.bn2=nn.BatchNorm2d(channel)\n\n        self.conv3=nn.Conv2d(channel,channel*self.expansion,kernel_size=1,stride=1,bias=False)\n        self.bn3=nn.BatchNorm2d(channel*self.expansion)\n\n        self.relu=nn.ReLU(False)\n\n        self.downsample=downsample\n        self.stride=stride\n\n    def forward(self,x):\n        residual=x\n\n        out=self.relu(self.bn1(self.conv1(x))) #bs,c,h,w\n        out=self.relu(self.bn2(self.conv2(out))) #bs,c,h,w\n        out=self.relu(self.bn3(self.conv3(out))) #bs,4c,h,w\n\n        if(self.downsample != None):\n            residual=self.downsample(residual)\n\n        out+=residual\n        return self.relu(out)\n\n    \nclass ResNet(nn.Module):\n    def __init__(self,block,layers,num_classes=1000):\n        super().__init__()\n        #定义输入模块的维度\n        self.in_channel=64\n        ### stem layer\n        self.conv1=nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False)\n        self.bn1=nn.BatchNorm2d(64)\n        self.relu=nn.ReLU(False)\n        self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=0,ceil_mode=True)\n\n        ### main layer\n        self.layer1=self._make_layer(block,64,layers[0])\n        self.layer2=self._make_layer(block,128,layers[1],stride=2)\n        self.layer3=self._make_layer(block,256,layers[2],stride=2)\n        self.layer4=self._make_layer(block,512,layers[3],stride=2)\n\n        #classifier\n        self.avgpool=nn.AdaptiveAvgPool2d(1)\n        self.classifier=nn.Linear(512*block.expansion,num_classes)\n        self.softmax=nn.Softmax(-1)\n\n    def forward(self,x):\n        ##stem layer\n        out=self.relu(self.bn1(self.conv1(x))) #bs,112,112,64\n        out=self.maxpool(out) #bs,56,56,64\n\n        ##layers:\n        out=self.layer1(out) #bs,56,56,64*4\n        out=self.layer2(out) #bs,28,28,128*4\n        out=self.layer3(out) #bs,14,14,256*4\n        out=self.layer4(out) #bs,7,7,512*4\n\n        ##classifier\n        out=self.avgpool(out) #bs,1,1,512*4\n        out=out.reshape(out.shape[0],-1) #bs,512*4\n        out=self.classifier(out) #bs,1000\n        out=self.softmax(out)\n\n        return out\n\n        \n    \n    def _make_layer(self,block,channel,blocks,stride=1):\n        # downsample 主要用来处理H(x)=F(x)+x中F(x)和x的channel维度不匹配问题，即对残差结构的输入进行升维，在做残差相加的时候，必须保证残差的纬度与真正的输出维度（宽、高、以及深度）相同\n        # 比如步长！=1 或者 in_channel!=channel&self.expansion\n        downsample = None\n        if(stride!=1 or self.in_channel!=channel*block.expansion):\n            self.downsample=nn.Conv2d(self.in_channel,channel*block.expansion,stride=stride,kernel_size=1,bias=False)\n        #第一个conv部分，可能需要downsample\n        layers=[]\n        layers.append(block(self.in_channel,channel,downsample=self.downsample,stride=stride))\n        self.in_channel=channel*block.expansion\n        for _ in range(1,blocks):\n            layers.append(block(self.in_channel,channel))\n        return nn.Sequential(*layers)\n\n\ndef ResNet50(num_classes=1000):\n    return ResNet(BottleNeck,[3,4,6,3],num_classes=num_classes)\n\n\ndef ResNet101(num_classes=1000):\n    return ResNet(BottleNeck,[3,4,23,3],num_classes=num_classes)\n\n\ndef ResNet152(num_classes=1000):\n    return ResNet(BottleNeck,[3,8,36,3],num_classes=num_classes)\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,3,224,224)\n    resnet50=ResNet50(1000)\n    # resnet101=ResNet101(1000)\n    # resnet152=ResNet152(1000)\n    out=resnet50(input)\n    print(out.shape)\n\n    "
  },
  {
    "path": "model/backbone/resnext.py",
    "content": "import torch\nfrom torch import nn\n\n\n\"\"\"\n    # in_channel:输入block之前的通道数\n    # channel:在block中间处理的时候的通道数（这个值是输出维度的1/4)\n    # channel * block.expansion:输出的维度\n\"\"\"\nclass BottleNeck(nn.Module):\n    expansion = 2\n    def __init__(self,in_channel,channel,stride=1,C=32,downsample=None):\n        super().__init__()\n\n        self.conv1=nn.Conv2d(in_channel,channel,kernel_size=1,stride=stride,bias=False)\n        self.bn1=nn.BatchNorm2d(channel)\n\n        self.conv2=nn.Conv2d(channel,channel,kernel_size=3,padding=1,bias=False,stride=1,groups=C)\n        self.bn2=nn.BatchNorm2d(channel)\n\n        self.conv3=nn.Conv2d(channel,channel*self.expansion,kernel_size=1,stride=1,bias=False)\n        self.bn3=nn.BatchNorm2d(channel*self.expansion)\n\n        self.relu=nn.ReLU(False)\n\n        self.downsample=downsample\n        self.stride=stride\n\n    def forward(self,x):\n        residual=x\n\n        out=self.relu(self.bn1(self.conv1(x))) #bs,c,h,w\n        out=self.relu(self.bn2(self.conv2(out))) #bs,c,h,w\n        out=self.relu(self.bn3(self.conv3(out))) #bs,4c,h,w\n\n        if(self.downsample != None):\n            residual=self.downsample(residual)\n\n        out+=residual\n        return self.relu(out)\n\n    \nclass ResNeXt(nn.Module):\n    def __init__(self,block,layers,num_classes=1000):\n        super().__init__()\n        #定义输入模块的维度\n        self.in_channel=64\n        ### stem layer\n        self.conv1=nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False)\n        self.bn1=nn.BatchNorm2d(64)\n        self.relu=nn.ReLU(False)\n        self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=0,ceil_mode=True)\n\n        ### main layer\n        self.layer1=self._make_layer(block,128,layers[0])\n        self.layer2=self._make_layer(block,256,layers[1],stride=2)\n        self.layer3=self._make_layer(block,512,layers[2],stride=2)\n        self.layer4=self._make_layer(block,1024,layers[3],stride=2)\n\n        #classifier\n        self.avgpool=nn.AdaptiveAvgPool2d(1)\n        self.classifier=nn.Linear(1024*block.expansion,num_classes)\n        self.softmax=nn.Softmax(-1)\n\n    def forward(self,x):\n        ##stem layer\n        out=self.relu(self.bn1(self.conv1(x))) #bs,112,112,64\n        out=self.maxpool(out) #bs,56,56,64\n\n        ##layers:\n        out=self.layer1(out) #bs,56,56,128*2\n        out=self.layer2(out) #bs,28,28,256*2\n        out=self.layer3(out) #bs,14,14,512*2\n        out=self.layer4(out) #bs,7,7,1024*2\n\n        ##classifier\n        out=self.avgpool(out) #bs,1,1,1024*2\n        out=out.reshape(out.shape[0],-1) #bs,1024*2\n        out=self.classifier(out) #bs,1000\n        out=self.softmax(out)\n\n        return out\n\n        \n    \n    def _make_layer(self,block,channel,blocks,stride=1):\n        # downsample 主要用来处理H(x)=F(x)+x中F(x)和x的channel维度不匹配问题，即对残差结构的输入进行升维，在做残差相加的时候，必须保证残差的纬度与真正的输出维度（宽、高、以及深度）相同\n        # 比如步长！=1 或者 in_channel!=channel&self.expansion\n        downsample = None\n        if(stride!=1 or self.in_channel!=channel*block.expansion):\n            self.downsample=nn.Conv2d(self.in_channel,channel*block.expansion,stride=stride,kernel_size=1,bias=False)\n        #第一个conv部分，可能需要downsample\n        layers=[]\n        layers.append(block(self.in_channel,channel,downsample=self.downsample,stride=stride))\n        self.in_channel=channel*block.expansion\n        for _ in range(1,blocks):\n            layers.append(block(self.in_channel,channel))\n        return nn.Sequential(*layers)\n\n\ndef ResNeXt50(num_classes=1000):\n    return ResNeXt(BottleNeck,[3,4,6,3],num_classes=num_classes)\n\n\ndef ResNeXt101(num_classes=1000):\n    return ResNeXt(BottleNeck,[3,4,23,3],num_classes=num_classes)\n\n\ndef ResNeXt152(num_classes=1000):\n    return ResNeXt(BottleNeck,[3,8,36,3],num_classes=num_classes)\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,3,224,224)\n    resnext50=ResNeXt50(1000)\n    # resnext101=ResNeXt101(1000)\n    # resnext152=ResNeXt152(1000)\n    out=resnext50(input)\n    print(out.shape)\n\n    "
  },
  {
    "path": "model/backbone/swin_transformer.py",
    "content": "\"\"\" Swin Transformer\nA PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`\n    - https://arxiv.org/pdf/2103.14030\nCode/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below\nS3 (AutoFormerV2, https://arxiv.org/abs/2111.14725) Swin weights from\n    - https://github.com/microsoft/Cream/tree/main/AutoFormerV2\nModifications and additions for timm hacked together by / Copyright 2021, Ross Wightman\n\"\"\"\n# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\nimport logging\nimport math\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\n\nfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\nfrom timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert\nfrom ._builder import build_model_with_cfg\nfrom ._features_fx import register_notrace_function\nfrom ._manipulate import checkpoint_seq, named_apply\nfrom ._registry import register_model\nfrom .vision_transformer import checkpoint_filter_fn, get_init_weights_vit\n\n__all__ = ['SwinTransformer']  # model_registry will add each entrypoint fn to this\n\n_logger = logging.getLogger(__name__)\n\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n        'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,\n        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,\n        'first_conv': 'patch_embed.proj', 'classifier': 'head',\n        **kwargs\n    }\n\n\ndefault_cfgs = {\n    'swin_base_patch4_window12_384': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',\n        input_size=(3, 384, 384), crop_pct=1.0),\n\n    'swin_base_patch4_window7_224': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',\n    ),\n\n    'swin_large_patch4_window12_384': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',\n        input_size=(3, 384, 384), crop_pct=1.0),\n\n    'swin_large_patch4_window7_224': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',\n    ),\n\n    'swin_small_patch4_window7_224': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',\n    ),\n\n    'swin_tiny_patch4_window7_224': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',\n    ),\n\n    'swin_base_patch4_window12_384_in22k': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',\n        input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),\n\n    'swin_base_patch4_window7_224_in22k': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',\n        num_classes=21841),\n\n    'swin_large_patch4_window12_384_in22k': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',\n        input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),\n\n    'swin_large_patch4_window7_224_in22k': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',\n        num_classes=21841),\n\n    'swin_s3_tiny_224': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth'\n    ),\n    'swin_s3_small_224': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth'\n    ),\n    'swin_s3_base_224': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'\n    )\n}\n\n\ndef window_partition(x, window_size: int):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\n \n@register_notrace_function  # reason: int argument is a Proxy\ndef window_reverse(windows, window_size: int, H: int, W: int):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\ndef get_relative_position_index(win_h, win_w):\n    # get pair-wise relative position index for each token inside the window\n    coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)]))  # 2, Wh, Ww\n    coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n    relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n    relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n    relative_coords[:, :, 0] += win_h - 1  # shift to start from 0\n    relative_coords[:, :, 1] += win_w - 1\n    relative_coords[:, :, 0] *= 2 * win_w - 1\n    return relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n\n\nclass WindowAttention(nn.Module):\n    r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n    Args:\n        dim (int): Number of input channels.\n        num_heads (int): Number of attention heads.\n        head_dim (int): Number of channels per head (dim // num_heads if not set)\n        window_size (tuple[int]): The height and width of the window.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(self, dim, num_heads, head_dim=None, window_size=7, qkv_bias=True, attn_drop=0., proj_drop=0.):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = to_2tuple(window_size)  # Wh, Ww\n        win_h, win_w = self.window_size\n        self.window_area = win_h * win_w\n        self.num_heads = num_heads\n        head_dim = head_dim or dim // num_heads\n        attn_dim = head_dim * num_heads\n        self.scale = head_dim ** -0.5\n\n        # define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH\n        self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads))\n\n        # get pair-wise relative position index for each token inside the window\n        self.register_buffer(\"relative_position_index\", get_relative_position_index(win_h, win_w))\n\n        self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(attn_dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        trunc_normal_(self.relative_position_bias_table, std=.02)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def _get_rel_pos_bias(self) -> torch.Tensor:\n        relative_position_bias = self.relative_position_bias_table[\n            self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n        return relative_position_bias.unsqueeze(0)\n\n    def forward(self, x, mask: Optional[torch.Tensor] = None):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n        attn = attn + self._get_rel_pos_bias()\n\n        if mask is not None:\n            num_win = mask.shape[0]\n            attn = attn.view(B_ // num_win, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, -1)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n#\nclass SwinTransformerBlock(nn.Module):\n    r\"\"\" Swin Transformer Block.\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        window_size (int): Window size.\n        num_heads (int): Number of attention heads.\n        head_dim (int): Enforce the number of channels per head\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(\n            self, dim, input_resolution, num_heads=4, head_dim=None, window_size=7, shift_size=0,\n            mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,\n            act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        if min(self.input_resolution) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            self.shift_size = 0\n            self.window_size = min(self.input_resolution)\n        assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n\n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim, num_heads=num_heads, head_dim=head_dim, window_size=to_2tuple(self.window_size),\n            qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)\n\n        if self.shift_size > 0:\n            # calculate attention mask for SW-MSA\n            H, W = self.input_resolution\n            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1\n            cnt = 0\n            for h in (\n                    slice(0, -self.window_size),\n                    slice(-self.window_size, -self.shift_size),\n                    slice(-self.shift_size, None)):\n                for w in (\n                        slice(0, -self.window_size),\n                        slice(-self.window_size, -self.shift_size),\n                        slice(-self.shift_size, None)):\n                    img_mask[:, h, w, :] = cnt\n                    cnt += 1\n            mask_windows = window_partition(img_mask, self.window_size)  # num_win, window_size, window_size, 1\n            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n        else:\n            attn_mask = None\n\n        self.register_buffer(\"attn_mask\", attn_mask)\n\n    def forward(self, x):\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        _assert(L == H * W, \"input feature has wrong size\")\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # cyclic shift\n        if self.shift_size > 0:\n            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n        else:\n            shifted_x = x\n\n        # partition windows\n        x_windows = window_partition(shifted_x, self.window_size)  # num_win*B, window_size, window_size, C\n        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # num_win*B, window_size*window_size, C\n\n        # W-MSA/SW-MSA\n        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # num_win*B, window_size*window_size, C\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C\n\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n        else:\n            x = shifted_x\n        x = x.view(B, H * W, C)\n\n        # FFN\n        x = shortcut + self.drop_path(x)\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n\nclass PatchMerging(nn.Module):\n    r\"\"\" Patch Merging Layer.\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, out_dim=None, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.out_dim = out_dim or 2 * dim\n        self.norm = norm_layer(4 * dim)\n        self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        _assert(L == H * W, \"input feature has wrong size\")\n        _assert(H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\")\n\n        x = x.view(B, H, W, C)\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.norm(x)\n        x = self.reduction(x)\n\n        return x\n\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic Swin Transformer layer for one stage.\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        head_dim (int): Channels per head (dim // num_heads if not set)\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n    \"\"\"\n\n    def __init__(\n            self, dim, out_dim, input_resolution, depth, num_heads=4, head_dim=None,\n            window_size=7, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,\n            drop_path=0., norm_layer=nn.LayerNorm, downsample=None):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.grad_checkpointing = False\n\n        # build blocks\n        self.blocks = nn.Sequential(*[\n            SwinTransformerBlock(\n                dim=dim, input_resolution=input_resolution, num_heads=num_heads, head_dim=head_dim,\n                window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2,\n                mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop,\n                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer)\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        if self.grad_checkpointing and not torch.jit.is_scripting():\n            x = checkpoint_seq(self.blocks, x)\n        else:\n            x = self.blocks(x)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x\n\n\nclass SwinTransformer(nn.Module):\n    r\"\"\" Swin Transformer\n        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -\n          https://arxiv.org/pdf/2103.14030\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 224\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each Swin Transformer layer.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        head_dim (int, tuple(int)):\n        window_size (int): Window size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        drop_rate (float): Dropout rate. Default: 0\n        attn_drop_rate (float): Attention dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n    \"\"\"\n\n    def __init__(\n            self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg',\n            embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), head_dim=None,\n            window_size=7, mlp_ratio=4., qkv_bias=True,\n            drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n            norm_layer=nn.LayerNorm, ape=False, patch_norm=True, weight_init='', **kwargs):\n        super().__init__()\n        assert global_pool in ('', 'avg')\n        self.num_classes = num_classes\n        self.global_pool = global_pool\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        self.patch_grid = self.patch_embed.grid_size\n\n        # absolute position embedding\n        self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) if ape else None\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # build layers\n        if not isinstance(embed_dim, (tuple, list)):\n            embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]\n        embed_out_dim = embed_dim[1:] + [None]\n        head_dim = to_ntuple(self.num_layers)(head_dim)\n        window_size = to_ntuple(self.num_layers)(window_size)\n        mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio)\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n        layers = []\n        for i in range(self.num_layers):\n            layers += [BasicLayer(\n                dim=embed_dim[i],\n                out_dim=embed_out_dim[i],\n                input_resolution=(self.patch_grid[0] // (2 ** i), self.patch_grid[1] // (2 ** i)),\n                depth=depths[i],\n                num_heads=num_heads[i],\n                head_dim=head_dim[i],\n                window_size=window_size[i],\n                mlp_ratio=mlp_ratio[i],\n                qkv_bias=qkv_bias,\n                drop=drop_rate,\n                attn_drop=attn_drop_rate,\n                drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],\n                norm_layer=norm_layer,\n                downsample=PatchMerging if (i < self.num_layers - 1) else None\n            )]\n        self.layers = nn.Sequential(*layers)\n\n        self.norm = norm_layer(self.num_features)\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n        if weight_init != 'skip':\n            self.init_weights(weight_init)\n\n    @torch.jit.ignore\n    def init_weights(self, mode=''):\n        assert mode in ('jax', 'jax_nlhb', 'moco', '')\n        if self.absolute_pos_embed is not None:\n            trunc_normal_(self.absolute_pos_embed, std=.02)\n        head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.\n        named_apply(get_init_weights_vit(mode, head_bias=head_bias), self)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        nwd = {'absolute_pos_embed'}\n        for n, _ in self.named_parameters():\n            if 'relative_position_bias_table' in n:\n                nwd.add(n)\n        return nwd\n\n    @torch.jit.ignore\n    def group_matcher(self, coarse=False):\n        return dict(\n            stem=r'^absolute_pos_embed|patch_embed',  # stem and embed\n            blocks=r'^layers\\.(\\d+)' if coarse else [\n                (r'^layers\\.(\\d+).downsample', (0,)),\n                (r'^layers\\.(\\d+)\\.\\w+\\.(\\d+)', None),\n                (r'^norm', (99999,)),\n            ]\n        )\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        for l in self.layers:\n            l.grad_checkpointing = enable\n\n    @torch.jit.ignore\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=None):\n        self.num_classes = num_classes\n        if global_pool is not None:\n            assert global_pool in ('', 'avg')\n            self.global_pool = global_pool\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        if self.absolute_pos_embed is not None:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n        x = self.layers(x)\n        x = self.norm(x)  # B L C\n        return x\n\n    def forward_head(self, x, pre_logits: bool = False):\n        if self.global_pool == 'avg':\n            x = x.mean(dim=1)\n        return x if pre_logits else self.head(x)\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.forward_head(x)\n        return x\n\n\ndef _create_swin_transformer(variant, pretrained=False, **kwargs):\n    model = build_model_with_cfg(\n        SwinTransformer, variant, pretrained,\n        pretrained_filter_fn=checkpoint_filter_fn,\n        **kwargs)\n\n    return model\n\n\n@register_model\ndef swin_base_patch4_window12_384(pretrained=False, **kwargs):\n    \"\"\" Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)\n    return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_base_patch4_window7_224(pretrained=False, **kwargs):\n    \"\"\" Swin-B @ 224x224, pretrained ImageNet-22k, fine tune 1k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)\n    return _create_swin_transformer('swin_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_large_patch4_window12_384(pretrained=False, **kwargs):\n    \"\"\" Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)\n    return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_large_patch4_window7_224(pretrained=False, **kwargs):\n    \"\"\" Swin-L @ 224x224, pretrained ImageNet-22k, fine tune 1k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)\n    return _create_swin_transformer('swin_large_patch4_window7_224', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_small_patch4_window7_224(pretrained=False, **kwargs):\n    \"\"\" Swin-S @ 224x224, trained ImageNet-1k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)\n    return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_tiny_patch4_window7_224(pretrained=False, **kwargs):\n    \"\"\" Swin-T @ 224x224, trained ImageNet-1k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs)\n    return _create_swin_transformer('swin_tiny_patch4_window7_224', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_base_patch4_window12_384_in22k(pretrained=False, **kwargs):\n    \"\"\" Swin-B @ 384x384, trained ImageNet-22k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)\n    return _create_swin_transformer('swin_base_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_base_patch4_window7_224_in22k(pretrained=False, **kwargs):\n    \"\"\" Swin-B @ 224x224, trained ImageNet-22k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)\n    return _create_swin_transformer('swin_base_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_large_patch4_window12_384_in22k(pretrained=False, **kwargs):\n    \"\"\" Swin-L @ 384x384, trained ImageNet-22k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)\n    return _create_swin_transformer('swin_large_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs):\n    \"\"\" Swin-L @ 224x224, trained ImageNet-22k\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)\n    return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_s3_tiny_224(pretrained=False, **kwargs):\n    \"\"\" Swin-S3-T @ 224x224, ImageNet-1k. https://arxiv.org/abs/2111.14725\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 6, 2),\n        num_heads=(3, 6, 12, 24), **kwargs)\n    return _create_swin_transformer('swin_s3_tiny_224', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_s3_small_224(pretrained=False, **kwargs):\n    \"\"\" Swin-S3-S @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=(14, 14, 14, 7), embed_dim=96, depths=(2, 2, 18, 2),\n        num_heads=(3, 6, 12, 24), **kwargs)\n    return _create_swin_transformer('swin_s3_small_224', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swin_s3_base_224(pretrained=False, **kwargs):\n    \"\"\" Swin-S3-B @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 30, 2),\n        num_heads=(3, 6, 12, 24), **kwargs)\n    return _create_swin_transformer('swin_s3_base_224', pretrained=pretrained, **model_kwargs)\n"
  },
  {
    "path": "model/backbone/swin_transformer_v2.py",
    "content": "\"\"\" Swin Transformer V2\nA PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`\n    - https://arxiv.org/abs/2111.09883\nCode/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below\nModifications and additions for timm hacked together by / Copyright 2022, Ross Wightman\n\"\"\"\n# --------------------------------------------------------\n# Swin Transformer V2\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\nimport math\nfrom typing import Tuple, Optional\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\n\nfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\nfrom timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert\nfrom ._builder import build_model_with_cfg\nfrom ._features_fx import register_notrace_function\nfrom ._registry import register_model\n\n__all__ = ['SwinTransformerV2']  # model_registry will add each entrypoint fn to this\n\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n        'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,\n        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,\n        'first_conv': 'patch_embed.proj', 'classifier': 'head',\n        **kwargs\n    }\n\n\ndefault_cfgs = {\n    'swinv2_tiny_window8_256': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',\n        input_size=(3, 256, 256)\n    ),\n    'swinv2_tiny_window16_256': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth',\n        input_size=(3, 256, 256)\n    ),\n    'swinv2_small_window8_256': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth',\n        input_size=(3, 256, 256)\n    ),\n    'swinv2_small_window16_256': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth',\n        input_size=(3, 256, 256)\n    ),\n    'swinv2_base_window8_256': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth',\n        input_size=(3, 256, 256)\n    ),\n    'swinv2_base_window16_256': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth',\n        input_size=(3, 256, 256)\n    ),\n\n    'swinv2_base_window12_192_22k': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth',\n        num_classes=21841, input_size=(3, 192, 192)\n    ),\n    'swinv2_base_window12to16_192to256_22kft1k': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth',\n        input_size=(3, 256, 256)\n    ),\n    'swinv2_base_window12to24_192to384_22kft1k': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth',\n        input_size=(3, 384, 384), crop_pct=1.0,\n    ),\n    'swinv2_large_window12_192_22k': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth',\n        num_classes=21841, input_size=(3, 192, 192)\n    ),\n    'swinv2_large_window12to16_192to256_22kft1k': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth',\n        input_size=(3, 256, 256)\n    ),\n    'swinv2_large_window12to24_192to384_22kft1k': _cfg(\n        url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth',\n        input_size=(3, 384, 384), crop_pct=1.0,\n    ),\n}\n\n\ndef window_partition(x, window_size: Tuple[int, int]):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)\n    return windows\n\n\n@register_notrace_function  # reason: int argument is a Proxy\ndef window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):\n    \"\"\"\n    Args:\n        windows: (num_windows * B, window_size[0], window_size[1], C)\n        window_size (Tuple[int, int]): Window size\n        img_size (Tuple[int, int]): Image size\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    H, W = img_size\n    B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))\n    x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowAttention(nn.Module):\n    r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n        pretrained_window_size (tuple[int]): The height and width of the window in pre-training.\n    \"\"\"\n\n    def __init__(\n            self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,\n            pretrained_window_size=[0, 0]):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.pretrained_window_size = pretrained_window_size\n        self.num_heads = num_heads\n\n        self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))\n\n        # mlp to generate continuous relative position bias\n        self.cpb_mlp = nn.Sequential(\n            nn.Linear(2, 512, bias=True),\n            nn.ReLU(inplace=True),\n            nn.Linear(512, num_heads, bias=False)\n        )\n\n        # get relative_coords_table\n        relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)\n        relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)\n        relative_coords_table = torch.stack(torch.meshgrid([\n            relative_coords_h,\n            relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0)  # 1, 2*Wh-1, 2*Ww-1, 2\n        if pretrained_window_size[0] > 0:\n            relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)\n            relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)\n        else:\n            relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)\n            relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)\n        relative_coords_table *= 8  # normalize to -8, 8\n        relative_coords_table = torch.sign(relative_coords_table) * torch.log2(\n            torch.abs(relative_coords_table) + 1.0) / math.log2(8)\n\n        self.register_buffer(\"relative_coords_table\", relative_coords_table, persistent=False)\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer(\"relative_position_index\", relative_position_index, persistent=False)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=False)\n        if qkv_bias:\n            self.q_bias = nn.Parameter(torch.zeros(dim))\n            self.register_buffer('k_bias', torch.zeros(dim), persistent=False)\n            self.v_bias = nn.Parameter(torch.zeros(dim))\n        else:\n            self.q_bias = None\n            self.k_bias = None\n            self.v_bias = None\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask: Optional[torch.Tensor] = None):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        qkv_bias = None\n        if self.q_bias is not None:\n            qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))\n        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)\n        qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv.unbind(0)\n\n        # cosine attention\n        attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))\n        logit_scale = torch.clamp(self.logit_scale, max=math.log(1. / 0.01)).exp()\n        attn = attn * logit_scale\n\n        relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)\n        relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n        relative_position_bias = 16 * torch.sigmoid(relative_position_bias)\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass SwinTransformerBlock(nn.Module):\n    r\"\"\" Swin Transformer Block.\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n        pretrained_window_size (int): Window size in pretraining.\n    \"\"\"\n\n    def __init__(\n            self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\n            mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,\n            act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = to_2tuple(input_resolution)\n        self.num_heads = num_heads\n        ws, ss = self._calc_window_shift(window_size, shift_size)\n        self.window_size: Tuple[int, int] = ws\n        self.shift_size: Tuple[int, int] = ss\n        self.window_area = self.window_size[0] * self.window_size[1]\n        self.mlp_ratio = mlp_ratio\n\n        self.attn = WindowAttention(\n            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,\n            qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,\n            pretrained_window_size=to_2tuple(pretrained_window_size))\n        self.norm1 = norm_layer(dim)\n        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)\n        self.norm2 = norm_layer(dim)\n        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n        if any(self.shift_size):\n            # calculate attention mask for SW-MSA\n            H, W = self.input_resolution\n            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1\n            cnt = 0\n            for h in (\n                    slice(0, -self.window_size[0]),\n                    slice(-self.window_size[0], -self.shift_size[0]),\n                    slice(-self.shift_size[0], None)):\n                for w in (\n                        slice(0, -self.window_size[1]),\n                        slice(-self.window_size[1], -self.shift_size[1]),\n                        slice(-self.shift_size[1], None)):\n                    img_mask[:, h, w, :] = cnt\n                    cnt += 1\n            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1\n            mask_windows = mask_windows.view(-1, self.window_area)\n            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n        else:\n            attn_mask = None\n\n        self.register_buffer(\"attn_mask\", attn_mask)\n\n    def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:\n        target_window_size = to_2tuple(target_window_size)\n        target_shift_size = to_2tuple(target_shift_size)\n        window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]\n        shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]\n        return tuple(window_size), tuple(shift_size)\n\n    def _attn(self, x):\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        _assert(L == H * W, \"input feature has wrong size\")\n        x = x.view(B, H, W, C)\n\n        # cyclic shift\n        has_shift = any(self.shift_size)\n        if has_shift:\n            shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))\n        else:\n            shifted_x = x\n\n        # partition windows\n        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C\n        x_windows = x_windows.view(-1, self.window_area, C)  # nW*B, window_size*window_size, C\n\n        # W-MSA/SW-MSA\n        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)\n        shifted_x = window_reverse(attn_windows, self.window_size, self.input_resolution)  # B H' W' C\n\n        # reverse cyclic shift\n        if has_shift:\n            x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))\n        else:\n            x = shifted_x\n        x = x.view(B, H * W, C)\n        return x\n\n    def forward(self, x):\n        x = x + self.drop_path1(self.norm1(self._attn(x)))\n        x = x + self.drop_path2(self.norm2(self.mlp(x)))\n        return x\n\n\nclass PatchMerging(nn.Module):\n    r\"\"\" Patch Merging Layer.\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(2 * dim)\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        _assert(L == H * W, \"input feature has wrong size\")\n        _assert(H % 2 == 0, f\"x size ({H}*{W}) are not even.\")\n        _assert(W % 2 == 0, f\"x size ({H}*{W}) are not even.\")\n\n        x = x.view(B, H, W, C)\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.reduction(x)\n        x = self.norm(x)\n\n        return x\n\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic Swin Transformer layer for one stage.\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        pretrained_window_size (int): Local window size in pre-training.\n    \"\"\"\n\n    def __init__(\n            self, dim, input_resolution, depth, num_heads, window_size,\n            mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,\n            norm_layer=nn.LayerNorm, downsample=None, pretrained_window_size=0):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.grad_checkpointing = False\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            SwinTransformerBlock(\n                dim=dim, input_resolution=input_resolution,\n                num_heads=num_heads, window_size=window_size,\n                shift_size=0 if (i % 2 == 0) else window_size // 2,\n                mlp_ratio=mlp_ratio,\n                qkv_bias=qkv_bias,\n                drop=drop, attn_drop=attn_drop,\n                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                norm_layer=norm_layer,\n                pretrained_window_size=pretrained_window_size)\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = nn.Identity()\n\n    def forward(self, x):\n        for blk in self.blocks:\n            if self.grad_checkpointing and not torch.jit.is_scripting():\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n        x = self.downsample(x)\n        return x\n\n    def _init_respostnorm(self):\n        for blk in self.blocks:\n            nn.init.constant_(blk.norm1.bias, 0)\n            nn.init.constant_(blk.norm1.weight, 0)\n            nn.init.constant_(blk.norm2.bias, 0)\n            nn.init.constant_(blk.norm2.weight, 0)\n\n\nclass SwinTransformerV2(nn.Module):\n    r\"\"\" Swin Transformer V2\n        A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`\n            - https://arxiv.org/abs/2111.09883\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 224\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each Swin Transformer layer.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        window_size (int): Window size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        drop_rate (float): Dropout rate. Default: 0\n        attn_drop_rate (float): Attention dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n        pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.\n    \"\"\"\n\n    def __init__(\n            self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg',\n            embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),\n            window_size=7, mlp_ratio=4., qkv_bias=True,\n            drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n            norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n            pretrained_window_sizes=(0, 0, 0, 0), **kwargs):\n        super().__init__()\n\n        self.num_classes = num_classes\n        assert global_pool in ('', 'avg')\n        self.global_pool = global_pool\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.patch_norm = patch_norm\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n\n        # absolute position embedding\n        if ape:\n            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n            trunc_normal_(self.absolute_pos_embed, std=.02)\n        else:\n            self.absolute_pos_embed = None\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(\n                dim=int(embed_dim * 2 ** i_layer),\n                input_resolution=(\n                    self.patch_embed.grid_size[0] // (2 ** i_layer),\n                    self.patch_embed.grid_size[1] // (2 ** i_layer)),\n                depth=depths[i_layer],\n                num_heads=num_heads[i_layer],\n                window_size=window_size,\n                mlp_ratio=mlp_ratio,\n                qkv_bias=qkv_bias,\n                drop=drop_rate, attn_drop=attn_drop_rate,\n                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n                norm_layer=norm_layer,\n                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n                pretrained_window_size=pretrained_window_sizes[i_layer]\n            )\n            self.layers.append(layer)\n\n        self.norm = norm_layer(self.num_features)\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n        self.apply(self._init_weights)\n        for bly in self.layers:\n            bly._init_respostnorm()\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        nod = {'absolute_pos_embed'}\n        for n, m in self.named_modules():\n            if any([kw in n for kw in (\"cpb_mlp\", \"logit_scale\", 'relative_position_bias_table')]):\n                nod.add(n)\n        return nod\n\n    @torch.jit.ignore\n    def group_matcher(self, coarse=False):\n        return dict(\n            stem=r'^absolute_pos_embed|patch_embed',  # stem and embed\n            blocks=r'^layers\\.(\\d+)' if coarse else [\n                (r'^layers\\.(\\d+).downsample', (0,)),\n                (r'^layers\\.(\\d+)\\.\\w+\\.(\\d+)', None),\n                (r'^norm', (99999,)),\n            ]\n        )\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        for l in self.layers:\n            l.grad_checkpointing = enable\n\n    @torch.jit.ignore\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=None):\n        self.num_classes = num_classes\n        if global_pool is not None:\n            assert global_pool in ('', 'avg')\n            self.global_pool = global_pool\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        if self.absolute_pos_embed is not None:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        for layer in self.layers:\n            x = layer(x)\n\n        x = self.norm(x)  # B L C\n        return x\n\n    def forward_head(self, x, pre_logits: bool = False):\n        if self.global_pool == 'avg':\n            x = x.mean(dim=1)\n        return x if pre_logits else self.head(x)\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.forward_head(x)\n        return x\n\n\ndef checkpoint_filter_fn(state_dict, model):\n    out_dict = {}\n    if 'model' in state_dict:\n        # For deit models\n        state_dict = state_dict['model']\n    for k, v in state_dict.items():\n        if any([n in k for n in ('relative_position_index', 'relative_coords_table')]):\n            continue  # skip buffers that should not be persistent\n        out_dict[k] = v\n    return out_dict\n\n\ndef _create_swin_transformer_v2(variant, pretrained=False, **kwargs):\n    model = build_model_with_cfg(\n        SwinTransformerV2, variant, pretrained,\n        pretrained_filter_fn=checkpoint_filter_fn,\n        **kwargs)\n    return model\n\n\n@register_model\ndef swinv2_tiny_window16_256(pretrained=False, **kwargs):\n    \"\"\"\n    \"\"\"\n    model_kwargs = dict(\n        window_size=16, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs)\n    return _create_swin_transformer_v2('swinv2_tiny_window16_256', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_tiny_window8_256(pretrained=False, **kwargs):\n    \"\"\"\n    \"\"\"\n    model_kwargs = dict(\n        window_size=8, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs)\n    return _create_swin_transformer_v2('swinv2_tiny_window8_256', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_small_window16_256(pretrained=False, **kwargs):\n    \"\"\"\n    \"\"\"\n    model_kwargs = dict(\n        window_size=16, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)\n    return _create_swin_transformer_v2('swinv2_small_window16_256', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_small_window8_256(pretrained=False, **kwargs):\n    \"\"\"\n    \"\"\"\n    model_kwargs = dict(\n        window_size=8, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)\n    return _create_swin_transformer_v2('swinv2_small_window8_256', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_base_window16_256(pretrained=False, **kwargs):\n    \"\"\"\n    \"\"\"\n    model_kwargs = dict(\n        window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)\n    return _create_swin_transformer_v2('swinv2_base_window16_256', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_base_window8_256(pretrained=False, **kwargs):\n    \"\"\"\n    \"\"\"\n    model_kwargs = dict(\n        window_size=8, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)\n    return _create_swin_transformer_v2('swinv2_base_window8_256', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_base_window12_192_22k(pretrained=False, **kwargs):\n    \"\"\"\n    \"\"\"\n    model_kwargs = dict(\n        window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)\n    return _create_swin_transformer_v2('swinv2_base_window12_192_22k', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_base_window12to16_192to256_22kft1k(pretrained=False, **kwargs):\n    \"\"\"\n    \"\"\"\n    model_kwargs = dict(\n        window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32),\n        pretrained_window_sizes=(12, 12, 12, 6), **kwargs)\n    return _create_swin_transformer_v2(\n        'swinv2_base_window12to16_192to256_22kft1k', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_base_window12to24_192to384_22kft1k(pretrained=False, **kwargs):\n    \"\"\"\n    \"\"\"\n    model_kwargs = dict(\n        window_size=24, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32),\n        pretrained_window_sizes=(12, 12, 12, 6), **kwargs)\n    return _create_swin_transformer_v2(\n        'swinv2_base_window12to24_192to384_22kft1k', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_large_window12_192_22k(pretrained=False, **kwargs):\n    \"\"\"\n    \"\"\"\n    model_kwargs = dict(\n        window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)\n    return _create_swin_transformer_v2('swinv2_large_window12_192_22k', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_large_window12to16_192to256_22kft1k(pretrained=False, **kwargs):\n    \"\"\"\n    \"\"\"\n    model_kwargs = dict(\n        window_size=16, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48),\n        pretrained_window_sizes=(12, 12, 12, 6), **kwargs)\n    return _create_swin_transformer_v2(\n        'swinv2_large_window12to16_192to256_22kft1k', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_large_window12to24_192to384_22kft1k(pretrained=False, **kwargs):\n    \"\"\"\n    \"\"\"\n    model_kwargs = dict(\n        window_size=24, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48),\n        pretrained_window_sizes=(12, 12, 12, 6), **kwargs)\n    return _create_swin_transformer_v2(\n        'swinv2_large_window12to24_192to384_22kft1k', pretrained=pretrained, **model_kwargs)\n"
  },
  {
    "path": "model/backbone/swin_transformer_v2_cr.py",
    "content": "\"\"\" Swin Transformer V2\nA PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`\n    - https://arxiv.org/pdf/2111.09883\nCode adapted from https://github.com/ChristophReich1996/Swin-Transformer-V2, original copyright/license info below\nThis implementation is experimental and subject to change in manners that will break weight compat:\n* Size of the pos embed MLP are not spelled out in paper in terms of dim, fixed for all models? vary with num_heads?\n  * currently dim is fixed, I feel it may make sense to scale with num_heads (dim per head)\n* The specifics of the memory saving 'sequential attention' are not detailed, Christoph Reich has an impl at\n  GitHub link above. It needs further investigation as throughput vs mem tradeoff doesn't appear beneficial.\n* num_heads per stage is not detailed for Huge and Giant model variants\n* 'Giant' is 3B params in paper but ~2.6B here despite matching paper dim + block counts\n* experiments are ongoing wrt to 'main branch' norm layer use and weight init scheme\nNoteworthy additions over official Swin v1:\n* MLP relative position embedding is looking promising and adapts to different image/window sizes\n* This impl has been designed to allow easy change of image size with matching window size changes\n* Non-square image size and window size are supported\nModifications and additions for timm hacked together by / Copyright 2022, Ross Wightman\n\"\"\"\n# --------------------------------------------------------\n# Swin Transformer V2 reimplementation\n# Copyright (c) 2021 Christoph Reich\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Christoph Reich\n# --------------------------------------------------------\nimport logging\nimport math\nfrom typing import Tuple, Optional, List, Union, Any, Type\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\n\nfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\nfrom timm.layers import DropPath, Mlp, to_2tuple, _assert\nfrom ._builder import build_model_with_cfg\nfrom ._features_fx import register_notrace_function\nfrom ._manipulate import named_apply\nfrom ._registry import register_model\n\n__all__ = ['SwinTransformerV2Cr']  # model_registry will add each entrypoint fn to this\n\n_logger = logging.getLogger(__name__)\n\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000,\n        'input_size': (3, 224, 224),\n        'pool_size': (7, 7),\n        'crop_pct': 0.9,\n        'interpolation': 'bicubic',\n        'fixed_input_size': True,\n        'mean': IMAGENET_DEFAULT_MEAN,\n        'std': IMAGENET_DEFAULT_STD,\n        'first_conv': 'patch_embed.proj',\n        'classifier': 'head',\n        **kwargs,\n    }\n\n\ndefault_cfgs = {\n    'swinv2_cr_tiny_384': _cfg(\n        url=\"\", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),\n    'swinv2_cr_tiny_224': _cfg(\n        url=\"\", input_size=(3, 224, 224), crop_pct=0.9),\n    'swinv2_cr_tiny_ns_224': _cfg(\n        url=\"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth\",\n        input_size=(3, 224, 224), crop_pct=0.9),\n    'swinv2_cr_small_384': _cfg(\n        url=\"\", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),\n    'swinv2_cr_small_224': _cfg(\n        url=\"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth\",\n        input_size=(3, 224, 224), crop_pct=0.9),\n    'swinv2_cr_small_ns_224': _cfg(\n        url=\"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth\",\n        input_size=(3, 224, 224), crop_pct=0.9),\n    'swinv2_cr_base_384': _cfg(\n        url=\"\", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),\n    'swinv2_cr_base_224': _cfg(\n        url=\"\", input_size=(3, 224, 224), crop_pct=0.9),\n    'swinv2_cr_base_ns_224': _cfg(\n        url=\"\", input_size=(3, 224, 224), crop_pct=0.9),\n    'swinv2_cr_large_384': _cfg(\n        url=\"\", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),\n    'swinv2_cr_large_224': _cfg(\n        url=\"\", input_size=(3, 224, 224), crop_pct=0.9),\n    'swinv2_cr_huge_384': _cfg(\n        url=\"\", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),\n    'swinv2_cr_huge_224': _cfg(\n        url=\"\", input_size=(3, 224, 224), crop_pct=0.9),\n    'swinv2_cr_giant_384': _cfg(\n        url=\"\", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),\n    'swinv2_cr_giant_224': _cfg(\n        url=\"\", input_size=(3, 224, 224), crop_pct=0.9),\n}\n\n\ndef bchw_to_bhwc(x: torch.Tensor) -> torch.Tensor:\n    \"\"\"Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C). \"\"\"\n    return x.permute(0, 2, 3, 1)\n\n\ndef bhwc_to_bchw(x: torch.Tensor) -> torch.Tensor:\n    \"\"\"Permutes a tensor from the shape (B, H, W, C) to (B, C, H, W). \"\"\"\n    return x.permute(0, 3, 1, 2)\n\n\ndef window_partition(x, window_size: Tuple[int, int]):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)\n    return windows\n\n\n@register_notrace_function  # reason: int argument is a Proxy\ndef window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):\n    \"\"\"\n    Args:\n        windows: (num_windows * B, window_size[0], window_size[1], C)\n        window_size (Tuple[int, int]): Window size\n        img_size (Tuple[int, int]): Image size\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    H, W = img_size\n    B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))\n    x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowMultiHeadAttention(nn.Module):\n    r\"\"\"This class implements window-based Multi-Head-Attention with log-spaced continuous position bias.\n    Args:\n        dim (int): Number of input features\n        window_size (int): Window size\n        num_heads (int): Number of attention heads\n        drop_attn (float): Dropout rate of attention map\n        drop_proj (float): Dropout rate after projection\n        meta_hidden_dim (int): Number of hidden features in the two layer MLP meta network\n        sequential_attn (bool): If true sequential self-attention is performed\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        window_size: Tuple[int, int],\n        drop_attn: float = 0.0,\n        drop_proj: float = 0.0,\n        meta_hidden_dim: int = 384,  # FIXME what's the optimal value?\n        sequential_attn: bool = False,\n    ) -> None:\n        super(WindowMultiHeadAttention, self).__init__()\n        assert dim % num_heads == 0, \\\n            \"The number of input features (in_features) are not divisible by the number of heads (num_heads).\"\n        self.in_features: int = dim\n        self.window_size: Tuple[int, int] = window_size\n        self.num_heads: int = num_heads\n        self.sequential_attn: bool = sequential_attn\n\n        self.qkv = nn.Linear(in_features=dim, out_features=dim * 3, bias=True)\n        self.attn_drop = nn.Dropout(drop_attn)\n        self.proj = nn.Linear(in_features=dim, out_features=dim, bias=True)\n        self.proj_drop = nn.Dropout(drop_proj)\n        # meta network for positional encodings\n        self.meta_mlp = Mlp(\n            2,  # x, y\n            hidden_features=meta_hidden_dim,\n            out_features=num_heads,\n            act_layer=nn.ReLU,\n            drop=(0.125, 0.)  # FIXME should there be stochasticity, appears to 'overfit' without?\n        )\n        # NOTE old checkpoints used inverse of logit_scale ('tau') following the paper, see conversion fn\n        self.logit_scale = nn.Parameter(torch.log(10 * torch.ones(num_heads)))\n        self._make_pair_wise_relative_positions()\n\n    def _make_pair_wise_relative_positions(self) -> None:\n        \"\"\"Method initializes the pair-wise relative positions to compute the positional biases.\"\"\"\n        device = self.logit_scale.device\n        coordinates = torch.stack(torch.meshgrid([\n            torch.arange(self.window_size[0], device=device),\n            torch.arange(self.window_size[1], device=device)]), dim=0).flatten(1)\n        relative_coordinates = coordinates[:, :, None] - coordinates[:, None, :]\n        relative_coordinates = relative_coordinates.permute(1, 2, 0).reshape(-1, 2).float()\n        relative_coordinates_log = torch.sign(relative_coordinates) * torch.log(\n            1.0 + relative_coordinates.abs())\n        self.register_buffer(\"relative_coordinates_log\", relative_coordinates_log, persistent=False)\n\n    def update_input_size(self, new_window_size: int, **kwargs: Any) -> None:\n        \"\"\"Method updates the window size and so the pair-wise relative positions\n        Args:\n            new_window_size (int): New window size\n            kwargs (Any): Unused\n        \"\"\"\n        # Set new window size and new pair-wise relative positions\n        self.window_size: int = new_window_size\n        self._make_pair_wise_relative_positions()\n\n    def _relative_positional_encodings(self) -> torch.Tensor:\n        \"\"\"Method computes the relative positional encodings\n        Returns:\n            relative_position_bias (torch.Tensor): Relative positional encodings\n            (1, number of heads, window size ** 2, window size ** 2)\n        \"\"\"\n        window_area = self.window_size[0] * self.window_size[1]\n        relative_position_bias = self.meta_mlp(self.relative_coordinates_log)\n        relative_position_bias = relative_position_bias.transpose(1, 0).reshape(\n            self.num_heads, window_area, window_area\n        )\n        relative_position_bias = relative_position_bias.unsqueeze(0)\n        return relative_position_bias\n\n    def _forward_sequential(\n        self,\n        x: torch.Tensor,\n        mask: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        \"\"\"\n        # FIXME TODO figure out 'sequential' attention mentioned in paper (should reduce GPU memory)\n        assert False, \"not implemented\"\n\n    def _forward_batch(\n        self,\n        x: torch.Tensor,\n        mask: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        \"\"\"This function performs standard (non-sequential) scaled cosine self-attention.\n        \"\"\"\n        Bw, L, C = x.shape\n\n        qkv = self.qkv(x).view(Bw, L, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        query, key, value = qkv.unbind(0)\n\n        # compute attention map with scaled cosine attention\n        attn = (F.normalize(query, dim=-1) @ F.normalize(key, dim=-1).transpose(-2, -1))\n        logit_scale = torch.clamp(self.logit_scale.reshape(1, self.num_heads, 1, 1), max=math.log(1. / 0.01)).exp()\n        attn = attn * logit_scale\n        attn = attn + self._relative_positional_encodings()\n\n        if mask is not None:\n            # Apply mask if utilized\n            num_win: int = mask.shape[0]\n            attn = attn.view(Bw // num_win, num_win, self.num_heads, L, L)\n            attn = attn + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, L, L)\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ value).transpose(1, 2).reshape(Bw, L, -1)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:\n        \"\"\" Forward pass.\n        Args:\n            x (torch.Tensor): Input tensor of the shape (B * windows, N, C)\n            mask (Optional[torch.Tensor]): Attention mask for the shift case\n        Returns:\n            Output tensor of the shape [B * windows, N, C]\n        \"\"\"\n        if self.sequential_attn:\n            return self._forward_sequential(x, mask)\n        else:\n            return self._forward_batch(x, mask)\n\n\nclass SwinTransformerBlock(nn.Module):\n    r\"\"\"This class implements the Swin transformer block.\n    Args:\n        dim (int): Number of input channels\n        num_heads (int): Number of attention heads to be utilized\n        feat_size (Tuple[int, int]): Input resolution\n        window_size (Tuple[int, int]): Window size to be utilized\n        shift_size (int): Shifting size to be used\n        mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels\n        drop (float): Dropout in input mapping\n        drop_attn (float): Dropout rate of attention map\n        drop_path (float): Dropout in main path\n        extra_norm (bool): Insert extra norm on 'main' branch if True\n        sequential_attn (bool): If true sequential self-attention is performed\n        norm_layer (Type[nn.Module]): Type of normalization layer to be utilized\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        feat_size: Tuple[int, int],\n        window_size: Tuple[int, int],\n        shift_size: Tuple[int, int] = (0, 0),\n        mlp_ratio: float = 4.0,\n        init_values: Optional[float] = 0,\n        drop: float = 0.0,\n        drop_attn: float = 0.0,\n        drop_path: float = 0.0,\n        extra_norm: bool = False,\n        sequential_attn: bool = False,\n        norm_layer: Type[nn.Module] = nn.LayerNorm,\n    ) -> None:\n        super(SwinTransformerBlock, self).__init__()\n        self.dim: int = dim\n        self.feat_size: Tuple[int, int] = feat_size\n        self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size)\n        self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size))\n        self.window_area = self.window_size[0] * self.window_size[1]\n        self.init_values: Optional[float] = init_values\n\n        # attn branch\n        self.attn = WindowMultiHeadAttention(\n            dim=dim,\n            num_heads=num_heads,\n            window_size=self.window_size,\n            drop_attn=drop_attn,\n            drop_proj=drop,\n            sequential_attn=sequential_attn,\n        )\n        self.norm1 = norm_layer(dim)\n        self.drop_path1 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity()\n\n        # mlp branch\n        self.mlp = Mlp(\n            in_features=dim,\n            hidden_features=int(dim * mlp_ratio),\n            drop=drop,\n            out_features=dim,\n        )\n        self.norm2 = norm_layer(dim)\n        self.drop_path2 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity()\n\n        # Extra main branch norm layer mentioned for Huge/Giant models in V2 paper.\n        # Also being used as final network norm and optional stage ending norm while still in a C-last format.\n        self.norm3 = norm_layer(dim) if extra_norm else nn.Identity()\n\n        self._make_attention_mask()\n        self.init_weights()\n\n    def _calc_window_shift(self, target_window_size):\n        window_size = [f if f <= w else w for f, w in zip(self.feat_size, target_window_size)]\n        shift_size = [0 if f <= w else s for f, w, s in zip(self.feat_size, window_size, self.target_shift_size)]\n        return tuple(window_size), tuple(shift_size)\n\n    def _make_attention_mask(self) -> None:\n        \"\"\"Method generates the attention mask used in shift case.\"\"\"\n        # Make masks for shift case\n        if any(self.shift_size):\n            # calculate attention mask for SW-MSA\n            H, W = self.feat_size\n            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1\n            cnt = 0\n            for h in (\n                    slice(0, -self.window_size[0]),\n                    slice(-self.window_size[0], -self.shift_size[0]),\n                    slice(-self.shift_size[0], None)):\n                for w in (\n                        slice(0, -self.window_size[1]),\n                        slice(-self.window_size[1], -self.shift_size[1]),\n                        slice(-self.shift_size[1], None)):\n                    img_mask[:, h, w, :] = cnt\n                    cnt += 1\n            mask_windows = window_partition(img_mask, self.window_size)  # num_windows, window_size, window_size, 1\n            mask_windows = mask_windows.view(-1, self.window_area)\n            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n        else:\n            attn_mask = None\n        self.register_buffer(\"attn_mask\", attn_mask, persistent=False)\n\n    def init_weights(self):\n        # extra, module specific weight init\n        if self.init_values is not None:\n            nn.init.constant_(self.norm1.weight, self.init_values)\n            nn.init.constant_(self.norm2.weight, self.init_values)\n\n    def update_input_size(self, new_window_size: Tuple[int, int], new_feat_size: Tuple[int, int]) -> None:\n        \"\"\"Method updates the image resolution to be processed and window size and so the pair-wise relative positions.\n        Args:\n            new_window_size (int): New window size\n            new_feat_size (Tuple[int, int]): New input resolution\n        \"\"\"\n        # Update input resolution\n        self.feat_size: Tuple[int, int] = new_feat_size\n        self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(new_window_size))\n        self.window_area = self.window_size[0] * self.window_size[1]\n        self.attn.update_input_size(new_window_size=self.window_size)\n        self._make_attention_mask()\n\n    def _shifted_window_attn(self, x):\n        H, W = self.feat_size\n        B, L, C = x.shape\n        x = x.view(B, H, W, C)\n\n        # cyclic shift\n        sh, sw = self.shift_size\n        do_shift: bool = any(self.shift_size)\n        if do_shift:\n            # FIXME PyTorch XLA needs cat impl, roll not lowered\n            # x = torch.cat([x[:, sh:], x[:, :sh]], dim=1)\n            # x = torch.cat([x[:, :, sw:], x[:, :, :sw]], dim=2)\n            x = torch.roll(x, shifts=(-sh, -sw), dims=(1, 2))\n\n        # partition windows\n        x_windows = window_partition(x, self.window_size)  # num_windows * B, window_size, window_size, C\n        x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)\n\n        # W-MSA/SW-MSA\n        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # num_windows * B, window_size * window_size, C\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)\n        x = window_reverse(attn_windows, self.window_size, self.feat_size)  # B H' W' C\n\n        # reverse cyclic shift\n        if do_shift:\n            # FIXME PyTorch XLA needs cat impl, roll not lowered\n            # x = torch.cat([x[:, -sh:], x[:, :-sh]], dim=1)\n            # x = torch.cat([x[:, :, -sw:], x[:, :, :-sw]], dim=2)\n            x = torch.roll(x, shifts=(sh, sw), dims=(1, 2))\n\n        x = x.view(B, L, C)\n        return x\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Forward pass.\n        Args:\n            x (torch.Tensor): Input tensor of the shape [B, C, H, W]\n        Returns:\n            output (torch.Tensor): Output tensor of the shape [B, C, H, W]\n        \"\"\"\n        # post-norm branches (op -> norm -> drop)\n        x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x)))\n        x = x + self.drop_path2(self.norm2(self.mlp(x)))\n        x = self.norm3(x)  # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant)\n        return x\n\n\nclass PatchMerging(nn.Module):\n    \"\"\" This class implements the patch merging as a strided convolution with a normalization before.\n    Args:\n        dim (int): Number of input channels\n        norm_layer (Type[nn.Module]): Type of normalization layer to be utilized.\n    \"\"\"\n\n    def __init__(self, dim: int, norm_layer: Type[nn.Module] = nn.LayerNorm) -> None:\n        super(PatchMerging, self).__init__()\n        self.norm = norm_layer(4 * dim)\n        self.reduction = nn.Linear(in_features=4 * dim, out_features=2 * dim, bias=False)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\" Forward pass.\n        Args:\n            x (torch.Tensor): Input tensor of the shape [B, C, H, W]\n        Returns:\n            output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]\n        \"\"\"\n        B, C, H, W = x.shape\n        # unfold + BCHW -> BHWC together\n        # ordering, 5, 3, 1 instead of 3, 5, 1 maintains compat with original swin v1 merge\n        x = x.reshape(B, C, H // 2, 2, W // 2, 2).permute(0, 2, 4, 5, 3, 1).flatten(3)\n        x = self.norm(x)\n        x = bhwc_to_bchw(self.reduction(x))\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" 2D Image to Patch Embedding \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])\n        self.num_patches = self.grid_size[0] * self.grid_size[1]\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        _assert(H == self.img_size[0], f\"Input image height ({H}) doesn't match model ({self.img_size[0]}).\")\n        _assert(W == self.img_size[1], f\"Input image width ({W}) doesn't match model ({self.img_size[1]}).\")\n        x = self.proj(x)\n        x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)\n        return x\n\n\nclass SwinTransformerStage(nn.Module):\n    r\"\"\"This class implements a stage of the Swin transformer including multiple layers.\n    Args:\n        embed_dim (int): Number of input channels\n        depth (int): Depth of the stage (number of layers)\n        downscale (bool): If true input is downsampled (see Fig. 3 or V1 paper)\n        feat_size (Tuple[int, int]): input feature map size (H, W)\n        num_heads (int): Number of attention heads to be utilized\n        window_size (int): Window size to be utilized\n        mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels\n        drop (float): Dropout in input mapping\n        drop_attn (float): Dropout rate of attention map\n        drop_path (float): Dropout in main path\n        norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm\n        extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks\n        extra_norm_stage (bool): End each stage with an extra norm layer in main branch\n        sequential_attn (bool): If true sequential self-attention is performed\n    \"\"\"\n\n    def __init__(\n        self,\n        embed_dim: int,\n        depth: int,\n        downscale: bool,\n        num_heads: int,\n        feat_size: Tuple[int, int],\n        window_size: Tuple[int, int],\n        mlp_ratio: float = 4.0,\n        init_values: Optional[float] = 0.0,\n        drop: float = 0.0,\n        drop_attn: float = 0.0,\n        drop_path: Union[List[float], float] = 0.0,\n        norm_layer: Type[nn.Module] = nn.LayerNorm,\n        extra_norm_period: int = 0,\n        extra_norm_stage: bool = False,\n        sequential_attn: bool = False,\n    ) -> None:\n        super(SwinTransformerStage, self).__init__()\n        self.downscale: bool = downscale\n        self.grad_checkpointing: bool = False\n        self.feat_size: Tuple[int, int] = (feat_size[0] // 2, feat_size[1] // 2) if downscale else feat_size\n\n        self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer) if downscale else nn.Identity()\n\n        def _extra_norm(index):\n            i = index + 1\n            if extra_norm_period and i % extra_norm_period == 0:\n                return True\n            return i == depth if extra_norm_stage else False\n\n        embed_dim = embed_dim * 2 if downscale else embed_dim\n        self.blocks = nn.Sequential(*[\n            SwinTransformerBlock(\n                dim=embed_dim,\n                num_heads=num_heads,\n                feat_size=self.feat_size,\n                window_size=window_size,\n                shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]),\n                mlp_ratio=mlp_ratio,\n                init_values=init_values,\n                drop=drop,\n                drop_attn=drop_attn,\n                drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path,\n                extra_norm=_extra_norm(index),\n                sequential_attn=sequential_attn,\n                norm_layer=norm_layer,\n            )\n            for index in range(depth)]\n        )\n\n    def update_input_size(self, new_window_size: int, new_feat_size: Tuple[int, int]) -> None:\n        \"\"\"Method updates the resolution to utilize and the window size and so the pair-wise relative positions.\n        Args:\n            new_window_size (int): New window size\n            new_feat_size (Tuple[int, int]): New input resolution\n        \"\"\"\n        self.feat_size: Tuple[int, int] = (\n            (new_feat_size[0] // 2, new_feat_size[1] // 2) if self.downscale else new_feat_size\n        )\n        for block in self.blocks:\n            block.update_input_size(new_window_size=new_window_size, new_feat_size=self.feat_size)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Forward pass.\n        Args:\n            x (torch.Tensor): Input tensor of the shape [B, C, H, W] or [B, L, C]\n        Returns:\n            output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]\n        \"\"\"\n        x = self.downsample(x)\n        B, C, H, W = x.shape\n        L = H * W\n\n        x = bchw_to_bhwc(x).reshape(B, L, C)\n        for block in self.blocks:\n            # Perform checkpointing if utilized\n            if self.grad_checkpointing and not torch.jit.is_scripting():\n                x = checkpoint.checkpoint(block, x)\n            else:\n                x = block(x)\n        x = bhwc_to_bchw(x.reshape(B, H, W, -1))\n        return x\n\n\nclass SwinTransformerV2Cr(nn.Module):\n    r\"\"\" Swin Transformer V2\n        A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`  -\n          https://arxiv.org/pdf/2111.09883\n    Args:\n        img_size (Tuple[int, int]): Input resolution.\n        window_size (Optional[int]): Window size. If None, img_size // window_div. Default: None\n        img_window_ratio (int): Window size to image size ratio. Default: 32\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input channels.\n        depths (int): Depth of the stage (number of layers).\n        num_heads (int): Number of attention heads to be utilized.\n        embed_dim (int): Patch embedding dimension. Default: 96\n        num_classes (int): Number of output classes. Default: 1000\n        mlp_ratio (int):  Ratio of the hidden dimension in the FFN to the input channels. Default: 4\n        drop_rate (float): Dropout rate. Default: 0.0\n        attn_drop_rate (float): Dropout rate of attention map. Default: 0.0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.0\n        norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm\n        extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks in stage\n        extra_norm_stage (bool): End each stage with an extra norm layer in main branch\n        sequential_attn (bool): If true sequential self-attention is performed. Default: False\n    \"\"\"\n\n    def __init__(\n        self,\n        img_size: Tuple[int, int] = (224, 224),\n        patch_size: int = 4,\n        window_size: Optional[int] = None,\n        img_window_ratio: int = 32,\n        in_chans: int = 3,\n        num_classes: int = 1000,\n        embed_dim: int = 96,\n        depths: Tuple[int, ...] = (2, 2, 6, 2),\n        num_heads: Tuple[int, ...] = (3, 6, 12, 24),\n        mlp_ratio: float = 4.0,\n        init_values: Optional[float] = 0.,\n        drop_rate: float = 0.0,\n        attn_drop_rate: float = 0.0,\n        drop_path_rate: float = 0.0,\n        norm_layer: Type[nn.Module] = nn.LayerNorm,\n        extra_norm_period: int = 0,\n        extra_norm_stage: bool = False,\n        sequential_attn: bool = False,\n        global_pool: str = 'avg',\n        weight_init='skip',\n        **kwargs: Any\n    ) -> None:\n        super(SwinTransformerV2Cr, self).__init__()\n        img_size = to_2tuple(img_size)\n        window_size = tuple([\n            s // img_window_ratio for s in img_size]) if window_size is None else to_2tuple(window_size)\n\n        self.num_classes: int = num_classes\n        self.patch_size: int = patch_size\n        self.img_size: Tuple[int, int] = img_size\n        self.window_size: int = window_size\n        self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1))\n\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans,\n            embed_dim=embed_dim, norm_layer=norm_layer)\n        patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size\n\n        drop_path_rate = torch.linspace(0.0, drop_path_rate, sum(depths)).tolist()\n        stages = []\n        for index, (depth, num_heads) in enumerate(zip(depths, num_heads)):\n            stage_scale = 2 ** max(index - 1, 0)\n            stages.append(\n                SwinTransformerStage(\n                    embed_dim=embed_dim * stage_scale,\n                    depth=depth,\n                    downscale=index != 0,\n                    feat_size=(patch_grid_size[0] // stage_scale, patch_grid_size[1] // stage_scale),\n                    num_heads=num_heads,\n                    window_size=window_size,\n                    mlp_ratio=mlp_ratio,\n                    init_values=init_values,\n                    drop=drop_rate,\n                    drop_attn=attn_drop_rate,\n                    drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])],\n                    extra_norm_period=extra_norm_period,\n                    extra_norm_stage=extra_norm_stage or (index + 1) == len(depths),  # last stage ends w/ norm\n                    sequential_attn=sequential_attn,\n                    norm_layer=norm_layer,\n                )\n            )\n        self.stages = nn.Sequential(*stages)\n\n        self.global_pool: str = global_pool\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes else nn.Identity()\n\n        # current weight init skips custom init and uses pytorch layer defaults, seems to work well\n        # FIXME more experiments needed\n        if weight_init != 'skip':\n            named_apply(init_weights, self)\n\n    def update_input_size(\n            self,\n            new_img_size: Optional[Tuple[int, int]] = None,\n            new_window_size: Optional[int] = None,\n            img_window_ratio: int = 32,\n    ) -> None:\n        \"\"\"Method updates the image resolution to be processed and window size and so the pair-wise relative positions.\n        Args:\n            new_window_size (Optional[int]): New window size, if None based on new_img_size // window_div\n            new_img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used\n            img_window_ratio (int): divisor for calculating window size from image size\n        \"\"\"\n        # Check parameters\n        if new_img_size is None:\n            new_img_size = self.img_size\n        else:\n            new_img_size = to_2tuple(new_img_size)\n        if new_window_size is None:\n            new_window_size = tuple([s // img_window_ratio for s in new_img_size])\n        # Compute new patch resolution & update resolution of each stage\n        new_patch_grid_size = (new_img_size[0] // self.patch_size, new_img_size[1] // self.patch_size)\n        for index, stage in enumerate(self.stages):\n            stage_scale = 2 ** max(index - 1, 0)\n            stage.update_input_size(\n                new_window_size=new_window_size,\n                new_img_size=(new_patch_grid_size[0] // stage_scale, new_patch_grid_size[1] // stage_scale),\n            )\n\n    @torch.jit.ignore\n    def group_matcher(self, coarse=False):\n        return dict(\n            stem=r'^patch_embed',  # stem and embed\n            blocks=r'^stages\\.(\\d+)' if coarse else [\n                (r'^stages\\.(\\d+).downsample', (0,)),\n                (r'^stages\\.(\\d+)\\.\\w+\\.(\\d+)', None),\n            ]\n        )\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        for s in self.stages:\n            s.grad_checkpointing = enable\n\n    @torch.jit.ignore()\n    def get_classifier(self) -> nn.Module:\n        \"\"\"Method returns the classification head of the model.\n        Returns:\n            head (nn.Module): Current classification head\n        \"\"\"\n        return self.head\n\n    def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:\n        \"\"\"Method results the classification head\n        Args:\n            num_classes (int): Number of classes to be predicted\n            global_pool (str): Unused\n        \"\"\"\n        self.num_classes: int = num_classes\n        if global_pool is not None:\n            self.global_pool = global_pool\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.patch_embed(x)\n        x = self.stages(x)\n        return x\n\n    def forward_head(self, x, pre_logits: bool = False):\n        if self.global_pool == 'avg':\n            x = x.mean(dim=(2, 3))\n        return x if pre_logits else self.head(x)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.forward_features(x)\n        x = self.forward_head(x)\n        return x\n\n\ndef init_weights(module: nn.Module, name: str = ''):\n    # FIXME WIP determining if there's a better weight init\n    if isinstance(module, nn.Linear):\n        if 'qkv' in name:\n            # treat the weights of Q, K, V separately\n            val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))\n            nn.init.uniform_(module.weight, -val, val)\n        elif 'head' in name:\n            nn.init.zeros_(module.weight)\n        else:\n            nn.init.xavier_uniform_(module.weight)\n        if module.bias is not None:\n            nn.init.zeros_(module.bias)\n    elif hasattr(module, 'init_weights'):\n        module.init_weights()\n\n\ndef checkpoint_filter_fn(state_dict, model):\n    \"\"\" convert patch embedding weight from manual patchify + linear proj to conv\"\"\"\n    out_dict = {}\n    if 'model' in state_dict:\n        # For deit models\n        state_dict = state_dict['model']\n    for k, v in state_dict.items():\n        if 'tau' in k:\n            # convert old tau based checkpoints -> logit_scale (inverse)\n            v = torch.log(1 / v)\n            k = k.replace('tau', 'logit_scale')\n        out_dict[k] = v\n    return out_dict\n\n\ndef _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs):\n    if kwargs.get('features_only', None):\n        raise RuntimeError('features_only not implemented for Vision Transformer models.')\n    model = build_model_with_cfg(\n        SwinTransformerV2Cr, variant, pretrained,\n        pretrained_filter_fn=checkpoint_filter_fn,\n        **kwargs\n    )\n    return model\n\n\n@register_model\ndef swinv2_cr_tiny_384(pretrained=False, **kwargs):\n    \"\"\"Swin-T V2 CR @ 384x384, trained ImageNet-1k\"\"\"\n    model_kwargs = dict(\n        embed_dim=96,\n        depths=(2, 2, 6, 2),\n        num_heads=(3, 6, 12, 24),\n        **kwargs\n    )\n    return _create_swin_transformer_v2_cr('swinv2_cr_tiny_384', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_cr_tiny_224(pretrained=False, **kwargs):\n    \"\"\"Swin-T V2 CR @ 224x224, trained ImageNet-1k\"\"\"\n    model_kwargs = dict(\n        embed_dim=96,\n        depths=(2, 2, 6, 2),\n        num_heads=(3, 6, 12, 24),\n        **kwargs\n    )\n    return _create_swin_transformer_v2_cr('swinv2_cr_tiny_224', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_cr_tiny_ns_224(pretrained=False, **kwargs):\n    \"\"\"Swin-T V2 CR @ 224x224, trained ImageNet-1k w/ extra stage norms.\n    ** Experimental, may make default if results are improved. **\n    \"\"\"\n    model_kwargs = dict(\n        embed_dim=96,\n        depths=(2, 2, 6, 2),\n        num_heads=(3, 6, 12, 24),\n        extra_norm_stage=True,\n        **kwargs\n    )\n    return _create_swin_transformer_v2_cr('swinv2_cr_tiny_ns_224', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_cr_small_384(pretrained=False, **kwargs):\n    \"\"\"Swin-S V2 CR @ 384x384, trained ImageNet-1k\"\"\"\n    model_kwargs = dict(\n        embed_dim=96,\n        depths=(2, 2, 18, 2),\n        num_heads=(3, 6, 12, 24),\n        **kwargs\n    )\n    return _create_swin_transformer_v2_cr('swinv2_cr_small_384', pretrained=pretrained, **model_kwargs\n    )\n\n\n@register_model\ndef swinv2_cr_small_224(pretrained=False, **kwargs):\n    \"\"\"Swin-S V2 CR @ 224x224, trained ImageNet-1k\"\"\"\n    model_kwargs = dict(\n        embed_dim=96,\n        depths=(2, 2, 18, 2),\n        num_heads=(3, 6, 12, 24),\n        **kwargs\n    )\n    return _create_swin_transformer_v2_cr('swinv2_cr_small_224', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_cr_small_ns_224(pretrained=False, **kwargs):\n    \"\"\"Swin-S V2 CR @ 224x224, trained ImageNet-1k\"\"\"\n    model_kwargs = dict(\n        embed_dim=96,\n        depths=(2, 2, 18, 2),\n        num_heads=(3, 6, 12, 24),\n        extra_norm_stage=True,\n        **kwargs\n    )\n    return _create_swin_transformer_v2_cr('swinv2_cr_small_ns_224', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_cr_base_384(pretrained=False, **kwargs):\n    \"\"\"Swin-B V2 CR @ 384x384, trained ImageNet-1k\"\"\"\n    model_kwargs = dict(\n        embed_dim=128,\n        depths=(2, 2, 18, 2),\n        num_heads=(4, 8, 16, 32),\n        **kwargs\n    )\n    return _create_swin_transformer_v2_cr('swinv2_cr_base_384', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_cr_base_224(pretrained=False, **kwargs):\n    \"\"\"Swin-B V2 CR @ 224x224, trained ImageNet-1k\"\"\"\n    model_kwargs = dict(\n        embed_dim=128,\n        depths=(2, 2, 18, 2),\n        num_heads=(4, 8, 16, 32),\n        **kwargs\n    )\n    return _create_swin_transformer_v2_cr('swinv2_cr_base_224', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_cr_base_ns_224(pretrained=False, **kwargs):\n    \"\"\"Swin-B V2 CR @ 224x224, trained ImageNet-1k\"\"\"\n    model_kwargs = dict(\n        embed_dim=128,\n        depths=(2, 2, 18, 2),\n        num_heads=(4, 8, 16, 32),\n        extra_norm_stage=True,\n        **kwargs\n    )\n    return _create_swin_transformer_v2_cr('swinv2_cr_base_ns_224', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_cr_large_384(pretrained=False, **kwargs):\n    \"\"\"Swin-L V2 CR @ 384x384, trained ImageNet-1k\"\"\"\n    model_kwargs = dict(\n        embed_dim=192,\n        depths=(2, 2, 18, 2),\n        num_heads=(6, 12, 24, 48),\n        **kwargs\n    )\n    return _create_swin_transformer_v2_cr('swinv2_cr_large_384', pretrained=pretrained, **model_kwargs\n    )\n\n\n@register_model\ndef swinv2_cr_large_224(pretrained=False, **kwargs):\n    \"\"\"Swin-L V2 CR @ 224x224, trained ImageNet-1k\"\"\"\n    model_kwargs = dict(\n        embed_dim=192,\n        depths=(2, 2, 18, 2),\n        num_heads=(6, 12, 24, 48),\n        **kwargs\n    )\n    return _create_swin_transformer_v2_cr('swinv2_cr_large_224', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_cr_huge_384(pretrained=False, **kwargs):\n    \"\"\"Swin-H V2 CR @ 384x384, trained ImageNet-1k\"\"\"\n    model_kwargs = dict(\n        embed_dim=352,\n        depths=(2, 2, 18, 2),\n        num_heads=(11, 22, 44, 88),  # head count not certain for Huge, 384 & 224 trying diff values\n        extra_norm_period=6,\n        **kwargs\n    )\n    return _create_swin_transformer_v2_cr('swinv2_cr_huge_384', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_cr_huge_224(pretrained=False, **kwargs):\n    \"\"\"Swin-H V2 CR @ 224x224, trained ImageNet-1k\"\"\"\n    model_kwargs = dict(\n        embed_dim=352,\n        depths=(2, 2, 18, 2),\n        num_heads=(8, 16, 32, 64),  # head count not certain for Huge, 384 & 224 trying diff values\n        extra_norm_period=6,\n        **kwargs\n    )\n    return _create_swin_transformer_v2_cr('swinv2_cr_huge_224', pretrained=pretrained, **model_kwargs)\n\n\n@register_model\ndef swinv2_cr_giant_384(pretrained=False, **kwargs):\n    \"\"\"Swin-G V2 CR @ 384x384, trained ImageNet-1k\"\"\"\n    model_kwargs = dict(\n        embed_dim=512,\n        depths=(2, 2, 42, 2),\n        num_heads=(16, 32, 64, 128),\n        extra_norm_period=6,\n        **kwargs\n    )\n    return _create_swin_transformer_v2_cr('swinv2_cr_giant_384', pretrained=pretrained, **model_kwargs\n    )\n\n\n@register_model\ndef swinv2_cr_giant_224(pretrained=False, **kwargs):\n    \"\"\"Swin-G V2 CR @ 224x224, trained ImageNet-1k\"\"\"\n    model_kwargs = dict(\n        embed_dim=512,\n        depths=(2, 2, 42, 2),\n        num_heads=(16, 32, 64, 128),\n        extra_norm_period=6,\n        **kwargs\n    )\n    return _create_swin_transformer_v2_cr('swinv2_cr_giant_224', pretrained=pretrained, **model_kwargs)\n"
  },
  {
    "path": "model/conv/CondConv.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nclass Attention(nn.Module):\n    def __init__(self,in_planes,K,init_weight=True):\n        super().__init__()\n        self.avgpool=nn.AdaptiveAvgPool2d(1)\n        self.net=nn.Conv2d(in_planes,K,kernel_size=1,bias=False)\n        self.sigmoid=nn.Sigmoid()\n\n        if(init_weight):\n            self._initialize_weights()\n\n\n    def _initialize_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            if isinstance(m ,nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n    def forward(self,x):\n        att=self.avgpool(x) #bs,dim,1,1\n        att=self.net(att).view(x.shape[0],-1) #bs,K\n        return self.sigmoid(att)\n\nclass CondConv(nn.Module):\n    def __init__(self,in_planes,out_planes,kernel_size,stride,padding=0,dilation=1,grounps=1,bias=True,K=4,init_weight=True):\n        super().__init__()\n        self.in_planes=in_planes\n        self.out_planes=out_planes\n        self.kernel_size=kernel_size\n        self.stride=stride\n        self.padding=padding\n        self.dilation=dilation\n        self.groups=grounps\n        self.bias=bias\n        self.K=K\n        self.init_weight=init_weight\n        self.attention=Attention(in_planes=in_planes,K=K,init_weight=init_weight)\n\n        self.weight=nn.Parameter(torch.randn(K,out_planes,in_planes//grounps,kernel_size,kernel_size),requires_grad=True)\n        if(bias):\n            self.bias=nn.Parameter(torch.randn(K,out_planes),requires_grad=True)\n        else:\n            self.bias=None\n        \n        if(self.init_weight):\n            self._initialize_weights()\n\n        #TODO 初始化\n    def _initialize_weights(self):\n        for i in range(self.K):\n            nn.init.kaiming_uniform_(self.weight[i])\n\n    def forward(self,x):\n        bs,in_planels,h,w=x.shape\n        softmax_att=self.attention(x) #bs,K\n        x=x.view(1,-1,h,w)\n        weight=self.weight.view(self.K,-1) #K,-1\n        aggregate_weight=torch.mm(softmax_att,weight).view(bs*self.out_planes,self.in_planes//self.groups,self.kernel_size,self.kernel_size) #bs*out_p,in_p,k,k\n\n        if(self.bias is not None):\n            bias=self.bias.view(self.K,-1) #K,out_p\n            aggregate_bias=torch.mm(softmax_att,bias).view(-1) #bs,out_p\n            output=F.conv2d(x,weight=aggregate_weight,bias=aggregate_bias,stride=self.stride,padding=self.padding,groups=self.groups*bs,dilation=self.dilation)\n        else:\n            output=F.conv2d(x,weight=aggregate_weight,bias=None,stride=self.stride,padding=self.padding,groups=self.groups*bs,dilation=self.dilation)\n        \n        output=output.view(bs,self.out_planes,h,w)\n        return output\n\nif __name__ == '__main__':\n    input=torch.randn(2,32,64,64)\n    m=CondConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)\n    out=m(input)\n    print(out.shape)"
  },
  {
    "path": "model/conv/DepthwiseSeparableConvolution.py",
    "content": "import torch\nfrom torch import nn\n\nclass DepthwiseSeparableConvolution(nn.Module):\n    def __init__(self,in_ch,out_ch,kernel_size=3,stride=1,padding=1):\n        super().__init__()\n        self.depthwise_conv=nn.Conv2d(\n            in_channels=in_ch,\n            out_channels=in_ch,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=padding,\n            groups=in_ch\n        )\n        self.pointwise_conv=nn.Conv2d(\n            in_channels=in_ch,\n            out_channels=out_ch,\n            kernel_size=1,\n            stride=1,\n            padding=0,\n            groups=1\n        )\n        \n    def forward(self, x):\n        out=self.depthwise_conv(x)\n        out=self.pointwise_conv(out)\n        return out\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    dsconv=DepthwiseSeparableConvolution(3,64)\n    out=dsconv(input)\n    print(out.shape)\n    "
  },
  {
    "path": "model/conv/DynamicConv.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nclass Attention(nn.Module):\n    def __init__(self,in_planes,ratio,K,temprature=30,init_weight=True):\n        super().__init__()\n        self.avgpool=nn.AdaptiveAvgPool2d(1)\n        self.temprature=temprature\n        assert in_planes>ratio\n        hidden_planes=in_planes//ratio\n        self.net=nn.Sequential(\n            nn.Conv2d(in_planes,hidden_planes,kernel_size=1,bias=False),\n            nn.ReLU(),\n            nn.Conv2d(hidden_planes,K,kernel_size=1,bias=False)\n        )\n\n        if(init_weight):\n            self._initialize_weights()\n\n    def update_temprature(self):\n        if(self.temprature>1):\n            self.temprature-=1\n\n    def _initialize_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            if isinstance(m ,nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n    def forward(self,x):\n        att=self.avgpool(x) #bs,dim,1,1\n        att=self.net(att).view(x.shape[0],-1) #bs,K\n        return F.softmax(att/self.temprature,-1)\n\nclass DynamicConv(nn.Module):\n    def __init__(self,in_planes,out_planes,kernel_size,stride,padding=0,dilation=1,grounps=1,bias=True,K=4,temprature=30,ratio=4,init_weight=True):\n        super().__init__()\n        self.in_planes=in_planes\n        self.out_planes=out_planes\n        self.kernel_size=kernel_size\n        self.stride=stride\n        self.padding=padding\n        self.dilation=dilation\n        self.groups=grounps\n        self.bias=bias\n        self.K=K\n        self.init_weight=init_weight\n        self.attention=Attention(in_planes=in_planes,ratio=ratio,K=K,temprature=temprature,init_weight=init_weight)\n\n        self.weight=nn.Parameter(torch.randn(K,out_planes,in_planes//grounps,kernel_size,kernel_size),requires_grad=True)\n        if(bias):\n            self.bias=nn.Parameter(torch.randn(K,out_planes),requires_grad=True)\n        else:\n            self.bias=None\n        \n        if(self.init_weight):\n            self._initialize_weights()\n\n        #TODO 初始化\n    def _initialize_weights(self):\n        for i in range(self.K):\n            nn.init.kaiming_uniform_(self.weight[i])\n\n    def forward(self,x):\n        bs,in_planels,h,w=x.shape\n        softmax_att=self.attention(x) #bs,K\n        x=x.view(1,-1,h,w)\n        weight=self.weight.view(self.K,-1) #K,-1\n        aggregate_weight=torch.mm(softmax_att,weight).view(bs*self.out_planes,self.in_planes//self.groups,self.kernel_size,self.kernel_size) #bs*out_p,in_p,k,k\n\n        if(self.bias is not None):\n            bias=self.bias.view(self.K,-1) #K,out_p\n            aggregate_bias=torch.mm(softmax_att,bias).view(-1) #bs,out_p\n            output=F.conv2d(x,weight=aggregate_weight,bias=aggregate_bias,stride=self.stride,padding=self.padding,groups=self.groups*bs,dilation=self.dilation)\n        else:\n            output=F.conv2d(x,weight=aggregate_weight,bias=None,stride=self.stride,padding=self.padding,groups=self.groups*bs,dilation=self.dilation)\n        \n        output=output.view(bs,self.out_planes,h,w)\n        return output\n\nif __name__ == '__main__':\n    input=torch.randn(2,32,64,64)\n    m=DynamicConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)\n    out=m(input)\n    print(out.shape)"
  },
  {
    "path": "model/conv/HorNet.py",
    "content": "from functools import partial\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom timm.models.layers import trunc_normal_, DropPath\nfrom timm.models.registry import register_model\nimport torch.fft\n\ndef get_dwconv(dim, kernel, bias):\n    return nn.Conv2d(dim, dim, kernel_size=kernel, padding=(kernel-1)//2 ,bias=bias, groups=dim)\n\nclass GlobalLocalFilter(nn.Module):\n    # https://arxiv.org/abs/2207.14284\n    def __init__(self, dim, h=14, w=8):\n        super().__init__()\n        self.dw = nn.Conv2d(dim // 2, dim // 2, kernel_size=3, padding=1, bias=False, groups=dim // 2)\n        self.complex_weight = nn.Parameter(torch.randn(dim // 2, h, w, 2, dtype=torch.float32) * 0.02)\n        trunc_normal_(self.complex_weight, std=.02)\n        self.pre_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first')\n        self.post_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first')\n\n    def forward(self, x):\n        x = self.pre_norm(x)\n        x1, x2 = torch.chunk(x, 2, dim=1)\n        x1 = self.dw(x1)\n\n        x2 = x2.to(torch.float32)\n        B, C, a, b = x2.shape\n        x2 = torch.fft.rfft2(x2, dim=(2, 3), norm='ortho')\n\n        weight = self.complex_weight\n        if not weight.shape[1:3] == x2.shape[2:4]:\n            weight = F.interpolate(weight.permute(3,0,1,2), size=x2.shape[2:4], mode='bilinear', align_corners=True).permute(1,2,3,0)\n\n        weight = torch.view_as_complex(weight.contiguous())\n\n        x2 = x2 * weight\n        x2 = torch.fft.irfft2(x2, s=(a, b), dim=(2, 3), norm='ortho')\n\n        x = torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)], dim=2).reshape(B, 2 * C, a, b)\n        x = self.post_norm(x)\n        return x\n\n\nclass gnconv(nn.Module):\n    def __init__(self, dim, order=5, gflayer=None, h=14, w=8, s=1.0):\n        super().__init__()\n        self.order = order\n        self.dims = [dim // 2 ** i for i in range(order)]\n        self.dims.reverse()\n        self.proj_in = nn.Conv2d(dim, 2*dim, 1)\n\n        if gflayer is None:\n            self.dwconv = get_dwconv(sum(self.dims), 7, True)\n        else:\n            self.dwconv = gflayer(sum(self.dims), h=h, w=w)\n        \n        self.proj_out = nn.Conv2d(dim, dim, 1)\n\n        self.pws = nn.ModuleList(\n            [nn.Conv2d(self.dims[i], self.dims[i+1], 1) for i in range(order-1)]\n        )\n\n        self.scale = s\n\n    def forward(self, x, mask=None, dummy=False):\n        B, C, H, W = x.shape\n\n        fused_x = self.proj_in(x)\n        pwa, abc = torch.split(fused_x, (self.dims[0], sum(self.dims)), dim=1)\n\n        dw_abc = self.dwconv(abc) * self.scale\n\n        dw_list = torch.split(dw_abc, self.dims, dim=1)\n        x = pwa * dw_list[0]\n\n        for i in range(self.order -1):\n            x = self.pws[i](x) * dw_list[i+1]\n\n        x = self.proj_out(x)\n\n        return x\n\nclass Block(nn.Module):\n    r\"\"\" HorNet block\n    \"\"\"\n    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, gnconv=gnconv):\n        super().__init__()\n\n        self.norm1 = LayerNorm(dim, eps=1e-6, data_format='channels_first')\n        self.gnconv = gnconv(dim) # depthwise conv\n        self.norm2 = LayerNorm(dim, eps=1e-6)\n        self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers\n        self.act = nn.GELU()\n        self.pwconv2 = nn.Linear(4 * dim, dim)\n\n        self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim), \n                                    requires_grad=True) if layer_scale_init_value > 0 else None\n\n        self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), \n                                    requires_grad=True) if layer_scale_init_value > 0 else None\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n    def forward(self, x):\n        B, C, H, W  = x.shape\n        if self.gamma1 is not None:\n            gamma1 = self.gamma1.view(C, 1, 1)\n        else:\n            gamma1 = 1\n        x = x + self.drop_path(gamma1 * self.gnconv(self.norm1(x)))\n\n        input = x\n        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)\n        x = self.norm2(x)\n        x = self.pwconv1(x)\n        x = self.act(x)\n        x = self.pwconv2(x)\n        if self.gamma2 is not None:\n            x = self.gamma2 * x\n        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)\n\n        x = input + self.drop_path(x)\n        return x\n\n\nclass HorNet(nn.Module):\n    def __init__(self, in_chans=3, num_classes=1000, \n                 depths=[3, 3, 9, 3], base_dim=96, drop_path_rate=0.,\n                 layer_scale_init_value=1e-6, head_init_scale=1.,\n                 gnconv=gnconv, block=Block, uniform_init=False, **kwargs\n                 ):\n        super().__init__()\n        dims = [base_dim, base_dim*2, base_dim*4, base_dim*8]\n\n        self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers\n        stem = nn.Sequential(\n            nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),\n            LayerNorm(dims[0], eps=1e-6, data_format=\"channels_first\")\n        )\n        self.downsample_layers.append(stem)\n        for i in range(3):\n            downsample_layer = nn.Sequential(\n                    LayerNorm(dims[i], eps=1e-6, data_format=\"channels_first\"),\n                    nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),\n            )\n            self.downsample_layers.append(downsample_layer)\n\n        self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks\n        dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] \n\n\n        if not isinstance(gnconv, list):\n            gnconv = [gnconv, gnconv, gnconv, gnconv]\n        else:\n            gnconv = gnconv\n            assert len(gnconv) == 4\n\n        cur = 0\n        for i in range(4):\n            stage = nn.Sequential(\n                *[block(dim=dims[i], drop_path=dp_rates[cur + j], \n                layer_scale_init_value=layer_scale_init_value, gnconv=gnconv[i]) for j in range(depths[i])]\n            )\n            self.stages.append(stage)\n            cur += depths[i]\n\n        self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer\n        self.head = nn.Linear(dims[-1], num_classes)\n\n        self.uniform_init = uniform_init\n\n        self.apply(self._init_weights)\n        self.head.weight.data.mul_(head_init_scale)\n        self.head.bias.data.mul_(head_init_scale)\n\n    def _init_weights(self, m):\n        if not self.uniform_init:\n            if isinstance(m, (nn.Conv2d, nn.Linear)):\n                trunc_normal_(m.weight, std=.02)\n                if hasattr(m, 'bias') and m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n        else:\n            if isinstance(m, (nn.Conv2d, nn.Linear)):\n                nn.init.xavier_uniform_(m.weight)\n                if hasattr(m, 'bias') and m.bias is not None:\n                    nn.init.constant_(m.bias, 0)            \n\n    def forward_features(self, x):\n        for i in range(4):\n            x = self.downsample_layers[i](x)\n            for j, blk in enumerate(self.stages[i]):\n                    x = blk(x)\n        return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\nclass LayerNorm(nn.Module):\n    r\"\"\" LayerNorm that supports two data formats: channels_last (default) or channels_first. \n    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with \n    shape (batch_size, height, width, channels) while channels_first corresponds to inputs \n    with shape (batch_size, channels, height, width).\n    \"\"\"\n    def __init__(self, normalized_shape, eps=1e-6, data_format=\"channels_last\"):\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(normalized_shape))\n        self.bias = nn.Parameter(torch.zeros(normalized_shape))\n        self.eps = eps\n        self.data_format = data_format\n        if self.data_format not in [\"channels_last\", \"channels_first\"]:\n            raise NotImplementedError \n        self.normalized_shape = (normalized_shape, )\n    \n    def forward(self, x):\n        if self.data_format == \"channels_last\":\n            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)\n        elif self.data_format == \"channels_first\":\n            u = x.mean(1, keepdim=True)\n            s = (x - u).pow(2).mean(1, keepdim=True)\n            x = (x - u) / torch.sqrt(s + self.eps)\n            x = self.weight[:, None, None] * x + self.bias[:, None, None]\n            return x\n\n@register_model\ndef hornet_tiny_7x7(pretrained=False,in_22k=False, **kwargs):\n    s = 1.0/3.0\n    model = HorNet(depths=[2, 3, 18, 2], base_dim=64, block=Block,\n    gnconv=[\n        partial(gnconv, order=2, s=s),\n        partial(gnconv, order=3, s=s),\n        partial(gnconv, order=4, s=s),\n        partial(gnconv, order=5, s=s),\n    ],\n    **kwargs\n    )\n    return model\n\n@register_model\ndef hornet_tiny_gf(pretrained=False,in_22k=False, **kwargs):\n    s = 1.0/3.0\n    model = HorNet(depths=[2, 3, 18, 2], base_dim=64, block=Block,\n    gnconv=[\n        partial(gnconv, order=2, s=s),\n        partial(gnconv, order=3, s=s),\n        partial(gnconv, order=4, s=s, h=14, w=8, gflayer=GlobalLocalFilter),\n        partial(gnconv, order=5, s=s, h=7, w=4, gflayer=GlobalLocalFilter),\n    ],\n    **kwargs\n    )\n    return model\n\n@register_model\ndef hornet_small_7x7(pretrained=False,in_22k=False, **kwargs):\n    s = 1.0/3.0\n    model = HorNet(depths=[2, 3, 18, 2], base_dim=96, block=Block,\n    gnconv=[\n        partial(gnconv, order=2, s=s),\n        partial(gnconv, order=3, s=s),\n        partial(gnconv, order=4, s=s),\n        partial(gnconv, order=5, s=s),\n    ],\n    **kwargs\n    )\n    return model\n\n@register_model\ndef hornet_small_gf(pretrained=False,in_22k=False, **kwargs):\n    s = 1.0/3.0\n    model = HorNet(depths=[2, 3, 18, 2], base_dim=96, block=Block,\n    gnconv=[\n        partial(gnconv, order=2, s=s),\n        partial(gnconv, order=3, s=s),\n        partial(gnconv, order=4, s=s, h=14, w=8, gflayer=GlobalLocalFilter),\n        partial(gnconv, order=5, s=s, h=7, w=4, gflayer=GlobalLocalFilter),\n    ],\n    **kwargs\n    )\n    return model\n\n@register_model\ndef hornet_base_7x7(pretrained=False,in_22k=False, **kwargs):\n    s = 1.0/3.0\n    model = HorNet(depths=[2, 3, 18, 2], base_dim=128, block=Block,\n    gnconv=[\n        partial(gnconv, order=2, s=s),\n        partial(gnconv, order=3, s=s),\n        partial(gnconv, order=4, s=s),\n        partial(gnconv, order=5, s=s),\n    ],\n    **kwargs\n    )\n    return model\n\n@register_model\ndef hornet_base_gf(pretrained=False,in_22k=False, **kwargs):\n    s = 1.0/3.0\n    model = HorNet(depths=[2, 3, 18, 2], base_dim=128, block=Block,\n    gnconv=[\n        partial(gnconv, order=2, s=s),\n        partial(gnconv, order=3, s=s),\n        partial(gnconv, order=4, s=s, h=14, w=8, gflayer=GlobalLocalFilter),\n        partial(gnconv, order=5, s=s, h=7, w=4, gflayer=GlobalLocalFilter),\n    ],\n    **kwargs\n    )\n    return model\n\n@register_model\ndef hornet_base_gf_img384(pretrained=False,in_22k=False, **kwargs):\n    s = 1.0/3.0\n    model = HorNet(depths=[2, 3, 18, 2], base_dim=128, block=Block,\n    gnconv=[\n        partial(gnconv, order=2, s=s),\n        partial(gnconv, order=3, s=s),\n        partial(gnconv, order=4, s=s, h=24, w=13, gflayer=GlobalLocalFilter),\n        partial(gnconv, order=5, s=s, h=12, w=7, gflayer=GlobalLocalFilter),\n    ],\n    **kwargs\n    )\n    return model\n\n@register_model\ndef hornet_large_7x7(pretrained=False,in_22k=False, **kwargs):\n    s = 1.0/3.0\n    model = HorNet(depths=[2, 3, 18, 2], base_dim=192, block=Block,\n    gnconv=[\n        partial(gnconv, order=2, s=s),\n        partial(gnconv, order=3, s=s),\n        partial(gnconv, order=4, s=s),\n        partial(gnconv, order=5, s=s),\n    ],\n    **kwargs\n    )\n    return model\n\n@register_model\ndef hornet_large_gf(pretrained=False,in_22k=False, **kwargs):\n    s = 1.0/3.0\n    model = HorNet(depths=[2, 3, 18, 2], base_dim=192, block=Block,\n    gnconv=[\n        partial(gnconv, order=2, s=s),\n        partial(gnconv, order=3, s=s),\n        partial(gnconv, order=4, s=s, h=14, w=8, gflayer=GlobalLocalFilter),\n        partial(gnconv, order=5, s=s, h=7, w=4, gflayer=GlobalLocalFilter),\n    ],\n    **kwargs\n    )\n    return model\n\n@register_model\ndef hornet_large_gf_img384(pretrained=False,in_22k=False, **kwargs):\n    s = 1.0/3.0\n    model = HorNet(depths=[2, 3, 18, 2], base_dim=192, block=Block,\n    gnconv=[\n        partial(gnconv, order=2, s=s),\n        partial(gnconv, order=3, s=s),\n        partial(gnconv, order=4, s=s, h=24, w=13, gflayer=GlobalLocalFilter),\n        partial(gnconv, order=5, s=s, h=12, w=7, gflayer=GlobalLocalFilter),\n    ],\n    **kwargs\n    )\n    return model\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    s = 1.0/3.0\n    model = HorNet(depths=[2, 3, 18, 2], base_dim=64, block=Block,\n    gnconv=[\n        partial(gnconv, order=2, s=s),\n        partial(gnconv, order=3, s=s),\n        partial(gnconv, order=4, s=s),\n        partial(gnconv, order=5, s=s),\n    ],\n    )\n    outputs = model(input)\n    print(outputs.shape)"
  },
  {
    "path": "model/conv/Involution.py",
    "content": "import math\nfrom functools import partial\n\nimport torch\nfrom torch import nn, select\nfrom torch.nn import functional as F\n\n\nclass Involution(nn.Module):\n    def __init__(self, kernel_size, in_channel=4, stride=1, group=1,ratio=4):\n        super().__init__()\n        self.kernel_size=kernel_size\n        self.in_channel=in_channel\n        self.stride=stride\n        self.group=group\n        assert self.in_channel%group==0\n        self.group_channel=self.in_channel//group\n        self.conv1=nn.Conv2d(\n            self.in_channel,\n            self.in_channel//ratio,\n            kernel_size=1\n        )\n        self.bn=nn.BatchNorm2d(in_channel//ratio)\n        self.relu=nn.ReLU()\n        self.conv2=nn.Conv2d(\n            self.in_channel//ratio,\n            self.group*self.kernel_size*self.kernel_size,\n            kernel_size=1\n        )\n        self.avgpool=nn.AvgPool2d(stride,stride) if stride>1 else nn.Identity()\n        self.unfold=nn.Unfold(kernel_size=kernel_size,stride=stride,padding=kernel_size//2)\n        \n\n    def forward(self, inputs):\n        B,C,H,W=inputs.shape\n        weight=self.conv2(self.relu(self.bn(self.conv1(self.avgpool(inputs))))) #(bs,G*K*K,H//stride,W//stride)\n        b,c,h,w=weight.shape\n        weight=weight.reshape(b,self.group,self.kernel_size*self.kernel_size,h,w).unsqueeze(2) #(bs,G,1,K*K,H//stride,W//stride)\n\n        x_unfold=self.unfold(inputs)\n        x_unfold=x_unfold.reshape(B,self.group,C//self.group,self.kernel_size*self.kernel_size,H//self.stride,W//self.stride) #(bs,G,G//C,K*K,H//stride,W//stride)\n\n        out=(x_unfold*weight).sum(dim=3)#(bs,G,G//C,1,H//stride,W//stride)\n        out=out.reshape(B,C,H//self.stride,W//self.stride) #(bs,C,H//stride,W//stride)\n\n        return out\n\nif __name__ == '__main__':\n    input=torch.randn(1,4,64,64)\n    involution=Involution(kernel_size=3,in_channel=4,stride=2)\n    out=involution(input)\n    print(out.shape)"
  },
  {
    "path": "model/conv/MBConv.py",
    "content": "import math\nfrom functools import partial\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nclass SwishImplementation(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, i):\n        result = i * torch.sigmoid(i)\n        ctx.save_for_backward(i)\n        return result\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        i = ctx.saved_variables[0]\n        sigmoid_i = torch.sigmoid(i)\n        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))\n\nclass MemoryEfficientSwish(nn.Module):\n    def forward(self, x):\n        return SwishImplementation.apply(x)\n\n\ndef drop_connect(inputs, p, training):\n    \"\"\" Drop connect. \"\"\"\n    if not training: return inputs\n    batch_size = inputs.shape[0]\n    keep_prob = 1 - p\n    random_tensor = keep_prob\n    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)\n    binary_tensor = torch.floor(random_tensor)\n    output = inputs / keep_prob * binary_tensor\n    return output\n\n\ndef get_same_padding_conv2d(image_size=None):\n     return partial(Conv2dStaticSamePadding, image_size=image_size)\n\ndef get_width_and_height_from_size(x):\n    \"\"\" Obtains width and height from a int or tuple \"\"\"\n    if isinstance(x, int): return x, x\n    if isinstance(x, list) or isinstance(x, tuple): return x\n    else: raise TypeError()\n\ndef calculate_output_image_size(input_image_size, stride):\n    \"\"\"\n    计算出 Conv2dSamePadding with a stride.\n    \"\"\"\n    if input_image_size is None: return None\n    image_height, image_width = get_width_and_height_from_size(input_image_size)\n    stride = stride if isinstance(stride, int) else stride[0]\n    image_height = int(math.ceil(image_height / stride))\n    image_width = int(math.ceil(image_width / stride))\n    return [image_height, image_width]\n\n\n\nclass Conv2dStaticSamePadding(nn.Conv2d):\n    \"\"\" 2D Convolutions like TensorFlow, for a fixed image size\"\"\"\n\n    def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):\n        super().__init__(in_channels, out_channels, kernel_size, **kwargs)\n        self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2\n\n        # Calculate padding based on image size and save it\n        assert image_size is not None\n        ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size\n        kh, kw = self.weight.size()[-2:]\n        sh, sw = self.stride\n        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)\n        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)\n        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)\n        if pad_h > 0 or pad_w > 0:\n            self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))\n        else:\n            self.static_padding = Identity()\n\n    def forward(self, x):\n        x = self.static_padding(x)\n        x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)\n        return x\n\nclass Identity(nn.Module):\n    def __init__(self, ):\n        super(Identity, self).__init__()\n\n    def forward(self, input):\n        return input\n\n\n# MBConvBlock\nclass MBConvBlock(nn.Module):\n    '''\n    层 ksize3*3 输入32 输出16  conv1  stride步长1\n    '''\n    def __init__(self, ksize, input_filters, output_filters, expand_ratio=1, stride=1, image_size=224):\n        super().__init__()\n        self._bn_mom = 0.1\n        self._bn_eps = 0.01\n        self._se_ratio = 0.25\n        self._input_filters = input_filters\n        self._output_filters = output_filters\n        self._expand_ratio = expand_ratio\n        self._kernel_size = ksize\n        self._stride = stride\n\n        inp = self._input_filters\n        oup = self._input_filters * self._expand_ratio\n        if self._expand_ratio != 1:\n            Conv2d = get_same_padding_conv2d(image_size=image_size)\n            self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)\n            self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)\n\n\n        # Depthwise convolution\n        k = self._kernel_size\n        s = self._stride\n        Conv2d = get_same_padding_conv2d(image_size=image_size)\n        self._depthwise_conv = Conv2d(\n            in_channels=oup, out_channels=oup, groups=oup,\n            kernel_size=k, stride=s, bias=False)\n        self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)\n        image_size = calculate_output_image_size(image_size, s)\n\n        # Squeeze and Excitation layer, if desired\n        Conv2d = get_same_padding_conv2d(image_size=(1,1))\n        num_squeezed_channels = max(1, int(self._input_filters * self._se_ratio))\n        self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)\n        self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)\n\n        # Output phase\n        final_oup = self._output_filters\n        Conv2d = get_same_padding_conv2d(image_size=image_size)\n        self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)\n        self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)\n        self._swish = MemoryEfficientSwish()\n\n    def forward(self, inputs, drop_connect_rate=None):\n        \"\"\"\n        :param inputs: input tensor\n        :param drop_connect_rate: drop connect rate (float, between 0 and 1)\n        :return: output of block\n        \"\"\"\n\n        # Expansion and Depthwise Convolution\n        x = inputs\n        if self._expand_ratio != 1:\n            expand = self._expand_conv(inputs)\n            bn0 = self._bn0(expand)\n            x = self._swish(bn0)\n        depthwise = self._depthwise_conv(x)\n        bn1 = self._bn1(depthwise)\n        x = self._swish(bn1)\n\n        # Squeeze and Excitation\n        x_squeezed = F.adaptive_avg_pool2d(x, 1)\n        x_squeezed = self._se_reduce(x_squeezed)\n        x_squeezed = self._swish(x_squeezed)\n        x_squeezed = self._se_expand(x_squeezed)\n        x = torch.sigmoid(x_squeezed) * x\n\n        x = self._bn2(self._project_conv(x))\n\n        # Skip connection and drop connect\n        input_filters, output_filters = self._input_filters, self._output_filters\n        if self._stride == 1 and input_filters == output_filters:\n            if drop_connect_rate:\n                x = drop_connect(x, p=drop_connect_rate, training=self.training)\n            x = x + inputs  # skip connection\n        return x\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,112,112)\n    mbconv=MBConvBlock(ksize=3,input_filters=3,output_filters=3,image_size=112)\n    out=mbconv(input)\n    print(out.shape)"
  },
  {
    "path": "model/fighingcv.egg-info/PKG-INFO",
    "content": "Metadata-Version: 2.1\nName: fighingcv\nVersion: 1.0.0\nSummary: Client library to download and publish models, datasets and other repos on the huggingface.co hub\nHome-page: https://github.com/xmu-xiaoma666/External-Attention-pytorch\nAuthor: Hugging Face, Inc.\nAuthor-email: julien@huggingface.co\nLicense: Apache\nKeywords: model-hub machine-learning models natural-language-processing deep-learning pytorch pretrained-models\nPlatform: UNKNOWN\nClassifier: Intended Audience :: Developers\nClassifier: Intended Audience :: Education\nClassifier: Intended Audience :: Science/Research\nClassifier: License :: OSI Approved :: Apache Software License\nClassifier: Operating System :: OS Independent\nClassifier: Programming Language :: Python :: 3\nClassifier: Topic :: Scientific/Engineering :: Artificial Intelligence\nRequires-Python: >=3.7.0\nDescription-Content-Type: text/markdown\nProvides-Extra: torch\nProvides-Extra: fastai\nProvides-Extra: tensorflow\nProvides-Extra: testing\nProvides-Extra: quality\nProvides-Extra: all\nProvides-Extra: dev\nLicense-File: LICENSE\n\n\n<img src=\"./FightingCVimg/LOGO.gif\" height=\"200\" width=\"400\"/>\n\n简体中文 | [English](./README_EN.md)\n\n# FightingCV 代码库， 包含 [***Attention***](#attention-series),[***Backbone***](#backbone-series), [***MLP***](#mlp-series), [***Re-parameter***](#re-parameter-series), [**Convolution**](#convolution-series)\n\n![](https://img.shields.io/badge/fightingcv-v0.0.1-brightgreen)\n![](https://img.shields.io/badge/python->=v3.0-blue)\n![](https://img.shields.io/badge/pytorch->=v1.4-red)\n\n<!--\n-------\n*If this project is helpful to you, welcome to give a ***star***.* \n\n*Don't forget to ***follow*** me to learn about project updates.*\n\n-->\n\n-------\n\n\n🔥🔥🔥 **重磅！！！作为项目补充，最近全新开源了一个目标检测代码库 [YOLOAir](https://github.com/iscyy/yoloair)，里面在目标检测算法中集成了各种Attention机制，代码简洁易读，欢迎大家来玩呀！**\n\n\n\n\n![image](https://user-images.githubusercontent.com/33897496/184842902-9acff374-b3e7-401a-80fd-9d484e40c637.png)\n\n\n\n-------\n\nHello，大家好，我是小马🚀🚀🚀\n\n***For 小白（Like Me）：***\n最近在读论文的时候会发现一个问题，有时候论文核心思想非常简单，核心代码可能也就十几行。但是打开作者release的源码时，却发现提出的模块嵌入到分类、检测、分割等任务框架中，导致代码比较冗余，对于特定任务框架不熟悉的我，**很难找到核心代码**，导致在论文和网络思想的理解上会有一定困难。\n\n***For 进阶者（Like You）：***\n如果把Conv、FC、RNN这些基本单元看做小的Lego积木，把Transformer、ResNet这些结构看成已经搭好的Lego城堡。那么本项目提供的模块就是一个个具有完整语义信息的Lego组件。**让科研工作者们避免反复造轮子**，只需思考如何利用这些“Lego组件”，搭建出更多绚烂多彩的作品。\n\n***For 大神（May Be Like You）：***\n能力有限，**不喜轻喷**！！！\n\n***For All：***\n本项目就是要实现一个既能**让深度学习小白也能搞懂**，又能**服务科研和工业社区**的代码库。作为[**FightingCV公众号**](https://mp.weixin.qq.com/s/m9RiivbbDPdjABsTd6q8FA)的补充，本项目的宗旨是从代码角度，实现🚀**让世界上没有难读的论文**🚀。\n\n（同时也非常欢迎各位科研工作者将自己的工作的核心代码整理到本项目中，推动科研社区的发展，会在readme中注明代码的作者~）\n\n\n\n\n\n\n## 公众号 & 微信交流群\n\n欢迎大家关注公众号：**FightingCV**\n\n公众号**每天**都会进行**论文、算法和代码的干货分享**哦~\n\n\n<!-- 已建立**机器学习/深度学习算法/计算机视觉/多模态交流群**微信交流群！\n\n（加不进去可以加微信：**775629340**，记得备注【**公司/学校+方向+ID**】） -->\n\n**每天在群里分享一些近期的论文和解析**，欢迎大家一起**学习交流**哈~~~\n（加不进去可以加微信：**775629340**，记得备注【**公司/学校+方向+ID**】）\n\n![](./FightingCVimg/wechat.jpg)\n\n强烈推荐大家关注[**知乎**](https://www.zhihu.com/people/jason-14-58-38/posts)账号和[**FightingCV公众号**](https://mp.weixin.qq.com/s/m9RiivbbDPdjABsTd6q8FA)，可以快速了解到最新优质的干货资源。\n\n\n\n\n***\n\n# 目录\n\n- [Attention Series](#attention-series)\n    - [1. External Attention Usage](#1-external-attention-usage)\n\n    - [2. Self Attention Usage](#2-self-attention-usage)\n\n    - [3. Simplified Self Attention Usage](#3-simplified-self-attention-usage)\n\n    - [4. Squeeze-and-Excitation Attention Usage](#4-squeeze-and-excitation-attention-usage)\n\n    - [5. SK Attention Usage](#5-sk-attention-usage)\n\n    - [6. CBAM Attention Usage](#6-cbam-attention-usage)\n\n    - [7. BAM Attention Usage](#7-bam-attention-usage)\n    \n    - [8. ECA Attention Usage](#8-eca-attention-usage)\n\n    - [9. DANet Attention Usage](#9-danet-attention-usage)\n\n    - [10. Pyramid Split Attention (PSA) Usage](#10-Pyramid-Split-Attention-Usage)\n\n    - [11. Efficient Multi-Head Self-Attention(EMSA) Usage](#11-Efficient-Multi-Head-Self-Attention-Usage)\n\n    - [12. Shuffle Attention Usage](#12-Shuffle-Attention-Usage)\n    \n    - [13. MUSE Attention Usage](#13-MUSE-Attention-Usage)\n  \n    - [14. SGE Attention Usage](#14-SGE-Attention-Usage)\n\n    - [15. A2 Attention Usage](#15-A2-Attention-Usage)\n\n    - [16. AFT Attention Usage](#16-AFT-Attention-Usage)\n\n    - [17. Outlook Attention Usage](#17-Outlook-Attention-Usage)\n\n    - [18. ViP Attention Usage](#18-ViP-Attention-Usage)\n\n    - [19. CoAtNet Attention Usage](#19-CoAtNet-Attention-Usage)\n\n    - [20. HaloNet Attention Usage](#20-HaloNet-Attention-Usage)\n\n    - [21. Polarized Self-Attention Usage](#21-Polarized-Self-Attention-Usage)\n\n    - [22. CoTAttention Usage](#22-CoTAttention-Usage)\n\n    - [23. Residual Attention Usage](#23-Residual-Attention-Usage)\n  \n    - [24. S2 Attention Usage](#24-S2-Attention-Usage)\n\n    - [25. GFNet Attention Usage](#25-GFNet-Attention-Usage)\n\n    - [26. Triplet Attention Usage](#26-TripletAttention-Usage)\n\n    - [27. Coordinate Attention Usage](#27-Coordinate-Attention-Usage)\n\n    - [28. MobileViT Attention Usage](#28-MobileViT-Attention-Usage)\n\n    - [29. ParNet Attention Usage](#29-ParNet-Attention-Usage)\n\n    - [30. UFO Attention Usage](#30-UFO-Attention-Usage)\n\n    - [31. ACmix Attention Usage](#31-Acmix-Attention-Usage)\n  \n    - [32. MobileViTv2 Attention Usage](#32-MobileViTv2-Attention-Usage)\n\n    - [33. DAT Attention Usage](#33-DAT-Attention-Usage)\n\n    - [34. CrossFormer Attention Usage](#34-CrossFormer-Attention-Usage)\n\n    - [35. MOATransformer Attention Usage](#35-MOATransformer-Attention-Usage)\n\n    - [36. CrissCrossAttention Attention Usage](#36-CrissCrossAttention-Attention-Usage)\n\n    - [37. Axial_attention Attention Usage](#37-Axial_attention-Attention-Usage)\n\n- [Backbone Series](#Backbone-series)\n\n    - [1. ResNet Usage](#1-ResNet-Usage)\n\n    - [2. ResNeXt Usage](#2-ResNeXt-Usage)\n\n    - [3. MobileViT Usage](#3-MobileViT-Usage)\n\n    - [4. ConvMixer Usage](#4-ConvMixer-Usage)\n\n    - [5. ShuffleTransformer Usage](#5-ShuffleTransformer-Usage)\n\n    - [6. ConTNet Usage](#6-ConTNet-Usage)\n\n    - [7. HATNet Usage](#7-HATNet-Usage)\n\n    - [8. CoaT Usage](#8-CoaT-Usage)\n\n    - [9. PVT Usage](#9-PVT-Usage)\n\n    - [10. CPVT Usage](#10-CPVT-Usage)\n\n    - [11. PIT Usage](#11-PIT-Usage)\n\n    - [12. CrossViT Usage](#12-CrossViT-Usage)\n\n    - [13. TnT Usage](#13-TnT-Usage)\n\n    - [14. DViT Usage](#14-DViT-Usage)\n\n    - [15. CeiT Usage](#15-CeiT-Usage)\n\n    - [16. ConViT Usage](#16-ConViT-Usage)\n\n    - [17. CaiT Usage](#17-CaiT-Usage)\n\n    - [18. PatchConvnet Usage](#18-PatchConvnet-Usage)\n\n    - [19. DeiT Usage](#19-DeiT-Usage)\n\n    - [20. LeViT Usage](#20-LeViT-Usage)\n\n    - [21. VOLO Usage](#21-VOLO-Usage)\n    \n    - [22. Container Usage](#22-Container-Usage)\n\n    - [23. CMT Usage](#23-CMT-Usage)\n\n\n- [MLP Series](#mlp-series)\n\n    - [1. RepMLP Usage](#1-RepMLP-Usage)\n\n    - [2. MLP-Mixer Usage](#2-MLP-Mixer-Usage)\n\n    - [3. ResMLP Usage](#3-ResMLP-Usage)\n\n    - [4. gMLP Usage](#4-gMLP-Usage)\n\n    - [5. sMLP Usage](#5-sMLP-Usage)\n\n    - [6. vip-mlp Usage](#6-vip-mlp-Usage)\n\n- [Re-Parameter(ReP) Series](#Re-Parameter-series)\n\n    - [1. RepVGG Usage](#1-RepVGG-Usage)\n\n    - [2. ACNet Usage](#2-ACNet-Usage)\n\n    - [3. Diverse Branch Block(DDB) Usage](#3-Diverse-Branch-Block-Usage)\n\n- [Convolution Series](#Convolution-series)\n\n    - [1. Depthwise Separable Convolution Usage](#1-Depthwise-Separable-Convolution-Usage)\n\n    - [2. MBConv Usage](#2-MBConv-Usage)\n\n    - [3. Involution Usage](#3-Involution-Usage)\n\n    - [4. DynamicConv Usage](#4-DynamicConv-Usage)\n\n    - [5. CondConv Usage](#5-CondConv-Usage)\n\n***\n\n\n# Attention Series\n\n- Pytorch implementation of [\"Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks---arXiv 2021.05.05\"](https://arxiv.org/abs/2105.02358)\n\n- Pytorch implementation of [\"Attention Is All You Need---NIPS2017\"](https://arxiv.org/pdf/1706.03762.pdf)\n\n- Pytorch implementation of [\"Squeeze-and-Excitation Networks---CVPR2018\"](https://arxiv.org/abs/1709.01507)\n\n- Pytorch implementation of [\"Selective Kernel Networks---CVPR2019\"](https://arxiv.org/pdf/1903.06586.pdf)\n\n- Pytorch implementation of [\"CBAM: Convolutional Block Attention Module---ECCV2018\"](https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf)\n\n- Pytorch implementation of [\"BAM: Bottleneck Attention Module---BMCV2018\"](https://arxiv.org/pdf/1807.06514.pdf)\n\n- Pytorch implementation of [\"ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks---CVPR2020\"](https://arxiv.org/pdf/1910.03151.pdf)\n\n- Pytorch implementation of [\"Dual Attention Network for Scene Segmentation---CVPR2019\"](https://arxiv.org/pdf/1809.02983.pdf)\n\n- Pytorch implementation of [\"EPSANet: An Efficient Pyramid Split Attention Block on Convolutional Neural Network---arXiv 2021.05.30\"](https://arxiv.org/pdf/2105.14447.pdf)\n\n- Pytorch implementation of [\"ResT: An Efficient Transformer for Visual Recognition---arXiv 2021.05.28\"](https://arxiv.org/abs/2105.13677)\n\n- Pytorch implementation of [\"SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS---ICASSP 2021\"](https://arxiv.org/pdf/2102.00240.pdf)\n\n- Pytorch implementation of [\"MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning---arXiv 2019.11.17\"](https://arxiv.org/abs/1911.09483)\n\n- Pytorch implementation of [\"Spatial Group-wise Enhance: Improving Semantic Feature Learning in Convolutional Networks---arXiv 2019.05.23\"](https://arxiv.org/pdf/1905.09646.pdf)\n\n- Pytorch implementation of [\"A2-Nets: Double Attention Networks---NIPS2018\"](https://arxiv.org/pdf/1810.11579.pdf)\n\n\n- Pytorch implementation of [\"An Attention Free Transformer---ICLR2021 (Apple New Work)\"](https://arxiv.org/pdf/2105.14103v1.pdf)\n\n\n- Pytorch implementation of [VOLO: Vision Outlooker for Visual Recognition---arXiv 2021.06.24\"](https://arxiv.org/abs/2106.13112) \n  [【论文解析】](https://zhuanlan.zhihu.com/p/385561050)\n\n\n- Pytorch implementation of [Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition---arXiv 2021.06.23](https://arxiv.org/abs/2106.12368) \n  [【论文解析】](https://mp.weixin.qq.com/s/5gonUQgBho_m2O54jyXF_Q)\n\n\n- Pytorch implementation of [CoAtNet: Marrying Convolution and Attention for All Data Sizes---arXiv 2021.06.09](https://arxiv.org/abs/2106.04803) \n  [【论文解析】](https://zhuanlan.zhihu.com/p/385578588)\n\n\n- Pytorch implementation of [Scaling Local Self-Attention for Parameter Efficient Visual Backbones---CVPR2021 Oral](https://arxiv.org/pdf/2103.12731.pdf)  [【论文解析】](https://zhuanlan.zhihu.com/p/388598744)\n\n\n\n- Pytorch implementation of [Polarized Self-Attention: Towards High-quality Pixel-wise Regression---arXiv 2021.07.02](https://arxiv.org/abs/2107.00782)  [【论文解析】](https://zhuanlan.zhihu.com/p/389770482) \n\n\n- Pytorch implementation of [Contextual Transformer Networks for Visual Recognition---arXiv 2021.07.26](https://arxiv.org/abs/2107.12292)  [【论文解析】](https://zhuanlan.zhihu.com/p/394795481) \n\n\n- Pytorch implementation of [Residual Attention: A Simple but Effective Method for Multi-Label Recognition---ICCV2021](https://arxiv.org/abs/2108.02456) \n\n\n- Pytorch implementation of [S²-MLPv2: Improved Spatial-Shift MLP Architecture for Vision---arXiv 2021.08.02](https://arxiv.org/abs/2108.01072) [【论文解析】](https://zhuanlan.zhihu.com/p/397003638) \n\n- Pytorch implementation of [Global Filter Networks for Image Classification---arXiv 2021.07.01](https://arxiv.org/abs/2107.00645) \n\n- Pytorch implementation of [Rotate to Attend: Convolutional Triplet Attention Module---WACV 2021](https://arxiv.org/abs/2010.03045) \n\n- Pytorch implementation of [Coordinate Attention for Efficient Mobile Network Design ---CVPR 2021](https://arxiv.org/abs/2103.02907)\n\n- Pytorch implementation of [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2021.10.05](https://arxiv.org/abs/2110.02178)\n\n- Pytorch implementation of [Non-deep Networks---ArXiv 2021.10.20](https://arxiv.org/abs/2110.07641)\n\n- Pytorch implementation of [UFO-ViT: High Performance Linear Vision Transformer without Softmax---ArXiv 2021.09.29](https://arxiv.org/abs/2109.14382)\n\n- Pytorch implementation of [Separable Self-attention for Mobile Vision Transformers---ArXiv 2022.06.06](https://arxiv.org/abs/2206.02680)\n\n- Pytorch implementation of [On the Integration of Self-Attention and Convolution---ArXiv 2022.03.14](https://arxiv.org/pdf/2111.14556.pdf)\n\n- Pytorch implementation of [CROSSFORMER: A VERSATILE VISION TRANSFORMER HINGING ON CROSS-SCALE ATTENTION---ICLR 2022](https://arxiv.org/pdf/2108.00154.pdf)\n\n- Pytorch implementation of [Aggregating Global Features into Local Vision Transformer](https://arxiv.org/abs/2201.12903)\n\n- Pytorch implementation of [CCNet: Criss-Cross Attention for Semantic Segmentation](https://arxiv.org/abs/1811.11721)\n\n- Pytorch implementation of [Axial Attention in Multidimensional Transformers](https://arxiv.org/abs/1912.12180)\n***\n\n\n### 1. External Attention Usage\n#### 1.1. Paper\n[\"Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks\"](https://arxiv.org/abs/2105.02358)\n\n#### 1.2. Overview\n![](./model/img/External_Attention.png)\n\n#### 1.3. Usage Code\n```python\nfrom model.attention.ExternalAttention import ExternalAttention\nimport torch\n\ninput=torch.randn(50,49,512)\nea = ExternalAttention(d_model=512,S=8)\noutput=ea(input)\nprint(output.shape)\n```\n\n***\n\n\n### 2. Self Attention Usage\n#### 2.1. Paper\n[\"Attention Is All You Need\"](https://arxiv.org/pdf/1706.03762.pdf)\n\n#### 1.2. Overview\n![](./model/img/SA.png)\n\n#### 1.3. Usage Code\n```python\nfrom model.attention.SelfAttention import ScaledDotProductAttention\nimport torch\n\ninput=torch.randn(50,49,512)\nsa = ScaledDotProductAttention(d_model=512, d_k=512, d_v=512, h=8)\noutput=sa(input,input,input)\nprint(output.shape)\n```\n\n***\n\n### 3. Simplified Self Attention Usage\n#### 3.1. Paper\n[None]()\n\n#### 3.2. Overview\n![](./model/img/SSA.png)\n\n#### 3.3. Usage Code\n```python\nfrom model.attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttention\nimport torch\n\ninput=torch.randn(50,49,512)\nssa = SimplifiedScaledDotProductAttention(d_model=512, h=8)\noutput=ssa(input,input,input)\nprint(output.shape)\n\n```\n\n***\n\n### 4. Squeeze-and-Excitation Attention Usage\n#### 4.1. Paper\n[\"Squeeze-and-Excitation Networks\"](https://arxiv.org/abs/1709.01507)\n\n#### 4.2. Overview\n![](./model/img/SE.png)\n\n#### 4.3. Usage Code\n```python\nfrom model.attention.SEAttention import SEAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\nse = SEAttention(channel=512,reduction=8)\noutput=se(input)\nprint(output.shape)\n\n```\n\n***\n\n### 5. SK Attention Usage\n#### 5.1. Paper\n[\"Selective Kernel Networks\"](https://arxiv.org/pdf/1903.06586.pdf)\n\n#### 5.2. Overview\n![](./model/img/SK.png)\n\n#### 5.3. Usage Code\n```python\nfrom model.attention.SKAttention import SKAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\nse = SKAttention(channel=512,reduction=8)\noutput=se(input)\nprint(output.shape)\n\n```\n***\n\n### 6. CBAM Attention Usage\n#### 6.1. Paper\n[\"CBAM: Convolutional Block Attention Module\"](https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf)\n\n#### 6.2. Overview\n![](./model/img/CBAM1.png)\n\n![](./model/img/CBAM2.png)\n\n#### 6.3. Usage Code\n```python\nfrom model.attention.CBAM import CBAMBlock\nimport torch\n\ninput=torch.randn(50,512,7,7)\nkernel_size=input.shape[2]\ncbam = CBAMBlock(channel=512,reduction=16,kernel_size=kernel_size)\noutput=cbam(input)\nprint(output.shape)\n\n```\n\n***\n\n### 7. BAM Attention Usage\n#### 7.1. Paper\n[\"BAM: Bottleneck Attention Module\"](https://arxiv.org/pdf/1807.06514.pdf)\n\n#### 7.2. Overview\n![](./model/img/BAM.png)\n\n#### 7.3. Usage Code\n```python\nfrom model.attention.BAM import BAMBlock\nimport torch\n\ninput=torch.randn(50,512,7,7)\nbam = BAMBlock(channel=512,reduction=16,dia_val=2)\noutput=bam(input)\nprint(output.shape)\n\n```\n\n***\n\n### 8. ECA Attention Usage\n#### 8.1. Paper\n[\"ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks\"](https://arxiv.org/pdf/1910.03151.pdf)\n\n#### 8.2. Overview\n![](./model/img/ECA.png)\n\n#### 8.3. Usage Code\n```python\nfrom model.attention.ECAAttention import ECAAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\neca = ECAAttention(kernel_size=3)\noutput=eca(input)\nprint(output.shape)\n\n```\n\n***\n\n### 9. DANet Attention Usage\n#### 9.1. Paper\n[\"Dual Attention Network for Scene Segmentation\"](https://arxiv.org/pdf/1809.02983.pdf)\n\n#### 9.2. Overview\n![](./model/img/danet.png)\n\n#### 9.3. Usage Code\n```python\nfrom model.attention.DANet import DAModule\nimport torch\n\ninput=torch.randn(50,512,7,7)\ndanet=DAModule(d_model=512,kernel_size=3,H=7,W=7)\nprint(danet(input).shape)\n\n```\n\n***\n\n### 10. Pyramid Split Attention Usage\n\n#### 10.1. Paper\n[\"EPSANet: An Efficient Pyramid Split Attention Block on Convolutional Neural Network\"](https://arxiv.org/pdf/2105.14447.pdf)\n\n#### 10.2. Overview\n![](./model/img/psa.png)\n\n#### 10.3. Usage Code\n```python\nfrom model.attention.PSA import PSA\nimport torch\n\ninput=torch.randn(50,512,7,7)\npsa = PSA(channel=512,reduction=8)\noutput=psa(input)\nprint(output.shape)\n\n```\n\n***\n\n\n### 11. Efficient Multi-Head Self-Attention Usage\n\n#### 11.1. Paper\n[\"ResT: An Efficient Transformer for Visual Recognition\"](https://arxiv.org/abs/2105.13677)\n\n#### 11.2. Overview\n![](./model/img/EMSA.png)\n\n#### 11.3. Usage Code\n```python\n\nfrom model.attention.EMSA import EMSA\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,64,512)\nemsa = EMSA(d_model=512, d_k=512, d_v=512, h=8,H=8,W=8,ratio=2,apply_transform=True)\noutput=emsa(input,input,input)\nprint(output.shape)\n    \n```\n\n***\n\n\n### 12. Shuffle Attention Usage\n\n#### 12.1. Paper\n[\"SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS\"](https://arxiv.org/pdf/2102.00240.pdf)\n\n#### 12.2. Overview\n![](./model/img/ShuffleAttention.jpg)\n\n#### 12.3. Usage Code\n```python\n\nfrom model.attention.ShuffleAttention import ShuffleAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\ninput=torch.randn(50,512,7,7)\nse = ShuffleAttention(channel=512,G=8)\noutput=se(input)\nprint(output.shape)\n\n    \n```\n\n\n***\n\n\n### 13. MUSE Attention Usage\n\n#### 13.1. Paper\n[\"MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning\"](https://arxiv.org/abs/1911.09483)\n\n#### 13.2. Overview\n![](./model/img/MUSE.png)\n\n#### 13.3. Usage Code\n```python\nfrom model.attention.MUSEAttention import MUSEAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\ninput=torch.randn(50,49,512)\nsa = MUSEAttention(d_model=512, d_k=512, d_v=512, h=8)\noutput=sa(input,input,input)\nprint(output.shape)\n\n```\n\n***\n\n\n### 14. SGE Attention Usage\n\n#### 14.1. Paper\n[Spatial Group-wise Enhance: Improving Semantic Feature Learning in Convolutional Networks](https://arxiv.org/pdf/1905.09646.pdf)\n\n#### 14.2. Overview\n![](./model/img/SGE.png)\n\n#### 14.3. Usage Code\n```python\nfrom model.attention.SGE import SpatialGroupEnhance\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\nsge = SpatialGroupEnhance(groups=8)\noutput=sge(input)\nprint(output.shape)\n\n```\n\n***\n\n\n### 15. A2 Attention Usage\n\n#### 15.1. Paper\n[A2-Nets: Double Attention Networks](https://arxiv.org/pdf/1810.11579.pdf)\n\n#### 15.2. Overview\n![](./model/img/A2.png)\n\n#### 15.3. Usage Code\n```python\nfrom model.attention.A2Atttention import DoubleAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\na2 = DoubleAttention(512,128,128,True)\noutput=a2(input)\nprint(output.shape)\n\n```\n\n\n\n### 16. AFT Attention Usage\n\n#### 16.1. Paper\n[An Attention Free Transformer](https://arxiv.org/pdf/2105.14103v1.pdf)\n\n#### 16.2. Overview\n![](./model/img/AFT.jpg)\n\n#### 16.3. Usage Code\n```python\nfrom model.attention.AFT import AFT_FULL\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,49,512)\naft_full = AFT_FULL(d_model=512, n=49)\noutput=aft_full(input)\nprint(output.shape)\n\n```\n\n\n\n\n\n\n### 17. Outlook Attention Usage\n\n#### 17.1. Paper\n\n\n[VOLO: Vision Outlooker for Visual Recognition\"](https://arxiv.org/abs/2106.13112)\n\n\n#### 17.2. Overview\n![](./model/img/OutlookAttention.png)\n\n#### 17.3. Usage Code\n```python\nfrom model.attention.OutlookAttention import OutlookAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,28,28,512)\noutlook = OutlookAttention(dim=512)\noutput=outlook(input)\nprint(output.shape)\n\n```\n\n\n***\n\n\n\n\n\n\n### 18. ViP Attention Usage\n\n#### 18.1. Paper\n\n\n[Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition\"](https://arxiv.org/abs/2106.12368)\n\n\n#### 18.2. Overview\n![](./model/img/ViP.png)\n\n#### 18.3. Usage Code\n```python\n\nfrom model.attention.ViP import WeightedPermuteMLP\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(64,8,8,512)\nseg_dim=8\nvip=WeightedPermuteMLP(512,seg_dim)\nout=vip(input)\nprint(out.shape)\n\n```\n\n\n***\n\n\n\n\n\n### 19. CoAtNet Attention Usage\n\n#### 19.1. Paper\n\n\n[CoAtNet: Marrying Convolution and Attention for All Data Sizes\"](https://arxiv.org/abs/2106.04803) \n\n\n#### 19.2. Overview\nNone\n\n\n#### 19.3. Usage Code\n```python\n\nfrom model.attention.CoAtNet import CoAtNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,3,224,224)\nmbconv=CoAtNet(in_ch=3,image_size=224)\nout=mbconv(input)\nprint(out.shape)\n\n```\n\n\n***\n\n\n\n\n\n\n### 20. HaloNet Attention Usage\n\n#### 20.1. Paper\n\n\n[Scaling Local Self-Attention for Parameter Efficient Visual Backbones\"](https://arxiv.org/pdf/2103.12731.pdf) \n\n\n#### 20.2. Overview\n\n![](./model/img/HaloNet.png)\n\n#### 20.3. Usage Code\n```python\n\nfrom model.attention.HaloAttention import HaloAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,512,8,8)\nhalo = HaloAttention(dim=512,\n    block_size=2,\n    halo_size=1,)\noutput=halo(input)\nprint(output.shape)\n\n```\n\n\n***\n\n### 21. Polarized Self-Attention Usage\n\n#### 21.1. Paper\n\n[Polarized Self-Attention: Towards High-quality Pixel-wise Regression\"](https://arxiv.org/abs/2107.00782)  \n\n\n#### 21.2. Overview\n\n![](./model/img/PoSA.png)\n\n#### 21.3. Usage Code\n```python\n\nfrom model.attention.PolarizedSelfAttention import ParallelPolarizedSelfAttention,SequentialPolarizedSelfAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,512,7,7)\npsa = SequentialPolarizedSelfAttention(channel=512)\noutput=psa(input)\nprint(output.shape)\n\n\n```\n\n\n***\n\n\n### 22. CoTAttention Usage\n\n#### 22.1. Paper\n\n[Contextual Transformer Networks for Visual Recognition---arXiv 2021.07.26](https://arxiv.org/abs/2107.12292) \n\n\n#### 22.2. Overview\n\n![](./model/img/CoT.png)\n\n#### 22.3. Usage Code\n```python\n\nfrom model.attention.CoTAttention import CoTAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\ncot = CoTAttention(dim=512,kernel_size=3)\noutput=cot(input)\nprint(output.shape)\n\n\n\n```\n\n***\n\n\n### 23. Residual Attention Usage\n\n#### 23.1. Paper\n\n[Residual Attention: A Simple but Effective Method for Multi-Label Recognition---ICCV2021](https://arxiv.org/abs/2108.02456) \n\n\n#### 23.2. Overview\n\n![](./model/img/ResAtt.png)\n\n#### 23.3. Usage Code\n```python\n\nfrom model.attention.ResidualAttention import ResidualAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\nresatt = ResidualAttention(channel=512,num_class=1000,la=0.2)\noutput=resatt(input)\nprint(output.shape)\n\n\n\n```\n\n***\n\n\n\n### 24. S2 Attention Usage\n\n#### 24.1. Paper\n\n[S²-MLPv2: Improved Spatial-Shift MLP Architecture for Vision---arXiv 2021.08.02](https://arxiv.org/abs/2108.01072) \n\n\n#### 24.2. Overview\n\n![](./model/img/S2Attention.png)\n\n#### 24.3. Usage Code\n```python\nfrom model.attention.S2Attention import S2Attention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\ns2att = S2Attention(channels=512)\noutput=s2att(input)\nprint(output.shape)\n\n```\n\n***\n\n\n\n### 25. GFNet Attention Usage\n\n#### 25.1. Paper\n\n[Global Filter Networks for Image Classification---arXiv 2021.07.01](https://arxiv.org/abs/2107.00645) \n\n\n#### 25.2. Overview\n\n![](./model/img/GFNet.jpg)\n\n#### 25.3. Usage Code - Implemented by [Wenliang Zhao (Author)](https://scholar.google.com/citations?user=lyPWvuEAAAAJ&hl=en)\n\n```python\nfrom model.attention.gfnet import GFNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nx = torch.randn(1, 3, 224, 224)\ngfnet = GFNet(embed_dim=384, img_size=224, patch_size=16, num_classes=1000)\nout = gfnet(x)\nprint(out.shape)\n\n```\n\n***\n\n\n### 26. TripletAttention Usage\n\n#### 26.1. Paper\n\n[Rotate to Attend: Convolutional Triplet Attention Module---CVPR 2021](https://arxiv.org/abs/2010.03045) \n\n#### 26.2. Overview\n\n![](./model/img/triplet.png)\n\n#### 26.3. Usage Code - Implemented by [digantamisra98](https://github.com/digantamisra98)\n\n```python\nfrom model.attention.TripletAttention import TripletAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\ninput=torch.randn(50,512,7,7)\ntriplet = TripletAttention()\noutput=triplet(input)\nprint(output.shape)\n```\n\n\n***\n\n\n### 27. Coordinate Attention Usage\n\n#### 27.1. Paper\n\n[Coordinate Attention for Efficient Mobile Network Design---CVPR 2021](https://arxiv.org/abs/2103.02907)\n\n\n#### 27.2. Overview\n\n![](./model/img/CoordAttention.png)\n\n#### 27.3. Usage Code - Implemented by [Andrew-Qibin](https://github.com/Andrew-Qibin)\n\n```python\nfrom model.attention.CoordAttention import CoordAtt\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninp=torch.rand([2, 96, 56, 56])\ninp_dim, oup_dim = 96, 96\nreduction=32\n\ncoord_attention = CoordAtt(inp_dim, oup_dim, reduction=reduction)\noutput=coord_attention(inp)\nprint(output.shape)\n```\n\n***\n\n\n### 28. MobileViT Attention Usage\n\n#### 28.1. Paper\n\n[MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2021.10.05](https://arxiv.org/abs/2103.02907)\n\n\n#### 28.2. Overview\n\n![](./model/img/MobileViTAttention.png)\n\n#### 28.3. Usage Code\n\n```python\nfrom model.attention.MobileViTAttention import MobileViTAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    m=MobileViTAttention()\n    input=torch.randn(1,3,49,49)\n    output=m(input)\n    print(output.shape)  #output:(1,3,49,49)\n    \n```\n\n***\n\n\n### 29. ParNet Attention Usage\n\n#### 29.1. Paper\n\n[Non-deep Networks---ArXiv 2021.10.20](https://arxiv.org/abs/2110.07641)\n\n\n#### 29.2. Overview\n\n![](./model/img/ParNet.png)\n\n#### 29.3. Usage Code\n\n```python\nfrom model.attention.ParNetAttention import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    pna = ParNetAttention(channel=512)\n    output=pna(input)\n    print(output.shape) #50,512,7,7\n    \n```\n\n***\n\n\n### 30. UFO Attention Usage\n\n#### 30.1. Paper\n\n[UFO-ViT: High Performance Linear Vision Transformer without Softmax---ArXiv 2021.09.29](https://arxiv.org/abs/2110.07641)\n\n\n#### 30.2. Overview\n\n![](./model/img/UFO.png)\n\n#### 30.3. Usage Code\n\n```python\nfrom model.attention.UFOAttention import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    ufo = UFOAttention(d_model=512, d_k=512, d_v=512, h=8)\n    output=ufo(input,input,input)\n    print(output.shape) #[50, 49, 512]\n    \n```\n\n-\n\n### 31. ACmix Attention Usage\n\n#### 31.1. Paper\n\n[On the Integration of Self-Attention and Convolution](https://arxiv.org/pdf/2111.14556.pdf)\n\n#### 31.2. Usage Code\n\n```python\nfrom model.attention.ACmix import ACmix\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(50,256,7,7)\n    acmix = ACmix(in_planes=256, out_planes=256)\n    output=acmix(input)\n    print(output.shape)\n    \n```\n\n### 32. MobileViTv2 Attention Usage\n\n#### 32.1. Paper\n\n[Separable Self-attention for Mobile Vision Transformers---ArXiv 2022.06.06](https://arxiv.org/abs/2206.02680)\n\n\n#### 32.2. Overview\n\n![](./model/img/MobileViTv2.png)\n\n#### 32.3. Usage Code\n\n```python\nfrom model.attention.MobileViTv2Attention import MobileViTv2Attention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    sa = MobileViTv2Attention(d_model=512)\n    output=sa(input)\n    print(output.shape)\n    \n```\n\n### 33. DAT Attention Usage\n\n#### 33.1. Paper\n\n[Vision Transformer with Deformable Attention---CVPR2022](https://arxiv.org/abs/2201.00520)\n\n#### 33.2. Usage Code\n\n```python\nfrom model.attention.DAT import DAT\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = DAT(\n        img_size=224,\n        patch_size=4,\n        num_classes=1000,\n        expansion=4,\n        dim_stem=96,\n        dims=[96, 192, 384, 768],\n        depths=[2, 2, 6, 2],\n        stage_spec=[['L', 'S'], ['L', 'S'], ['L', 'D', 'L', 'D', 'L', 'D'], ['L', 'D']],\n        heads=[3, 6, 12, 24],\n        window_sizes=[7, 7, 7, 7] ,\n        groups=[-1, -1, 3, 6],\n        use_pes=[False, False, True, True],\n        dwc_pes=[False, False, False, False],\n        strides=[-1, -1, 1, 1],\n        sr_ratios=[-1, -1, -1, -1],\n        offset_range_factor=[-1, -1, 2, 2],\n        no_offs=[False, False, False, False],\n        fixed_pes=[False, False, False, False],\n        use_dwc_mlps=[False, False, False, False],\n        use_conv_patches=False,\n        drop_rate=0.0,\n        attn_drop_rate=0.0,\n        drop_path_rate=0.2,\n    )\n    output=model(input)\n    print(output[0].shape)\n    \n```\n\n### 34. CrossFormer Attention Usage\n\n#### 34.1. Paper\n\n[CROSSFORMER: A VERSATILE VISION TRANSFORMER HINGING ON CROSS-SCALE ATTENTION---ICLR 2022](https://arxiv.org/pdf/2108.00154.pdf)\n\n#### 34.2. Usage Code\n\n```python\nfrom model.attention.Crossformer import CrossFormer\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CrossFormer(img_size=224,\n        patch_size=[4, 8, 16, 32],\n        in_chans= 3,\n        num_classes=1000,\n        embed_dim=48,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        group_size=[7, 7, 7, 7],\n        mlp_ratio=4.,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.0,\n        drop_path_rate=0.1,\n        ape=False,\n        patch_norm=True,\n        use_checkpoint=False,\n        merge_size=[[2, 4], [2,4], [2, 4]]\n    )\n    output=model(input)\n    print(output.shape)\n    \n```\n\n### 35. MOATransformer Attention Usage\n\n#### 35.1. Paper\n\n[Aggregating Global Features into Local Vision Transformer](https://arxiv.org/abs/2201.12903)\n\n#### 35.2. Usage Code\n\n```python\nfrom model.attention.MOATransformer import MOATransformer\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = MOATransformer(\n        img_size=224,\n        patch_size=4,\n        in_chans=3,\n        num_classes=1000,\n        embed_dim=96,\n        depths=[2, 2, 6],\n        num_heads=[3, 6, 12],\n        window_size=14,\n        mlp_ratio=4.,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.0,\n        drop_path_rate=0.1,\n        ape=False,\n        patch_norm=True,\n        use_checkpoint=False\n    )\n    output=model(input)\n    print(output.shape)\n    \n```\n\n### 36. CrissCrossAttention Attention Usage\n\n#### 36.1. Paper\n\n[CCNet: Criss-Cross Attention for Semantic Segmentation](https://arxiv.org/abs/1811.11721)\n\n#### 36.2. Usage Code\n\n```python\nfrom model.attention.CrissCrossAttention import CrissCrossAttention\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(3, 64, 7, 7)\n    model = CrissCrossAttention(64)\n    outputs = model(input)\n    print(outputs.shape)\n    \n```\n\n### 37. Axial_attention Attention Usage\n\n#### 37.1. Paper\n\n[Axial Attention in Multidimensional Transformers](https://arxiv.org/abs/1912.12180)\n\n#### 37.2. Usage Code\n\n```python\nfrom model.attention.Axial_attention import AxialImageTransformer\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(3, 128, 7, 7)\n    model = AxialImageTransformer(\n        dim = 128,\n        depth = 12,\n        reversible = True\n    )\n    outputs = model(input)\n    print(outputs.shape)\n    \n```\n\n***\n\n\n# Backbone Series\n\n- Pytorch implementation of [\"Deep Residual Learning for Image Recognition---CVPR2016 Best Paper\"](https://arxiv.org/pdf/1512.03385.pdf)\n\n- Pytorch implementation of [\"Aggregated Residual Transformations for Deep Neural Networks---CVPR2017\"](https://arxiv.org/abs/1611.05431v2)\n\n- Pytorch implementation of [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2020.10.05](https://arxiv.org/abs/2103.02907)\n\n- Pytorch implementation of [Patches Are All You Need?---ICLR2022 (Under Review)](https://openreview.net/forum?id=TVHS5Y4dNvM)\n\n- Pytorch implementation of [Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer---ArXiv 2021.06.07](https://arxiv.org/abs/2106.03650)\n\n- Pytorch implementation of [ConTNet: Why not use convolution and transformer at the same time?---ArXiv 2021.04.27](https://arxiv.org/abs/2104.13497)\n\n- Pytorch implementation of [Vision Transformers with Hierarchical Attention---ArXiv 2022.06.15](https://arxiv.org/abs/2106.03180)\n\n- Pytorch implementation of [Co-Scale Conv-Attentional Image Transformers---ArXiv 2021.08.26](https://arxiv.org/abs/2104.06399)\n\n- Pytorch implementation of [Conditional Positional Encodings for Vision Transformers](https://arxiv.org/abs/2102.10882)\n\n- Pytorch implementation of [Rethinking Spatial Dimensions of Vision Transformers---ICCV 2021](https://arxiv.org/abs/2103.16302)\n\n- Pytorch implementation of [CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification---ICCV 2021](https://arxiv.org/abs/2103.14899)\n\n- Pytorch implementation of [Transformer in Transformer---NeurIPS 2021](https://arxiv.org/abs/2103.00112)\n\n- Pytorch implementation of [DeepViT: Towards Deeper Vision Transformer](https://arxiv.org/abs/2103.11886)\n\n- Pytorch implementation of [Incorporating Convolution Designs into Visual Transformers](https://arxiv.org/abs/2103.11816)\n***\n\n- Pytorch implementation of [ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases](https://arxiv.org/abs/2103.10697)\n\n- Pytorch implementation of [Augmenting Convolutional networks with attention-based aggregation](https://arxiv.org/abs/2112.13692)\n\n- Pytorch implementation of [Going deeper with Image Transformers---ICCV 2021 (Oral)](https://arxiv.org/abs/2103.17239)\n\n- Pytorch implementation of [Training data-efficient image transformers & distillation through attention---ICML 2021](https://arxiv.org/abs/2012.12877)\n\n- Pytorch implementation of [LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference](https://arxiv.org/abs/2104.01136)\n\n- Pytorch implementation of [VOLO: Vision Outlooker for Visual Recognition](https://arxiv.org/abs/2106.13112)\n\n- Pytorch implementation of [Container: Context Aggregation Network---NeuIPS 2021](https://arxiv.org/abs/2106.01401)\n\n- Pytorch implementation of [CMT: Convolutional Neural Networks Meet Vision Transformers---CVPR 2022](https://arxiv.org/abs/2107.06263)\n\n- Pytorch implementation of [Vision Transformer with Deformable Attention---CVPR 2022](https://arxiv.org/abs/2201.00520)\n\n\n### 1. ResNet Usage\n#### 1.1. Paper\n[\"Deep Residual Learning for Image Recognition---CVPR2016 Best Paper\"](https://arxiv.org/pdf/1512.03385.pdf)\n\n#### 1.2. Overview\n![](./model/img/resnet.png)\n![](./model/img/resnet2.jpg)\n\n#### 1.3. Usage Code\n```python\n\nfrom model.backbone.resnet import ResNet50,ResNet101,ResNet152\nimport torch\nif __name__ == '__main__':\n    input=torch.randn(50,3,224,224)\n    resnet50=ResNet50(1000)\n    # resnet101=ResNet101(1000)\n    # resnet152=ResNet152(1000)\n    out=resnet50(input)\n    print(out.shape)\n\n```\n\n\n### 2. ResNeXt Usage\n#### 2.1. Paper\n\n[\"Aggregated Residual Transformations for Deep Neural Networks---CVPR2017\"](https://arxiv.org/abs/1611.05431v2)\n\n#### 2.2. Overview\n![](./model/img/resnext.png)\n\n#### 2.3. Usage Code\n```python\n\nfrom model.backbone.resnext import ResNeXt50,ResNeXt101,ResNeXt152\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(50,3,224,224)\n    resnext50=ResNeXt50(1000)\n    # resnext101=ResNeXt101(1000)\n    # resnext152=ResNeXt152(1000)\n    out=resnext50(input)\n    print(out.shape)\n\n\n```\n\n\n\n### 3. MobileViT Usage\n#### 3.1. Paper\n\n[MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2020.10.05](https://arxiv.org/abs/2103.02907)\n\n#### 3.2. Overview\n![](./model/img/mobileViT.jpg)\n\n#### 3.3. Usage Code\n```python\n\nfrom model.backbone.MobileViT import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n\n    ### mobilevit_xxs\n    mvit_xxs=mobilevit_xxs()\n    out=mvit_xxs(input)\n    print(out.shape)\n\n    ### mobilevit_xs\n    mvit_xs=mobilevit_xs()\n    out=mvit_xs(input)\n    print(out.shape)\n\n\n    ### mobilevit_s\n    mvit_s=mobilevit_s()\n    out=mvit_s(input)\n    print(out.shape)\n\n```\n\n\n\n\n\n### 4. ConvMixer Usage\n#### 4.1. Paper\n[Patches Are All You Need?---ICLR2022 (Under Review)](https://openreview.net/forum?id=TVHS5Y4dNvM)\n#### 4.2. Overview\n![](./model/img/ConvMixer.png)\n\n#### 4.3. Usage Code\n```python\n\nfrom model.backbone.ConvMixer import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    x=torch.randn(1,3,224,224)\n    convmixer=ConvMixer(dim=512,depth=12)\n    out=convmixer(x)\n    print(out.shape)  #[1, 1000]\n\n\n```\n\n### 5. ShuffleTransformer Usage\n#### 5.1. Paper\n[Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer](https://arxiv.org/pdf/2106.03650.pdf)\n\n#### 5.2. Usage Code\n```python\n\nfrom model.backbone.ShuffleTransformer import ShuffleTransformer\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    sft = ShuffleTransformer()\n    output=sft(input)\n    print(output.shape)\n\n\n```\n\n### 6. ConTNet Usage\n#### 6.1. Paper\n[ConTNet: Why not use convolution and transformer at the same time?](https://arxiv.org/abs/2104.13497)\n\n#### 6.2. Usage Code\n```python\n\nfrom model.backbone.ConTNet import ConTNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == \"__main__\":\n    model = build_model(use_avgdown=True, relative=True, qkv_bias=True, pre_norm=True)\n    input = torch.randn(1, 3, 224, 224)\n    out = model(input)\n    print(out.shape)\n\n\n```\n\n### 7 HATNet Usage\n#### 7.1. Paper\n[Vision Transformers with Hierarchical Attention](https://arxiv.org/abs/2106.03180)\n\n#### 7.2. Usage Code\n```python\n\nfrom model.backbone.HATNet import HATNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    hat = HATNet(dims=[48, 96, 240, 384], head_dim=48, expansions=[8, 8, 4, 4],\n        grid_sizes=[8, 7, 7, 1], ds_ratios=[8, 4, 2, 1], depths=[2, 2, 6, 3])\n    output=hat(input)\n    print(output.shape)\n\n\n```\n\n### 8 CoaT Usage\n#### 8.1. Paper\n[Co-Scale Conv-Attentional Image Transformers](https://arxiv.org/abs/2104.06399)\n\n#### 8.2. Usage Code\n```python\n\nfrom model.backbone.CoaT import CoaT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CoaT(patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4])\n    output=model(input)\n    print(output.shape) # torch.Size([1, 1000])\n\n```\n\n### 9 PVT Usage\n#### 9.1. Paper\n[PVT v2: Improved Baselines with Pyramid Vision Transformer](https://arxiv.org/pdf/2106.13797.pdf)\n\n#### 9.2. Usage Code\n```python\n\nfrom model.backbone.PVT import PyramidVisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = PyramidVisionTransformer(\n        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1])\n    output=model(input)\n    print(output.shape)\n\n```\n\n\n### 10 CPVT Usage\n#### 10.1. Paper\n[Conditional Positional Encodings for Vision Transformers](https://arxiv.org/abs/2102.10882)\n\n#### 10.2. Usage Code\n```python\n\nfrom model.backbone.CPVT import CPVTV2\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CPVTV2(\n        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1])\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 11 PIT Usage\n#### 11.1. Paper\n[Rethinking Spatial Dimensions of Vision Transformers](https://arxiv.org/abs/2103.16302)\n\n#### 11.2. Usage Code\n```python\n\nfrom model.backbone.PIT import PoolingTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = PoolingTransformer(\n        image_size=224,\n        patch_size=14,\n        stride=7,\n        base_dims=[64, 64, 64],\n        depth=[3, 6, 4],\n        heads=[4, 8, 16],\n        mlp_ratio=4\n    )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 12 CrossViT Usage\n#### 12.1. Paper\n[CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification](https://arxiv.org/abs/2103.14899)\n\n#### 12.2. Usage Code\n```python\n\nfrom model.backbone.CrossViT import VisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == \"__main__\":\n    input=torch.randn(1,3,224,224)\n    model = VisionTransformer(\n        img_size=[240, 224],\n        patch_size=[12, 16], \n        embed_dim=[192, 384], \n        depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],\n        num_heads=[6, 6], \n        mlp_ratio=[4, 4, 1], \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n    )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 13 TnT Usage\n#### 13.1. Paper\n[Transformer in Transformer](https://arxiv.org/abs/2103.00112)\n\n#### 13.2. Usage Code\n```python\n\nfrom model.backbone.TnT import TNT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = TNT(\n        img_size=224, \n        patch_size=16, \n        outer_dim=384, \n        inner_dim=24, \n        depth=12,\n        outer_num_heads=6, \n        inner_num_heads=4, \n        qkv_bias=False,\n        inner_stride=4)\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 14 DViT Usage\n#### 14.1. Paper\n[DeepViT: Towards Deeper Vision Transformer](https://arxiv.org/abs/2103.11886)\n\n#### 14.2. Usage Code\n```python\n\nfrom model.backbone.DViT import DeepVisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = DeepVisionTransformer(\n        patch_size=16, embed_dim=384, \n        depth=[False] * 16, \n        apply_transform=[False] * 0 + [True] * 32, \n        num_heads=12, \n        mlp_ratio=3, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 15 CeiT Usage\n#### 15.1. Paper\n[Incorporating Convolution Designs into Visual Transformers](https://arxiv.org/abs/2103.11816)\n\n#### 15.2. Usage Code\n```python\n\nfrom model.backbone.CeiT import CeIT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CeIT(\n        hybrid_backbone=Image2Tokens(),\n        patch_size=4, \n        embed_dim=192, \n        depth=12, \n        num_heads=3, \n        mlp_ratio=4, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 16 ConViT Usage\n#### 16.1. Paper\n[ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases](https://arxiv.org/abs/2103.10697)\n\n#### 16.2. Usage Code\n```python\n\nfrom model.backbone.ConViT import VisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VisionTransformer(\n        num_heads=16,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 17 CaiT Usage\n#### 17.1. Paper\n[Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239)\n\n#### 17.2. Usage Code\n```python\n\nfrom model.backbone.CaiT import CaiT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CaiT(\n        img_size= 224,\n        patch_size=16, \n        embed_dim=192, \n        depth=24, \n        num_heads=4, \n        mlp_ratio=4, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_scale=1e-5,\n        depth_token_only=2\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 18 PatchConvnet Usage\n#### 18.1. Paper\n[Augmenting Convolutional networks with attention-based aggregation](https://arxiv.org/abs/2112.13692)\n\n#### 18.2. Usage Code\n```python\n\nfrom model.backbone.PatchConvnet import PatchConvnet\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = PatchConvnet(\n        patch_size=16,\n        embed_dim=384,\n        depth=60,\n        num_heads=1,\n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        Patch_layer=ConvStem,\n        Attention_block=Conv_blocks_se,\n        depth_token_only=1,\n        mlp_ratio_clstk=3.0,\n    )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 19 DeiT Usage\n#### 19.1. Paper\n[Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877)\n\n#### 19.2. Usage Code\n```python\n\nfrom model.backbone.DeiT import DistilledVisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = DistilledVisionTransformer(\n        patch_size=16, \n        embed_dim=384, \n        depth=12, \n        num_heads=6, \n        mlp_ratio=4, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n        )\n    output=model(input)\n    print(output[0].shape)\n\n```\n\n### 20 LeViT Usage\n#### 20.1. Paper\n[LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference](https://arxiv.org/abs/2104.01136)\n\n#### 20.2. Usage Code\n```python\n\nfrom model.backbone.LeViT import *\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    for name in specification:\n        input=torch.randn(1,3,224,224)\n        model = globals()[name](fuse=True, pretrained=False)\n        model.eval()\n        output = model(input)\n        print(output.shape)\n\n```\n\n### 21 VOLO Usage\n#### 21.1. Paper\n[VOLO: Vision Outlooker for Visual Recognition](https://arxiv.org/abs/2106.13112)\n\n#### 21.2. Usage Code\n```python\n\nfrom model.backbone.VOLO import VOLO\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VOLO([4, 4, 8, 2],\n                 embed_dims=[192, 384, 384, 384],\n                 num_heads=[6, 12, 12, 12],\n                 mlp_ratios=[3, 3, 3, 3],\n                 downsamples=[True, False, False, False],\n                 outlook_attention=[True, False, False, False ],\n                 post_layers=['ca', 'ca'],\n                 )\n    output=model(input)\n    print(output[0].shape)\n\n```\n\n### 22 Container Usage\n#### 22.1. Paper\n[Container: Context Aggregation Network](https://arxiv.org/abs/2106.01401)\n\n#### 22.2. Usage Code\n```python\n\nfrom model.backbone.Container import VisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VisionTransformer(\n        img_size=[224, 56, 28, 14], \n        patch_size=[4, 2, 2, 2], \n        embed_dim=[64, 128, 320, 512], \n        depth=[3, 4, 8, 3], \n        num_heads=16, \n        mlp_ratio=[8, 8, 4, 4], \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6))\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 23 CMT Usage\n#### 23.1. Paper\n[CMT: Convolutional Neural Networks Meet Vision Transformers](https://arxiv.org/abs/2107.06263)\n\n#### 23.2. Usage Code\n```python\n\nfrom model.backbone.CMT import CMT_Tiny\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CMT_Tiny()\n    output=model(input)\n    print(output[0].shape)\n\n```\n\n\n\n\n\n\n# MLP Series\n\n- Pytorch implementation of [\"RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition---arXiv 2021.05.05\"](https://arxiv.org/pdf/2105.01883v1.pdf)\n\n- Pytorch implementation of [\"MLP-Mixer: An all-MLP Architecture for Vision---arXiv 2021.05.17\"](https://arxiv.org/pdf/2105.01601.pdf)\n\n- Pytorch implementation of [\"ResMLP: Feedforward networks for image classification with data-efficient training---arXiv 2021.05.07\"](https://arxiv.org/pdf/2105.03404.pdf)\n\n- Pytorch implementation of [\"Pay Attention to MLPs---arXiv 2021.05.17\"](https://arxiv.org/abs/2105.08050)\n\n\n- Pytorch implementation of [\"Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?---arXiv 2021.09.12\"](https://arxiv.org/abs/2109.05422)\n\n### 1. RepMLP Usage\n#### 1.1. Paper\n[\"RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition\"](https://arxiv.org/pdf/2105.01883v1.pdf)\n\n#### 1.2. Overview\n![](./model/img/repmlp.png)\n\n#### 1.3. Usage Code\n```python\nfrom model.mlp.repmlp import RepMLP\nimport torch\nfrom torch import nn\n\nN=4 #batch size\nC=512 #input dim\nO=1024 #output dim\nH=14 #image height\nW=14 #image width\nh=7 #patch height\nw=7 #patch width\nfc1_fc2_reduction=1 #reduction ratio\nfc3_groups=8 # groups\nrepconv_kernels=[1,3,5,7] #kernel list\nrepmlp=RepMLP(C,O,H,W,h,w,fc1_fc2_reduction,fc3_groups,repconv_kernels=repconv_kernels)\nx=torch.randn(N,C,H,W)\nrepmlp.eval()\nfor module in repmlp.modules():\n    if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):\n        nn.init.uniform_(module.running_mean, 0, 0.1)\n        nn.init.uniform_(module.running_var, 0, 0.1)\n        nn.init.uniform_(module.weight, 0, 0.1)\n        nn.init.uniform_(module.bias, 0, 0.1)\n\n#training result\nout=repmlp(x)\n#inference result\nrepmlp.switch_to_deploy()\ndeployout = repmlp(x)\n\nprint(((deployout-out)**2).sum())\n```\n\n### 2. MLP-Mixer Usage\n#### 2.1. Paper\n[\"MLP-Mixer: An all-MLP Architecture for Vision\"](https://arxiv.org/pdf/2105.01601.pdf)\n\n#### 2.2. Overview\n![](./model/img/mlpmixer.png)\n\n#### 2.3. Usage Code\n```python\nfrom model.mlp.mlp_mixer import MlpMixer\nimport torch\nmlp_mixer=MlpMixer(num_classes=1000,num_blocks=10,patch_size=10,tokens_hidden_dim=32,channels_hidden_dim=1024,tokens_mlp_dim=16,channels_mlp_dim=1024)\ninput=torch.randn(50,3,40,40)\noutput=mlp_mixer(input)\nprint(output.shape)\n```\n\n***\n\n### 3. ResMLP Usage\n#### 3.1. Paper\n[\"ResMLP: Feedforward networks for image classification with data-efficient training\"](https://arxiv.org/pdf/2105.03404.pdf)\n\n#### 3.2. Overview\n![](./model/img/resmlp.png)\n\n#### 3.3. Usage Code\n```python\nfrom model.mlp.resmlp import ResMLP\nimport torch\n\ninput=torch.randn(50,3,14,14)\nresmlp=ResMLP(dim=128,image_size=14,patch_size=7,class_num=1000)\nout=resmlp(input)\nprint(out.shape) #the last dimention is class_num\n```\n\n***\n\n### 4. gMLP Usage\n#### 4.1. Paper\n[\"Pay Attention to MLPs\"](https://arxiv.org/abs/2105.08050)\n\n#### 4.2. Overview\n![](./model/img/gMLP.jpg)\n\n#### 4.3. Usage Code\n```python\nfrom model.mlp.g_mlp import gMLP\nimport torch\n\nnum_tokens=10000\nbs=50\nlen_sen=49\nnum_layers=6\ninput=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen\ngmlp = gMLP(num_tokens=num_tokens,len_sen=len_sen,dim=512,d_ff=1024)\noutput=gmlp(input)\nprint(output.shape)\n```\n\n***\n\n### 5. sMLP Usage\n#### 5.1. Paper\n[\"Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?\"](https://arxiv.org/abs/2109.05422)\n\n#### 5.2. Overview\n![](./model/img/sMLP.jpg)\n\n#### 5.3. Usage Code\n```python\nfrom model.mlp.sMLP_block import sMLPBlock\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,3,224,224)\n    smlp=sMLPBlock(h=224,w=224)\n    out=smlp(input)\n    print(out.shape)\n```\n\n### 6. vip-mlp Usage\n#### 6.1. Paper\n[\"Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition\"](https://arxiv.org/abs/2106.12368)\n\n#### 6.2. Usage Code\n```python\nfrom model.mlp.vip-mlp import VisionPermutator\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VisionPermutator(\n        layers=[4, 3, 8, 3], \n        embed_dims=[384, 384, 384, 384], \n        patch_size=14, \n        transitions=[False, False, False, False],\n        segment_dim=[16, 16, 16, 16], \n        mlp_ratios=[3, 3, 3, 3], \n        mlp_fn=WeightedPermuteMLP\n    )\n    output=model(input)\n    print(output.shape)\n```\n\n\n# Re-Parameter Series\n\n- Pytorch implementation of [\"RepVGG: Making VGG-style ConvNets Great Again---CVPR2021\"](https://arxiv.org/abs/2101.03697)\n\n- Pytorch implementation of [\"ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks---ICCV2019\"](https://arxiv.org/abs/1908.03930)\n\n- Pytorch implementation of [\"Diverse Branch Block: Building a Convolution as an Inception-like Unit---CVPR2021\"](https://arxiv.org/abs/2103.13425)\n\n\n***\n\n### 1. RepVGG Usage\n#### 1.1. Paper\n[\"RepVGG: Making VGG-style ConvNets Great Again\"](https://arxiv.org/abs/2101.03697)\n\n#### 1.2. Overview\n![](./model/img/repvgg.png)\n\n#### 1.3. Usage Code\n```python\n\nfrom model.rep.repvgg import RepBlock\nimport torch\n\n\ninput=torch.randn(50,512,49,49)\nrepblock=RepBlock(512,512)\nrepblock.eval()\nout=repblock(input)\nrepblock._switch_to_deploy()\nout2=repblock(input)\nprint('difference between vgg and repvgg')\nprint(((out2-out)**2).sum())\n```\n\n\n\n***\n\n### 2. ACNet Usage\n#### 2.1. Paper\n[\"ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks\"](https://arxiv.org/abs/1908.03930)\n\n#### 2.2. Overview\n![](./model/img/acnet.png)\n\n#### 2.3. Usage Code\n```python\nfrom model.rep.acnet import ACNet\nimport torch\nfrom torch import nn\n\ninput=torch.randn(50,512,49,49)\nacnet=ACNet(512,512)\nacnet.eval()\nout=acnet(input)\nacnet._switch_to_deploy()\nout2=acnet(input)\nprint('difference:')\nprint(((out2-out)**2).sum())\n\n```\n\n\n\n***\n\n### 2. Diverse Branch Block Usage\n#### 2.1. Paper\n[\"Diverse Branch Block: Building a Convolution as an Inception-like Unit\"](https://arxiv.org/abs/2103.13425)\n\n#### 2.2. Overview\n![](./model/img/ddb.png)\n\n#### 2.3. Usage Code\n##### 2.3.1 Transform I\n```python\nfrom model.rep.ddb import transI_conv_bn\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n#conv+bn\nconv1=nn.Conv2d(64,64,3,padding=1)\nbn1=nn.BatchNorm2d(64)\nbn1.eval()\nout1=bn1(conv1(input))\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transI_conv_bn(conv1,bn1)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.2 Transform II\n```python\nfrom model.rep.ddb import transII_conv_branch\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1=nn.Conv2d(64,64,3,padding=1)\nconv2=nn.Conv2d(64,64,3,padding=1)\nout1=conv1(input)+conv2(input)\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transII_conv_branch(conv1,conv2)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.3 Transform III\n```python\nfrom model.rep.ddb import transIII_conv_sequential\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1=nn.Conv2d(64,64,1,padding=0,bias=False)\nconv2=nn.Conv2d(64,64,3,padding=1,bias=False)\nout1=conv2(conv1(input))\n\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1,bias=False)\nconv_fuse.weight.data=transIII_conv_sequential(conv1,conv2)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.4 Transform IV\n```python\nfrom model.rep.ddb import transIV_conv_concat\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1=nn.Conv2d(64,32,3,padding=1)\nconv2=nn.Conv2d(64,32,3,padding=1)\nout1=torch.cat([conv1(input),conv2(input)],dim=1)\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transIV_conv_concat(conv1,conv2)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.5 Transform V\n```python\nfrom model.rep.ddb import transV_avg\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\navg=nn.AvgPool2d(kernel_size=3,stride=1)\nout1=avg(input)\n\nconv=transV_avg(64,3)\nout2=conv(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n\n##### 2.3.6 Transform VI\n```python\nfrom model.rep.ddb import transVI_conv_scale\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1x1=nn.Conv2d(64,64,1)\nconv1x3=nn.Conv2d(64,64,(1,3),padding=(0,1))\nconv3x1=nn.Conv2d(64,64,(3,1),padding=(1,0))\nout1=conv1x1(input)+conv1x3(input)+conv3x1(input)\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transVI_conv_scale(conv1x1,conv1x3,conv3x1)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n\n\n\n\n# Convolution Series\n\n- Pytorch implementation of [\"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications---CVPR2017\"](https://arxiv.org/abs/1704.04861)\n\n- Pytorch implementation of [\"Efficientnet: Rethinking model scaling for convolutional neural networks---PMLR2019\"](http://proceedings.mlr.press/v97/tan19a.html)\n\n- Pytorch implementation of [\"Involution: Inverting the Inherence of Convolution for Visual Recognition---CVPR2021\"](https://arxiv.org/abs/2103.06255)\n\n- Pytorch implementation of [\"Dynamic Convolution: Attention over Convolution Kernels---CVPR2020 Oral\"](https://arxiv.org/abs/1912.03458)\n\n- Pytorch implementation of [\"CondConv: Conditionally Parameterized Convolutions for Efficient Inference---NeurIPS2019\"](https://arxiv.org/abs/1904.04971)\n\n***\n\n### 1. Depthwise Separable Convolution Usage\n#### 1.1. Paper\n[\"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications\"](https://arxiv.org/abs/1704.04861)\n\n#### 1.2. Overview\n![](./model/img/DepthwiseSeparableConv.png)\n\n#### 1.3. Usage Code\n```python\nfrom model.conv.DepthwiseSeparableConvolution import DepthwiseSeparableConvolution\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,3,224,224)\ndsconv=DepthwiseSeparableConvolution(3,64)\nout=dsconv(input)\nprint(out.shape)\n```\n\n***\n\n\n### 2. MBConv Usage\n#### 2.1. Paper\n[\"Efficientnet: Rethinking model scaling for convolutional neural networks\"](http://proceedings.mlr.press/v97/tan19a.html)\n\n#### 2.2. Overview\n![](./model/img/MBConv.jpg)\n\n#### 2.3. Usage Code\n```python\nfrom model.conv.MBConv import MBConvBlock\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,3,224,224)\nmbconv=MBConvBlock(ksize=3,input_filters=3,output_filters=512,image_size=224)\nout=mbconv(input)\nprint(out.shape)\n\n\n```\n\n***\n\n\n### 3. Involution Usage\n#### 3.1. Paper\n[\"Involution: Inverting the Inherence of Convolution for Visual Recognition\"](https://arxiv.org/abs/2103.06255)\n\n#### 3.2. Overview\n![](./model/img/Involution.png)\n\n#### 3.3. Usage Code\n```python\nfrom model.conv.Involution import Involution\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,4,64,64)\ninvolution=Involution(kernel_size=3,in_channel=4,stride=2)\nout=involution(input)\nprint(out.shape)\n```\n\n***\n\n\n### 4. DynamicConv Usage\n#### 4.1. Paper\n[\"Dynamic Convolution: Attention over Convolution Kernels\"](https://arxiv.org/abs/1912.03458)\n\n#### 4.2. Overview\n![](./model/img/DynamicConv.png)\n\n#### 4.3. Usage Code\n```python\nfrom model.conv.DynamicConv import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(2,32,64,64)\n    m=DynamicConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)\n    out=m(input)\n    print(out.shape) # 2,32,64,64\n\n```\n\n***\n\n\n### 5. CondConv Usage\n#### 5.1. Paper\n[\"CondConv: Conditionally Parameterized Convolutions for Efficient Inference\"](https://arxiv.org/abs/1904.04971)\n\n#### 5.2. Overview\n![](./model/img/CondConv.png)\n\n#### 5.3. Usage Code\n```python\nfrom model.conv.CondConv import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(2,32,64,64)\n    m=CondConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)\n    out=m(input)\n    print(out.shape)\n\n```\n\n***\n\n\n"
  },
  {
    "path": "model/fighingcv.egg-info/SOURCES.txt",
    "content": "LICENSE\nREADME.md\nsetup.py\nmodel/fighingcv.egg-info/PKG-INFO\nmodel/fighingcv.egg-info/SOURCES.txt\nmodel/fighingcv.egg-info/dependency_links.txt\nmodel/fighingcv.egg-info/entry_points.txt\nmodel/fighingcv.egg-info/requires.txt\nmodel/fighingcv.egg-info/top_level.txt"
  },
  {
    "path": "model/fighingcv.egg-info/dependency_links.txt",
    "content": "\n"
  },
  {
    "path": "model/fighingcv.egg-info/entry_points.txt",
    "content": "[console_scripts]\nhuggingface-cli = huggingface_hub.commands.huggingface_cli:main\n\n"
  },
  {
    "path": "model/fighingcv.egg-info/requires.txt",
    "content": "filelock\nrequests\ntqdm\npyyaml>=5.1\ntyping-extensions>=3.7.4.3\npackaging>=20.9\n\n[:python_version < \"3.8\"]\nimportlib_metadata\n\n[all]\npytest\npytest-cov\ndatasets\nsoundfile\nblack==22.3\nisort>=5.5.4\nflake8>=3.8.3\nflake8-bugbear\n\n[dev]\npytest\npytest-cov\ndatasets\nsoundfile\nblack==22.3\nisort>=5.5.4\nflake8>=3.8.3\nflake8-bugbear\n\n[fastai]\ntoml\nfastai>=2.4\nfastcore>=1.3.27\n\n[quality]\nblack==22.3\nisort>=5.5.4\nflake8>=3.8.3\nflake8-bugbear\n\n[tensorflow]\ntensorflow\npydot\ngraphviz\n\n[testing]\npytest\npytest-cov\ndatasets\nsoundfile\n\n[torch]\ntorch\n"
  },
  {
    "path": "model/fighingcv.egg-info/top_level.txt",
    "content": "\n"
  },
  {
    "path": "model/huggingface_hub.egg-info/PKG-INFO",
    "content": "Metadata-Version: 2.1\nName: huggingface-hub\nVersion: 1.0.0\nSummary: Client library to download and publish models, datasets and other repos on the huggingface.co hub\nHome-page: https://github.com/xmu-xiaoma666/External-Attention-pytorch\nAuthor: Hugging Face, Inc.\nAuthor-email: julien@huggingface.co\nLicense: Apache\nKeywords: model-hub machine-learning models natural-language-processing deep-learning pytorch pretrained-models\nPlatform: UNKNOWN\nClassifier: Intended Audience :: Developers\nClassifier: Intended Audience :: Education\nClassifier: Intended Audience :: Science/Research\nClassifier: License :: OSI Approved :: Apache Software License\nClassifier: Operating System :: OS Independent\nClassifier: Programming Language :: Python :: 3\nClassifier: Topic :: Scientific/Engineering :: Artificial Intelligence\nRequires-Python: >=3.7.0\nDescription-Content-Type: text/markdown\nProvides-Extra: torch\nProvides-Extra: fastai\nProvides-Extra: tensorflow\nProvides-Extra: testing\nProvides-Extra: quality\nProvides-Extra: all\nProvides-Extra: dev\nLicense-File: LICENSE\n\n\n<img src=\"./FightingCVimg/LOGO.gif\" height=\"200\" width=\"400\"/>\n\n简体中文 | [English](./README_EN.md)\n\n# FightingCV 代码库， 包含 [***Attention***](#attention-series),[***Backbone***](#backbone-series), [***MLP***](#mlp-series), [***Re-parameter***](#re-parameter-series), [**Convolution**](#convolution-series)\n\n![](https://img.shields.io/badge/fightingcv-v0.0.1-brightgreen)\n![](https://img.shields.io/badge/python->=v3.0-blue)\n![](https://img.shields.io/badge/pytorch->=v1.4-red)\n\n<!--\n-------\n*If this project is helpful to you, welcome to give a ***star***.* \n\n*Don't forget to ***follow*** me to learn about project updates.*\n\n-->\n\n-------\n\n\n🔥🔥🔥 **重磅！！！作为项目补充，最近全新开源了一个目标检测代码库 [YOLOAir](https://github.com/iscyy/yoloair)，里面在目标检测算法中集成了各种Attention机制，代码简洁易读，欢迎大家来玩呀！**\n\n\n\n\n![image](https://user-images.githubusercontent.com/33897496/184842902-9acff374-b3e7-401a-80fd-9d484e40c637.png)\n\n\n\n-------\n\nHello，大家好，我是小马🚀🚀🚀\n\n***For 小白（Like Me）：***\n最近在读论文的时候会发现一个问题，有时候论文核心思想非常简单，核心代码可能也就十几行。但是打开作者release的源码时，却发现提出的模块嵌入到分类、检测、分割等任务框架中，导致代码比较冗余，对于特定任务框架不熟悉的我，**很难找到核心代码**，导致在论文和网络思想的理解上会有一定困难。\n\n***For 进阶者（Like You）：***\n如果把Conv、FC、RNN这些基本单元看做小的Lego积木，把Transformer、ResNet这些结构看成已经搭好的Lego城堡。那么本项目提供的模块就是一个个具有完整语义信息的Lego组件。**让科研工作者们避免反复造轮子**，只需思考如何利用这些“Lego组件”，搭建出更多绚烂多彩的作品。\n\n***For 大神（May Be Like You）：***\n能力有限，**不喜轻喷**！！！\n\n***For All：***\n本项目就是要实现一个既能**让深度学习小白也能搞懂**，又能**服务科研和工业社区**的代码库。作为[**FightingCV公众号**](https://mp.weixin.qq.com/s/m9RiivbbDPdjABsTd6q8FA)的补充，本项目的宗旨是从代码角度，实现🚀**让世界上没有难读的论文**🚀。\n\n（同时也非常欢迎各位科研工作者将自己的工作的核心代码整理到本项目中，推动科研社区的发展，会在readme中注明代码的作者~）\n\n\n\n\n\n\n## 公众号 & 微信交流群\n\n欢迎大家关注公众号：**FightingCV**\n\n公众号**每天**都会进行**论文、算法和代码的干货分享**哦~\n\n\n<!-- 已建立**机器学习/深度学习算法/计算机视觉/多模态交流群**微信交流群！\n\n（加不进去可以加微信：**775629340**，记得备注【**公司/学校+方向+ID**】） -->\n\n**每天在群里分享一些近期的论文和解析**，欢迎大家一起**学习交流**哈~~~\n（加不进去可以加微信：**775629340**，记得备注【**公司/学校+方向+ID**】）\n\n![](./FightingCVimg/wechat.jpg)\n\n强烈推荐大家关注[**知乎**](https://www.zhihu.com/people/jason-14-58-38/posts)账号和[**FightingCV公众号**](https://mp.weixin.qq.com/s/m9RiivbbDPdjABsTd6q8FA)，可以快速了解到最新优质的干货资源。\n\n\n\n\n***\n\n# 目录\n\n- [Attention Series](#attention-series)\n    - [1. External Attention Usage](#1-external-attention-usage)\n\n    - [2. Self Attention Usage](#2-self-attention-usage)\n\n    - [3. Simplified Self Attention Usage](#3-simplified-self-attention-usage)\n\n    - [4. Squeeze-and-Excitation Attention Usage](#4-squeeze-and-excitation-attention-usage)\n\n    - [5. SK Attention Usage](#5-sk-attention-usage)\n\n    - [6. CBAM Attention Usage](#6-cbam-attention-usage)\n\n    - [7. BAM Attention Usage](#7-bam-attention-usage)\n    \n    - [8. ECA Attention Usage](#8-eca-attention-usage)\n\n    - [9. DANet Attention Usage](#9-danet-attention-usage)\n\n    - [10. Pyramid Split Attention (PSA) Usage](#10-Pyramid-Split-Attention-Usage)\n\n    - [11. Efficient Multi-Head Self-Attention(EMSA) Usage](#11-Efficient-Multi-Head-Self-Attention-Usage)\n\n    - [12. Shuffle Attention Usage](#12-Shuffle-Attention-Usage)\n    \n    - [13. MUSE Attention Usage](#13-MUSE-Attention-Usage)\n  \n    - [14. SGE Attention Usage](#14-SGE-Attention-Usage)\n\n    - [15. A2 Attention Usage](#15-A2-Attention-Usage)\n\n    - [16. AFT Attention Usage](#16-AFT-Attention-Usage)\n\n    - [17. Outlook Attention Usage](#17-Outlook-Attention-Usage)\n\n    - [18. ViP Attention Usage](#18-ViP-Attention-Usage)\n\n    - [19. CoAtNet Attention Usage](#19-CoAtNet-Attention-Usage)\n\n    - [20. HaloNet Attention Usage](#20-HaloNet-Attention-Usage)\n\n    - [21. Polarized Self-Attention Usage](#21-Polarized-Self-Attention-Usage)\n\n    - [22. CoTAttention Usage](#22-CoTAttention-Usage)\n\n    - [23. Residual Attention Usage](#23-Residual-Attention-Usage)\n  \n    - [24. S2 Attention Usage](#24-S2-Attention-Usage)\n\n    - [25. GFNet Attention Usage](#25-GFNet-Attention-Usage)\n\n    - [26. Triplet Attention Usage](#26-TripletAttention-Usage)\n\n    - [27. Coordinate Attention Usage](#27-Coordinate-Attention-Usage)\n\n    - [28. MobileViT Attention Usage](#28-MobileViT-Attention-Usage)\n\n    - [29. ParNet Attention Usage](#29-ParNet-Attention-Usage)\n\n    - [30. UFO Attention Usage](#30-UFO-Attention-Usage)\n\n    - [31. ACmix Attention Usage](#31-Acmix-Attention-Usage)\n  \n    - [32. MobileViTv2 Attention Usage](#32-MobileViTv2-Attention-Usage)\n\n    - [33. DAT Attention Usage](#33-DAT-Attention-Usage)\n\n    - [34. CrossFormer Attention Usage](#34-CrossFormer-Attention-Usage)\n\n    - [35. MOATransformer Attention Usage](#35-MOATransformer-Attention-Usage)\n\n    - [36. CrissCrossAttention Attention Usage](#36-CrissCrossAttention-Attention-Usage)\n\n    - [37. Axial_attention Attention Usage](#37-Axial_attention-Attention-Usage)\n\n- [Backbone Series](#Backbone-series)\n\n    - [1. ResNet Usage](#1-ResNet-Usage)\n\n    - [2. ResNeXt Usage](#2-ResNeXt-Usage)\n\n    - [3. MobileViT Usage](#3-MobileViT-Usage)\n\n    - [4. ConvMixer Usage](#4-ConvMixer-Usage)\n\n    - [5. ShuffleTransformer Usage](#5-ShuffleTransformer-Usage)\n\n    - [6. ConTNet Usage](#6-ConTNet-Usage)\n\n    - [7. HATNet Usage](#7-HATNet-Usage)\n\n    - [8. CoaT Usage](#8-CoaT-Usage)\n\n    - [9. PVT Usage](#9-PVT-Usage)\n\n    - [10. CPVT Usage](#10-CPVT-Usage)\n\n    - [11. PIT Usage](#11-PIT-Usage)\n\n    - [12. CrossViT Usage](#12-CrossViT-Usage)\n\n    - [13. TnT Usage](#13-TnT-Usage)\n\n    - [14. DViT Usage](#14-DViT-Usage)\n\n    - [15. CeiT Usage](#15-CeiT-Usage)\n\n    - [16. ConViT Usage](#16-ConViT-Usage)\n\n    - [17. CaiT Usage](#17-CaiT-Usage)\n\n    - [18. PatchConvnet Usage](#18-PatchConvnet-Usage)\n\n    - [19. DeiT Usage](#19-DeiT-Usage)\n\n    - [20. LeViT Usage](#20-LeViT-Usage)\n\n    - [21. VOLO Usage](#21-VOLO-Usage)\n    \n    - [22. Container Usage](#22-Container-Usage)\n\n    - [23. CMT Usage](#23-CMT-Usage)\n\n\n- [MLP Series](#mlp-series)\n\n    - [1. RepMLP Usage](#1-RepMLP-Usage)\n\n    - [2. MLP-Mixer Usage](#2-MLP-Mixer-Usage)\n\n    - [3. ResMLP Usage](#3-ResMLP-Usage)\n\n    - [4. gMLP Usage](#4-gMLP-Usage)\n\n    - [5. sMLP Usage](#5-sMLP-Usage)\n\n    - [6. vip-mlp Usage](#6-vip-mlp-Usage)\n\n- [Re-Parameter(ReP) Series](#Re-Parameter-series)\n\n    - [1. RepVGG Usage](#1-RepVGG-Usage)\n\n    - [2. ACNet Usage](#2-ACNet-Usage)\n\n    - [3. Diverse Branch Block(DDB) Usage](#3-Diverse-Branch-Block-Usage)\n\n- [Convolution Series](#Convolution-series)\n\n    - [1. Depthwise Separable Convolution Usage](#1-Depthwise-Separable-Convolution-Usage)\n\n    - [2. MBConv Usage](#2-MBConv-Usage)\n\n    - [3. Involution Usage](#3-Involution-Usage)\n\n    - [4. DynamicConv Usage](#4-DynamicConv-Usage)\n\n    - [5. CondConv Usage](#5-CondConv-Usage)\n\n***\n\n\n# Attention Series\n\n- Pytorch implementation of [\"Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks---arXiv 2021.05.05\"](https://arxiv.org/abs/2105.02358)\n\n- Pytorch implementation of [\"Attention Is All You Need---NIPS2017\"](https://arxiv.org/pdf/1706.03762.pdf)\n\n- Pytorch implementation of [\"Squeeze-and-Excitation Networks---CVPR2018\"](https://arxiv.org/abs/1709.01507)\n\n- Pytorch implementation of [\"Selective Kernel Networks---CVPR2019\"](https://arxiv.org/pdf/1903.06586.pdf)\n\n- Pytorch implementation of [\"CBAM: Convolutional Block Attention Module---ECCV2018\"](https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf)\n\n- Pytorch implementation of [\"BAM: Bottleneck Attention Module---BMCV2018\"](https://arxiv.org/pdf/1807.06514.pdf)\n\n- Pytorch implementation of [\"ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks---CVPR2020\"](https://arxiv.org/pdf/1910.03151.pdf)\n\n- Pytorch implementation of [\"Dual Attention Network for Scene Segmentation---CVPR2019\"](https://arxiv.org/pdf/1809.02983.pdf)\n\n- Pytorch implementation of [\"EPSANet: An Efficient Pyramid Split Attention Block on Convolutional Neural Network---arXiv 2021.05.30\"](https://arxiv.org/pdf/2105.14447.pdf)\n\n- Pytorch implementation of [\"ResT: An Efficient Transformer for Visual Recognition---arXiv 2021.05.28\"](https://arxiv.org/abs/2105.13677)\n\n- Pytorch implementation of [\"SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS---ICASSP 2021\"](https://arxiv.org/pdf/2102.00240.pdf)\n\n- Pytorch implementation of [\"MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning---arXiv 2019.11.17\"](https://arxiv.org/abs/1911.09483)\n\n- Pytorch implementation of [\"Spatial Group-wise Enhance: Improving Semantic Feature Learning in Convolutional Networks---arXiv 2019.05.23\"](https://arxiv.org/pdf/1905.09646.pdf)\n\n- Pytorch implementation of [\"A2-Nets: Double Attention Networks---NIPS2018\"](https://arxiv.org/pdf/1810.11579.pdf)\n\n\n- Pytorch implementation of [\"An Attention Free Transformer---ICLR2021 (Apple New Work)\"](https://arxiv.org/pdf/2105.14103v1.pdf)\n\n\n- Pytorch implementation of [VOLO: Vision Outlooker for Visual Recognition---arXiv 2021.06.24\"](https://arxiv.org/abs/2106.13112) \n  [【论文解析】](https://zhuanlan.zhihu.com/p/385561050)\n\n\n- Pytorch implementation of [Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition---arXiv 2021.06.23](https://arxiv.org/abs/2106.12368) \n  [【论文解析】](https://mp.weixin.qq.com/s/5gonUQgBho_m2O54jyXF_Q)\n\n\n- Pytorch implementation of [CoAtNet: Marrying Convolution and Attention for All Data Sizes---arXiv 2021.06.09](https://arxiv.org/abs/2106.04803) \n  [【论文解析】](https://zhuanlan.zhihu.com/p/385578588)\n\n\n- Pytorch implementation of [Scaling Local Self-Attention for Parameter Efficient Visual Backbones---CVPR2021 Oral](https://arxiv.org/pdf/2103.12731.pdf)  [【论文解析】](https://zhuanlan.zhihu.com/p/388598744)\n\n\n\n- Pytorch implementation of [Polarized Self-Attention: Towards High-quality Pixel-wise Regression---arXiv 2021.07.02](https://arxiv.org/abs/2107.00782)  [【论文解析】](https://zhuanlan.zhihu.com/p/389770482) \n\n\n- Pytorch implementation of [Contextual Transformer Networks for Visual Recognition---arXiv 2021.07.26](https://arxiv.org/abs/2107.12292)  [【论文解析】](https://zhuanlan.zhihu.com/p/394795481) \n\n\n- Pytorch implementation of [Residual Attention: A Simple but Effective Method for Multi-Label Recognition---ICCV2021](https://arxiv.org/abs/2108.02456) \n\n\n- Pytorch implementation of [S²-MLPv2: Improved Spatial-Shift MLP Architecture for Vision---arXiv 2021.08.02](https://arxiv.org/abs/2108.01072) [【论文解析】](https://zhuanlan.zhihu.com/p/397003638) \n\n- Pytorch implementation of [Global Filter Networks for Image Classification---arXiv 2021.07.01](https://arxiv.org/abs/2107.00645) \n\n- Pytorch implementation of [Rotate to Attend: Convolutional Triplet Attention Module---WACV 2021](https://arxiv.org/abs/2010.03045) \n\n- Pytorch implementation of [Coordinate Attention for Efficient Mobile Network Design ---CVPR 2021](https://arxiv.org/abs/2103.02907)\n\n- Pytorch implementation of [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2021.10.05](https://arxiv.org/abs/2110.02178)\n\n- Pytorch implementation of [Non-deep Networks---ArXiv 2021.10.20](https://arxiv.org/abs/2110.07641)\n\n- Pytorch implementation of [UFO-ViT: High Performance Linear Vision Transformer without Softmax---ArXiv 2021.09.29](https://arxiv.org/abs/2109.14382)\n\n- Pytorch implementation of [Separable Self-attention for Mobile Vision Transformers---ArXiv 2022.06.06](https://arxiv.org/abs/2206.02680)\n\n- Pytorch implementation of [On the Integration of Self-Attention and Convolution---ArXiv 2022.03.14](https://arxiv.org/pdf/2111.14556.pdf)\n\n- Pytorch implementation of [CROSSFORMER: A VERSATILE VISION TRANSFORMER HINGING ON CROSS-SCALE ATTENTION---ICLR 2022](https://arxiv.org/pdf/2108.00154.pdf)\n\n- Pytorch implementation of [Aggregating Global Features into Local Vision Transformer](https://arxiv.org/abs/2201.12903)\n\n- Pytorch implementation of [CCNet: Criss-Cross Attention for Semantic Segmentation](https://arxiv.org/abs/1811.11721)\n\n- Pytorch implementation of [Axial Attention in Multidimensional Transformers](https://arxiv.org/abs/1912.12180)\n***\n\n\n### 1. External Attention Usage\n#### 1.1. Paper\n[\"Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks\"](https://arxiv.org/abs/2105.02358)\n\n#### 1.2. Overview\n![](./model/img/External_Attention.png)\n\n#### 1.3. Usage Code\n```python\nfrom model.attention.ExternalAttention import ExternalAttention\nimport torch\n\ninput=torch.randn(50,49,512)\nea = ExternalAttention(d_model=512,S=8)\noutput=ea(input)\nprint(output.shape)\n```\n\n***\n\n\n### 2. Self Attention Usage\n#### 2.1. Paper\n[\"Attention Is All You Need\"](https://arxiv.org/pdf/1706.03762.pdf)\n\n#### 1.2. Overview\n![](./model/img/SA.png)\n\n#### 1.3. Usage Code\n```python\nfrom model.attention.SelfAttention import ScaledDotProductAttention\nimport torch\n\ninput=torch.randn(50,49,512)\nsa = ScaledDotProductAttention(d_model=512, d_k=512, d_v=512, h=8)\noutput=sa(input,input,input)\nprint(output.shape)\n```\n\n***\n\n### 3. Simplified Self Attention Usage\n#### 3.1. Paper\n[None]()\n\n#### 3.2. Overview\n![](./model/img/SSA.png)\n\n#### 3.3. Usage Code\n```python\nfrom model.attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttention\nimport torch\n\ninput=torch.randn(50,49,512)\nssa = SimplifiedScaledDotProductAttention(d_model=512, h=8)\noutput=ssa(input,input,input)\nprint(output.shape)\n\n```\n\n***\n\n### 4. Squeeze-and-Excitation Attention Usage\n#### 4.1. Paper\n[\"Squeeze-and-Excitation Networks\"](https://arxiv.org/abs/1709.01507)\n\n#### 4.2. Overview\n![](./model/img/SE.png)\n\n#### 4.3. Usage Code\n```python\nfrom model.attention.SEAttention import SEAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\nse = SEAttention(channel=512,reduction=8)\noutput=se(input)\nprint(output.shape)\n\n```\n\n***\n\n### 5. SK Attention Usage\n#### 5.1. Paper\n[\"Selective Kernel Networks\"](https://arxiv.org/pdf/1903.06586.pdf)\n\n#### 5.2. Overview\n![](./model/img/SK.png)\n\n#### 5.3. Usage Code\n```python\nfrom model.attention.SKAttention import SKAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\nse = SKAttention(channel=512,reduction=8)\noutput=se(input)\nprint(output.shape)\n\n```\n***\n\n### 6. CBAM Attention Usage\n#### 6.1. Paper\n[\"CBAM: Convolutional Block Attention Module\"](https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf)\n\n#### 6.2. Overview\n![](./model/img/CBAM1.png)\n\n![](./model/img/CBAM2.png)\n\n#### 6.3. Usage Code\n```python\nfrom model.attention.CBAM import CBAMBlock\nimport torch\n\ninput=torch.randn(50,512,7,7)\nkernel_size=input.shape[2]\ncbam = CBAMBlock(channel=512,reduction=16,kernel_size=kernel_size)\noutput=cbam(input)\nprint(output.shape)\n\n```\n\n***\n\n### 7. BAM Attention Usage\n#### 7.1. Paper\n[\"BAM: Bottleneck Attention Module\"](https://arxiv.org/pdf/1807.06514.pdf)\n\n#### 7.2. Overview\n![](./model/img/BAM.png)\n\n#### 7.3. Usage Code\n```python\nfrom model.attention.BAM import BAMBlock\nimport torch\n\ninput=torch.randn(50,512,7,7)\nbam = BAMBlock(channel=512,reduction=16,dia_val=2)\noutput=bam(input)\nprint(output.shape)\n\n```\n\n***\n\n### 8. ECA Attention Usage\n#### 8.1. Paper\n[\"ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks\"](https://arxiv.org/pdf/1910.03151.pdf)\n\n#### 8.2. Overview\n![](./model/img/ECA.png)\n\n#### 8.3. Usage Code\n```python\nfrom model.attention.ECAAttention import ECAAttention\nimport torch\n\ninput=torch.randn(50,512,7,7)\neca = ECAAttention(kernel_size=3)\noutput=eca(input)\nprint(output.shape)\n\n```\n\n***\n\n### 9. DANet Attention Usage\n#### 9.1. Paper\n[\"Dual Attention Network for Scene Segmentation\"](https://arxiv.org/pdf/1809.02983.pdf)\n\n#### 9.2. Overview\n![](./model/img/danet.png)\n\n#### 9.3. Usage Code\n```python\nfrom model.attention.DANet import DAModule\nimport torch\n\ninput=torch.randn(50,512,7,7)\ndanet=DAModule(d_model=512,kernel_size=3,H=7,W=7)\nprint(danet(input).shape)\n\n```\n\n***\n\n### 10. Pyramid Split Attention Usage\n\n#### 10.1. Paper\n[\"EPSANet: An Efficient Pyramid Split Attention Block on Convolutional Neural Network\"](https://arxiv.org/pdf/2105.14447.pdf)\n\n#### 10.2. Overview\n![](./model/img/psa.png)\n\n#### 10.3. Usage Code\n```python\nfrom model.attention.PSA import PSA\nimport torch\n\ninput=torch.randn(50,512,7,7)\npsa = PSA(channel=512,reduction=8)\noutput=psa(input)\nprint(output.shape)\n\n```\n\n***\n\n\n### 11. Efficient Multi-Head Self-Attention Usage\n\n#### 11.1. Paper\n[\"ResT: An Efficient Transformer for Visual Recognition\"](https://arxiv.org/abs/2105.13677)\n\n#### 11.2. Overview\n![](./model/img/EMSA.png)\n\n#### 11.3. Usage Code\n```python\n\nfrom model.attention.EMSA import EMSA\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,64,512)\nemsa = EMSA(d_model=512, d_k=512, d_v=512, h=8,H=8,W=8,ratio=2,apply_transform=True)\noutput=emsa(input,input,input)\nprint(output.shape)\n    \n```\n\n***\n\n\n### 12. Shuffle Attention Usage\n\n#### 12.1. Paper\n[\"SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS\"](https://arxiv.org/pdf/2102.00240.pdf)\n\n#### 12.2. Overview\n![](./model/img/ShuffleAttention.jpg)\n\n#### 12.3. Usage Code\n```python\n\nfrom model.attention.ShuffleAttention import ShuffleAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\ninput=torch.randn(50,512,7,7)\nse = ShuffleAttention(channel=512,G=8)\noutput=se(input)\nprint(output.shape)\n\n    \n```\n\n\n***\n\n\n### 13. MUSE Attention Usage\n\n#### 13.1. Paper\n[\"MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning\"](https://arxiv.org/abs/1911.09483)\n\n#### 13.2. Overview\n![](./model/img/MUSE.png)\n\n#### 13.3. Usage Code\n```python\nfrom model.attention.MUSEAttention import MUSEAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\ninput=torch.randn(50,49,512)\nsa = MUSEAttention(d_model=512, d_k=512, d_v=512, h=8)\noutput=sa(input,input,input)\nprint(output.shape)\n\n```\n\n***\n\n\n### 14. SGE Attention Usage\n\n#### 14.1. Paper\n[Spatial Group-wise Enhance: Improving Semantic Feature Learning in Convolutional Networks](https://arxiv.org/pdf/1905.09646.pdf)\n\n#### 14.2. Overview\n![](./model/img/SGE.png)\n\n#### 14.3. Usage Code\n```python\nfrom model.attention.SGE import SpatialGroupEnhance\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\nsge = SpatialGroupEnhance(groups=8)\noutput=sge(input)\nprint(output.shape)\n\n```\n\n***\n\n\n### 15. A2 Attention Usage\n\n#### 15.1. Paper\n[A2-Nets: Double Attention Networks](https://arxiv.org/pdf/1810.11579.pdf)\n\n#### 15.2. Overview\n![](./model/img/A2.png)\n\n#### 15.3. Usage Code\n```python\nfrom model.attention.A2Atttention import DoubleAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\na2 = DoubleAttention(512,128,128,True)\noutput=a2(input)\nprint(output.shape)\n\n```\n\n\n\n### 16. AFT Attention Usage\n\n#### 16.1. Paper\n[An Attention Free Transformer](https://arxiv.org/pdf/2105.14103v1.pdf)\n\n#### 16.2. Overview\n![](./model/img/AFT.jpg)\n\n#### 16.3. Usage Code\n```python\nfrom model.attention.AFT import AFT_FULL\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,49,512)\naft_full = AFT_FULL(d_model=512, n=49)\noutput=aft_full(input)\nprint(output.shape)\n\n```\n\n\n\n\n\n\n### 17. Outlook Attention Usage\n\n#### 17.1. Paper\n\n\n[VOLO: Vision Outlooker for Visual Recognition\"](https://arxiv.org/abs/2106.13112)\n\n\n#### 17.2. Overview\n![](./model/img/OutlookAttention.png)\n\n#### 17.3. Usage Code\n```python\nfrom model.attention.OutlookAttention import OutlookAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,28,28,512)\noutlook = OutlookAttention(dim=512)\noutput=outlook(input)\nprint(output.shape)\n\n```\n\n\n***\n\n\n\n\n\n\n### 18. ViP Attention Usage\n\n#### 18.1. Paper\n\n\n[Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition\"](https://arxiv.org/abs/2106.12368)\n\n\n#### 18.2. Overview\n![](./model/img/ViP.png)\n\n#### 18.3. Usage Code\n```python\n\nfrom model.attention.ViP import WeightedPermuteMLP\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(64,8,8,512)\nseg_dim=8\nvip=WeightedPermuteMLP(512,seg_dim)\nout=vip(input)\nprint(out.shape)\n\n```\n\n\n***\n\n\n\n\n\n### 19. CoAtNet Attention Usage\n\n#### 19.1. Paper\n\n\n[CoAtNet: Marrying Convolution and Attention for All Data Sizes\"](https://arxiv.org/abs/2106.04803) \n\n\n#### 19.2. Overview\nNone\n\n\n#### 19.3. Usage Code\n```python\n\nfrom model.attention.CoAtNet import CoAtNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,3,224,224)\nmbconv=CoAtNet(in_ch=3,image_size=224)\nout=mbconv(input)\nprint(out.shape)\n\n```\n\n\n***\n\n\n\n\n\n\n### 20. HaloNet Attention Usage\n\n#### 20.1. Paper\n\n\n[Scaling Local Self-Attention for Parameter Efficient Visual Backbones\"](https://arxiv.org/pdf/2103.12731.pdf) \n\n\n#### 20.2. Overview\n\n![](./model/img/HaloNet.png)\n\n#### 20.3. Usage Code\n```python\n\nfrom model.attention.HaloAttention import HaloAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,512,8,8)\nhalo = HaloAttention(dim=512,\n    block_size=2,\n    halo_size=1,)\noutput=halo(input)\nprint(output.shape)\n\n```\n\n\n***\n\n### 21. Polarized Self-Attention Usage\n\n#### 21.1. Paper\n\n[Polarized Self-Attention: Towards High-quality Pixel-wise Regression\"](https://arxiv.org/abs/2107.00782)  \n\n\n#### 21.2. Overview\n\n![](./model/img/PoSA.png)\n\n#### 21.3. Usage Code\n```python\n\nfrom model.attention.PolarizedSelfAttention import ParallelPolarizedSelfAttention,SequentialPolarizedSelfAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,512,7,7)\npsa = SequentialPolarizedSelfAttention(channel=512)\noutput=psa(input)\nprint(output.shape)\n\n\n```\n\n\n***\n\n\n### 22. CoTAttention Usage\n\n#### 22.1. Paper\n\n[Contextual Transformer Networks for Visual Recognition---arXiv 2021.07.26](https://arxiv.org/abs/2107.12292) \n\n\n#### 22.2. Overview\n\n![](./model/img/CoT.png)\n\n#### 22.3. Usage Code\n```python\n\nfrom model.attention.CoTAttention import CoTAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\ncot = CoTAttention(dim=512,kernel_size=3)\noutput=cot(input)\nprint(output.shape)\n\n\n\n```\n\n***\n\n\n### 23. Residual Attention Usage\n\n#### 23.1. Paper\n\n[Residual Attention: A Simple but Effective Method for Multi-Label Recognition---ICCV2021](https://arxiv.org/abs/2108.02456) \n\n\n#### 23.2. Overview\n\n![](./model/img/ResAtt.png)\n\n#### 23.3. Usage Code\n```python\n\nfrom model.attention.ResidualAttention import ResidualAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\nresatt = ResidualAttention(channel=512,num_class=1000,la=0.2)\noutput=resatt(input)\nprint(output.shape)\n\n\n\n```\n\n***\n\n\n\n### 24. S2 Attention Usage\n\n#### 24.1. Paper\n\n[S²-MLPv2: Improved Spatial-Shift MLP Architecture for Vision---arXiv 2021.08.02](https://arxiv.org/abs/2108.01072) \n\n\n#### 24.2. Overview\n\n![](./model/img/S2Attention.png)\n\n#### 24.3. Usage Code\n```python\nfrom model.attention.S2Attention import S2Attention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(50,512,7,7)\ns2att = S2Attention(channels=512)\noutput=s2att(input)\nprint(output.shape)\n\n```\n\n***\n\n\n\n### 25. GFNet Attention Usage\n\n#### 25.1. Paper\n\n[Global Filter Networks for Image Classification---arXiv 2021.07.01](https://arxiv.org/abs/2107.00645) \n\n\n#### 25.2. Overview\n\n![](./model/img/GFNet.jpg)\n\n#### 25.3. Usage Code - Implemented by [Wenliang Zhao (Author)](https://scholar.google.com/citations?user=lyPWvuEAAAAJ&hl=en)\n\n```python\nfrom model.attention.gfnet import GFNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nx = torch.randn(1, 3, 224, 224)\ngfnet = GFNet(embed_dim=384, img_size=224, patch_size=16, num_classes=1000)\nout = gfnet(x)\nprint(out.shape)\n\n```\n\n***\n\n\n### 26. TripletAttention Usage\n\n#### 26.1. Paper\n\n[Rotate to Attend: Convolutional Triplet Attention Module---CVPR 2021](https://arxiv.org/abs/2010.03045) \n\n#### 26.2. Overview\n\n![](./model/img/triplet.png)\n\n#### 26.3. Usage Code - Implemented by [digantamisra98](https://github.com/digantamisra98)\n\n```python\nfrom model.attention.TripletAttention import TripletAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\ninput=torch.randn(50,512,7,7)\ntriplet = TripletAttention()\noutput=triplet(input)\nprint(output.shape)\n```\n\n\n***\n\n\n### 27. Coordinate Attention Usage\n\n#### 27.1. Paper\n\n[Coordinate Attention for Efficient Mobile Network Design---CVPR 2021](https://arxiv.org/abs/2103.02907)\n\n\n#### 27.2. Overview\n\n![](./model/img/CoordAttention.png)\n\n#### 27.3. Usage Code - Implemented by [Andrew-Qibin](https://github.com/Andrew-Qibin)\n\n```python\nfrom model.attention.CoordAttention import CoordAtt\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninp=torch.rand([2, 96, 56, 56])\ninp_dim, oup_dim = 96, 96\nreduction=32\n\ncoord_attention = CoordAtt(inp_dim, oup_dim, reduction=reduction)\noutput=coord_attention(inp)\nprint(output.shape)\n```\n\n***\n\n\n### 28. MobileViT Attention Usage\n\n#### 28.1. Paper\n\n[MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2021.10.05](https://arxiv.org/abs/2103.02907)\n\n\n#### 28.2. Overview\n\n![](./model/img/MobileViTAttention.png)\n\n#### 28.3. Usage Code\n\n```python\nfrom model.attention.MobileViTAttention import MobileViTAttention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    m=MobileViTAttention()\n    input=torch.randn(1,3,49,49)\n    output=m(input)\n    print(output.shape)  #output:(1,3,49,49)\n    \n```\n\n***\n\n\n### 29. ParNet Attention Usage\n\n#### 29.1. Paper\n\n[Non-deep Networks---ArXiv 2021.10.20](https://arxiv.org/abs/2110.07641)\n\n\n#### 29.2. Overview\n\n![](./model/img/ParNet.png)\n\n#### 29.3. Usage Code\n\n```python\nfrom model.attention.ParNetAttention import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,7,7)\n    pna = ParNetAttention(channel=512)\n    output=pna(input)\n    print(output.shape) #50,512,7,7\n    \n```\n\n***\n\n\n### 30. UFO Attention Usage\n\n#### 30.1. Paper\n\n[UFO-ViT: High Performance Linear Vision Transformer without Softmax---ArXiv 2021.09.29](https://arxiv.org/abs/2110.07641)\n\n\n#### 30.2. Overview\n\n![](./model/img/UFO.png)\n\n#### 30.3. Usage Code\n\n```python\nfrom model.attention.UFOAttention import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    ufo = UFOAttention(d_model=512, d_k=512, d_v=512, h=8)\n    output=ufo(input,input,input)\n    print(output.shape) #[50, 49, 512]\n    \n```\n\n-\n\n### 31. ACmix Attention Usage\n\n#### 31.1. Paper\n\n[On the Integration of Self-Attention and Convolution](https://arxiv.org/pdf/2111.14556.pdf)\n\n#### 31.2. Usage Code\n\n```python\nfrom model.attention.ACmix import ACmix\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(50,256,7,7)\n    acmix = ACmix(in_planes=256, out_planes=256)\n    output=acmix(input)\n    print(output.shape)\n    \n```\n\n### 32. MobileViTv2 Attention Usage\n\n#### 32.1. Paper\n\n[Separable Self-attention for Mobile Vision Transformers---ArXiv 2022.06.06](https://arxiv.org/abs/2206.02680)\n\n\n#### 32.2. Overview\n\n![](./model/img/MobileViTv2.png)\n\n#### 32.3. Usage Code\n\n```python\nfrom model.attention.MobileViTv2Attention import MobileViTv2Attention\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,49,512)\n    sa = MobileViTv2Attention(d_model=512)\n    output=sa(input)\n    print(output.shape)\n    \n```\n\n### 33. DAT Attention Usage\n\n#### 33.1. Paper\n\n[Vision Transformer with Deformable Attention---CVPR2022](https://arxiv.org/abs/2201.00520)\n\n#### 33.2. Usage Code\n\n```python\nfrom model.attention.DAT import DAT\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = DAT(\n        img_size=224,\n        patch_size=4,\n        num_classes=1000,\n        expansion=4,\n        dim_stem=96,\n        dims=[96, 192, 384, 768],\n        depths=[2, 2, 6, 2],\n        stage_spec=[['L', 'S'], ['L', 'S'], ['L', 'D', 'L', 'D', 'L', 'D'], ['L', 'D']],\n        heads=[3, 6, 12, 24],\n        window_sizes=[7, 7, 7, 7] ,\n        groups=[-1, -1, 3, 6],\n        use_pes=[False, False, True, True],\n        dwc_pes=[False, False, False, False],\n        strides=[-1, -1, 1, 1],\n        sr_ratios=[-1, -1, -1, -1],\n        offset_range_factor=[-1, -1, 2, 2],\n        no_offs=[False, False, False, False],\n        fixed_pes=[False, False, False, False],\n        use_dwc_mlps=[False, False, False, False],\n        use_conv_patches=False,\n        drop_rate=0.0,\n        attn_drop_rate=0.0,\n        drop_path_rate=0.2,\n    )\n    output=model(input)\n    print(output[0].shape)\n    \n```\n\n### 34. CrossFormer Attention Usage\n\n#### 34.1. Paper\n\n[CROSSFORMER: A VERSATILE VISION TRANSFORMER HINGING ON CROSS-SCALE ATTENTION---ICLR 2022](https://arxiv.org/pdf/2108.00154.pdf)\n\n#### 34.2. Usage Code\n\n```python\nfrom model.attention.Crossformer import CrossFormer\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CrossFormer(img_size=224,\n        patch_size=[4, 8, 16, 32],\n        in_chans= 3,\n        num_classes=1000,\n        embed_dim=48,\n        depths=[2, 2, 6, 2],\n        num_heads=[3, 6, 12, 24],\n        group_size=[7, 7, 7, 7],\n        mlp_ratio=4.,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.0,\n        drop_path_rate=0.1,\n        ape=False,\n        patch_norm=True,\n        use_checkpoint=False,\n        merge_size=[[2, 4], [2,4], [2, 4]]\n    )\n    output=model(input)\n    print(output.shape)\n    \n```\n\n### 35. MOATransformer Attention Usage\n\n#### 35.1. Paper\n\n[Aggregating Global Features into Local Vision Transformer](https://arxiv.org/abs/2201.12903)\n\n#### 35.2. Usage Code\n\n```python\nfrom model.attention.MOATransformer import MOATransformer\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = MOATransformer(\n        img_size=224,\n        patch_size=4,\n        in_chans=3,\n        num_classes=1000,\n        embed_dim=96,\n        depths=[2, 2, 6],\n        num_heads=[3, 6, 12],\n        window_size=14,\n        mlp_ratio=4.,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.0,\n        drop_path_rate=0.1,\n        ape=False,\n        patch_norm=True,\n        use_checkpoint=False\n    )\n    output=model(input)\n    print(output.shape)\n    \n```\n\n### 36. CrissCrossAttention Attention Usage\n\n#### 36.1. Paper\n\n[CCNet: Criss-Cross Attention for Semantic Segmentation](https://arxiv.org/abs/1811.11721)\n\n#### 36.2. Usage Code\n\n```python\nfrom model.attention.CrissCrossAttention import CrissCrossAttention\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(3, 64, 7, 7)\n    model = CrissCrossAttention(64)\n    outputs = model(input)\n    print(outputs.shape)\n    \n```\n\n### 37. Axial_attention Attention Usage\n\n#### 37.1. Paper\n\n[Axial Attention in Multidimensional Transformers](https://arxiv.org/abs/1912.12180)\n\n#### 37.2. Usage Code\n\n```python\nfrom model.attention.Axial_attention import AxialImageTransformer\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(3, 128, 7, 7)\n    model = AxialImageTransformer(\n        dim = 128,\n        depth = 12,\n        reversible = True\n    )\n    outputs = model(input)\n    print(outputs.shape)\n    \n```\n\n***\n\n\n# Backbone Series\n\n- Pytorch implementation of [\"Deep Residual Learning for Image Recognition---CVPR2016 Best Paper\"](https://arxiv.org/pdf/1512.03385.pdf)\n\n- Pytorch implementation of [\"Aggregated Residual Transformations for Deep Neural Networks---CVPR2017\"](https://arxiv.org/abs/1611.05431v2)\n\n- Pytorch implementation of [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2020.10.05](https://arxiv.org/abs/2103.02907)\n\n- Pytorch implementation of [Patches Are All You Need?---ICLR2022 (Under Review)](https://openreview.net/forum?id=TVHS5Y4dNvM)\n\n- Pytorch implementation of [Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer---ArXiv 2021.06.07](https://arxiv.org/abs/2106.03650)\n\n- Pytorch implementation of [ConTNet: Why not use convolution and transformer at the same time?---ArXiv 2021.04.27](https://arxiv.org/abs/2104.13497)\n\n- Pytorch implementation of [Vision Transformers with Hierarchical Attention---ArXiv 2022.06.15](https://arxiv.org/abs/2106.03180)\n\n- Pytorch implementation of [Co-Scale Conv-Attentional Image Transformers---ArXiv 2021.08.26](https://arxiv.org/abs/2104.06399)\n\n- Pytorch implementation of [Conditional Positional Encodings for Vision Transformers](https://arxiv.org/abs/2102.10882)\n\n- Pytorch implementation of [Rethinking Spatial Dimensions of Vision Transformers---ICCV 2021](https://arxiv.org/abs/2103.16302)\n\n- Pytorch implementation of [CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification---ICCV 2021](https://arxiv.org/abs/2103.14899)\n\n- Pytorch implementation of [Transformer in Transformer---NeurIPS 2021](https://arxiv.org/abs/2103.00112)\n\n- Pytorch implementation of [DeepViT: Towards Deeper Vision Transformer](https://arxiv.org/abs/2103.11886)\n\n- Pytorch implementation of [Incorporating Convolution Designs into Visual Transformers](https://arxiv.org/abs/2103.11816)\n***\n\n- Pytorch implementation of [ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases](https://arxiv.org/abs/2103.10697)\n\n- Pytorch implementation of [Augmenting Convolutional networks with attention-based aggregation](https://arxiv.org/abs/2112.13692)\n\n- Pytorch implementation of [Going deeper with Image Transformers---ICCV 2021 (Oral)](https://arxiv.org/abs/2103.17239)\n\n- Pytorch implementation of [Training data-efficient image transformers & distillation through attention---ICML 2021](https://arxiv.org/abs/2012.12877)\n\n- Pytorch implementation of [LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference](https://arxiv.org/abs/2104.01136)\n\n- Pytorch implementation of [VOLO: Vision Outlooker for Visual Recognition](https://arxiv.org/abs/2106.13112)\n\n- Pytorch implementation of [Container: Context Aggregation Network---NeuIPS 2021](https://arxiv.org/abs/2106.01401)\n\n- Pytorch implementation of [CMT: Convolutional Neural Networks Meet Vision Transformers---CVPR 2022](https://arxiv.org/abs/2107.06263)\n\n- Pytorch implementation of [Vision Transformer with Deformable Attention---CVPR 2022](https://arxiv.org/abs/2201.00520)\n\n\n### 1. ResNet Usage\n#### 1.1. Paper\n[\"Deep Residual Learning for Image Recognition---CVPR2016 Best Paper\"](https://arxiv.org/pdf/1512.03385.pdf)\n\n#### 1.2. Overview\n![](./model/img/resnet.png)\n![](./model/img/resnet2.jpg)\n\n#### 1.3. Usage Code\n```python\n\nfrom model.backbone.resnet import ResNet50,ResNet101,ResNet152\nimport torch\nif __name__ == '__main__':\n    input=torch.randn(50,3,224,224)\n    resnet50=ResNet50(1000)\n    # resnet101=ResNet101(1000)\n    # resnet152=ResNet152(1000)\n    out=resnet50(input)\n    print(out.shape)\n\n```\n\n\n### 2. ResNeXt Usage\n#### 2.1. Paper\n\n[\"Aggregated Residual Transformations for Deep Neural Networks---CVPR2017\"](https://arxiv.org/abs/1611.05431v2)\n\n#### 2.2. Overview\n![](./model/img/resnext.png)\n\n#### 2.3. Usage Code\n```python\n\nfrom model.backbone.resnext import ResNeXt50,ResNeXt101,ResNeXt152\nimport torch\n\nif __name__ == '__main__':\n    input=torch.randn(50,3,224,224)\n    resnext50=ResNeXt50(1000)\n    # resnext101=ResNeXt101(1000)\n    # resnext152=ResNeXt152(1000)\n    out=resnext50(input)\n    print(out.shape)\n\n\n```\n\n\n\n### 3. MobileViT Usage\n#### 3.1. Paper\n\n[MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2020.10.05](https://arxiv.org/abs/2103.02907)\n\n#### 3.2. Overview\n![](./model/img/mobileViT.jpg)\n\n#### 3.3. Usage Code\n```python\n\nfrom model.backbone.MobileViT import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n\n    ### mobilevit_xxs\n    mvit_xxs=mobilevit_xxs()\n    out=mvit_xxs(input)\n    print(out.shape)\n\n    ### mobilevit_xs\n    mvit_xs=mobilevit_xs()\n    out=mvit_xs(input)\n    print(out.shape)\n\n\n    ### mobilevit_s\n    mvit_s=mobilevit_s()\n    out=mvit_s(input)\n    print(out.shape)\n\n```\n\n\n\n\n\n### 4. ConvMixer Usage\n#### 4.1. Paper\n[Patches Are All You Need?---ICLR2022 (Under Review)](https://openreview.net/forum?id=TVHS5Y4dNvM)\n#### 4.2. Overview\n![](./model/img/ConvMixer.png)\n\n#### 4.3. Usage Code\n```python\n\nfrom model.backbone.ConvMixer import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    x=torch.randn(1,3,224,224)\n    convmixer=ConvMixer(dim=512,depth=12)\n    out=convmixer(x)\n    print(out.shape)  #[1, 1000]\n\n\n```\n\n### 5. ShuffleTransformer Usage\n#### 5.1. Paper\n[Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer](https://arxiv.org/pdf/2106.03650.pdf)\n\n#### 5.2. Usage Code\n```python\n\nfrom model.backbone.ShuffleTransformer import ShuffleTransformer\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    sft = ShuffleTransformer()\n    output=sft(input)\n    print(output.shape)\n\n\n```\n\n### 6. ConTNet Usage\n#### 6.1. Paper\n[ConTNet: Why not use convolution and transformer at the same time?](https://arxiv.org/abs/2104.13497)\n\n#### 6.2. Usage Code\n```python\n\nfrom model.backbone.ConTNet import ConTNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == \"__main__\":\n    model = build_model(use_avgdown=True, relative=True, qkv_bias=True, pre_norm=True)\n    input = torch.randn(1, 3, 224, 224)\n    out = model(input)\n    print(out.shape)\n\n\n```\n\n### 7 HATNet Usage\n#### 7.1. Paper\n[Vision Transformers with Hierarchical Attention](https://arxiv.org/abs/2106.03180)\n\n#### 7.2. Usage Code\n```python\n\nfrom model.backbone.HATNet import HATNet\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    hat = HATNet(dims=[48, 96, 240, 384], head_dim=48, expansions=[8, 8, 4, 4],\n        grid_sizes=[8, 7, 7, 1], ds_ratios=[8, 4, 2, 1], depths=[2, 2, 6, 3])\n    output=hat(input)\n    print(output.shape)\n\n\n```\n\n### 8 CoaT Usage\n#### 8.1. Paper\n[Co-Scale Conv-Attentional Image Transformers](https://arxiv.org/abs/2104.06399)\n\n#### 8.2. Usage Code\n```python\n\nfrom model.backbone.CoaT import CoaT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CoaT(patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4])\n    output=model(input)\n    print(output.shape) # torch.Size([1, 1000])\n\n```\n\n### 9 PVT Usage\n#### 9.1. Paper\n[PVT v2: Improved Baselines with Pyramid Vision Transformer](https://arxiv.org/pdf/2106.13797.pdf)\n\n#### 9.2. Usage Code\n```python\n\nfrom model.backbone.PVT import PyramidVisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = PyramidVisionTransformer(\n        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1])\n    output=model(input)\n    print(output.shape)\n\n```\n\n\n### 10 CPVT Usage\n#### 10.1. Paper\n[Conditional Positional Encodings for Vision Transformers](https://arxiv.org/abs/2102.10882)\n\n#### 10.2. Usage Code\n```python\n\nfrom model.backbone.CPVT import CPVTV2\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CPVTV2(\n        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1])\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 11 PIT Usage\n#### 11.1. Paper\n[Rethinking Spatial Dimensions of Vision Transformers](https://arxiv.org/abs/2103.16302)\n\n#### 11.2. Usage Code\n```python\n\nfrom model.backbone.PIT import PoolingTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = PoolingTransformer(\n        image_size=224,\n        patch_size=14,\n        stride=7,\n        base_dims=[64, 64, 64],\n        depth=[3, 6, 4],\n        heads=[4, 8, 16],\n        mlp_ratio=4\n    )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 12 CrossViT Usage\n#### 12.1. Paper\n[CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification](https://arxiv.org/abs/2103.14899)\n\n#### 12.2. Usage Code\n```python\n\nfrom model.backbone.CrossViT import VisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == \"__main__\":\n    input=torch.randn(1,3,224,224)\n    model = VisionTransformer(\n        img_size=[240, 224],\n        patch_size=[12, 16], \n        embed_dim=[192, 384], \n        depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],\n        num_heads=[6, 6], \n        mlp_ratio=[4, 4, 1], \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n    )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 13 TnT Usage\n#### 13.1. Paper\n[Transformer in Transformer](https://arxiv.org/abs/2103.00112)\n\n#### 13.2. Usage Code\n```python\n\nfrom model.backbone.TnT import TNT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = TNT(\n        img_size=224, \n        patch_size=16, \n        outer_dim=384, \n        inner_dim=24, \n        depth=12,\n        outer_num_heads=6, \n        inner_num_heads=4, \n        qkv_bias=False,\n        inner_stride=4)\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 14 DViT Usage\n#### 14.1. Paper\n[DeepViT: Towards Deeper Vision Transformer](https://arxiv.org/abs/2103.11886)\n\n#### 14.2. Usage Code\n```python\n\nfrom model.backbone.DViT import DeepVisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = DeepVisionTransformer(\n        patch_size=16, embed_dim=384, \n        depth=[False] * 16, \n        apply_transform=[False] * 0 + [True] * 32, \n        num_heads=12, \n        mlp_ratio=3, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 15 CeiT Usage\n#### 15.1. Paper\n[Incorporating Convolution Designs into Visual Transformers](https://arxiv.org/abs/2103.11816)\n\n#### 15.2. Usage Code\n```python\n\nfrom model.backbone.CeiT import CeIT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CeIT(\n        hybrid_backbone=Image2Tokens(),\n        patch_size=4, \n        embed_dim=192, \n        depth=12, \n        num_heads=3, \n        mlp_ratio=4, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 16 ConViT Usage\n#### 16.1. Paper\n[ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases](https://arxiv.org/abs/2103.10697)\n\n#### 16.2. Usage Code\n```python\n\nfrom model.backbone.ConViT import VisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VisionTransformer(\n        num_heads=16,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 17 CaiT Usage\n#### 17.1. Paper\n[Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239)\n\n#### 17.2. Usage Code\n```python\n\nfrom model.backbone.CaiT import CaiT\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CaiT(\n        img_size= 224,\n        patch_size=16, \n        embed_dim=192, \n        depth=24, \n        num_heads=4, \n        mlp_ratio=4, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        init_scale=1e-5,\n        depth_token_only=2\n        )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 18 PatchConvnet Usage\n#### 18.1. Paper\n[Augmenting Convolutional networks with attention-based aggregation](https://arxiv.org/abs/2112.13692)\n\n#### 18.2. Usage Code\n```python\n\nfrom model.backbone.PatchConvnet import PatchConvnet\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = PatchConvnet(\n        patch_size=16,\n        embed_dim=384,\n        depth=60,\n        num_heads=1,\n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        Patch_layer=ConvStem,\n        Attention_block=Conv_blocks_se,\n        depth_token_only=1,\n        mlp_ratio_clstk=3.0,\n    )\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 19 DeiT Usage\n#### 19.1. Paper\n[Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877)\n\n#### 19.2. Usage Code\n```python\n\nfrom model.backbone.DeiT import DistilledVisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = DistilledVisionTransformer(\n        patch_size=16, \n        embed_dim=384, \n        depth=12, \n        num_heads=6, \n        mlp_ratio=4, \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6)\n        )\n    output=model(input)\n    print(output[0].shape)\n\n```\n\n### 20 LeViT Usage\n#### 20.1. Paper\n[LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference](https://arxiv.org/abs/2104.01136)\n\n#### 20.2. Usage Code\n```python\n\nfrom model.backbone.LeViT import *\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    for name in specification:\n        input=torch.randn(1,3,224,224)\n        model = globals()[name](fuse=True, pretrained=False)\n        model.eval()\n        output = model(input)\n        print(output.shape)\n\n```\n\n### 21 VOLO Usage\n#### 21.1. Paper\n[VOLO: Vision Outlooker for Visual Recognition](https://arxiv.org/abs/2106.13112)\n\n#### 21.2. Usage Code\n```python\n\nfrom model.backbone.VOLO import VOLO\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VOLO([4, 4, 8, 2],\n                 embed_dims=[192, 384, 384, 384],\n                 num_heads=[6, 12, 12, 12],\n                 mlp_ratios=[3, 3, 3, 3],\n                 downsamples=[True, False, False, False],\n                 outlook_attention=[True, False, False, False ],\n                 post_layers=['ca', 'ca'],\n                 )\n    output=model(input)\n    print(output[0].shape)\n\n```\n\n### 22 Container Usage\n#### 22.1. Paper\n[Container: Context Aggregation Network](https://arxiv.org/abs/2106.01401)\n\n#### 22.2. Usage Code\n```python\n\nfrom model.backbone.Container import VisionTransformer\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VisionTransformer(\n        img_size=[224, 56, 28, 14], \n        patch_size=[4, 2, 2, 2], \n        embed_dim=[64, 128, 320, 512], \n        depth=[3, 4, 8, 3], \n        num_heads=16, \n        mlp_ratio=[8, 8, 4, 4], \n        qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6))\n    output=model(input)\n    print(output.shape)\n\n```\n\n### 23 CMT Usage\n#### 23.1. Paper\n[CMT: Convolutional Neural Networks Meet Vision Transformers](https://arxiv.org/abs/2107.06263)\n\n#### 23.2. Usage Code\n```python\n\nfrom model.backbone.CMT import CMT_Tiny\nimport torch\nfrom torch import nn\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = CMT_Tiny()\n    output=model(input)\n    print(output[0].shape)\n\n```\n\n\n\n\n\n\n# MLP Series\n\n- Pytorch implementation of [\"RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition---arXiv 2021.05.05\"](https://arxiv.org/pdf/2105.01883v1.pdf)\n\n- Pytorch implementation of [\"MLP-Mixer: An all-MLP Architecture for Vision---arXiv 2021.05.17\"](https://arxiv.org/pdf/2105.01601.pdf)\n\n- Pytorch implementation of [\"ResMLP: Feedforward networks for image classification with data-efficient training---arXiv 2021.05.07\"](https://arxiv.org/pdf/2105.03404.pdf)\n\n- Pytorch implementation of [\"Pay Attention to MLPs---arXiv 2021.05.17\"](https://arxiv.org/abs/2105.08050)\n\n\n- Pytorch implementation of [\"Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?---arXiv 2021.09.12\"](https://arxiv.org/abs/2109.05422)\n\n### 1. RepMLP Usage\n#### 1.1. Paper\n[\"RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition\"](https://arxiv.org/pdf/2105.01883v1.pdf)\n\n#### 1.2. Overview\n![](./model/img/repmlp.png)\n\n#### 1.3. Usage Code\n```python\nfrom model.mlp.repmlp import RepMLP\nimport torch\nfrom torch import nn\n\nN=4 #batch size\nC=512 #input dim\nO=1024 #output dim\nH=14 #image height\nW=14 #image width\nh=7 #patch height\nw=7 #patch width\nfc1_fc2_reduction=1 #reduction ratio\nfc3_groups=8 # groups\nrepconv_kernels=[1,3,5,7] #kernel list\nrepmlp=RepMLP(C,O,H,W,h,w,fc1_fc2_reduction,fc3_groups,repconv_kernels=repconv_kernels)\nx=torch.randn(N,C,H,W)\nrepmlp.eval()\nfor module in repmlp.modules():\n    if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):\n        nn.init.uniform_(module.running_mean, 0, 0.1)\n        nn.init.uniform_(module.running_var, 0, 0.1)\n        nn.init.uniform_(module.weight, 0, 0.1)\n        nn.init.uniform_(module.bias, 0, 0.1)\n\n#training result\nout=repmlp(x)\n#inference result\nrepmlp.switch_to_deploy()\ndeployout = repmlp(x)\n\nprint(((deployout-out)**2).sum())\n```\n\n### 2. MLP-Mixer Usage\n#### 2.1. Paper\n[\"MLP-Mixer: An all-MLP Architecture for Vision\"](https://arxiv.org/pdf/2105.01601.pdf)\n\n#### 2.2. Overview\n![](./model/img/mlpmixer.png)\n\n#### 2.3. Usage Code\n```python\nfrom model.mlp.mlp_mixer import MlpMixer\nimport torch\nmlp_mixer=MlpMixer(num_classes=1000,num_blocks=10,patch_size=10,tokens_hidden_dim=32,channels_hidden_dim=1024,tokens_mlp_dim=16,channels_mlp_dim=1024)\ninput=torch.randn(50,3,40,40)\noutput=mlp_mixer(input)\nprint(output.shape)\n```\n\n***\n\n### 3. ResMLP Usage\n#### 3.1. Paper\n[\"ResMLP: Feedforward networks for image classification with data-efficient training\"](https://arxiv.org/pdf/2105.03404.pdf)\n\n#### 3.2. Overview\n![](./model/img/resmlp.png)\n\n#### 3.3. Usage Code\n```python\nfrom model.mlp.resmlp import ResMLP\nimport torch\n\ninput=torch.randn(50,3,14,14)\nresmlp=ResMLP(dim=128,image_size=14,patch_size=7,class_num=1000)\nout=resmlp(input)\nprint(out.shape) #the last dimention is class_num\n```\n\n***\n\n### 4. gMLP Usage\n#### 4.1. Paper\n[\"Pay Attention to MLPs\"](https://arxiv.org/abs/2105.08050)\n\n#### 4.2. Overview\n![](./model/img/gMLP.jpg)\n\n#### 4.3. Usage Code\n```python\nfrom model.mlp.g_mlp import gMLP\nimport torch\n\nnum_tokens=10000\nbs=50\nlen_sen=49\nnum_layers=6\ninput=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen\ngmlp = gMLP(num_tokens=num_tokens,len_sen=len_sen,dim=512,d_ff=1024)\noutput=gmlp(input)\nprint(output.shape)\n```\n\n***\n\n### 5. sMLP Usage\n#### 5.1. Paper\n[\"Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?\"](https://arxiv.org/abs/2109.05422)\n\n#### 5.2. Overview\n![](./model/img/sMLP.jpg)\n\n#### 5.3. Usage Code\n```python\nfrom model.mlp.sMLP_block import sMLPBlock\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(50,3,224,224)\n    smlp=sMLPBlock(h=224,w=224)\n    out=smlp(input)\n    print(out.shape)\n```\n\n### 6. vip-mlp Usage\n#### 6.1. Paper\n[\"Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition\"](https://arxiv.org/abs/2106.12368)\n\n#### 6.2. Usage Code\n```python\nfrom model.mlp.vip-mlp import VisionPermutator\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VisionPermutator(\n        layers=[4, 3, 8, 3], \n        embed_dims=[384, 384, 384, 384], \n        patch_size=14, \n        transitions=[False, False, False, False],\n        segment_dim=[16, 16, 16, 16], \n        mlp_ratios=[3, 3, 3, 3], \n        mlp_fn=WeightedPermuteMLP\n    )\n    output=model(input)\n    print(output.shape)\n```\n\n\n# Re-Parameter Series\n\n- Pytorch implementation of [\"RepVGG: Making VGG-style ConvNets Great Again---CVPR2021\"](https://arxiv.org/abs/2101.03697)\n\n- Pytorch implementation of [\"ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks---ICCV2019\"](https://arxiv.org/abs/1908.03930)\n\n- Pytorch implementation of [\"Diverse Branch Block: Building a Convolution as an Inception-like Unit---CVPR2021\"](https://arxiv.org/abs/2103.13425)\n\n\n***\n\n### 1. RepVGG Usage\n#### 1.1. Paper\n[\"RepVGG: Making VGG-style ConvNets Great Again\"](https://arxiv.org/abs/2101.03697)\n\n#### 1.2. Overview\n![](./model/img/repvgg.png)\n\n#### 1.3. Usage Code\n```python\n\nfrom model.rep.repvgg import RepBlock\nimport torch\n\n\ninput=torch.randn(50,512,49,49)\nrepblock=RepBlock(512,512)\nrepblock.eval()\nout=repblock(input)\nrepblock._switch_to_deploy()\nout2=repblock(input)\nprint('difference between vgg and repvgg')\nprint(((out2-out)**2).sum())\n```\n\n\n\n***\n\n### 2. ACNet Usage\n#### 2.1. Paper\n[\"ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks\"](https://arxiv.org/abs/1908.03930)\n\n#### 2.2. Overview\n![](./model/img/acnet.png)\n\n#### 2.3. Usage Code\n```python\nfrom model.rep.acnet import ACNet\nimport torch\nfrom torch import nn\n\ninput=torch.randn(50,512,49,49)\nacnet=ACNet(512,512)\nacnet.eval()\nout=acnet(input)\nacnet._switch_to_deploy()\nout2=acnet(input)\nprint('difference:')\nprint(((out2-out)**2).sum())\n\n```\n\n\n\n***\n\n### 2. Diverse Branch Block Usage\n#### 2.1. Paper\n[\"Diverse Branch Block: Building a Convolution as an Inception-like Unit\"](https://arxiv.org/abs/2103.13425)\n\n#### 2.2. Overview\n![](./model/img/ddb.png)\n\n#### 2.3. Usage Code\n##### 2.3.1 Transform I\n```python\nfrom model.rep.ddb import transI_conv_bn\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n#conv+bn\nconv1=nn.Conv2d(64,64,3,padding=1)\nbn1=nn.BatchNorm2d(64)\nbn1.eval()\nout1=bn1(conv1(input))\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transI_conv_bn(conv1,bn1)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.2 Transform II\n```python\nfrom model.rep.ddb import transII_conv_branch\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1=nn.Conv2d(64,64,3,padding=1)\nconv2=nn.Conv2d(64,64,3,padding=1)\nout1=conv1(input)+conv2(input)\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transII_conv_branch(conv1,conv2)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.3 Transform III\n```python\nfrom model.rep.ddb import transIII_conv_sequential\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1=nn.Conv2d(64,64,1,padding=0,bias=False)\nconv2=nn.Conv2d(64,64,3,padding=1,bias=False)\nout1=conv2(conv1(input))\n\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1,bias=False)\nconv_fuse.weight.data=transIII_conv_sequential(conv1,conv2)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.4 Transform IV\n```python\nfrom model.rep.ddb import transIV_conv_concat\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1=nn.Conv2d(64,32,3,padding=1)\nconv2=nn.Conv2d(64,32,3,padding=1)\nout1=torch.cat([conv1(input),conv2(input)],dim=1)\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transIV_conv_concat(conv1,conv2)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n##### 2.3.5 Transform V\n```python\nfrom model.rep.ddb import transV_avg\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\navg=nn.AvgPool2d(kernel_size=3,stride=1)\nout1=avg(input)\n\nconv=transV_avg(64,3)\nout2=conv(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n\n##### 2.3.6 Transform VI\n```python\nfrom model.rep.ddb import transVI_conv_scale\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,64,7,7)\n\n#conv+conv\nconv1x1=nn.Conv2d(64,64,1)\nconv1x3=nn.Conv2d(64,64,(1,3),padding=(0,1))\nconv3x1=nn.Conv2d(64,64,(3,1),padding=(1,0))\nout1=conv1x1(input)+conv1x3(input)+conv3x1(input)\n\n#conv_fuse\nconv_fuse=nn.Conv2d(64,64,3,padding=1)\nconv_fuse.weight.data,conv_fuse.bias.data=transVI_conv_scale(conv1x1,conv1x3,conv3x1)\nout2=conv_fuse(input)\n\nprint(\"difference:\",((out2-out1)**2).sum().item())\n```\n\n\n\n\n\n# Convolution Series\n\n- Pytorch implementation of [\"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications---CVPR2017\"](https://arxiv.org/abs/1704.04861)\n\n- Pytorch implementation of [\"Efficientnet: Rethinking model scaling for convolutional neural networks---PMLR2019\"](http://proceedings.mlr.press/v97/tan19a.html)\n\n- Pytorch implementation of [\"Involution: Inverting the Inherence of Convolution for Visual Recognition---CVPR2021\"](https://arxiv.org/abs/2103.06255)\n\n- Pytorch implementation of [\"Dynamic Convolution: Attention over Convolution Kernels---CVPR2020 Oral\"](https://arxiv.org/abs/1912.03458)\n\n- Pytorch implementation of [\"CondConv: Conditionally Parameterized Convolutions for Efficient Inference---NeurIPS2019\"](https://arxiv.org/abs/1904.04971)\n\n***\n\n### 1. Depthwise Separable Convolution Usage\n#### 1.1. Paper\n[\"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications\"](https://arxiv.org/abs/1704.04861)\n\n#### 1.2. Overview\n![](./model/img/DepthwiseSeparableConv.png)\n\n#### 1.3. Usage Code\n```python\nfrom model.conv.DepthwiseSeparableConvolution import DepthwiseSeparableConvolution\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,3,224,224)\ndsconv=DepthwiseSeparableConvolution(3,64)\nout=dsconv(input)\nprint(out.shape)\n```\n\n***\n\n\n### 2. MBConv Usage\n#### 2.1. Paper\n[\"Efficientnet: Rethinking model scaling for convolutional neural networks\"](http://proceedings.mlr.press/v97/tan19a.html)\n\n#### 2.2. Overview\n![](./model/img/MBConv.jpg)\n\n#### 2.3. Usage Code\n```python\nfrom model.conv.MBConv import MBConvBlock\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,3,224,224)\nmbconv=MBConvBlock(ksize=3,input_filters=3,output_filters=512,image_size=224)\nout=mbconv(input)\nprint(out.shape)\n\n\n```\n\n***\n\n\n### 3. Involution Usage\n#### 3.1. Paper\n[\"Involution: Inverting the Inherence of Convolution for Visual Recognition\"](https://arxiv.org/abs/2103.06255)\n\n#### 3.2. Overview\n![](./model/img/Involution.png)\n\n#### 3.3. Usage Code\n```python\nfrom model.conv.Involution import Involution\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\ninput=torch.randn(1,4,64,64)\ninvolution=Involution(kernel_size=3,in_channel=4,stride=2)\nout=involution(input)\nprint(out.shape)\n```\n\n***\n\n\n### 4. DynamicConv Usage\n#### 4.1. Paper\n[\"Dynamic Convolution: Attention over Convolution Kernels\"](https://arxiv.org/abs/1912.03458)\n\n#### 4.2. Overview\n![](./model/img/DynamicConv.png)\n\n#### 4.3. Usage Code\n```python\nfrom model.conv.DynamicConv import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(2,32,64,64)\n    m=DynamicConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)\n    out=m(input)\n    print(out.shape) # 2,32,64,64\n\n```\n\n***\n\n\n### 5. CondConv Usage\n#### 5.1. Paper\n[\"CondConv: Conditionally Parameterized Convolutions for Efficient Inference\"](https://arxiv.org/abs/1904.04971)\n\n#### 5.2. Overview\n![](./model/img/CondConv.png)\n\n#### 5.3. Usage Code\n```python\nfrom model.conv.CondConv import *\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nif __name__ == '__main__':\n    input=torch.randn(2,32,64,64)\n    m=CondConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)\n    out=m(input)\n    print(out.shape)\n\n```\n\n***\n\n\n"
  },
  {
    "path": "model/huggingface_hub.egg-info/SOURCES.txt",
    "content": "LICENSE\nREADME.md\nsetup.py\nmodel/huggingface_hub.egg-info/PKG-INFO\nmodel/huggingface_hub.egg-info/SOURCES.txt\nmodel/huggingface_hub.egg-info/dependency_links.txt\nmodel/huggingface_hub.egg-info/entry_points.txt\nmodel/huggingface_hub.egg-info/requires.txt\nmodel/huggingface_hub.egg-info/top_level.txt"
  },
  {
    "path": "model/huggingface_hub.egg-info/dependency_links.txt",
    "content": "\n"
  },
  {
    "path": "model/huggingface_hub.egg-info/entry_points.txt",
    "content": "[console_scripts]\nhuggingface-cli = huggingface_hub.commands.huggingface_cli:main\n\n"
  },
  {
    "path": "model/huggingface_hub.egg-info/requires.txt",
    "content": "filelock\nrequests\ntqdm\npyyaml>=5.1\ntyping-extensions>=3.7.4.3\npackaging>=20.9\n\n[:python_version < \"3.8\"]\nimportlib_metadata\n\n[all]\npytest\npytest-cov\ndatasets\nsoundfile\nblack==22.3\nisort>=5.5.4\nflake8>=3.8.3\nflake8-bugbear\n\n[dev]\npytest\npytest-cov\ndatasets\nsoundfile\nblack==22.3\nisort>=5.5.4\nflake8>=3.8.3\nflake8-bugbear\n\n[fastai]\ntoml\nfastai>=2.4\nfastcore>=1.3.27\n\n[quality]\nblack==22.3\nisort>=5.5.4\nflake8>=3.8.3\nflake8-bugbear\n\n[tensorflow]\ntensorflow\npydot\ngraphviz\n\n[testing]\npytest\npytest-cov\ndatasets\nsoundfile\n\n[torch]\ntorch\n"
  },
  {
    "path": "model/huggingface_hub.egg-info/top_level.txt",
    "content": "\n"
  },
  {
    "path": "model/mlp/g_mlp.py",
    "content": "from collections import OrderedDict\nimport torch\nfrom torch import nn\n\n\ndef exist(x):\n    return x is not None\n\nclass Residual(nn.Module):\n    def __init__(self,fn):\n        super().__init__()\n        self.fn=fn\n    \n    def forward(self,x):\n        return self.fn(x)+x\n\nclass SpatialGatingUnit(nn.Module):\n    def __init__(self,dim,len_sen):\n        super().__init__()\n        self.ln=nn.LayerNorm(dim)\n        self.proj=nn.Conv1d(len_sen,len_sen,1)\n\n        nn.init.zeros_(self.proj.weight)\n        nn.init.ones_(self.proj.bias)\n    \n    def forward(self,x):\n        res,gate=torch.chunk(x,2,-1) #bs,n,d_ff\n        ###Norm\n        gate=self.ln(gate) #bs,n,d_ff\n        ###Spatial Proj\n        gate=self.proj(gate) #bs,n,d_ff\n\n        return res*gate\n\nclass gMLP(nn.Module):\n    def __init__(self,num_tokens=None,len_sen=49,dim=512,d_ff=1024,num_layers=6):\n        super().__init__()\n        self.num_layers=num_layers\n        self.embedding=nn.Embedding(num_tokens,dim) if exist(num_tokens) else nn.Identity()\n\n        self.gmlp=nn.ModuleList([Residual(nn.Sequential(OrderedDict([\n            ('ln1_%d'%i,nn.LayerNorm(dim)),\n            ('fc1_%d'%i,nn.Linear(dim,d_ff*2)),\n            ('gelu_%d'%i,nn.GELU()),\n            ('sgu_%d'%i,SpatialGatingUnit(d_ff,len_sen)),\n            ('fc2_%d'%i,nn.Linear(d_ff,dim)),\n        ])))  for i in range(num_layers)])\n\n\n\n        self.to_logits=nn.Sequential(\n            nn.LayerNorm(dim),\n            nn.Linear(dim,num_tokens),\n            nn.Softmax(-1)\n        )\n\n\n    def forward(self,x):\n        #embedding\n        embeded=self.embedding(x)\n\n        #gMLP\n        y=nn.Sequential(*self.gmlp)(embeded)\n\n\n        #to logits\n        logits=self.to_logits(y)\n\n\n        return logits\n\n\n            \n\n\nif __name__ == '__main__':\n\n    num_tokens=10000\n    bs=50\n    len_sen=49\n    num_layers=6\n    input=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen\n    gmlp = gMLP(num_tokens=num_tokens,len_sen=len_sen,dim=512,d_ff=1024)\n    output=gmlp(input)\n    print(output.shape)\n        "
  },
  {
    "path": "model/mlp/mlp_mixer.py",
    "content": "import torch\nfrom torch import nn\n\nclass MlpBlock(nn.Module):\n    def __init__(self,input_dim,mlp_dim=512) :\n        super().__init__()\n        self.fc1=nn.Linear(input_dim,mlp_dim)\n        self.gelu=nn.GELU()\n        self.fc2=nn.Linear(mlp_dim,input_dim)\n    \n    def forward(self,x):\n        #x: (bs,tokens,channels) or (bs,channels,tokens)\n        return self.fc2(self.gelu(self.fc1(x)))\n\n\n\nclass MixerBlock(nn.Module):\n    def __init__(self,tokens_mlp_dim=16,channels_mlp_dim=1024,tokens_hidden_dim=32,channels_hidden_dim=1024):\n        super().__init__()\n        self.ln=nn.LayerNorm(channels_mlp_dim)\n        self.tokens_mlp_block=MlpBlock(tokens_mlp_dim,mlp_dim=tokens_hidden_dim)\n        self.channels_mlp_block=MlpBlock(channels_mlp_dim,mlp_dim=channels_hidden_dim)\n\n    def forward(self,x):\n        \"\"\"\n        x: (bs,tokens,channels)\n        \"\"\"\n        ### tokens mixing\n        y=self.ln(x)\n        y=y.transpose(1,2) #(bs,channels,tokens)\n        y=self.tokens_mlp_block(y) #(bs,channels,tokens)\n        ### channels mixing\n        y=y.transpose(1,2) #(bs,tokens,channels)\n        out =x+y #(bs,tokens,channels)\n        y=self.ln(out) #(bs,tokens,channels)\n        y=out+self.channels_mlp_block(y) #(bs,tokens,channels)\n        return y\n\nclass MlpMixer(nn.Module):\n    def __init__(self,num_classes,num_blocks,patch_size,tokens_hidden_dim,channels_hidden_dim,tokens_mlp_dim,channels_mlp_dim):\n        super().__init__()\n        self.num_classes=num_classes\n        self.num_blocks=num_blocks #num of mlp layers\n        self.patch_size=patch_size\n        self.tokens_mlp_dim=tokens_mlp_dim\n        self.channels_mlp_dim=channels_mlp_dim\n        self.embd=nn.Conv2d(3,channels_mlp_dim,kernel_size=patch_size,stride=patch_size) \n        self.ln=nn.LayerNorm(channels_mlp_dim)\n        self.mlp_blocks=[]\n        for _ in range(num_blocks):\n            self.mlp_blocks.append(MixerBlock(tokens_mlp_dim,channels_mlp_dim,tokens_hidden_dim,channels_hidden_dim))\n        self.fc=nn.Linear(channels_mlp_dim,num_classes)\n\n    def forward(self,x):\n        y=self.embd(x) # bs,channels,h,w\n        bs,c,h,w=y.shape\n        y=y.view(bs,c,-1).transpose(1,2) # bs,tokens,channels\n\n        if(self.tokens_mlp_dim!=y.shape[1]):\n            raise ValueError('Tokens_mlp_dim is not correct.')\n\n        for i in range(self.num_blocks):\n            y=self.mlp_blocks[i](y) # bs,tokens,channels\n        y=self.ln(y) # bs,tokens,channels\n        y=torch.mean(y,dim=1,keepdim=False) # bs,channels\n        probs=self.fc(y) # bs,num_classes\n        return probs\n            \n\n\nif __name__ == '__main__':\n    mlp_mixer=MlpMixer(num_classes=1000,num_blocks=10,patch_size=10,tokens_hidden_dim=32,channels_hidden_dim=1024,tokens_mlp_dim=16,channels_mlp_dim=1024)\n    input=torch.randn(50,3,40,40)\n    output=mlp_mixer(input)\n    print(output.shape)\n    "
  },
  {
    "path": "model/mlp/repmlp.py",
    "content": "import torch\nfrom torch import nn\nfrom collections import OrderedDict\nfrom torch.nn import functional as F\nimport numpy as np\nfrom numpy import random\n\n\ndef setup_seed(seed):\n     torch.manual_seed(seed)\n     torch.cuda.manual_seed_all(seed)\n     np.random.seed(seed)\n     random.seed(seed)\n     torch.backends.cudnn.deterministic = True\n\nclass RepMLP(nn.Module):\n    def __init__(self,C,O,H,W,h,w,fc1_fc2_reduction=1,fc3_groups=8,repconv_kernels=None,deploy=False):\n        super().__init__()\n        self.C=C\n        self.O=O\n        self.H=H\n        self.W=W\n        self.h=h\n        self.w=w\n        self.fc1_fc2_reduction=fc1_fc2_reduction\n        self.repconv_kernels=repconv_kernels\n        self.h_part=H//h\n        self.w_part=W//w\n        self.deploy=deploy\n        self.fc3_groups=fc3_groups\n\n        # make sure H,W can divided by h,w respectively\n        assert H%h==0\n        assert W%w==0\n\n        self.is_global_perceptron= (H!=h) or (W!=w)\n        ### global perceptron\n        if(self.is_global_perceptron):\n            if(not self.deploy):\n                self.avg=nn.Sequential(OrderedDict([\n                    ('avg',nn.AvgPool2d(kernel_size=(self.h,self.w))),\n                    ('bn',nn.BatchNorm2d(num_features=C))\n                ])\n                )\n            else:\n                self.avg=nn.AvgPool2d(kernel_size=(self.h,self.w))\n            hidden_dim=self.C//self.fc1_fc2_reduction\n            self.fc1_fc2=nn.Sequential(OrderedDict([\n                ('fc1',nn.Linear(C*self.h_part*self.w_part,hidden_dim)),\n                ('relu',nn.ReLU()),\n                ('fc2',nn.Linear(hidden_dim,C*self.h_part*self.w_part))\n            ])\n            )\n\n        self.fc3=nn.Conv2d(self.C*self.h*self.w,self.O*self.h*self.w,kernel_size=1,groups=fc3_groups,bias=self.deploy)\n        self.fc3_bn=nn.Identity() if self.deploy else nn.BatchNorm2d(self.O*self.h*self.w)\n        \n        if not self.deploy and self.repconv_kernels is not None:\n            for k in self.repconv_kernels:\n                repconv=nn.Sequential(OrderedDict([\n                    ('conv',nn.Conv2d(self.C,self.O,kernel_size=k,padding=(k-1)//2, groups=fc3_groups,bias=False)),\n                    ('bn',nn.BatchNorm2d(self.O))\n                ])\n\n                )\n                self.__setattr__('repconv{}'.format(k),repconv)\n                \n\n    def switch_to_deploy(self):\n        self.deploy=True\n        fc1_weight,fc1_bias,fc3_weight,fc3_bias=self.get_equivalent_fc1_fc3_params()\n        #del conv\n        if(self.repconv_kernels is not None):\n            for k in self.repconv_kernels:\n                self.__delattr__('repconv{}'.format(k))\n        #del fc3,bn\n        self.__delattr__('fc3')\n        self.__delattr__('fc3_bn')\n        self.fc3 = nn.Conv2d(self.C * self.h * self.w, self.O * self.h * self.w, 1, 1, 0, bias=True, groups=self.fc3_groups)\n        self.fc3_bn = nn.Identity()\n        #   Remove the BN after AVG\n        if self.is_global_perceptron:\n            self.__delattr__('avg')\n            self.avg = nn.AvgPool2d(kernel_size=(self.h, self.w))\n        #   Set values\n        if fc1_weight is not None:\n            self.fc1_fc2.fc1.weight.data = fc1_weight\n            self.fc1_fc2.fc1.bias.data = fc1_bias\n        self.fc3.weight.data = fc3_weight\n        self.fc3.bias.data = fc3_bias\n\n\n\n\n    def get_equivalent_fc1_fc3_params(self):\n        #training fc3+bn weight\n        fc_weight,fc_bias=self._fuse_bn(self.fc3,self.fc3_bn)\n        #training conv weight\n        if(self.repconv_kernels is not None):\n            max_kernel=max(self.repconv_kernels)\n            max_branch=self.__getattr__('repconv{}'.format(max_kernel))\n            conv_weight,conv_bias=self._fuse_bn(max_branch.conv,max_branch.bn)\n            for k in self.repconv_kernels:\n                if(k!=max_kernel):\n                    tmp_branch=self.__getattr__('repconv{}'.format(k))\n                    tmp_weight,tmp_bias=self._fuse_bn(tmp_branch.conv,tmp_branch.bn)\n                    tmp_weight=F.pad(tmp_weight,[(max_kernel-k)//2]*4)\n                    conv_weight+=tmp_weight\n                    conv_bias+=tmp_bias\n            repconv_weight,repconv_bias=self._conv_to_fc(conv_weight,conv_bias)\n            final_fc3_weight=fc_weight+repconv_weight.reshape_as(fc_weight)\n            final_fc3_bias=fc_bias+repconv_bias\n        else:\n            final_fc3_weight=fc_weight\n            final_fc3_bias=fc_bias\n\n        #fc1\n        if(self.is_global_perceptron):\n            #remove BN after avg\n            avgbn = self.avg.bn\n            std = (avgbn.running_var + avgbn.eps).sqrt()\n            scale = avgbn.weight / std\n            avgbias = avgbn.bias - avgbn.running_mean * scale\n            fc1 = self.fc1_fc2.fc1\n            replicate_times = fc1.in_features // len(avgbias)\n            replicated_avgbias = avgbias.repeat_interleave(replicate_times).view(-1, 1)\n            bias_diff = fc1.weight.matmul(replicated_avgbias).squeeze()\n            final_fc1_bias = fc1.bias + bias_diff\n            final_fc1_weight = fc1.weight * scale.repeat_interleave(replicate_times).view(1, -1)\n\n        else:\n            final_fc1_weight=None\n            final_fc1_bias=None\n        \n        return final_fc1_weight,final_fc1_bias,final_fc3_weight,final_fc3_bias\n\n\n\n\n    # def _conv_to_fc(self,weight,bias):\n    #     i_maxtrix=torch.eye(self.C*self.h*self.w//self.fc3_groups).repeat(1,self.fc3_groups).reshape(self.C*self.h*self.w//self.fc3_groups,self.C,self.h,self.w)\n    #     fc_weight=F.conv2d(i_maxtrix,weight=weight,bias=bias,padding=weight.shape[2]//2,groups=self.fc3_groups)\n    #     fc_weight=fc_weight.reshape(self.C*self.h*self.w//self.fc3_groups,-1)\n    #     fc_bias = bias.repeat_interleave(self.h * self.w)\n    #     return fc_weight,fc_bias\n\n\n    def _conv_to_fc(self,conv_kernel, conv_bias):\n        I = torch.eye(self.C * self.h * self.w // self.fc3_groups).repeat(1, self.fc3_groups).reshape(self.C * self.h * self.w // self.fc3_groups, self.C, self.h, self.w).to(conv_kernel.device) \n        fc_k = F.conv2d(I, conv_kernel, padding=conv_kernel.size(2)//2, groups=self.fc3_groups)\n        fc_k = fc_k.reshape(self.C * self.h * self.w // self.fc3_groups, self.O * self.h * self.w).t()\n        fc_bias = conv_bias.repeat_interleave(self.h * self.w)\n        return fc_k, fc_bias\n\n\n    def _fuse_bn(self, conv_or_fc, bn):\n        std = (bn.running_var + bn.eps).sqrt()\n        t = bn.weight / std\n        if conv_or_fc.weight.ndim == 4:\n            t = t.reshape(-1, 1, 1, 1)\n        else:\n            t = t.reshape(-1, 1)\n        return conv_or_fc.weight * t, bn.bias - bn.running_mean * bn.weight / std\n\n\n    def forward(self,x) :\n        ### global partition\n        if(self.is_global_perceptron):\n            input=x\n            v=self.avg(x) #bs,C,h_part,w_part\n            v=v.reshape(-1,self.C*self.h_part*self.w_part) #bs,C*h_part*w_part\n            v=self.fc1_fc2(v) #bs,C*h_part*w_part\n            v=v.reshape(-1,self.C,self.h_part,1,self.w_part,1) #bs,C,h_part,w_part\n            input=input.reshape(-1,self.C,self.h_part,self.h,self.w_part,self.w) #bs,C,h_part,h,w_part,w\n            input=v+input\n        else:\n            input=x.view(-1,self.C,self.h_part,self.h,self.w_part,self.w) #bs,C,h_part,h,w_part,w\n        partition=input.permute(0,2,4,1,3,5) #bs,h_part,w_part,C,h,w\n\n        ### partition partition\n        fc3_out=partition.reshape(-1,self.C*self.h*self.w,1,1) #bs*h_part*w_part,C*h*w,1,1\n        fc3_out=self.fc3_bn(self.fc3(fc3_out)) #bs*h_part*w_part,O*h*w,1,1\n        fc3_out=fc3_out.reshape(-1,self.h_part,self.w_part,self.O,self.h,self.w) #bs,h_part,w_part,O,h,w\n\n        ### local perceptron\n        if(self.repconv_kernels is not None and not self.deploy):\n            conv_input=partition.reshape(-1,self.C,self.h,self.w) #bs*h_part*w_part,C,h,w\n            conv_out=0\n            for k in self.repconv_kernels:\n                repconv=self.__getattr__('repconv{}'.format(k))\n                conv_out+=repconv(conv_input) ##bs*h_part*w_part,O,h,w\n            conv_out=conv_out.view(-1,self.h_part,self.w_part,self.O,self.h,self.w) #bs,h_part,w_part,O,h,w\n            fc3_out+=conv_out\n        fc3_out=fc3_out.permute(0,3,1,4,2,5)#bs,O,h_part,h,w_part,w\n        fc3_out=fc3_out.reshape(-1,self.C,self.H,self.W) #bs,O,H,W\n\n\n        return fc3_out\n\n\n\nif __name__ == '__main__':\n    setup_seed(20)\n    N=4 #batch size\n    C=512 #input dim\n    O=1024 #output dim\n    H=14 #image height\n    W=14 #image width\n    h=7 #patch height\n    w=7 #patch width\n    fc1_fc2_reduction=1 #reduction ratio\n    fc3_groups=8 # groups\n    repconv_kernels=[1,3,5,7] #kernel list\n    repmlp=RepMLP(C,O,H,W,h,w,fc1_fc2_reduction,fc3_groups,repconv_kernels=repconv_kernels)\n    x=torch.randn(N,C,H,W)\n    repmlp.eval()\n    for module in repmlp.modules():\n        if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):\n            nn.init.uniform_(module.running_mean, 0, 0.1)\n            nn.init.uniform_(module.running_var, 0, 0.1)\n            nn.init.uniform_(module.weight, 0, 0.1)\n            nn.init.uniform_(module.bias, 0, 0.1)\n\n    #training result\n    out=repmlp(x)\n\n\n    #inference result\n    repmlp.switch_to_deploy()\n    deployout = repmlp(x)\n\n    print(((deployout-out)**2).sum())"
  },
  {
    "path": "model/mlp/resmlp.py",
    "content": "import torch\nfrom torch import nn\n\n\nclass Rearange(nn.Module):\n    def __init__(self,image_size=14,patch_size=7) :\n        self.h=patch_size\n        self.w=patch_size\n        self.nw=image_size // patch_size\n        self.nh=image_size // patch_size\n\n        num_patches = (image_size // patch_size) ** 2\n        super().__init__()\n\n    def forward(self,x):\n        ### bs,c,H,W\n        bs,c,H,W=x.shape\n\n        y=x.reshape(bs,c,self.h,self.nh,self.w,self.nw)\n        y=y.permute(0,3,5,2,4,1) #bs,nh,nw,h,w,c\n        y=y.contiguous().view(bs,self.nh*self.nw,-1) #bs,nh*nw,h*w*c\n        return y\n\nclass Affine(nn.Module):\n    def __init__(self, channel):\n        super().__init__()\n        self.g = nn.Parameter(torch.ones(1, 1, channel))\n        self.b = nn.Parameter(torch.zeros(1, 1, channel))\n\n    def forward(self, x):\n        return x * self.g + self.b\n\nclass PreAffinePostLayerScale(nn.Module): # https://arxiv.org/abs/2103.17239\n    def __init__(self, dim, depth, fn):\n        super().__init__()\n        if depth <= 18:\n            init_eps = 0.1\n        elif depth > 18 and depth <= 24:\n            init_eps = 1e-5\n        else:\n            init_eps = 1e-6\n\n        scale = torch.zeros(1, 1, dim).fill_(init_eps)\n        self.scale = nn.Parameter(scale)\n        self.affine = Affine(dim)\n        self.fn = fn\n\n    def forward(self, x):\n        return self.fn(self.affine(x)) * self.scale + x\n\n\nclass ResMLP(nn.Module):\n    def __init__(self,dim=128,image_size=14,patch_size=7,expansion_factor=4,depth=4,class_num=1000):\n        super().__init__()\n        self.flatten=Rearange(image_size,patch_size)\n        num_patches = (image_size // patch_size) ** 2\n        wrapper = lambda i, fn: PreAffinePostLayerScale(dim, i + 1, fn)\n        self.embedding=nn.Linear((patch_size ** 2) * 3, dim)\n        self.mlp=nn.Sequential()\n\n        for i in range(depth):\n            self.mlp.add_module('fc1_%d'%i,wrapper(i, nn.Conv1d(patch_size ** 2, patch_size ** 2, 1)))\n            self.mlp.add_module('fc1_%d'%i,wrapper(i, nn.Sequential(\n                    nn.Linear(dim, dim * expansion_factor),\n                    nn.GELU(),\n                    nn.Linear(dim * expansion_factor, dim)\n                )))\n\n        self.aff=Affine(dim)\n\n        self.classifier=nn.Linear(dim,class_num)\n        self.softmax=nn.Softmax(1)\n    \n    def forward(self, x) :\n        y=self.flatten(x)\n        y=self.embedding(y)\n        y=self.mlp(y)\n        y=self.aff(y)\n        y=torch.mean(y,dim=1) #bs,dim\n        out=self.softmax(self.classifier(y))\n        return out\n\nif __name__ == '__main__':\n    input=torch.randn(50,3,14,14)\n    resmlp=ResMLP(dim=128,image_size=14,patch_size=7,class_num=1000)\n    out=resmlp(input)\n    print(out.shape)\n\n    "
  },
  {
    "path": "model/mlp/sMLP_block.py",
    "content": "import torch\nfrom torch import nn\n\n\n\n\n\n\n\nclass sMLPBlock(nn.Module):\n    def __init__(self,h=224,w=224,c=3):\n        super().__init__()\n        self.proj_h=nn.Linear(h,h)\n        self.proj_w=nn.Linear(w,w)\n        self.fuse=nn.Linear(3*c,c)\n    \n    def forward(self,x):\n        x_h=self.proj_h(x.permute(0,1,3,2)).permute(0,1,3,2)\n        x_w=self.proj_w(x)\n        x_id=x\n        x_fuse=torch.cat([x_h,x_w,x_id],dim=1)\n        out=self.fuse(x_fuse.permute(0,2,3,1)).permute(0,3,1,2)\n        return out\n\n\nif __name__ == '__main__':\n    input=torch.randn(50,3,224,224)\n    smlp=sMLPBlock(h=224,w=224)\n    out=smlp(input)\n    print(out.shape)"
  },
  {
    "path": "model/mlp/vip-mlp.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\nfrom timm.models.layers import DropPath, trunc_normal_\nfrom timm.models.registry import register_model\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n        'crop_pct': .96, 'interpolation': 'bicubic',\n        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head',\n        **kwargs\n    }\n\ndefault_cfgs = {\n    'ViP_S': _cfg(crop_pct=0.9),\n    'ViP_M': _cfg(crop_pct=0.9),\n    'ViP_L': _cfg(crop_pct=0.875),\n}\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\nclass WeightedPermuteMLP(nn.Module):\n    def __init__(self, dim, segment_dim=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n        super().__init__()\n        self.segment_dim = segment_dim\n\n        self.mlp_c = nn.Linear(dim, dim, bias=qkv_bias)\n        self.mlp_h = nn.Linear(dim, dim, bias=qkv_bias)\n        self.mlp_w = nn.Linear(dim, dim, bias=qkv_bias)\n\n        self.reweight = Mlp(dim, dim // 4, dim *3)\n        \n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n\n\n    def forward(self, x):\n        B, H, W, C = x.shape\n\n        S = C // self.segment_dim\n        h = x.reshape(B, H, W, self.segment_dim, S).permute(0, 3, 2, 1, 4).reshape(B, self.segment_dim, W, H*S)\n        h = self.mlp_h(h).reshape(B, self.segment_dim, W, H, S).permute(0, 3, 2, 1, 4).reshape(B, H, W, C)\n\n        w = x.reshape(B, H, W, self.segment_dim, S).permute(0, 1, 3, 2, 4).reshape(B, H, self.segment_dim, W*S)\n        w = self.mlp_w(w).reshape(B, H, self.segment_dim, W, S).permute(0, 1, 3, 2, 4).reshape(B, H, W, C)\n\n        c = self.mlp_c(x)\n        \n        a = (h + w + c).permute(0, 3, 1, 2).flatten(2).mean(2)\n        a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2)\n\n        x = h * a[0] + w * a[1] + c * a[2]\n\n        x = self.proj(x)\n        x = self.proj_drop(x)\n\n        return x\n\n\nclass PermutatorBlock(nn.Module):\n\n    def __init__(self, dim, segment_dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip_lam=1.0, mlp_fn = WeightedPermuteMLP):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = mlp_fn(dim, segment_dim=segment_dim, qkv_bias=qkv_bias, qk_scale=None, attn_drop=attn_drop)\n\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)\n        self.skip_lam = skip_lam\n\n    def forward(self, x):\n        x = x + self.drop_path(self.attn(self.norm1(x))) / self.skip_lam\n        x = x + self.drop_path(self.mlp(self.norm2(x))) / self.skip_lam\n        return x\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n    def forward(self, x):\n        x = self.proj(x) # B, C, H, W\n        return x\n\n\nclass Downsample(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n    def __init__(self, in_embed_dim, out_embed_dim, patch_size):\n        super().__init__()\n        self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=patch_size, stride=patch_size)\n\n    def forward(self, x):\n        x = x.permute(0, 3, 1, 2)\n        x = self.proj(x) # B, C, H, W\n        x = x.permute(0, 2, 3, 1)\n        return x\n\ndef basic_blocks(dim, index, layers, segment_dim, mlp_ratio=3., qkv_bias=False, qk_scale=None, \\\n    attn_drop=0, drop_path_rate=0., skip_lam=1.0, mlp_fn = WeightedPermuteMLP, **kwargs):\n    blocks = []\n\n    for block_idx in range(layers[index]):\n        block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)\n        blocks.append(PermutatorBlock(dim, segment_dim, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\\\n            attn_drop=attn_drop, drop_path=block_dpr, skip_lam=skip_lam, mlp_fn = mlp_fn))\n\n    blocks = nn.Sequential(*blocks)\n\n    return blocks\n\nclass VisionPermutator(nn.Module):\n    \"\"\" Vision Permutator\n    \"\"\"\n    def __init__(self, layers, img_size=224, patch_size=4, in_chans=3, num_classes=1000,\n        embed_dims=None, transitions=None, segment_dim=None, mlp_ratios=None, skip_lam=1.0,\n        qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,\n        norm_layer=nn.LayerNorm,mlp_fn = WeightedPermuteMLP):\n\n        super().__init__()\n        self.num_classes = num_classes\n\n        self.patch_embed = PatchEmbed(img_size = img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0])\n\n        network = []\n        for i in range(len(layers)):\n            stage = basic_blocks(embed_dims[i], i, layers, segment_dim[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias,\n                    qk_scale=qk_scale, attn_drop=attn_drop_rate, drop_path_rate=drop_path_rate, norm_layer=norm_layer, skip_lam=skip_lam,\n                    mlp_fn = mlp_fn)\n            network.append(stage)\n            if i >= len(layers) - 1:\n                break\n            if transitions[i] or embed_dims[i] != embed_dims[i+1]:\n                patch_size = 2 if transitions[i] else 1\n                network.append(Downsample(embed_dims[i], embed_dims[i+1], patch_size))\n\n\n        self.network = nn.ModuleList(network)\n\n        self.norm = norm_layer(embed_dims[-1])\n\n        # Classifier head\n        self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_embeddings(self, x):\n        x = self.patch_embed(x)\n        # B,C,H,W-> B,H,W,C\n        x = x.permute(0, 2, 3, 1)\n        return x\n\n    def forward_tokens(self,x):\n        for idx, block in enumerate(self.network):\n            x = block(x)\n        B, H, W, C = x.shape\n        x = x.reshape(B, -1, C)\n        return x\n\n    def forward(self, x):\n        x = self.forward_embeddings(x)\n        # B, H, W, C -> B, N, C\n        x = self.forward_tokens(x)\n        x = self.norm(x)\n        return self.head(x.mean(1))\n\n\n\n\n@register_model\ndef vip_s14(pretrained=False, **kwargs):\n    layers = [4, 3, 8, 3]\n    transitions = [False, False, False, False]\n    segment_dim = [16, 16, 16, 16]\n    mlp_ratios = [3, 3, 3, 3]\n    embed_dims = [384, 384, 384, 384]\n    model = VisionPermutator(layers, embed_dims=embed_dims, patch_size=14, transitions=transitions,\n        segment_dim=segment_dim, mlp_ratios=mlp_ratios, mlp_fn=WeightedPermuteMLP, **kwargs)\n    model.default_cfg = default_cfgs['ViP_S']\n    return model\n\n@register_model\ndef vip_s7(pretrained=False, **kwargs):\n    layers = [4, 3, 8, 3]\n    transitions = [True, False, False, False]\n    segment_dim = [32, 16, 16, 16]\n    mlp_ratios = [3, 3, 3, 3]\n    embed_dims = [192, 384, 384, 384]\n    model = VisionPermutator(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,\n        segment_dim=segment_dim, mlp_ratios=mlp_ratios, mlp_fn=WeightedPermuteMLP, **kwargs)\n    model.default_cfg = default_cfgs['ViP_S']\n    return model\n\n@register_model\ndef vip_m7(pretrained=False, **kwargs):\n    # 55534632\n    layers = [4, 3, 14, 3]\n    transitions = [False, True, False, False]\n    segment_dim = [32, 32, 16, 16]\n    mlp_ratios = [3, 3, 3, 3]\n    embed_dims = [256, 256, 512, 512]\n    model = VisionPermutator(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,\n        segment_dim=segment_dim, mlp_ratios=mlp_ratios, mlp_fn=WeightedPermuteMLP, **kwargs)\n    model.default_cfg = default_cfgs['ViP_M']\n    return model\n\n\n@register_model\ndef vip_l7(pretrained=False, **kwargs):\n    layers = [8, 8, 16, 4]\n    transitions = [True, False, False, False]\n    segment_dim = [32, 16, 16, 16]\n    mlp_ratios = [3, 3, 3, 3]\n    embed_dims = [256, 512, 512, 512]\n    model = VisionPermutator(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,\n        segment_dim=segment_dim, mlp_ratios=mlp_ratios, mlp_fn=WeightedPermuteMLP, **kwargs)\n    model.default_cfg = default_cfgs['ViP_L']\n    return model\n\nif __name__ == '__main__':\n    input=torch.randn(1,3,224,224)\n    model = VisionPermutator(\n        layers=[4, 3, 8, 3], \n        embed_dims=[384, 384, 384, 384], \n        patch_size=14, \n        transitions=[False, False, False, False],\n        segment_dim=[16, 16, 16, 16], \n        mlp_ratios=[3, 3, 3, 3], \n        mlp_fn=WeightedPermuteMLP\n    )\n    output=model(input)\n    print(output.shape)"
  },
  {
    "path": "model/rep/acnet.py",
    "content": "import torch\nfrom torch import mean, nn\nfrom collections import OrderedDict\nfrom torch.nn import functional as F\nimport numpy as np\nfrom numpy import random\n\ndef setup_seed(seed):\n     torch.manual_seed(seed)\n     torch.cuda.manual_seed_all(seed)\n     np.random.seed(seed)\n     random.seed(seed)\n     torch.backends.cudnn.deterministic = True\n\ndef _conv_bn(input_channel,output_channel,kernel_size=3,padding=1,stride=1,groups=1):\n     res=nn.Sequential()\n     res.add_module('conv',nn.Conv2d(in_channels=input_channel,out_channels=output_channel,kernel_size=kernel_size,padding=padding,padding_mode='zeros',stride=stride,groups=groups,bias=False))\n     res.add_module('bn',nn.BatchNorm2d(output_channel))\n     return res\n\nclass ACNet(nn.Module):\n     def __init__(self,input_channel,output_channel,kernel_size=3,groups=1,stride=1,deploy=False,use_se=False):\n          super().__init__()\n          self.use_se=use_se\n          self.input_channel=input_channel\n          self.output_channel=output_channel\n          self.deploy=deploy\n          self.kernel_size=kernel_size\n          self.padding=kernel_size//2\n          self.groups=groups\n          self.activation=nn.ReLU()\n\n\n          if(not self.deploy):\n               self.brb_3x3=_conv_bn(input_channel,output_channel,kernel_size=3,padding=1,groups=groups)\n               self.brb_1x3=_conv_bn(input_channel,output_channel,kernel_size=(1,3),padding=(0,1),groups=groups)\n               self.brb_3x1=_conv_bn(input_channel,output_channel,kernel_size=(3,1),padding=(1,0),groups=groups)\n          else:\n               self.brb_rep=nn.Conv2d(in_channels=input_channel,out_channels=output_channel,kernel_size=self.kernel_size,padding=self.padding,padding_mode='zeros',stride=stride,bias=True)\n\n\n     \n     def forward(self, inputs):\n          if(self.deploy):\n               return self.activation(self.brb_rep(inputs))\n\n          return self.activation(self.brb_1x3(inputs)+self.brb_3x1(inputs)+self.brb_3x3(inputs))\n\n     \n     \n\n     def _switch_to_deploy(self):\n          self.deploy=True\n          kernel,bias=self._get_equivalent_kernel_bias()\n          self.brb_rep=nn.Conv2d(in_channels=self.brb_3x3.conv.in_channels,out_channels=self.brb_3x3.conv.out_channels,\n                                   kernel_size=self.brb_3x3.conv.kernel_size,padding=self.brb_3x3.conv.padding,\n                                   padding_mode=self.brb_3x3.conv.padding_mode,stride=self.brb_3x3.conv.stride,\n                                   groups=self.brb_3x3.conv.groups,bias=True)\n          self.brb_rep.weight.data=kernel\n          self.brb_rep.bias.data=bias\n          #消除梯度更新\n          for para in self.parameters():\n               para.detach_()\n          #删除没用的分支\n          self.__delattr__('brb_3x3')\n          self.__delattr__('brb_3x1')\n          self.__delattr__('brb_1x3')\n\n\n\n     #将1x3的卷积变成3x3的卷积参数\n     def _pad_1x3_kernel(self,kernel):\n          if(kernel is None):\n               return 0\n          else:\n               return F.pad(kernel,[0,0,1,1])\n\n     #将3x1的卷积变成3x3的卷积参数\n     def _pad_3x1_kernel(self,kernel):\n          if(kernel is None):\n               return 0\n          else:\n               return F.pad(kernel,[1,1,0,0])\n\n\n     #将identity，1x1,3x3的卷积融合到一起，变成一个3x3卷积的参数\n     def _get_equivalent_kernel_bias(self):\n          brb_3x3_weight,brb_3x3_bias=self._fuse_conv_bn(self.brb_3x3)\n          brb_1x3_weight,brb_1x3_bias=self._fuse_conv_bn(self.brb_1x3)\n          brb_3x1_weight,brb_3x1_bias=self._fuse_conv_bn(self.brb_3x1)\n          return brb_3x3_weight+self._pad_1x3_kernel(brb_1x3_weight)+self._pad_3x1_kernel(brb_3x1_weight),brb_3x3_bias+brb_1x3_bias+brb_3x1_bias\n     \n     \n     ### 将卷积和BN的参数融合到一起\n     def _fuse_conv_bn(self,branch):\n          kernel=branch.conv.weight\n          running_mean=branch.bn.running_mean\n          running_var=branch.bn.running_var\n          gamma=branch.bn.weight\n          beta=branch.bn.bias\n          eps=branch.bn.eps\n          \n          std=(running_var+eps).sqrt()\n          t=gamma/std\n          t=t.view(-1,1,1,1)\n          return kernel*t,beta-running_mean*gamma/std\n          \n\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,49,49)\n    acnet=ACNet(512,512)\n    acnet.eval()\n    out=acnet(input)\n    acnet._switch_to_deploy()\n    out2=acnet(input)\n    print('difference:')\n    print(((out2-out)**2).sum())\n    "
  },
  {
    "path": "model/rep/ddb.py",
    "content": "from torch import conv2d, nn\nimport torch\nfrom torch.nn import functional as F\n\ndef transI_conv_bn(conv, bn):\n\n    std = (bn.running_var + bn.eps).sqrt()\n    gamma=bn.weight\n\n    weight=conv.weight*((gamma/std).reshape(-1, 1, 1, 1))\n    if(conv.bias is not None):\n        bias=gamma/std*conv.bias-gamma/std*bn.running_mean+bn.bias\n    else:\n        bias=bn.bias-gamma/std*bn.running_mean\n    return weight,bias\n\ndef transII_conv_branch(conv1, conv2):\n    weight=conv1.weight.data+conv2.weight.data\n    bias=conv1.bias.data+conv2.bias.data\n    return weight,bias\n\n\ndef transIII_conv_sequential(conv1, conv2):\n    weight=F.conv2d(conv2.weight.data,conv1.weight.data.permute(1,0,2,3))\n    # bias=((conv2.weight.data*(conv1.bias.data.reshape(1,-1,1,1))).sum(-1).sum(-1).sum(-1))+conv2.bias.data\n    return weight#,bias\n\ndef transIV_conv_concat(conv1, conv2):\n    print(conv1.bias.data.shape)\n    print(conv2.bias.data.shape)\n    weight=torch.cat([conv1.weight.data,conv2.weight.data],0)\n    bias=torch.cat([conv1.bias.data,conv2.bias.data],0)\n    return weight,bias\n\ndef transV_avg(channel,kernel):\n    conv=nn.Conv2d(channel,channel,kernel,bias=False)\n    conv.weight.data[:]=0\n    for i in range(channel):\n        conv.weight.data[i,i,:,:]=1/(kernel*kernel)\n    return conv\n\ndef transVI_conv_scale(conv1, conv2, conv3):\n    weight=F.pad(conv1.weight.data,(1,1,1,1))+F.pad(conv2.weight.data,(0,0,1,1))+F.pad(conv3.weight.data,(1,1,0,0))\n    bias=conv1.bias.data+conv2.bias.data+conv3.bias.data\n    return weight,bias\n\nif __name__ == '__main__':\n    input=torch.randn(1,64,7,7)\n\n    #conv+conv\n    conv1x1=nn.Conv2d(64,64,1)\n    conv1x3=nn.Conv2d(64,64,(1,3),padding=(0,1))\n    conv3x1=nn.Conv2d(64,64,(3,1),padding=(1,0))\n    out1=conv1x1(input)+conv1x3(input)+conv3x1(input)\n\n    #conv_fuse\n    conv_fuse=nn.Conv2d(64,64,3,padding=1)\n    conv_fuse.weight.data,conv_fuse.bias.data=transVI_conv_scale(conv1x1,conv1x3,conv3x1)\n    out2=conv_fuse(input)\n\n    print(\"difference:\",((out2-out1)**2).sum().item())\n    "
  },
  {
    "path": "model/rep/mobileone.py",
    "content": ""
  },
  {
    "path": "model/rep/repvgg.py",
    "content": "import torch\nfrom torch import mean, nn\nfrom collections import OrderedDict\nfrom torch.nn import functional as F\nimport numpy as np\nfrom numpy import random\n\ndef setup_seed(seed):\n     torch.manual_seed(seed)\n     torch.cuda.manual_seed_all(seed)\n     np.random.seed(seed)\n     random.seed(seed)\n     torch.backends.cudnn.deterministic = True\n\ndef _conv_bn(input_channel,output_channel,kernel_size=3,padding=1,stride=1,groups=1):\n     res=nn.Sequential()\n     res.add_module('conv',nn.Conv2d(in_channels=input_channel,out_channels=output_channel,kernel_size=kernel_size,padding=padding,padding_mode='zeros',stride=stride,groups=groups,bias=False))\n     res.add_module('bn',nn.BatchNorm2d(output_channel))\n     return res\n\nclass RepBlock(nn.Module):\n     def __init__(self,input_channel,output_channel,kernel_size=3,groups=1,stride=1,deploy=False,use_se=False):\n          super().__init__()\n          self.use_se=use_se\n          self.input_channel=input_channel\n          self.output_channel=output_channel\n          self.deploy=deploy\n          self.kernel_size=kernel_size\n          self.padding=kernel_size//2\n          self.groups=groups\n          self.activation=nn.ReLU()\n\n\n\n          #make sure kernel_size=3 padding=1\n          assert self.kernel_size==3\n          assert self.padding==1\n          if(not self.deploy):\n               self.brb_3x3=_conv_bn(input_channel,output_channel,kernel_size=self.kernel_size,padding=self.padding,groups=groups)\n               self.brb_1x1=_conv_bn(input_channel,output_channel,kernel_size=1,padding=0,groups=groups)\n               self.brb_identity=nn.BatchNorm2d(self.input_channel) if self.input_channel == self.output_channel else None\n          else:\n               self.brb_rep=nn.Conv2d(in_channels=input_channel,out_channels=output_channel,kernel_size=self.kernel_size,padding=self.padding,padding_mode='zeros',stride=stride,bias=True)\n\n\n     \n     def forward(self, inputs):\n          if(self.deploy):\n               return self.activation(self.brb_rep(inputs))\n          \n          if(self.brb_identity==None):\n               identity_out=0\n          else:\n               identity_out=self.brb_identity(inputs)\n          \n          return self.activation(self.brb_1x1(inputs)+self.brb_3x3(inputs)+identity_out)\n\n     \n     \n\n     def _switch_to_deploy(self):\n          self.deploy=True\n          kernel,bias=self._get_equivalent_kernel_bias()\n          self.brb_rep=nn.Conv2d(in_channels=self.brb_3x3.conv.in_channels,out_channels=self.brb_3x3.conv.out_channels,\n                                   kernel_size=self.brb_3x3.conv.kernel_size,padding=self.brb_3x3.conv.padding,\n                                   padding_mode=self.brb_3x3.conv.padding_mode,stride=self.brb_3x3.conv.stride,\n                                   groups=self.brb_3x3.conv.groups,bias=True)\n          self.brb_rep.weight.data=kernel\n          self.brb_rep.bias.data=bias\n          #消除梯度更新\n          for para in self.parameters():\n               para.detach_()\n          #删除没用的分支\n          self.__delattr__('brb_3x3')\n          self.__delattr__('brb_1x1')\n          self.__delattr__('brb_identity')\n\n\n     #将1x1的卷积变成3x3的卷积参数\n     def _pad_1x1_kernel(self,kernel):\n          if(kernel is None):\n               return 0\n          else:\n               return F.pad(kernel,[1]*4)\n\n\n     #将identity，1x1,3x3的卷积融合到一起，变成一个3x3卷积的参数\n     def _get_equivalent_kernel_bias(self):\n          brb_3x3_weight,brb_3x3_bias=self._fuse_conv_bn(self.brb_3x3)\n          brb_1x1_weight,brb_1x1_bias=self._fuse_conv_bn(self.brb_1x1)\n          brb_id_weight,brb_id_bias=self._fuse_conv_bn(self.brb_identity)\n          return brb_3x3_weight+self._pad_1x1_kernel(brb_1x1_weight)+brb_id_weight,brb_3x3_bias+brb_1x1_bias+brb_id_bias\n     \n     \n     ### 将卷积和BN的参数融合到一起\n     def _fuse_conv_bn(self,branch):\n          if(branch is None):\n               return 0,0\n          elif(isinstance(branch,nn.Sequential)):\n               kernel=branch.conv.weight\n               running_mean=branch.bn.running_mean\n               running_var=branch.bn.running_var\n               gamma=branch.bn.weight\n               beta=branch.bn.bias\n               eps=branch.bn.eps\n          else:\n               assert isinstance(branch, nn.BatchNorm2d)\n               if not hasattr(self, 'id_tensor'):\n                    input_dim = self.input_channel // self.groups\n                    kernel_value = np.zeros((self.input_channel, input_dim, 3, 3), dtype=np.float32)\n                    for i in range(self.input_channel):\n                         kernel_value[i, i % input_dim, 1, 1] = 1\n                    self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)\n               kernel = self.id_tensor\n               running_mean = branch.running_mean\n               running_var = branch.running_var\n               gamma = branch.weight\n               beta = branch.bias\n               eps = branch.eps\n          \n          std=(running_var+eps).sqrt()\n          t=gamma/std\n          t=t.view(-1,1,1,1)\n          return kernel*t,beta-running_mean*gamma/std\n          \n\n\nif __name__ == '__main__':\n    input=torch.randn(50,512,49,49)\n    repblock=RepBlock(512,512)\n    repblock.eval()\n    out=repblock(input)\n    repblock._switch_to_deploy()\n    out2=repblock(input)\n    print('difference between vgg and repvgg')\n    print(((out2-out)**2).sum())\n    "
  },
  {
    "path": "setup.py",
    "content": "from setuptools import find_packages, setup\n\n\n\nsetup(\n    name=\"fighingcv\",\n    version=\"1.0.0\",\n    author=\"xmu-xiaoma666\",\n    author_email=\"julien@huggingface.co\",\n    description=(\n        \"FightingCV Codebase For Attention,Backbone, MLP, Re-parameter, Convolution\"\n    ),\n    long_description=open(\"README.md\", \"r\", encoding=\"utf-8\").read(),\n    long_description_content_type=\"text/markdown\",\n    keywords=(\n        \"Attention\"\n        \"Backbone\"\n    ),\n    license=\"Apache\",\n    url=\"https://github.com/xmu-xiaoma666/External-Attention-pytorch\",\n    package_dir={\"\": \".\"},\n    packages=find_packages(\".\"),\n    # entry_points={\n    #     \"console_scripts\": [\n    #         \"huggingface-cli=huggingface_hub.commands.huggingface_cli:main\"\n    #     ]\n    # },\n    python_requires=\">=3.7.0\",\n    # install_requires=install_requires,\n    classifiers=[\n        \"Intended Audience :: Developers\",\n        \"Intended Audience :: Education\",\n        \"Intended Audience :: Science/Research\",\n        \"License :: OSI Approved :: Apache Software License\",\n        \"Operating System :: OS Independent\",\n        \"Programming Language :: Python :: 3\",\n        \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n    ],\n)"
  }
]