[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# launch bash\n*.sh\n# nsight system report files\n*.nsys-rep\n*.sqlite\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "content": "# Microsoft Open Source Code of Conduct\n\nThis project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).\n\nResources:\n\n- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)\n- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)\n- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns\n"
  },
  {
    "path": "LICENSE",
    "content": "    MIT License\n\n    Copyright (c) Microsoft Corporation.\n\n    Permission is hereby granted, free of charge, to any person obtaining a copy\n    of this software and associated documentation files (the \"Software\"), to deal\n    in the Software without restriction, including without limitation the rights\n    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n    copies of the Software, and to permit persons to whom the Software is\n    furnished to do so, subject to the following conditions:\n\n    The above copyright notice and this permission notice shall be included in all\n    copies or substantial portions of the Software.\n\n    THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n    SOFTWARE\n"
  },
  {
    "path": "MODELHUB.md",
    "content": "Access code for `baidu` is `swin`.\n\n## ImageNet-1K and ImageNet-22K Pretrained Swin-V1 Models\n\n| name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS| 22K model | 1K model |\n| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---: |:---: |\n| Swin-T | ImageNet-1K | 224x224 | 81.2 | 95.5 | 28M | 4.5G | 755 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w)/[config](configs/swin/swin_tiny_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745562/log_swin_tiny_patch4_window7_224.txt) |\n| Swin-S | ImageNet-1K | 224x224 | 83.2 | 96.2 | 50M | 8.7G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg)/[config](configs/swin/swin_small_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745563/log_swin_small_patch4_window7_224.txt) |\n| Swin-B | ImageNet-1K | 224x224 | 83.5 | 96.5 | 88M | 15.4G | 278  | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ)/[config](configs/swin/swin_base_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745564/log_swin_base_patch4_window7_224.txt) |\n| Swin-B | ImageNet-1K | 384x384 | 84.5 | 97.0 | 88M | 47.1G | 85 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw)/[config](configs/swin/swin_base_patch4_window12_384_finetune.yaml) |\n| Swin-T | ImageNet-22K | 224x224 | 80.9 | 96.0 | 28M | 4.5G | 755 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1vct0VYwwQQ8PYkBjwSSBZQ?pwd=swin)/[config](configs/swin/swin_tiny_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22kto1k_finetune.pth)/[baidu](https://pan.baidu.com/s/1K0OO-nGZDPkR8fm_r83e8Q?pwd=swin)/[config](configs/swin/swin_tiny_patch4_window7_224_22kto1k_finetune.yaml) |\n| Swin-S | ImageNet-22K | 224x224 | 83.2 | 97.0 | 50M | 8.7G | 437 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/11NC1xdT5BAGBgazdTme5Sg?pwd=swin)/[config](configs/swin/swin_small_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22kto1k_finetune.pth)/[baidu](https://pan.baidu.com/s/10RFVfjQJhwPfeHrmxQUaLw?pwd=swin)/[config](configs/swin/swin_small_patch4_window7_224_22kto1k_finetune.yaml) |\n| Swin-B | ImageNet-22K | 224x224 | 85.2 | 97.5 | 88M | 15.4G | 278 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA)/[config](configs/swin/swin_base_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg)/[config](configs/swin/swin_base_patch4_window7_224_22kto1k_finetune.yaml) |\n| Swin-B | ImageNet-22K | 384x384 | 86.4 | 98.0 | 88M | 47.1G | 85 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg)/[config](configs/swin/swin_base_patch4_window12_384_22kto1k_finetune.yaml) |\n| Swin-L | ImageNet-22K | 224x224 | 86.3 | 97.9 | 197M | 34.5G | 141 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w)/[config](configs/swin/swin_large_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ)/[config](configs/swin/swin_large_patch4_window7_224_22kto1k_finetune.yaml) |\n| Swin-L | ImageNet-22K | 384x384 | 87.3 | 98.2 | 197M | 103.9G | 42 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA)/[config](configs/swin/swin_large_patch4_window12_384_22kto1k_finetune.yaml) |\n\n## ImageNet-1K and ImageNet-22K Pretrained Swin-V2 Models\n\n| name | pretrain | resolution | window |acc@1 | acc@5 | #params | FLOPs | FPS |22K model | 1K model |\n|:---------------------:| :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---:|:---: |:---: |\n| SwinV2-T | ImageNet-1K | 256x256 | 8x8 | 81.8 | 95.9 | 28M | 5.9G | 572 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1RzLkAH_5OtfRCJe6Vlg6rg?pwd=swin)/[config](configs/swinv2/swinv2_tiny_patch4_window8_256.yaml) |\n| SwinV2-S | ImageNet-1K | 256x256 | 8x8 | 83.7 | 96.6 | 50M | 11.5G | 327 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/195PdA41szEduW3jEtRSa4Q?pwd=swin)/[config](configs/swinv2/swinv2_small_patch4_window8_256.yaml) |\n| SwinV2-B | ImageNet-1K | 256x256 | 8x8 | 84.2 | 96.9 | 88M | 20.3G | 217 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/18AfMSz3dPyzIvP1dKuERvQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window8_256.yaml) |\n| SwinV2-T | ImageNet-1K | 256x256 | 16x16 | 82.8 | 96.2 | 28M | 6.6G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1dyK3cK9Xipmv6RnTtrPocw?pwd=swin)/[config](configs/swinv2/swinv2_tiny_patch4_window16_256.yaml) |\n| SwinV2-S | ImageNet-1K | 256x256 | 16x16 | 84.1 | 96.8 | 50M | 12.6G  | 257 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1ZIPiSfWNKTPp821Ka-Mifw?pwd=swin)/[config](configs/swinv2/swinv2_small_patch4_window16_256.yaml) |\n| SwinV2-B | ImageNet-1K | 256x256 | 16x16 | 84.6 | 97.0 | 88M | 21.8G | 174 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1dlDQGn8BXCmnh7wQSM5Nhw?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window16_256.yaml) |\n| SwinV2-B<sup>\\*</sup> | ImageNet-22K | 256x256 | 16x16 | 86.2 | 97.9 |  88M | 21.8G | 174 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/1Xc2rsSsRQz_sy5mjgfxrMQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/1sgstld4MgGsZxhUAW7MlmQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml) |\n| SwinV2-B<sup>\\*</sup> | ImageNet-22K | 384x384 | 24x24 | 87.1 | 98.2 | 88M | 54.7G | 57  | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/1Xc2rsSsRQz_sy5mjgfxrMQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/17u3sEQaUYlvfL195rrORzQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.yaml) |\n| SwinV2-L<sup>\\*</sup> | ImageNet-22K | 256x256 | 16x16 | 86.9 | 98.0 | 197M | 47.5G | 95  | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/11PhCV7qAGXtZ8dXNgyiGOw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/1pqp31N80qIWjFPbudzB6Bw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.yaml) |\n| SwinV2-L<sup>\\*</sup> | ImageNet-22K | 384x384 | 24x24 | 87.6 | 98.3 | 197M | 115.4G | 33  | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/11PhCV7qAGXtZ8dXNgyiGOw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/13URdNkygr3Xn0N3e6IwjgA?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.yaml) |\n\nNote:\n\n- SwinV2-B<sup>\\*</sup>  (SwinV2-L<sup>\\*</sup>) with input resolution of 256x256 and 384x384 both fine-tuned from the\n  same pre-training model using a smaller input resolution of 192x192.\n- SwinV2-B<sup>\\*</sup> (384x384) achieves 78.08 acc@1 on ImageNet-1K-V2 while SwinV2-L<sup>\\*</sup> (384x384) achieves\n  78.31.\n\n## ImageNet-1K Pretrained Swin MLP Models\n\n| name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS |  1K model |\n| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n| [Mixer-B/16](https://arxiv.org/pdf/2105.01601.pdf) | ImageNet-1K | 224x224 | 76.4 | - | 59M | 12.7G | - | [official repo](https://github.com/google-research/vision_transformer) |\n| [ResMLP-S24](https://arxiv.org/abs/2105.03404) | ImageNet-1K | 224x224 | 79.4 | - | 30M | 6.0G | 715 | [timm](https://github.com/rwightman/pytorch-image-models) |\n| [ResMLP-B24](https://arxiv.org/abs/2105.03404) | ImageNet-1K | 224x224 | 81.0 | - | 116M | 23.0G |  231 | [timm](https://github.com/rwightman/pytorch-image-models) |\n| Swin-T/C24 | ImageNet-1K | 256x256 | 81.6 | 95.7 | 28M | 5.9G | 563 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_tiny_c24_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/17k-7l6Sxt7uZ7IV0f26GNQ)/[config](configs/swin/swin_tiny_c24_patch4_window8_256.yaml) |\n| SwinMLP-T/C24 | ImageNet-1K | 256x256 | 79.4 | 94.6 | 20M | 4.0G | 807 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c24_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1Sa4vP5R0M2RjfIe9HIga-Q)/[config](configs/swin/swin_mlp_tiny_c24_patch4_window8_256.yaml) |\n| SwinMLP-T/C12 | ImageNet-1K | 256x256 | 79.6 | 94.7 | 21M | 4.0G | 792 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c12_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1mM9J2_DEVZHUB5ASIpFl0w)/[config](configs/swin/swin_mlp_tiny_c12_patch4_window8_256.yaml) |\n| SwinMLP-T/C6 | ImageNet-1K | 256x256 | 79.7 | 94.9 | 23M | 4.0G | 766 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c6_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1hUTYVT2W1CsjICw-3W-Vjg)/[config](configs/swin/swin_mlp_tiny_c6_patch4_window8_256.yaml) |\n| SwinMLP-B | ImageNet-1K | 224x224 | 81.3 | 95.3 | 61M | 10.4G | 409 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1zww3dnbX3GxNiGfb-GwyUg)/[config](configs/swin/swin_mlp_base_patch4_window7_224.yaml) |\n\nNote: C24 means each head has 24 channels.\n\n## ImageNet-22K Pretrained Swin-MoE Models\n\n| name | #experts | k | router | resolution | window | IN-22K acc@1 | IN-1K/ft acc@1 | IN-1K/5-shot acc@1 | 22K model |\n| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n| Swin-MoE-S | 1 (dense) | - | - | 192x192 | 8x8 | 35.5| 83.5 | 70.3 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_densebaseline_22k.zip)/[baidu](https://pan.baidu.com/s/1O1m9jT2pGoago_RiRX914w?pwd=swin)/[config](configs/swinmoe/swin_moe_small_patch4_window12_192_densebaseline_22k.yaml) |\n| Swin-MoE-S | 8 | 1 | Linear | 192x192 | 8x8 | 36.8 | 84.5 | 75.2 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_8expert_32gpu_22k.zip)/[baidu](https://pan.baidu.com/s/198IlYUrWOxEUp7wNdoJT5Q?pwd=swin)/[config](configs/swinmoe/swin_moe_small_patch4_window12_192_8expert_32gpu_22k.yaml) |\n| Swin-MoE-S | 16 | 1 | Linear |192x192 | 8x8 | 37.6 | 84.9 | 76.5 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_16expert_32gpu_22k.zip)/[baidu](https://pan.baidu.com/s/1vRQweedtT42VwMTqe9-r2A?pwd=swin)/[config](configs/swinmoe/swin_moe_small_patch4_window12_192_16expert_32gpu_22k.yaml) |\n| Swin-MoE-S | 32 | 1 | Linear | 192x192 | 8x8 | 37.4 | 84.7 | 75.9 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.zip)/[baidu](https://pan.baidu.com/s/1i7rImt5pwO8gJC-PRRuZwQ?pwd=swin)/[config](configs/swinmoe/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml) |\n| Swin-MoE-S | 32 | 1 | Cosine | 192x192 | 8x8 | 37.2 | 84.3 | 75.2 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_cosine_router_32expert_32gpu_22k.zip)/[baidu](https://pan.baidu.com/s/1Yghr_12ntSrv01I9yatPDQ?pwd=swin)/[config](configs/swinmoe/swin_moe_small_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml) |\n| Swin-MoE-S | 64 | 1 | Linear | 192x192 | 8x8 | 37.8 | 84.7 | 75.7 | - |\n| Swin-MoE-S | 128 | 1 | Linear | 192x192 | 8x8 | 37.4 | 84.5 | 75.4 | - |\n| Swin-MoE-B | 1 (dense) | - | - | 192x192 | 8x8 | 37.3 | 85.1 | 75.9 | [config](configs/swinmoe/swin_moe_base_patch4_window12_192_densebaseline_22k.yaml) |\n| Swin-MoE-B | 8 | 1 | Linear | 192x192 | 8x8 | 38.1 | 85.3 | 77.2 | [config](configs/swinmoe/swin_moe_base_patch4_window12_192_8expert_32gpu_22k.yaml) |\n| Swin-MoE-B | 16 | 1 | Linear | 192x192 | 8x8 | 38.7 | 85.5 | 78.2 | [config](configs/swinmoe/swin_moe_base_patch4_window12_192_16expert_32gpu_22k.yaml) |\n| Swin-MoE-B | 32 | 1 | Linear | 192x192 | 8x8 | 38.6 | 85.5 | 77.9 | [config](configs/swinmoe/swin_moe_base_patch4_window12_192_32expert_32gpu_22k.yaml) |\n| Swin-MoE-B | 32 | 1 | Cosine | 192x192 | 8x8 | 38.5 | 85.3 | 77.3 | [config](configs/swinmoe/swin_moe_base_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml) |\n| Swin-MoE-B | 32 | 2 | Linear | 192x192 | 8x8 | 38.6 | 85.5 | 78.7 | - |\n\n## SimMIM Pretrained Swin-V2 Models\n\n> Please note that all SimMIM pretrained Swin-V2 models will be stored in the Huggingface repository starting July 2024. For more details, refer to the [huggingface repository](https://huggingface.co/zdaxie/SimMIM).\n\n- **Model size** only includes the backbone weights and excludes weights in the decoders/classification heads.\n- **Batch size** for all models is set to 2048.\n- **Validation loss** is calculated on the ImageNet-1K validation set.\n- **Fine-tuned acc@1** refers to the top-1 accuracy on the ImageNet-1K validation set after fine-tuning.\n\n| name | model size | pre-train dataset | pre-train iterations | validation loss | fine-tuned acc@1 | pre-trained model | fine-tuned model |\n| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n| SwinV2-Small | 49M | ImageNet-1K 10% | 125k | 0.4820 | 82.69 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper10_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper10_125k.pth?download=true) |\n| SwinV2-Small | 49M | ImageNet-1K 10% | 250k | 0.4961 | 83.11 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper10_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper10_250k.pth?download=true) |\n| SwinV2-Small | 49M | ImageNet-1K 10% | 500k | 0.5115 | 83.17 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper10_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper10_500k.pth?download=true) |\n| SwinV2-Small | 49M | ImageNet-1K 20% | 125k | 0.4751 | 83.05 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper20_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper20_125k.pth?download=true) |\n| SwinV2-Small | 49M | ImageNet-1K 20% | 250k | 0.4722 | 83.56 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper20_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper20_250k.pth?download=true) |\n| SwinV2-Small | 49M | ImageNet-1K 20% | 500k | 0.4734 | 83.75 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper20_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper20_500k.pth?download=true) |\n| SwinV2-Small | 49M | ImageNet-1K 50% | 125k | 0.4732 | 83.04 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper50_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper50_125k.pth?download=true) |\n| SwinV2-Small | 49M | ImageNet-1K 50% | 250k | 0.4681 | 83.67 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper50_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper50_250k.pth?download=true) |\n| SwinV2-Small | 49M | ImageNet-1K 50% | 500k | 0.4646 | 83.96 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper50_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper50_500k.pth?download=true) |\n| SwinV2-Small | 49M | ImageNet-1K | 125k | 0.4728 | 82.92 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1k_125k.pth?download=true) |\n| SwinV2-Small | 49M | ImageNet-1K | 250k | 0.4674 | 83.66 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1k_250k.pth?download=true) |\n| SwinV2-Small | 49M | ImageNet-1K | 500k | 0.4641 | 84.08 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1k_500k.pth?download=true) |\n| SwinV2-Base | 87M | ImageNet-1K 10% | 125k | 0.4822 | 83.33 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper10_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper10_125k.pth?download=true) |\n| SwinV2-Base | 87M | ImageNet-1K 10% | 250k | 0.4997 | 83.60 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper10_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper10_250k.pth?download=true) |\n| SwinV2-Base | 87M | ImageNet-1K 10% | 500k | 0.5112 | 83.41 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper10_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper10_500k.pth?download=true) |\n| SwinV2-Base | 87M | ImageNet-1K 20% | 125k | 0.4703 | 83.86 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper20_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper20_125k.pth?download=true) |\n| SwinV2-Base | 87M | ImageNet-1K 20% | 250k | 0.4679 | 84.37 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper20_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper20_250k.pth?download=true) |\n| SwinV2-Base | 87M | ImageNet-1K 20% | 500k | 0.4711 | 84.61 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper20_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper20_500k.pth?download=true) |\n| SwinV2-Base | 87M | ImageNet-1K 50% | 125k | 0.4683 | 84.04 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper50_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper50_125k.pth?download=true) |\n| SwinV2-Base | 87M | ImageNet-1K 50% | 250k | 0.4633 | 84.57 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper50_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper50_250k.pth?download=true) |\n| SwinV2-Base | 87M | ImageNet-1K 50% | 500k | 0.4598 | 84.95 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper50_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper50_500k.pth?download=true) |\n| SwinV2-Base | 87M | ImageNet-1K | 125k | 0.4680 | 84.13 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1k_125k.pth?download=true) |\n| SwinV2-Base | 87M | ImageNet-1K | 250k | 0.4626 | 84.65 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1k_250k.pth?download=true) |\n| SwinV2-Base | 87M | ImageNet-1K | 500k | 0.4588 | 85.04 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1k_500k.pth?download=true) |\n| SwinV2-Base | 87M | ImageNet-22K | 125k | 0.4695 | 84.11 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_22k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_22k_125k.pth?download=true) |\n| SwinV2-Base | 87M | ImageNet-22K | 250k | 0.4649 | 84.57 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_22k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_22k_250k.pth?download=true) |\n| SwinV2-Base | 87M | ImageNet-22K | 500k | 0.4614 | 85.11 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_22k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_22k_500k.pth?download=true) |\n| SwinV2-Large | 195M | ImageNet-1K 10% | 125k | 0.4995 | 83.69 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper10_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper10_125k.pth?download=true) |\n| SwinV2-Large | 195M | ImageNet-1K 10% | 250k | 0.5140 | 83.66 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper10_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper10_250k.pth?download=true) |\n| SwinV2-Large | 195M | ImageNet-1K 10% | 500k | 0.5150 | 83.50 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper10_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper10_500k.pth?download=true) |\n| SwinV2-Large | 195M | ImageNet-1K 20% | 125k | 0.4675 | 84.38 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper20_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper20_125k.pth?download=true) |\n| SwinV2-Large | 195M | ImageNet-1K 20% | 250k | 0.4746 | 84.71 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper20_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper20_250k.pth?download=true) |\n| SwinV2-Large | 195M | ImageNet-1K 20% | 500k | 0.4960 | 84.59 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper20_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper20_500k.pth?download=true) |\n| SwinV2-Large | 195M | ImageNet-1K 50% | 125k | 0.4622 | 84.78 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper50_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper50_125k.pth?download=true) |\n| SwinV2-Large | 195M | ImageNet-1K 50% | 250k | 0.4566 | 85.38 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper50_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper50_250k.pth?download=true) |\n| SwinV2-Large | 195M | ImageNet-1K 50% | 500k | 0.4530 | 85.80 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper50_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper50_500k.pth?download=true) |\n| SwinV2-Large | 195M | ImageNet-1K | 125k | 0.4611 | 84.98 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1k_125k.pth?download=true) |\n| SwinV2-Large | 195M | ImageNet-1K | 250k | 0.4552 | 85.45 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1k_250k.pth?download=true) |\n| SwinV2-Large | 195M | ImageNet-1K | 500k | 0.4507 | 85.91 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1k_500k.pth?download=true) |\n| SwinV2-Large | 195M | ImageNet-22K | 125k | 0.4649 | 84.61 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_22k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_22k_125k.pth?download=true) |\n| SwinV2-Large | 195M | ImageNet-22K | 250k | 0.4586 | 85.39 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_22k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_22k_250k.pth?download=true) |\n| SwinV2-Large | 195M | ImageNet-22K | 500k | 0.4536 | 85.81 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_22k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_22k_500k.pth?download=true) |\n| SwinV2-Huge | 655M | ImageNet-1K 20% | 125k | 0.4789 | 84.35 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1kper20_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1kper20_125k.pth?download=true) |\n| SwinV2-Huge | 655M | ImageNet-1K 20% | 250k | 0.5038 | 84.16 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1kper20_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1kper20_250k.pth?download=true) |\n| SwinV2-Huge | 655M | ImageNet-1K 20% | 500k | 0.5071 | 83.44 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1kper20_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1kper20_500k.pth?download=true) |\n| SwinV2-Huge | 655M | ImageNet-1K 50% | 125k | 0.4549 | 85.09 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1kper50_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1kper50_125k.pth?download=true) |\n| SwinV2-Huge | 655M | ImageNet-1K 50% | 250k | 0.4511 | 85.64 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1kper50_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1kper50_250k.pth?download=true) |\n| SwinV2-Huge | 655M | ImageNet-1K 50% | 500k | 0.4559 | 85.69 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1kper50_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1kper50_500k.pth?download=true) |\n| SwinV2-Huge | 655M | ImageNet-1K | 125k | 0.4531 | 85.23 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1k_125k.pth?download=true) |\n| SwinV2-Huge | 655M | ImageNet-1K | 250k | 0.4464 | 85.90 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1k_250k.pth?download=true) |\n| SwinV2-Huge | 655M | ImageNet-1K | 500k | 0.4416 | 86.34 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1k_500k.pth?download=true) |\n| SwinV2-Huge | 655M | ImageNet-22K | 125k | 0.4564 | 85.14 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_22k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_22k_125k.pth?download=true) |\n| SwinV2-Huge | 655M | ImageNet-22K | 250k | 0.4499 | 85.86 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_22k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_22k_250k.pth?download=true) |\n| SwinV2-Huge | 655M | ImageNet-22K | 500k | 0.4444 | 86.27 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_22k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_22k_500k.pth?download=true) |\n| SwinV2-giant | 1.06B | ImageNet-1K 50% | 125k | 0.4534 | 85.44 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_1kper50_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_1kper50_125k.pth?download=true) |\n| SwinV2-giant | 1.06B | ImageNet-1K 50% | 250k | 0.4515 | 85.76 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_1kper50_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_1kper50_250k.pth?download=true) |\n| SwinV2-giant | 1.06B | ImageNet-1K 50% | 500k | 0.4719 | 85.51 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_1kper50_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_1kper50_500k.pth?download=true) |\n| SwinV2-giant | 1.06B | ImageNet-1K | 125k | 0.4513 | 85.57 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_1k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_1k_125k.pth?download=true) |\n| SwinV2-giant | 1.06B | ImageNet-1K | 250k | 0.4442 | 86.12 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_1k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_1k_250k.pth?download=true) |\n| SwinV2-giant | 1.06B | ImageNet-1K | 500k | 0.4395 | 86.46 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_1k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_1k_500k.pth?download=true) |\n| SwinV2-giant | 1.06B | ImageNet-22K | 125k | 0.4544 | 85.39 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_22k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_22k_125k.pth?download=true) |\n| SwinV2-giant | 1.06B | ImageNet-22K | 250k | 0.4475 | 85.96 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_22k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_22k_250k.pth?download=true) |\n| SwinV2-giant | 1.06B | ImageNet-22K | 500k | 0.4416 | 86.53 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_22k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_22k_500k.pth?download=true) |\n\n## SimMIM Pretrained Swin-V1 Models\n\n**ImageNet-1K Pre-trained and Fine-tuned Models**\n\n| name | pre-train epochs | pre-train resolution | fine-tune resolution | acc@1 | pre-trained model | fine-tuned model |\n| :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n| Swin-Base | 100 | 192x192 | 192x192 | 82.8 | [google](https://drive.google.com/file/d/1Wcbr66JL26FF30Kip9fZa_0lXrDAKP-d/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_pretrain__swin_base__img192_window6__100ep.yaml) | [google](https://drive.google.com/file/d/1RsgHfjB4B1ZYblXEQVT-FPX3WSvBrxcs/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_finetune__swin_base__img192_window6__100ep.yaml) |\n| Swin-Base | 100 | 192x192 | 224x224 | 83.5 | [google](https://drive.google.com/file/d/1Wcbr66JL26FF30Kip9fZa_0lXrDAKP-d/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_pretrain__swin_base__img192_window6__100ep.yaml) | [google](https://drive.google.com/file/d/1mb43BkW56F5smwiX-g7QUUD7f1Rftq8u/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_finetune__swin_base__img224_window7__100ep.yaml) |\n| Swin-Base | 800 | 192x192 | 224x224 | 84.0 | [google](https://drive.google.com/file/d/15zENvGjHlM71uKQ3d2FbljWPubtrPtjl/view?usp=sharing)/[config](configs/swin_base__800ep/simmim_pretrain__swin_base__img192_window6__800ep.yaml) | [google](https://drive.google.com/file/d/1xEKyfMTsdh6TfnYhk5vbw0Yz7a-viZ0w/view?usp=sharing)/[config](configs/swin_base__800ep/simmim_finetune__swin_base__img224_window7__800ep.yaml) |\n| Swin-Large | 800 | 192x192 | 224x224 | 85.4 | [google](https://drive.google.com/file/d/1qDxrTl2YUDB0505_4QrU5LU2R1kKmcBP/view?usp=sharing)/[config](configs/swin_large__800ep/simmim_pretrain__swin_large__img192_window12__800ep.yaml) | [google](https://drive.google.com/file/d/1mf0ZpXttEvFsH87Www4oQ-t8Kwr0x485/view?usp=sharing)/[config](configs/swin_large__800ep/simmim_finetune__swin_large__img224_window14__800ep.yaml) |\n| SwinV2-Huge | 800 | 192x192 | 224x224 | 85.7 | / | / |\n| SwinV2-Huge | 800 | 192x192 | 512x512 | 87.1 | / | / |\n"
  },
  {
    "path": "README.md",
    "content": "# Swin Transformer\n\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-v2-scaling-up-capacity-and/object-detection-on-coco)](https://paperswithcode.com/sota/object-detection-on-coco?p=swin-transformer-v2-scaling-up-capacity-and)\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-v2-scaling-up-capacity-and/instance-segmentation-on-coco)](https://paperswithcode.com/sota/instance-segmentation-on-coco?p=swin-transformer-v2-scaling-up-capacity-and)\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-v2-scaling-up-capacity-and/semantic-segmentation-on-ade20k)](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k?p=swin-transformer-v2-scaling-up-capacity-and)\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-v2-scaling-up-capacity-and/action-classification-on-kinetics-400)](https://paperswithcode.com/sota/action-classification-on-kinetics-400?p=swin-transformer-v2-scaling-up-capacity-and)\n\nThis repo is the official implementation of [\"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows\"](https://arxiv.org/pdf/2103.14030.pdf) as well as the follow-ups. It currently includes code and models for the following tasks:\n\n> **Image Classification**: Included in this repo. See [get_started.md](get_started.md) for a quick start.\n\n> **Object Detection and Instance Segmentation**: See [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection).\n\n> **Semantic Segmentation**: See [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation).\n\n> **Video Action Recognition**: See [Video Swin Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer).\n\n> **Semi-Supervised Object Detection**: See [Soft Teacher](https://github.com/microsoft/SoftTeacher).\n\n> **SSL: Contrasitive Learning**: See [Transformer-SSL](https://github.com/SwinTransformer/Transformer-SSL).\n\n> **SSL: Masked Image Modeling**: See [get_started.md#simmim-support](https://github.com/microsoft/Swin-Transformer/blob/main/get_started.md#simmim-support).\n\n> **Mixture-of-Experts**: See [get_started](get_started.md#mixture-of-experts-support) for more instructions.\n\n> **Feature-Distillation**: See [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation).\n\n## Updates\n\n***12/29/2022***\n\n1. **Nvidia**'s [FasterTransformer](https://github.com/NVIDIA/FasterTransformer/blob/main/docs/swin_guide.md) now supports Swin Transformer V2 inference, which have significant speed improvements on `T4 and A100 GPUs`.\n\n***11/30/2022***\n\n1. Models and codes of **Feature Distillation** are released. Please refer to [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation) for details, and the checkpoints (FD-EsViT-Swin-B, FD-DeiT-ViT-B, FD-DINO-ViT-B, FD-CLIP-ViT-B, FD-CLIP-ViT-L).\n\n***09/24/2022***\n\n1. Merged [SimMIM](https://github.com/microsoft/SimMIM), which is a **Masked Image Modeling** based pre-training approach applicable to Swin and SwinV2 (and also applicable for ViT and ResNet). Please refer to [get started with SimMIM](get_started.md#simmim-support) to play with SimMIM pre-training.\n\n2. Released a series of Swin and SwinV2 models pre-trained using the SimMIM approach (see [MODELHUB for SimMIM](MODELHUB.md#simmim-pretrained-swin-v2-models)), with model size ranging from SwinV2-Small-50M to SwinV2-giant-1B, data size ranging from ImageNet-1K-10% to ImageNet-22K, and iterations from 125k to 500k. You may leverage these models to study the properties of MIM methods. Please look into the [data scaling](https://arxiv.org/abs/2206.04664) paper for more details.\n\n***07/09/2022***\n\n`News`: \n\n1. SwinV2-G achieves `61.4 mIoU` on ADE20K semantic segmentation (+1.5 mIoU over the previous SwinV2-G model), using an additional [feature distillation (FD)](https://github.com/SwinTransformer/Feature-Distillation) approach, **setting a new recrod** on this benchmark. FD is an approach that can generally improve the fine-tuning performance of various pre-trained models, including DeiT, DINO, and CLIP. Particularly, it improves CLIP pre-trained ViT-L by +1.6% to reach `89.0%` on ImageNet-1K image classification, which is **the most accurate ViT-L model**.\n2. Merged a PR from **Nvidia** that links to faster Swin Transformer inference that have significant speed improvements on `T4 and A100 GPUs`.\n3. Merged a PR from **Nvidia** that enables an option to use `pure FP16 (Apex O2)` in training, while almost maintaining the accuracy.\n\n***06/03/2022***\n\n1. Added **Swin-MoE**, the Mixture-of-Experts variant of Swin Transformer implemented using [Tutel](https://github.com/microsoft/tutel) (an optimized Mixture-of-Experts implementation). **Swin-MoE** is introduced in the [TuTel](https://arxiv.org/abs/2206.03382) paper.\n\n***05/12/2022***\n\n1. Pretrained models of [Swin Transformer V2](https://arxiv.org/abs/2111.09883) on ImageNet-1K and ImageNet-22K are released. \n2. ImageNet-22K pretrained models for Swin-V1-Tiny and Swin-V2-Small are released.\n\n***03/02/2022***\n\n1. Swin Transformer V2 and SimMIM got accepted by CVPR 2022. [SimMIM](https://github.com/microsoft/SimMIM) is a self-supervised pre-training approach based on masked image modeling, a key technique that works out the 3-billion-parameter Swin V2 model using `40x less labelled data` than that of previous billion-scale models based on JFT-3B. \n\n***02/09/2022***\n\n1. Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/akhaliq/Swin-Transformer)\n\n***10/12/2021***\n\n1. Swin Transformer received ICCV 2021 best paper award (Marr Prize).\n\n***08/09/2021***\n1. [Soft Teacher](https://arxiv.org/pdf/2106.09018v2.pdf) will appear at ICCV2021. The code will be released at [GitHub Repo](https://github.com/microsoft/SoftTeacher). `Soft Teacher` is an end-to-end semi-supervisd object detection method, achieving a new record on the COCO test-dev: `61.3 box AP` and `53.0 mask AP`.\n \n***07/03/2021***\n1. Add **Swin MLP**, which is an adaption of `Swin Transformer` by replacing all multi-head self-attention (MHSA) blocks by MLP layers (more precisely it is a group linear layer). The shifted window configuration can also significantly improve the performance of vanilla MLP architectures. \n\n***06/25/2021***\n1. [Video Swin Transformer](https://arxiv.org/abs/2106.13230) is released at [Video-Swin-Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer).\n`Video Swin Transformer` achieves state-of-the-art accuracy on a broad range of video recognition benchmarks, including action recognition (`84.9` top-1 accuracy on Kinetics-400 and `86.1` top-1 accuracy on Kinetics-600 with `~20x` less pre-training data and `~3x` smaller model size) and temporal modeling (`69.6` top-1 accuracy on Something-Something v2).\n\n***05/12/2021***\n1. Used as a backbone for `Self-Supervised Learning`: [Transformer-SSL](https://github.com/SwinTransformer/Transformer-SSL)\n\nUsing Swin-Transformer as the backbone for self-supervised learning enables us to evaluate the transferring performance of the learnt representations on down-stream tasks, which is missing in previous works due to the use of ViT/DeiT, which has not been well tamed for down-stream tasks.\n\n***04/12/2021***\n\nInitial commits:\n\n1. Pretrained models on ImageNet-1K ([Swin-T-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth), [Swin-S-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth), [Swin-B-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)) and ImageNet-22K ([Swin-B-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth), [Swin-L-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)) are provided.\n2. The supported code and models for ImageNet-1K image classification, COCO object detection and ADE20K semantic segmentation are provided.\n3. The cuda kernel implementation for the [local relation layer](https://arxiv.org/pdf/1904.11491.pdf) is provided in branch [LR-Net](https://github.com/microsoft/Swin-Transformer/tree/LR-Net).\n\n## Introduction\n\n**Swin Transformer** (the name `Swin` stands for **S**hifted **win**dow) is initially described in [arxiv](https://arxiv.org/abs/2103.14030), which capably serves as a\ngeneral-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is\ncomputed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention\ncomputation to non-overlapping local windows while also allowing for cross-window connection.\n\nSwin Transformer achieves strong performance on COCO object detection (`58.7 box AP` and `51.1 mask AP` on test-dev) and\nADE20K semantic segmentation (`53.5 mIoU` on val), surpassing previous models by a large margin.\n\n![teaser](figures/teaser.png)\n\n## Main Results on ImageNet with Pretrained Models\n\n**ImageNet-1K and ImageNet-22K Pretrained Swin-V1 Models**\n\n| name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS| 22K model | 1K model |\n| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---: |:---: |\n| Swin-T | ImageNet-1K | 224x224 | 81.2 | 95.5 | 28M | 4.5G | 755 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w)/[config](configs/swin/swin_tiny_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745562/log_swin_tiny_patch4_window7_224.txt) |\n| Swin-S | ImageNet-1K | 224x224 | 83.2 | 96.2 | 50M | 8.7G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg)/[config](configs/swin/swin_small_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745563/log_swin_small_patch4_window7_224.txt) |\n| Swin-B | ImageNet-1K | 224x224 | 83.5 | 96.5 | 88M | 15.4G | 278  | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ)/[config](configs/swin/swin_base_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745564/log_swin_base_patch4_window7_224.txt) |\n| Swin-B | ImageNet-1K | 384x384 | 84.5 | 97.0 | 88M | 47.1G | 85 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw)/[config](configs/swin/swin_base_patch4_window12_384_finetune.yaml) |\n| Swin-T | ImageNet-22K | 224x224 | 80.9 | 96.0 | 28M | 4.5G | 755 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1vct0VYwwQQ8PYkBjwSSBZQ?pwd=swin)/[config](configs/swin/swin_tiny_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22kto1k_finetune.pth)/[baidu](https://pan.baidu.com/s/1K0OO-nGZDPkR8fm_r83e8Q?pwd=swin)/[config](configs/swin/swin_tiny_patch4_window7_224_22kto1k_finetune.yaml) |\n| Swin-S | ImageNet-22K | 224x224 | 83.2 | 97.0 | 50M | 8.7G | 437 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/11NC1xdT5BAGBgazdTme5Sg?pwd=swin)/[config](configs/swin/swin_small_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22kto1k_finetune.pth)/[baidu](https://pan.baidu.com/s/10RFVfjQJhwPfeHrmxQUaLw?pwd=swin)/[config](configs/swin/swin_small_patch4_window7_224_22kto1k_finetune.yaml) |\n| Swin-B | ImageNet-22K | 224x224 | 85.2 | 97.5 | 88M | 15.4G | 278 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA)/[config](configs/swin/swin_base_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg)/[config](configs/swin/swin_base_patch4_window7_224_22kto1k_finetune.yaml) |\n| Swin-B | ImageNet-22K | 384x384 | 86.4 | 98.0 | 88M | 47.1G | 85 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg)/[config](configs/swin/swin_base_patch4_window12_384_22kto1k_finetune.yaml) |\n| Swin-L | ImageNet-22K | 224x224 | 86.3 | 97.9 | 197M | 34.5G | 141 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w)/[config](configs/swin/swin_large_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ)/[config](configs/swin/swin_large_patch4_window7_224_22kto1k_finetune.yaml) |\n| Swin-L | ImageNet-22K | 384x384 | 87.3 | 98.2 | 197M | 103.9G | 42 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA)/[config](configs/swin/swin_large_patch4_window12_384_22kto1k_finetune.yaml) |\n\n**ImageNet-1K and ImageNet-22K Pretrained Swin-V2 Models**\n\n| name | pretrain | resolution | window |acc@1 | acc@5 | #params | FLOPs | FPS |22K model | 1K model |\n|:---------------------:| :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---:|:---: |:---: |\n| SwinV2-T | ImageNet-1K | 256x256 | 8x8 | 81.8 | 95.9 | 28M | 5.9G | 572 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1RzLkAH_5OtfRCJe6Vlg6rg?pwd=swin)/[config](configs/swinv2/swinv2_tiny_patch4_window8_256.yaml) |\n| SwinV2-S | ImageNet-1K | 256x256 | 8x8 | 83.7 | 96.6 | 50M | 11.5G | 327 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/195PdA41szEduW3jEtRSa4Q?pwd=swin)/[config](configs/swinv2/swinv2_small_patch4_window8_256.yaml) |\n| SwinV2-B | ImageNet-1K | 256x256 | 8x8 | 84.2 | 96.9 | 88M | 20.3G | 217 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/18AfMSz3dPyzIvP1dKuERvQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window8_256.yaml) |\n| SwinV2-T | ImageNet-1K | 256x256 | 16x16 | 82.8 | 96.2 | 28M | 6.6G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1dyK3cK9Xipmv6RnTtrPocw?pwd=swin)/[config](configs/swinv2/swinv2_tiny_patch4_window16_256.yaml) |\n| SwinV2-S | ImageNet-1K | 256x256 | 16x16 | 84.1 | 96.8 | 50M | 12.6G  | 257 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1ZIPiSfWNKTPp821Ka-Mifw?pwd=swin)/[config](configs/swinv2/swinv2_small_patch4_window16_256.yaml) |\n| SwinV2-B | ImageNet-1K | 256x256 | 16x16 | 84.6 | 97.0 | 88M | 21.8G | 174 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1dlDQGn8BXCmnh7wQSM5Nhw?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window16_256.yaml) |\n| SwinV2-B<sup>\\*</sup> | ImageNet-22K | 256x256 | 16x16 | 86.2 | 97.9 |  88M | 21.8G | 174 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/1Xc2rsSsRQz_sy5mjgfxrMQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/1sgstld4MgGsZxhUAW7MlmQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml) |\n| SwinV2-B<sup>\\*</sup> | ImageNet-22K | 384x384 | 24x24 | 87.1 | 98.2 | 88M | 54.7G | 57  | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/1Xc2rsSsRQz_sy5mjgfxrMQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/17u3sEQaUYlvfL195rrORzQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.yaml) |\n| SwinV2-L<sup>\\*</sup> | ImageNet-22K | 256x256 | 16x16 | 86.9 | 98.0 | 197M | 47.5G | 95  | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/11PhCV7qAGXtZ8dXNgyiGOw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/1pqp31N80qIWjFPbudzB6Bw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.yaml) |\n| SwinV2-L<sup>\\*</sup> | ImageNet-22K | 384x384 | 24x24 | 87.6 | 98.3 | 197M | 115.4G | 33  | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/11PhCV7qAGXtZ8dXNgyiGOw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/13URdNkygr3Xn0N3e6IwjgA?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.yaml) |\n\nNote: \n- SwinV2-B<sup>\\*</sup>  (SwinV2-L<sup>\\*</sup>) with input resolution of 256x256 and 384x384 both fine-tuned from the same pre-training model using a smaller input resolution of 192x192.\n- SwinV2-B<sup>\\*</sup> (384x384) achieves 78.08 acc@1 on ImageNet-1K-V2 while SwinV2-L<sup>\\*</sup> (384x384) achieves 78.31.\n\n**ImageNet-1K Pretrained Swin MLP Models**\n\n| name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS |  1K model |\n| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n| [Mixer-B/16](https://arxiv.org/pdf/2105.01601.pdf) | ImageNet-1K | 224x224 | 76.4 | - | 59M | 12.7G | - | [official repo](https://github.com/google-research/vision_transformer) |\n| [ResMLP-S24](https://arxiv.org/abs/2105.03404) | ImageNet-1K | 224x224 | 79.4 | - | 30M | 6.0G | 715 | [timm](https://github.com/rwightman/pytorch-image-models) |\n| [ResMLP-B24](https://arxiv.org/abs/2105.03404) | ImageNet-1K | 224x224 | 81.0 | - | 116M | 23.0G |  231 | [timm](https://github.com/rwightman/pytorch-image-models) |\n| Swin-T/C24 | ImageNet-1K | 256x256 | 81.6 | 95.7 | 28M | 5.9G | 563 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_tiny_c24_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/17k-7l6Sxt7uZ7IV0f26GNQ)/[config](configs/swin/swin_tiny_c24_patch4_window8_256.yaml) |\n| SwinMLP-T/C24 | ImageNet-1K | 256x256 | 79.4 | 94.6 | 20M | 4.0G | 807 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c24_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1Sa4vP5R0M2RjfIe9HIga-Q)/[config](configs/swin/swin_mlp_tiny_c24_patch4_window8_256.yaml) |\n| SwinMLP-T/C12 | ImageNet-1K | 256x256 | 79.6 | 94.7 | 21M | 4.0G | 792 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c12_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1mM9J2_DEVZHUB5ASIpFl0w)/[config](configs/swin/swin_mlp_tiny_c12_patch4_window8_256.yaml) |\n| SwinMLP-T/C6 | ImageNet-1K | 256x256 | 79.7 | 94.9 | 23M | 4.0G | 766 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c6_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1hUTYVT2W1CsjICw-3W-Vjg)/[config](configs/swin/swin_mlp_tiny_c6_patch4_window8_256.yaml) |\n| SwinMLP-B | ImageNet-1K | 224x224 | 81.3 | 95.3 | 61M | 10.4G | 409 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1zww3dnbX3GxNiGfb-GwyUg)/[config](configs/swin/swin_mlp_base_patch4_window7_224.yaml) |\n\nNote: access code for `baidu` is `swin`. C24 means each head has 24 channels.\n\n**ImageNet-22K Pretrained Swin-MoE Models**\n\n- Please refer to [get_started](get_started.md#mixture-of-experts-support) for instructions on running Swin-MoE. \n- Pretrained models for Swin-MoE can be found in [MODEL HUB](MODELHUB.md#imagenet-22k-pretrained-swin-moe-models)\n\n## Main Results on Downstream Tasks\n\n**COCO Object Detection (2017 val)**\n\n| Backbone | Method | pretrain | Lr Schd | box mAP | mask mAP | #params | FLOPs |\n| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n| Swin-T | Mask R-CNN | ImageNet-1K | 3x | 46.0 | 41.6 | 48M | 267G |\n| Swin-S | Mask R-CNN | ImageNet-1K | 3x | 48.5 | 43.3 | 69M | 359G |\n| Swin-T | Cascade Mask R-CNN | ImageNet-1K | 3x | 50.4 | 43.7 | 86M | 745G |\n| Swin-S | Cascade Mask R-CNN | ImageNet-1K |  3x | 51.9 | 45.0 | 107M | 838G |\n| Swin-B | Cascade Mask R-CNN | ImageNet-1K |  3x | 51.9 | 45.0 | 145M | 982G |\n| Swin-T | RepPoints V2 | ImageNet-1K | 3x | 50.0 | - | 45M | 283G |\n| Swin-T | Mask RepPoints V2 | ImageNet-1K | 3x | 50.3 | 43.6 | 47M | 292G |\n| Swin-B | HTC++ | ImageNet-22K | 6x | 56.4 | 49.1 | 160M | 1043G |\n| Swin-L | HTC++ | ImageNet-22K | 3x | 57.1 | 49.5 | 284M | 1470G |\n| Swin-L | HTC++<sup>*</sup> | ImageNet-22K | 3x | 58.0 | 50.4 | 284M | - |\n\nNote: <sup>*</sup> indicates multi-scale testing.\n\n**ADE20K Semantic Segmentation (val)**\n\n| Backbone | Method | pretrain | Crop Size | Lr Schd | mIoU | mIoU (ms+flip) | #params | FLOPs |\n| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n| Swin-T | UPerNet | ImageNet-1K | 512x512 | 160K | 44.51 | 45.81 | 60M | 945G |\n| Swin-S | UperNet | ImageNet-1K | 512x512 | 160K | 47.64 | 49.47 | 81M | 1038G |\n| Swin-B | UperNet | ImageNet-1K | 512x512 | 160K | 48.13 | 49.72 | 121M | 1188G |\n| Swin-B | UPerNet | ImageNet-22K | 640x640 | 160K | 50.04 | 51.66 | 121M | 1841G |\n| Swin-L | UperNet | ImageNet-22K | 640x640 | 160K | 52.05 | 53.53 | 234M | 3230G |\n\n## Citing Swin Transformer\n\n```\n@inproceedings{liu2021Swin,\n  title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},\n  author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},\n  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},\n  year={2021}\n}\n```\n## Citing Local Relation Networks (the first full-attention visual backbone)\n```\n@inproceedings{hu2019local,\n  title={Local Relation Networks for Image Recognition},\n  author={Hu, Han and Zhang, Zheng and Xie, Zhenda and Lin, Stephen},\n  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},\n  pages={3464--3473},\n  year={2019}\n}\n```\n## Citing Swin Transformer V2\n```\n@inproceedings{liu2021swinv2,\n  title={Swin Transformer V2: Scaling Up Capacity and Resolution}, \n  author={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},\n  booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)},\n  year={2022}\n}\n```\n## Citing SimMIM (a self-supervised approach that enables SwinV2-G)\n```\n@inproceedings{xie2021simmim,\n  title={SimMIM: A Simple Framework for Masked Image Modeling},\n  author={Xie, Zhenda and Zhang, Zheng and Cao, Yue and Lin, Yutong and Bao, Jianmin and Yao, Zhuliang and Dai, Qi and Hu, Han},\n  booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)},\n  year={2022}\n}\n```\n## Citing SimMIM-data-scaling\n```\n@article{xie2022data,\n  title={On Data Scaling in Masked Image Modeling},\n  author={Xie, Zhenda and Zhang, Zheng and Cao, Yue and Lin, Yutong and Wei, Yixuan and Dai, Qi and Hu, Han},\n  journal={arXiv preprint arXiv:2206.04664},\n  year={2022}\n}\n```\n## Citing Swin-MoE\n```\n@misc{hwang2022tutel,\n      title={Tutel: Adaptive Mixture-of-Experts at Scale}, \n      author={Changho Hwang and Wei Cui and Yifan Xiong and Ziyue Yang and Ze Liu and Han Hu and Zilong Wang and Rafael Salas and Jithin Jose and Prabhat Ram and Joe Chau and Peng Cheng and Fan Yang and Mao Yang and Yongqiang Xiong},\n      year={2022},\n      eprint={2206.03382},\n      archivePrefix={arXiv}\n}\n```\n\n## Getting Started\n\n- For **Image Classification**, please see [get_started.md](get_started.md) for detailed instructions.\n- For **Object Detection and Instance Segmentation**, please see [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection).\n- For **Semantic Segmentation**, please see [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation).\n- For **Self-Supervised Learning**, please see [Transformer-SSL](https://github.com/SwinTransformer/Transformer-SSL).\n- For **Video Recognition**, please see [Video Swin Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer).\n\n## Third-party Usage and Experiments\n\n***In this pargraph, we cross link third-party repositories which use Swin and report results. You can let us know by raising an issue*** \n\n(`Note please report accuracy numbers and provide trained models in your new repository to facilitate others to get sense of correctness and model behavior`)\n\n[12/29/2022] Swin Transformers (V2) inference implemented in FasterTransformer: [FasterTransformer](https://github.com/NVIDIA/FasterTransformer/blob/main/docs/swin_guide.md)\n\n[06/30/2022] Swin Transformers (V1) inference implemented in FasterTransformer: [FasterTransformer](https://github.com/NVIDIA/FasterTransformer/blob/main/docs/swin_guide.md)\n\n[05/12/2022] Swin Transformers (V1) implemented in TensorFlow with the pre-trained parameters ported into them. Find the implementation,\nTensorFlow weights, code example here in [this repository](https://github.com/sayakpaul/swin-transformers-tf/).\n\n[04/06/2022] Swin Transformer for Audio Classification: [Hierarchical Token Semantic Audio Transformer](https://github.com/RetroCirce/HTS-Audio-Transformer).\n\n[12/21/2021] Swin Transformer for StyleGAN: [StyleSwin](https://github.com/microsoft/StyleSwin)\n\n[12/13/2021] Swin Transformer for Face Recognition: [FaceX-Zoo](https://github.com/JDAI-CV/FaceX-Zoo)\n\n[08/29/2021] Swin Transformer for Image Restoration: [SwinIR](https://github.com/JingyunLiang/SwinIR)\n\n[08/12/2021] Swin Transformer for person reID: [https://github.com/layumi/Person_reID_baseline_pytorch](https://github.com/layumi/Person_reID_baseline_pytorch)\n\n[06/29/2021] Swin-Transformer in PaddleClas and inference based on whl package: [https://github.com/PaddlePaddle/PaddleClas](https://github.com/PaddlePaddle/PaddleClas)\n\n[04/14/2021] Swin for RetinaNet in Detectron: https://github.com/xiaohu2015/SwinT_detectron2.\n\n[04/16/2021] Included in a famous model zoo: https://github.com/rwightman/pytorch-image-models.\n\n[04/20/2021] Swin-Transformer classifier inference using TorchServe: https://github.com/kamalkraj/Swin-Transformer-Serve\n\n## Contributing\n\nThis project welcomes contributions and suggestions.  Most contributions require you to agree to a\nContributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us\nthe rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.\n\nWhen you submit a pull request, a CLA bot will automatically determine whether you need to provide\na CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions\nprovided by the bot. You will only need to do this once across all repos using our CLA.\n\nThis project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).\nFor more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or\ncontact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.\n\n## Trademarks\n\nThis project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft \ntrademarks or logos is subject to and must follow \n[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).\nUse of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.\nAny use of third-party trademarks or logos are subject to those third-party's policies.\n"
  },
  {
    "path": "SECURITY.md",
    "content": "<!-- BEGIN MICROSOFT SECURITY.MD V0.0.5 BLOCK -->\n\n## Security\n\nMicrosoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).\n\nIf you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.\n\n## Reporting Security Issues\n\n**Please do not report security vulnerabilities through public GitHub issues.**\n\nInstead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).\n\nIf you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com).  If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).\n\nYou should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). \n\nPlease include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:\n\n  * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)\n  * Full paths of source file(s) related to the manifestation of the issue\n  * The location of the affected source code (tag/branch/commit or direct URL)\n  * Any special configuration required to reproduce the issue\n  * Step-by-step instructions to reproduce the issue\n  * Proof-of-concept or exploit code (if possible)\n  * Impact of the issue, including how an attacker might exploit the issue\n\nThis information will help us triage your report more quickly.\n\nIf you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.\n\n## Preferred Languages\n\nWe prefer all communications to be in English.\n\n## Policy\n\nMicrosoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).\n\n<!-- END MICROSOFT SECURITY.MD BLOCK -->"
  },
  {
    "path": "SUPPORT.md",
    "content": "# TODO: The maintainer of this repo has not yet edited this file\r\n\r\n**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?\r\n\r\n- **No CSS support:** Fill out this template with information about how to file issues and get help.\r\n- **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport).\r\n- **Not sure?** Fill out a SPOT intake as though the answer were \"Yes\". CSS will help you decide.\r\n\r\n*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*\r\n\r\n# Support\r\n\r\n## How to file issues and get help  \r\n\r\nThis project uses GitHub Issues to track bugs and feature requests. Please search the existing \r\nissues before filing new issues to avoid duplicates.  For new issues, file your bug or \r\nfeature request as a new Issue.\r\n\r\nFor help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE \r\nFOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER\r\nCHANNEL. WHERE WILL YOU HELP PEOPLE?**.\r\n\r\n## Microsoft Support Policy  \r\n\r\nSupport for this **PROJECT or PRODUCT** is limited to the resources listed above.\r\n"
  },
  {
    "path": "config.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------'\n\nimport os\nimport torch\nimport yaml\nfrom yacs.config import CfgNode as CN\n\n# pytorch major version (1.x or 2.x)\nPYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])\n\n_C = CN()\n\n# Base config files\n_C.BASE = ['']\n\n# -----------------------------------------------------------------------------\n# Data settings\n# -----------------------------------------------------------------------------\n_C.DATA = CN()\n# Batch size for a single GPU, could be overwritten by command line argument\n_C.DATA.BATCH_SIZE = 128\n# Path to dataset, could be overwritten by command line argument\n_C.DATA.DATA_PATH = ''\n# Dataset name\n_C.DATA.DATASET = 'imagenet'\n# Input image size\n_C.DATA.IMG_SIZE = 224\n# Interpolation to resize image (random, bilinear, bicubic)\n_C.DATA.INTERPOLATION = 'bicubic'\n# Use zipped dataset instead of folder dataset\n# could be overwritten by command line argument\n_C.DATA.ZIP_MODE = False\n# Cache Data in Memory, could be overwritten by command line argument\n_C.DATA.CACHE_MODE = 'part'\n# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.\n_C.DATA.PIN_MEMORY = True\n# Number of data loading threads\n_C.DATA.NUM_WORKERS = 8\n\n# [SimMIM] Mask patch size for MaskGenerator\n_C.DATA.MASK_PATCH_SIZE = 32\n# [SimMIM] Mask ratio for MaskGenerator\n_C.DATA.MASK_RATIO = 0.6\n\n# -----------------------------------------------------------------------------\n# Model settings\n# -----------------------------------------------------------------------------\n_C.MODEL = CN()\n# Model type\n_C.MODEL.TYPE = 'swin'\n# Model name\n_C.MODEL.NAME = 'swin_tiny_patch4_window7_224'\n# Pretrained weight from checkpoint, could be imagenet22k pretrained weight\n# could be overwritten by command line argument\n_C.MODEL.PRETRAINED = ''\n# Checkpoint to resume, could be overwritten by command line argument\n_C.MODEL.RESUME = ''\n# Number of classes, overwritten in data preparation\n_C.MODEL.NUM_CLASSES = 1000\n# Dropout rate\n_C.MODEL.DROP_RATE = 0.0\n# Drop path rate\n_C.MODEL.DROP_PATH_RATE = 0.1\n# Label Smoothing\n_C.MODEL.LABEL_SMOOTHING = 0.1\n\n# Swin Transformer parameters\n_C.MODEL.SWIN = CN()\n_C.MODEL.SWIN.PATCH_SIZE = 4\n_C.MODEL.SWIN.IN_CHANS = 3\n_C.MODEL.SWIN.EMBED_DIM = 96\n_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]\n_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]\n_C.MODEL.SWIN.WINDOW_SIZE = 7\n_C.MODEL.SWIN.MLP_RATIO = 4.\n_C.MODEL.SWIN.QKV_BIAS = True\n_C.MODEL.SWIN.QK_SCALE = None\n_C.MODEL.SWIN.APE = False\n_C.MODEL.SWIN.PATCH_NORM = True\n\n# Swin Transformer V2 parameters\n_C.MODEL.SWINV2 = CN()\n_C.MODEL.SWINV2.PATCH_SIZE = 4\n_C.MODEL.SWINV2.IN_CHANS = 3\n_C.MODEL.SWINV2.EMBED_DIM = 96\n_C.MODEL.SWINV2.DEPTHS = [2, 2, 6, 2]\n_C.MODEL.SWINV2.NUM_HEADS = [3, 6, 12, 24]\n_C.MODEL.SWINV2.WINDOW_SIZE = 7\n_C.MODEL.SWINV2.MLP_RATIO = 4.\n_C.MODEL.SWINV2.QKV_BIAS = True\n_C.MODEL.SWINV2.APE = False\n_C.MODEL.SWINV2.PATCH_NORM = True\n_C.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0]\n\n# Swin Transformer MoE parameters\n_C.MODEL.SWIN_MOE = CN()\n_C.MODEL.SWIN_MOE.PATCH_SIZE = 4\n_C.MODEL.SWIN_MOE.IN_CHANS = 3\n_C.MODEL.SWIN_MOE.EMBED_DIM = 96\n_C.MODEL.SWIN_MOE.DEPTHS = [2, 2, 6, 2]\n_C.MODEL.SWIN_MOE.NUM_HEADS = [3, 6, 12, 24]\n_C.MODEL.SWIN_MOE.WINDOW_SIZE = 7\n_C.MODEL.SWIN_MOE.MLP_RATIO = 4.\n_C.MODEL.SWIN_MOE.QKV_BIAS = True\n_C.MODEL.SWIN_MOE.QK_SCALE = None\n_C.MODEL.SWIN_MOE.APE = False\n_C.MODEL.SWIN_MOE.PATCH_NORM = True\n_C.MODEL.SWIN_MOE.MLP_FC2_BIAS = True\n_C.MODEL.SWIN_MOE.INIT_STD = 0.02\n_C.MODEL.SWIN_MOE.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0]\n_C.MODEL.SWIN_MOE.MOE_BLOCKS = [[-1], [-1], [-1], [-1]]\n_C.MODEL.SWIN_MOE.NUM_LOCAL_EXPERTS = 1\n_C.MODEL.SWIN_MOE.TOP_VALUE = 1\n_C.MODEL.SWIN_MOE.CAPACITY_FACTOR = 1.25\n_C.MODEL.SWIN_MOE.COSINE_ROUTER = False\n_C.MODEL.SWIN_MOE.NORMALIZE_GATE = False\n_C.MODEL.SWIN_MOE.USE_BPR = True\n_C.MODEL.SWIN_MOE.IS_GSHARD_LOSS = False\n_C.MODEL.SWIN_MOE.GATE_NOISE = 1.0\n_C.MODEL.SWIN_MOE.COSINE_ROUTER_DIM = 256\n_C.MODEL.SWIN_MOE.COSINE_ROUTER_INIT_T = 0.5\n_C.MODEL.SWIN_MOE.MOE_DROP = 0.0\n_C.MODEL.SWIN_MOE.AUX_LOSS_WEIGHT = 0.01\n\n# Swin MLP parameters\n_C.MODEL.SWIN_MLP = CN()\n_C.MODEL.SWIN_MLP.PATCH_SIZE = 4\n_C.MODEL.SWIN_MLP.IN_CHANS = 3\n_C.MODEL.SWIN_MLP.EMBED_DIM = 96\n_C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2]\n_C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24]\n_C.MODEL.SWIN_MLP.WINDOW_SIZE = 7\n_C.MODEL.SWIN_MLP.MLP_RATIO = 4.\n_C.MODEL.SWIN_MLP.APE = False\n_C.MODEL.SWIN_MLP.PATCH_NORM = True\n\n# [SimMIM] Norm target during training\n_C.MODEL.SIMMIM = CN()\n_C.MODEL.SIMMIM.NORM_TARGET = CN()\n_C.MODEL.SIMMIM.NORM_TARGET.ENABLE = False\n_C.MODEL.SIMMIM.NORM_TARGET.PATCH_SIZE = 47\n\n# -----------------------------------------------------------------------------\n# Training settings\n# -----------------------------------------------------------------------------\n_C.TRAIN = CN()\n_C.TRAIN.START_EPOCH = 0\n_C.TRAIN.EPOCHS = 300\n_C.TRAIN.WARMUP_EPOCHS = 20\n_C.TRAIN.WEIGHT_DECAY = 0.05\n_C.TRAIN.BASE_LR = 5e-4\n_C.TRAIN.WARMUP_LR = 5e-7\n_C.TRAIN.MIN_LR = 5e-6\n# Clip gradient norm\n_C.TRAIN.CLIP_GRAD = 5.0\n# Auto resume from latest checkpoint\n_C.TRAIN.AUTO_RESUME = True\n# Gradient accumulation steps\n# could be overwritten by command line argument\n_C.TRAIN.ACCUMULATION_STEPS = 1\n# Whether to use gradient checkpointing to save memory\n# could be overwritten by command line argument\n_C.TRAIN.USE_CHECKPOINT = False\n\n# LR scheduler\n_C.TRAIN.LR_SCHEDULER = CN()\n_C.TRAIN.LR_SCHEDULER.NAME = 'cosine'\n# Epoch interval to decay LR, used in StepLRScheduler\n_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30\n# LR decay rate, used in StepLRScheduler\n_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1\n# warmup_prefix used in CosineLRScheduler\n_C.TRAIN.LR_SCHEDULER.WARMUP_PREFIX = True\n# [SimMIM] Gamma / Multi steps value, used in MultiStepLRScheduler\n_C.TRAIN.LR_SCHEDULER.GAMMA = 0.1\n_C.TRAIN.LR_SCHEDULER.MULTISTEPS = []\n\n# Optimizer\n_C.TRAIN.OPTIMIZER = CN()\n_C.TRAIN.OPTIMIZER.NAME = 'adamw'\n# Optimizer Epsilon\n_C.TRAIN.OPTIMIZER.EPS = 1e-8\n# Optimizer Betas\n_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)\n# SGD momentum\n_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9\n\n# [SimMIM] Layer decay for fine-tuning\n_C.TRAIN.LAYER_DECAY = 1.0\n\n# MoE\n_C.TRAIN.MOE = CN()\n# Only save model on master device\n_C.TRAIN.MOE.SAVE_MASTER = False\n# -----------------------------------------------------------------------------\n# Augmentation settings\n# -----------------------------------------------------------------------------\n_C.AUG = CN()\n# Color jitter factor\n_C.AUG.COLOR_JITTER = 0.4\n# Use AutoAugment policy. \"v0\" or \"original\"\n_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'\n# Random erase prob\n_C.AUG.REPROB = 0.25\n# Random erase mode\n_C.AUG.REMODE = 'pixel'\n# Random erase count\n_C.AUG.RECOUNT = 1\n# Mixup alpha, mixup enabled if > 0\n_C.AUG.MIXUP = 0.8\n# Cutmix alpha, cutmix enabled if > 0\n_C.AUG.CUTMIX = 1.0\n# Cutmix min/max ratio, overrides alpha and enables cutmix if set\n_C.AUG.CUTMIX_MINMAX = None\n# Probability of performing mixup or cutmix when either/both is enabled\n_C.AUG.MIXUP_PROB = 1.0\n# Probability of switching to cutmix when both mixup and cutmix enabled\n_C.AUG.MIXUP_SWITCH_PROB = 0.5\n# How to apply mixup/cutmix params. Per \"batch\", \"pair\", or \"elem\"\n_C.AUG.MIXUP_MODE = 'batch'\n\n# -----------------------------------------------------------------------------\n# Testing settings\n# -----------------------------------------------------------------------------\n_C.TEST = CN()\n# Whether to use center crop when testing\n_C.TEST.CROP = True\n# Whether to use SequentialSampler as validation sampler\n_C.TEST.SEQUENTIAL = False\n_C.TEST.SHUFFLE = False\n\n# -----------------------------------------------------------------------------\n# Misc\n# -----------------------------------------------------------------------------\n# [SimMIM] Whether to enable pytorch amp, overwritten by command line argument\n_C.ENABLE_AMP = False\n\n# Enable Pytorch automatic mixed precision (amp).\n_C.AMP_ENABLE = True\n# [Deprecated] Mixed precision opt level of apex, if O0, no apex amp is used ('O0', 'O1', 'O2')\n_C.AMP_OPT_LEVEL = ''\n# Path to output folder, overwritten by command line argument\n_C.OUTPUT = ''\n# Tag of experiment, overwritten by command line argument\n_C.TAG = 'default'\n# Frequency to save checkpoint\n_C.SAVE_FREQ = 1\n# Frequency to logging info\n_C.PRINT_FREQ = 10\n# Fixed random seed\n_C.SEED = 0\n# Perform evaluation only, overwritten by command line argument\n_C.EVAL_MODE = False\n# Test throughput only, overwritten by command line argument\n_C.THROUGHPUT_MODE = False\n# local rank for DistributedDataParallel, given by command line argument\n_C.LOCAL_RANK = 0\n# for acceleration\n_C.FUSED_WINDOW_PROCESS = False\n_C.FUSED_LAYERNORM = False\n\n\ndef _update_config_from_file(config, cfg_file):\n    config.defrost()\n    with open(cfg_file, 'r') as f:\n        yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)\n\n    for cfg in yaml_cfg.setdefault('BASE', ['']):\n        if cfg:\n            _update_config_from_file(\n                config, os.path.join(os.path.dirname(cfg_file), cfg)\n            )\n    print('=> merge config from {}'.format(cfg_file))\n    config.merge_from_file(cfg_file)\n    config.freeze()\n\n\ndef update_config(config, args):\n    _update_config_from_file(config, args.cfg)\n\n    config.defrost()\n    if args.opts:\n        config.merge_from_list(args.opts)\n\n    def _check_args(name):\n        if hasattr(args, name) and eval(f'args.{name}'):\n            return True\n        return False\n\n    # merge from specific arguments\n    if _check_args('batch_size'):\n        config.DATA.BATCH_SIZE = args.batch_size\n    if _check_args('data_path'):\n        config.DATA.DATA_PATH = args.data_path\n    if _check_args('zip'):\n        config.DATA.ZIP_MODE = True\n    if _check_args('cache_mode'):\n        config.DATA.CACHE_MODE = args.cache_mode\n    if _check_args('pretrained'):\n        config.MODEL.PRETRAINED = args.pretrained\n    if _check_args('resume'):\n        config.MODEL.RESUME = args.resume\n    if _check_args('accumulation_steps'):\n        config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps\n    if _check_args('use_checkpoint'):\n        config.TRAIN.USE_CHECKPOINT = True\n    if _check_args('amp_opt_level'):\n        print(\"[warning] Apex amp has been deprecated, please use pytorch amp instead!\")\n        if args.amp_opt_level == 'O0':\n            config.AMP_ENABLE = False\n    if _check_args('disable_amp'):\n        config.AMP_ENABLE = False\n    if _check_args('output'):\n        config.OUTPUT = args.output\n    if _check_args('tag'):\n        config.TAG = args.tag\n    if _check_args('eval'):\n        config.EVAL_MODE = True\n    if _check_args('throughput'):\n        config.THROUGHPUT_MODE = True\n\n    # [SimMIM]\n    if _check_args('enable_amp'):\n        config.ENABLE_AMP = args.enable_amp\n\n    # for acceleration\n    if _check_args('fused_window_process'):\n        config.FUSED_WINDOW_PROCESS = True\n    if _check_args('fused_layernorm'):\n        config.FUSED_LAYERNORM = True\n    ## Overwrite optimizer if not None, currently we use it for [fused_adam, fused_lamb]\n    if _check_args('optim'):\n        config.TRAIN.OPTIMIZER.NAME = args.optim\n\n    # set local rank for distributed training\n    if PYTORCH_MAJOR_VERSION == 1:\n        config.LOCAL_RANK = args.local_rank\n    else:\n        config.LOCAL_RANK = int(os.environ['LOCAL_RANK'])\n\n    # output folder\n    config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)\n\n    config.freeze()\n\n\ndef get_config(args):\n    \"\"\"Get a yacs CfgNode object with default values.\"\"\"\n    # Return a clone so that the defaults will not be altered\n    # This is for the \"local variable\" use pattern\n    config = _C.clone()\n    update_config(config, args)\n\n    return config\n"
  },
  {
    "path": "configs/simmim/simmim_finetune__swin_base__img224_window7__800ep.yaml",
    "content": "MODEL:\n  TYPE: swin\n  NAME: simmim_finetune\n  DROP_PATH_RATE: 0.1\n  SWIN:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 7\nDATA:\n  IMG_SIZE: 224\nTRAIN:\n  EPOCHS: 100\n  WARMUP_EPOCHS: 20\n  BASE_LR: 1.25e-3\n  WARMUP_LR: 2.5e-7\n  MIN_LR: 2.5e-7\n  WEIGHT_DECAY: 0.05\n  LAYER_DECAY: 0.8\nPRINT_FREQ: 100\nSAVE_FREQ: 5\nTAG: simmim_finetune__swin_base__img224_window7__800ep"
  },
  {
    "path": "configs/simmim/simmim_finetune__swinv2_base__img224_window14__800ep.yaml",
    "content": "MODEL:\n  TYPE: swinv2\n  NAME: simmim_finetune\n  DROP_PATH_RATE: 0.1\n  SWINV2:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 14\n    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]\nDATA:\n  IMG_SIZE: 224\nTRAIN:\n  EPOCHS: 100\n  WARMUP_EPOCHS: 20\n  BASE_LR: 1.25e-3\n  WARMUP_LR: 2.5e-7\n  MIN_LR: 2.5e-7\n  WEIGHT_DECAY: 0.05\n  LAYER_DECAY: 0.75\nPRINT_FREQ: 100\nSAVE_FREQ: 5\nTAG: simmim_finetune__swinv2_base__img224_window14__800ep"
  },
  {
    "path": "configs/simmim/simmim_pretrain__swin_base__img192_window6__800ep.yaml",
    "content": "MODEL:\n  TYPE: swin\n  NAME: simmim_pretrain\n  DROP_PATH_RATE: 0.0\n  SWIN:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 6\nDATA:\n  IMG_SIZE: 192\n  MASK_PATCH_SIZE: 32\n  MASK_RATIO: 0.6\nTRAIN:\n  EPOCHS: 800\n  WARMUP_EPOCHS: 10\n  BASE_LR: 1e-4\n  WARMUP_LR: 5e-7\n  WEIGHT_DECAY: 0.05\n  LR_SCHEDULER:\n    NAME: 'multistep'\n    GAMMA: 0.1\n    MULTISTEPS: [700,]\nPRINT_FREQ: 100\nSAVE_FREQ: 5\nTAG: simmim_pretrain__swin_base__img192_window6__800ep"
  },
  {
    "path": "configs/simmim/simmim_pretrain__swinv2_base__img192_window12__800ep.yaml",
    "content": "MODEL:\n  TYPE: swinv2\n  NAME: simmim_pretrain\n  DROP_PATH_RATE: 0.1\n  SIMMIM:\n    NORM_TARGET:\n      ENABLE: True\n      PATCH_SIZE: 47\n  SWINV2:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 12\nDATA:\n  IMG_SIZE: 192\n  MASK_PATCH_SIZE: 32\n  MASK_RATIO: 0.6\nTRAIN:\n  EPOCHS: 800\n  WARMUP_EPOCHS: 10\n  BASE_LR: 1e-4\n  WARMUP_LR: 5e-7\n  WEIGHT_DECAY: 0.05\n  LR_SCHEDULER:\n    NAME: 'multistep'\n    GAMMA: 0.1\n    MULTISTEPS: [700,]\nPRINT_FREQ: 100\nSAVE_FREQ: 5\nTAG: simmim_pretrain__swinv2_base__img192_window12__800ep"
  },
  {
    "path": "configs/swin/swin_base_patch4_window12_384_22kto1k_finetune.yaml",
    "content": "DATA:\n  IMG_SIZE: 384\nMODEL:\n  TYPE: swin\n  NAME: swin_base_patch4_window12_384_22kto1k_finetune\n  DROP_PATH_RATE: 0.2\n  SWIN:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 12\nTRAIN:\n  EPOCHS: 30\n  WARMUP_EPOCHS: 5\n  WEIGHT_DECAY: 1e-8\n  BASE_LR: 2e-05\n  WARMUP_LR: 2e-08\n  MIN_LR: 2e-07\nTEST:\n  CROP: False"
  },
  {
    "path": "configs/swin/swin_base_patch4_window12_384_finetune.yaml",
    "content": "DATA:\n  IMG_SIZE: 384\nMODEL:\n  TYPE: swin\n  NAME: swin_base_patch4_window12_384_finetune\n  DROP_PATH_RATE: 0.5\n  SWIN:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 12\nTRAIN:\n  EPOCHS: 30\n  WARMUP_EPOCHS: 5\n  WEIGHT_DECAY: 1e-8\n  BASE_LR: 2e-05\n  WARMUP_LR: 2e-08\n  MIN_LR: 2e-07\nTEST:\n  CROP: False"
  },
  {
    "path": "configs/swin/swin_base_patch4_window7_224.yaml",
    "content": "MODEL:\n  TYPE: swin\n  NAME: swin_base_patch4_window7_224\n  DROP_PATH_RATE: 0.5\n  SWIN:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 7"
  },
  {
    "path": "configs/swin/swin_base_patch4_window7_224_22k.yaml",
    "content": "DATA:\n  DATASET: imagenet22K\nMODEL:\n  TYPE: swin\n  NAME: swin_base_patch4_window7_224_22k\n  DROP_PATH_RATE: 0.2\n  SWIN:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 7\nTRAIN:\n  EPOCHS: 90\n  WARMUP_EPOCHS: 5\n  WEIGHT_DECAY: 0.05\n  BASE_LR: 1.25e-4 # 4096 batch-size\n  WARMUP_LR: 1.25e-7\n  MIN_LR: 1.25e-6"
  },
  {
    "path": "configs/swin/swin_base_patch4_window7_224_22kto1k_finetune.yaml",
    "content": "MODEL:\n  TYPE: swin\n  NAME: swin_base_patch4_window7_224_22kto1k_finetune\n  DROP_PATH_RATE: 0.2\n  SWIN:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 7\nTRAIN:\n  EPOCHS: 30\n  WARMUP_EPOCHS: 5\n  WEIGHT_DECAY: 1e-8\n  BASE_LR: 2e-05\n  WARMUP_LR: 2e-08\n  MIN_LR: 2e-07"
  },
  {
    "path": "configs/swin/swin_large_patch4_window12_384_22kto1k_finetune.yaml",
    "content": "DATA:\n  IMG_SIZE: 384\nMODEL:\n  TYPE: swin\n  NAME: swin_large_patch4_window12_384_22kto1k_finetune\n  DROP_PATH_RATE: 0.2\n  SWIN:\n    EMBED_DIM: 192\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 6, 12, 24, 48 ]\n    WINDOW_SIZE: 12\nTRAIN:\n  EPOCHS: 30\n  WARMUP_EPOCHS: 5\n  WEIGHT_DECAY: 1e-8\n  BASE_LR: 2e-05\n  WARMUP_LR: 2e-08\n  MIN_LR: 2e-07\nTEST:\n  CROP: False"
  },
  {
    "path": "configs/swin/swin_large_patch4_window7_224_22k.yaml",
    "content": "DATA:\n  DATASET: imagenet22K\nMODEL:\n  TYPE: swin\n  NAME: swin_large_patch4_window7_224_22k\n  DROP_PATH_RATE: 0.2\n  SWIN:\n    EMBED_DIM: 192\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 6, 12, 24, 48 ]\n    WINDOW_SIZE: 7\nTRAIN:\n  EPOCHS: 90\n  WARMUP_EPOCHS: 5\n  WEIGHT_DECAY: 0.05\n  BASE_LR: 1.25e-4 # 4096 batch-size\n  WARMUP_LR: 1.25e-7\n  MIN_LR: 1.25e-6"
  },
  {
    "path": "configs/swin/swin_large_patch4_window7_224_22kto1k_finetune.yaml",
    "content": "MODEL:\n  TYPE: swin\n  NAME: swin_large_patch4_window7_224_22kto1k_finetune\n  DROP_PATH_RATE: 0.2\n  SWIN:\n    EMBED_DIM: 192\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 6, 12, 24, 48 ]\n    WINDOW_SIZE: 7\nTRAIN:\n  EPOCHS: 30\n  WARMUP_EPOCHS: 5\n  WEIGHT_DECAY: 1e-8\n  BASE_LR: 2e-05\n  WARMUP_LR: 2e-08\n  MIN_LR: 2e-07"
  },
  {
    "path": "configs/swin/swin_small_patch4_window7_224.yaml",
    "content": "MODEL:\n  TYPE: swin\n  NAME: swin_small_patch4_window7_224\n  DROP_PATH_RATE: 0.3\n  SWIN:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 3, 6, 12, 24 ]\n    WINDOW_SIZE: 7"
  },
  {
    "path": "configs/swin/swin_small_patch4_window7_224_22k.yaml",
    "content": "DATA:\n  DATASET: imagenet22K\nMODEL:\n  TYPE: swin\n  NAME: swin_small_patch4_window7_224_22k\n  DROP_PATH_RATE: 0.2\n  SWIN:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 3, 6, 12, 24 ]\n    WINDOW_SIZE: 7\nTRAIN:\n  EPOCHS: 90\n  WARMUP_EPOCHS: 5\n  WEIGHT_DECAY: 0.05\n  BASE_LR: 1.25e-4 # 4096 batch-size\n  WARMUP_LR: 1.25e-7\n  MIN_LR: 1.25e-6"
  },
  {
    "path": "configs/swin/swin_small_patch4_window7_224_22kto1k_finetune.yaml",
    "content": "MODEL:\n  TYPE: swin\n  NAME: swin_small_patch4_window7_224_22kto1k_finetune\n  DROP_PATH_RATE: 0.2\n  SWIN:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 3, 6, 12, 24 ]\n    WINDOW_SIZE: 7\nTRAIN:\n  EPOCHS: 30\n  WARMUP_EPOCHS: 5\n  WEIGHT_DECAY: 1e-8\n  BASE_LR: 2e-05\n  WARMUP_LR: 2e-08\n  MIN_LR: 2e-07"
  },
  {
    "path": "configs/swin/swin_tiny_c24_patch4_window8_256.yaml",
    "content": "DATA:\n  IMG_SIZE: 256\nMODEL:\n  TYPE: swin\n  NAME: swin_tiny_c24_patch4_window8_256\n  DROP_PATH_RATE: 0.2\n  SWIN:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 6, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 8"
  },
  {
    "path": "configs/swin/swin_tiny_patch4_window7_224.yaml",
    "content": "MODEL:\n  TYPE: swin\n  NAME: swin_tiny_patch4_window7_224\n  DROP_PATH_RATE: 0.2\n  SWIN:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 6, 2 ]\n    NUM_HEADS: [ 3, 6, 12, 24 ]\n    WINDOW_SIZE: 7"
  },
  {
    "path": "configs/swin/swin_tiny_patch4_window7_224_22k.yaml",
    "content": "DATA:\n  DATASET: imagenet22K\nMODEL:\n  TYPE: swin\n  NAME: swin_tiny_patch4_window7_224_22k\n  DROP_PATH_RATE: 0.1\n  SWIN:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 6, 2 ]\n    NUM_HEADS: [ 3, 6, 12, 24 ]\n    WINDOW_SIZE: 7\nTRAIN:\n  EPOCHS: 90\n  WARMUP_EPOCHS: 5\n  WEIGHT_DECAY: 0.05\n  BASE_LR: 1.25e-4 # 4096 batch-size\n  WARMUP_LR: 1.25e-7\n  MIN_LR: 1.25e-6"
  },
  {
    "path": "configs/swin/swin_tiny_patch4_window7_224_22kto1k_finetune.yaml",
    "content": "MODEL:\n  TYPE: swin\n  NAME: swin_tiny_patch4_window7_224_22kto1k_finetune\n  DROP_PATH_RATE: 0.1\n  SWIN:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 6, 2 ]\n    NUM_HEADS: [ 3, 6, 12, 24 ]\n    WINDOW_SIZE: 7\nTRAIN:\n  EPOCHS: 30\n  WARMUP_EPOCHS: 5\n  WEIGHT_DECAY: 1e-8\n  BASE_LR: 2e-05\n  WARMUP_LR: 2e-08\n  MIN_LR: 2e-07"
  },
  {
    "path": "configs/swinmlp/swin_mlp_base_patch4_window7_224.yaml",
    "content": "MODEL:\n  TYPE: swin_mlp\n  NAME: swin_mlp_base_patch4_window7_224\n  DROP_PATH_RATE: 0.5\n  SWIN_MLP:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 7\n"
  },
  {
    "path": "configs/swinmlp/swin_mlp_tiny_c12_patch4_window8_256.yaml",
    "content": "DATA:\n  IMG_SIZE: 256\nMODEL:\n  TYPE: swin_mlp\n  NAME: swin_mlp_tiny_c12_patch4_window8_256\n  DROP_PATH_RATE: 0.2\n  SWIN_MLP:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 6, 2 ]\n    NUM_HEADS: [ 8, 16, 32, 64 ]\n    WINDOW_SIZE: 8"
  },
  {
    "path": "configs/swinmlp/swin_mlp_tiny_c24_patch4_window8_256.yaml",
    "content": "DATA:\n  IMG_SIZE: 256\nMODEL:\n  TYPE: swin_mlp\n  NAME: swin_mlp_tiny_c24_patch4_window8_256\n  DROP_PATH_RATE: 0.2\n  SWIN_MLP:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 6, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 8"
  },
  {
    "path": "configs/swinmlp/swin_mlp_tiny_c6_patch4_window8_256.yaml",
    "content": "DATA:\n  IMG_SIZE: 256\nMODEL:\n  TYPE: swin_mlp\n  NAME: swin_mlp_tiny_c6_patch4_window8_256\n  DROP_PATH_RATE: 0.2\n  SWIN_MLP:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 6, 2 ]\n    NUM_HEADS: [ 16, 32, 64, 128 ]\n    WINDOW_SIZE: 8"
  },
  {
    "path": "configs/swinmoe/swin_moe_base_patch4_window12_192_16expert_32gpu_22k.yaml",
    "content": "DATA:\n  DATASET: imagenet22K\n  IMG_SIZE: 192\nMODEL:\n  TYPE: swin_moe\n  NAME: swin_moe_base_patch4_window12_192_16expert_32gpu_22k\n  DROP_PATH_RATE: 0.3\n  SWIN_MOE:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 12\n    MLP_FC2_BIAS: False\n    INIT_STD: 0.005\n    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]\n    NUM_LOCAL_EXPERTS: -2\n    TOP_VALUE: 1\n    CAPACITY_FACTOR: 1.25\n    IS_GSHARD_LOSS: False\n    MOE_DROP: 0.1\n    AUX_LOSS_WEIGHT: 0.01\nTRAIN:\n  EPOCHS: 90\n  WARMUP_EPOCHS: 10\n  WEIGHT_DECAY: 0.1\n  BASE_LR: 1.25e-4 # 4096 batch-size\n  WARMUP_LR: 1.25e-7\n  MIN_LR: 1.25e-6\n  CLIP_GRAD: 3.0\nTEST:\n  SHUFFLE: True"
  },
  {
    "path": "configs/swinmoe/swin_moe_base_patch4_window12_192_32expert_32gpu_22k.yaml",
    "content": "DATA:\n  DATASET: imagenet22K\n  IMG_SIZE: 192\nMODEL:\n  TYPE: swin_moe\n  NAME: swin_moe_base_patch4_window12_192_32expert_32gpu_22k\n  DROP_PATH_RATE: 0.3\n  SWIN_MOE:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 12\n    MLP_FC2_BIAS: False\n    INIT_STD: 0.005\n    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]\n    NUM_LOCAL_EXPERTS: 1\n    TOP_VALUE: 1\n    CAPACITY_FACTOR: 1.25\n    IS_GSHARD_LOSS: False\n    MOE_DROP: 0.1\n    AUX_LOSS_WEIGHT: 0.01\nTRAIN:\n  EPOCHS: 90\n  WARMUP_EPOCHS: 10\n  WEIGHT_DECAY: 0.1\n  BASE_LR: 1.25e-4 # 4096 batch-size\n  WARMUP_LR: 1.25e-7\n  MIN_LR: 1.25e-6\n  CLIP_GRAD: 3.0\nTEST:\n  SHUFFLE: True"
  },
  {
    "path": "configs/swinmoe/swin_moe_base_patch4_window12_192_8expert_32gpu_22k.yaml",
    "content": "DATA:\n  DATASET: imagenet22K\n  IMG_SIZE: 192\nMODEL:\n  TYPE: swin_moe\n  NAME: swin_moe_base_patch4_window12_192_8expert_32gpu_22k\n  DROP_PATH_RATE: 0.3\n  SWIN_MOE:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 12\n    MLP_FC2_BIAS: False\n    INIT_STD: 0.005\n    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]\n    NUM_LOCAL_EXPERTS: -4\n    TOP_VALUE: 1\n    CAPACITY_FACTOR: 1.25\n    IS_GSHARD_LOSS: False\n    MOE_DROP: 0.1\n    AUX_LOSS_WEIGHT: 0.01\nTRAIN:\n  EPOCHS: 90\n  WARMUP_EPOCHS: 10\n  WEIGHT_DECAY: 0.1\n  BASE_LR: 1.25e-4 # 4096 batch-size\n  WARMUP_LR: 1.25e-7\n  MIN_LR: 1.25e-6\n  CLIP_GRAD: 3.0\nTEST:\n  SHUFFLE: True"
  },
  {
    "path": "configs/swinmoe/swin_moe_base_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml",
    "content": "DATA:\n  DATASET: imagenet22K\n  IMG_SIZE: 192\nMODEL:\n  TYPE: swin_moe\n  NAME: swin_moe_base_patch4_window12_192_cosine_router_32expert_32gpu_22k\n  DROP_PATH_RATE: 0.3\n  SWIN_MOE:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 12\n    MLP_FC2_BIAS: False\n    INIT_STD: 0.005\n    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]\n    NUM_LOCAL_EXPERTS: 1\n    TOP_VALUE: 1\n    CAPACITY_FACTOR: 1.25\n    COSINE_ROUTER: True\n    IS_GSHARD_LOSS: False\n    MOE_DROP: 0.1\n    AUX_LOSS_WEIGHT: 0.01\nTRAIN:\n  EPOCHS: 90\n  WARMUP_EPOCHS: 10\n  WEIGHT_DECAY: 0.1\n  BASE_LR: 1.25e-4 # 4096 batch-size\n  WARMUP_LR: 1.25e-7\n  MIN_LR: 1.25e-6\n  CLIP_GRAD: 3.0\nTEST:\n  SHUFFLE: True"
  },
  {
    "path": "configs/swinmoe/swin_moe_base_patch4_window12_192_densebaseline_22k.yaml",
    "content": "DATA:\n  DATASET: imagenet22K\n  IMG_SIZE: 192\nMODEL:\n  TYPE: swin_moe\n  NAME: swin_moe_base_patch4_window12_192_densebaseline_22k\n  DROP_PATH_RATE: 0.2\n  SWIN_MOE:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 12\n    MLP_FC2_BIAS: False\n    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ -1 ], [ -1 ] ]\nTRAIN:\n  EPOCHS: 90\n  WARMUP_EPOCHS: 10\n  WEIGHT_DECAY: 0.1\n  BASE_LR: 1.25e-4 # 4096 batch-size\n  WARMUP_LR: 1.25e-7\n  MIN_LR: 1.25e-6\n  CLIP_GRAD: 3.0\n  MOE:\n    SAVE_MASTER: True\nTEST:\n  SHUFFLE: True"
  },
  {
    "path": "configs/swinmoe/swin_moe_small_patch4_window12_192_16expert_32gpu_22k.yaml",
    "content": "DATA:\n  DATASET: imagenet22K\n  IMG_SIZE: 192\nMODEL:\n  TYPE: swin_moe\n  NAME: swin_moe_small_patch4_window12_192_16expert_32gpu_22k\n  DROP_PATH_RATE: 0.2\n  SWIN_MOE:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 3, 6, 12, 24 ]\n    WINDOW_SIZE: 12\n    MLP_FC2_BIAS: False\n    INIT_STD: 0.005\n    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]\n    NUM_LOCAL_EXPERTS: -2\n    TOP_VALUE: 1\n    CAPACITY_FACTOR: 1.25\n    IS_GSHARD_LOSS: False\n    MOE_DROP: 0.1\n    AUX_LOSS_WEIGHT: 0.01\nTRAIN:\n  EPOCHS: 90\n  WARMUP_EPOCHS: 10\n  WEIGHT_DECAY: 0.1\n  BASE_LR: 1.25e-4 # 4096 batch-size\n  WARMUP_LR: 1.25e-7\n  MIN_LR: 1.25e-6\n  CLIP_GRAD: 3.0\nTEST:\n  SHUFFLE: True"
  },
  {
    "path": "configs/swinmoe/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml",
    "content": "DATA:\n  DATASET: imagenet22K\n  IMG_SIZE: 192\nMODEL:\n  TYPE: swin_moe\n  NAME: swin_moe_small_patch4_window12_192_32expert_32gpu_22k\n  DROP_PATH_RATE: 0.2\n  SWIN_MOE:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 3, 6, 12, 24 ]\n    WINDOW_SIZE: 12\n    MLP_FC2_BIAS: False\n    INIT_STD: 0.005\n    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]\n    NUM_LOCAL_EXPERTS: 1\n    TOP_VALUE: 1\n    CAPACITY_FACTOR: 1.25\n    IS_GSHARD_LOSS: False\n    MOE_DROP: 0.1\n    AUX_LOSS_WEIGHT: 0.01\nTRAIN:\n  EPOCHS: 90\n  WARMUP_EPOCHS: 10\n  WEIGHT_DECAY: 0.1\n  BASE_LR: 1.25e-4 # 4096 batch-size\n  WARMUP_LR: 1.25e-7\n  MIN_LR: 1.25e-6\n  CLIP_GRAD: 3.0\nTEST:\n  SHUFFLE: True"
  },
  {
    "path": "configs/swinmoe/swin_moe_small_patch4_window12_192_64expert_64gpu_22k.yaml",
    "content": "DATA:\n  DATASET: imagenet22K\n  IMG_SIZE: 192\nMODEL:\n  TYPE: swin_moe\n  NAME: swin_moe_small_patch4_window12_192_64expert_64gpu_22k\n  DROP_PATH_RATE: 0.2\n  SWIN_MOE:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 3, 6, 12, 24 ]\n    WINDOW_SIZE: 12\n    MLP_FC2_BIAS: False\n    INIT_STD: 0.005\n    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]\n    NUM_LOCAL_EXPERTS: 1\n    TOP_VALUE: 1\n    CAPACITY_FACTOR: 1.25\n    IS_GSHARD_LOSS: False\n    MOE_DROP: 0.1\n    AUX_LOSS_WEIGHT: 0.01\nTRAIN:\n  EPOCHS: 90\n  WARMUP_EPOCHS: 10\n  WEIGHT_DECAY: 0.1\n  BASE_LR: 1.25e-4 # 4096 batch-size\n  WARMUP_LR: 1.25e-7\n  MIN_LR: 1.25e-6\n  CLIP_GRAD: 3.0\nTEST:\n  SHUFFLE: True"
  },
  {
    "path": "configs/swinmoe/swin_moe_small_patch4_window12_192_8expert_32gpu_22k.yaml",
    "content": "DATA:\n  DATASET: imagenet22K\n  IMG_SIZE: 192\nMODEL:\n  TYPE: swin_moe\n  NAME: swin_moe_small_patch4_window12_192_8expert_32gpu_22k\n  DROP_PATH_RATE: 0.2\n  SWIN_MOE:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 3, 6, 12, 24 ]\n    WINDOW_SIZE: 12\n    MLP_FC2_BIAS: False\n    INIT_STD: 0.005\n    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]\n    NUM_LOCAL_EXPERTS: -4\n    TOP_VALUE: 1\n    CAPACITY_FACTOR: 1.25\n    IS_GSHARD_LOSS: False\n    MOE_DROP: 0.1\n    AUX_LOSS_WEIGHT: 0.01\nTRAIN:\n  EPOCHS: 90\n  WARMUP_EPOCHS: 10\n  WEIGHT_DECAY: 0.1\n  BASE_LR: 1.25e-4 # 4096 batch-size\n  WARMUP_LR: 1.25e-7\n  MIN_LR: 1.25e-6\n  CLIP_GRAD: 3.0\nTEST:\n  SHUFFLE: True"
  },
  {
    "path": "configs/swinmoe/swin_moe_small_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml",
    "content": "DATA:\n  DATASET: imagenet22K\n  IMG_SIZE: 192\nMODEL:\n  TYPE: swin_moe\n  NAME: swin_moe_small_patch4_window12_192_cosine_router_32expert_32gpu_22k\n  DROP_PATH_RATE: 0.2\n  SWIN_MOE:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 3, 6, 12, 24 ]\n    WINDOW_SIZE: 12\n    MLP_FC2_BIAS: False\n    INIT_STD: 0.005\n    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]\n    NUM_LOCAL_EXPERTS: 1\n    TOP_VALUE: 1\n    CAPACITY_FACTOR: 1.25\n    COSINE_ROUTER: True\n    IS_GSHARD_LOSS: False\n    MOE_DROP: 0.1\n    AUX_LOSS_WEIGHT: 0.01\nTRAIN:\n  EPOCHS: 90\n  WARMUP_EPOCHS: 10\n  WEIGHT_DECAY: 0.1\n  BASE_LR: 1.25e-4 # 4096 batch-size\n  WARMUP_LR: 1.25e-7\n  MIN_LR: 1.25e-6\n  CLIP_GRAD: 3.0\nTEST:\n  SHUFFLE: True"
  },
  {
    "path": "configs/swinmoe/swin_moe_small_patch4_window12_192_densebaseline_22k.yaml",
    "content": "DATA:\n  DATASET: imagenet22K\n  IMG_SIZE: 192\nMODEL:\n  TYPE: swin_moe\n  NAME: swin_moe_small_patch4_window12_192_densebaseline_22k\n  DROP_PATH_RATE: 0.2\n  SWIN_MOE:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 3, 6, 12, 24 ]\n    WINDOW_SIZE: 12\n    MLP_FC2_BIAS: False\n    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ -1 ], [ -1 ] ]\nTRAIN:\n  EPOCHS: 90\n  WARMUP_EPOCHS: 10\n  WEIGHT_DECAY: 0.1\n  BASE_LR: 1.25e-4 # 4096 batch-size\n  WARMUP_LR: 1.25e-7\n  MIN_LR: 1.25e-6\n  CLIP_GRAD: 3.0\n  MOE:\n    SAVE_MASTER: True\nTEST:\n  SHUFFLE: True"
  },
  {
    "path": "configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml",
    "content": "DATA:\n  DATASET: imagenet22K\n  IMG_SIZE: 192\nMODEL:\n  TYPE: swinv2\n  NAME: swinv2_base_patch4_window12_192_22k\n  DROP_PATH_RATE: 0.2\n  SWINV2:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 12\nTRAIN:\n  EPOCHS: 90\n  WARMUP_EPOCHS: 5\n  WEIGHT_DECAY: 0.1\n  BASE_LR: 1.25e-4 # 4096 batch-size\n  WARMUP_LR: 1.25e-7\n  MIN_LR: 1.25e-6"
  },
  {
    "path": "configs/swinv2/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml",
    "content": "DATA:\n  IMG_SIZE: 256\nMODEL:\n  TYPE: swinv2\n  NAME: swinv2_base_patch4_window12to16_192to256_22kto1k_ft\n  DROP_PATH_RATE: 0.2\n  SWINV2:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 16\n    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]\nTRAIN:\n  EPOCHS: 30\n  WARMUP_EPOCHS: 5\n  WEIGHT_DECAY: 1e-8\n  BASE_LR: 2e-05\n  WARMUP_LR: 2e-08\n  MIN_LR: 2e-07"
  },
  {
    "path": "configs/swinv2/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.yaml",
    "content": "DATA:\n  IMG_SIZE: 384\nMODEL:\n  TYPE: swinv2\n  NAME: swinv2_base_patch4_window12to24_192to384_22kto1k_ft\n  DROP_PATH_RATE: 0.2\n  SWINV2:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 24\n    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]\nTRAIN:\n  EPOCHS: 30\n  WARMUP_EPOCHS: 5\n  WEIGHT_DECAY: 1e-8\n  BASE_LR: 2e-05\n  WARMUP_LR: 2e-08\n  MIN_LR: 2e-07\nTEST:\n  CROP: False"
  },
  {
    "path": "configs/swinv2/swinv2_base_patch4_window16_256.yaml",
    "content": "DATA:\n  IMG_SIZE: 256\nMODEL:\n  TYPE: swinv2\n  NAME: swinv2_base_patch4_window16_256\n  DROP_PATH_RATE: 0.5\n  SWINV2:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 16"
  },
  {
    "path": "configs/swinv2/swinv2_base_patch4_window8_256.yaml",
    "content": "DATA:\n  IMG_SIZE: 256\nMODEL:\n  TYPE: swinv2\n  NAME: swinv2_base_patch4_window8_256\n  DROP_PATH_RATE: 0.5\n  SWINV2:\n    EMBED_DIM: 128\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 4, 8, 16, 32 ]\n    WINDOW_SIZE: 8"
  },
  {
    "path": "configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml",
    "content": "DATA:\n  DATASET: imagenet22K\n  IMG_SIZE: 192\nMODEL:\n  TYPE: swinv2\n  NAME: swinv2_large_patch4_window12_192_22k\n  DROP_PATH_RATE: 0.2\n  SWINV2:\n    EMBED_DIM: 192\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 6, 12, 24, 48 ]\n    WINDOW_SIZE: 12\nTRAIN:\n  EPOCHS: 90\n  WARMUP_EPOCHS: 5\n  WEIGHT_DECAY: 0.1\n  BASE_LR: 1.25e-4 # 4096 batch-size\n  WARMUP_LR: 1.25e-7\n  MIN_LR: 1.25e-6"
  },
  {
    "path": "configs/swinv2/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.yaml",
    "content": "DATA:\n  IMG_SIZE: 256\nMODEL:\n  TYPE: swinv2\n  NAME: swinv2_base_patch4_window12to16_192to256_22kto1k_ft\n  DROP_PATH_RATE: 0.2\n  SWINV2:\n    EMBED_DIM: 192\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 6, 12, 24, 48 ]\n    WINDOW_SIZE: 16\n    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]\nTRAIN:\n  EPOCHS: 30\n  WARMUP_EPOCHS: 5\n  WEIGHT_DECAY: 1e-8\n  BASE_LR: 2e-05\n  WARMUP_LR: 2e-08\n  MIN_LR: 2e-07"
  },
  {
    "path": "configs/swinv2/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.yaml",
    "content": "DATA:\n  IMG_SIZE: 384\nMODEL:\n  TYPE: swinv2\n  NAME: swinv2_large_patch4_window12to24_192to384_22kto1k_ft\n  DROP_PATH_RATE: 0.2\n  SWINV2:\n    EMBED_DIM: 192\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 6, 12, 24, 48 ]\n    WINDOW_SIZE: 24\n    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]\nTRAIN:\n  EPOCHS: 30\n  WARMUP_EPOCHS: 5\n  WEIGHT_DECAY: 1e-8\n  BASE_LR: 2e-05\n  WARMUP_LR: 2e-08\n  MIN_LR: 2e-07\nTEST:\n  CROP: False"
  },
  {
    "path": "configs/swinv2/swinv2_small_patch4_window16_256.yaml",
    "content": "DATA:\n  IMG_SIZE: 256\nMODEL:\n  TYPE: swinv2\n  NAME: swinv2_small_patch4_window16_256\n  DROP_PATH_RATE: 0.3\n  SWINV2:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 3, 6, 12, 24 ]\n    WINDOW_SIZE: 16"
  },
  {
    "path": "configs/swinv2/swinv2_small_patch4_window8_256.yaml",
    "content": "DATA:\n  IMG_SIZE: 256\nMODEL:\n  TYPE: swinv2\n  NAME: swinv2_small_patch4_window8_256\n  DROP_PATH_RATE: 0.3\n  SWINV2:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 18, 2 ]\n    NUM_HEADS: [ 3, 6, 12, 24 ]\n    WINDOW_SIZE: 8"
  },
  {
    "path": "configs/swinv2/swinv2_tiny_patch4_window16_256.yaml",
    "content": "DATA:\n  IMG_SIZE: 256\nMODEL:\n  TYPE: swinv2\n  NAME: swinv2_tiny_patch4_window16_256\n  DROP_PATH_RATE: 0.2\n  SWINV2:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 6, 2 ]\n    NUM_HEADS: [ 3, 6, 12, 24 ]\n    WINDOW_SIZE: 16"
  },
  {
    "path": "configs/swinv2/swinv2_tiny_patch4_window8_256.yaml",
    "content": "DATA:\n  IMG_SIZE: 256\nMODEL:\n  TYPE: swinv2\n  NAME: swinv2_tiny_patch4_window8_256\n  DROP_PATH_RATE: 0.2\n  SWINV2:\n    EMBED_DIM: 96\n    DEPTHS: [ 2, 2, 6, 2 ]\n    NUM_HEADS: [ 3, 6, 12, 24 ]\n    WINDOW_SIZE: 8"
  },
  {
    "path": "data/__init__.py",
    "content": "from .build import build_loader as _build_loader\nfrom .data_simmim_pt import build_loader_simmim\nfrom .data_simmim_ft import build_loader_finetune\n\n\ndef build_loader(config, simmim=False, is_pretrain=False):\n    if not simmim:\n        return _build_loader(config)\n    if is_pretrain:\n        return build_loader_simmim(config)\n    else:\n        return build_loader_finetune(config)\n"
  },
  {
    "path": "data/build.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nimport os\nimport torch\nimport numpy as np\nimport torch.distributed as dist\nfrom torchvision import datasets, transforms\nfrom timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\nfrom timm.data import Mixup\nfrom timm.data import create_transform\n\nfrom .cached_image_folder import CachedImageFolder\nfrom .imagenet22k_dataset import IN22KDATASET\nfrom .samplers import SubsetRandomSampler\n\ntry:\n    from torchvision.transforms import InterpolationMode\n\n\n    def _pil_interp(method):\n        if method == 'bicubic':\n            return InterpolationMode.BICUBIC\n        elif method == 'lanczos':\n            return InterpolationMode.LANCZOS\n        elif method == 'hamming':\n            return InterpolationMode.HAMMING\n        else:\n            # default bilinear, do we want to allow nearest?\n            return InterpolationMode.BILINEAR\n\n\n    import timm.data.transforms as timm_transforms\n\n    timm_transforms._pil_interp = _pil_interp\nexcept:\n    from timm.data.transforms import _pil_interp\n\n\ndef build_loader(config):\n    config.defrost()\n    dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)\n    config.freeze()\n    print(f\"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset\")\n    dataset_val, _ = build_dataset(is_train=False, config=config)\n    print(f\"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset\")\n\n    num_tasks = dist.get_world_size()\n    global_rank = dist.get_rank()\n    if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':\n        indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())\n        sampler_train = SubsetRandomSampler(indices)\n    else:\n        sampler_train = torch.utils.data.DistributedSampler(\n            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True\n        )\n\n    if config.TEST.SEQUENTIAL:\n        sampler_val = torch.utils.data.SequentialSampler(dataset_val)\n    else:\n        sampler_val = torch.utils.data.distributed.DistributedSampler(\n            dataset_val, shuffle=config.TEST.SHUFFLE\n        )\n\n    data_loader_train = torch.utils.data.DataLoader(\n        dataset_train, sampler=sampler_train,\n        batch_size=config.DATA.BATCH_SIZE,\n        num_workers=config.DATA.NUM_WORKERS,\n        pin_memory=config.DATA.PIN_MEMORY,\n        drop_last=True,\n    )\n\n    data_loader_val = torch.utils.data.DataLoader(\n        dataset_val, sampler=sampler_val,\n        batch_size=config.DATA.BATCH_SIZE,\n        shuffle=False,\n        num_workers=config.DATA.NUM_WORKERS,\n        pin_memory=config.DATA.PIN_MEMORY,\n        drop_last=False\n    )\n\n    # setup mixup / cutmix\n    mixup_fn = None\n    mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None\n    if mixup_active:\n        mixup_fn = Mixup(\n            mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,\n            prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,\n            label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)\n\n    return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn\n\n\ndef build_dataset(is_train, config):\n    transform = build_transform(is_train, config)\n    if config.DATA.DATASET == 'imagenet':\n        prefix = 'train' if is_train else 'val'\n        if config.DATA.ZIP_MODE:\n            ann_file = prefix + \"_map.txt\"\n            prefix = prefix + \".zip@/\"\n            dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,\n                                        cache_mode=config.DATA.CACHE_MODE if is_train else 'part')\n        else:\n            root = os.path.join(config.DATA.DATA_PATH, prefix)\n            dataset = datasets.ImageFolder(root, transform=transform)\n        nb_classes = 1000\n    elif config.DATA.DATASET == 'imagenet22K':\n        prefix = 'ILSVRC2011fall_whole'\n        if is_train:\n            ann_file = prefix + \"_map_train.txt\"\n        else:\n            ann_file = prefix + \"_map_val.txt\"\n        dataset = IN22KDATASET(config.DATA.DATA_PATH, ann_file, transform)\n        nb_classes = 21841\n    else:\n        raise NotImplementedError(\"We only support ImageNet Now.\")\n\n    return dataset, nb_classes\n\n\ndef build_transform(is_train, config):\n    resize_im = config.DATA.IMG_SIZE > 32\n    if is_train:\n        # this should always dispatch to transforms_imagenet_train\n        transform = create_transform(\n            input_size=config.DATA.IMG_SIZE,\n            is_training=True,\n            color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,\n            auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,\n            re_prob=config.AUG.REPROB,\n            re_mode=config.AUG.REMODE,\n            re_count=config.AUG.RECOUNT,\n            interpolation=config.DATA.INTERPOLATION,\n        )\n        if not resize_im:\n            # replace RandomResizedCropAndInterpolation with\n            # RandomCrop\n            transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)\n        return transform\n\n    t = []\n    if resize_im:\n        if config.TEST.CROP:\n            size = int((256 / 224) * config.DATA.IMG_SIZE)\n            t.append(\n                transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),\n                # to maintain same ratio w.r.t. 224 images\n            )\n            t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))\n        else:\n            t.append(\n                transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),\n                                  interpolation=_pil_interp(config.DATA.INTERPOLATION))\n            )\n\n    t.append(transforms.ToTensor())\n    t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))\n    return transforms.Compose(t)\n"
  },
  {
    "path": "data/cached_image_folder.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nimport io\nimport os\nimport time\nimport torch.distributed as dist\nimport torch.utils.data as data\nfrom PIL import Image\n\nfrom .zipreader import is_zip_path, ZipReader\n\n\ndef has_file_allowed_extension(filename, extensions):\n    \"\"\"Checks if a file is an allowed extension.\n    Args:\n        filename (string): path to a file\n    Returns:\n        bool: True if the filename ends with a known image extension\n    \"\"\"\n    filename_lower = filename.lower()\n    return any(filename_lower.endswith(ext) for ext in extensions)\n\n\ndef find_classes(dir):\n    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]\n    classes.sort()\n    class_to_idx = {classes[i]: i for i in range(len(classes))}\n    return classes, class_to_idx\n\n\ndef make_dataset(dir, class_to_idx, extensions):\n    images = []\n    dir = os.path.expanduser(dir)\n    for target in sorted(os.listdir(dir)):\n        d = os.path.join(dir, target)\n        if not os.path.isdir(d):\n            continue\n\n        for root, _, fnames in sorted(os.walk(d)):\n            for fname in sorted(fnames):\n                if has_file_allowed_extension(fname, extensions):\n                    path = os.path.join(root, fname)\n                    item = (path, class_to_idx[target])\n                    images.append(item)\n\n    return images\n\n\ndef make_dataset_with_ann(ann_file, img_prefix, extensions):\n    images = []\n    with open(ann_file, \"r\") as f:\n        contents = f.readlines()\n        for line_str in contents:\n            path_contents = [c for c in line_str.split('\\t')]\n            im_file_name = path_contents[0]\n            class_index = int(path_contents[1])\n\n            assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions\n            item = (os.path.join(img_prefix, im_file_name), class_index)\n\n            images.append(item)\n\n    return images\n\n\nclass DatasetFolder(data.Dataset):\n    \"\"\"A generic data loader where the samples are arranged in this way: ::\n        root/class_x/xxx.ext\n        root/class_x/xxy.ext\n        root/class_x/xxz.ext\n        root/class_y/123.ext\n        root/class_y/nsdf3.ext\n        root/class_y/asd932_.ext\n    Args:\n        root (string): Root directory path.\n        loader (callable): A function to load a sample given its path.\n        extensions (list[string]): A list of allowed extensions.\n        transform (callable, optional): A function/transform that takes in\n            a sample and returns a transformed version.\n            E.g, ``transforms.RandomCrop`` for images.\n        target_transform (callable, optional): A function/transform that takes\n            in the target and transforms it.\n     Attributes:\n        samples (list): List of (sample path, class_index) tuples\n    \"\"\"\n\n    def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,\n                 cache_mode=\"no\"):\n        # image folder mode\n        if ann_file == '':\n            _, class_to_idx = find_classes(root)\n            samples = make_dataset(root, class_to_idx, extensions)\n        # zip mode\n        else:\n            samples = make_dataset_with_ann(os.path.join(root, ann_file),\n                                            os.path.join(root, img_prefix),\n                                            extensions)\n\n        if len(samples) == 0:\n            raise (RuntimeError(\"Found 0 files in subfolders of: \" + root + \"\\n\" +\n                                \"Supported extensions are: \" + \",\".join(extensions)))\n\n        self.root = root\n        self.loader = loader\n        self.extensions = extensions\n\n        self.samples = samples\n        self.labels = [y_1k for _, y_1k in samples]\n        self.classes = list(set(self.labels))\n\n        self.transform = transform\n        self.target_transform = target_transform\n\n        self.cache_mode = cache_mode\n        if self.cache_mode != \"no\":\n            self.init_cache()\n\n    def init_cache(self):\n        assert self.cache_mode in [\"part\", \"full\"]\n        n_sample = len(self.samples)\n        global_rank = dist.get_rank()\n        world_size = dist.get_world_size()\n\n        samples_bytes = [None for _ in range(n_sample)]\n        start_time = time.time()\n        for index in range(n_sample):\n            if index % (n_sample // 10) == 0:\n                t = time.time() - start_time\n                print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')\n                start_time = time.time()\n            path, target = self.samples[index]\n            if self.cache_mode == \"full\":\n                samples_bytes[index] = (ZipReader.read(path), target)\n            elif self.cache_mode == \"part\" and index % world_size == global_rank:\n                samples_bytes[index] = (ZipReader.read(path), target)\n            else:\n                samples_bytes[index] = (path, target)\n        self.samples = samples_bytes\n\n    def __getitem__(self, index):\n        \"\"\"\n        Args:\n            index (int): Index\n        Returns:\n            tuple: (sample, target) where target is class_index of the target class.\n        \"\"\"\n        path, target = self.samples[index]\n        sample = self.loader(path)\n        if self.transform is not None:\n            sample = self.transform(sample)\n        if self.target_transform is not None:\n            target = self.target_transform(target)\n\n        return sample, target\n\n    def __len__(self):\n        return len(self.samples)\n\n    def __repr__(self):\n        fmt_str = 'Dataset ' + self.__class__.__name__ + '\\n'\n        fmt_str += '    Number of datapoints: {}\\n'.format(self.__len__())\n        fmt_str += '    Root Location: {}\\n'.format(self.root)\n        tmp = '    Transforms (if any): '\n        fmt_str += '{0}{1}\\n'.format(tmp, self.transform.__repr__().replace('\\n', '\\n' + ' ' * len(tmp)))\n        tmp = '    Target Transforms (if any): '\n        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\\n', '\\n' + ' ' * len(tmp)))\n        return fmt_str\n\n\nIMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']\n\n\ndef pil_loader(path):\n    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)\n    if isinstance(path, bytes):\n        img = Image.open(io.BytesIO(path))\n    elif is_zip_path(path):\n        data = ZipReader.read(path)\n        img = Image.open(io.BytesIO(data))\n    else:\n        with open(path, 'rb') as f:\n            img = Image.open(f)\n            return img.convert('RGB')\n    return img.convert('RGB')\n\n\ndef accimage_loader(path):\n    import accimage\n    try:\n        return accimage.Image(path)\n    except IOError:\n        # Potentially a decoding problem, fall back to PIL.Image\n        return pil_loader(path)\n\n\ndef default_img_loader(path):\n    from torchvision import get_image_backend\n    if get_image_backend() == 'accimage':\n        return accimage_loader(path)\n    else:\n        return pil_loader(path)\n\n\nclass CachedImageFolder(DatasetFolder):\n    \"\"\"A generic data loader where the images are arranged in this way: ::\n        root/dog/xxx.png\n        root/dog/xxy.png\n        root/dog/xxz.png\n        root/cat/123.png\n        root/cat/nsdf3.png\n        root/cat/asd932_.png\n    Args:\n        root (string): Root directory path.\n        transform (callable, optional): A function/transform that  takes in an PIL image\n            and returns a transformed version. E.g, ``transforms.RandomCrop``\n        target_transform (callable, optional): A function/transform that takes in the\n            target and transforms it.\n        loader (callable, optional): A function to load an image given its path.\n     Attributes:\n        imgs (list): List of (image path, class_index) tuples\n    \"\"\"\n\n    def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,\n                 loader=default_img_loader, cache_mode=\"no\"):\n        super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,\n                                                ann_file=ann_file, img_prefix=img_prefix,\n                                                transform=transform, target_transform=target_transform,\n                                                cache_mode=cache_mode)\n        self.imgs = self.samples\n\n    def __getitem__(self, index):\n        \"\"\"\n        Args:\n            index (int): Index\n        Returns:\n            tuple: (image, target) where target is class_index of the target class.\n        \"\"\"\n        path, target = self.samples[index]\n        image = self.loader(path)\n        if self.transform is not None:\n            img = self.transform(image)\n        else:\n            img = image\n        if self.target_transform is not None:\n            target = self.target_transform(target)\n\n        return img, target\n"
  },
  {
    "path": "data/data_simmim_ft.py",
    "content": "# --------------------------------------------------------\n# SimMIM\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Zhenda Xie\n# --------------------------------------------------------\n\nimport os\nimport torch.distributed as dist\nfrom torch.utils.data import DataLoader, DistributedSampler\nfrom torchvision import datasets, transforms\nfrom timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\nfrom timm.data import Mixup\nfrom timm.data import create_transform\nfrom timm.data.transforms import _pil_interp\n\n\ndef build_loader_finetune(config):\n    config.defrost()\n    dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)\n    config.freeze()\n    dataset_val, _ = build_dataset(is_train=False, config=config)\n\n    num_tasks = dist.get_world_size()\n    global_rank = dist.get_rank()\n    sampler_train = DistributedSampler(\n        dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True\n    )\n    sampler_val = DistributedSampler(\n        dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False\n    )\n\n    data_loader_train = DataLoader(\n        dataset_train, sampler=sampler_train,\n        batch_size=config.DATA.BATCH_SIZE,\n        num_workers=config.DATA.NUM_WORKERS,\n        pin_memory=config.DATA.PIN_MEMORY,\n        drop_last=True,\n    )\n\n    data_loader_val = DataLoader(\n        dataset_val, sampler=sampler_val,\n        batch_size=config.DATA.BATCH_SIZE,\n        num_workers=config.DATA.NUM_WORKERS,\n        pin_memory=config.DATA.PIN_MEMORY,\n        drop_last=False,\n    )\n\n    # setup mixup / cutmix\n    mixup_fn = None\n    mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None\n    if mixup_active:\n        mixup_fn = Mixup(\n            mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,\n            prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,\n            label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)\n\n    return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn\n\n\ndef build_dataset(is_train, config):\n    transform = build_transform(is_train, config)\n    \n    if config.DATA.DATASET == 'imagenet':\n        prefix = 'train' if is_train else 'val'\n        root = os.path.join(config.DATA.DATA_PATH, prefix)\n        dataset = datasets.ImageFolder(root, transform=transform)\n        nb_classes = 1000\n    else:\n        raise NotImplementedError(\"We only support ImageNet Now.\")\n\n    return dataset, nb_classes\n\n\ndef build_transform(is_train, config):\n    resize_im = config.DATA.IMG_SIZE > 32\n    if is_train:\n        # this should always dispatch to transforms_imagenet_train\n        transform = create_transform(\n            input_size=config.DATA.IMG_SIZE,\n            is_training=True,\n            color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,\n            auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,\n            re_prob=config.AUG.REPROB,\n            re_mode=config.AUG.REMODE,\n            re_count=config.AUG.RECOUNT,\n            interpolation=config.DATA.INTERPOLATION,\n        )\n        if not resize_im:\n            # replace RandomResizedCropAndInterpolation with\n            # RandomCrop\n            transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)\n        return transform\n\n    t = []\n    if resize_im:\n        if config.TEST.CROP:\n            size = int((256 / 224) * config.DATA.IMG_SIZE)\n            t.append(\n                transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),\n                # to maintain same ratio w.r.t. 224 images\n            )\n            t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))\n        else:\n            t.append(\n                transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),\n                                  interpolation=_pil_interp(config.DATA.INTERPOLATION))\n            )\n\n    t.append(transforms.ToTensor())\n    t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))\n    return transforms.Compose(t)\n"
  },
  {
    "path": "data/data_simmim_pt.py",
    "content": "# --------------------------------------------------------\n# SimMIM\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Zhenda Xie\n# --------------------------------------------------------\n\nimport math\nimport random\nimport numpy as np\n\nimport torch\nimport torch.distributed as dist\nimport torchvision.transforms as T\nfrom torch.utils.data import DataLoader, DistributedSampler\nfrom torch.utils.data._utils.collate import default_collate\nfrom torchvision.datasets import ImageFolder\nfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\n\n\nclass MaskGenerator:\n    def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):\n        self.input_size = input_size\n        self.mask_patch_size = mask_patch_size\n        self.model_patch_size = model_patch_size\n        self.mask_ratio = mask_ratio\n        \n        assert self.input_size % self.mask_patch_size == 0\n        assert self.mask_patch_size % self.model_patch_size == 0\n        \n        self.rand_size = self.input_size // self.mask_patch_size\n        self.scale = self.mask_patch_size // self.model_patch_size\n        \n        self.token_count = self.rand_size ** 2\n        self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))\n        \n    def __call__(self):\n        mask_idx = np.random.permutation(self.token_count)[:self.mask_count]\n        mask = np.zeros(self.token_count, dtype=int)\n        mask[mask_idx] = 1\n        \n        mask = mask.reshape((self.rand_size, self.rand_size))\n        mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)\n        \n        return mask\n\n\nclass SimMIMTransform:\n    def __init__(self, config):\n        self.transform_img = T.Compose([\n            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),\n            T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)),\n            T.RandomHorizontalFlip(),\n            T.ToTensor(),\n            T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)),\n        ])\n \n        if config.MODEL.TYPE in ['swin', 'swinv2']:\n            model_patch_size=config.MODEL.SWIN.PATCH_SIZE\n        else:\n            raise NotImplementedError\n        \n        self.mask_generator = MaskGenerator(\n            input_size=config.DATA.IMG_SIZE,\n            mask_patch_size=config.DATA.MASK_PATCH_SIZE,\n            model_patch_size=model_patch_size,\n            mask_ratio=config.DATA.MASK_RATIO,\n        )\n    \n    def __call__(self, img):\n        img = self.transform_img(img)\n        mask = self.mask_generator()\n        \n        return img, mask\n\n\ndef collate_fn(batch):\n    if not isinstance(batch[0][0], tuple):\n        return default_collate(batch)\n    else:\n        batch_num = len(batch)\n        ret = []\n        for item_idx in range(len(batch[0][0])):\n            if batch[0][0][item_idx] is None:\n                ret.append(None)\n            else:\n                ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))\n        ret.append(default_collate([batch[i][1] for i in range(batch_num)]))\n        return ret\n\n\ndef build_loader_simmim(config):\n    transform = SimMIMTransform(config)\n    dataset = ImageFolder(config.DATA.DATA_PATH, transform)\n    \n    sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)\n    dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn)\n    \n    return dataloader"
  },
  {
    "path": "data/imagenet22k_dataset.py",
    "content": "import os\nimport json\nimport torch.utils.data as data\nimport numpy as np\nfrom PIL import Image\n\nimport warnings\n\nwarnings.filterwarnings(\"ignore\", \"(Possibly )?corrupt EXIF data\", UserWarning)\n\n\nclass IN22KDATASET(data.Dataset):\n    def __init__(self, root, ann_file='', transform=None, target_transform=None):\n        super(IN22KDATASET, self).__init__()\n\n        self.data_path = root\n        self.ann_path = os.path.join(self.data_path, ann_file)\n        self.transform = transform\n        self.target_transform = target_transform\n        # id & label: https://github.com/google-research/big_transfer/issues/7\n        # total: 21843; only 21841 class have images: map 21841->9205; 21842->15027\n        self.database = json.load(open(self.ann_path))\n\n    def _load_image(self, path):\n        try:\n            im = Image.open(path)\n        except:\n            print(\"ERROR IMG LOADED: \", path)\n            random_img = np.random.rand(224, 224, 3) * 255\n            im = Image.fromarray(np.uint8(random_img))\n        return im\n\n    def __getitem__(self, index):\n        \"\"\"\n        Args:\n            index (int): Index\n        Returns:\n            tuple: (image, target) where target is class_index of the target class.\n        \"\"\"\n        idb = self.database[index]\n\n        # images\n        images = self._load_image(self.data_path + '/' + idb[0]).convert('RGB')\n        if self.transform is not None:\n            images = self.transform(images)\n\n        # target\n        target = int(idb[1])\n        if self.target_transform is not None:\n            target = self.target_transform(target)\n\n        return images, target\n\n    def __len__(self):\n        return len(self.database)\n"
  },
  {
    "path": "data/map22kto1k.txt",
    "content": "359\n368\n460\n475\n486\n492\n496\n514\n516\n525\n547\n548\n556\n563\n575\n641\n648\n723\n733\n765\n801\n826\n852\n858\n878\n896\n900\n905\n908\n910\n935\n946\n947\n994\n999\n1003\n1005\n1010\n1027\n1029\n1048\n1055\n1064\n1065\n1069\n1075\n1079\n1081\n1085\n1088\n1093\n1106\n1143\n1144\n1145\n1147\n1168\n1171\n1178\n1187\n1190\n1197\n1205\n1216\n1223\n1230\n1236\n1241\n1245\n1257\n1259\n1260\n1267\n1268\n1269\n1271\n1272\n1273\n1277\n1303\n1344\n1349\n1355\n1357\n1384\n1388\n1391\n1427\n1429\n1432\n1437\n1450\n1461\n1462\n1474\n1502\n1503\n1512\n1552\n1555\n1577\n1584\n1587\n1589\n1599\n1615\n1616\n1681\n1692\n1701\n1716\n1729\n1757\n1759\n1764\n1777\n1786\n1822\n1841\n1842\n1848\n1850\n1856\n1860\n1861\n1864\n1876\n1897\n1898\n1910\n1913\n1918\n1922\n1928\n1932\n1935\n1947\n1951\n1953\n1970\n1977\n1979\n2001\n2017\n2067\n2081\n2087\n2112\n2128\n2135\n2147\n2174\n2175\n2176\n2177\n2178\n2181\n2183\n2184\n2187\n2189\n2190\n2191\n2192\n2193\n2197\n2202\n2203\n2206\n2208\n2209\n2211\n2212\n2213\n2214\n2215\n2216\n2217\n2219\n2222\n2223\n2224\n2225\n2226\n2227\n2228\n2229\n2230\n2236\n2238\n2240\n2241\n2242\n2243\n2244\n2245\n2247\n2248\n2249\n2250\n2251\n2252\n2255\n2256\n2257\n2262\n2263\n2264\n2265\n2266\n2268\n2270\n2271\n2272\n2273\n2275\n2276\n2279\n2280\n2281\n2282\n2285\n2289\n2292\n2295\n2296\n2297\n2298\n2299\n2300\n2301\n2302\n2303\n2304\n2305\n2306\n2309\n2310\n2312\n2313\n2314\n2315\n2316\n2318\n2319\n2321\n2322\n2326\n2329\n2330\n2331\n2332\n2334\n2335\n2336\n2337\n2338\n2339\n2341\n2342\n2343\n2344\n2346\n2348\n2349\n2351\n2352\n2353\n2355\n2357\n2358\n2359\n2360\n2364\n2365\n2368\n2369\n2377\n2382\n2383\n2385\n2397\n2398\n2400\n2402\n2405\n2412\n2421\n2428\n2431\n2432\n2433\n2436\n2441\n2445\n2450\n2453\n2454\n2465\n2469\n2532\n2533\n2538\n2544\n2547\n2557\n2565\n2578\n2612\n2658\n2702\n2722\n2731\n2738\n2741\n2747\n2810\n2818\n2833\n2844\n2845\n2867\n2874\n2882\n2884\n2888\n2889\n3008\n3012\n3019\n3029\n3033\n3042\n3091\n3106\n3138\n3159\n3164\n3169\n3280\n3296\n3311\n3318\n3320\n3324\n3330\n3366\n3375\n3381\n3406\n3419\n3432\n3434\n3435\n3493\n3495\n3503\n3509\n3511\n3513\n3517\n3521\n3526\n3546\n3554\n3600\n3601\n3606\n3612\n3613\n3616\n3622\n3623\n3627\n3632\n3634\n3636\n3638\n3644\n3646\n3649\n3650\n3651\n3656\n3663\n3673\n3674\n3689\n3690\n3702\n3733\n3769\n3971\n3974\n4065\n4068\n4073\n4102\n4136\n4140\n4151\n4159\n4165\n4207\n4219\n4226\n4249\n4256\n4263\n4270\n4313\n4321\n4378\n4386\n4478\n4508\n4512\n4536\n4542\n4550\n4560\n4562\n4570\n4571\n4572\n4583\n4588\n4594\n4604\n4608\n4623\n4634\n4636\n4646\n4651\n4652\n4686\n4688\n4691\n4699\n4724\n4727\n4737\n4770\n4774\n4789\n4802\n4807\n4819\n4880\n4886\n4908\n4927\n4931\n4936\n4964\n4976\n4993\n5028\n5033\n5043\n5046\n5096\n5111\n5114\n5131\n5132\n5183\n5199\n5235\n5275\n5291\n5293\n5294\n5343\n5360\n5362\n5364\n5390\n5402\n5418\n5428\n5430\n5437\n5443\n5473\n5484\n5486\n5505\n5507\n5508\n5510\n5567\n5578\n5580\n5584\n5606\n5613\n5629\n5672\n5676\n5692\n5701\n5760\n5769\n5770\n5779\n5814\n5850\n5871\n5893\n5911\n5949\n5954\n6005\n6006\n6012\n6017\n6023\n6024\n6040\n6050\n6054\n6087\n6105\n6157\n6235\n6237\n6256\n6259\n6286\n6291\n6306\n6339\n6341\n6343\n6379\n6383\n6393\n6405\n6479\n6511\n6517\n6541\n6561\n6608\n6611\n6615\n6678\n6682\n6707\n6752\n6798\n6850\n6880\n6885\n6890\n6920\n6981\n7000\n7009\n7038\n7049\n7050\n7052\n7073\n7078\n7098\n7111\n7165\n7198\n7204\n7280\n7283\n7286\n7287\n7293\n7294\n7305\n7318\n7341\n7346\n7354\n7382\n7427\n7428\n7435\n7445\n7450\n7455\n7467\n7469\n7497\n7502\n7506\n7514\n7523\n7651\n7661\n7664\n7672\n7679\n7685\n7696\n7730\n7871\n7873\n7895\n7914\n7915\n7920\n7934\n7935\n7949\n8009\n8036\n8051\n8065\n8074\n8090\n8112\n8140\n8164\n8168\n8178\n8182\n8198\n8212\n8216\n8230\n8242\n8288\n8289\n8295\n8318\n8352\n8368\n8371\n8375\n8376\n8401\n8416\n8419\n8436\n8460\n8477\n8478\n8482\n8498\n8500\n8539\n8543\n8552\n8555\n8580\n8584\n8586\n8594\n8598\n8601\n8606\n8610\n8611\n8622\n8627\n8639\n8649\n8650\n8653\n8654\n8667\n8672\n8673\n8674\n8676\n8684\n8720\n8723\n8750\n8753\n8801\n8815\n8831\n8835\n8842\n8845\n8858\n8897\n8916\n8951\n8954\n8959\n8970\n8976\n8981\n8983\n8989\n8991\n8993\n9019\n9039\n9042\n9043\n9056\n9057\n9070\n9087\n9098\n9106\n9130\n9131\n9155\n9171\n9183\n9198\n9199\n9201\n9204\n9212\n9221\n9225\n9229\n9250\n9260\n9271\n9279\n9295\n9300\n9310\n9322\n9345\n9352\n9376\n9377\n9382\n9392\n9401\n9405\n9441\n9449\n9464\n9475\n9502\n9505\n9514\n9515\n9545\n9567\n9576\n9608\n9609\n9624\n9633\n9639\n9643\n9656\n9674\n9740\n9752\n9760\n9767\n9778\n9802\n9820\n9839\n9879\n9924\n9956\n9961\n9963\n9970\n9997\n10010\n10031\n10040\n10052\n10073\n10075\n10078\n10094\n10097\n10109\n10118\n10121\n10124\n10158\n10226\n10276\n10304\n10307\n10314\n10315\n10332\n10337\n10338\n10413\n10423\n10451\n10463\n10465\n10487\n10519\n10522\n10523\n10532\n10534\n10535\n10551\n10559\n10574\n10583\n10586\n10589\n10612\n10626\n10635\n10638\n10677\n10683\n10726\n10776\n10782\n10783\n10807\n10837\n10840\n10848\n10859\n10871\n10881\n10884\n10908\n10914\n10921\n10936\n10947\n10951\n10952\n10957\n10999\n11003\n11018\n11023\n11025\n11027\n11045\n11055\n11095\n11110\n11137\n5564\n11168\n11186\n11221\n11223\n11242\n11255\n11259\n11279\n11306\n11311\n11331\n11367\n11377\n11389\n11392\n11401\n11407\n11437\n11449\n11466\n11469\n11473\n11478\n11483\n11484\n11507\n11536\n11558\n11566\n11575\n11584\n11594\n11611\n11612\n11619\n11621\n11640\n11643\n11664\n11674\n11689\n11709\n11710\n11716\n11721\n11726\n11729\n11743\n11760\n11771\n11837\n11839\n11856\n11876\n11878\n11884\n11889\n11896\n11917\n11923\n11930\n11944\n11952\n11980\n11984\n12214\n12229\n12239\n12241\n12242\n12247\n12283\n12349\n12369\n12373\n12422\n12560\n12566\n12575\n12688\n12755\n12768\n12778\n12780\n12812\n12832\n12835\n12836\n12843\n12847\n12849\n12850\n12856\n12858\n12873\n12938\n12971\n13017\n13038\n13046\n13059\n13085\n13086\n13088\n13094\n13134\n13182\n13230\n13406\n13444\n13614\n13690\n13698\n13709\n13749\n13804\n13982\n14051\n14059\n14219\n14246\n14256\n14264\n14294\n14324\n14367\n14389\n14394\n14438\n14442\n14965\n15732\n16744\n18037\n18205\n18535\n18792\n19102\n20019\n20462\n21026\n21045\n21163\n21171\n21181\n21196\n21200\n21369\n21817"
  },
  {
    "path": "data/samplers.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nimport torch\n\n\nclass SubsetRandomSampler(torch.utils.data.Sampler):\n    r\"\"\"Samples elements randomly from a given list of indices, without replacement.\n\n    Arguments:\n        indices (sequence): a sequence of indices\n    \"\"\"\n\n    def __init__(self, indices):\n        self.epoch = 0\n        self.indices = indices\n\n    def __iter__(self):\n        return (self.indices[i] for i in torch.randperm(len(self.indices)))\n\n    def __len__(self):\n        return len(self.indices)\n\n    def set_epoch(self, epoch):\n        self.epoch = epoch\n"
  },
  {
    "path": "data/zipreader.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nimport os\nimport zipfile\nimport io\nimport numpy as np\nfrom PIL import Image\nfrom PIL import ImageFile\n\nImageFile.LOAD_TRUNCATED_IMAGES = True\n\n\ndef is_zip_path(img_or_path):\n    \"\"\"judge if this is a zip path\"\"\"\n    return '.zip@' in img_or_path\n\n\nclass ZipReader(object):\n    \"\"\"A class to read zipped files\"\"\"\n    zip_bank = dict()\n\n    def __init__(self):\n        super(ZipReader, self).__init__()\n\n    @staticmethod\n    def get_zipfile(path):\n        zip_bank = ZipReader.zip_bank\n        if path not in zip_bank:\n            zfile = zipfile.ZipFile(path, 'r')\n            zip_bank[path] = zfile\n        return zip_bank[path]\n\n    @staticmethod\n    def split_zip_style_path(path):\n        pos_at = path.index('@')\n        assert pos_at != -1, \"character '@' is not found from the given path '%s'\" % path\n\n        zip_path = path[0: pos_at]\n        folder_path = path[pos_at + 1:]\n        folder_path = str.strip(folder_path, '/')\n        return zip_path, folder_path\n\n    @staticmethod\n    def list_folder(path):\n        zip_path, folder_path = ZipReader.split_zip_style_path(path)\n\n        zfile = ZipReader.get_zipfile(zip_path)\n        folder_list = []\n        for file_foler_name in zfile.namelist():\n            file_foler_name = str.strip(file_foler_name, '/')\n            if file_foler_name.startswith(folder_path) and \\\n                    len(os.path.splitext(file_foler_name)[-1]) == 0 and \\\n                    file_foler_name != folder_path:\n                if len(folder_path) == 0:\n                    folder_list.append(file_foler_name)\n                else:\n                    folder_list.append(file_foler_name[len(folder_path) + 1:])\n\n        return folder_list\n\n    @staticmethod\n    def list_files(path, extension=None):\n        if extension is None:\n            extension = ['.*']\n        zip_path, folder_path = ZipReader.split_zip_style_path(path)\n\n        zfile = ZipReader.get_zipfile(zip_path)\n        file_lists = []\n        for file_foler_name in zfile.namelist():\n            file_foler_name = str.strip(file_foler_name, '/')\n            if file_foler_name.startswith(folder_path) and \\\n                    str.lower(os.path.splitext(file_foler_name)[-1]) in extension:\n                if len(folder_path) == 0:\n                    file_lists.append(file_foler_name)\n                else:\n                    file_lists.append(file_foler_name[len(folder_path) + 1:])\n\n        return file_lists\n\n    @staticmethod\n    def read(path):\n        zip_path, path_img = ZipReader.split_zip_style_path(path)\n        zfile = ZipReader.get_zipfile(zip_path)\n        data = zfile.read(path_img)\n        return data\n\n    @staticmethod\n    def imread(path):\n        zip_path, path_img = ZipReader.split_zip_style_path(path)\n        zfile = ZipReader.get_zipfile(zip_path)\n        data = zfile.read(path_img)\n        try:\n            im = Image.open(io.BytesIO(data))\n        except:\n            print(\"ERROR IMG LOADED: \", path_img)\n            random_img = np.random.rand(224, 224, 3) * 255\n            im = Image.fromarray(np.uint8(random_img))\n        return im\n"
  },
  {
    "path": "get_started.md",
    "content": "# Swin Transformer for Image Classification\n\nThis folder contains the implementation of the Swin Transformer for image classification.\n\n## Model Zoo\n\nPlease refer to [MODEL HUB](MODELHUB.md#imagenet-22k-pretrained-swin-moe-models) for more pre-trained models.\n\n## Usage\n\n### Install\n\nWe recommend using the pytorch docker `nvcr>=21.05` by\nnvidia: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch.\n\n- Clone this repo:\n\n```bash\ngit clone https://github.com/microsoft/Swin-Transformer.git\ncd Swin-Transformer\n```\n\n- Create a conda virtual environment and activate it:\n\n```bash\nconda create -n swin python=3.7 -y\nconda activate swin\n```\n\n- Install `CUDA>=10.2` with `cudnn>=7` following\n  the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)\n- Install `PyTorch>=1.8.0` and `torchvision>=0.9.0` with `CUDA>=10.2`:\n\n```bash\nconda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=10.2 -c pytorch\n```\n\n- Install `timm==0.4.12`:\n\n```bash\npip install timm==0.4.12\n```\n\n- Install other requirements:\n\n```bash\npip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8 pyyaml scipy\n```\n\n- Install fused window process for acceleration, activated by passing `--fused_window_process` in the running script\n```bash\ncd kernels/window_process\npython setup.py install #--user\n```\n\n### Data preparation\n\nWe use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to\nload data:\n\n- For standard folder dataset, move validation images to labeled sub-folders. The file structure should look like:\n  ```bash\n  $ tree data\n  imagenet\n  ├── train\n  │   ├── class1\n  │   │   ├── img1.jpeg\n  │   │   ├── img2.jpeg\n  │   │   └── ...\n  │   ├── class2\n  │   │   ├── img3.jpeg\n  │   │   └── ...\n  │   └── ...\n  └── val\n      ├── class1\n      │   ├── img4.jpeg\n      │   ├── img5.jpeg\n      │   └── ...\n      ├── class2\n      │   ├── img6.jpeg\n      │   └── ...\n      └── ...\n \n  ```\n- To boost the slow speed when reading images from massive small files, we also support zipped ImageNet, which includes\n  four files:\n    - `train.zip`, `val.zip`: which store the zipped folder for train and validate splits.\n    - `train_map.txt`, `val_map.txt`: which store the relative path in the corresponding zip file and ground truth\n      label. Make sure the data folder looks like this:\n\n  ```bash\n  $ tree data\n  data\n  └── ImageNet-Zip\n      ├── train_map.txt\n      ├── train.zip\n      ├── val_map.txt\n      └── val.zip\n  \n  $ head -n 5 data/ImageNet-Zip/val_map.txt\n  ILSVRC2012_val_00000001.JPEG\t65\n  ILSVRC2012_val_00000002.JPEG\t970\n  ILSVRC2012_val_00000003.JPEG\t230\n  ILSVRC2012_val_00000004.JPEG\t809\n  ILSVRC2012_val_00000005.JPEG\t516\n  \n  $ head -n 5 data/ImageNet-Zip/train_map.txt\n  n01440764/n01440764_10026.JPEG\t0\n  n01440764/n01440764_10027.JPEG\t0\n  n01440764/n01440764_10029.JPEG\t0\n  n01440764/n01440764_10040.JPEG\t0\n  n01440764/n01440764_10042.JPEG\t0\n  ```\n- For ImageNet-22K dataset, make a folder named `fall11_whole` and move all images to labeled sub-folders in this\n  folder. Then download the train-val split\n  file ([ILSVRC2011fall_whole_map_train.txt](https://github.com/SwinTransformer/storage/releases/download/v2.0.1/ILSVRC2011fall_whole_map_train.txt)\n  & [ILSVRC2011fall_whole_map_val.txt](https://github.com/SwinTransformer/storage/releases/download/v2.0.1/ILSVRC2011fall_whole_map_val.txt))\n  , and put them in the parent directory of `fall11_whole`. The file structure should look like:\n\n  ```bash\n    $ tree imagenet22k/\n    imagenet22k/\n    ├── ILSVRC2011fall_whole_map_train.txt\n    ├── ILSVRC2011fall_whole_map_val.txt\n    └── fall11_whole\n        ├── n00004475\n        ├── n00005787\n        ├── n00006024\n        ├── n00006484\n        └── ...\n  ```\n\n### Evaluation\n\nTo evaluate a pre-trained `Swin Transformer` on ImageNet val, run:\n\n```bash\npython -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py --eval \\\n--cfg <config-file> --resume <checkpoint> --data-path <imagenet-path> \n```\n\nFor example, to evaluate the `Swin-B` with a single GPU:\n\n```bash\npython -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --eval \\\n--cfg configs/swin/swin_base_patch4_window7_224.yaml --resume swin_base_patch4_window7_224.pth --data-path <imagenet-path>\n```\n\n### Training from scratch on ImageNet-1K\n\nTo train a `Swin Transformer` on ImageNet from scratch, run:\n\n```bash\npython -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345  main.py \\ \n--cfg <config-file> --data-path <imagenet-path> [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]\n```\n\n**Notes**:\n\n- To use zipped ImageNet instead of folder dataset, add `--zip` to the parameters.\n    - To cache the dataset in the memory instead of reading from files every time, add `--cache-mode part`, which will\n      shard the dataset into non-overlapping pieces for different GPUs and only load the corresponding one for each GPU.\n- When GPU memory is not enough, you can try the following suggestions:\n    - Use gradient accumulation by adding `--accumulation-steps <steps>`, set appropriate `<steps>` according to your need.\n    - Use gradient checkpointing by adding `--use-checkpoint`, e.g., it saves about 60% memory when training `Swin-B`.\n      Please refer to [this page](https://pytorch.org/docs/stable/checkpoint.html) for more details.\n    - We recommend using multi-node with more GPUs for training very large models, a tutorial can be found\n      in [this page](https://pytorch.org/tutorials/intermediate/dist_tuto.html).\n- To change config options in general, you can use `--opts KEY1 VALUE1 KEY2 VALUE2`, e.g.,\n  `--opts TRAIN.EPOCHS 100 TRAIN.WARMUP_EPOCHS 5` will change total epochs to 100 and warm-up epochs to 5.\n- For additional options, see [config](config.py) and run `python main.py --help` to get detailed message.\n\nFor example, to train `Swin Transformer` with 8 GPU on a single node for 300 epochs, run:\n\n`Swin-T`:\n\n```bash\npython -m torch.distributed.launch --nproc_per_node 8 --master_port 12345  main.py \\\n--cfg configs/swin/swin_tiny_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 128 \n```\n\n`Swin-S`:\n\n```bash\npython -m torch.distributed.launch --nproc_per_node 8 --master_port 12345  main.py \\\n--cfg configs/swin/swin_small_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 128 \n```\n\n`Swin-B`:\n\n```bash\npython -m torch.distributed.launch --nproc_per_node 8 --master_port 12345  main.py \\\n--cfg configs/swin/swin_base_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 64 \\\n--accumulation-steps 2 [--use-checkpoint]\n```\n\n### Pre-training on ImageNet-22K\n\nFor example, to pre-train a `Swin-B` model on ImageNet-22K:\n\n```bash\npython -m torch.distributed.launch --nproc_per_node 8 --master_port 12345  main.py \\\n--cfg configs/swin/swin_base_patch4_window7_224_22k.yaml --data-path <imagenet22k-path> --batch-size 64 \\\n--accumulation-steps 8 [--use-checkpoint]\n```\n\n### Fine-tuning on higher resolution\n\nFor example, to fine-tune a `Swin-B` model pre-trained on 224x224 resolution to 384x384 resolution:\n\n```bashs\npython -m torch.distributed.launch --nproc_per_node 8 --master_port 12345  main.py \\\n--cfg configs/swin/swin_base_patch4_window12_384_finetune.yaml --pretrained swin_base_patch4_window7_224.pth \\\n--data-path <imagenet-path> --batch-size 64 --accumulation-steps 2 [--use-checkpoint]\n```\n\n### Fine-tuning from a ImageNet-22K(21K) pre-trained model\n\nFor example, to fine-tune a `Swin-B` model pre-trained on ImageNet-22K(21K):\n\n```bashs\npython -m torch.distributed.launch --nproc_per_node 8 --master_port 12345  main.py \\\n--cfg configs/swin/swin_base_patch4_window7_224_22kto1k_finetune.yaml --pretrained swin_base_patch4_window7_224_22k.pth \\\n--data-path <imagenet-path> --batch-size 64 --accumulation-steps 2 [--use-checkpoint]\n```\n\n### Throughput\n\nTo measure the throughput, run:\n\n```bash\npython -m torch.distributed.launch --nproc_per_node 1 --master_port 12345  main.py \\\n--cfg <config-file> --data-path <imagenet-path> --batch-size 64 --throughput --disable_amp\n```\n\n\n## Mixture-of-Experts Support\n\n### Install [Tutel](https://github.com/microsoft/tutel)\n```bash\npython3 -m pip uninstall tutel -y \npython3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@main\n```\n\n### Training Swin-MoE \nFor example, to train a `Swin-MoE-S` model with 32 experts on ImageNet-22K with 32 GPUs (4 nodes):\n\n```bash\npython -m torch.distributed.launch --nproc_per_node 8 --nnode=4 \\\n--node_rank=<node-rank> --master_addr=<master-ip> --master_port 12345  main_moe.py \\\n--cfg configs/swinmoe/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml --data-path <imagenet22k-path> --batch-size 128\n```\n\n### Evaluating Swin-MoE\n\nTo evaluate a `Swin-MoE-S` with 32 experts on ImageNet-22K with 32 GPUs (4 nodes):\n\n1. Download the zip file [swin_moe_small_patch4_window12_192_32expert_32gpu_22k.zip](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.zip) which contains the pre-trained models for each rank, and unzip them to the folder \"swin_moe_small_patch4_window12_192_32expert_32gpu_22k\".\n2. Run the following evaluation command, note the checkpoint path should not contain the \".rank\\<x\\>\" suffix.\n\n```bash\npython -m torch.distributed.launch --nproc_per_node 8 --nnode=4 \\\n--node_rank=<node-rank> --master_addr=<master-ip> --master_port 12345  main_moe.py \\\n--cfg configs/swinmoe/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml --data-path <imagenet22k-path> --batch-size 128 \\\n--resume swin_moe_small_patch4_window12_192_32expert_32gpu_22k/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.pth \n```\n\nMore Swin-MoE models can be found in [MODEL HUB](MODELHUB.md#imagenet-22k-pretrained-swin-moe-models)\n\n## SimMIM Support\n\n### Evaluating provided models\n\nTo evaluate a provided model on ImageNet validation set, run:\n```bash\npython -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> main_simmim_ft.py \\\n--eval --cfg <config-file> --resume <checkpoint> --data-path <imagenet-path>\n```\n\nFor example, to evaluate the `Swin Base` model on a single GPU, run:\n```bash\npython -m torch.distributed.launch --nproc_per_node 1 main_simmim_ft.py \\\n--eval --cfg configs/simmim/simmim_finetune__swin_base__img224_window7__800ep.yaml --resume simmim_finetune__swin_base__img224_window7__800ep.pth --data-path <imagenet-path>\n```\n\n### Pre-training with SimMIM\nTo pre-train models with `SimMIM`, run:\n```bash\npython -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> main_simmim_pt.py \\ \n--cfg <config-file> --data-path <imagenet-path>/train [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]\n```\n\nFor example, to pre-train `Swin Base` for 800 epochs on one DGX-2 server, run:\n```bash\npython -m torch.distributed.launch --nproc_per_node 16 main_simmim_pt.py \\ \n--cfg configs/simmim/simmim_pretrain__swin_base__img192_window6__800ep.yaml --batch-size 128 --data-path <imagenet-path>/train [--output <output-directory> --tag <job-tag>]\n```\n\n### Fine-tuning pre-trained models\nTo fine-tune models pre-trained by `SimMIM`, run:\n```bash\npython -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> main_simmim_ft.py \\ \n--cfg <config-file> --data-path <imagenet-path> --pretrained <pretrained-ckpt> [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]\n```\n\nFor example, to fine-tune `Swin Base` pre-trained by `SimMIM` on one DGX-2 server, run:\n```bash\npython -m torch.distributed.launch --nproc_per_node 16 main_simmim_ft.py \\ \n--cfg configs/simmim/simmim_finetune__swin_base__img224_window7__800ep.yaml --batch-size 128 --data-path <imagenet-path> --pretrained <pretrained-ckpt> [--output <output-directory> --tag <job-tag>]\n```"
  },
  {
    "path": "kernels/window_process/setup.py",
    "content": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n\nsetup(name='swin_window_process',\n    ext_modules=[\n        CUDAExtension('swin_window_process', [\n            'swin_window_process.cpp',\n            'swin_window_process_kernel.cu',\n        ])\n    ],\n    cmdclass={'build_ext': BuildExtension})"
  },
  {
    "path": "kernels/window_process/swin_window_process.cpp",
    "content": "/*\n * Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <torch/torch.h>\n#include <torch/extension.h>\n\n\nat::Tensor roll_and_window_partition_forward_cuda(\n    at::Tensor & input, \n    //at::Tensor & output,\n    const int B,\n    const int H,\n    const int W,\n    const int C,\n    const int shift_size,\n    const int window_size);\n\n\nat::Tensor roll_and_window_partition_backward_cuda(\n    at::Tensor & grad_in, \n    //at::Tensor & grad_out,\n    const int B,\n    const int H,\n    const int W,\n    const int C,\n    const int shift_size,\n    const int window_size);\n\n\nat::Tensor window_merge_and_roll_forward_cuda(\n    at::Tensor & input, \n    //at::Tensor & output,\n    const int B,\n    const int H,\n    const int W,\n    const int C,\n    const int shift_size,\n    const int window_size);\n\nat::Tensor window_merge_and_roll_backward_cuda(\n    at::Tensor & grad_in, \n    //at::Tensor & grad_out,\n    const int B,\n    const int H,\n    const int W,\n    const int C,\n    const int shift_size,\n    const int window_size);\n\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\n\n\nat::Tensor roll_and_window_partition_forward(\n    at::Tensor & input, \n    //at::Tensor & output,\n    const int B,\n    const int H,\n    const int W,\n    const int C,\n    const int shift_size,\n    const int window_size){\n    CHECK_INPUT(input);\n    return roll_and_window_partition_forward_cuda(input, B, H, W, C, shift_size, window_size);\n}\n\n\nat::Tensor roll_and_window_partition_backward(\n    at::Tensor & grad_in, \n    //at::Tensor & grad_out,\n    const int B,\n    const int H,\n    const int W,\n    const int C,\n    const int shift_size,\n    const int window_size){\n    CHECK_INPUT(grad_in);\n    return roll_and_window_partition_backward_cuda(grad_in, B, H, W, C, shift_size, window_size);\n}\n\n\nat::Tensor window_merge_and_roll_forward(\n    at::Tensor & input, \n    //at::Tensor & output,\n    const int B,\n    const int H,\n    const int W,\n    const int C,\n    const int shift_size,\n    const int window_size){\n    CHECK_INPUT(input);\n    return window_merge_and_roll_forward_cuda(input, B, H, W, C, shift_size, window_size);\n}\n\n\nat::Tensor window_merge_and_roll_backward(\n    at::Tensor & grad_in, \n    //at::Tensor & grad_out,\n    const int B,\n    const int H,\n    const int W,\n    const int C,\n    const int shift_size,\n    const int window_size){\n    CHECK_INPUT(grad_in);\n    return window_merge_and_roll_backward_cuda(grad_in, B, H, W, C, shift_size, window_size);\n}\n\n\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"roll_and_window_partition_forward\", &roll_and_window_partition_forward, \"torch.roll and window_partition.\");\n    m.def(\"roll_and_window_partition_backward\", &roll_and_window_partition_backward, \"torch.roll and window_partition.\");\n    m.def(\"window_merge_and_roll_forward\", &window_merge_and_roll_forward, \"window merge and torch.roll.\");\n    m.def(\"window_merge_and_roll_backward\", &window_merge_and_roll_backward, \"window merge and torch.roll.\");\n}"
  },
  {
    "path": "kernels/window_process/swin_window_process_kernel.cu",
    "content": "/*\n * Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <torch/extension.h>\n#include <stdio.h>\n\nint best_block_dim(int feat_dim){\n    int best_dim;\n    if (feat_dim < 384){\n        best_dim = 64;\n    }\n    else{\n        if (feat_dim < 1024){\n            best_dim = 128;\n        }\n        else{\n            best_dim = 256;\n        }\n    }\n    return best_dim;\n}\n\n\ntemplate <typename T>\n__global__ void roll_and_window_partition_forward_cuda_kernel(\n    T* input, \n    T* output, \n    const int B,\n    const int H,\n    const int W,\n    const int C,\n    const int shift_size,\n    const int window_size,\n    const int nH,\n    const int nW){\n    // start\n    //bool qual = threadIdx.x < C;\n    int index = threadIdx.x;\n    int offset;\n    for (int i = index; i < C; i += blockDim.x) {\n        offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize\n        int input_offset = blockIdx.z / (nH * nW) * H * W * C +\n            (blockIdx.z % (nH * nW) / nW * window_size + blockIdx.y - shift_size + H) % H * W * C + \n            (blockIdx.z % nW * window_size + blockIdx.x - shift_size + W) % W * C +\n            i;\n        output[offset] = (T)(__ldg(input + input_offset));\n    }\n}\n\n\ntemplate <typename T>\n__global__ void roll_and_window_partition_backward_cuda_kernel(\n    T* grad_in, \n    T* grad_out, \n    const int B,\n    const int H,\n    const int W,\n    const int C,\n    const int shift_size,\n    const int window_size,\n    const int nH,\n    const int nW){\n    // start\n    int index = threadIdx.x;\n    int offset;\n    for (int i = index; i < C; i += blockDim.x) {\n        offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize\n        int input_offset = \n        (blockIdx.z * nH * nW + (blockIdx.y + shift_size + H) % H / window_size * nW + (blockIdx.x + shift_size + W) % W / window_size) * window_size * window_size * C +\n        (blockIdx.y + shift_size + H ) % H % window_size * window_size * C +\n        (blockIdx.x + shift_size + W ) % W % window_size * C +\n        i;\n        grad_out[offset] = (T)(__ldg(grad_in + input_offset));\n    }\n}\n\n\ntemplate <typename T>\n__global__ void window_merge_and_roll_forward_cuda_kernel(\n    T* input, \n    T* output, \n    const int B,\n    const int H,\n    const int W,\n    const int C,\n    const int shift_size,\n    const int window_size,\n    const int nH,\n    const int nW){\n    // start\n    int index = threadIdx.x;\n    int offset;\n    for (int i = index; i < C; i += blockDim.x) {\n        offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize\n        int input_offset = \n            (blockIdx.z * nH * nW + (blockIdx.y - shift_size + H) % H / window_size * nH + (blockIdx.x - shift_size + W) % W / window_size) * window_size * window_size * C +\n            (blockIdx.y - shift_size + H) % window_size * window_size * C + \n            (blockIdx.x - shift_size + W) % window_size * C +\n            i;\n        output[offset] = (T)(__ldg(input + input_offset));\n    }\n}\n\n\n\ntemplate <typename T>\n__global__ void window_merge_and_roll_backward_cuda_kernel(\n    T* grad_in, \n    T* grad_out, \n    const int B,\n    const int H,\n    const int W,\n    const int C,\n    const int shift_size,\n    const int window_size,\n    const int nH,\n    const int nW){\n    // start\n    int index = threadIdx.x;\n    int offset;\n    for (int i = index; i < C; i += blockDim.x) {\n        offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize\n        int input_offset = \n        (blockIdx.z / (nH * nW)) * H * W * C +\n        (blockIdx.z % (nH * nW) / nW * window_size + blockIdx.y + shift_size + H) % H * W * C +\n        (blockIdx.z % nW * window_size + blockIdx.x + shift_size + W) % W * C +\n        i;\n        grad_out[offset] = (T)(__ldg(grad_in + input_offset));\n    }\n}\n\n// input: [B, H, W, C]\n// output: [B*nH*nW, window_size, window_size, C]\nat::Tensor roll_and_window_partition_forward_cuda(\n    at::Tensor & input, \n    //at::Tensor & output,\n    const int B,\n    const int H,\n    const int W,\n    const int C,\n    const int shift_size,\n    const int window_size){\n    \n    int nH = H / window_size;\n    int nW = W / window_size;\n\n    dim3 grid(window_size, window_size, B * nH * nW);\n    //dim3 block((C + 31) / 32 * 32);\n    int blocknum = best_block_dim(C);\n    dim3 block(blocknum);\n\n    at::Tensor output;\n    if (input.scalar_type() == torch::kFloat16){\n        output = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(true));\n    }\n    else{\n        output = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(true));\n    }\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), \"roll_and_window_partition_forward_cuda_kernel\", ([&] {\n        roll_and_window_partition_forward_cuda_kernel<scalar_t><<<grid, block, 0>>>(\n            input.data<scalar_t>(),\n            output.data<scalar_t>(),\n            B,\n            H,\n            W,\n            C,\n            shift_size,\n            window_size,\n            nH,\n            nW);\n    }));\n    return output;\n}\n\n\n// grad_in: [B*nH*nW, window_size, window_size, C]\n// grad_out: [B, H, W, C]\nat::Tensor roll_and_window_partition_backward_cuda(\n    at::Tensor & grad_in, \n    const int B,\n    const int H,\n    const int W,\n    const int C,\n    const int shift_size,\n    const int window_size){\n    \n    int nH = H / window_size;\n    int nW = W / window_size;\n\n    dim3 grid(W, H, B);\n    //dim3 block((C + 31) / 32 * 32);\n    int blocknum = best_block_dim(C);\n    dim3 block(blocknum);\n\n    at::Tensor grad_out;\n    if (grad_in.scalar_type() == torch::kFloat16){\n        grad_out = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false));\n    }\n    else{\n        grad_out = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));\n    }\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_in.type(), \"roll_and_window_partition_backward_cuda_kernel\", ([&] {\n        roll_and_window_partition_backward_cuda_kernel<scalar_t><<<grid, block, 0>>>(\n            grad_in.data<scalar_t>(),\n            grad_out.data<scalar_t>(),\n            B,\n            H,\n            W,\n            C,\n            shift_size,\n            window_size,\n            nH,\n            nW);\n    }));\n    return grad_out;\n}\n\n\n// input: [B*nH*nW, window_size, window_size, C]\n// output: [B, H, W, C]\nat::Tensor window_merge_and_roll_forward_cuda(\n    at::Tensor & input, \n    //at::Tensor & output,\n    const int B,\n    const int H,\n    const int W,\n    const int C,\n    const int shift_size,\n    const int window_size){\n    \n    int nH = H / window_size;\n    int nW = W / window_size;\n\n    dim3 grid(W, H, B);\n    //dim3 block((C + 31) / 32 * 32);\n    int blocknum = best_block_dim(C);\n    dim3 block(blocknum);\n\n    //generate output tensor inside\n    at::Tensor output;\n    if (input.scalar_type() == torch::kFloat16){\n        output = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(true));\n    }\n    else{\n        output = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(true));\n    }\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), \"window_merge_and_roll_forward_cuda_kernel\", ([&] {\n        window_merge_and_roll_forward_cuda_kernel<scalar_t><<<grid, block, 0>>>(\n            input.data<scalar_t>(),\n            output.data<scalar_t>(),\n            B,\n            H,\n            W,\n            C,\n            shift_size,\n            window_size,\n            nH,\n            nW);\n    }));\n    return output;\n}\n\n\nat::Tensor window_merge_and_roll_backward_cuda(\n    at::Tensor & grad_in, \n    const int B,\n    const int H,\n    const int W,\n    const int C,\n    const int shift_size,\n    const int window_size){\n    \n    int nH = H / window_size;\n    int nW = W / window_size;\n\n    dim3 grid(window_size, window_size, B * nH * nW);\n    //dim3 block((C + 31) / 32 * 32);\n    int blocknum = best_block_dim(C);\n    dim3 block(blocknum);\n\n    at::Tensor grad_out;\n    if (grad_in.scalar_type() == torch::kFloat16){\n        grad_out = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false));\n    }\n    else{\n        grad_out = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));\n    }\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_in.type(), \"window_merge_and_roll_backward_cuda_kernel\", ([&] {\n        window_merge_and_roll_backward_cuda_kernel<scalar_t><<<grid, block, 0>>>(\n            grad_in.data<scalar_t>(),\n            grad_out.data<scalar_t>(),\n            B,\n            H,\n            W,\n            C,\n            shift_size,\n            window_size,\n            nH,\n            nW);\n    }));\n    return grad_out;\n}"
  },
  {
    "path": "kernels/window_process/unit_test.py",
    "content": "# --------------------------------------------------------\n# Fused kernel for window process for SwinTransformer\n# Copyright (c) 2022 Nvidia\n# Licensed under The MIT License [see LICENSE for details]\n# --------------------------------------------------------\n\nimport torch\nimport swin_window_process\nimport random\nimport time\nimport unittest\n\n\nclass WindowProcess(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input, B, H, W, C, shift_size, window_size):\n        output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size)\n\n        ctx.B = B\n        ctx.H = H\n        ctx.W = W \n        ctx.C = C \n        ctx.shift_size = shift_size\n        ctx.window_size = window_size\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_in):\n        B = ctx.B\n        H = ctx.H\n        W = ctx.W \n        C = ctx.C \n        shift_size = ctx.shift_size\n        window_size = ctx.window_size\n\n        grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size)\n        return grad_out, None, None, None, None, None, None, None\n\n\nclass WindowProcessReverse(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input, B, H, W, C, shift_size, window_size):\n        output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size)\n\n        ctx.B = B\n        ctx.H = H\n        ctx.W = W \n        ctx.C = C \n        ctx.shift_size = shift_size\n        ctx.window_size = window_size\n\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_in):\n        B = ctx.B\n        H = ctx.H\n        W = ctx.W \n        C = ctx.C \n        shift_size = ctx.shift_size\n        window_size = ctx.window_size\n\n        grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size)\n        return grad_out, None, None, None, None, None, None, None\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\ndef pyt_forward(x, shift_size, window_size):\n    # x in shape(B, H, W, C)\n    # cyclic shift\n    if shift_size > 0:\n        shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2))\n    else:\n        shifted_x = x\n    # partition windows\n    x_windows = window_partition(shifted_x, window_size)\n    return x_windows\n\n\ndef reverse_pyt_forward(attn_windows, shift_size, window_size, H, W):\n    # x in shape(B*nH*nW, window_size, window_size, C)\n    shifted_x = window_reverse(attn_windows, window_size, H, W)\n    if shift_size > 0:\n        x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2))\n    else:\n        x = shifted_x\n    return x\n\n\ndef copy_one_tensor(input, requires_grad=True):\n    input1 = input.clone().detach().requires_grad_(requires_grad).cuda()\n    return input1\n\nclass Test_WindowProcess(unittest.TestCase):\n    def setUp(self):\n        self.B = 192\n        self.H = 56\n        self.W = 56\n        self.C = 96\n        self.shift_size = 2\n        self.window_size = 7\n        self.nH = self.H // self.window_size\n        self.nW = self.W // self.window_size\n    \n    def test_roll_and_window_partition_forward(self, dtype=torch.float32):\n        input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()\n        \n        input1 = copy_one_tensor(input, True)\n        input2 = copy_one_tensor(input, True)\n\n        with torch.no_grad():\n            # ori\n            expected = pyt_forward(input1, self.shift_size, self.window_size)\n            # fused kernel\n            fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size)\n        \n        self.assertTrue(torch.equal(expected, fused_output))\n        #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))\n    \n    def test_roll_and_window_partition_backward(self, dtype=torch.float32):\n        input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()\n        d_loss_tensor = torch.randn((self.B*self.nW*self.nH, self.window_size, self.window_size, self.C), dtype=dtype).cuda()\n        \n        input1 = copy_one_tensor(input, True)\n        input2 = copy_one_tensor(input, True)\n\n        # ori\n        expected = pyt_forward(input1, self.shift_size, self.window_size)\n        expected.backward(d_loss_tensor)\n        # fused kernel\n        fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size)\n        fused_output.backward(d_loss_tensor)\n        \n        self.assertTrue(torch.equal(expected, fused_output))\n        #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))\n\n    def test_window_merge_and_roll_forward(self, dtype=torch.float32):\n        input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()\n        \n        input1 = copy_one_tensor(input, True)\n        input2 = copy_one_tensor(input, True)\n\n        with torch.no_grad():\n            # ori\n            expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)\n            # fused kernel\n            fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)\n        \n        self.assertTrue(torch.equal(expected, fused_output))\n        #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))\n    \n\n    def test_window_merge_and_roll_backward(self, dtype=torch.float32):\n        input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()\n        d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()\n        \n        input1 = copy_one_tensor(input, True)\n        input2 = copy_one_tensor(input, True)\n\n        # ori\n        expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)\n        expected.backward(d_loss_tensor)\n        # fused kernel\n        fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)\n        fused_output.backward(d_loss_tensor)\n        \n        self.assertTrue(torch.equal(expected, fused_output))\n        #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))\n\n    def test_forward_backward_speed(self, dtype=torch.float32, times=1000):\n        input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()\n        d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()\n        \n        input1 = copy_one_tensor(input, True)\n        input2 = copy_one_tensor(input, True)\n\n        # SwinTransformer official\n        def run_pyt(t=1000):\n            for _ in range(t):\n                expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)\n                expected.backward(d_loss_tensor)\n\n        # my op\n        def run_fusedop(t=1000):\n            for _ in range(t):\n                fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)\n                fused_output.backward(d_loss_tensor)\n        \n        torch.cuda.synchronize()\n        t1 = time.time()\n        run_pyt(t=times)\n        torch.cuda.synchronize()\n        t2 = time.time()\n        run_fusedop(t=times)\n        torch.cuda.synchronize()\n        t3 = time.time()\n        self.assertTrue((t3 - t2) < (t2 - t1))\n\n        print('Run {} times'.format(times))\n        print('Original time cost: {}'.format(t2 - t1))\n        print('Fused op time cost: {}'.format(t3 - t2))\n    \n    def test_roll_and_window_partition_forward_fp16(self, dtype=torch.float16):\n        self.test_roll_and_window_partition_forward(dtype=dtype)\n\n    def test_roll_and_window_partition_backward_fp16(self, dtype=torch.float16):\n        self.test_roll_and_window_partition_backward(dtype=dtype)\n\n    def test_window_merge_and_roll_forward_fp16(self, dtype=torch.float16):\n        self.test_window_merge_and_roll_forward(dtype=dtype)\n    \n    def test_window_merge_and_roll_backward_fp16(self, dtype=torch.float16):\n        self.test_window_merge_and_roll_backward(dtype=dtype)\n\n    def test_forward_backward_speed_fp16(self, dtype=torch.float16, times=1000):\n        self.test_forward_backward_speed(dtype=dtype, times=times)\n\n\nif __name__ == '__main__':\n    print('Pass only two tensors are exactly the same (using torch.equal).\\n')\n    torch.manual_seed(0)\n    unittest.main(verbosity=2)\n"
  },
  {
    "path": "kernels/window_process/window_process.py",
    "content": "# --------------------------------------------------------\n# Fused kernel for window process for SwinTransformer\n# Copyright (c) 2022 Nvidia\n# Licensed under The MIT License [see LICENSE for details]\n# --------------------------------------------------------\n\nimport torch\nimport swin_window_process\n\n\nclass WindowProcess(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input, B, H, W, C, shift_size, window_size):\n        output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size)\n\n        ctx.B = B\n        ctx.H = H\n        ctx.W = W \n        ctx.C = C \n        ctx.shift_size = shift_size\n        ctx.window_size = window_size\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_in):\n        B = ctx.B\n        H = ctx.H\n        W = ctx.W \n        C = ctx.C \n        shift_size = ctx.shift_size\n        window_size = ctx.window_size\n\n        grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size)\n        return grad_out, None, None, None, None, None, None, None\n\n\nclass WindowProcessReverse(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input, B, H, W, C, shift_size, window_size):\n        output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size)\n\n        ctx.B = B\n        ctx.H = H\n        ctx.W = W \n        ctx.C = C \n        ctx.shift_size = shift_size\n        ctx.window_size = window_size\n\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_in):\n        B = ctx.B\n        H = ctx.H\n        W = ctx.W \n        C = ctx.C \n        shift_size = ctx.shift_size\n        window_size = ctx.window_size\n\n        #grad_out = ctx.saved_tensors[0]\n        #grad_out = torch.zeros((B, H, W, C), dtype=dtype).cuda()\n        grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size)\n        return grad_out, None, None, None, None, None, None, None\n"
  },
  {
    "path": "logger.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nimport os\nimport sys\nimport logging\nimport functools\nfrom termcolor import colored\n\n\n@functools.lru_cache()\ndef create_logger(output_dir, dist_rank=0, name=''):\n    # create logger\n    logger = logging.getLogger(name)\n    logger.setLevel(logging.DEBUG)\n    logger.propagate = False\n\n    # create formatter\n    fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'\n    color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \\\n                colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'\n\n    # create console handlers for master process\n    if dist_rank == 0:\n        console_handler = logging.StreamHandler(sys.stdout)\n        console_handler.setLevel(logging.DEBUG)\n        console_handler.setFormatter(\n            logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))\n        logger.addHandler(console_handler)\n\n    # create file handlers\n    file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')\n    file_handler.setLevel(logging.DEBUG)\n    file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))\n    logger.addHandler(file_handler)\n\n    return logger\n"
  },
  {
    "path": "lr_scheduler.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nimport bisect\n\nimport torch\nfrom timm.scheduler.cosine_lr import CosineLRScheduler\nfrom timm.scheduler.step_lr import StepLRScheduler\nfrom timm.scheduler.scheduler import Scheduler\n\n\ndef build_scheduler(config, optimizer, n_iter_per_epoch):\n    num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)\n    warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)\n    decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)\n    multi_steps = [i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS]\n\n    lr_scheduler = None\n    if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':\n        lr_scheduler = CosineLRScheduler(\n            optimizer,\n            t_initial=(num_steps - warmup_steps) if config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX else num_steps,\n            t_mul=1.,\n            lr_min=config.TRAIN.MIN_LR,\n            warmup_lr_init=config.TRAIN.WARMUP_LR,\n            warmup_t=warmup_steps,\n            cycle_limit=1,\n            t_in_epochs=False,\n            warmup_prefix=config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX,\n        )\n    elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':\n        lr_scheduler = LinearLRScheduler(\n            optimizer,\n            t_initial=num_steps,\n            lr_min_rate=0.01,\n            warmup_lr_init=config.TRAIN.WARMUP_LR,\n            warmup_t=warmup_steps,\n            t_in_epochs=False,\n        )\n    elif config.TRAIN.LR_SCHEDULER.NAME == 'step':\n        lr_scheduler = StepLRScheduler(\n            optimizer,\n            decay_t=decay_steps,\n            decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,\n            warmup_lr_init=config.TRAIN.WARMUP_LR,\n            warmup_t=warmup_steps,\n            t_in_epochs=False,\n        )\n    elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep':\n        lr_scheduler = MultiStepLRScheduler(\n            optimizer,\n            milestones=multi_steps,\n            gamma=config.TRAIN.LR_SCHEDULER.GAMMA,\n            warmup_lr_init=config.TRAIN.WARMUP_LR,\n            warmup_t=warmup_steps,\n            t_in_epochs=False,\n        )\n\n    return lr_scheduler\n\n\nclass LinearLRScheduler(Scheduler):\n    def __init__(self,\n                 optimizer: torch.optim.Optimizer,\n                 t_initial: int,\n                 lr_min_rate: float,\n                 warmup_t=0,\n                 warmup_lr_init=0.,\n                 t_in_epochs=True,\n                 noise_range_t=None,\n                 noise_pct=0.67,\n                 noise_std=1.0,\n                 noise_seed=42,\n                 initialize=True,\n                 ) -> None:\n        super().__init__(\n            optimizer, param_group_field=\"lr\",\n            noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,\n            initialize=initialize)\n\n        self.t_initial = t_initial\n        self.lr_min_rate = lr_min_rate\n        self.warmup_t = warmup_t\n        self.warmup_lr_init = warmup_lr_init\n        self.t_in_epochs = t_in_epochs\n        if self.warmup_t:\n            self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]\n            super().update_groups(self.warmup_lr_init)\n        else:\n            self.warmup_steps = [1 for _ in self.base_values]\n\n    def _get_lr(self, t):\n        if t < self.warmup_t:\n            lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]\n        else:\n            t = t - self.warmup_t\n            total_t = self.t_initial - self.warmup_t\n            lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]\n        return lrs\n\n    def get_epoch_values(self, epoch: int):\n        if self.t_in_epochs:\n            return self._get_lr(epoch)\n        else:\n            return None\n\n    def get_update_values(self, num_updates: int):\n        if not self.t_in_epochs:\n            return self._get_lr(num_updates)\n        else:\n            return None\n\n\nclass MultiStepLRScheduler(Scheduler):\n    def __init__(self, optimizer: torch.optim.Optimizer, milestones, gamma=0.1, warmup_t=0, warmup_lr_init=0, t_in_epochs=True) -> None:\n        super().__init__(optimizer, param_group_field=\"lr\")\n        \n        self.milestones = milestones\n        self.gamma = gamma\n        self.warmup_t = warmup_t\n        self.warmup_lr_init = warmup_lr_init\n        self.t_in_epochs = t_in_epochs\n        if self.warmup_t:\n            self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]\n            super().update_groups(self.warmup_lr_init)\n        else:\n            self.warmup_steps = [1 for _ in self.base_values]\n        \n        assert self.warmup_t <= min(self.milestones)\n    \n    def _get_lr(self, t):\n        if t < self.warmup_t:\n            lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]\n        else:\n            lrs = [v * (self.gamma ** bisect.bisect_right(self.milestones, t)) for v in self.base_values]\n        return lrs\n\n    def get_epoch_values(self, epoch: int):\n        if self.t_in_epochs:\n            return self._get_lr(epoch)\n        else:\n            return None\n\n    def get_update_values(self, num_updates: int):\n        if not self.t_in_epochs:\n            return self._get_lr(num_updates)\n        else:\n            return None\n"
  },
  {
    "path": "main.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nimport os\nimport time\nimport json\nimport random\nimport argparse\nimport datetime\nimport numpy as np\n\nimport torch\nimport torch.backends.cudnn as cudnn\nimport torch.distributed as dist\n\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy\nfrom timm.utils import accuracy, AverageMeter\n\nfrom config import get_config\nfrom models import build_model\nfrom data import build_loader\nfrom lr_scheduler import build_scheduler\nfrom optimizer import build_optimizer\nfrom logger import create_logger\nfrom utils import load_checkpoint, load_pretrained, save_checkpoint, NativeScalerWithGradNormCount, auto_resume_helper, \\\n    reduce_tensor\n\n# pytorch major version (1.x or 2.x)\nPYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])\n\n\ndef parse_option():\n    parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)\n    parser.add_argument('--cfg', type=str, required=True, metavar=\"FILE\", help='path to config file', )\n    parser.add_argument(\n        \"--opts\",\n        help=\"Modify config options by adding 'KEY VALUE' pairs. \",\n        default=None,\n        nargs='+',\n    )\n\n    # easy config modification\n    parser.add_argument('--batch-size', type=int, help=\"batch size for single GPU\")\n    parser.add_argument('--data-path', type=str, help='path to dataset')\n    parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')\n    parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],\n                        help='no: no cache, '\n                             'full: cache all data, '\n                             'part: sharding the dataset into nonoverlapping pieces and only cache one piece')\n    parser.add_argument('--pretrained',\n                        help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')\n    parser.add_argument('--resume', help='resume from checkpoint')\n    parser.add_argument('--accumulation-steps', type=int, help=\"gradient accumulation steps\")\n    parser.add_argument('--use-checkpoint', action='store_true',\n                        help=\"whether to use gradient checkpointing to save memory\")\n    parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp')\n    parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'],\n                        help='mixed precision opt level, if O0, no amp is used (deprecated!)')\n    parser.add_argument('--output', default='output', type=str, metavar='PATH',\n                        help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')\n    parser.add_argument('--tag', help='tag of experiment')\n    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')\n    parser.add_argument('--throughput', action='store_true', help='Test throughput only')\n\n    # distributed training\n    # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead\n    # (see https://pytorch.org/docs/stable/distributed.html#launch-utility)\n    if PYTORCH_MAJOR_VERSION == 1:\n        parser.add_argument(\"--local_rank\", type=int, required=True, help='local rank for DistributedDataParallel')\n\n    # for acceleration\n    parser.add_argument('--fused_window_process', action='store_true',\n                        help='Fused window shift & window partition, similar for reversed part.')\n    parser.add_argument('--fused_layernorm', action='store_true', help='Use fused layernorm.')\n    ## overwrite optimizer in config (*.yaml) if specified, e.g., fused_adam/fused_lamb\n    parser.add_argument('--optim', type=str,\n                        help='overwrite optimizer if provided, can be adamw/sgd/fused_adam/fused_lamb.')\n\n    args, unparsed = parser.parse_known_args()\n\n    config = get_config(args)\n\n    return args, config\n\n\ndef main(config):\n    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)\n\n    logger.info(f\"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}\")\n    model = build_model(config)\n    logger.info(str(model))\n\n    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)\n    logger.info(f\"number of params: {n_parameters}\")\n    if hasattr(model, 'flops'):\n        flops = model.flops()\n        logger.info(f\"number of GFLOPs: {flops / 1e9}\")\n\n    model.cuda()\n    model_without_ddp = model\n\n    optimizer = build_optimizer(config, model)\n    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)\n    loss_scaler = NativeScalerWithGradNormCount()\n\n    if config.TRAIN.ACCUMULATION_STEPS > 1:\n        lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS)\n    else:\n        lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))\n\n    if config.AUG.MIXUP > 0.:\n        # smoothing is handled with mixup label transform\n        criterion = SoftTargetCrossEntropy()\n    elif config.MODEL.LABEL_SMOOTHING > 0.:\n        criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)\n    else:\n        criterion = torch.nn.CrossEntropyLoss()\n\n    max_accuracy = 0.0\n\n    if config.TRAIN.AUTO_RESUME:\n        resume_file = auto_resume_helper(config.OUTPUT)\n        if resume_file:\n            if config.MODEL.RESUME:\n                logger.warning(f\"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}\")\n            config.defrost()\n            config.MODEL.RESUME = resume_file\n            config.freeze()\n            logger.info(f'auto resuming from {resume_file}')\n        else:\n            logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')\n\n    if config.MODEL.RESUME:\n        max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger)\n        acc1, acc5, loss = validate(config, data_loader_val, model)\n        logger.info(f\"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%\")\n        if config.EVAL_MODE:\n            return\n\n    if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):\n        load_pretrained(config, model_without_ddp, logger)\n        acc1, acc5, loss = validate(config, data_loader_val, model)\n        logger.info(f\"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%\")\n\n    if config.THROUGHPUT_MODE:\n        throughput(data_loader_val, model, logger)\n        return\n\n    logger.info(\"Start training\")\n    start_time = time.time()\n    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):\n        data_loader_train.sampler.set_epoch(epoch)\n\n        train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler,\n                        loss_scaler)\n        if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):\n            save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler,\n                            logger)\n\n        acc1, acc5, loss = validate(config, data_loader_val, model)\n        logger.info(f\"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%\")\n        max_accuracy = max(max_accuracy, acc1)\n        logger.info(f'Max accuracy: {max_accuracy:.2f}%')\n\n    total_time = time.time() - start_time\n    total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n    logger.info('Training time {}'.format(total_time_str))\n\n\ndef train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler):\n    model.train()\n    optimizer.zero_grad()\n\n    num_steps = len(data_loader)\n    batch_time = AverageMeter()\n    loss_meter = AverageMeter()\n    norm_meter = AverageMeter()\n    scaler_meter = AverageMeter()\n\n    start = time.time()\n    end = time.time()\n    for idx, (samples, targets) in enumerate(data_loader):\n        samples = samples.cuda(non_blocking=True)\n        targets = targets.cuda(non_blocking=True)\n\n        if mixup_fn is not None:\n            samples, targets = mixup_fn(samples, targets)\n\n        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):\n            outputs = model(samples)\n        loss = criterion(outputs, targets)\n        loss = loss / config.TRAIN.ACCUMULATION_STEPS\n\n        # this attribute is added by timm on one optimizer (adahessian)\n        is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order\n        grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD,\n                                parameters=model.parameters(), create_graph=is_second_order,\n                                update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0)\n        if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:\n            optimizer.zero_grad()\n            lr_scheduler.step_update((epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS)\n        loss_scale_value = loss_scaler.state_dict()[\"scale\"]\n\n        torch.cuda.synchronize()\n\n        loss_meter.update(loss.item(), targets.size(0))\n        if grad_norm is not None:  # loss_scaler return None if not update\n            norm_meter.update(grad_norm)\n        scaler_meter.update(loss_scale_value)\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if idx % config.PRINT_FREQ == 0:\n            lr = optimizer.param_groups[0]['lr']\n            wd = optimizer.param_groups[0]['weight_decay']\n            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)\n            etas = batch_time.avg * (num_steps - idx)\n            logger.info(\n                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\\t'\n                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\\t wd {wd:.4f}\\t'\n                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\\t'\n                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\\t'\n                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\\t'\n                f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\\t'\n                f'mem {memory_used:.0f}MB')\n    epoch_time = time.time() - start\n    logger.info(f\"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}\")\n\n\n@torch.no_grad()\ndef validate(config, data_loader, model):\n    criterion = torch.nn.CrossEntropyLoss()\n    model.eval()\n\n    batch_time = AverageMeter()\n    loss_meter = AverageMeter()\n    acc1_meter = AverageMeter()\n    acc5_meter = AverageMeter()\n\n    end = time.time()\n    for idx, (images, target) in enumerate(data_loader):\n        images = images.cuda(non_blocking=True)\n        target = target.cuda(non_blocking=True)\n\n        # compute output\n        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):\n            output = model(images)\n\n        # measure accuracy and record loss\n        loss = criterion(output, target)\n        acc1, acc5 = accuracy(output, target, topk=(1, 5))\n\n        acc1 = reduce_tensor(acc1)\n        acc5 = reduce_tensor(acc5)\n        loss = reduce_tensor(loss)\n\n        loss_meter.update(loss.item(), target.size(0))\n        acc1_meter.update(acc1.item(), target.size(0))\n        acc5_meter.update(acc5.item(), target.size(0))\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if idx % config.PRINT_FREQ == 0:\n            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)\n            logger.info(\n                f'Test: [{idx}/{len(data_loader)}]\\t'\n                f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\\t'\n                f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\\t'\n                f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\\t'\n                f'Mem {memory_used:.0f}MB')\n    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')\n    return acc1_meter.avg, acc5_meter.avg, loss_meter.avg\n\n\n@torch.no_grad()\ndef throughput(data_loader, model, logger):\n    model.eval()\n\n    for idx, (images, _) in enumerate(data_loader):\n        images = images.cuda(non_blocking=True)\n        batch_size = images.shape[0]\n        for i in range(50):\n            model(images)\n        torch.cuda.synchronize()\n        logger.info(f\"throughput averaged with 30 times\")\n        tic1 = time.time()\n        for i in range(30):\n            model(images)\n        torch.cuda.synchronize()\n        tic2 = time.time()\n        logger.info(f\"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}\")\n        return\n\n\nif __name__ == '__main__':\n    args, config = parse_option()\n\n    if config.AMP_OPT_LEVEL:\n        print(\"[warning] Apex amp has been deprecated, please use pytorch amp instead!\")\n\n    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:\n        rank = int(os.environ[\"RANK\"])\n        world_size = int(os.environ['WORLD_SIZE'])\n        print(f\"RANK and WORLD_SIZE in environ: {rank}/{world_size}\")\n    else:\n        rank = -1\n        world_size = -1\n    torch.cuda.set_device(config.LOCAL_RANK)\n    torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)\n    torch.distributed.barrier()\n\n    seed = config.SEED + dist.get_rank()\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n    cudnn.benchmark = True\n\n    # linear scale the learning rate according to total batch size, may not be optimal\n    linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0\n    linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0\n    linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0\n    # gradient accumulation also need to scale the learning rate\n    if config.TRAIN.ACCUMULATION_STEPS > 1:\n        linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS\n        linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS\n        linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS\n    config.defrost()\n    config.TRAIN.BASE_LR = linear_scaled_lr\n    config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr\n    config.TRAIN.MIN_LR = linear_scaled_min_lr\n    config.freeze()\n\n    os.makedirs(config.OUTPUT, exist_ok=True)\n    logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f\"{config.MODEL.NAME}\")\n\n    if dist.get_rank() == 0:\n        path = os.path.join(config.OUTPUT, \"config.json\")\n        with open(path, \"w\") as f:\n            f.write(config.dump())\n        logger.info(f\"Full config saved to {path}\")\n\n    # print config\n    logger.info(config.dump())\n    logger.info(json.dumps(vars(args)))\n\n    main(config)\n"
  },
  {
    "path": "main_moe.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nfrom tutel import system\n\nimport os\nimport time\nimport json\nimport random\nimport argparse\nimport datetime\nimport numpy as np\nfrom functools import partial\nimport torch\nimport torch.backends.cudnn as cudnn\nimport torch.distributed as dist\n\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy\nfrom timm.utils import accuracy, AverageMeter\n\nfrom config import get_config\nfrom models import build_model\nfrom data import build_loader\nfrom lr_scheduler import build_scheduler\nfrom optimizer import build_optimizer\nfrom logger import create_logger\nfrom utils import NativeScalerWithGradNormCount, reduce_tensor\nfrom utils_moe import load_checkpoint, load_pretrained, save_checkpoint, auto_resume_helper, hook_scale_grad\n\nassert torch.__version__ >= '1.8.0', \"DDP-based MoE requires Pytorch >= 1.8.0\"\n\n# pytorch major version (1.x or 2.x)\nPYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])\n\n\ndef parse_option():\n    parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)\n    parser.add_argument('--cfg', type=str, required=True, metavar=\"FILE\", help='path to config file', )\n    parser.add_argument(\n        \"--opts\",\n        help=\"Modify config options by adding 'KEY VALUE' pairs. \",\n        default=None,\n        nargs='+',\n    )\n\n    # easy config modification\n    parser.add_argument('--batch-size', type=int, help=\"batch size for single GPU\")\n    parser.add_argument('--data-path', type=str, help='path to dataset')\n    parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')\n    parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],\n                        help='no: no cache, '\n                             'full: cache all data, '\n                             'part: sharding the dataset into nonoverlapping pieces and only cache one piece')\n    parser.add_argument('--pretrained',\n                        help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')\n    parser.add_argument('--resume', help='resume from checkpoint')\n    parser.add_argument('--accumulation-steps', type=int, help=\"gradient accumulation steps\")\n    parser.add_argument('--use-checkpoint', action='store_true',\n                        help=\"whether to use gradient checkpointing to save memory\")\n    parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp')\n    parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'],\n                        help='mixed precision opt level, if O0, no amp is used (deprecated!)')\n    parser.add_argument('--output', default='output', type=str, metavar='PATH',\n                        help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')\n    parser.add_argument('--tag', help='tag of experiment')\n    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')\n    parser.add_argument('--throughput', action='store_true', help='Test throughput only')\n\n    # distributed training\n    # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead\n    # (see https://pytorch.org/docs/stable/distributed.html#launch-utility)\n    if PYTORCH_MAJOR_VERSION == 1:\n        parser.add_argument(\"--local_rank\", type=int, required=True, help='local rank for DistributedDataParallel')\n\n    args, unparsed = parser.parse_known_args()\n\n    config = get_config(args)\n\n    return args, config\n\n\ndef main(config):\n    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)\n\n    logger.info(f\"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}\")\n    model = build_model(config)\n    logger.info(str(model))\n\n    # For Tutel MoE\n    for name, param in model.named_parameters():\n        if param.requires_grad == True and hasattr(param, 'skip_allreduce') and param.skip_allreduce is True:\n            model.add_param_to_skip_allreduce(name)\n            param.register_hook(partial(hook_scale_grad, dist.get_world_size()))\n            logger.info(f\"[rank{dist.get_rank()}] [{name}] skip all_reduce and div {dist.get_world_size()} for grad\")\n\n    n_parameters_single = sum(p.numel() * model.sharded_count if hasattr(p, 'skip_allreduce')\n                              else p.numel() for p in model.parameters() if p.requires_grad)\n    logger.info(f\"number of params single: {n_parameters_single}\")\n    n_parameters_whole = sum(p.numel() * model.sharded_count * model.global_experts if hasattr(p, 'skip_allreduce')\n                             else p.numel() for p in model.parameters() if p.requires_grad)\n    logger.info(f\"number of params whole: {n_parameters_whole}\")\n    if hasattr(model, 'flops'):\n        flops = model.flops()\n        logger.info(f\"number of GFLOPs: {flops / 1e9}\")\n\n    model.cuda(config.LOCAL_RANK)\n    model_without_ddp = model\n\n    optimizer = build_optimizer(config, model)\n    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)\n    loss_scaler = NativeScalerWithGradNormCount()\n\n    if config.TRAIN.ACCUMULATION_STEPS > 1:\n        lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS)\n    else:\n        lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))\n\n    if config.AUG.MIXUP > 0.:\n        # smoothing is handled with mixup label transform\n        criterion = SoftTargetCrossEntropy()\n    elif config.MODEL.LABEL_SMOOTHING > 0.:\n        criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)\n    else:\n        criterion = torch.nn.CrossEntropyLoss()\n\n    max_accuracy = 0.0\n\n    if config.TRAIN.AUTO_RESUME:\n        resume_file = auto_resume_helper(config.OUTPUT, config.TRAIN.MOE.SAVE_MASTER)\n        if resume_file:\n            if config.MODEL.RESUME:\n                logger.warning(f\"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}\")\n            config.defrost()\n            config.MODEL.RESUME = resume_file\n            config.freeze()\n            logger.info(f'auto resuming from {resume_file}')\n        else:\n            logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')\n\n    if config.MODEL.RESUME:\n        max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger)\n        acc1, acc5, loss = validate(config, data_loader_val, model)\n        logger.info(f\"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%\")\n        if config.EVAL_MODE:\n            return\n\n    if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):\n        load_pretrained(config, model_without_ddp, logger)\n        acc1, acc5, loss = validate(config, data_loader_val, model)\n        logger.info(f\"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%\")\n        if config.EVAL_MODE:\n            return\n\n    if config.THROUGHPUT_MODE:\n        throughput(data_loader_val, model, logger)\n        return\n\n    logger.info(\"Start training\")\n    start_time = time.time()\n    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):\n        data_loader_train.sampler.set_epoch(epoch)\n\n        train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler,\n                        loss_scaler)\n        if (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):\n            save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler,\n                            logger)\n\n        acc1, acc5, loss = validate(config, data_loader_val, model)\n        logger.info(f\"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%\")\n        max_accuracy = max(max_accuracy, acc1)\n        logger.info(f'Max accuracy: {max_accuracy:.2f}%')\n    save_checkpoint(config, 'final', model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler,\n                    logger, zero_redundancy=True)\n    total_time = time.time() - start_time\n    total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n    logger.info('Training time {}'.format(total_time_str))\n\n\ndef train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler):\n    model.train()\n    optimizer.zero_grad()\n\n    num_steps = len(data_loader)\n    batch_time = AverageMeter()\n    loss_meter = AverageMeter()\n    loss_aux_meter = AverageMeter()\n    loss_cls_meter = AverageMeter()\n    norm_meter = AverageMeter()\n    scaler_meter = AverageMeter()\n\n    start = time.time()\n    end = time.time()\n    for idx, (samples, targets) in enumerate(data_loader):\n        samples = samples.cuda(non_blocking=True)\n        targets = targets.cuda(non_blocking=True)\n\n        if mixup_fn is not None:\n            samples, targets = mixup_fn(samples, targets)\n\n        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):\n            outputs, l_aux = model(samples)\n        l_cls = criterion(outputs, targets)\n        loss = l_cls + l_aux\n        loss = loss / config.TRAIN.ACCUMULATION_STEPS\n\n        # this attribute is added by timm on one optimizer (adahessian)\n        is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order\n        grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD,\n                                parameters=model.parameters(), create_graph=is_second_order,\n                                update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0)\n        if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:\n            optimizer.zero_grad()\n            lr_scheduler.step_update((epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS)\n        loss_scale_value = loss_scaler.state_dict()[\"scale\"]\n\n        torch.cuda.synchronize()\n\n        loss_meter.update(loss.item(), targets.size(0))\n        loss_cls_meter.update(l_cls.item(), targets.size(0))\n        loss_aux_meter.update(l_aux if isinstance(l_aux, float) else l_aux.item(), targets.size(0))\n        if grad_norm is not None:  # loss_scaler return None if not update\n            norm_meter.update(grad_norm)\n        scaler_meter.update(loss_scale_value)\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if idx % config.PRINT_FREQ == 0:\n            lr = optimizer.param_groups[0]['lr']\n            wd = optimizer.param_groups[0]['weight_decay']\n            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)\n            etas = batch_time.avg * (num_steps - idx)\n            logger.info(\n                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\\t'\n                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\\t wd {wd:.4f}\\t'\n                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\\t'\n                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\\t'\n                f'loss-cls {loss_cls_meter.val:.4f} ({loss_cls_meter.avg:.4f})\\t'\n                f'loss-aux {loss_aux_meter.val:.4f} ({loss_aux_meter.avg:.4f})\\t'\n                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\\t'\n                f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\\t'\n                f'mem {memory_used:.0f}MB')\n    epoch_time = time.time() - start\n    logger.info(f\"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}\")\n\n\n@torch.no_grad()\ndef validate(config, data_loader, model):\n    criterion = torch.nn.CrossEntropyLoss()\n    model.eval()\n\n    batch_time = AverageMeter()\n    loss_cls_meter = AverageMeter()\n    loss_aux_meter = AverageMeter()\n    acc1_meter = AverageMeter()\n    acc5_meter = AverageMeter()\n\n    end = time.time()\n    for idx, (images, target) in enumerate(data_loader):\n        images = images.cuda(non_blocking=True)\n        target = target.cuda(non_blocking=True)\n\n        # compute output\n        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):\n            output, l_aux = model(images)\n\n        # measure accuracy and record loss\n        l_cls = criterion(output, target)\n        acc1, acc5 = accuracy(output, target, topk=(1, 5))\n\n        acc1 = reduce_tensor(acc1)\n        acc5 = reduce_tensor(acc5)\n\n        loss_cls_meter.update(l_cls.item(), target.size(0))\n        loss_aux_meter.update(l_aux if isinstance(l_aux, float) else l_aux.item(), target.size(0))\n        acc1_meter.update(acc1.item(), target.size(0))\n        acc5_meter.update(acc5.item(), target.size(0))\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if idx % config.PRINT_FREQ == 0:\n            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)\n            logger.info(\n                f'Test: [{idx}/{len(data_loader)}]\\t'\n                f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                f'Loss-Cls {loss_cls_meter.val:.4f} ({loss_cls_meter.avg:.4f})\\t'\n                f'Loss-Aux {loss_aux_meter.val:.4f} ({loss_aux_meter.avg:.4f})\\t'\n                f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\\t'\n                f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\\t'\n                f'Mem {memory_used:.0f}MB')\n    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')\n    return acc1_meter.avg, acc5_meter.avg, loss_cls_meter.avg\n\n\n@torch.no_grad()\ndef throughput(data_loader, model, logger):\n    model.eval()\n\n    for idx, (images, _) in enumerate(data_loader):\n        images = images.cuda(non_blocking=True)\n        batch_size = images.shape[0]\n        for i in range(50):\n            model(images)\n        torch.cuda.synchronize()\n        logger.info(f\"throughput averaged with 30 times\")\n        tic1 = time.time()\n        for i in range(30):\n            model(images)\n        torch.cuda.synchronize()\n        tic2 = time.time()\n        logger.info(f\"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}\")\n        return\n\n\nif __name__ == '__main__':\n    args, config = parse_option()\n\n    if config.AMP_OPT_LEVEL:\n        print(\"[warning] Apex amp has been deprecated, please use pytorch amp instead!\")\n\n    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:\n        rank = int(os.environ[\"RANK\"])\n        world_size = int(os.environ['WORLD_SIZE'])\n        print(f\"RANK and WORLD_SIZE in environ: {rank}/{world_size}\")\n    else:\n        rank = -1\n        world_size = -1\n    torch.cuda.set_device(config.LOCAL_RANK)\n    torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)\n    torch.distributed.barrier()\n\n    seed = config.SEED + dist.get_rank()\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n    cudnn.benchmark = True\n\n    # linear scale the learning rate according to total batch size, may not be optimal\n    linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0\n    linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0\n    linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0\n    # gradient accumulation also need to scale the learning rate\n    if config.TRAIN.ACCUMULATION_STEPS > 1:\n        linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS\n        linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS\n        linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS\n    config.defrost()\n    config.TRAIN.BASE_LR = linear_scaled_lr\n    config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr\n    config.TRAIN.MIN_LR = linear_scaled_min_lr\n    config.freeze()\n\n    os.makedirs(config.OUTPUT, exist_ok=True)\n    logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f\"{config.MODEL.NAME}\")\n\n    if dist.get_rank() == 0:\n        path = os.path.join(config.OUTPUT, \"config.json\")\n        with open(path, \"w\") as f:\n            f.write(config.dump())\n        logger.info(f\"Full config saved to {path}\")\n\n    # print config\n    logger.info(config.dump())\n    logger.info(json.dumps(vars(args)))\n\n    main(config)\n"
  },
  {
    "path": "main_simmim_ft.py",
    "content": "# --------------------------------------------------------\n# SimMIM\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# Modified by Zhenda Xie\n# --------------------------------------------------------\n\nimport os\nimport time\nimport argparse\nimport datetime\nimport numpy as np\n\nimport torch\nimport torch.backends.cudnn as cudnn\nimport torch.distributed as dist\nimport torch.cuda.amp as amp\n\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy\nfrom timm.utils import accuracy, AverageMeter\n\nfrom config import get_config\nfrom models import build_model\nfrom data import build_loader\nfrom lr_scheduler import build_scheduler\nfrom optimizer import build_optimizer\nfrom logger import create_logger\nfrom utils_simmim import load_checkpoint, load_pretrained, save_checkpoint, get_grad_norm, auto_resume_helper, \\\n    reduce_tensor\n\n# pytorch major version (1.x or 2.x)\nPYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])\n\n\ndef parse_option():\n    parser = argparse.ArgumentParser('SimMIM fine-tuning script', add_help=False)\n    parser.add_argument('--cfg', type=str, required=True, metavar=\"FILE\", help='path to config file', )\n    parser.add_argument(\n        \"--opts\",\n        help=\"Modify config options by adding 'KEY VALUE' pairs. \",\n        default=None,\n        nargs='+',\n    )\n\n    # easy config modification\n    parser.add_argument('--batch-size', type=int, help=\"batch size for single GPU\")\n    parser.add_argument('--data-path', type=str, help='path to dataset')\n    parser.add_argument('--pretrained', type=str, help='path to pre-trained model')\n    parser.add_argument('--resume', help='resume from checkpoint')\n    parser.add_argument('--accumulation-steps', type=int, help=\"gradient accumulation steps\")\n    parser.add_argument('--use-checkpoint', action='store_true',\n                        help=\"whether to use gradient checkpointing to save memory\")\n    parser.add_argument('--enable-amp', action='store_true')\n    parser.add_argument('--disable-amp', action='store_false', dest='enable_amp')\n    parser.set_defaults(enable_amp=True)\n    parser.add_argument('--output', default='output', type=str, metavar='PATH',\n                        help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')\n    parser.add_argument('--tag', help='tag of experiment')\n    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')\n    parser.add_argument('--throughput', action='store_true', help='Test throughput only')\n\n    # distributed training\n    # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead\n    # (see https://pytorch.org/docs/stable/distributed.html#launch-utility)\n    if PYTORCH_MAJOR_VERSION == 1:\n        parser.add_argument(\"--local_rank\", type=int, required=True, help='local rank for DistributedDataParallel')\n\n    args = parser.parse_args()\n\n    config = get_config(args)\n\n    return args, config\n\n\ndef main(config):\n    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, simmim=True,\n                                                                                            is_pretrain=False)\n\n    logger.info(f\"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}\")\n    model = build_model(config, is_pretrain=False)\n    model.cuda()\n    logger.info(str(model))\n\n    optimizer = build_optimizer(config, model, simmim=True, is_pretrain=False)\n    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)\n    model_without_ddp = model.module\n\n    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)\n    logger.info(f\"number of params: {n_parameters}\")\n    if hasattr(model_without_ddp, 'flops'):\n        flops = model_without_ddp.flops()\n        logger.info(f\"number of GFLOPs: {flops / 1e9}\")\n\n    lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))\n    scaler = amp.GradScaler()\n\n    if config.AUG.MIXUP > 0.:\n        # smoothing is handled with mixup label transform\n        criterion = SoftTargetCrossEntropy()\n    elif config.MODEL.LABEL_SMOOTHING > 0.:\n        criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)\n    else:\n        criterion = torch.nn.CrossEntropyLoss()\n\n    max_accuracy = 0.0\n\n    if config.TRAIN.AUTO_RESUME:\n        resume_file = auto_resume_helper(config.OUTPUT, logger)\n        if resume_file:\n            if config.MODEL.RESUME:\n                logger.warning(f\"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}\")\n            config.defrost()\n            config.MODEL.RESUME = resume_file\n            config.freeze()\n            logger.info(f'auto resuming from {resume_file}')\n        else:\n            logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')\n\n    if config.MODEL.RESUME:\n        max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger)\n        acc1, acc5, loss = validate(config, data_loader_val, model)\n        logger.info(f\"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%\")\n        if config.EVAL_MODE:\n            return\n\n    if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):\n        load_pretrained(config, model_without_ddp, logger)\n        acc1, acc5, loss = validate(config, data_loader_val, model)\n        logger.info(f\"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%\")\n\n    if config.THROUGHPUT_MODE:\n        throughput(data_loader_val, model, logger)\n        return\n\n    logger.info(\"Start training\")\n    start_time = time.time()\n    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):\n        data_loader_train.sampler.set_epoch(epoch)\n\n        train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, scaler)\n        if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):\n            save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, scaler, logger)\n\n        acc1, acc5, loss = validate(config, data_loader_val, model)\n        logger.info(f\"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%\")\n        max_accuracy = max(max_accuracy, acc1)\n        logger.info(f'Max accuracy: {max_accuracy:.2f}%')\n\n    total_time = time.time() - start_time\n    total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n    logger.info('Training time {}'.format(total_time_str))\n\n\ndef train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, scaler):\n    model.train()\n    optimizer.zero_grad()\n\n    logger.info(f'Current learning rate for different parameter groups: {[it[\"lr\"] for it in optimizer.param_groups]}')\n\n    num_steps = len(data_loader)\n    batch_time = AverageMeter()\n    loss_meter = AverageMeter()\n    norm_meter = AverageMeter()\n    loss_scale_meter = AverageMeter()\n\n    start = time.time()\n    end = time.time()\n    for idx, (samples, targets) in enumerate(data_loader):\n        samples = samples.cuda(non_blocking=True)\n        targets = targets.cuda(non_blocking=True)\n\n        if mixup_fn is not None:\n            samples, targets = mixup_fn(samples, targets)\n\n        outputs = model(samples)\n\n        if config.TRAIN.ACCUMULATION_STEPS > 1:\n            loss = criterion(outputs, targets)\n            loss = loss / config.TRAIN.ACCUMULATION_STEPS\n            scaler.scale(loss).backward()\n            if config.TRAIN.CLIP_GRAD:\n                scaler.unscale_(optimizer)\n                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)\n            else:\n                grad_norm = get_grad_norm(model.parameters())\n            if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:\n                scaler.step(optimizer)\n                optimizer.zero_grad()\n                scaler.update()\n                lr_scheduler.step_update(epoch * num_steps + idx)\n        else:\n            loss = criterion(outputs, targets)\n            optimizer.zero_grad()\n            scaler.scale(loss).backward()\n            if config.TRAIN.CLIP_GRAD:\n                scaler.unscale_(optimizer)\n                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)\n            else:\n                grad_norm = get_grad_norm(model.parameters())\n            scaler.step(optimizer)\n            scaler.update()\n            lr_scheduler.step_update(epoch * num_steps + idx)\n\n        torch.cuda.synchronize()\n\n        loss_meter.update(loss.item(), targets.size(0))\n        norm_meter.update(grad_norm)\n        loss_scale_meter.update(scaler.get_scale())\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if idx % config.PRINT_FREQ == 0:\n            lr = optimizer.param_groups[-1]['lr']\n            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)\n            etas = batch_time.avg * (num_steps - idx)\n            logger.info(\n                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\\t'\n                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\\t'\n                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\\t'\n                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\\t'\n                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\\t'\n                f'loss_scale {loss_scale_meter.val:.4f} ({loss_scale_meter.avg:.4f})\\t'\n                f'mem {memory_used:.0f}MB')\n    epoch_time = time.time() - start\n    logger.info(f\"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}\")\n\n\n@torch.no_grad()\ndef validate(config, data_loader, model):\n    criterion = torch.nn.CrossEntropyLoss()\n    model.eval()\n\n    batch_time = AverageMeter()\n    loss_meter = AverageMeter()\n    acc1_meter = AverageMeter()\n    acc5_meter = AverageMeter()\n\n    end = time.time()\n    for idx, (images, target) in enumerate(data_loader):\n        images = images.cuda(non_blocking=True)\n        target = target.cuda(non_blocking=True)\n\n        # compute output\n        output = model(images)\n\n        # measure accuracy and record loss\n        loss = criterion(output, target)\n        acc1, acc5 = accuracy(output, target, topk=(1, 5))\n\n        acc1 = reduce_tensor(acc1)\n        acc5 = reduce_tensor(acc5)\n        loss = reduce_tensor(loss)\n\n        loss_meter.update(loss.item(), target.size(0))\n        acc1_meter.update(acc1.item(), target.size(0))\n        acc5_meter.update(acc5.item(), target.size(0))\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if idx % config.PRINT_FREQ == 0:\n            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)\n            logger.info(\n                f'Test: [{idx}/{len(data_loader)}]\\t'\n                f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\\t'\n                f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\\t'\n                f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\\t'\n                f'Mem {memory_used:.0f}MB')\n    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')\n    return acc1_meter.avg, acc5_meter.avg, loss_meter.avg\n\n\n@torch.no_grad()\ndef throughput(data_loader, model, logger):\n    model.eval()\n\n    for idx, (images, _) in enumerate(data_loader):\n        images = images.cuda(non_blocking=True)\n        batch_size = images.shape[0]\n        for i in range(50):\n            model(images)\n        torch.cuda.synchronize()\n        logger.info(f\"throughput averaged with 30 times\")\n        tic1 = time.time()\n        for i in range(30):\n            model(images)\n        torch.cuda.synchronize()\n        tic2 = time.time()\n        logger.info(f\"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}\")\n        return\n\n\nif __name__ == '__main__':\n    _, config = parse_option()\n\n    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:\n        rank = int(os.environ[\"RANK\"])\n        world_size = int(os.environ['WORLD_SIZE'])\n        print(f\"RANK and WORLD_SIZE in environ: {rank}/{world_size}\")\n    else:\n        rank = -1\n        world_size = -1\n    torch.cuda.set_device(config.LOCAL_RANK)\n    torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)\n    torch.distributed.barrier()\n\n    seed = config.SEED + dist.get_rank()\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n    cudnn.benchmark = True\n\n    # linear scale the learning rate according to total batch size, may not be optimal\n    linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0\n    linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0\n    linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0\n    # gradient accumulation also need to scale the learning rate\n    if config.TRAIN.ACCUMULATION_STEPS > 1:\n        linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS\n        linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS\n        linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS\n    config.defrost()\n    config.TRAIN.BASE_LR = linear_scaled_lr\n    config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr\n    config.TRAIN.MIN_LR = linear_scaled_min_lr\n    config.freeze()\n\n    os.makedirs(config.OUTPUT, exist_ok=True)\n    logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f\"{config.MODEL.NAME}\")\n\n    if dist.get_rank() == 0:\n        path = os.path.join(config.OUTPUT, \"config.json\")\n        with open(path, \"w\") as f:\n            f.write(config.dump())\n        logger.info(f\"Full config saved to {path}\")\n\n    # print config\n    logger.info(config.dump())\n\n    main(config)\n"
  },
  {
    "path": "main_simmim_pt.py",
    "content": "# --------------------------------------------------------\n# SimMIM\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# Modified by Zhenda Xie\n# --------------------------------------------------------\n\nimport os\nimport time\nimport argparse\nimport datetime\nimport numpy as np\n\nimport torch\nimport torch.backends.cudnn as cudnn\nimport torch.distributed as dist\nimport torch.cuda.amp as amp\nfrom timm.utils import AverageMeter\n\nfrom config import get_config\nfrom models import build_model\nfrom data import build_loader\nfrom lr_scheduler import build_scheduler\nfrom optimizer import build_optimizer\nfrom logger import create_logger\nfrom utils_simmim import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper\n\n# pytorch major version (1.x or 2.x)\nPYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])\n\n\ndef parse_option():\n    parser = argparse.ArgumentParser('SimMIM pre-training script', add_help=False)\n    parser.add_argument('--cfg', type=str, required=True, metavar=\"FILE\", help='path to config file', )\n    parser.add_argument(\n        \"--opts\",\n        help=\"Modify config options by adding 'KEY VALUE' pairs. \",\n        default=None,\n        nargs='+',\n    )\n\n    # easy config modification\n    parser.add_argument('--batch-size', type=int, help=\"batch size for single GPU\")\n    parser.add_argument('--data-path', type=str, help='path to dataset')\n    parser.add_argument('--resume', help='resume from checkpoint')\n    parser.add_argument('--accumulation-steps', type=int, help=\"gradient accumulation steps\")\n    parser.add_argument('--use-checkpoint', action='store_true',\n                        help=\"whether to use gradient checkpointing to save memory\")\n    parser.add_argument('--enable-amp', action='store_true')\n    parser.add_argument('--disable-amp', action='store_false', dest='enable_amp')\n    parser.set_defaults(enable_amp=True)\n    parser.add_argument('--output', default='output', type=str, metavar='PATH',\n                        help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')\n    parser.add_argument('--tag', help='tag of experiment')\n\n    # distributed training\n    # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead\n    # (see https://pytorch.org/docs/stable/distributed.html#launch-utility)\n    if PYTORCH_MAJOR_VERSION == 1:\n        parser.add_argument(\"--local_rank\", type=int, required=True, help='local rank for DistributedDataParallel')\n\n    args = parser.parse_args()\n\n    config = get_config(args)\n\n    return args, config\n\n\ndef main(config):\n    data_loader_train = build_loader(config, simmim=True, is_pretrain=True)\n\n    logger.info(f\"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}\")\n    model = build_model(config, is_pretrain=True)\n    model.cuda()\n    logger.info(str(model))\n\n    optimizer = build_optimizer(config, model, simmim=True, is_pretrain=True)\n    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)\n    model_without_ddp = model.module\n\n    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)\n    logger.info(f\"number of params: {n_parameters}\")\n    if hasattr(model_without_ddp, 'flops'):\n        flops = model_without_ddp.flops()\n        logger.info(f\"number of GFLOPs: {flops / 1e9}\")\n\n    lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))\n    scaler = amp.GradScaler()\n\n    if config.TRAIN.AUTO_RESUME:\n        resume_file = auto_resume_helper(config.OUTPUT, logger)\n        if resume_file:\n            if config.MODEL.RESUME:\n                logger.warning(f\"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}\")\n            config.defrost()\n            config.MODEL.RESUME = resume_file\n            config.freeze()\n            logger.info(f'auto resuming from {resume_file}')\n        else:\n            logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')\n\n    if config.MODEL.RESUME:\n        load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger)\n\n    logger.info(\"Start training\")\n    start_time = time.time()\n    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):\n        data_loader_train.sampler.set_epoch(epoch)\n\n        train_one_epoch(config, model, data_loader_train, optimizer, epoch, lr_scheduler, scaler)\n        if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):\n            save_checkpoint(config, epoch, model_without_ddp, 0., optimizer, lr_scheduler, scaler, logger)\n\n    total_time = time.time() - start_time\n    total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n    logger.info('Training time {}'.format(total_time_str))\n\n\ndef train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler, scaler):\n    model.train()\n    optimizer.zero_grad()\n\n    num_steps = len(data_loader)\n    batch_time = AverageMeter()\n    loss_meter = AverageMeter()\n    norm_meter = AverageMeter()\n    loss_scale_meter = AverageMeter()\n\n    start = time.time()\n    end = time.time()\n    for idx, (img, mask, _) in enumerate(data_loader):\n        img = img.cuda(non_blocking=True)\n        mask = mask.cuda(non_blocking=True)\n\n        with amp.autocast(enabled=config.ENABLE_AMP):\n            loss = model(img, mask)\n\n        if config.TRAIN.ACCUMULATION_STEPS > 1:\n            loss = loss / config.TRAIN.ACCUMULATION_STEPS\n            scaler.scale(loss).backward()\n            if config.TRAIN.CLIP_GRAD:\n                scaler.unscale_(optimizer)\n                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)\n            else:\n                grad_norm = get_grad_norm(model.parameters())\n            if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:\n                scaler.step(optimizer)\n                optimizer.zero_grad()\n                scaler.update()\n                lr_scheduler.step_update(epoch * num_steps + idx)\n        else:\n            optimizer.zero_grad()\n            scaler.scale(loss).backward()\n            if config.TRAIN.CLIP_GRAD:\n                scaler.unscale_(optimizer)\n                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)\n            else:\n                grad_norm = get_grad_norm(model.parameters())\n            scaler.step(optimizer)\n            scaler.update()\n            lr_scheduler.step_update(epoch * num_steps + idx)\n\n        torch.cuda.synchronize()\n\n        loss_meter.update(loss.item(), img.size(0))\n        norm_meter.update(grad_norm)\n        loss_scale_meter.update(scaler.get_scale())\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if idx % config.PRINT_FREQ == 0:\n            lr = optimizer.param_groups[0]['lr']\n            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)\n            etas = batch_time.avg * (num_steps - idx)\n            logger.info(\n                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\\t'\n                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\\t'\n                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\\t'\n                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\\t'\n                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\\t'\n                f'loss_scale {loss_scale_meter.val:.4f} ({loss_scale_meter.avg:.4f})\\t'\n                f'mem {memory_used:.0f}MB')\n    epoch_time = time.time() - start\n    logger.info(f\"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}\")\n\n\nif __name__ == '__main__':\n    _, config = parse_option()\n\n    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:\n        rank = int(os.environ[\"RANK\"])\n        world_size = int(os.environ['WORLD_SIZE'])\n        print(f\"RANK and WORLD_SIZE in environ: {rank}/{world_size}\")\n    else:\n        rank = -1\n        world_size = -1\n    torch.cuda.set_device(config.LOCAL_RANK)\n    torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)\n    torch.distributed.barrier()\n\n    seed = config.SEED + dist.get_rank()\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n    cudnn.benchmark = True\n\n    # linear scale the learning rate according to total batch size, may not be optimal\n    linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0\n    linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0\n    linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0\n    # gradient accumulation also need to scale the learning rate\n    if config.TRAIN.ACCUMULATION_STEPS > 1:\n        linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS\n        linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS\n        linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS\n    config.defrost()\n    config.TRAIN.BASE_LR = linear_scaled_lr\n    config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr\n    config.TRAIN.MIN_LR = linear_scaled_min_lr\n    config.freeze()\n\n    os.makedirs(config.OUTPUT, exist_ok=True)\n    logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f\"{config.MODEL.NAME}\")\n\n    if dist.get_rank() == 0:\n        path = os.path.join(config.OUTPUT, \"config.json\")\n        with open(path, \"w\") as f:\n            f.write(config.dump())\n        logger.info(f\"Full config saved to {path}\")\n\n    # print config\n    logger.info(config.dump())\n\n    main(config)\n"
  },
  {
    "path": "models/__init__.py",
    "content": "from .build import build_model"
  },
  {
    "path": "models/build.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nfrom .swin_transformer import SwinTransformer\nfrom .swin_transformer_v2 import SwinTransformerV2\nfrom .swin_transformer_moe import SwinTransformerMoE\nfrom .swin_mlp import SwinMLP\nfrom .simmim import build_simmim\n\n\ndef build_model(config, is_pretrain=False):\n    model_type = config.MODEL.TYPE\n\n    # accelerate layernorm\n    if config.FUSED_LAYERNORM:\n        try:\n            import apex as amp\n            layernorm = amp.normalization.FusedLayerNorm\n        except:\n            layernorm = None\n            print(\"To use FusedLayerNorm, please install apex.\")\n    else:\n        import torch.nn as nn\n        layernorm = nn.LayerNorm\n\n    if is_pretrain:\n        model = build_simmim(config)\n        return model\n\n    if model_type == 'swin':\n        model = SwinTransformer(img_size=config.DATA.IMG_SIZE,\n                                patch_size=config.MODEL.SWIN.PATCH_SIZE,\n                                in_chans=config.MODEL.SWIN.IN_CHANS,\n                                num_classes=config.MODEL.NUM_CLASSES,\n                                embed_dim=config.MODEL.SWIN.EMBED_DIM,\n                                depths=config.MODEL.SWIN.DEPTHS,\n                                num_heads=config.MODEL.SWIN.NUM_HEADS,\n                                window_size=config.MODEL.SWIN.WINDOW_SIZE,\n                                mlp_ratio=config.MODEL.SWIN.MLP_RATIO,\n                                qkv_bias=config.MODEL.SWIN.QKV_BIAS,\n                                qk_scale=config.MODEL.SWIN.QK_SCALE,\n                                drop_rate=config.MODEL.DROP_RATE,\n                                drop_path_rate=config.MODEL.DROP_PATH_RATE,\n                                ape=config.MODEL.SWIN.APE,\n                                norm_layer=layernorm,\n                                patch_norm=config.MODEL.SWIN.PATCH_NORM,\n                                use_checkpoint=config.TRAIN.USE_CHECKPOINT,\n                                fused_window_process=config.FUSED_WINDOW_PROCESS)\n    elif model_type == 'swinv2':\n        model = SwinTransformerV2(img_size=config.DATA.IMG_SIZE,\n                                  patch_size=config.MODEL.SWINV2.PATCH_SIZE,\n                                  in_chans=config.MODEL.SWINV2.IN_CHANS,\n                                  num_classes=config.MODEL.NUM_CLASSES,\n                                  embed_dim=config.MODEL.SWINV2.EMBED_DIM,\n                                  depths=config.MODEL.SWINV2.DEPTHS,\n                                  num_heads=config.MODEL.SWINV2.NUM_HEADS,\n                                  window_size=config.MODEL.SWINV2.WINDOW_SIZE,\n                                  mlp_ratio=config.MODEL.SWINV2.MLP_RATIO,\n                                  qkv_bias=config.MODEL.SWINV2.QKV_BIAS,\n                                  drop_rate=config.MODEL.DROP_RATE,\n                                  drop_path_rate=config.MODEL.DROP_PATH_RATE,\n                                  ape=config.MODEL.SWINV2.APE,\n                                  patch_norm=config.MODEL.SWINV2.PATCH_NORM,\n                                  use_checkpoint=config.TRAIN.USE_CHECKPOINT,\n                                  pretrained_window_sizes=config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES)\n    elif model_type == 'swin_moe':\n        model = SwinTransformerMoE(img_size=config.DATA.IMG_SIZE,\n                                   patch_size=config.MODEL.SWIN_MOE.PATCH_SIZE,\n                                   in_chans=config.MODEL.SWIN_MOE.IN_CHANS,\n                                   num_classes=config.MODEL.NUM_CLASSES,\n                                   embed_dim=config.MODEL.SWIN_MOE.EMBED_DIM,\n                                   depths=config.MODEL.SWIN_MOE.DEPTHS,\n                                   num_heads=config.MODEL.SWIN_MOE.NUM_HEADS,\n                                   window_size=config.MODEL.SWIN_MOE.WINDOW_SIZE,\n                                   mlp_ratio=config.MODEL.SWIN_MOE.MLP_RATIO,\n                                   qkv_bias=config.MODEL.SWIN_MOE.QKV_BIAS,\n                                   qk_scale=config.MODEL.SWIN_MOE.QK_SCALE,\n                                   drop_rate=config.MODEL.DROP_RATE,\n                                   drop_path_rate=config.MODEL.DROP_PATH_RATE,\n                                   ape=config.MODEL.SWIN_MOE.APE,\n                                   patch_norm=config.MODEL.SWIN_MOE.PATCH_NORM,\n                                   mlp_fc2_bias=config.MODEL.SWIN_MOE.MLP_FC2_BIAS,\n                                   init_std=config.MODEL.SWIN_MOE.INIT_STD,\n                                   use_checkpoint=config.TRAIN.USE_CHECKPOINT,\n                                   pretrained_window_sizes=config.MODEL.SWIN_MOE.PRETRAINED_WINDOW_SIZES,\n                                   moe_blocks=config.MODEL.SWIN_MOE.MOE_BLOCKS,\n                                   num_local_experts=config.MODEL.SWIN_MOE.NUM_LOCAL_EXPERTS,\n                                   top_value=config.MODEL.SWIN_MOE.TOP_VALUE,\n                                   capacity_factor=config.MODEL.SWIN_MOE.CAPACITY_FACTOR,\n                                   cosine_router=config.MODEL.SWIN_MOE.COSINE_ROUTER,\n                                   normalize_gate=config.MODEL.SWIN_MOE.NORMALIZE_GATE,\n                                   use_bpr=config.MODEL.SWIN_MOE.USE_BPR,\n                                   is_gshard_loss=config.MODEL.SWIN_MOE.IS_GSHARD_LOSS,\n                                   gate_noise=config.MODEL.SWIN_MOE.GATE_NOISE,\n                                   cosine_router_dim=config.MODEL.SWIN_MOE.COSINE_ROUTER_DIM,\n                                   cosine_router_init_t=config.MODEL.SWIN_MOE.COSINE_ROUTER_INIT_T,\n                                   moe_drop=config.MODEL.SWIN_MOE.MOE_DROP,\n                                   aux_loss_weight=config.MODEL.SWIN_MOE.AUX_LOSS_WEIGHT)\n    elif model_type == 'swin_mlp':\n        model = SwinMLP(img_size=config.DATA.IMG_SIZE,\n                        patch_size=config.MODEL.SWIN_MLP.PATCH_SIZE,\n                        in_chans=config.MODEL.SWIN_MLP.IN_CHANS,\n                        num_classes=config.MODEL.NUM_CLASSES,\n                        embed_dim=config.MODEL.SWIN_MLP.EMBED_DIM,\n                        depths=config.MODEL.SWIN_MLP.DEPTHS,\n                        num_heads=config.MODEL.SWIN_MLP.NUM_HEADS,\n                        window_size=config.MODEL.SWIN_MLP.WINDOW_SIZE,\n                        mlp_ratio=config.MODEL.SWIN_MLP.MLP_RATIO,\n                        drop_rate=config.MODEL.DROP_RATE,\n                        drop_path_rate=config.MODEL.DROP_PATH_RATE,\n                        ape=config.MODEL.SWIN_MLP.APE,\n                        patch_norm=config.MODEL.SWIN_MLP.PATCH_NORM,\n                        use_checkpoint=config.TRAIN.USE_CHECKPOINT)\n    else:\n        raise NotImplementedError(f\"Unkown model: {model_type}\")\n\n    return model\n"
  },
  {
    "path": "models/simmim.py",
    "content": "\n\n# --------------------------------------------------------\n# SimMIM\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Zhenda Xie\n# --------------------------------------------------------\n\nfrom functools import partial\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom timm.models.layers import trunc_normal_\n\nfrom .swin_transformer import SwinTransformer\nfrom .swin_transformer_v2 import SwinTransformerV2\n\n\ndef norm_targets(targets, patch_size):\n    assert patch_size % 2 == 1\n    \n    targets_ = targets\n    targets_count = torch.ones_like(targets)\n\n    targets_square = targets ** 2.\n    \n    targets_mean = F.avg_pool2d(targets, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=False)\n    targets_square_mean = F.avg_pool2d(targets_square, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=False)\n    targets_count = F.avg_pool2d(targets_count, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=True) * (patch_size ** 2)\n    \n    targets_var = (targets_square_mean - targets_mean ** 2.) * (targets_count / (targets_count - 1))\n    targets_var = torch.clamp(targets_var, min=0.)\n    \n    targets_ = (targets_ - targets_mean) / (targets_var + 1.e-6) ** 0.5\n    \n    return targets_\n\n\nclass SwinTransformerForSimMIM(SwinTransformer):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n        assert self.num_classes == 0\n\n        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))\n        trunc_normal_(self.mask_token, mean=0., std=.02)\n\n    def forward(self, x, mask):\n        x = self.patch_embed(x)\n\n        assert mask is not None\n        B, L, _ = x.shape\n\n        mask_tokens = self.mask_token.expand(B, L, -1)\n        w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)\n        x = x * (1. - w) + mask_tokens * w\n\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        for layer in self.layers:\n            x = layer(x)\n        x = self.norm(x)\n\n        x = x.transpose(1, 2)\n        B, C, L = x.shape\n        H = W = int(L ** 0.5)\n        x = x.reshape(B, C, H, W)\n        return x\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return super().no_weight_decay() | {'mask_token'}\n\n\nclass SwinTransformerV2ForSimMIM(SwinTransformerV2):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n        assert self.num_classes == 0\n\n        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))\n        trunc_normal_(self.mask_token, mean=0., std=.02)\n\n    def forward(self, x, mask):\n        x = self.patch_embed(x)\n\n        assert mask is not None\n        B, L, _ = x.shape\n\n        mask_tokens = self.mask_token.expand(B, L, -1)\n        w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)\n        x = x * (1. - w) + mask_tokens * w\n\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        for layer in self.layers:\n            x = layer(x)\n        x = self.norm(x)\n\n        x = x.transpose(1, 2)\n        B, C, L = x.shape\n        H = W = int(L ** 0.5)\n        x = x.reshape(B, C, H, W)\n        return x\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return super().no_weight_decay() | {'mask_token'}\n\n\nclass SimMIM(nn.Module):\n    def __init__(self, config, encoder, encoder_stride, in_chans, patch_size):\n        super().__init__()\n        self.config = config\n        self.encoder = encoder\n        self.encoder_stride = encoder_stride\n\n        self.decoder = nn.Sequential(\n            nn.Conv2d(\n                in_channels=self.encoder.num_features,\n                out_channels=self.encoder_stride ** 2 * 3, kernel_size=1),\n            nn.PixelShuffle(self.encoder_stride),\n        )\n\n        self.in_chans = in_chans\n        self.patch_size = patch_size\n\n    def forward(self, x, mask):\n        z = self.encoder(x, mask)\n        x_rec = self.decoder(z)\n\n        mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous()\n        \n        # norm target as prompted\n        if self.config.NORM_TARGET.ENABLE:\n            x = norm_targets(x, self.config.NORM_TARGET.PATCH_SIZE)\n        \n        loss_recon = F.l1_loss(x, x_rec, reduction='none')\n        loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans\n        return loss\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        if hasattr(self.encoder, 'no_weight_decay'):\n            return {'encoder.' + i for i in self.encoder.no_weight_decay()}\n        return {}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        if hasattr(self.encoder, 'no_weight_decay_keywords'):\n            return {'encoder.' + i for i in self.encoder.no_weight_decay_keywords()}\n        return {}\n\n\ndef build_simmim(config):\n    model_type = config.MODEL.TYPE\n    if model_type == 'swin':\n        encoder = SwinTransformerForSimMIM(\n            img_size=config.DATA.IMG_SIZE,\n            patch_size=config.MODEL.SWIN.PATCH_SIZE,\n            in_chans=config.MODEL.SWIN.IN_CHANS,\n            num_classes=0,\n            embed_dim=config.MODEL.SWIN.EMBED_DIM,\n            depths=config.MODEL.SWIN.DEPTHS,\n            num_heads=config.MODEL.SWIN.NUM_HEADS,\n            window_size=config.MODEL.SWIN.WINDOW_SIZE,\n            mlp_ratio=config.MODEL.SWIN.MLP_RATIO,\n            qkv_bias=config.MODEL.SWIN.QKV_BIAS,\n            qk_scale=config.MODEL.SWIN.QK_SCALE,\n            drop_rate=config.MODEL.DROP_RATE,\n            drop_path_rate=config.MODEL.DROP_PATH_RATE,\n            ape=config.MODEL.SWIN.APE,\n            patch_norm=config.MODEL.SWIN.PATCH_NORM,\n            use_checkpoint=config.TRAIN.USE_CHECKPOINT)\n        encoder_stride = 32\n        in_chans = config.MODEL.SWIN.IN_CHANS\n        patch_size = config.MODEL.SWIN.PATCH_SIZE\n    elif model_type == 'swinv2':\n        encoder = SwinTransformerV2ForSimMIM(\n            img_size=config.DATA.IMG_SIZE,\n            patch_size=config.MODEL.SWINV2.PATCH_SIZE,\n            in_chans=config.MODEL.SWINV2.IN_CHANS,\n            num_classes=0,\n            embed_dim=config.MODEL.SWINV2.EMBED_DIM,\n            depths=config.MODEL.SWINV2.DEPTHS,\n            num_heads=config.MODEL.SWINV2.NUM_HEADS,\n            window_size=config.MODEL.SWINV2.WINDOW_SIZE,\n            mlp_ratio=config.MODEL.SWINV2.MLP_RATIO,\n            qkv_bias=config.MODEL.SWINV2.QKV_BIAS,\n            drop_rate=config.MODEL.DROP_RATE,\n            drop_path_rate=config.MODEL.DROP_PATH_RATE,\n            ape=config.MODEL.SWINV2.APE,\n            patch_norm=config.MODEL.SWINV2.PATCH_NORM,\n            use_checkpoint=config.TRAIN.USE_CHECKPOINT)\n        encoder_stride = 32\n        in_chans = config.MODEL.SWINV2.IN_CHANS\n        patch_size = config.MODEL.SWINV2.PATCH_SIZE\n    else:\n        raise NotImplementedError(f\"Unknown pre-train model: {model_type}\")\n\n    model = SimMIM(config=config.MODEL.SIMMIM, encoder=encoder, encoder_stride=encoder_stride, in_chans=in_chans, patch_size=patch_size)\n\n    return model"
  },
  {
    "path": "models/swin_mlp.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass SwinMLPBlock(nn.Module):\n    r\"\"\" Swin MLP Block.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        drop (float, optional): Dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\n                 mlp_ratio=4., drop=0., drop_path=0.,\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        if min(self.input_resolution) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            self.shift_size = 0\n            self.window_size = min(self.input_resolution)\n        assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n\n        self.padding = [self.window_size - self.shift_size, self.shift_size,\n                        self.window_size - self.shift_size, self.shift_size]  # P_l,P_r,P_t,P_b\n\n        self.norm1 = norm_layer(dim)\n        # use group convolution to implement multi-head MLP\n        self.spatial_mlp = nn.Conv1d(self.num_heads * self.window_size ** 2,\n                                     self.num_heads * self.window_size ** 2,\n                                     kernel_size=1,\n                                     groups=self.num_heads)\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n    def forward(self, x):\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # shift\n        if self.shift_size > 0:\n            P_l, P_r, P_t, P_b = self.padding\n            shifted_x = F.pad(x, [0, 0, P_l, P_r, P_t, P_b], \"constant\", 0)\n        else:\n            shifted_x = x\n        _, _H, _W, _ = shifted_x.shape\n\n        # partition windows\n        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C\n        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C\n\n        # Window/Shifted-Window Spatial MLP\n        x_windows_heads = x_windows.view(-1, self.window_size * self.window_size, self.num_heads, C // self.num_heads)\n        x_windows_heads = x_windows_heads.transpose(1, 2)  # nW*B, nH, window_size*window_size, C//nH\n        x_windows_heads = x_windows_heads.reshape(-1, self.num_heads * self.window_size * self.window_size,\n                                                  C // self.num_heads)\n        spatial_mlp_windows = self.spatial_mlp(x_windows_heads)  # nW*B, nH*window_size*window_size, C//nH\n        spatial_mlp_windows = spatial_mlp_windows.view(-1, self.num_heads, self.window_size * self.window_size,\n                                                       C // self.num_heads).transpose(1, 2)\n        spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size * self.window_size, C)\n\n        # merge windows\n        spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size, self.window_size, C)\n        shifted_x = window_reverse(spatial_mlp_windows, self.window_size, _H, _W)  # B H' W' C\n\n        # reverse shift\n        if self.shift_size > 0:\n            P_l, P_r, P_t, P_b = self.padding\n            x = shifted_x[:, P_t:-P_b, P_l:-P_r, :].contiguous()\n        else:\n            x = shifted_x\n        x = x.view(B, H * W, C)\n\n        # FFN\n        x = shortcut + self.drop_path(x)\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n               f\"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"\n\n    def flops(self):\n        flops = 0\n        H, W = self.input_resolution\n        # norm1\n        flops += self.dim * H * W\n\n        # Window/Shifted-Window Spatial MLP\n        if self.shift_size > 0:\n            nW = (H / self.window_size + 1) * (W / self.window_size + 1)\n        else:\n            nW = H * W / self.window_size / self.window_size\n        flops += nW * self.dim * (self.window_size * self.window_size) * (self.window_size * self.window_size)\n        # mlp\n        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n        # norm2\n        flops += self.dim * H * W\n        return flops\n\n\nclass PatchMerging(nn.Module):\n    r\"\"\" Patch Merging Layer.\n\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n        assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n\n        x = x.view(B, H, W, C)\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.norm(x)\n        x = self.reduction(x)\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n\n    def flops(self):\n        H, W = self.input_resolution\n        flops = H * W * self.dim\n        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim\n        return flops\n\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic Swin MLP layer for one stage.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        drop (float, optional): Dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n                 mlp_ratio=4., drop=0., drop_path=0.,\n                 norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            SwinMLPBlock(dim=dim, input_resolution=input_resolution,\n                         num_heads=num_heads, window_size=window_size,\n                         shift_size=0 if (i % 2 == 0) else window_size // 2,\n                         mlp_ratio=mlp_ratio,\n                         drop=drop,\n                         drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                         norm_layer=norm_layer)\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops\n\n\nclass PatchEmbed(nn.Module):\n    r\"\"\" Image to Patch Embedding\n\n    Args:\n        img_size (int): Image size.  Default: 224.\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C\n        if self.norm is not None:\n            x = self.norm(x)\n        return x\n\n    def flops(self):\n        Ho, Wo = self.patches_resolution\n        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops\n\n\nclass SwinMLP(nn.Module):\n    r\"\"\" Swin MLP\n\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 224\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each Swin MLP layer.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        window_size (int): Window size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        drop_rate (float): Dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,\n                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],\n                 window_size=7, mlp_ratio=4., drop_rate=0., drop_path_rate=0.1,\n                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n                 use_checkpoint=False, **kwargs):\n        super().__init__()\n\n        self.num_classes = num_classes\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n        self.mlp_ratio = mlp_ratio\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        # absolute position embedding\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n            trunc_normal_(self.absolute_pos_embed, std=.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),\n                               input_resolution=(patches_resolution[0] // (2 ** i_layer),\n                                                 patches_resolution[1] // (2 ** i_layer)),\n                               depth=depths[i_layer],\n                               num_heads=num_heads[i_layer],\n                               window_size=window_size,\n                               mlp_ratio=self.mlp_ratio,\n                               drop=drop_rate,\n                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n                               norm_layer=norm_layer,\n                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n                               use_checkpoint=use_checkpoint)\n            self.layers.append(layer)\n\n        self.norm = norm_layer(self.num_features)\n        self.avgpool = nn.AdaptiveAvgPool1d(1)\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, (nn.Linear, nn.Conv1d)):\n            trunc_normal_(m.weight, std=.02)\n            if m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {'relative_position_bias_table'}\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        for layer in self.layers:\n            x = layer(x)\n\n        x = self.norm(x)  # B L C\n        x = self.avgpool(x.transpose(1, 2))  # B C 1\n        x = torch.flatten(x, 1)\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n    def flops(self):\n        flops = 0\n        flops += self.patch_embed.flops()\n        for i, layer in enumerate(self.layers):\n            flops += layer.flops()\n        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)\n        flops += self.num_features * self.num_classes\n        return flops\n"
  },
  {
    "path": "models/swin_transformer.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\ntry:\n    import os, sys\n\n    kernel_path = os.path.abspath(os.path.join('..'))\n    sys.path.append(kernel_path)\n    from kernels.window_process.window_process import WindowProcess, WindowProcessReverse\n\nexcept:\n    WindowProcess = None\n    WindowProcessReverse = None\n    print(\"[Warning] Fused window process have not been installed. Please refer to get_started.md for installation.\")\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowAttention(nn.Module):\n    r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        # define a parameter table of relative position bias\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        trunc_normal_(self.relative_position_bias_table, std=.02)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask=None):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'\n\n    def flops(self, N):\n        # calculate flops for 1 window with token length of N\n        flops = 0\n        # qkv = self.qkv(x)\n        flops += N * self.dim * 3 * self.dim\n        # attn = (q @ k.transpose(-2, -1))\n        flops += self.num_heads * N * (self.dim // self.num_heads) * N\n        #  x = (attn @ v)\n        flops += self.num_heads * N * N * (self.dim // self.num_heads)\n        # x = self.proj(x)\n        flops += N * self.dim * self.dim\n        return flops\n\n\nclass SwinTransformerBlock(nn.Module):\n    r\"\"\" Swin Transformer Block.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n        fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,\n                 fused_window_process=False):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        if min(self.input_resolution) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            self.shift_size = 0\n            self.window_size = min(self.input_resolution)\n        assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n\n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,\n            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        if self.shift_size > 0:\n            # calculate attention mask for SW-MSA\n            H, W = self.input_resolution\n            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1\n            h_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size, -self.shift_size),\n                        slice(-self.shift_size, None))\n            w_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size, -self.shift_size),\n                        slice(-self.shift_size, None))\n            cnt = 0\n            for h in h_slices:\n                for w in w_slices:\n                    img_mask[:, h, w, :] = cnt\n                    cnt += 1\n\n            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1\n            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n        else:\n            attn_mask = None\n\n        self.register_buffer(\"attn_mask\", attn_mask)\n        self.fused_window_process = fused_window_process\n\n    def forward(self, x):\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # cyclic shift\n        if self.shift_size > 0:\n            if not self.fused_window_process:\n                shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n                # partition windows\n                x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C\n            else:\n                x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)\n        else:\n            shifted_x = x\n            # partition windows\n            x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C\n\n        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C\n\n        # W-MSA/SW-MSA\n        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            if not self.fused_window_process:\n                shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C\n                x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n            else:\n                x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)\n        else:\n            shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C\n            x = shifted_x\n        x = x.view(B, H * W, C)\n        x = shortcut + self.drop_path(x)\n\n        # FFN\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n               f\"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"\n\n    def flops(self):\n        flops = 0\n        H, W = self.input_resolution\n        # norm1\n        flops += self.dim * H * W\n        # W-MSA/SW-MSA\n        nW = H * W / self.window_size / self.window_size\n        flops += nW * self.attn.flops(self.window_size * self.window_size)\n        # mlp\n        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n        # norm2\n        flops += self.dim * H * W\n        return flops\n\n\nclass PatchMerging(nn.Module):\n    r\"\"\" Patch Merging Layer.\n\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n        assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n\n        x = x.view(B, H, W, C)\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.norm(x)\n        x = self.reduction(x)\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n\n    def flops(self):\n        H, W = self.input_resolution\n        flops = H * W * self.dim\n        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim\n        return flops\n\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic Swin Transformer layer for one stage.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n        fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,\n                 fused_window_process=False):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,\n                                 num_heads=num_heads, window_size=window_size,\n                                 shift_size=0 if (i % 2 == 0) else window_size // 2,\n                                 mlp_ratio=mlp_ratio,\n                                 qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                 drop=drop, attn_drop=attn_drop,\n                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                                 norm_layer=norm_layer,\n                                 fused_window_process=fused_window_process)\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops\n\n\nclass PatchEmbed(nn.Module):\n    r\"\"\" Image to Patch Embedding\n\n    Args:\n        img_size (int): Image size.  Default: 224.\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C\n        if self.norm is not None:\n            x = self.norm(x)\n        return x\n\n    def flops(self):\n        Ho, Wo = self.patches_resolution\n        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops\n\n\nclass SwinTransformer(nn.Module):\n    r\"\"\" Swin Transformer\n        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -\n          https://arxiv.org/pdf/2103.14030\n\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 224\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each Swin Transformer layer.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        window_size (int): Window size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None\n        drop_rate (float): Dropout rate. Default: 0\n        attn_drop_rate (float): Attention dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n        fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,\n                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],\n                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n                 use_checkpoint=False, fused_window_process=False, **kwargs):\n        super().__init__()\n\n        self.num_classes = num_classes\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n        self.mlp_ratio = mlp_ratio\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        # absolute position embedding\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n            trunc_normal_(self.absolute_pos_embed, std=.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),\n                               input_resolution=(patches_resolution[0] // (2 ** i_layer),\n                                                 patches_resolution[1] // (2 ** i_layer)),\n                               depth=depths[i_layer],\n                               num_heads=num_heads[i_layer],\n                               window_size=window_size,\n                               mlp_ratio=self.mlp_ratio,\n                               qkv_bias=qkv_bias, qk_scale=qk_scale,\n                               drop=drop_rate, attn_drop=attn_drop_rate,\n                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n                               norm_layer=norm_layer,\n                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n                               use_checkpoint=use_checkpoint,\n                               fused_window_process=fused_window_process)\n            self.layers.append(layer)\n\n        self.norm = norm_layer(self.num_features)\n        self.avgpool = nn.AdaptiveAvgPool1d(1)\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {'relative_position_bias_table'}\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        for layer in self.layers:\n            x = layer(x)\n\n        x = self.norm(x)  # B L C\n        x = self.avgpool(x.transpose(1, 2))  # B C 1\n        x = torch.flatten(x, 1)\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n    def flops(self):\n        flops = 0\n        flops += self.patch_embed.flops()\n        for i, layer in enumerate(self.layers):\n            flops += layer.flops()\n        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)\n        flops += self.num_features * self.num_classes\n        return flops\n"
  },
  {
    "path": "models/swin_transformer_moe.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer MoE\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.distributed as dist\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nimport numpy as np\n\ntry:\n    from tutel import moe as tutel_moe\nexcept:\n    tutel_moe = None\n    print(\"Tutel has not been installed. To use Swin-MoE, please install Tutel; otherwise, just ignore this.\")\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,\n                 mlp_fc2_bias=True):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=mlp_fc2_bias)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass MoEMlp(nn.Module):\n    def __init__(self, in_features, hidden_features, num_local_experts, top_value, capacity_factor=1.25,\n                 cosine_router=False, normalize_gate=False, use_bpr=True, is_gshard_loss=True,\n                 gate_noise=1.0, cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0, init_std=0.02,\n                 mlp_fc2_bias=True):\n        super().__init__()\n\n        self.in_features = in_features\n        self.hidden_features = hidden_features\n        self.num_local_experts = num_local_experts\n        self.top_value = top_value\n        self.capacity_factor = capacity_factor\n        self.cosine_router = cosine_router\n        self.normalize_gate = normalize_gate\n        self.use_bpr = use_bpr\n        self.init_std = init_std\n        self.mlp_fc2_bias = mlp_fc2_bias\n\n        self.dist_rank = dist.get_rank()\n\n        self._dropout = nn.Dropout(p=moe_drop)\n\n        _gate_type = {'type': 'cosine_top' if cosine_router else 'top',\n                      'k': top_value, 'capacity_factor': capacity_factor,\n                      'gate_noise': gate_noise, 'fp32_gate': True}\n        if cosine_router:\n            _gate_type['proj_dim'] = cosine_router_dim\n            _gate_type['init_t'] = cosine_router_init_t\n        self._moe_layer = tutel_moe.moe_layer(\n            gate_type=_gate_type,\n            model_dim=in_features,\n            experts={'type': 'ffn', 'count_per_node': num_local_experts, 'hidden_size_per_expert': hidden_features,\n                     'activation_fn': lambda x: self._dropout(F.gelu(x))},\n            scan_expert_func=lambda name, param: setattr(param, 'skip_allreduce', True),\n            seeds=(1, self.dist_rank + 1, self.dist_rank + 1),\n            batch_prioritized_routing=use_bpr,\n            normalize_gate=normalize_gate,\n            is_gshard_loss=is_gshard_loss,\n\n        )\n        if not self.mlp_fc2_bias:\n            self._moe_layer.experts.batched_fc2_bias.requires_grad = False\n\n    def forward(self, x):\n        x = self._moe_layer(x)\n        return x, x.l_aux\n\n    def extra_repr(self) -> str:\n        return f'[Statistics-{self.dist_rank}] param count for MoE, ' \\\n               f'in_features = {self.in_features}, hidden_features = {self.hidden_features}, ' \\\n               f'num_local_experts = {self.num_local_experts}, top_value = {self.top_value}, ' \\\n               f'cosine_router={self.cosine_router} normalize_gate={self.normalize_gate}, use_bpr = {self.use_bpr}'\n\n    def _init_weights(self):\n        if hasattr(self._moe_layer, \"experts\"):\n            trunc_normal_(self._moe_layer.experts.batched_fc1_w, std=self.init_std)\n            trunc_normal_(self._moe_layer.experts.batched_fc2_w, std=self.init_std)\n            nn.init.constant_(self._moe_layer.experts.batched_fc1_bias, 0)\n            nn.init.constant_(self._moe_layer.experts.batched_fc2_bias, 0)\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowAttention(nn.Module):\n    r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n        pretrained_window_size (tuple[int]): The height and width of the window in pre-training.\n    \"\"\"\n\n    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,\n                 pretrained_window_size=[0, 0]):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.pretrained_window_size = pretrained_window_size\n        self.num_heads = num_heads\n\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        # mlp to generate continuous relative position bias\n        self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),\n                                     nn.ReLU(inplace=True),\n                                     nn.Linear(512, num_heads, bias=False))\n\n        # get relative_coords_table\n        relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)\n        relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)\n        relative_coords_table = torch.stack(\n            torch.meshgrid([relative_coords_h,\n                            relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0)  # 1, 2*Wh-1, 2*Ww-1, 2\n        if pretrained_window_size[0] > 0:\n            relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)\n            relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)\n        else:\n            relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)\n            relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)\n        relative_coords_table *= 8  # normalize to -8, 8\n        relative_coords_table = torch.sign(relative_coords_table) * torch.log2(\n            torch.abs(relative_coords_table) + 1.0) / np.log2(8)\n\n        self.register_buffer(\"relative_coords_table\", relative_coords_table)\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask=None):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)\n        relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f'dim={self.dim}, window_size={self.window_size}, ' \\\n               f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'\n\n    def flops(self, N):\n        # calculate flops for 1 window with token length of N\n        flops = 0\n        # qkv = self.qkv(x)\n        flops += N * self.dim * 3 * self.dim\n        # attn = (q @ k.transpose(-2, -1))\n        flops += self.num_heads * N * (self.dim // self.num_heads) * N\n        #  x = (attn @ v)\n        flops += self.num_heads * N * N * (self.dim // self.num_heads)\n        # x = self.proj(x)\n        flops += N * self.dim * self.dim\n        return flops\n\n\nclass SwinTransformerBlock(nn.Module):\n    r\"\"\" Swin Transformer Block.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n        mlp_fc2_bias (bool): Whether to add bias in fc2 of Mlp. Default: True\n        init_std: Initialization std. Default: 0.02\n        pretrained_window_size (int): Window size in pre-training.\n        is_moe (bool): If True, this block is a MoE block.\n        num_local_experts (int): number of local experts in each device (GPU). Default: 1\n        top_value (int): the value of k in top-k gating. Default: 1\n        capacity_factor (float): the capacity factor in MoE. Default: 1.25\n        cosine_router (bool): Whether to use cosine router. Default: False\n        normalize_gate (bool): Whether to normalize the gating score in top-k gating. Default: False\n        use_bpr (bool): Whether to use batch-prioritized-routing. Default: True\n        is_gshard_loss (bool): If True, use Gshard balance loss.\n                               If False, use the load loss and importance loss in \"arXiv:1701.06538\". Default: False\n        gate_noise (float): the noise ratio in top-k gating. Default: 1.0\n        cosine_router_dim (int): Projection dimension in cosine router.\n        cosine_router_init_t (float): Initialization temperature in cosine router.\n        moe_drop (float): Dropout rate in MoE. Default: 0.0\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, mlp_fc2_bias=True, init_std=0.02, pretrained_window_size=0,\n                 is_moe=False, num_local_experts=1, top_value=1, capacity_factor=1.25, cosine_router=False,\n                 normalize_gate=False, use_bpr=True, is_gshard_loss=True, gate_noise=1.0,\n                 cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        self.is_moe = is_moe\n        self.capacity_factor = capacity_factor\n        self.top_value = top_value\n\n        if min(self.input_resolution) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            self.shift_size = 0\n            self.window_size = min(self.input_resolution)\n        assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n\n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,\n            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,\n            pretrained_window_size=to_2tuple(pretrained_window_size))\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        if self.is_moe:\n            self.mlp = MoEMlp(in_features=dim,\n                              hidden_features=mlp_hidden_dim,\n                              num_local_experts=num_local_experts,\n                              top_value=top_value,\n                              capacity_factor=capacity_factor,\n                              cosine_router=cosine_router,\n                              normalize_gate=normalize_gate,\n                              use_bpr=use_bpr,\n                              is_gshard_loss=is_gshard_loss,\n                              gate_noise=gate_noise,\n                              cosine_router_dim=cosine_router_dim,\n                              cosine_router_init_t=cosine_router_init_t,\n                              moe_drop=moe_drop,\n                              mlp_fc2_bias=mlp_fc2_bias,\n                              init_std=init_std)\n        else:\n            self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,\n                           mlp_fc2_bias=mlp_fc2_bias)\n\n        if self.shift_size > 0:\n            # calculate attention mask for SW-MSA\n            H, W = self.input_resolution\n            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1\n            h_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size, -self.shift_size),\n                        slice(-self.shift_size, None))\n            w_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size, -self.shift_size),\n                        slice(-self.shift_size, None))\n            cnt = 0\n            for h in h_slices:\n                for w in w_slices:\n                    img_mask[:, h, w, :] = cnt\n                    cnt += 1\n\n            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1\n            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n        else:\n            attn_mask = None\n\n        self.register_buffer(\"attn_mask\", attn_mask)\n\n    def forward(self, x):\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # cyclic shift\n        if self.shift_size > 0:\n            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n        else:\n            shifted_x = x\n\n        # partition windows\n        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C\n        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C\n\n        # W-MSA/SW-MSA\n        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C\n\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n        else:\n            x = shifted_x\n        x = x.view(B, H * W, C)\n        x = shortcut + self.drop_path(x)\n\n        # FFN\n        shortcut = x\n        x = self.norm2(x)\n        if self.is_moe:\n            x, l_aux = self.mlp(x)\n            x = shortcut + self.drop_path(x)\n            return x, l_aux\n        else:\n            x = shortcut + self.drop_path(self.mlp(x))\n            return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n               f\"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"\n\n    def flops(self):\n        flops = 0\n        H, W = self.input_resolution\n        # norm1\n        flops += self.dim * H * W\n        # W-MSA/SW-MSA\n        nW = H * W / self.window_size / self.window_size\n        flops += nW * self.attn.flops(self.window_size * self.window_size)\n        # mlp\n        if self.is_moe:\n            flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio * self.capacity_factor * self.top_value\n        else:\n            flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n        # norm2\n        flops += self.dim * H * W\n        return flops\n\n\nclass PatchMerging(nn.Module):\n    r\"\"\" Patch Merging Layer.\n\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n        assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n\n        x = x.view(B, H, W, C)\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.norm(x)\n        x = self.reduction(x)\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n\n    def flops(self):\n        H, W = self.input_resolution\n        flops = H * W * self.dim\n        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim\n        return flops\n\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic Swin Transformer layer for one stage.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        mlp_fc2_bias (bool): Whether to add bias in fc2 of Mlp. Default: True\n        init_std: Initialization std. Default: 0.02\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n        pretrained_window_size (int): Local window size in pre-training.\n        moe_blocks (tuple(int)): The index of each MoE block.\n        num_local_experts (int): number of local experts in each device (GPU). Default: 1\n        top_value (int): the value of k in top-k gating. Default: 1\n        capacity_factor (float): the capacity factor in MoE. Default: 1.25\n        cosine_router (bool): Whether to use cosine router Default: False\n        normalize_gate (bool): Whether to normalize the gating score in top-k gating. Default: False\n        use_bpr (bool): Whether to use batch-prioritized-routing. Default: True\n        is_gshard_loss (bool): If True, use Gshard balance loss.\n                               If False, use the load loss and importance loss in \"arXiv:1701.06538\". Default: False\n        gate_noise (float): the noise ratio in top-k gating. Default: 1.0\n        cosine_router_dim (int): Projection dimension in cosine router.\n        cosine_router_init_t (float): Initialization temperature in cosine router.\n        moe_drop (float): Dropout rate in MoE. Default: 0.0\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None,\n                 mlp_fc2_bias=True, init_std=0.02, use_checkpoint=False, pretrained_window_size=0,\n                 moe_block=[-1], num_local_experts=1, top_value=1, capacity_factor=1.25, cosine_router=False,\n                 normalize_gate=False, use_bpr=True, is_gshard_loss=True,\n                 cosine_router_dim=256, cosine_router_init_t=0.5, gate_noise=1.0, moe_drop=0.0):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,\n                                 num_heads=num_heads, window_size=window_size,\n                                 shift_size=0 if (i % 2 == 0) else window_size // 2,\n                                 mlp_ratio=mlp_ratio,\n                                 qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                 drop=drop, attn_drop=attn_drop,\n                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                                 norm_layer=norm_layer,\n                                 mlp_fc2_bias=mlp_fc2_bias,\n                                 init_std=init_std,\n                                 pretrained_window_size=pretrained_window_size,\n\n                                 is_moe=True if i in moe_block else False,\n                                 num_local_experts=num_local_experts,\n                                 top_value=top_value,\n                                 capacity_factor=capacity_factor,\n                                 cosine_router=cosine_router,\n                                 normalize_gate=normalize_gate,\n                                 use_bpr=use_bpr,\n                                 is_gshard_loss=is_gshard_loss,\n                                 gate_noise=gate_noise,\n                                 cosine_router_dim=cosine_router_dim,\n                                 cosine_router_init_t=cosine_router_init_t,\n                                 moe_drop=moe_drop)\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        l_aux = 0.0\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                out = checkpoint.checkpoint(blk, x)\n            else:\n                out = blk(x)\n            if isinstance(out, tuple):\n                x = out[0]\n                cur_l_aux = out[1]\n                l_aux = cur_l_aux + l_aux\n            else:\n                x = out\n\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x, l_aux\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops\n\n\nclass PatchEmbed(nn.Module):\n    r\"\"\" Image to Patch Embedding\n\n    Args:\n        img_size (int): Image size.  Default: 224.\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C\n        if self.norm is not None:\n            x = self.norm(x)\n        return x\n\n    def flops(self):\n        Ho, Wo = self.patches_resolution\n        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops\n\n\nclass SwinTransformerMoE(nn.Module):\n    r\"\"\" Swin Transformer\n        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -\n          https://arxiv.org/pdf/2103.14030\n\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 224\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each Swin Transformer layer.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        window_size (int): Window size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None\n        drop_rate (float): Dropout rate. Default: 0\n        attn_drop_rate (float): Attention dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n        mlp_fc2_bias (bool): Whether to add bias in fc2 of Mlp. Default: True\n        init_std: Initialization std. Default: 0.02\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n        pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.\n        moe_blocks (tuple(tuple(int))): The index of each MoE block in each layer.\n        num_local_experts (int): number of local experts in each device (GPU). Default: 1\n        top_value (int): the value of k in top-k gating. Default: 1\n        capacity_factor (float): the capacity factor in MoE. Default: 1.25\n        cosine_router (bool): Whether to use cosine router Default: False\n        normalize_gate (bool): Whether to normalize the gating score in top-k gating. Default: False\n        use_bpr (bool): Whether to use batch-prioritized-routing. Default: True\n        is_gshard_loss (bool): If True, use Gshard balance loss.\n                               If False, use the load loss and importance loss in \"arXiv:1701.06538\". Default: False\n        gate_noise (float): the noise ratio in top-k gating. Default: 1.0\n        cosine_router_dim (int): Projection dimension in cosine router.\n        cosine_router_init_t (float): Initialization temperature in cosine router.\n        moe_drop (float): Dropout rate in MoE. Default: 0.0\n        aux_loss_weight (float): auxiliary loss weight. Default: 0.1\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,\n                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],\n                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n                 mlp_fc2_bias=True, init_std=0.02, use_checkpoint=False, pretrained_window_sizes=[0, 0, 0, 0],\n                 moe_blocks=[[-1], [-1], [-1], [-1]], num_local_experts=1, top_value=1, capacity_factor=1.25,\n                 cosine_router=False, normalize_gate=False, use_bpr=True, is_gshard_loss=True, gate_noise=1.0,\n                 cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0, aux_loss_weight=0.01, **kwargs):\n        super().__init__()\n        self._ddp_params_and_buffers_to_ignore = list()\n\n        self.num_classes = num_classes\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n        self.mlp_ratio = mlp_ratio\n        self.init_std = init_std\n        self.aux_loss_weight = aux_loss_weight\n        self.num_local_experts = num_local_experts\n        self.global_experts = num_local_experts * dist.get_world_size() if num_local_experts > 0 \\\n            else dist.get_world_size() // (-num_local_experts)\n        self.sharded_count = (1.0 / num_local_experts) if num_local_experts > 0 else (-num_local_experts)\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        # absolute position embedding\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n            trunc_normal_(self.absolute_pos_embed, std=self.init_std)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),\n                               input_resolution=(patches_resolution[0] // (2 ** i_layer),\n                                                 patches_resolution[1] // (2 ** i_layer)),\n                               depth=depths[i_layer],\n                               num_heads=num_heads[i_layer],\n                               window_size=window_size,\n                               mlp_ratio=self.mlp_ratio,\n                               qkv_bias=qkv_bias, qk_scale=qk_scale,\n                               drop=drop_rate, attn_drop=attn_drop_rate,\n                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n                               norm_layer=norm_layer,\n                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n                               mlp_fc2_bias=mlp_fc2_bias,\n                               init_std=init_std,\n                               use_checkpoint=use_checkpoint,\n                               pretrained_window_size=pretrained_window_sizes[i_layer],\n\n                               moe_block=moe_blocks[i_layer],\n                               num_local_experts=num_local_experts,\n                               top_value=top_value,\n                               capacity_factor=capacity_factor,\n                               cosine_router=cosine_router,\n                               normalize_gate=normalize_gate,\n                               use_bpr=use_bpr,\n                               is_gshard_loss=is_gshard_loss,\n                               gate_noise=gate_noise,\n                               cosine_router_dim=cosine_router_dim,\n                               cosine_router_init_t=cosine_router_init_t,\n                               moe_drop=moe_drop)\n            self.layers.append(layer)\n\n        self.norm = norm_layer(self.num_features)\n        self.avgpool = nn.AdaptiveAvgPool1d(1)\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=self.init_std)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, MoEMlp):\n            m._init_weights()\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {\"cpb_mlp\", 'relative_position_bias_table', 'fc1_bias', 'fc2_bias',\n                'temperature', 'cosine_projector', 'sim_matrix'}\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n        l_aux = 0.0\n        for layer in self.layers:\n            x, cur_l_aux = layer(x)\n            l_aux = cur_l_aux + l_aux\n\n        x = self.norm(x)  # B L C\n        x = self.avgpool(x.transpose(1, 2))  # B C 1\n        x = torch.flatten(x, 1)\n        return x, l_aux\n\n    def forward(self, x):\n        x, l_aux = self.forward_features(x)\n        x = self.head(x)\n        return x, l_aux * self.aux_loss_weight\n\n    def add_param_to_skip_allreduce(self, param_name):\n        self._ddp_params_and_buffers_to_ignore.append(param_name)\n\n    def flops(self):\n        flops = 0\n        flops += self.patch_embed.flops()\n        for i, layer in enumerate(self.layers):\n            flops += layer.flops()\n        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)\n        flops += self.num_features * self.num_classes\n        return flops\n"
  },
  {
    "path": "models/swin_transformer_v2.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer V2\n# Copyright (c) 2022 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nimport numpy as np\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowAttention(nn.Module):\n    r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n        pretrained_window_size (tuple[int]): The height and width of the window in pre-training.\n    \"\"\"\n\n    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,\n                 pretrained_window_size=[0, 0]):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.pretrained_window_size = pretrained_window_size\n        self.num_heads = num_heads\n\n        self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)\n\n        # mlp to generate continuous relative position bias\n        self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),\n                                     nn.ReLU(inplace=True),\n                                     nn.Linear(512, num_heads, bias=False))\n\n        # get relative_coords_table\n        relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)\n        relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)\n        relative_coords_table = torch.stack(\n            torch.meshgrid([relative_coords_h,\n                            relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0)  # 1, 2*Wh-1, 2*Ww-1, 2\n        if pretrained_window_size[0] > 0:\n            relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)\n            relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)\n        else:\n            relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)\n            relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)\n        relative_coords_table *= 8  # normalize to -8, 8\n        relative_coords_table = torch.sign(relative_coords_table) * torch.log2(\n            torch.abs(relative_coords_table) + 1.0) / np.log2(8)\n\n        self.register_buffer(\"relative_coords_table\", relative_coords_table)\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=False)\n        if qkv_bias:\n            self.q_bias = nn.Parameter(torch.zeros(dim))\n            self.v_bias = nn.Parameter(torch.zeros(dim))\n        else:\n            self.q_bias = None\n            self.v_bias = None\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask=None):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        qkv_bias = None\n        if self.q_bias is not None:\n            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))\n        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)\n        qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)\n\n        # cosine attention\n        attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))\n        logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()\n        attn = attn * logit_scale\n\n        relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)\n        relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n        relative_position_bias = 16 * torch.sigmoid(relative_position_bias)\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f'dim={self.dim}, window_size={self.window_size}, ' \\\n               f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'\n\n    def flops(self, N):\n        # calculate flops for 1 window with token length of N\n        flops = 0\n        # qkv = self.qkv(x)\n        flops += N * self.dim * 3 * self.dim\n        # attn = (q @ k.transpose(-2, -1))\n        flops += self.num_heads * N * (self.dim // self.num_heads) * N\n        #  x = (attn @ v)\n        flops += self.num_heads * N * N * (self.dim // self.num_heads)\n        # x = self.proj(x)\n        flops += N * self.dim * self.dim\n        return flops\n\n\nclass SwinTransformerBlock(nn.Module):\n    r\"\"\" Swin Transformer Block.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n        pretrained_window_size (int): Window size in pre-training.\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\n                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        if min(self.input_resolution) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            self.shift_size = 0\n            self.window_size = min(self.input_resolution)\n        assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n\n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,\n            qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,\n            pretrained_window_size=to_2tuple(pretrained_window_size))\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        if self.shift_size > 0:\n            # calculate attention mask for SW-MSA\n            H, W = self.input_resolution\n            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1\n            h_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size, -self.shift_size),\n                        slice(-self.shift_size, None))\n            w_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size, -self.shift_size),\n                        slice(-self.shift_size, None))\n            cnt = 0\n            for h in h_slices:\n                for w in w_slices:\n                    img_mask[:, h, w, :] = cnt\n                    cnt += 1\n\n            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1\n            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n        else:\n            attn_mask = None\n\n        self.register_buffer(\"attn_mask\", attn_mask)\n\n    def forward(self, x):\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n\n        shortcut = x\n        x = x.view(B, H, W, C)\n\n        # cyclic shift\n        if self.shift_size > 0:\n            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n        else:\n            shifted_x = x\n\n        # partition windows\n        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C\n        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C\n\n        # W-MSA/SW-MSA\n        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C\n\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n        else:\n            x = shifted_x\n        x = x.view(B, H * W, C)\n        x = shortcut + self.drop_path(self.norm1(x))\n\n        # FFN\n        x = x + self.drop_path(self.norm2(self.mlp(x)))\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n               f\"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"\n\n    def flops(self):\n        flops = 0\n        H, W = self.input_resolution\n        # norm1\n        flops += self.dim * H * W\n        # W-MSA/SW-MSA\n        nW = H * W / self.window_size / self.window_size\n        flops += nW * self.attn.flops(self.window_size * self.window_size)\n        # mlp\n        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n        # norm2\n        flops += self.dim * H * W\n        return flops\n\n\nclass PatchMerging(nn.Module):\n    r\"\"\" Patch Merging Layer.\n\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(2 * dim)\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n        assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n\n        x = x.view(B, H, W, C)\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.reduction(x)\n        x = self.norm(x)\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n\n    def flops(self):\n        H, W = self.input_resolution\n        flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim\n        flops += H * W * self.dim // 2\n        return flops\n\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic Swin Transformer layer for one stage.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n        pretrained_window_size (int): Local window size in pre-training.\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,\n                 pretrained_window_size=0):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,\n                                 num_heads=num_heads, window_size=window_size,\n                                 shift_size=0 if (i % 2 == 0) else window_size // 2,\n                                 mlp_ratio=mlp_ratio,\n                                 qkv_bias=qkv_bias,\n                                 drop=drop, attn_drop=attn_drop,\n                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                                 norm_layer=norm_layer,\n                                 pretrained_window_size=pretrained_window_size)\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops\n\n    def _init_respostnorm(self):\n        for blk in self.blocks:\n            nn.init.constant_(blk.norm1.bias, 0)\n            nn.init.constant_(blk.norm1.weight, 0)\n            nn.init.constant_(blk.norm2.bias, 0)\n            nn.init.constant_(blk.norm2.weight, 0)\n\n\nclass PatchEmbed(nn.Module):\n    r\"\"\" Image to Patch Embedding\n\n    Args:\n        img_size (int): Image size.  Default: 224.\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C\n        if self.norm is not None:\n            x = self.norm(x)\n        return x\n\n    def flops(self):\n        Ho, Wo = self.patches_resolution\n        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops\n\n\nclass SwinTransformerV2(nn.Module):\n    r\"\"\" Swin Transformer\n        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -\n          https://arxiv.org/pdf/2103.14030\n\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 224\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each Swin Transformer layer.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        window_size (int): Window size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        drop_rate (float): Dropout rate. Default: 0\n        attn_drop_rate (float): Attention dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n        pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,\n                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],\n                 window_size=7, mlp_ratio=4., qkv_bias=True,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n                 use_checkpoint=False, pretrained_window_sizes=[0, 0, 0, 0], **kwargs):\n        super().__init__()\n\n        self.num_classes = num_classes\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n        self.mlp_ratio = mlp_ratio\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        # absolute position embedding\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n            trunc_normal_(self.absolute_pos_embed, std=.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),\n                               input_resolution=(patches_resolution[0] // (2 ** i_layer),\n                                                 patches_resolution[1] // (2 ** i_layer)),\n                               depth=depths[i_layer],\n                               num_heads=num_heads[i_layer],\n                               window_size=window_size,\n                               mlp_ratio=self.mlp_ratio,\n                               qkv_bias=qkv_bias,\n                               drop=drop_rate, attn_drop=attn_drop_rate,\n                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n                               norm_layer=norm_layer,\n                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n                               use_checkpoint=use_checkpoint,\n                               pretrained_window_size=pretrained_window_sizes[i_layer])\n            self.layers.append(layer)\n\n        self.norm = norm_layer(self.num_features)\n        self.avgpool = nn.AdaptiveAvgPool1d(1)\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n\n        self.apply(self._init_weights)\n        for bly in self.layers:\n            bly._init_respostnorm()\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {\"cpb_mlp\", \"logit_scale\", 'relative_position_bias_table'}\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        for layer in self.layers:\n            x = layer(x)\n\n        x = self.norm(x)  # B L C\n        x = self.avgpool(x.transpose(1, 2))  # B C 1\n        x = torch.flatten(x, 1)\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n    def flops(self):\n        flops = 0\n        flops += self.patch_embed.flops()\n        for i, layer in enumerate(self.layers):\n            flops += layer.flops()\n        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)\n        flops += self.num_features * self.num_classes\n        return flops\n"
  },
  {
    "path": "optimizer.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nfrom functools import partial\nfrom torch import optim as optim\n\ntry:\n    from apex.optimizers import FusedAdam, FusedLAMB\nexcept:\n    FusedAdam = None\n    FusedLAMB = None\n    print(\"To use FusedLAMB or FusedAdam, please install apex.\")\n\n\ndef build_optimizer(config, model, simmim=False, is_pretrain=False):\n    \"\"\"\n    Build optimizer, set weight decay of normalization to 0 by default.\n    \"\"\"\n    skip = {}\n    skip_keywords = {}\n    if hasattr(model, 'no_weight_decay'):\n        skip = model.no_weight_decay()\n    if hasattr(model, 'no_weight_decay_keywords'):\n        skip_keywords = model.no_weight_decay_keywords()\n    if simmim:\n        if is_pretrain:\n            parameters = get_pretrain_param_groups(model, skip, skip_keywords)\n        else:\n            depths = config.MODEL.SWIN.DEPTHS if config.MODEL.TYPE == 'swin' else config.MODEL.SWINV2.DEPTHS\n            num_layers = sum(depths)\n            get_layer_func = partial(get_swin_layer, num_layers=num_layers + 2, depths=depths)\n            scales = list(config.TRAIN.LAYER_DECAY ** i for i in reversed(range(num_layers + 2)))\n            parameters = get_finetune_param_groups(model, config.TRAIN.BASE_LR, config.TRAIN.WEIGHT_DECAY, get_layer_func, scales, skip, skip_keywords)\n    else:\n        parameters = set_weight_decay(model, skip, skip_keywords)\n\n    opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()\n    optimizer = None\n    if opt_lower == 'sgd':\n        optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,\n                              lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)\n    elif opt_lower == 'adamw':\n        optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,\n                                lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)\n    elif opt_lower == 'fused_adam':\n        optimizer = FusedAdam(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,\n                              lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)\n    elif opt_lower == 'fused_lamb':\n        optimizer = FusedLAMB(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,\n                              lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)\n\n    return optimizer\n\n\ndef set_weight_decay(model, skip_list=(), skip_keywords=()):\n    has_decay = []\n    no_decay = []\n\n    for name, param in model.named_parameters():\n        if not param.requires_grad:\n            continue  # frozen weights\n        if len(param.shape) == 1 or name.endswith(\".bias\") or (name in skip_list) or \\\n                check_keywords_in_name(name, skip_keywords):\n            no_decay.append(param)\n            # print(f\"{name} has no weight decay\")\n        else:\n            has_decay.append(param)\n    return [{'params': has_decay},\n            {'params': no_decay, 'weight_decay': 0.}]\n\n\ndef check_keywords_in_name(name, keywords=()):\n    isin = False\n    for keyword in keywords:\n        if keyword in name:\n            isin = True\n    return isin\n\n\ndef get_pretrain_param_groups(model, skip_list=(), skip_keywords=()):\n    has_decay = []\n    no_decay = []\n    has_decay_name = []\n    no_decay_name = []\n    \n    for name, param in model.named_parameters():\n        if not param.requires_grad:\n            continue\n        if len(param.shape) == 1 or name.endswith(\".bias\") or (name in skip_list) or \\\n                check_keywords_in_name(name, skip_keywords):\n            no_decay.append(param)\n            no_decay_name.append(name)\n        else:\n            has_decay.append(param)\n            has_decay_name.append(name)\n    return [{'params': has_decay},\n            {'params': no_decay, 'weight_decay': 0.}]\n\n\ndef get_swin_layer(name, num_layers, depths):\n    if name in (\"mask_token\"):\n        return 0\n    elif name.startswith(\"patch_embed\"):\n        return 0\n    elif name.startswith(\"layers\"):\n        layer_id = int(name.split('.')[1])\n        block_id = name.split('.')[3]\n        if block_id == 'reduction' or block_id == 'norm':\n            return sum(depths[:layer_id + 1])\n        layer_id = sum(depths[:layer_id]) + int(block_id)\n        return layer_id + 1\n    else:\n        return num_layers - 1\n\n\ndef get_finetune_param_groups(model, lr, weight_decay, get_layer_func, scales, skip_list=(), skip_keywords=()):\n    parameter_group_names = {}\n    parameter_group_vars = {}\n\n    for name, param in model.named_parameters():\n        if not param.requires_grad:\n            continue\n        if len(param.shape) == 1 or name.endswith(\".bias\") or (name in skip_list) or \\\n                check_keywords_in_name(name, skip_keywords):\n            group_name = \"no_decay\"\n            this_weight_decay = 0.\n        else:\n            group_name = \"decay\"\n            this_weight_decay = weight_decay\n        if get_layer_func is not None:\n            layer_id = get_layer_func(name)\n            group_name = \"layer_%d_%s\" % (layer_id, group_name)\n        else:\n            layer_id = None\n\n        if group_name not in parameter_group_names:\n            if scales is not None:\n                scale = scales[layer_id]\n            else:\n                scale = 1.\n\n            parameter_group_names[group_name] = {\n                \"group_name\": group_name,\n                \"weight_decay\": this_weight_decay,\n                \"params\": [],\n                \"lr\": lr * scale,\n                \"lr_scale\": scale,\n            }\n            parameter_group_vars[group_name] = {\n                \"group_name\": group_name,\n                \"weight_decay\": this_weight_decay,\n                \"params\": [],\n                \"lr\": lr * scale,\n                \"lr_scale\": scale\n            }\n\n        parameter_group_vars[group_name][\"params\"].append(param)\n        parameter_group_names[group_name][\"params\"].append(name)\n    return list(parameter_group_vars.values())\n"
  },
  {
    "path": "utils.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nimport os\nimport torch\nimport torch.distributed as dist\n\ntry:\n    from torch._six import inf\nexcept:\n    from torch import inf\n\n\ndef load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger):\n    logger.info(f\"==============> Resuming form {config.MODEL.RESUME}....................\")\n    if config.MODEL.RESUME.startswith('https'):\n        checkpoint = torch.hub.load_state_dict_from_url(\n            config.MODEL.RESUME, map_location='cpu', check_hash=True)\n    else:\n        checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')\n    msg = model.load_state_dict(checkpoint['model'], strict=False)\n    logger.info(msg)\n    max_accuracy = 0.0\n    if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:\n        optimizer.load_state_dict(checkpoint['optimizer'])\n        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n        config.defrost()\n        config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1\n        config.freeze()\n        if 'scaler' in checkpoint:\n            loss_scaler.load_state_dict(checkpoint['scaler'])\n        logger.info(f\"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})\")\n        if 'max_accuracy' in checkpoint:\n            max_accuracy = checkpoint['max_accuracy']\n\n    del checkpoint\n    torch.cuda.empty_cache()\n    return max_accuracy\n\n\ndef load_pretrained(config, model, logger):\n    logger.info(f\"==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......\")\n    checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')\n    state_dict = checkpoint['model']\n\n    # delete relative_position_index since we always re-init it\n    relative_position_index_keys = [k for k in state_dict.keys() if \"relative_position_index\" in k]\n    for k in relative_position_index_keys:\n        del state_dict[k]\n\n    # delete relative_coords_table since we always re-init it\n    relative_position_index_keys = [k for k in state_dict.keys() if \"relative_coords_table\" in k]\n    for k in relative_position_index_keys:\n        del state_dict[k]\n\n    # delete attn_mask since we always re-init it\n    attn_mask_keys = [k for k in state_dict.keys() if \"attn_mask\" in k]\n    for k in attn_mask_keys:\n        del state_dict[k]\n\n    # bicubic interpolate relative_position_bias_table if not match\n    relative_position_bias_table_keys = [k for k in state_dict.keys() if \"relative_position_bias_table\" in k]\n    for k in relative_position_bias_table_keys:\n        relative_position_bias_table_pretrained = state_dict[k]\n        relative_position_bias_table_current = model.state_dict()[k]\n        L1, nH1 = relative_position_bias_table_pretrained.size()\n        L2, nH2 = relative_position_bias_table_current.size()\n        if nH1 != nH2:\n            logger.warning(f\"Error in loading {k}, passing......\")\n        else:\n            if L1 != L2:\n                # bicubic interpolate relative_position_bias_table if not match\n                S1 = int(L1 ** 0.5)\n                S2 = int(L2 ** 0.5)\n                relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(\n                    relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2),\n                    mode='bicubic')\n                state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)\n\n    # bicubic interpolate absolute_pos_embed if not match\n    absolute_pos_embed_keys = [k for k in state_dict.keys() if \"absolute_pos_embed\" in k]\n    for k in absolute_pos_embed_keys:\n        # dpe\n        absolute_pos_embed_pretrained = state_dict[k]\n        absolute_pos_embed_current = model.state_dict()[k]\n        _, L1, C1 = absolute_pos_embed_pretrained.size()\n        _, L2, C2 = absolute_pos_embed_current.size()\n        if C1 != C1:\n            logger.warning(f\"Error in loading {k}, passing......\")\n        else:\n            if L1 != L2:\n                S1 = int(L1 ** 0.5)\n                S2 = int(L2 ** 0.5)\n                absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)\n                absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)\n                absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(\n                    absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')\n                absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)\n                absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2)\n                state_dict[k] = absolute_pos_embed_pretrained_resized\n\n    # check classifier, if not match, then re-init classifier to zero\n    head_bias_pretrained = state_dict['head.bias']\n    Nc1 = head_bias_pretrained.shape[0]\n    Nc2 = model.head.bias.shape[0]\n    if (Nc1 != Nc2):\n        if Nc1 == 21841 and Nc2 == 1000:\n            logger.info(\"loading ImageNet-22K weight to ImageNet-1K ......\")\n            map22kto1k_path = f'data/map22kto1k.txt'\n            with open(map22kto1k_path) as f:\n                map22kto1k = f.readlines()\n            map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]\n            state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :]\n            state_dict['head.bias'] = state_dict['head.bias'][map22kto1k]\n        else:\n            torch.nn.init.constant_(model.head.bias, 0.)\n            torch.nn.init.constant_(model.head.weight, 0.)\n            del state_dict['head.weight']\n            del state_dict['head.bias']\n            logger.warning(f\"Error in loading classifier head, re-init classifier head to 0\")\n\n    msg = model.load_state_dict(state_dict, strict=False)\n    logger.warning(msg)\n\n    logger.info(f\"=> loaded successfully '{config.MODEL.PRETRAINED}'\")\n\n    del checkpoint\n    torch.cuda.empty_cache()\n\n\ndef save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger):\n    save_state = {'model': model.state_dict(),\n                  'optimizer': optimizer.state_dict(),\n                  'lr_scheduler': lr_scheduler.state_dict(),\n                  'max_accuracy': max_accuracy,\n                  'scaler': loss_scaler.state_dict(),\n                  'epoch': epoch,\n                  'config': config}\n\n    save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')\n    logger.info(f\"{save_path} saving......\")\n    torch.save(save_state, save_path)\n    logger.info(f\"{save_path} saved !!!\")\n\n\ndef get_grad_norm(parameters, norm_type=2):\n    if isinstance(parameters, torch.Tensor):\n        parameters = [parameters]\n    parameters = list(filter(lambda p: p.grad is not None, parameters))\n    norm_type = float(norm_type)\n    total_norm = 0\n    for p in parameters:\n        param_norm = p.grad.data.norm(norm_type)\n        total_norm += param_norm.item() ** norm_type\n    total_norm = total_norm ** (1. / norm_type)\n    return total_norm\n\n\ndef auto_resume_helper(output_dir):\n    checkpoints = os.listdir(output_dir)\n    checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]\n    print(f\"All checkpoints founded in {output_dir}: {checkpoints}\")\n    if len(checkpoints) > 0:\n        latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)\n        print(f\"The latest checkpoint founded: {latest_checkpoint}\")\n        resume_file = latest_checkpoint\n    else:\n        resume_file = None\n    return resume_file\n\n\ndef reduce_tensor(tensor):\n    rt = tensor.clone()\n    dist.all_reduce(rt, op=dist.ReduceOp.SUM)\n    rt /= dist.get_world_size()\n    return rt\n\n\ndef ampscaler_get_grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor:\n    if isinstance(parameters, torch.Tensor):\n        parameters = [parameters]\n    parameters = [p for p in parameters if p.grad is not None]\n    norm_type = float(norm_type)\n    if len(parameters) == 0:\n        return torch.tensor(0.)\n    device = parameters[0].grad.device\n    if norm_type == inf:\n        total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)\n    else:\n        total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(),\n                                                        norm_type).to(device) for p in parameters]), norm_type)\n    return total_norm\n\n\nclass NativeScalerWithGradNormCount:\n    state_dict_key = \"amp_scaler\"\n\n    def __init__(self):\n        self._scaler = torch.cuda.amp.GradScaler()\n\n    def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):\n        self._scaler.scale(loss).backward(create_graph=create_graph)\n        if update_grad:\n            if clip_grad is not None:\n                assert parameters is not None\n                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place\n                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)\n            else:\n                self._scaler.unscale_(optimizer)\n                norm = ampscaler_get_grad_norm(parameters)\n            self._scaler.step(optimizer)\n            self._scaler.update()\n        else:\n            norm = None\n        return norm\n\n    def state_dict(self):\n        return self._scaler.state_dict()\n\n    def load_state_dict(self, state_dict):\n        self._scaler.load_state_dict(state_dict)\n"
  },
  {
    "path": "utils_moe.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------\n\nimport os\nimport torch\nimport torch.distributed as dist\n\n\ndef split_moe_model_state_dict(moe_keys, model_state_dict):\n    moe_model_state_dict = {}\n    non_moe_model_state_dict = {}\n    for (k, v) in model_state_dict.items():\n        if k in moe_keys:\n            moe_model_state_dict[k] = v\n        else:\n            non_moe_model_state_dict[k] = v\n    return moe_model_state_dict, non_moe_model_state_dict\n\n\ndef merge_moe_model_state_dict(moe_model_state_dict, non_moe_model_state_dict):\n    model_state_dict = {}\n    model_state_dict.update(moe_model_state_dict)\n    model_state_dict.update(non_moe_model_state_dict)\n    return model_state_dict\n\n\ndef load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger):\n    global_rank = dist.get_rank()\n    logger.info(f\"==============> Rank[{global_rank}] Resuming form {config.MODEL.RESUME}....................\")\n    if config.MODEL.RESUME.endswith(f'.pth'):\n        if config.TRAIN.MOE.SAVE_MASTER:\n            resume_path = config.MODEL.RESUME + f'.global'\n        else:\n            resume_path = config.MODEL.RESUME + f'.rank{global_rank}'\n        logger.info(f\"===> Rank[{global_rank}] Re-formatting checkpoint name to {resume_path}......\")\n    else:\n        resume_path = config.MODEL.RESUME\n\n    checkpoint = torch.load(resume_path, map_location='cpu')\n    msg = model.load_state_dict(checkpoint['model'], strict=False)\n    logger.info(msg)\n    max_accuracy = 0.0\n    if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:\n        optimizer.load_state_dict(checkpoint['optimizer'])\n        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n        config.defrost()\n        config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1\n        config.freeze()\n        if 'scaler' in checkpoint:\n            loss_scaler.load_state_dict(checkpoint['scaler'])\n        logger.info(f\"=>Rank[{global_rank}] loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})\")\n        if 'max_accuracy' in checkpoint:\n            max_accuracy = checkpoint['max_accuracy']\n\n    del checkpoint\n    torch.cuda.empty_cache()\n    return max_accuracy\n\n\ndef load_pretrained(config, model, logger):\n    global_rank = dist.get_rank()\n    logger.info(f\"==============> Rank[{global_rank}] Loading weight {config.MODEL.PRETRAINED} for fine-tuning......\")\n    if config.MODEL.PRETRAINED.endswith(f'.pth'):\n        if config.TRAIN.MOE.SAVE_MASTER:\n            pretrained_path = config.MODEL.PRETRAINED + f'.global'\n        else:\n            pretrained_path = config.MODEL.PRETRAINED + f'.rank{global_rank}'\n        logger.info(f\"===> Rank[{global_rank}] Re-formatting checkpoint name to {pretrained_path}......\")\n    else:\n        pretrained_path = config.MODEL.PRETRAINED\n\n    if pretrained_path.endswith(f'.rank{global_rank}'):\n        checkpoint = torch.load(pretrained_path, map_location='cpu')\n        if os.path.exists(pretrained_path.replace(f'.rank{global_rank}', f'.master')):\n            checkpoint_master = torch.load(pretrained_path.replace(f'.rank{global_rank}', f'.master'),\n                                           map_location='cpu')\n            state_dict = merge_moe_model_state_dict(checkpoint['model'], checkpoint_master['model'])\n        else:\n            state_dict = checkpoint['model']\n    elif pretrained_path.endswith(f'.pth.global'):\n        checkpoint = torch.load(pretrained_path, map_location='cpu')\n        state_dict = checkpoint['model']\n    else:\n        raise NotImplementedError(f\"{config.MODEL.PRETRAINED} file error...\")\n\n    # delete relative_position_index since we always re-init it\n    relative_position_index_keys = [k for k in state_dict.keys() if \"relative_position_index\" in k]\n    for k in relative_position_index_keys:\n        del state_dict[k]\n\n    # delete relative_coords_table since we always re-init it\n    relative_position_index_keys = [k for k in state_dict.keys() if \"relative_coords_table\" in k]\n    for k in relative_position_index_keys:\n        del state_dict[k]\n\n    # delete attn_mask since we always re-init it\n    attn_mask_keys = [k for k in state_dict.keys() if \"attn_mask\" in k]\n    for k in attn_mask_keys:\n        del state_dict[k]\n\n    # bicubic interpolate relative_position_bias_table if not match\n    relative_position_bias_table_keys = [k for k in state_dict.keys() if \"relative_position_bias_table\" in k]\n    for k in relative_position_bias_table_keys:\n        relative_position_bias_table_pretrained = state_dict[k]\n        relative_position_bias_table_current = model.state_dict()[k]\n        L1, nH1 = relative_position_bias_table_pretrained.size()\n        L2, nH2 = relative_position_bias_table_current.size()\n        if nH1 != nH2:\n            logger.warning(f\"Error in loading {k}, passing......\")\n        else:\n            if L1 != L2:\n                # bicubic interpolate relative_position_bias_table if not match\n                S1 = int(L1 ** 0.5)\n                S2 = int(L2 ** 0.5)\n                relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(\n                    relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2),\n                    mode='bicubic')\n                state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)\n\n    # bicubic interpolate absolute_pos_embed if not match\n    absolute_pos_embed_keys = [k for k in state_dict.keys() if \"absolute_pos_embed\" in k]\n    for k in absolute_pos_embed_keys:\n        # dpe\n        absolute_pos_embed_pretrained = state_dict[k]\n        absolute_pos_embed_current = model.state_dict()[k]\n        _, L1, C1 = absolute_pos_embed_pretrained.size()\n        _, L2, C2 = absolute_pos_embed_current.size()\n        if C1 != C1:\n            logger.warning(f\"Error in loading {k}, passing......\")\n        else:\n            if L1 != L2:\n                S1 = int(L1 ** 0.5)\n                S2 = int(L2 ** 0.5)\n                absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)\n                absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)\n                absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(\n                    absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')\n                absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)\n                absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2)\n                state_dict[k] = absolute_pos_embed_pretrained_resized\n\n    # check classifier, if not match, then re-init classifier to zero\n    head_bias_pretrained = state_dict['head.bias']\n    Nc1 = head_bias_pretrained.shape[0]\n    Nc2 = model.head.bias.shape[0]\n    if (Nc1 != Nc2):\n        if Nc1 == 21841 and Nc2 == 1000:\n            logger.info(\"loading ImageNet-22K weight to ImageNet-1K ......\")\n            map22kto1k_path = f'data/map22kto1k.txt'\n            with open(map22kto1k_path) as f:\n                map22kto1k = f.readlines()\n            map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]\n            state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :]\n            state_dict['head.bias'] = state_dict['head.bias'][map22kto1k]\n        else:\n            torch.nn.init.constant_(model.head.bias, 0.)\n            torch.nn.init.constant_(model.head.weight, 0.)\n            del state_dict['head.weight']\n            del state_dict['head.bias']\n            logger.warning(f\"Error in loading classifier head, re-init classifier head to 0\")\n\n    msg = model.load_state_dict(state_dict, strict=False)\n    logger.warning(msg)\n\n    logger.info(f\"=> loaded successfully '{config.MODEL.PRETRAINED}'\")\n\n    del checkpoint\n    torch.cuda.empty_cache()\n\n\ndef save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger,\n                    zero_redundancy=False):\n    global_rank = dist.get_rank()\n\n    if zero_redundancy:\n        if config.TRAIN.MOE.SAVE_MASTER:\n            save_state = {'model': model.state_dict()}\n            if global_rank == 0:\n                save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.global')\n                logger.info(f\"{save_path} saving......\")\n                torch.save(save_state, save_path)\n                logger.info(f\"{save_path} saved !!!\")\n        else:\n            moe_model_state_dict, non_moe_model_state_dict = \\\n                split_moe_model_state_dict(model._ddp_params_and_buffers_to_ignore, model.state_dict())\n            save_state = {'model': moe_model_state_dict}\n            save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.rank{global_rank}')\n            logger.info(f\"{save_path} saving......\")\n            torch.save(save_state, save_path)\n            logger.info(f\"{save_path} saved !!!\")\n            if global_rank == 0:\n                save_state_master = {'model': non_moe_model_state_dict}\n                save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.master')\n                logger.info(f\"{save_path} saving......\")\n                torch.save(save_state_master, save_path)\n                logger.info(f\"{save_path} saved !!!\")\n    else:\n        save_state = {'model': model.state_dict(),\n                      'optimizer': optimizer.state_dict(),\n                      'lr_scheduler': lr_scheduler.state_dict(),\n                      'max_accuracy': max_accuracy,\n                      'scaler': loss_scaler.state_dict(),\n                      'epoch': epoch,\n                      'config': config}\n        if config.TRAIN.MOE.SAVE_MASTER:\n            if global_rank == 0:\n                save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.global')\n                logger.info(f\"{save_path} saving......\")\n                torch.save(save_state, save_path)\n                logger.info(f\"{save_path} saved !!!\")\n        else:\n            save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.rank{global_rank}')\n            logger.info(f\"{save_path} saving......\")\n            torch.save(save_state, save_path)\n            logger.info(f\"{save_path} saved !!!\")\n\n\ndef auto_resume_helper(output_dir, save_master=False):\n    global_rank = dist.get_rank()\n    checkpoints = os.listdir(output_dir)\n    if not save_master:\n        master_checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith(f'pth.rank0')]\n    else:\n        master_checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith(f'pth.global')]\n    print(f\"All master checkpoints founded in {output_dir}: {master_checkpoints}\")\n    if len(master_checkpoints) > 0:\n        latest_master_checkpoint = max([os.path.join(output_dir, d) for d in master_checkpoints], key=os.path.getmtime)\n        latest_checkpoint = latest_master_checkpoint.replace('pth.rank0', f'pth.rank{global_rank}')\n        print(f\"The latest checkpoint founded: {latest_checkpoint}\")\n        resume_file = latest_checkpoint\n    else:\n        resume_file = None\n    return resume_file\n\n\ndef hook_scale_grad(scale, tensor):\n    return tensor / scale\n"
  },
  {
    "path": "utils_simmim.py",
    "content": "# --------------------------------------------------------\n# SimMIM\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# Modified by Zhenda Xie\n# --------------------------------------------------------\n\nimport os\nimport torch\nimport torch.distributed as dist\nimport numpy as np\nfrom scipy import interpolate\n\n\ndef load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger):\n    logger.info(f\">>>>>>>>>> Resuming from {config.MODEL.RESUME} ..........\")\n    if config.MODEL.RESUME.startswith('https'):\n        checkpoint = torch.hub.load_state_dict_from_url(\n            config.MODEL.RESUME, map_location='cpu', check_hash=True)\n    else:\n        checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')\n\n    # re-map keys due to name change (only for loading provided models)\n    rpe_mlp_keys = [k for k in checkpoint['model'].keys() if \"rpe_mlp\" in k]\n    for k in rpe_mlp_keys:\n        checkpoint['model'][k.replace('rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k)\n    \n    msg = model.load_state_dict(checkpoint['model'], strict=False)\n    logger.info(msg)\n\n    max_accuracy = 0.0\n    if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'scaler' in checkpoint and 'epoch' in checkpoint:\n        optimizer.load_state_dict(checkpoint['optimizer'])\n        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n        scaler.load_state_dict(checkpoint['scaler'])\n\n        config.defrost()\n        config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1\n        config.freeze()\n\n        logger.info(f\"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})\")\n        if 'max_accuracy' in checkpoint:\n            max_accuracy = checkpoint['max_accuracy']\n        else:\n            max_accuracy = 0.0\n\n    del checkpoint\n    torch.cuda.empty_cache()\n    return max_accuracy\n\n\ndef save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, scaler, logger):\n    save_state = {'model': model.state_dict(),\n                  'optimizer': optimizer.state_dict(),\n                  'lr_scheduler': lr_scheduler.state_dict(),\n                  'scaler': scaler.state_dict(),\n                  'max_accuracy': max_accuracy,\n                  'epoch': epoch,\n                  'config': config}\n\n    save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')\n    logger.info(f\"{save_path} saving......\")\n    torch.save(save_state, save_path)\n    logger.info(f\"{save_path} saved !!!\")\n\n\ndef get_grad_norm(parameters, norm_type=2):\n    if isinstance(parameters, torch.Tensor):\n        parameters = [parameters]\n    parameters = list(filter(lambda p: p.grad is not None, parameters))\n    norm_type = float(norm_type)\n    total_norm = 0\n    for p in parameters:\n        param_norm = p.grad.data.norm(norm_type)\n        total_norm += param_norm.item() ** norm_type\n    total_norm = total_norm ** (1. / norm_type)\n    return total_norm\n\n\ndef auto_resume_helper(output_dir, logger):\n    checkpoints = os.listdir(output_dir)\n    checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]\n    logger.info(f\"All checkpoints founded in {output_dir}: {checkpoints}\")\n    if len(checkpoints) > 0:\n        latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)\n        logger.info(f\"The latest checkpoint founded: {latest_checkpoint}\")\n        resume_file = latest_checkpoint\n    else:\n        resume_file = None\n    return resume_file\n\n\ndef reduce_tensor(tensor):\n    rt = tensor.clone()\n    dist.all_reduce(rt, op=dist.ReduceOp.SUM)\n    rt /= dist.get_world_size()\n    return rt\n\n\ndef load_pretrained(config, model, logger):\n    logger.info(f\">>>>>>>>>> Fine-tuned from {config.MODEL.PRETRAINED} ..........\")\n    checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')\n    checkpoint_model = checkpoint['model']\n    \n    if any([True if 'encoder.' in k else False for k in checkpoint_model.keys()]):\n        checkpoint_model = {k.replace('encoder.', ''): v for k, v in checkpoint_model.items() if k.startswith('encoder.')}\n        logger.info('Detect pre-trained model, remove [encoder.] prefix.')\n    else:\n        logger.info('Detect non-pre-trained model, pass without doing anything.')\n\n    if config.MODEL.TYPE in ['swin', 'swinv2']:\n        logger.info(f\">>>>>>>>>> Remapping pre-trained keys for SWIN ..........\")\n        checkpoint = remap_pretrained_keys_swin(model, checkpoint_model, logger)\n    else:\n        raise NotImplementedError\n\n    msg = model.load_state_dict(checkpoint_model, strict=False)\n    logger.info(msg)\n    \n    del checkpoint\n    torch.cuda.empty_cache()\n    logger.info(f\">>>>>>>>>> loaded successfully '{config.MODEL.PRETRAINED}'\")\n    \n\ndef remap_pretrained_keys_swin(model, checkpoint_model, logger):\n    state_dict = model.state_dict()\n    \n    # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size\n    all_keys = list(checkpoint_model.keys())\n    for key in all_keys:\n        if \"relative_position_bias_table\" in key:\n            relative_position_bias_table_pretrained = checkpoint_model[key]\n            relative_position_bias_table_current = state_dict[key]\n            L1, nH1 = relative_position_bias_table_pretrained.size()\n            L2, nH2 = relative_position_bias_table_current.size()\n            if nH1 != nH2:\n                logger.info(f\"Error in loading {key}, passing......\")\n            else:\n                if L1 != L2:\n                    logger.info(f\"{key}: Interpolate relative_position_bias_table using geo.\")\n                    src_size = int(L1 ** 0.5)\n                    dst_size = int(L2 ** 0.5)\n\n                    def geometric_progression(a, r, n):\n                        return a * (1.0 - r ** n) / (1.0 - r)\n\n                    left, right = 1.01, 1.5\n                    while right - left > 1e-6:\n                        q = (left + right) / 2.0\n                        gp = geometric_progression(1, q, src_size // 2)\n                        if gp > dst_size // 2:\n                            right = q\n                        else:\n                            left = q\n\n                    # if q > 1.090307:\n                    #     q = 1.090307\n\n                    dis = []\n                    cur = 1\n                    for i in range(src_size // 2):\n                        dis.append(cur)\n                        cur += q ** (i + 1)\n\n                    r_ids = [-_ for _ in reversed(dis)]\n\n                    x = r_ids + [0] + dis\n                    y = r_ids + [0] + dis\n\n                    t = dst_size // 2.0\n                    dx = np.arange(-t, t + 0.1, 1.0)\n                    dy = np.arange(-t, t + 0.1, 1.0)\n\n                    logger.info(\"Original positions = %s\" % str(x))\n                    logger.info(\"Target positions = %s\" % str(dx))\n\n                    all_rel_pos_bias = []\n\n                    for i in range(nH1):\n                        z = relative_position_bias_table_pretrained[:, i].view(src_size, src_size).float().numpy()\n                        f_cubic = interpolate.interp2d(x, y, z, kind='cubic')\n                        all_rel_pos_bias.append(torch.Tensor(f_cubic(dx, dy)).contiguous().view(-1, 1).to(\n                            relative_position_bias_table_pretrained.device))\n\n                    new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)\n                    checkpoint_model[key] = new_rel_pos_bias\n\n    # delete relative_position_index since we always re-init it\n    relative_position_index_keys = [k for k in checkpoint_model.keys() if \"relative_position_index\" in k]\n    for k in relative_position_index_keys:\n        del checkpoint_model[k]\n\n    # delete relative_coords_table since we always re-init it\n    relative_coords_table_keys = [k for k in checkpoint_model.keys() if \"relative_coords_table\" in k]\n    for k in relative_coords_table_keys:\n        del checkpoint_model[k]\n\n    # re-map keys due to name change\n    rpe_mlp_keys = [k for k in checkpoint_model.keys() if \"rpe_mlp\" in k]\n    for k in rpe_mlp_keys:\n        checkpoint_model[k.replace('rpe_mlp', 'cpb_mlp')] = checkpoint_model.pop(k)\n\n    # delete attn_mask since we always re-init it\n    attn_mask_keys = [k for k in checkpoint_model.keys() if \"attn_mask\" in k]\n    for k in attn_mask_keys:\n        del checkpoint_model[k]\n\n    return checkpoint_model\n"
  }
]