Repository: microsoft/Swin-Transformer Branch: main Commit: f82860bfb522 Files: 86 Total size: 392.8 KB Directory structure: gitextract_gk47a2w1/ ├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── MODELHUB.md ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── config.py ├── configs/ │ ├── simmim/ │ │ ├── simmim_finetune__swin_base__img224_window7__800ep.yaml │ │ ├── simmim_finetune__swinv2_base__img224_window14__800ep.yaml │ │ ├── simmim_pretrain__swin_base__img192_window6__800ep.yaml │ │ └── simmim_pretrain__swinv2_base__img192_window12__800ep.yaml │ ├── swin/ │ │ ├── swin_base_patch4_window12_384_22kto1k_finetune.yaml │ │ ├── swin_base_patch4_window12_384_finetune.yaml │ │ ├── swin_base_patch4_window7_224.yaml │ │ ├── swin_base_patch4_window7_224_22k.yaml │ │ ├── swin_base_patch4_window7_224_22kto1k_finetune.yaml │ │ ├── swin_large_patch4_window12_384_22kto1k_finetune.yaml │ │ ├── swin_large_patch4_window7_224_22k.yaml │ │ ├── swin_large_patch4_window7_224_22kto1k_finetune.yaml │ │ ├── swin_small_patch4_window7_224.yaml │ │ ├── swin_small_patch4_window7_224_22k.yaml │ │ ├── swin_small_patch4_window7_224_22kto1k_finetune.yaml │ │ ├── swin_tiny_c24_patch4_window8_256.yaml │ │ ├── swin_tiny_patch4_window7_224.yaml │ │ ├── swin_tiny_patch4_window7_224_22k.yaml │ │ └── swin_tiny_patch4_window7_224_22kto1k_finetune.yaml │ ├── swinmlp/ │ │ ├── swin_mlp_base_patch4_window7_224.yaml │ │ ├── swin_mlp_tiny_c12_patch4_window8_256.yaml │ │ ├── swin_mlp_tiny_c24_patch4_window8_256.yaml │ │ └── swin_mlp_tiny_c6_patch4_window8_256.yaml │ ├── swinmoe/ │ │ ├── swin_moe_base_patch4_window12_192_16expert_32gpu_22k.yaml │ │ ├── swin_moe_base_patch4_window12_192_32expert_32gpu_22k.yaml │ │ ├── swin_moe_base_patch4_window12_192_8expert_32gpu_22k.yaml │ │ ├── swin_moe_base_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml │ │ ├── swin_moe_base_patch4_window12_192_densebaseline_22k.yaml │ │ ├── swin_moe_small_patch4_window12_192_16expert_32gpu_22k.yaml │ │ ├── swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml │ │ ├── swin_moe_small_patch4_window12_192_64expert_64gpu_22k.yaml │ │ ├── swin_moe_small_patch4_window12_192_8expert_32gpu_22k.yaml │ │ ├── swin_moe_small_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml │ │ └── swin_moe_small_patch4_window12_192_densebaseline_22k.yaml │ └── swinv2/ │ ├── swinv2_base_patch4_window12_192_22k.yaml │ ├── swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml │ ├── swinv2_base_patch4_window12to24_192to384_22kto1k_ft.yaml │ ├── swinv2_base_patch4_window16_256.yaml │ ├── swinv2_base_patch4_window8_256.yaml │ ├── swinv2_large_patch4_window12_192_22k.yaml │ ├── swinv2_large_patch4_window12to16_192to256_22kto1k_ft.yaml │ ├── swinv2_large_patch4_window12to24_192to384_22kto1k_ft.yaml │ ├── swinv2_small_patch4_window16_256.yaml │ ├── swinv2_small_patch4_window8_256.yaml │ ├── swinv2_tiny_patch4_window16_256.yaml │ └── swinv2_tiny_patch4_window8_256.yaml ├── data/ │ ├── __init__.py │ ├── build.py │ ├── cached_image_folder.py │ ├── data_simmim_ft.py │ ├── data_simmim_pt.py │ ├── imagenet22k_dataset.py │ ├── map22kto1k.txt │ ├── samplers.py │ └── zipreader.py ├── get_started.md ├── kernels/ │ └── window_process/ │ ├── setup.py │ ├── swin_window_process.cpp │ ├── swin_window_process_kernel.cu │ ├── unit_test.py │ └── window_process.py ├── logger.py ├── lr_scheduler.py ├── main.py ├── main_moe.py ├── main_simmim_ft.py ├── main_simmim_pt.py ├── models/ │ ├── __init__.py │ ├── build.py │ ├── simmim.py │ ├── swin_mlp.py │ ├── swin_transformer.py │ ├── swin_transformer_moe.py │ └── swin_transformer_v2.py ├── optimizer.py ├── utils.py ├── utils_moe.py └── utils_simmim.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # launch bash *.sh # nsight system report files *.nsys-rep *.sqlite # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Microsoft Open Source Code of Conduct This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). Resources: - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) Microsoft Corporation. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE ================================================ FILE: MODELHUB.md ================================================ Access code for `baidu` is `swin`. ## ImageNet-1K and ImageNet-22K Pretrained Swin-V1 Models | name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS| 22K model | 1K model | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---: |:---: | | Swin-T | ImageNet-1K | 224x224 | 81.2 | 95.5 | 28M | 4.5G | 755 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w)/[config](configs/swin/swin_tiny_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745562/log_swin_tiny_patch4_window7_224.txt) | | Swin-S | ImageNet-1K | 224x224 | 83.2 | 96.2 | 50M | 8.7G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg)/[config](configs/swin/swin_small_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745563/log_swin_small_patch4_window7_224.txt) | | Swin-B | ImageNet-1K | 224x224 | 83.5 | 96.5 | 88M | 15.4G | 278 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ)/[config](configs/swin/swin_base_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745564/log_swin_base_patch4_window7_224.txt) | | Swin-B | ImageNet-1K | 384x384 | 84.5 | 97.0 | 88M | 47.1G | 85 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw)/[config](configs/swin/swin_base_patch4_window12_384_finetune.yaml) | | Swin-T | ImageNet-22K | 224x224 | 80.9 | 96.0 | 28M | 4.5G | 755 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1vct0VYwwQQ8PYkBjwSSBZQ?pwd=swin)/[config](configs/swin/swin_tiny_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22kto1k_finetune.pth)/[baidu](https://pan.baidu.com/s/1K0OO-nGZDPkR8fm_r83e8Q?pwd=swin)/[config](configs/swin/swin_tiny_patch4_window7_224_22kto1k_finetune.yaml) | | Swin-S | ImageNet-22K | 224x224 | 83.2 | 97.0 | 50M | 8.7G | 437 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/11NC1xdT5BAGBgazdTme5Sg?pwd=swin)/[config](configs/swin/swin_small_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22kto1k_finetune.pth)/[baidu](https://pan.baidu.com/s/10RFVfjQJhwPfeHrmxQUaLw?pwd=swin)/[config](configs/swin/swin_small_patch4_window7_224_22kto1k_finetune.yaml) | | Swin-B | ImageNet-22K | 224x224 | 85.2 | 97.5 | 88M | 15.4G | 278 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA)/[config](configs/swin/swin_base_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg)/[config](configs/swin/swin_base_patch4_window7_224_22kto1k_finetune.yaml) | | Swin-B | ImageNet-22K | 384x384 | 86.4 | 98.0 | 88M | 47.1G | 85 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg)/[config](configs/swin/swin_base_patch4_window12_384_22kto1k_finetune.yaml) | | Swin-L | ImageNet-22K | 224x224 | 86.3 | 97.9 | 197M | 34.5G | 141 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w)/[config](configs/swin/swin_large_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ)/[config](configs/swin/swin_large_patch4_window7_224_22kto1k_finetune.yaml) | | Swin-L | ImageNet-22K | 384x384 | 87.3 | 98.2 | 197M | 103.9G | 42 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA)/[config](configs/swin/swin_large_patch4_window12_384_22kto1k_finetune.yaml) | ## ImageNet-1K and ImageNet-22K Pretrained Swin-V2 Models | name | pretrain | resolution | window |acc@1 | acc@5 | #params | FLOPs | FPS |22K model | 1K model | |:---------------------:| :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---:|:---: |:---: | | SwinV2-T | ImageNet-1K | 256x256 | 8x8 | 81.8 | 95.9 | 28M | 5.9G | 572 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1RzLkAH_5OtfRCJe6Vlg6rg?pwd=swin)/[config](configs/swinv2/swinv2_tiny_patch4_window8_256.yaml) | | SwinV2-S | ImageNet-1K | 256x256 | 8x8 | 83.7 | 96.6 | 50M | 11.5G | 327 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/195PdA41szEduW3jEtRSa4Q?pwd=swin)/[config](configs/swinv2/swinv2_small_patch4_window8_256.yaml) | | SwinV2-B | ImageNet-1K | 256x256 | 8x8 | 84.2 | 96.9 | 88M | 20.3G | 217 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/18AfMSz3dPyzIvP1dKuERvQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window8_256.yaml) | | SwinV2-T | ImageNet-1K | 256x256 | 16x16 | 82.8 | 96.2 | 28M | 6.6G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1dyK3cK9Xipmv6RnTtrPocw?pwd=swin)/[config](configs/swinv2/swinv2_tiny_patch4_window16_256.yaml) | | SwinV2-S | ImageNet-1K | 256x256 | 16x16 | 84.1 | 96.8 | 50M | 12.6G | 257 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1ZIPiSfWNKTPp821Ka-Mifw?pwd=swin)/[config](configs/swinv2/swinv2_small_patch4_window16_256.yaml) | | SwinV2-B | ImageNet-1K | 256x256 | 16x16 | 84.6 | 97.0 | 88M | 21.8G | 174 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1dlDQGn8BXCmnh7wQSM5Nhw?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window16_256.yaml) | | SwinV2-B\* | ImageNet-22K | 256x256 | 16x16 | 86.2 | 97.9 | 88M | 21.8G | 174 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/1Xc2rsSsRQz_sy5mjgfxrMQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/1sgstld4MgGsZxhUAW7MlmQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml) | | SwinV2-B\* | ImageNet-22K | 384x384 | 24x24 | 87.1 | 98.2 | 88M | 54.7G | 57 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/1Xc2rsSsRQz_sy5mjgfxrMQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/17u3sEQaUYlvfL195rrORzQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.yaml) | | SwinV2-L\* | ImageNet-22K | 256x256 | 16x16 | 86.9 | 98.0 | 197M | 47.5G | 95 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/11PhCV7qAGXtZ8dXNgyiGOw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/1pqp31N80qIWjFPbudzB6Bw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.yaml) | | SwinV2-L\* | ImageNet-22K | 384x384 | 24x24 | 87.6 | 98.3 | 197M | 115.4G | 33 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/11PhCV7qAGXtZ8dXNgyiGOw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/13URdNkygr3Xn0N3e6IwjgA?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.yaml) | Note: - SwinV2-B\* (SwinV2-L\*) with input resolution of 256x256 and 384x384 both fine-tuned from the same pre-training model using a smaller input resolution of 192x192. - SwinV2-B\* (384x384) achieves 78.08 acc@1 on ImageNet-1K-V2 while SwinV2-L\* (384x384) achieves 78.31. ## ImageNet-1K Pretrained Swin MLP Models | name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS | 1K model | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | | [Mixer-B/16](https://arxiv.org/pdf/2105.01601.pdf) | ImageNet-1K | 224x224 | 76.4 | - | 59M | 12.7G | - | [official repo](https://github.com/google-research/vision_transformer) | | [ResMLP-S24](https://arxiv.org/abs/2105.03404) | ImageNet-1K | 224x224 | 79.4 | - | 30M | 6.0G | 715 | [timm](https://github.com/rwightman/pytorch-image-models) | | [ResMLP-B24](https://arxiv.org/abs/2105.03404) | ImageNet-1K | 224x224 | 81.0 | - | 116M | 23.0G | 231 | [timm](https://github.com/rwightman/pytorch-image-models) | | Swin-T/C24 | ImageNet-1K | 256x256 | 81.6 | 95.7 | 28M | 5.9G | 563 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_tiny_c24_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/17k-7l6Sxt7uZ7IV0f26GNQ)/[config](configs/swin/swin_tiny_c24_patch4_window8_256.yaml) | | SwinMLP-T/C24 | ImageNet-1K | 256x256 | 79.4 | 94.6 | 20M | 4.0G | 807 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c24_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1Sa4vP5R0M2RjfIe9HIga-Q)/[config](configs/swin/swin_mlp_tiny_c24_patch4_window8_256.yaml) | | SwinMLP-T/C12 | ImageNet-1K | 256x256 | 79.6 | 94.7 | 21M | 4.0G | 792 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c12_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1mM9J2_DEVZHUB5ASIpFl0w)/[config](configs/swin/swin_mlp_tiny_c12_patch4_window8_256.yaml) | | SwinMLP-T/C6 | ImageNet-1K | 256x256 | 79.7 | 94.9 | 23M | 4.0G | 766 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c6_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1hUTYVT2W1CsjICw-3W-Vjg)/[config](configs/swin/swin_mlp_tiny_c6_patch4_window8_256.yaml) | | SwinMLP-B | ImageNet-1K | 224x224 | 81.3 | 95.3 | 61M | 10.4G | 409 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1zww3dnbX3GxNiGfb-GwyUg)/[config](configs/swin/swin_mlp_base_patch4_window7_224.yaml) | Note: C24 means each head has 24 channels. ## ImageNet-22K Pretrained Swin-MoE Models | name | #experts | k | router | resolution | window | IN-22K acc@1 | IN-1K/ft acc@1 | IN-1K/5-shot acc@1 | 22K model | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | | Swin-MoE-S | 1 (dense) | - | - | 192x192 | 8x8 | 35.5| 83.5 | 70.3 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_densebaseline_22k.zip)/[baidu](https://pan.baidu.com/s/1O1m9jT2pGoago_RiRX914w?pwd=swin)/[config](configs/swinmoe/swin_moe_small_patch4_window12_192_densebaseline_22k.yaml) | | Swin-MoE-S | 8 | 1 | Linear | 192x192 | 8x8 | 36.8 | 84.5 | 75.2 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_8expert_32gpu_22k.zip)/[baidu](https://pan.baidu.com/s/198IlYUrWOxEUp7wNdoJT5Q?pwd=swin)/[config](configs/swinmoe/swin_moe_small_patch4_window12_192_8expert_32gpu_22k.yaml) | | Swin-MoE-S | 16 | 1 | Linear |192x192 | 8x8 | 37.6 | 84.9 | 76.5 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_16expert_32gpu_22k.zip)/[baidu](https://pan.baidu.com/s/1vRQweedtT42VwMTqe9-r2A?pwd=swin)/[config](configs/swinmoe/swin_moe_small_patch4_window12_192_16expert_32gpu_22k.yaml) | | Swin-MoE-S | 32 | 1 | Linear | 192x192 | 8x8 | 37.4 | 84.7 | 75.9 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.zip)/[baidu](https://pan.baidu.com/s/1i7rImt5pwO8gJC-PRRuZwQ?pwd=swin)/[config](configs/swinmoe/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml) | | Swin-MoE-S | 32 | 1 | Cosine | 192x192 | 8x8 | 37.2 | 84.3 | 75.2 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_cosine_router_32expert_32gpu_22k.zip)/[baidu](https://pan.baidu.com/s/1Yghr_12ntSrv01I9yatPDQ?pwd=swin)/[config](configs/swinmoe/swin_moe_small_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml) | | Swin-MoE-S | 64 | 1 | Linear | 192x192 | 8x8 | 37.8 | 84.7 | 75.7 | - | | Swin-MoE-S | 128 | 1 | Linear | 192x192 | 8x8 | 37.4 | 84.5 | 75.4 | - | | Swin-MoE-B | 1 (dense) | - | - | 192x192 | 8x8 | 37.3 | 85.1 | 75.9 | [config](configs/swinmoe/swin_moe_base_patch4_window12_192_densebaseline_22k.yaml) | | Swin-MoE-B | 8 | 1 | Linear | 192x192 | 8x8 | 38.1 | 85.3 | 77.2 | [config](configs/swinmoe/swin_moe_base_patch4_window12_192_8expert_32gpu_22k.yaml) | | Swin-MoE-B | 16 | 1 | Linear | 192x192 | 8x8 | 38.7 | 85.5 | 78.2 | [config](configs/swinmoe/swin_moe_base_patch4_window12_192_16expert_32gpu_22k.yaml) | | Swin-MoE-B | 32 | 1 | Linear | 192x192 | 8x8 | 38.6 | 85.5 | 77.9 | [config](configs/swinmoe/swin_moe_base_patch4_window12_192_32expert_32gpu_22k.yaml) | | Swin-MoE-B | 32 | 1 | Cosine | 192x192 | 8x8 | 38.5 | 85.3 | 77.3 | [config](configs/swinmoe/swin_moe_base_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml) | | Swin-MoE-B | 32 | 2 | Linear | 192x192 | 8x8 | 38.6 | 85.5 | 78.7 | - | ## SimMIM Pretrained Swin-V2 Models > Please note that all SimMIM pretrained Swin-V2 models will be stored in the Huggingface repository starting July 2024. For more details, refer to the [huggingface repository](https://huggingface.co/zdaxie/SimMIM). - **Model size** only includes the backbone weights and excludes weights in the decoders/classification heads. - **Batch size** for all models is set to 2048. - **Validation loss** is calculated on the ImageNet-1K validation set. - **Fine-tuned acc@1** refers to the top-1 accuracy on the ImageNet-1K validation set after fine-tuning. | name | model size | pre-train dataset | pre-train iterations | validation loss | fine-tuned acc@1 | pre-trained model | fine-tuned model | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | | SwinV2-Small | 49M | ImageNet-1K 10% | 125k | 0.4820 | 82.69 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper10_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper10_125k.pth?download=true) | | SwinV2-Small | 49M | ImageNet-1K 10% | 250k | 0.4961 | 83.11 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper10_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper10_250k.pth?download=true) | | SwinV2-Small | 49M | ImageNet-1K 10% | 500k | 0.5115 | 83.17 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper10_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper10_500k.pth?download=true) | | SwinV2-Small | 49M | ImageNet-1K 20% | 125k | 0.4751 | 83.05 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper20_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper20_125k.pth?download=true) | | SwinV2-Small | 49M | ImageNet-1K 20% | 250k | 0.4722 | 83.56 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper20_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper20_250k.pth?download=true) | | SwinV2-Small | 49M | ImageNet-1K 20% | 500k | 0.4734 | 83.75 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper20_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper20_500k.pth?download=true) | | SwinV2-Small | 49M | ImageNet-1K 50% | 125k | 0.4732 | 83.04 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper50_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper50_125k.pth?download=true) | | SwinV2-Small | 49M | ImageNet-1K 50% | 250k | 0.4681 | 83.67 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper50_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper50_250k.pth?download=true) | | SwinV2-Small | 49M | ImageNet-1K 50% | 500k | 0.4646 | 83.96 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper50_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper50_500k.pth?download=true) | | SwinV2-Small | 49M | ImageNet-1K | 125k | 0.4728 | 82.92 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1k_125k.pth?download=true) | | SwinV2-Small | 49M | ImageNet-1K | 250k | 0.4674 | 83.66 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1k_250k.pth?download=true) | | SwinV2-Small | 49M | ImageNet-1K | 500k | 0.4641 | 84.08 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1k_500k.pth?download=true) | | SwinV2-Base | 87M | ImageNet-1K 10% | 125k | 0.4822 | 83.33 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper10_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper10_125k.pth?download=true) | | SwinV2-Base | 87M | ImageNet-1K 10% | 250k | 0.4997 | 83.60 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper10_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper10_250k.pth?download=true) | | SwinV2-Base | 87M | ImageNet-1K 10% | 500k | 0.5112 | 83.41 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper10_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper10_500k.pth?download=true) | | SwinV2-Base | 87M | ImageNet-1K 20% | 125k | 0.4703 | 83.86 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper20_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper20_125k.pth?download=true) | | SwinV2-Base | 87M | ImageNet-1K 20% | 250k | 0.4679 | 84.37 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper20_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper20_250k.pth?download=true) | | SwinV2-Base | 87M | ImageNet-1K 20% | 500k | 0.4711 | 84.61 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper20_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper20_500k.pth?download=true) | | SwinV2-Base | 87M | ImageNet-1K 50% | 125k | 0.4683 | 84.04 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper50_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper50_125k.pth?download=true) | | SwinV2-Base | 87M | ImageNet-1K 50% | 250k | 0.4633 | 84.57 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper50_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper50_250k.pth?download=true) | | SwinV2-Base | 87M | ImageNet-1K 50% | 500k | 0.4598 | 84.95 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper50_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper50_500k.pth?download=true) | | SwinV2-Base | 87M | ImageNet-1K | 125k | 0.4680 | 84.13 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1k_125k.pth?download=true) | | SwinV2-Base | 87M | ImageNet-1K | 250k | 0.4626 | 84.65 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1k_250k.pth?download=true) | | SwinV2-Base | 87M | ImageNet-1K | 500k | 0.4588 | 85.04 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1k_500k.pth?download=true) | | SwinV2-Base | 87M | ImageNet-22K | 125k | 0.4695 | 84.11 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_22k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_22k_125k.pth?download=true) | | SwinV2-Base | 87M | ImageNet-22K | 250k | 0.4649 | 84.57 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_22k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_22k_250k.pth?download=true) | | SwinV2-Base | 87M | ImageNet-22K | 500k | 0.4614 | 85.11 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_22k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_22k_500k.pth?download=true) | | SwinV2-Large | 195M | ImageNet-1K 10% | 125k | 0.4995 | 83.69 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper10_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper10_125k.pth?download=true) | | SwinV2-Large | 195M | ImageNet-1K 10% | 250k | 0.5140 | 83.66 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper10_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper10_250k.pth?download=true) | | SwinV2-Large | 195M | ImageNet-1K 10% | 500k | 0.5150 | 83.50 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper10_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper10_500k.pth?download=true) | | SwinV2-Large | 195M | ImageNet-1K 20% | 125k | 0.4675 | 84.38 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper20_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper20_125k.pth?download=true) | | SwinV2-Large | 195M | ImageNet-1K 20% | 250k | 0.4746 | 84.71 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper20_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper20_250k.pth?download=true) | | SwinV2-Large | 195M | ImageNet-1K 20% | 500k | 0.4960 | 84.59 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper20_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper20_500k.pth?download=true) | | SwinV2-Large | 195M | ImageNet-1K 50% | 125k | 0.4622 | 84.78 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper50_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper50_125k.pth?download=true) | | SwinV2-Large | 195M | ImageNet-1K 50% | 250k | 0.4566 | 85.38 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper50_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper50_250k.pth?download=true) | | SwinV2-Large | 195M | ImageNet-1K 50% | 500k | 0.4530 | 85.80 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper50_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper50_500k.pth?download=true) | | SwinV2-Large | 195M | ImageNet-1K | 125k | 0.4611 | 84.98 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1k_125k.pth?download=true) | | SwinV2-Large | 195M | ImageNet-1K | 250k | 0.4552 | 85.45 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1k_250k.pth?download=true) | | SwinV2-Large | 195M | ImageNet-1K | 500k | 0.4507 | 85.91 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1k_500k.pth?download=true) | | SwinV2-Large | 195M | ImageNet-22K | 125k | 0.4649 | 84.61 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_22k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_22k_125k.pth?download=true) | | SwinV2-Large | 195M | ImageNet-22K | 250k | 0.4586 | 85.39 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_22k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_22k_250k.pth?download=true) | | SwinV2-Large | 195M | ImageNet-22K | 500k | 0.4536 | 85.81 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_22k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_22k_500k.pth?download=true) | | SwinV2-Huge | 655M | ImageNet-1K 20% | 125k | 0.4789 | 84.35 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1kper20_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1kper20_125k.pth?download=true) | | SwinV2-Huge | 655M | ImageNet-1K 20% | 250k | 0.5038 | 84.16 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1kper20_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1kper20_250k.pth?download=true) | | SwinV2-Huge | 655M | ImageNet-1K 20% | 500k | 0.5071 | 83.44 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1kper20_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1kper20_500k.pth?download=true) | | SwinV2-Huge | 655M | ImageNet-1K 50% | 125k | 0.4549 | 85.09 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1kper50_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1kper50_125k.pth?download=true) | | SwinV2-Huge | 655M | ImageNet-1K 50% | 250k | 0.4511 | 85.64 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1kper50_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1kper50_250k.pth?download=true) | | SwinV2-Huge | 655M | ImageNet-1K 50% | 500k | 0.4559 | 85.69 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1kper50_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1kper50_500k.pth?download=true) | | SwinV2-Huge | 655M | ImageNet-1K | 125k | 0.4531 | 85.23 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1k_125k.pth?download=true) | | SwinV2-Huge | 655M | ImageNet-1K | 250k | 0.4464 | 85.90 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1k_250k.pth?download=true) | | SwinV2-Huge | 655M | ImageNet-1K | 500k | 0.4416 | 86.34 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1k_500k.pth?download=true) | | SwinV2-Huge | 655M | ImageNet-22K | 125k | 0.4564 | 85.14 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_22k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_22k_125k.pth?download=true) | | SwinV2-Huge | 655M | ImageNet-22K | 250k | 0.4499 | 85.86 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_22k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_22k_250k.pth?download=true) | | SwinV2-Huge | 655M | ImageNet-22K | 500k | 0.4444 | 86.27 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_22k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_22k_500k.pth?download=true) | | SwinV2-giant | 1.06B | ImageNet-1K 50% | 125k | 0.4534 | 85.44 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_1kper50_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_1kper50_125k.pth?download=true) | | SwinV2-giant | 1.06B | ImageNet-1K 50% | 250k | 0.4515 | 85.76 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_1kper50_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_1kper50_250k.pth?download=true) | | SwinV2-giant | 1.06B | ImageNet-1K 50% | 500k | 0.4719 | 85.51 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_1kper50_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_1kper50_500k.pth?download=true) | | SwinV2-giant | 1.06B | ImageNet-1K | 125k | 0.4513 | 85.57 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_1k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_1k_125k.pth?download=true) | | SwinV2-giant | 1.06B | ImageNet-1K | 250k | 0.4442 | 86.12 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_1k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_1k_250k.pth?download=true) | | SwinV2-giant | 1.06B | ImageNet-1K | 500k | 0.4395 | 86.46 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_1k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_1k_500k.pth?download=true) | | SwinV2-giant | 1.06B | ImageNet-22K | 125k | 0.4544 | 85.39 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_22k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_22k_125k.pth?download=true) | | SwinV2-giant | 1.06B | ImageNet-22K | 250k | 0.4475 | 85.96 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_22k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_22k_250k.pth?download=true) | | SwinV2-giant | 1.06B | ImageNet-22K | 500k | 0.4416 | 86.53 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_22k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_22k_500k.pth?download=true) | ## SimMIM Pretrained Swin-V1 Models **ImageNet-1K Pre-trained and Fine-tuned Models** | name | pre-train epochs | pre-train resolution | fine-tune resolution | acc@1 | pre-trained model | fine-tuned model | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | | Swin-Base | 100 | 192x192 | 192x192 | 82.8 | [google](https://drive.google.com/file/d/1Wcbr66JL26FF30Kip9fZa_0lXrDAKP-d/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_pretrain__swin_base__img192_window6__100ep.yaml) | [google](https://drive.google.com/file/d/1RsgHfjB4B1ZYblXEQVT-FPX3WSvBrxcs/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_finetune__swin_base__img192_window6__100ep.yaml) | | Swin-Base | 100 | 192x192 | 224x224 | 83.5 | [google](https://drive.google.com/file/d/1Wcbr66JL26FF30Kip9fZa_0lXrDAKP-d/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_pretrain__swin_base__img192_window6__100ep.yaml) | [google](https://drive.google.com/file/d/1mb43BkW56F5smwiX-g7QUUD7f1Rftq8u/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_finetune__swin_base__img224_window7__100ep.yaml) | | Swin-Base | 800 | 192x192 | 224x224 | 84.0 | [google](https://drive.google.com/file/d/15zENvGjHlM71uKQ3d2FbljWPubtrPtjl/view?usp=sharing)/[config](configs/swin_base__800ep/simmim_pretrain__swin_base__img192_window6__800ep.yaml) | [google](https://drive.google.com/file/d/1xEKyfMTsdh6TfnYhk5vbw0Yz7a-viZ0w/view?usp=sharing)/[config](configs/swin_base__800ep/simmim_finetune__swin_base__img224_window7__800ep.yaml) | | Swin-Large | 800 | 192x192 | 224x224 | 85.4 | [google](https://drive.google.com/file/d/1qDxrTl2YUDB0505_4QrU5LU2R1kKmcBP/view?usp=sharing)/[config](configs/swin_large__800ep/simmim_pretrain__swin_large__img192_window12__800ep.yaml) | [google](https://drive.google.com/file/d/1mf0ZpXttEvFsH87Www4oQ-t8Kwr0x485/view?usp=sharing)/[config](configs/swin_large__800ep/simmim_finetune__swin_large__img224_window14__800ep.yaml) | | SwinV2-Huge | 800 | 192x192 | 224x224 | 85.7 | / | / | | SwinV2-Huge | 800 | 192x192 | 512x512 | 87.1 | / | / | ================================================ FILE: README.md ================================================ # Swin Transformer [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-v2-scaling-up-capacity-and/object-detection-on-coco)](https://paperswithcode.com/sota/object-detection-on-coco?p=swin-transformer-v2-scaling-up-capacity-and) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-v2-scaling-up-capacity-and/instance-segmentation-on-coco)](https://paperswithcode.com/sota/instance-segmentation-on-coco?p=swin-transformer-v2-scaling-up-capacity-and) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-v2-scaling-up-capacity-and/semantic-segmentation-on-ade20k)](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k?p=swin-transformer-v2-scaling-up-capacity-and) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-v2-scaling-up-capacity-and/action-classification-on-kinetics-400)](https://paperswithcode.com/sota/action-classification-on-kinetics-400?p=swin-transformer-v2-scaling-up-capacity-and) This repo is the official implementation of ["Swin Transformer: Hierarchical Vision Transformer using Shifted Windows"](https://arxiv.org/pdf/2103.14030.pdf) as well as the follow-ups. It currently includes code and models for the following tasks: > **Image Classification**: Included in this repo. See [get_started.md](get_started.md) for a quick start. > **Object Detection and Instance Segmentation**: See [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection). > **Semantic Segmentation**: See [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation). > **Video Action Recognition**: See [Video Swin Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer). > **Semi-Supervised Object Detection**: See [Soft Teacher](https://github.com/microsoft/SoftTeacher). > **SSL: Contrasitive Learning**: See [Transformer-SSL](https://github.com/SwinTransformer/Transformer-SSL). > **SSL: Masked Image Modeling**: See [get_started.md#simmim-support](https://github.com/microsoft/Swin-Transformer/blob/main/get_started.md#simmim-support). > **Mixture-of-Experts**: See [get_started](get_started.md#mixture-of-experts-support) for more instructions. > **Feature-Distillation**: See [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation). ## Updates ***12/29/2022*** 1. **Nvidia**'s [FasterTransformer](https://github.com/NVIDIA/FasterTransformer/blob/main/docs/swin_guide.md) now supports Swin Transformer V2 inference, which have significant speed improvements on `T4 and A100 GPUs`. ***11/30/2022*** 1. Models and codes of **Feature Distillation** are released. Please refer to [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation) for details, and the checkpoints (FD-EsViT-Swin-B, FD-DeiT-ViT-B, FD-DINO-ViT-B, FD-CLIP-ViT-B, FD-CLIP-ViT-L). ***09/24/2022*** 1. Merged [SimMIM](https://github.com/microsoft/SimMIM), which is a **Masked Image Modeling** based pre-training approach applicable to Swin and SwinV2 (and also applicable for ViT and ResNet). Please refer to [get started with SimMIM](get_started.md#simmim-support) to play with SimMIM pre-training. 2. Released a series of Swin and SwinV2 models pre-trained using the SimMIM approach (see [MODELHUB for SimMIM](MODELHUB.md#simmim-pretrained-swin-v2-models)), with model size ranging from SwinV2-Small-50M to SwinV2-giant-1B, data size ranging from ImageNet-1K-10% to ImageNet-22K, and iterations from 125k to 500k. You may leverage these models to study the properties of MIM methods. Please look into the [data scaling](https://arxiv.org/abs/2206.04664) paper for more details. ***07/09/2022*** `News`: 1. SwinV2-G achieves `61.4 mIoU` on ADE20K semantic segmentation (+1.5 mIoU over the previous SwinV2-G model), using an additional [feature distillation (FD)](https://github.com/SwinTransformer/Feature-Distillation) approach, **setting a new recrod** on this benchmark. FD is an approach that can generally improve the fine-tuning performance of various pre-trained models, including DeiT, DINO, and CLIP. Particularly, it improves CLIP pre-trained ViT-L by +1.6% to reach `89.0%` on ImageNet-1K image classification, which is **the most accurate ViT-L model**. 2. Merged a PR from **Nvidia** that links to faster Swin Transformer inference that have significant speed improvements on `T4 and A100 GPUs`. 3. Merged a PR from **Nvidia** that enables an option to use `pure FP16 (Apex O2)` in training, while almost maintaining the accuracy. ***06/03/2022*** 1. Added **Swin-MoE**, the Mixture-of-Experts variant of Swin Transformer implemented using [Tutel](https://github.com/microsoft/tutel) (an optimized Mixture-of-Experts implementation). **Swin-MoE** is introduced in the [TuTel](https://arxiv.org/abs/2206.03382) paper. ***05/12/2022*** 1. Pretrained models of [Swin Transformer V2](https://arxiv.org/abs/2111.09883) on ImageNet-1K and ImageNet-22K are released. 2. ImageNet-22K pretrained models for Swin-V1-Tiny and Swin-V2-Small are released. ***03/02/2022*** 1. Swin Transformer V2 and SimMIM got accepted by CVPR 2022. [SimMIM](https://github.com/microsoft/SimMIM) is a self-supervised pre-training approach based on masked image modeling, a key technique that works out the 3-billion-parameter Swin V2 model using `40x less labelled data` than that of previous billion-scale models based on JFT-3B. ***02/09/2022*** 1. Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/akhaliq/Swin-Transformer) ***10/12/2021*** 1. Swin Transformer received ICCV 2021 best paper award (Marr Prize). ***08/09/2021*** 1. [Soft Teacher](https://arxiv.org/pdf/2106.09018v2.pdf) will appear at ICCV2021. The code will be released at [GitHub Repo](https://github.com/microsoft/SoftTeacher). `Soft Teacher` is an end-to-end semi-supervisd object detection method, achieving a new record on the COCO test-dev: `61.3 box AP` and `53.0 mask AP`. ***07/03/2021*** 1. Add **Swin MLP**, which is an adaption of `Swin Transformer` by replacing all multi-head self-attention (MHSA) blocks by MLP layers (more precisely it is a group linear layer). The shifted window configuration can also significantly improve the performance of vanilla MLP architectures. ***06/25/2021*** 1. [Video Swin Transformer](https://arxiv.org/abs/2106.13230) is released at [Video-Swin-Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer). `Video Swin Transformer` achieves state-of-the-art accuracy on a broad range of video recognition benchmarks, including action recognition (`84.9` top-1 accuracy on Kinetics-400 and `86.1` top-1 accuracy on Kinetics-600 with `~20x` less pre-training data and `~3x` smaller model size) and temporal modeling (`69.6` top-1 accuracy on Something-Something v2). ***05/12/2021*** 1. Used as a backbone for `Self-Supervised Learning`: [Transformer-SSL](https://github.com/SwinTransformer/Transformer-SSL) Using Swin-Transformer as the backbone for self-supervised learning enables us to evaluate the transferring performance of the learnt representations on down-stream tasks, which is missing in previous works due to the use of ViT/DeiT, which has not been well tamed for down-stream tasks. ***04/12/2021*** Initial commits: 1. Pretrained models on ImageNet-1K ([Swin-T-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth), [Swin-S-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth), [Swin-B-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)) and ImageNet-22K ([Swin-B-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth), [Swin-L-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)) are provided. 2. The supported code and models for ImageNet-1K image classification, COCO object detection and ADE20K semantic segmentation are provided. 3. The cuda kernel implementation for the [local relation layer](https://arxiv.org/pdf/1904.11491.pdf) is provided in branch [LR-Net](https://github.com/microsoft/Swin-Transformer/tree/LR-Net). ## Introduction **Swin Transformer** (the name `Swin` stands for **S**hifted **win**dow) is initially described in [arxiv](https://arxiv.org/abs/2103.14030), which capably serves as a general-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping local windows while also allowing for cross-window connection. Swin Transformer achieves strong performance on COCO object detection (`58.7 box AP` and `51.1 mask AP` on test-dev) and ADE20K semantic segmentation (`53.5 mIoU` on val), surpassing previous models by a large margin. ![teaser](figures/teaser.png) ## Main Results on ImageNet with Pretrained Models **ImageNet-1K and ImageNet-22K Pretrained Swin-V1 Models** | name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS| 22K model | 1K model | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---: |:---: | | Swin-T | ImageNet-1K | 224x224 | 81.2 | 95.5 | 28M | 4.5G | 755 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w)/[config](configs/swin/swin_tiny_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745562/log_swin_tiny_patch4_window7_224.txt) | | Swin-S | ImageNet-1K | 224x224 | 83.2 | 96.2 | 50M | 8.7G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg)/[config](configs/swin/swin_small_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745563/log_swin_small_patch4_window7_224.txt) | | Swin-B | ImageNet-1K | 224x224 | 83.5 | 96.5 | 88M | 15.4G | 278 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ)/[config](configs/swin/swin_base_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745564/log_swin_base_patch4_window7_224.txt) | | Swin-B | ImageNet-1K | 384x384 | 84.5 | 97.0 | 88M | 47.1G | 85 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw)/[config](configs/swin/swin_base_patch4_window12_384_finetune.yaml) | | Swin-T | ImageNet-22K | 224x224 | 80.9 | 96.0 | 28M | 4.5G | 755 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1vct0VYwwQQ8PYkBjwSSBZQ?pwd=swin)/[config](configs/swin/swin_tiny_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22kto1k_finetune.pth)/[baidu](https://pan.baidu.com/s/1K0OO-nGZDPkR8fm_r83e8Q?pwd=swin)/[config](configs/swin/swin_tiny_patch4_window7_224_22kto1k_finetune.yaml) | | Swin-S | ImageNet-22K | 224x224 | 83.2 | 97.0 | 50M | 8.7G | 437 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/11NC1xdT5BAGBgazdTme5Sg?pwd=swin)/[config](configs/swin/swin_small_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22kto1k_finetune.pth)/[baidu](https://pan.baidu.com/s/10RFVfjQJhwPfeHrmxQUaLw?pwd=swin)/[config](configs/swin/swin_small_patch4_window7_224_22kto1k_finetune.yaml) | | Swin-B | ImageNet-22K | 224x224 | 85.2 | 97.5 | 88M | 15.4G | 278 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA)/[config](configs/swin/swin_base_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg)/[config](configs/swin/swin_base_patch4_window7_224_22kto1k_finetune.yaml) | | Swin-B | ImageNet-22K | 384x384 | 86.4 | 98.0 | 88M | 47.1G | 85 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg)/[config](configs/swin/swin_base_patch4_window12_384_22kto1k_finetune.yaml) | | Swin-L | ImageNet-22K | 224x224 | 86.3 | 97.9 | 197M | 34.5G | 141 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w)/[config](configs/swin/swin_large_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ)/[config](configs/swin/swin_large_patch4_window7_224_22kto1k_finetune.yaml) | | Swin-L | ImageNet-22K | 384x384 | 87.3 | 98.2 | 197M | 103.9G | 42 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA)/[config](configs/swin/swin_large_patch4_window12_384_22kto1k_finetune.yaml) | **ImageNet-1K and ImageNet-22K Pretrained Swin-V2 Models** | name | pretrain | resolution | window |acc@1 | acc@5 | #params | FLOPs | FPS |22K model | 1K model | |:---------------------:| :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---:|:---: |:---: | | SwinV2-T | ImageNet-1K | 256x256 | 8x8 | 81.8 | 95.9 | 28M | 5.9G | 572 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1RzLkAH_5OtfRCJe6Vlg6rg?pwd=swin)/[config](configs/swinv2/swinv2_tiny_patch4_window8_256.yaml) | | SwinV2-S | ImageNet-1K | 256x256 | 8x8 | 83.7 | 96.6 | 50M | 11.5G | 327 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/195PdA41szEduW3jEtRSa4Q?pwd=swin)/[config](configs/swinv2/swinv2_small_patch4_window8_256.yaml) | | SwinV2-B | ImageNet-1K | 256x256 | 8x8 | 84.2 | 96.9 | 88M | 20.3G | 217 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/18AfMSz3dPyzIvP1dKuERvQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window8_256.yaml) | | SwinV2-T | ImageNet-1K | 256x256 | 16x16 | 82.8 | 96.2 | 28M | 6.6G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1dyK3cK9Xipmv6RnTtrPocw?pwd=swin)/[config](configs/swinv2/swinv2_tiny_patch4_window16_256.yaml) | | SwinV2-S | ImageNet-1K | 256x256 | 16x16 | 84.1 | 96.8 | 50M | 12.6G | 257 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1ZIPiSfWNKTPp821Ka-Mifw?pwd=swin)/[config](configs/swinv2/swinv2_small_patch4_window16_256.yaml) | | SwinV2-B | ImageNet-1K | 256x256 | 16x16 | 84.6 | 97.0 | 88M | 21.8G | 174 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1dlDQGn8BXCmnh7wQSM5Nhw?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window16_256.yaml) | | SwinV2-B\* | ImageNet-22K | 256x256 | 16x16 | 86.2 | 97.9 | 88M | 21.8G | 174 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/1Xc2rsSsRQz_sy5mjgfxrMQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/1sgstld4MgGsZxhUAW7MlmQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml) | | SwinV2-B\* | ImageNet-22K | 384x384 | 24x24 | 87.1 | 98.2 | 88M | 54.7G | 57 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/1Xc2rsSsRQz_sy5mjgfxrMQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/17u3sEQaUYlvfL195rrORzQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.yaml) | | SwinV2-L\* | ImageNet-22K | 256x256 | 16x16 | 86.9 | 98.0 | 197M | 47.5G | 95 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/11PhCV7qAGXtZ8dXNgyiGOw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/1pqp31N80qIWjFPbudzB6Bw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.yaml) | | SwinV2-L\* | ImageNet-22K | 384x384 | 24x24 | 87.6 | 98.3 | 197M | 115.4G | 33 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/11PhCV7qAGXtZ8dXNgyiGOw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/13URdNkygr3Xn0N3e6IwjgA?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.yaml) | Note: - SwinV2-B\* (SwinV2-L\*) with input resolution of 256x256 and 384x384 both fine-tuned from the same pre-training model using a smaller input resolution of 192x192. - SwinV2-B\* (384x384) achieves 78.08 acc@1 on ImageNet-1K-V2 while SwinV2-L\* (384x384) achieves 78.31. **ImageNet-1K Pretrained Swin MLP Models** | name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS | 1K model | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | | [Mixer-B/16](https://arxiv.org/pdf/2105.01601.pdf) | ImageNet-1K | 224x224 | 76.4 | - | 59M | 12.7G | - | [official repo](https://github.com/google-research/vision_transformer) | | [ResMLP-S24](https://arxiv.org/abs/2105.03404) | ImageNet-1K | 224x224 | 79.4 | - | 30M | 6.0G | 715 | [timm](https://github.com/rwightman/pytorch-image-models) | | [ResMLP-B24](https://arxiv.org/abs/2105.03404) | ImageNet-1K | 224x224 | 81.0 | - | 116M | 23.0G | 231 | [timm](https://github.com/rwightman/pytorch-image-models) | | Swin-T/C24 | ImageNet-1K | 256x256 | 81.6 | 95.7 | 28M | 5.9G | 563 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_tiny_c24_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/17k-7l6Sxt7uZ7IV0f26GNQ)/[config](configs/swin/swin_tiny_c24_patch4_window8_256.yaml) | | SwinMLP-T/C24 | ImageNet-1K | 256x256 | 79.4 | 94.6 | 20M | 4.0G | 807 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c24_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1Sa4vP5R0M2RjfIe9HIga-Q)/[config](configs/swin/swin_mlp_tiny_c24_patch4_window8_256.yaml) | | SwinMLP-T/C12 | ImageNet-1K | 256x256 | 79.6 | 94.7 | 21M | 4.0G | 792 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c12_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1mM9J2_DEVZHUB5ASIpFl0w)/[config](configs/swin/swin_mlp_tiny_c12_patch4_window8_256.yaml) | | SwinMLP-T/C6 | ImageNet-1K | 256x256 | 79.7 | 94.9 | 23M | 4.0G | 766 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c6_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1hUTYVT2W1CsjICw-3W-Vjg)/[config](configs/swin/swin_mlp_tiny_c6_patch4_window8_256.yaml) | | SwinMLP-B | ImageNet-1K | 224x224 | 81.3 | 95.3 | 61M | 10.4G | 409 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1zww3dnbX3GxNiGfb-GwyUg)/[config](configs/swin/swin_mlp_base_patch4_window7_224.yaml) | Note: access code for `baidu` is `swin`. C24 means each head has 24 channels. **ImageNet-22K Pretrained Swin-MoE Models** - Please refer to [get_started](get_started.md#mixture-of-experts-support) for instructions on running Swin-MoE. - Pretrained models for Swin-MoE can be found in [MODEL HUB](MODELHUB.md#imagenet-22k-pretrained-swin-moe-models) ## Main Results on Downstream Tasks **COCO Object Detection (2017 val)** | Backbone | Method | pretrain | Lr Schd | box mAP | mask mAP | #params | FLOPs | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | | Swin-T | Mask R-CNN | ImageNet-1K | 3x | 46.0 | 41.6 | 48M | 267G | | Swin-S | Mask R-CNN | ImageNet-1K | 3x | 48.5 | 43.3 | 69M | 359G | | Swin-T | Cascade Mask R-CNN | ImageNet-1K | 3x | 50.4 | 43.7 | 86M | 745G | | Swin-S | Cascade Mask R-CNN | ImageNet-1K | 3x | 51.9 | 45.0 | 107M | 838G | | Swin-B | Cascade Mask R-CNN | ImageNet-1K | 3x | 51.9 | 45.0 | 145M | 982G | | Swin-T | RepPoints V2 | ImageNet-1K | 3x | 50.0 | - | 45M | 283G | | Swin-T | Mask RepPoints V2 | ImageNet-1K | 3x | 50.3 | 43.6 | 47M | 292G | | Swin-B | HTC++ | ImageNet-22K | 6x | 56.4 | 49.1 | 160M | 1043G | | Swin-L | HTC++ | ImageNet-22K | 3x | 57.1 | 49.5 | 284M | 1470G | | Swin-L | HTC++* | ImageNet-22K | 3x | 58.0 | 50.4 | 284M | - | Note: * indicates multi-scale testing. **ADE20K Semantic Segmentation (val)** | Backbone | Method | pretrain | Crop Size | Lr Schd | mIoU | mIoU (ms+flip) | #params | FLOPs | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | | Swin-T | UPerNet | ImageNet-1K | 512x512 | 160K | 44.51 | 45.81 | 60M | 945G | | Swin-S | UperNet | ImageNet-1K | 512x512 | 160K | 47.64 | 49.47 | 81M | 1038G | | Swin-B | UperNet | ImageNet-1K | 512x512 | 160K | 48.13 | 49.72 | 121M | 1188G | | Swin-B | UPerNet | ImageNet-22K | 640x640 | 160K | 50.04 | 51.66 | 121M | 1841G | | Swin-L | UperNet | ImageNet-22K | 640x640 | 160K | 52.05 | 53.53 | 234M | 3230G | ## Citing Swin Transformer ``` @inproceedings{liu2021Swin, title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows}, author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining}, booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, year={2021} } ``` ## Citing Local Relation Networks (the first full-attention visual backbone) ``` @inproceedings{hu2019local, title={Local Relation Networks for Image Recognition}, author={Hu, Han and Zhang, Zheng and Xie, Zhenda and Lin, Stephen}, booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, pages={3464--3473}, year={2019} } ``` ## Citing Swin Transformer V2 ``` @inproceedings{liu2021swinv2, title={Swin Transformer V2: Scaling Up Capacity and Resolution}, author={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo}, booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)}, year={2022} } ``` ## Citing SimMIM (a self-supervised approach that enables SwinV2-G) ``` @inproceedings{xie2021simmim, title={SimMIM: A Simple Framework for Masked Image Modeling}, author={Xie, Zhenda and Zhang, Zheng and Cao, Yue and Lin, Yutong and Bao, Jianmin and Yao, Zhuliang and Dai, Qi and Hu, Han}, booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)}, year={2022} } ``` ## Citing SimMIM-data-scaling ``` @article{xie2022data, title={On Data Scaling in Masked Image Modeling}, author={Xie, Zhenda and Zhang, Zheng and Cao, Yue and Lin, Yutong and Wei, Yixuan and Dai, Qi and Hu, Han}, journal={arXiv preprint arXiv:2206.04664}, year={2022} } ``` ## Citing Swin-MoE ``` @misc{hwang2022tutel, title={Tutel: Adaptive Mixture-of-Experts at Scale}, author={Changho Hwang and Wei Cui and Yifan Xiong and Ziyue Yang and Ze Liu and Han Hu and Zilong Wang and Rafael Salas and Jithin Jose and Prabhat Ram and Joe Chau and Peng Cheng and Fan Yang and Mao Yang and Yongqiang Xiong}, year={2022}, eprint={2206.03382}, archivePrefix={arXiv} } ``` ## Getting Started - For **Image Classification**, please see [get_started.md](get_started.md) for detailed instructions. - For **Object Detection and Instance Segmentation**, please see [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection). - For **Semantic Segmentation**, please see [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation). - For **Self-Supervised Learning**, please see [Transformer-SSL](https://github.com/SwinTransformer/Transformer-SSL). - For **Video Recognition**, please see [Video Swin Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer). ## Third-party Usage and Experiments ***In this pargraph, we cross link third-party repositories which use Swin and report results. You can let us know by raising an issue*** (`Note please report accuracy numbers and provide trained models in your new repository to facilitate others to get sense of correctness and model behavior`) [12/29/2022] Swin Transformers (V2) inference implemented in FasterTransformer: [FasterTransformer](https://github.com/NVIDIA/FasterTransformer/blob/main/docs/swin_guide.md) [06/30/2022] Swin Transformers (V1) inference implemented in FasterTransformer: [FasterTransformer](https://github.com/NVIDIA/FasterTransformer/blob/main/docs/swin_guide.md) [05/12/2022] Swin Transformers (V1) implemented in TensorFlow with the pre-trained parameters ported into them. Find the implementation, TensorFlow weights, code example here in [this repository](https://github.com/sayakpaul/swin-transformers-tf/). [04/06/2022] Swin Transformer for Audio Classification: [Hierarchical Token Semantic Audio Transformer](https://github.com/RetroCirce/HTS-Audio-Transformer). [12/21/2021] Swin Transformer for StyleGAN: [StyleSwin](https://github.com/microsoft/StyleSwin) [12/13/2021] Swin Transformer for Face Recognition: [FaceX-Zoo](https://github.com/JDAI-CV/FaceX-Zoo) [08/29/2021] Swin Transformer for Image Restoration: [SwinIR](https://github.com/JingyunLiang/SwinIR) [08/12/2021] Swin Transformer for person reID: [https://github.com/layumi/Person_reID_baseline_pytorch](https://github.com/layumi/Person_reID_baseline_pytorch) [06/29/2021] Swin-Transformer in PaddleClas and inference based on whl package: [https://github.com/PaddlePaddle/PaddleClas](https://github.com/PaddlePaddle/PaddleClas) [04/14/2021] Swin for RetinaNet in Detectron: https://github.com/xiaohu2015/SwinT_detectron2. [04/16/2021] Included in a famous model zoo: https://github.com/rwightman/pytorch-image-models. [04/20/2021] Swin-Transformer classifier inference using TorchServe: https://github.com/kamalkraj/Swin-Transformer-Serve ## Contributing This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA. This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. ## Trademarks This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies. ================================================ FILE: SECURITY.md ================================================ ## Security Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. ## Reporting Security Issues **Please do not report security vulnerabilities through public GitHub issues.** Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) * Full paths of source file(s) related to the manifestation of the issue * The location of the affected source code (tag/branch/commit or direct URL) * Any special configuration required to reproduce the issue * Step-by-step instructions to reproduce the issue * Proof-of-concept or exploit code (if possible) * Impact of the issue, including how an attacker might exploit the issue This information will help us triage your report more quickly. If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. ## Preferred Languages We prefer all communications to be in English. ## Policy Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). ================================================ FILE: SUPPORT.md ================================================ # TODO: The maintainer of this repo has not yet edited this file **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? - **No CSS support:** Fill out this template with information about how to file issues and get help. - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* # Support ## How to file issues and get help This project uses GitHub Issues to track bugs and feature requests. Please search the existing issues before filing new issues to avoid duplicates. For new issues, file your bug or feature request as a new Issue. For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER CHANNEL. WHERE WILL YOU HELP PEOPLE?**. ## Microsoft Support Policy Support for this **PROJECT or PRODUCT** is limited to the resources listed above. ================================================ FILE: config.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # --------------------------------------------------------' import os import torch import yaml from yacs.config import CfgNode as CN # pytorch major version (1.x or 2.x) PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0]) _C = CN() # Base config files _C.BASE = [''] # ----------------------------------------------------------------------------- # Data settings # ----------------------------------------------------------------------------- _C.DATA = CN() # Batch size for a single GPU, could be overwritten by command line argument _C.DATA.BATCH_SIZE = 128 # Path to dataset, could be overwritten by command line argument _C.DATA.DATA_PATH = '' # Dataset name _C.DATA.DATASET = 'imagenet' # Input image size _C.DATA.IMG_SIZE = 224 # Interpolation to resize image (random, bilinear, bicubic) _C.DATA.INTERPOLATION = 'bicubic' # Use zipped dataset instead of folder dataset # could be overwritten by command line argument _C.DATA.ZIP_MODE = False # Cache Data in Memory, could be overwritten by command line argument _C.DATA.CACHE_MODE = 'part' # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. _C.DATA.PIN_MEMORY = True # Number of data loading threads _C.DATA.NUM_WORKERS = 8 # [SimMIM] Mask patch size for MaskGenerator _C.DATA.MASK_PATCH_SIZE = 32 # [SimMIM] Mask ratio for MaskGenerator _C.DATA.MASK_RATIO = 0.6 # ----------------------------------------------------------------------------- # Model settings # ----------------------------------------------------------------------------- _C.MODEL = CN() # Model type _C.MODEL.TYPE = 'swin' # Model name _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' # Pretrained weight from checkpoint, could be imagenet22k pretrained weight # could be overwritten by command line argument _C.MODEL.PRETRAINED = '' # Checkpoint to resume, could be overwritten by command line argument _C.MODEL.RESUME = '' # Number of classes, overwritten in data preparation _C.MODEL.NUM_CLASSES = 1000 # Dropout rate _C.MODEL.DROP_RATE = 0.0 # Drop path rate _C.MODEL.DROP_PATH_RATE = 0.1 # Label Smoothing _C.MODEL.LABEL_SMOOTHING = 0.1 # Swin Transformer parameters _C.MODEL.SWIN = CN() _C.MODEL.SWIN.PATCH_SIZE = 4 _C.MODEL.SWIN.IN_CHANS = 3 _C.MODEL.SWIN.EMBED_DIM = 96 _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] _C.MODEL.SWIN.WINDOW_SIZE = 7 _C.MODEL.SWIN.MLP_RATIO = 4. _C.MODEL.SWIN.QKV_BIAS = True _C.MODEL.SWIN.QK_SCALE = None _C.MODEL.SWIN.APE = False _C.MODEL.SWIN.PATCH_NORM = True # Swin Transformer V2 parameters _C.MODEL.SWINV2 = CN() _C.MODEL.SWINV2.PATCH_SIZE = 4 _C.MODEL.SWINV2.IN_CHANS = 3 _C.MODEL.SWINV2.EMBED_DIM = 96 _C.MODEL.SWINV2.DEPTHS = [2, 2, 6, 2] _C.MODEL.SWINV2.NUM_HEADS = [3, 6, 12, 24] _C.MODEL.SWINV2.WINDOW_SIZE = 7 _C.MODEL.SWINV2.MLP_RATIO = 4. _C.MODEL.SWINV2.QKV_BIAS = True _C.MODEL.SWINV2.APE = False _C.MODEL.SWINV2.PATCH_NORM = True _C.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0] # Swin Transformer MoE parameters _C.MODEL.SWIN_MOE = CN() _C.MODEL.SWIN_MOE.PATCH_SIZE = 4 _C.MODEL.SWIN_MOE.IN_CHANS = 3 _C.MODEL.SWIN_MOE.EMBED_DIM = 96 _C.MODEL.SWIN_MOE.DEPTHS = [2, 2, 6, 2] _C.MODEL.SWIN_MOE.NUM_HEADS = [3, 6, 12, 24] _C.MODEL.SWIN_MOE.WINDOW_SIZE = 7 _C.MODEL.SWIN_MOE.MLP_RATIO = 4. _C.MODEL.SWIN_MOE.QKV_BIAS = True _C.MODEL.SWIN_MOE.QK_SCALE = None _C.MODEL.SWIN_MOE.APE = False _C.MODEL.SWIN_MOE.PATCH_NORM = True _C.MODEL.SWIN_MOE.MLP_FC2_BIAS = True _C.MODEL.SWIN_MOE.INIT_STD = 0.02 _C.MODEL.SWIN_MOE.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0] _C.MODEL.SWIN_MOE.MOE_BLOCKS = [[-1], [-1], [-1], [-1]] _C.MODEL.SWIN_MOE.NUM_LOCAL_EXPERTS = 1 _C.MODEL.SWIN_MOE.TOP_VALUE = 1 _C.MODEL.SWIN_MOE.CAPACITY_FACTOR = 1.25 _C.MODEL.SWIN_MOE.COSINE_ROUTER = False _C.MODEL.SWIN_MOE.NORMALIZE_GATE = False _C.MODEL.SWIN_MOE.USE_BPR = True _C.MODEL.SWIN_MOE.IS_GSHARD_LOSS = False _C.MODEL.SWIN_MOE.GATE_NOISE = 1.0 _C.MODEL.SWIN_MOE.COSINE_ROUTER_DIM = 256 _C.MODEL.SWIN_MOE.COSINE_ROUTER_INIT_T = 0.5 _C.MODEL.SWIN_MOE.MOE_DROP = 0.0 _C.MODEL.SWIN_MOE.AUX_LOSS_WEIGHT = 0.01 # Swin MLP parameters _C.MODEL.SWIN_MLP = CN() _C.MODEL.SWIN_MLP.PATCH_SIZE = 4 _C.MODEL.SWIN_MLP.IN_CHANS = 3 _C.MODEL.SWIN_MLP.EMBED_DIM = 96 _C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2] _C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24] _C.MODEL.SWIN_MLP.WINDOW_SIZE = 7 _C.MODEL.SWIN_MLP.MLP_RATIO = 4. _C.MODEL.SWIN_MLP.APE = False _C.MODEL.SWIN_MLP.PATCH_NORM = True # [SimMIM] Norm target during training _C.MODEL.SIMMIM = CN() _C.MODEL.SIMMIM.NORM_TARGET = CN() _C.MODEL.SIMMIM.NORM_TARGET.ENABLE = False _C.MODEL.SIMMIM.NORM_TARGET.PATCH_SIZE = 47 # ----------------------------------------------------------------------------- # Training settings # ----------------------------------------------------------------------------- _C.TRAIN = CN() _C.TRAIN.START_EPOCH = 0 _C.TRAIN.EPOCHS = 300 _C.TRAIN.WARMUP_EPOCHS = 20 _C.TRAIN.WEIGHT_DECAY = 0.05 _C.TRAIN.BASE_LR = 5e-4 _C.TRAIN.WARMUP_LR = 5e-7 _C.TRAIN.MIN_LR = 5e-6 # Clip gradient norm _C.TRAIN.CLIP_GRAD = 5.0 # Auto resume from latest checkpoint _C.TRAIN.AUTO_RESUME = True # Gradient accumulation steps # could be overwritten by command line argument _C.TRAIN.ACCUMULATION_STEPS = 1 # Whether to use gradient checkpointing to save memory # could be overwritten by command line argument _C.TRAIN.USE_CHECKPOINT = False # LR scheduler _C.TRAIN.LR_SCHEDULER = CN() _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' # Epoch interval to decay LR, used in StepLRScheduler _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 # LR decay rate, used in StepLRScheduler _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 # warmup_prefix used in CosineLRScheduler _C.TRAIN.LR_SCHEDULER.WARMUP_PREFIX = True # [SimMIM] Gamma / Multi steps value, used in MultiStepLRScheduler _C.TRAIN.LR_SCHEDULER.GAMMA = 0.1 _C.TRAIN.LR_SCHEDULER.MULTISTEPS = [] # Optimizer _C.TRAIN.OPTIMIZER = CN() _C.TRAIN.OPTIMIZER.NAME = 'adamw' # Optimizer Epsilon _C.TRAIN.OPTIMIZER.EPS = 1e-8 # Optimizer Betas _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) # SGD momentum _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 # [SimMIM] Layer decay for fine-tuning _C.TRAIN.LAYER_DECAY = 1.0 # MoE _C.TRAIN.MOE = CN() # Only save model on master device _C.TRAIN.MOE.SAVE_MASTER = False # ----------------------------------------------------------------------------- # Augmentation settings # ----------------------------------------------------------------------------- _C.AUG = CN() # Color jitter factor _C.AUG.COLOR_JITTER = 0.4 # Use AutoAugment policy. "v0" or "original" _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' # Random erase prob _C.AUG.REPROB = 0.25 # Random erase mode _C.AUG.REMODE = 'pixel' # Random erase count _C.AUG.RECOUNT = 1 # Mixup alpha, mixup enabled if > 0 _C.AUG.MIXUP = 0.8 # Cutmix alpha, cutmix enabled if > 0 _C.AUG.CUTMIX = 1.0 # Cutmix min/max ratio, overrides alpha and enables cutmix if set _C.AUG.CUTMIX_MINMAX = None # Probability of performing mixup or cutmix when either/both is enabled _C.AUG.MIXUP_PROB = 1.0 # Probability of switching to cutmix when both mixup and cutmix enabled _C.AUG.MIXUP_SWITCH_PROB = 0.5 # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" _C.AUG.MIXUP_MODE = 'batch' # ----------------------------------------------------------------------------- # Testing settings # ----------------------------------------------------------------------------- _C.TEST = CN() # Whether to use center crop when testing _C.TEST.CROP = True # Whether to use SequentialSampler as validation sampler _C.TEST.SEQUENTIAL = False _C.TEST.SHUFFLE = False # ----------------------------------------------------------------------------- # Misc # ----------------------------------------------------------------------------- # [SimMIM] Whether to enable pytorch amp, overwritten by command line argument _C.ENABLE_AMP = False # Enable Pytorch automatic mixed precision (amp). _C.AMP_ENABLE = True # [Deprecated] Mixed precision opt level of apex, if O0, no apex amp is used ('O0', 'O1', 'O2') _C.AMP_OPT_LEVEL = '' # Path to output folder, overwritten by command line argument _C.OUTPUT = '' # Tag of experiment, overwritten by command line argument _C.TAG = 'default' # Frequency to save checkpoint _C.SAVE_FREQ = 1 # Frequency to logging info _C.PRINT_FREQ = 10 # Fixed random seed _C.SEED = 0 # Perform evaluation only, overwritten by command line argument _C.EVAL_MODE = False # Test throughput only, overwritten by command line argument _C.THROUGHPUT_MODE = False # local rank for DistributedDataParallel, given by command line argument _C.LOCAL_RANK = 0 # for acceleration _C.FUSED_WINDOW_PROCESS = False _C.FUSED_LAYERNORM = False def _update_config_from_file(config, cfg_file): config.defrost() with open(cfg_file, 'r') as f: yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) for cfg in yaml_cfg.setdefault('BASE', ['']): if cfg: _update_config_from_file( config, os.path.join(os.path.dirname(cfg_file), cfg) ) print('=> merge config from {}'.format(cfg_file)) config.merge_from_file(cfg_file) config.freeze() def update_config(config, args): _update_config_from_file(config, args.cfg) config.defrost() if args.opts: config.merge_from_list(args.opts) def _check_args(name): if hasattr(args, name) and eval(f'args.{name}'): return True return False # merge from specific arguments if _check_args('batch_size'): config.DATA.BATCH_SIZE = args.batch_size if _check_args('data_path'): config.DATA.DATA_PATH = args.data_path if _check_args('zip'): config.DATA.ZIP_MODE = True if _check_args('cache_mode'): config.DATA.CACHE_MODE = args.cache_mode if _check_args('pretrained'): config.MODEL.PRETRAINED = args.pretrained if _check_args('resume'): config.MODEL.RESUME = args.resume if _check_args('accumulation_steps'): config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps if _check_args('use_checkpoint'): config.TRAIN.USE_CHECKPOINT = True if _check_args('amp_opt_level'): print("[warning] Apex amp has been deprecated, please use pytorch amp instead!") if args.amp_opt_level == 'O0': config.AMP_ENABLE = False if _check_args('disable_amp'): config.AMP_ENABLE = False if _check_args('output'): config.OUTPUT = args.output if _check_args('tag'): config.TAG = args.tag if _check_args('eval'): config.EVAL_MODE = True if _check_args('throughput'): config.THROUGHPUT_MODE = True # [SimMIM] if _check_args('enable_amp'): config.ENABLE_AMP = args.enable_amp # for acceleration if _check_args('fused_window_process'): config.FUSED_WINDOW_PROCESS = True if _check_args('fused_layernorm'): config.FUSED_LAYERNORM = True ## Overwrite optimizer if not None, currently we use it for [fused_adam, fused_lamb] if _check_args('optim'): config.TRAIN.OPTIMIZER.NAME = args.optim # set local rank for distributed training if PYTORCH_MAJOR_VERSION == 1: config.LOCAL_RANK = args.local_rank else: config.LOCAL_RANK = int(os.environ['LOCAL_RANK']) # output folder config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) config.freeze() def get_config(args): """Get a yacs CfgNode object with default values.""" # Return a clone so that the defaults will not be altered # This is for the "local variable" use pattern config = _C.clone() update_config(config, args) return config ================================================ FILE: configs/simmim/simmim_finetune__swin_base__img224_window7__800ep.yaml ================================================ MODEL: TYPE: swin NAME: simmim_finetune DROP_PATH_RATE: 0.1 SWIN: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 7 DATA: IMG_SIZE: 224 TRAIN: EPOCHS: 100 WARMUP_EPOCHS: 20 BASE_LR: 1.25e-3 WARMUP_LR: 2.5e-7 MIN_LR: 2.5e-7 WEIGHT_DECAY: 0.05 LAYER_DECAY: 0.8 PRINT_FREQ: 100 SAVE_FREQ: 5 TAG: simmim_finetune__swin_base__img224_window7__800ep ================================================ FILE: configs/simmim/simmim_finetune__swinv2_base__img224_window14__800ep.yaml ================================================ MODEL: TYPE: swinv2 NAME: simmim_finetune DROP_PATH_RATE: 0.1 SWINV2: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 14 PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ] DATA: IMG_SIZE: 224 TRAIN: EPOCHS: 100 WARMUP_EPOCHS: 20 BASE_LR: 1.25e-3 WARMUP_LR: 2.5e-7 MIN_LR: 2.5e-7 WEIGHT_DECAY: 0.05 LAYER_DECAY: 0.75 PRINT_FREQ: 100 SAVE_FREQ: 5 TAG: simmim_finetune__swinv2_base__img224_window14__800ep ================================================ FILE: configs/simmim/simmim_pretrain__swin_base__img192_window6__800ep.yaml ================================================ MODEL: TYPE: swin NAME: simmim_pretrain DROP_PATH_RATE: 0.0 SWIN: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 6 DATA: IMG_SIZE: 192 MASK_PATCH_SIZE: 32 MASK_RATIO: 0.6 TRAIN: EPOCHS: 800 WARMUP_EPOCHS: 10 BASE_LR: 1e-4 WARMUP_LR: 5e-7 WEIGHT_DECAY: 0.05 LR_SCHEDULER: NAME: 'multistep' GAMMA: 0.1 MULTISTEPS: [700,] PRINT_FREQ: 100 SAVE_FREQ: 5 TAG: simmim_pretrain__swin_base__img192_window6__800ep ================================================ FILE: configs/simmim/simmim_pretrain__swinv2_base__img192_window12__800ep.yaml ================================================ MODEL: TYPE: swinv2 NAME: simmim_pretrain DROP_PATH_RATE: 0.1 SIMMIM: NORM_TARGET: ENABLE: True PATCH_SIZE: 47 SWINV2: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 12 DATA: IMG_SIZE: 192 MASK_PATCH_SIZE: 32 MASK_RATIO: 0.6 TRAIN: EPOCHS: 800 WARMUP_EPOCHS: 10 BASE_LR: 1e-4 WARMUP_LR: 5e-7 WEIGHT_DECAY: 0.05 LR_SCHEDULER: NAME: 'multistep' GAMMA: 0.1 MULTISTEPS: [700,] PRINT_FREQ: 100 SAVE_FREQ: 5 TAG: simmim_pretrain__swinv2_base__img192_window12__800ep ================================================ FILE: configs/swin/swin_base_patch4_window12_384_22kto1k_finetune.yaml ================================================ DATA: IMG_SIZE: 384 MODEL: TYPE: swin NAME: swin_base_patch4_window12_384_22kto1k_finetune DROP_PATH_RATE: 0.2 SWIN: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 12 TRAIN: EPOCHS: 30 WARMUP_EPOCHS: 5 WEIGHT_DECAY: 1e-8 BASE_LR: 2e-05 WARMUP_LR: 2e-08 MIN_LR: 2e-07 TEST: CROP: False ================================================ FILE: configs/swin/swin_base_patch4_window12_384_finetune.yaml ================================================ DATA: IMG_SIZE: 384 MODEL: TYPE: swin NAME: swin_base_patch4_window12_384_finetune DROP_PATH_RATE: 0.5 SWIN: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 12 TRAIN: EPOCHS: 30 WARMUP_EPOCHS: 5 WEIGHT_DECAY: 1e-8 BASE_LR: 2e-05 WARMUP_LR: 2e-08 MIN_LR: 2e-07 TEST: CROP: False ================================================ FILE: configs/swin/swin_base_patch4_window7_224.yaml ================================================ MODEL: TYPE: swin NAME: swin_base_patch4_window7_224 DROP_PATH_RATE: 0.5 SWIN: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 7 ================================================ FILE: configs/swin/swin_base_patch4_window7_224_22k.yaml ================================================ DATA: DATASET: imagenet22K MODEL: TYPE: swin NAME: swin_base_patch4_window7_224_22k DROP_PATH_RATE: 0.2 SWIN: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 7 TRAIN: EPOCHS: 90 WARMUP_EPOCHS: 5 WEIGHT_DECAY: 0.05 BASE_LR: 1.25e-4 # 4096 batch-size WARMUP_LR: 1.25e-7 MIN_LR: 1.25e-6 ================================================ FILE: configs/swin/swin_base_patch4_window7_224_22kto1k_finetune.yaml ================================================ MODEL: TYPE: swin NAME: swin_base_patch4_window7_224_22kto1k_finetune DROP_PATH_RATE: 0.2 SWIN: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 7 TRAIN: EPOCHS: 30 WARMUP_EPOCHS: 5 WEIGHT_DECAY: 1e-8 BASE_LR: 2e-05 WARMUP_LR: 2e-08 MIN_LR: 2e-07 ================================================ FILE: configs/swin/swin_large_patch4_window12_384_22kto1k_finetune.yaml ================================================ DATA: IMG_SIZE: 384 MODEL: TYPE: swin NAME: swin_large_patch4_window12_384_22kto1k_finetune DROP_PATH_RATE: 0.2 SWIN: EMBED_DIM: 192 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 6, 12, 24, 48 ] WINDOW_SIZE: 12 TRAIN: EPOCHS: 30 WARMUP_EPOCHS: 5 WEIGHT_DECAY: 1e-8 BASE_LR: 2e-05 WARMUP_LR: 2e-08 MIN_LR: 2e-07 TEST: CROP: False ================================================ FILE: configs/swin/swin_large_patch4_window7_224_22k.yaml ================================================ DATA: DATASET: imagenet22K MODEL: TYPE: swin NAME: swin_large_patch4_window7_224_22k DROP_PATH_RATE: 0.2 SWIN: EMBED_DIM: 192 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 6, 12, 24, 48 ] WINDOW_SIZE: 7 TRAIN: EPOCHS: 90 WARMUP_EPOCHS: 5 WEIGHT_DECAY: 0.05 BASE_LR: 1.25e-4 # 4096 batch-size WARMUP_LR: 1.25e-7 MIN_LR: 1.25e-6 ================================================ FILE: configs/swin/swin_large_patch4_window7_224_22kto1k_finetune.yaml ================================================ MODEL: TYPE: swin NAME: swin_large_patch4_window7_224_22kto1k_finetune DROP_PATH_RATE: 0.2 SWIN: EMBED_DIM: 192 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 6, 12, 24, 48 ] WINDOW_SIZE: 7 TRAIN: EPOCHS: 30 WARMUP_EPOCHS: 5 WEIGHT_DECAY: 1e-8 BASE_LR: 2e-05 WARMUP_LR: 2e-08 MIN_LR: 2e-07 ================================================ FILE: configs/swin/swin_small_patch4_window7_224.yaml ================================================ MODEL: TYPE: swin NAME: swin_small_patch4_window7_224 DROP_PATH_RATE: 0.3 SWIN: EMBED_DIM: 96 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 7 ================================================ FILE: configs/swin/swin_small_patch4_window7_224_22k.yaml ================================================ DATA: DATASET: imagenet22K MODEL: TYPE: swin NAME: swin_small_patch4_window7_224_22k DROP_PATH_RATE: 0.2 SWIN: EMBED_DIM: 96 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 7 TRAIN: EPOCHS: 90 WARMUP_EPOCHS: 5 WEIGHT_DECAY: 0.05 BASE_LR: 1.25e-4 # 4096 batch-size WARMUP_LR: 1.25e-7 MIN_LR: 1.25e-6 ================================================ FILE: configs/swin/swin_small_patch4_window7_224_22kto1k_finetune.yaml ================================================ MODEL: TYPE: swin NAME: swin_small_patch4_window7_224_22kto1k_finetune DROP_PATH_RATE: 0.2 SWIN: EMBED_DIM: 96 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 7 TRAIN: EPOCHS: 30 WARMUP_EPOCHS: 5 WEIGHT_DECAY: 1e-8 BASE_LR: 2e-05 WARMUP_LR: 2e-08 MIN_LR: 2e-07 ================================================ FILE: configs/swin/swin_tiny_c24_patch4_window8_256.yaml ================================================ DATA: IMG_SIZE: 256 MODEL: TYPE: swin NAME: swin_tiny_c24_patch4_window8_256 DROP_PATH_RATE: 0.2 SWIN: EMBED_DIM: 96 DEPTHS: [ 2, 2, 6, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 8 ================================================ FILE: configs/swin/swin_tiny_patch4_window7_224.yaml ================================================ MODEL: TYPE: swin NAME: swin_tiny_patch4_window7_224 DROP_PATH_RATE: 0.2 SWIN: EMBED_DIM: 96 DEPTHS: [ 2, 2, 6, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 7 ================================================ FILE: configs/swin/swin_tiny_patch4_window7_224_22k.yaml ================================================ DATA: DATASET: imagenet22K MODEL: TYPE: swin NAME: swin_tiny_patch4_window7_224_22k DROP_PATH_RATE: 0.1 SWIN: EMBED_DIM: 96 DEPTHS: [ 2, 2, 6, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 7 TRAIN: EPOCHS: 90 WARMUP_EPOCHS: 5 WEIGHT_DECAY: 0.05 BASE_LR: 1.25e-4 # 4096 batch-size WARMUP_LR: 1.25e-7 MIN_LR: 1.25e-6 ================================================ FILE: configs/swin/swin_tiny_patch4_window7_224_22kto1k_finetune.yaml ================================================ MODEL: TYPE: swin NAME: swin_tiny_patch4_window7_224_22kto1k_finetune DROP_PATH_RATE: 0.1 SWIN: EMBED_DIM: 96 DEPTHS: [ 2, 2, 6, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 7 TRAIN: EPOCHS: 30 WARMUP_EPOCHS: 5 WEIGHT_DECAY: 1e-8 BASE_LR: 2e-05 WARMUP_LR: 2e-08 MIN_LR: 2e-07 ================================================ FILE: configs/swinmlp/swin_mlp_base_patch4_window7_224.yaml ================================================ MODEL: TYPE: swin_mlp NAME: swin_mlp_base_patch4_window7_224 DROP_PATH_RATE: 0.5 SWIN_MLP: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 7 ================================================ FILE: configs/swinmlp/swin_mlp_tiny_c12_patch4_window8_256.yaml ================================================ DATA: IMG_SIZE: 256 MODEL: TYPE: swin_mlp NAME: swin_mlp_tiny_c12_patch4_window8_256 DROP_PATH_RATE: 0.2 SWIN_MLP: EMBED_DIM: 96 DEPTHS: [ 2, 2, 6, 2 ] NUM_HEADS: [ 8, 16, 32, 64 ] WINDOW_SIZE: 8 ================================================ FILE: configs/swinmlp/swin_mlp_tiny_c24_patch4_window8_256.yaml ================================================ DATA: IMG_SIZE: 256 MODEL: TYPE: swin_mlp NAME: swin_mlp_tiny_c24_patch4_window8_256 DROP_PATH_RATE: 0.2 SWIN_MLP: EMBED_DIM: 96 DEPTHS: [ 2, 2, 6, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 8 ================================================ FILE: configs/swinmlp/swin_mlp_tiny_c6_patch4_window8_256.yaml ================================================ DATA: IMG_SIZE: 256 MODEL: TYPE: swin_mlp NAME: swin_mlp_tiny_c6_patch4_window8_256 DROP_PATH_RATE: 0.2 SWIN_MLP: EMBED_DIM: 96 DEPTHS: [ 2, 2, 6, 2 ] NUM_HEADS: [ 16, 32, 64, 128 ] WINDOW_SIZE: 8 ================================================ FILE: configs/swinmoe/swin_moe_base_patch4_window12_192_16expert_32gpu_22k.yaml ================================================ DATA: DATASET: imagenet22K IMG_SIZE: 192 MODEL: TYPE: swin_moe NAME: swin_moe_base_patch4_window12_192_16expert_32gpu_22k DROP_PATH_RATE: 0.3 SWIN_MOE: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 12 MLP_FC2_BIAS: False INIT_STD: 0.005 MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ] NUM_LOCAL_EXPERTS: -2 TOP_VALUE: 1 CAPACITY_FACTOR: 1.25 IS_GSHARD_LOSS: False MOE_DROP: 0.1 AUX_LOSS_WEIGHT: 0.01 TRAIN: EPOCHS: 90 WARMUP_EPOCHS: 10 WEIGHT_DECAY: 0.1 BASE_LR: 1.25e-4 # 4096 batch-size WARMUP_LR: 1.25e-7 MIN_LR: 1.25e-6 CLIP_GRAD: 3.0 TEST: SHUFFLE: True ================================================ FILE: configs/swinmoe/swin_moe_base_patch4_window12_192_32expert_32gpu_22k.yaml ================================================ DATA: DATASET: imagenet22K IMG_SIZE: 192 MODEL: TYPE: swin_moe NAME: swin_moe_base_patch4_window12_192_32expert_32gpu_22k DROP_PATH_RATE: 0.3 SWIN_MOE: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 12 MLP_FC2_BIAS: False INIT_STD: 0.005 MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ] NUM_LOCAL_EXPERTS: 1 TOP_VALUE: 1 CAPACITY_FACTOR: 1.25 IS_GSHARD_LOSS: False MOE_DROP: 0.1 AUX_LOSS_WEIGHT: 0.01 TRAIN: EPOCHS: 90 WARMUP_EPOCHS: 10 WEIGHT_DECAY: 0.1 BASE_LR: 1.25e-4 # 4096 batch-size WARMUP_LR: 1.25e-7 MIN_LR: 1.25e-6 CLIP_GRAD: 3.0 TEST: SHUFFLE: True ================================================ FILE: configs/swinmoe/swin_moe_base_patch4_window12_192_8expert_32gpu_22k.yaml ================================================ DATA: DATASET: imagenet22K IMG_SIZE: 192 MODEL: TYPE: swin_moe NAME: swin_moe_base_patch4_window12_192_8expert_32gpu_22k DROP_PATH_RATE: 0.3 SWIN_MOE: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 12 MLP_FC2_BIAS: False INIT_STD: 0.005 MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ] NUM_LOCAL_EXPERTS: -4 TOP_VALUE: 1 CAPACITY_FACTOR: 1.25 IS_GSHARD_LOSS: False MOE_DROP: 0.1 AUX_LOSS_WEIGHT: 0.01 TRAIN: EPOCHS: 90 WARMUP_EPOCHS: 10 WEIGHT_DECAY: 0.1 BASE_LR: 1.25e-4 # 4096 batch-size WARMUP_LR: 1.25e-7 MIN_LR: 1.25e-6 CLIP_GRAD: 3.0 TEST: SHUFFLE: True ================================================ FILE: configs/swinmoe/swin_moe_base_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml ================================================ DATA: DATASET: imagenet22K IMG_SIZE: 192 MODEL: TYPE: swin_moe NAME: swin_moe_base_patch4_window12_192_cosine_router_32expert_32gpu_22k DROP_PATH_RATE: 0.3 SWIN_MOE: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 12 MLP_FC2_BIAS: False INIT_STD: 0.005 MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ] NUM_LOCAL_EXPERTS: 1 TOP_VALUE: 1 CAPACITY_FACTOR: 1.25 COSINE_ROUTER: True IS_GSHARD_LOSS: False MOE_DROP: 0.1 AUX_LOSS_WEIGHT: 0.01 TRAIN: EPOCHS: 90 WARMUP_EPOCHS: 10 WEIGHT_DECAY: 0.1 BASE_LR: 1.25e-4 # 4096 batch-size WARMUP_LR: 1.25e-7 MIN_LR: 1.25e-6 CLIP_GRAD: 3.0 TEST: SHUFFLE: True ================================================ FILE: configs/swinmoe/swin_moe_base_patch4_window12_192_densebaseline_22k.yaml ================================================ DATA: DATASET: imagenet22K IMG_SIZE: 192 MODEL: TYPE: swin_moe NAME: swin_moe_base_patch4_window12_192_densebaseline_22k DROP_PATH_RATE: 0.2 SWIN_MOE: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 12 MLP_FC2_BIAS: False MOE_BLOCKS: [ [ -1 ], [ -1 ], [ -1 ], [ -1 ] ] TRAIN: EPOCHS: 90 WARMUP_EPOCHS: 10 WEIGHT_DECAY: 0.1 BASE_LR: 1.25e-4 # 4096 batch-size WARMUP_LR: 1.25e-7 MIN_LR: 1.25e-6 CLIP_GRAD: 3.0 MOE: SAVE_MASTER: True TEST: SHUFFLE: True ================================================ FILE: configs/swinmoe/swin_moe_small_patch4_window12_192_16expert_32gpu_22k.yaml ================================================ DATA: DATASET: imagenet22K IMG_SIZE: 192 MODEL: TYPE: swin_moe NAME: swin_moe_small_patch4_window12_192_16expert_32gpu_22k DROP_PATH_RATE: 0.2 SWIN_MOE: EMBED_DIM: 96 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 12 MLP_FC2_BIAS: False INIT_STD: 0.005 MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ] NUM_LOCAL_EXPERTS: -2 TOP_VALUE: 1 CAPACITY_FACTOR: 1.25 IS_GSHARD_LOSS: False MOE_DROP: 0.1 AUX_LOSS_WEIGHT: 0.01 TRAIN: EPOCHS: 90 WARMUP_EPOCHS: 10 WEIGHT_DECAY: 0.1 BASE_LR: 1.25e-4 # 4096 batch-size WARMUP_LR: 1.25e-7 MIN_LR: 1.25e-6 CLIP_GRAD: 3.0 TEST: SHUFFLE: True ================================================ FILE: configs/swinmoe/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml ================================================ DATA: DATASET: imagenet22K IMG_SIZE: 192 MODEL: TYPE: swin_moe NAME: swin_moe_small_patch4_window12_192_32expert_32gpu_22k DROP_PATH_RATE: 0.2 SWIN_MOE: EMBED_DIM: 96 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 12 MLP_FC2_BIAS: False INIT_STD: 0.005 MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ] NUM_LOCAL_EXPERTS: 1 TOP_VALUE: 1 CAPACITY_FACTOR: 1.25 IS_GSHARD_LOSS: False MOE_DROP: 0.1 AUX_LOSS_WEIGHT: 0.01 TRAIN: EPOCHS: 90 WARMUP_EPOCHS: 10 WEIGHT_DECAY: 0.1 BASE_LR: 1.25e-4 # 4096 batch-size WARMUP_LR: 1.25e-7 MIN_LR: 1.25e-6 CLIP_GRAD: 3.0 TEST: SHUFFLE: True ================================================ FILE: configs/swinmoe/swin_moe_small_patch4_window12_192_64expert_64gpu_22k.yaml ================================================ DATA: DATASET: imagenet22K IMG_SIZE: 192 MODEL: TYPE: swin_moe NAME: swin_moe_small_patch4_window12_192_64expert_64gpu_22k DROP_PATH_RATE: 0.2 SWIN_MOE: EMBED_DIM: 96 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 12 MLP_FC2_BIAS: False INIT_STD: 0.005 MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ] NUM_LOCAL_EXPERTS: 1 TOP_VALUE: 1 CAPACITY_FACTOR: 1.25 IS_GSHARD_LOSS: False MOE_DROP: 0.1 AUX_LOSS_WEIGHT: 0.01 TRAIN: EPOCHS: 90 WARMUP_EPOCHS: 10 WEIGHT_DECAY: 0.1 BASE_LR: 1.25e-4 # 4096 batch-size WARMUP_LR: 1.25e-7 MIN_LR: 1.25e-6 CLIP_GRAD: 3.0 TEST: SHUFFLE: True ================================================ FILE: configs/swinmoe/swin_moe_small_patch4_window12_192_8expert_32gpu_22k.yaml ================================================ DATA: DATASET: imagenet22K IMG_SIZE: 192 MODEL: TYPE: swin_moe NAME: swin_moe_small_patch4_window12_192_8expert_32gpu_22k DROP_PATH_RATE: 0.2 SWIN_MOE: EMBED_DIM: 96 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 12 MLP_FC2_BIAS: False INIT_STD: 0.005 MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ] NUM_LOCAL_EXPERTS: -4 TOP_VALUE: 1 CAPACITY_FACTOR: 1.25 IS_GSHARD_LOSS: False MOE_DROP: 0.1 AUX_LOSS_WEIGHT: 0.01 TRAIN: EPOCHS: 90 WARMUP_EPOCHS: 10 WEIGHT_DECAY: 0.1 BASE_LR: 1.25e-4 # 4096 batch-size WARMUP_LR: 1.25e-7 MIN_LR: 1.25e-6 CLIP_GRAD: 3.0 TEST: SHUFFLE: True ================================================ FILE: configs/swinmoe/swin_moe_small_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml ================================================ DATA: DATASET: imagenet22K IMG_SIZE: 192 MODEL: TYPE: swin_moe NAME: swin_moe_small_patch4_window12_192_cosine_router_32expert_32gpu_22k DROP_PATH_RATE: 0.2 SWIN_MOE: EMBED_DIM: 96 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 12 MLP_FC2_BIAS: False INIT_STD: 0.005 MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ] NUM_LOCAL_EXPERTS: 1 TOP_VALUE: 1 CAPACITY_FACTOR: 1.25 COSINE_ROUTER: True IS_GSHARD_LOSS: False MOE_DROP: 0.1 AUX_LOSS_WEIGHT: 0.01 TRAIN: EPOCHS: 90 WARMUP_EPOCHS: 10 WEIGHT_DECAY: 0.1 BASE_LR: 1.25e-4 # 4096 batch-size WARMUP_LR: 1.25e-7 MIN_LR: 1.25e-6 CLIP_GRAD: 3.0 TEST: SHUFFLE: True ================================================ FILE: configs/swinmoe/swin_moe_small_patch4_window12_192_densebaseline_22k.yaml ================================================ DATA: DATASET: imagenet22K IMG_SIZE: 192 MODEL: TYPE: swin_moe NAME: swin_moe_small_patch4_window12_192_densebaseline_22k DROP_PATH_RATE: 0.2 SWIN_MOE: EMBED_DIM: 96 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 12 MLP_FC2_BIAS: False MOE_BLOCKS: [ [ -1 ], [ -1 ], [ -1 ], [ -1 ] ] TRAIN: EPOCHS: 90 WARMUP_EPOCHS: 10 WEIGHT_DECAY: 0.1 BASE_LR: 1.25e-4 # 4096 batch-size WARMUP_LR: 1.25e-7 MIN_LR: 1.25e-6 CLIP_GRAD: 3.0 MOE: SAVE_MASTER: True TEST: SHUFFLE: True ================================================ FILE: configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml ================================================ DATA: DATASET: imagenet22K IMG_SIZE: 192 MODEL: TYPE: swinv2 NAME: swinv2_base_patch4_window12_192_22k DROP_PATH_RATE: 0.2 SWINV2: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 12 TRAIN: EPOCHS: 90 WARMUP_EPOCHS: 5 WEIGHT_DECAY: 0.1 BASE_LR: 1.25e-4 # 4096 batch-size WARMUP_LR: 1.25e-7 MIN_LR: 1.25e-6 ================================================ FILE: configs/swinv2/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml ================================================ DATA: IMG_SIZE: 256 MODEL: TYPE: swinv2 NAME: swinv2_base_patch4_window12to16_192to256_22kto1k_ft DROP_PATH_RATE: 0.2 SWINV2: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 16 PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ] TRAIN: EPOCHS: 30 WARMUP_EPOCHS: 5 WEIGHT_DECAY: 1e-8 BASE_LR: 2e-05 WARMUP_LR: 2e-08 MIN_LR: 2e-07 ================================================ FILE: configs/swinv2/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.yaml ================================================ DATA: IMG_SIZE: 384 MODEL: TYPE: swinv2 NAME: swinv2_base_patch4_window12to24_192to384_22kto1k_ft DROP_PATH_RATE: 0.2 SWINV2: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 24 PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ] TRAIN: EPOCHS: 30 WARMUP_EPOCHS: 5 WEIGHT_DECAY: 1e-8 BASE_LR: 2e-05 WARMUP_LR: 2e-08 MIN_LR: 2e-07 TEST: CROP: False ================================================ FILE: configs/swinv2/swinv2_base_patch4_window16_256.yaml ================================================ DATA: IMG_SIZE: 256 MODEL: TYPE: swinv2 NAME: swinv2_base_patch4_window16_256 DROP_PATH_RATE: 0.5 SWINV2: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 16 ================================================ FILE: configs/swinv2/swinv2_base_patch4_window8_256.yaml ================================================ DATA: IMG_SIZE: 256 MODEL: TYPE: swinv2 NAME: swinv2_base_patch4_window8_256 DROP_PATH_RATE: 0.5 SWINV2: EMBED_DIM: 128 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 4, 8, 16, 32 ] WINDOW_SIZE: 8 ================================================ FILE: configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml ================================================ DATA: DATASET: imagenet22K IMG_SIZE: 192 MODEL: TYPE: swinv2 NAME: swinv2_large_patch4_window12_192_22k DROP_PATH_RATE: 0.2 SWINV2: EMBED_DIM: 192 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 6, 12, 24, 48 ] WINDOW_SIZE: 12 TRAIN: EPOCHS: 90 WARMUP_EPOCHS: 5 WEIGHT_DECAY: 0.1 BASE_LR: 1.25e-4 # 4096 batch-size WARMUP_LR: 1.25e-7 MIN_LR: 1.25e-6 ================================================ FILE: configs/swinv2/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.yaml ================================================ DATA: IMG_SIZE: 256 MODEL: TYPE: swinv2 NAME: swinv2_base_patch4_window12to16_192to256_22kto1k_ft DROP_PATH_RATE: 0.2 SWINV2: EMBED_DIM: 192 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 6, 12, 24, 48 ] WINDOW_SIZE: 16 PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ] TRAIN: EPOCHS: 30 WARMUP_EPOCHS: 5 WEIGHT_DECAY: 1e-8 BASE_LR: 2e-05 WARMUP_LR: 2e-08 MIN_LR: 2e-07 ================================================ FILE: configs/swinv2/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.yaml ================================================ DATA: IMG_SIZE: 384 MODEL: TYPE: swinv2 NAME: swinv2_large_patch4_window12to24_192to384_22kto1k_ft DROP_PATH_RATE: 0.2 SWINV2: EMBED_DIM: 192 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 6, 12, 24, 48 ] WINDOW_SIZE: 24 PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ] TRAIN: EPOCHS: 30 WARMUP_EPOCHS: 5 WEIGHT_DECAY: 1e-8 BASE_LR: 2e-05 WARMUP_LR: 2e-08 MIN_LR: 2e-07 TEST: CROP: False ================================================ FILE: configs/swinv2/swinv2_small_patch4_window16_256.yaml ================================================ DATA: IMG_SIZE: 256 MODEL: TYPE: swinv2 NAME: swinv2_small_patch4_window16_256 DROP_PATH_RATE: 0.3 SWINV2: EMBED_DIM: 96 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 16 ================================================ FILE: configs/swinv2/swinv2_small_patch4_window8_256.yaml ================================================ DATA: IMG_SIZE: 256 MODEL: TYPE: swinv2 NAME: swinv2_small_patch4_window8_256 DROP_PATH_RATE: 0.3 SWINV2: EMBED_DIM: 96 DEPTHS: [ 2, 2, 18, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 8 ================================================ FILE: configs/swinv2/swinv2_tiny_patch4_window16_256.yaml ================================================ DATA: IMG_SIZE: 256 MODEL: TYPE: swinv2 NAME: swinv2_tiny_patch4_window16_256 DROP_PATH_RATE: 0.2 SWINV2: EMBED_DIM: 96 DEPTHS: [ 2, 2, 6, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 16 ================================================ FILE: configs/swinv2/swinv2_tiny_patch4_window8_256.yaml ================================================ DATA: IMG_SIZE: 256 MODEL: TYPE: swinv2 NAME: swinv2_tiny_patch4_window8_256 DROP_PATH_RATE: 0.2 SWINV2: EMBED_DIM: 96 DEPTHS: [ 2, 2, 6, 2 ] NUM_HEADS: [ 3, 6, 12, 24 ] WINDOW_SIZE: 8 ================================================ FILE: data/__init__.py ================================================ from .build import build_loader as _build_loader from .data_simmim_pt import build_loader_simmim from .data_simmim_ft import build_loader_finetune def build_loader(config, simmim=False, is_pretrain=False): if not simmim: return _build_loader(config) if is_pretrain: return build_loader_simmim(config) else: return build_loader_finetune(config) ================================================ FILE: data/build.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # -------------------------------------------------------- import os import torch import numpy as np import torch.distributed as dist from torchvision import datasets, transforms from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import Mixup from timm.data import create_transform from .cached_image_folder import CachedImageFolder from .imagenet22k_dataset import IN22KDATASET from .samplers import SubsetRandomSampler try: from torchvision.transforms import InterpolationMode def _pil_interp(method): if method == 'bicubic': return InterpolationMode.BICUBIC elif method == 'lanczos': return InterpolationMode.LANCZOS elif method == 'hamming': return InterpolationMode.HAMMING else: # default bilinear, do we want to allow nearest? return InterpolationMode.BILINEAR import timm.data.transforms as timm_transforms timm_transforms._pil_interp = _pil_interp except: from timm.data.transforms import _pil_interp def build_loader(config): config.defrost() dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) config.freeze() print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset") dataset_val, _ = build_dataset(is_train=False, config=config) print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset") num_tasks = dist.get_world_size() global_rank = dist.get_rank() if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) sampler_train = SubsetRandomSampler(indices) else: sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) if config.TEST.SEQUENTIAL: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_val = torch.utils.data.distributed.DistributedSampler( dataset_val, shuffle=config.TEST.SHUFFLE ) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=config.DATA.BATCH_SIZE, num_workers=config.DATA.NUM_WORKERS, pin_memory=config.DATA.PIN_MEMORY, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader( dataset_val, sampler=sampler_val, batch_size=config.DATA.BATCH_SIZE, shuffle=False, num_workers=config.DATA.NUM_WORKERS, pin_memory=config.DATA.PIN_MEMORY, drop_last=False ) # setup mixup / cutmix mixup_fn = None mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None if mixup_active: mixup_fn = Mixup( mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn def build_dataset(is_train, config): transform = build_transform(is_train, config) if config.DATA.DATASET == 'imagenet': prefix = 'train' if is_train else 'val' if config.DATA.ZIP_MODE: ann_file = prefix + "_map.txt" prefix = prefix + ".zip@/" dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform, cache_mode=config.DATA.CACHE_MODE if is_train else 'part') else: root = os.path.join(config.DATA.DATA_PATH, prefix) dataset = datasets.ImageFolder(root, transform=transform) nb_classes = 1000 elif config.DATA.DATASET == 'imagenet22K': prefix = 'ILSVRC2011fall_whole' if is_train: ann_file = prefix + "_map_train.txt" else: ann_file = prefix + "_map_val.txt" dataset = IN22KDATASET(config.DATA.DATA_PATH, ann_file, transform) nb_classes = 21841 else: raise NotImplementedError("We only support ImageNet Now.") return dataset, nb_classes def build_transform(is_train, config): resize_im = config.DATA.IMG_SIZE > 32 if is_train: # this should always dispatch to transforms_imagenet_train transform = create_transform( input_size=config.DATA.IMG_SIZE, is_training=True, color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, re_prob=config.AUG.REPROB, re_mode=config.AUG.REMODE, re_count=config.AUG.RECOUNT, interpolation=config.DATA.INTERPOLATION, ) if not resize_im: # replace RandomResizedCropAndInterpolation with # RandomCrop transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) return transform t = [] if resize_im: if config.TEST.CROP: size = int((256 / 224) * config.DATA.IMG_SIZE) t.append( transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), # to maintain same ratio w.r.t. 224 images ) t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) else: t.append( transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), interpolation=_pil_interp(config.DATA.INTERPOLATION)) ) t.append(transforms.ToTensor()) t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) return transforms.Compose(t) ================================================ FILE: data/cached_image_folder.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # -------------------------------------------------------- import io import os import time import torch.distributed as dist import torch.utils.data as data from PIL import Image from .zipreader import is_zip_path, ZipReader def has_file_allowed_extension(filename, extensions): """Checks if a file is an allowed extension. Args: filename (string): path to a file Returns: bool: True if the filename ends with a known image extension """ filename_lower = filename.lower() return any(filename_lower.endswith(ext) for ext in extensions) def find_classes(dir): classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] classes.sort() class_to_idx = {classes[i]: i for i in range(len(classes))} return classes, class_to_idx def make_dataset(dir, class_to_idx, extensions): images = [] dir = os.path.expanduser(dir) for target in sorted(os.listdir(dir)): d = os.path.join(dir, target) if not os.path.isdir(d): continue for root, _, fnames in sorted(os.walk(d)): for fname in sorted(fnames): if has_file_allowed_extension(fname, extensions): path = os.path.join(root, fname) item = (path, class_to_idx[target]) images.append(item) return images def make_dataset_with_ann(ann_file, img_prefix, extensions): images = [] with open(ann_file, "r") as f: contents = f.readlines() for line_str in contents: path_contents = [c for c in line_str.split('\t')] im_file_name = path_contents[0] class_index = int(path_contents[1]) assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions item = (os.path.join(img_prefix, im_file_name), class_index) images.append(item) return images class DatasetFolder(data.Dataset): """A generic data loader where the samples are arranged in this way: :: root/class_x/xxx.ext root/class_x/xxy.ext root/class_x/xxz.ext root/class_y/123.ext root/class_y/nsdf3.ext root/class_y/asd932_.ext Args: root (string): Root directory path. loader (callable): A function to load a sample given its path. extensions (list[string]): A list of allowed extensions. transform (callable, optional): A function/transform that takes in a sample and returns a transformed version. E.g, ``transforms.RandomCrop`` for images. target_transform (callable, optional): A function/transform that takes in the target and transforms it. Attributes: samples (list): List of (sample path, class_index) tuples """ def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None, cache_mode="no"): # image folder mode if ann_file == '': _, class_to_idx = find_classes(root) samples = make_dataset(root, class_to_idx, extensions) # zip mode else: samples = make_dataset_with_ann(os.path.join(root, ann_file), os.path.join(root, img_prefix), extensions) if len(samples) == 0: raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" + "Supported extensions are: " + ",".join(extensions))) self.root = root self.loader = loader self.extensions = extensions self.samples = samples self.labels = [y_1k for _, y_1k in samples] self.classes = list(set(self.labels)) self.transform = transform self.target_transform = target_transform self.cache_mode = cache_mode if self.cache_mode != "no": self.init_cache() def init_cache(self): assert self.cache_mode in ["part", "full"] n_sample = len(self.samples) global_rank = dist.get_rank() world_size = dist.get_world_size() samples_bytes = [None for _ in range(n_sample)] start_time = time.time() for index in range(n_sample): if index % (n_sample // 10) == 0: t = time.time() - start_time print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block') start_time = time.time() path, target = self.samples[index] if self.cache_mode == "full": samples_bytes[index] = (ZipReader.read(path), target) elif self.cache_mode == "part" and index % world_size == global_rank: samples_bytes[index] = (ZipReader.read(path), target) else: samples_bytes[index] = (path, target) self.samples = samples_bytes def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ path, target = self.samples[index] sample = self.loader(path) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target def __len__(self): return len(self.samples) def __repr__(self): fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) fmt_str += ' Root Location: {}\n'.format(self.root) tmp = ' Transforms (if any): ' fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) tmp = ' Target Transforms (if any): ' fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) return fmt_str IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] def pil_loader(path): # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) if isinstance(path, bytes): img = Image.open(io.BytesIO(path)) elif is_zip_path(path): data = ZipReader.read(path) img = Image.open(io.BytesIO(data)) else: with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB') return img.convert('RGB') def accimage_loader(path): import accimage try: return accimage.Image(path) except IOError: # Potentially a decoding problem, fall back to PIL.Image return pil_loader(path) def default_img_loader(path): from torchvision import get_image_backend if get_image_backend() == 'accimage': return accimage_loader(path) else: return pil_loader(path) class CachedImageFolder(DatasetFolder): """A generic data loader where the images are arranged in this way: :: root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png Args: root (string): Root directory path. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. Attributes: imgs (list): List of (image path, class_index) tuples """ def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None, loader=default_img_loader, cache_mode="no"): super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, ann_file=ann_file, img_prefix=img_prefix, transform=transform, target_transform=target_transform, cache_mode=cache_mode) self.imgs = self.samples def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is class_index of the target class. """ path, target = self.samples[index] image = self.loader(path) if self.transform is not None: img = self.transform(image) else: img = image if self.target_transform is not None: target = self.target_transform(target) return img, target ================================================ FILE: data/data_simmim_ft.py ================================================ # -------------------------------------------------------- # SimMIM # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Zhenda Xie # -------------------------------------------------------- import os import torch.distributed as dist from torch.utils.data import DataLoader, DistributedSampler from torchvision import datasets, transforms from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import Mixup from timm.data import create_transform from timm.data.transforms import _pil_interp def build_loader_finetune(config): config.defrost() dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) config.freeze() dataset_val, _ = build_dataset(is_train=False, config=config) num_tasks = dist.get_world_size() global_rank = dist.get_rank() sampler_train = DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) sampler_val = DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False ) data_loader_train = DataLoader( dataset_train, sampler=sampler_train, batch_size=config.DATA.BATCH_SIZE, num_workers=config.DATA.NUM_WORKERS, pin_memory=config.DATA.PIN_MEMORY, drop_last=True, ) data_loader_val = DataLoader( dataset_val, sampler=sampler_val, batch_size=config.DATA.BATCH_SIZE, num_workers=config.DATA.NUM_WORKERS, pin_memory=config.DATA.PIN_MEMORY, drop_last=False, ) # setup mixup / cutmix mixup_fn = None mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None if mixup_active: mixup_fn = Mixup( mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn def build_dataset(is_train, config): transform = build_transform(is_train, config) if config.DATA.DATASET == 'imagenet': prefix = 'train' if is_train else 'val' root = os.path.join(config.DATA.DATA_PATH, prefix) dataset = datasets.ImageFolder(root, transform=transform) nb_classes = 1000 else: raise NotImplementedError("We only support ImageNet Now.") return dataset, nb_classes def build_transform(is_train, config): resize_im = config.DATA.IMG_SIZE > 32 if is_train: # this should always dispatch to transforms_imagenet_train transform = create_transform( input_size=config.DATA.IMG_SIZE, is_training=True, color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, re_prob=config.AUG.REPROB, re_mode=config.AUG.REMODE, re_count=config.AUG.RECOUNT, interpolation=config.DATA.INTERPOLATION, ) if not resize_im: # replace RandomResizedCropAndInterpolation with # RandomCrop transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) return transform t = [] if resize_im: if config.TEST.CROP: size = int((256 / 224) * config.DATA.IMG_SIZE) t.append( transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), # to maintain same ratio w.r.t. 224 images ) t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) else: t.append( transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), interpolation=_pil_interp(config.DATA.INTERPOLATION)) ) t.append(transforms.ToTensor()) t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) return transforms.Compose(t) ================================================ FILE: data/data_simmim_pt.py ================================================ # -------------------------------------------------------- # SimMIM # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Zhenda Xie # -------------------------------------------------------- import math import random import numpy as np import torch import torch.distributed as dist import torchvision.transforms as T from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data._utils.collate import default_collate from torchvision.datasets import ImageFolder from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD class MaskGenerator: def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6): self.input_size = input_size self.mask_patch_size = mask_patch_size self.model_patch_size = model_patch_size self.mask_ratio = mask_ratio assert self.input_size % self.mask_patch_size == 0 assert self.mask_patch_size % self.model_patch_size == 0 self.rand_size = self.input_size // self.mask_patch_size self.scale = self.mask_patch_size // self.model_patch_size self.token_count = self.rand_size ** 2 self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) def __call__(self): mask_idx = np.random.permutation(self.token_count)[:self.mask_count] mask = np.zeros(self.token_count, dtype=int) mask[mask_idx] = 1 mask = mask.reshape((self.rand_size, self.rand_size)) mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) return mask class SimMIMTransform: def __init__(self, config): self.transform_img = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)), ]) if config.MODEL.TYPE in ['swin', 'swinv2']: model_patch_size=config.MODEL.SWIN.PATCH_SIZE else: raise NotImplementedError self.mask_generator = MaskGenerator( input_size=config.DATA.IMG_SIZE, mask_patch_size=config.DATA.MASK_PATCH_SIZE, model_patch_size=model_patch_size, mask_ratio=config.DATA.MASK_RATIO, ) def __call__(self, img): img = self.transform_img(img) mask = self.mask_generator() return img, mask def collate_fn(batch): if not isinstance(batch[0][0], tuple): return default_collate(batch) else: batch_num = len(batch) ret = [] for item_idx in range(len(batch[0][0])): if batch[0][0][item_idx] is None: ret.append(None) else: ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)])) ret.append(default_collate([batch[i][1] for i in range(batch_num)])) return ret def build_loader_simmim(config): transform = SimMIMTransform(config) dataset = ImageFolder(config.DATA.DATA_PATH, transform) sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn) return dataloader ================================================ FILE: data/imagenet22k_dataset.py ================================================ import os import json import torch.utils.data as data import numpy as np from PIL import Image import warnings warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) class IN22KDATASET(data.Dataset): def __init__(self, root, ann_file='', transform=None, target_transform=None): super(IN22KDATASET, self).__init__() self.data_path = root self.ann_path = os.path.join(self.data_path, ann_file) self.transform = transform self.target_transform = target_transform # id & label: https://github.com/google-research/big_transfer/issues/7 # total: 21843; only 21841 class have images: map 21841->9205; 21842->15027 self.database = json.load(open(self.ann_path)) def _load_image(self, path): try: im = Image.open(path) except: print("ERROR IMG LOADED: ", path) random_img = np.random.rand(224, 224, 3) * 255 im = Image.fromarray(np.uint8(random_img)) return im def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is class_index of the target class. """ idb = self.database[index] # images images = self._load_image(self.data_path + '/' + idb[0]).convert('RGB') if self.transform is not None: images = self.transform(images) # target target = int(idb[1]) if self.target_transform is not None: target = self.target_transform(target) return images, target def __len__(self): return len(self.database) ================================================ FILE: data/map22kto1k.txt ================================================ 359 368 460 475 486 492 496 514 516 525 547 548 556 563 575 641 648 723 733 765 801 826 852 858 878 896 900 905 908 910 935 946 947 994 999 1003 1005 1010 1027 1029 1048 1055 1064 1065 1069 1075 1079 1081 1085 1088 1093 1106 1143 1144 1145 1147 1168 1171 1178 1187 1190 1197 1205 1216 1223 1230 1236 1241 1245 1257 1259 1260 1267 1268 1269 1271 1272 1273 1277 1303 1344 1349 1355 1357 1384 1388 1391 1427 1429 1432 1437 1450 1461 1462 1474 1502 1503 1512 1552 1555 1577 1584 1587 1589 1599 1615 1616 1681 1692 1701 1716 1729 1757 1759 1764 1777 1786 1822 1841 1842 1848 1850 1856 1860 1861 1864 1876 1897 1898 1910 1913 1918 1922 1928 1932 1935 1947 1951 1953 1970 1977 1979 2001 2017 2067 2081 2087 2112 2128 2135 2147 2174 2175 2176 2177 2178 2181 2183 2184 2187 2189 2190 2191 2192 2193 2197 2202 2203 2206 2208 2209 2211 2212 2213 2214 2215 2216 2217 2219 2222 2223 2224 2225 2226 2227 2228 2229 2230 2236 2238 2240 2241 2242 2243 2244 2245 2247 2248 2249 2250 2251 2252 2255 2256 2257 2262 2263 2264 2265 2266 2268 2270 2271 2272 2273 2275 2276 2279 2280 2281 2282 2285 2289 2292 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2309 2310 2312 2313 2314 2315 2316 2318 2319 2321 2322 2326 2329 2330 2331 2332 2334 2335 2336 2337 2338 2339 2341 2342 2343 2344 2346 2348 2349 2351 2352 2353 2355 2357 2358 2359 2360 2364 2365 2368 2369 2377 2382 2383 2385 2397 2398 2400 2402 2405 2412 2421 2428 2431 2432 2433 2436 2441 2445 2450 2453 2454 2465 2469 2532 2533 2538 2544 2547 2557 2565 2578 2612 2658 2702 2722 2731 2738 2741 2747 2810 2818 2833 2844 2845 2867 2874 2882 2884 2888 2889 3008 3012 3019 3029 3033 3042 3091 3106 3138 3159 3164 3169 3280 3296 3311 3318 3320 3324 3330 3366 3375 3381 3406 3419 3432 3434 3435 3493 3495 3503 3509 3511 3513 3517 3521 3526 3546 3554 3600 3601 3606 3612 3613 3616 3622 3623 3627 3632 3634 3636 3638 3644 3646 3649 3650 3651 3656 3663 3673 3674 3689 3690 3702 3733 3769 3971 3974 4065 4068 4073 4102 4136 4140 4151 4159 4165 4207 4219 4226 4249 4256 4263 4270 4313 4321 4378 4386 4478 4508 4512 4536 4542 4550 4560 4562 4570 4571 4572 4583 4588 4594 4604 4608 4623 4634 4636 4646 4651 4652 4686 4688 4691 4699 4724 4727 4737 4770 4774 4789 4802 4807 4819 4880 4886 4908 4927 4931 4936 4964 4976 4993 5028 5033 5043 5046 5096 5111 5114 5131 5132 5183 5199 5235 5275 5291 5293 5294 5343 5360 5362 5364 5390 5402 5418 5428 5430 5437 5443 5473 5484 5486 5505 5507 5508 5510 5567 5578 5580 5584 5606 5613 5629 5672 5676 5692 5701 5760 5769 5770 5779 5814 5850 5871 5893 5911 5949 5954 6005 6006 6012 6017 6023 6024 6040 6050 6054 6087 6105 6157 6235 6237 6256 6259 6286 6291 6306 6339 6341 6343 6379 6383 6393 6405 6479 6511 6517 6541 6561 6608 6611 6615 6678 6682 6707 6752 6798 6850 6880 6885 6890 6920 6981 7000 7009 7038 7049 7050 7052 7073 7078 7098 7111 7165 7198 7204 7280 7283 7286 7287 7293 7294 7305 7318 7341 7346 7354 7382 7427 7428 7435 7445 7450 7455 7467 7469 7497 7502 7506 7514 7523 7651 7661 7664 7672 7679 7685 7696 7730 7871 7873 7895 7914 7915 7920 7934 7935 7949 8009 8036 8051 8065 8074 8090 8112 8140 8164 8168 8178 8182 8198 8212 8216 8230 8242 8288 8289 8295 8318 8352 8368 8371 8375 8376 8401 8416 8419 8436 8460 8477 8478 8482 8498 8500 8539 8543 8552 8555 8580 8584 8586 8594 8598 8601 8606 8610 8611 8622 8627 8639 8649 8650 8653 8654 8667 8672 8673 8674 8676 8684 8720 8723 8750 8753 8801 8815 8831 8835 8842 8845 8858 8897 8916 8951 8954 8959 8970 8976 8981 8983 8989 8991 8993 9019 9039 9042 9043 9056 9057 9070 9087 9098 9106 9130 9131 9155 9171 9183 9198 9199 9201 9204 9212 9221 9225 9229 9250 9260 9271 9279 9295 9300 9310 9322 9345 9352 9376 9377 9382 9392 9401 9405 9441 9449 9464 9475 9502 9505 9514 9515 9545 9567 9576 9608 9609 9624 9633 9639 9643 9656 9674 9740 9752 9760 9767 9778 9802 9820 9839 9879 9924 9956 9961 9963 9970 9997 10010 10031 10040 10052 10073 10075 10078 10094 10097 10109 10118 10121 10124 10158 10226 10276 10304 10307 10314 10315 10332 10337 10338 10413 10423 10451 10463 10465 10487 10519 10522 10523 10532 10534 10535 10551 10559 10574 10583 10586 10589 10612 10626 10635 10638 10677 10683 10726 10776 10782 10783 10807 10837 10840 10848 10859 10871 10881 10884 10908 10914 10921 10936 10947 10951 10952 10957 10999 11003 11018 11023 11025 11027 11045 11055 11095 11110 11137 5564 11168 11186 11221 11223 11242 11255 11259 11279 11306 11311 11331 11367 11377 11389 11392 11401 11407 11437 11449 11466 11469 11473 11478 11483 11484 11507 11536 11558 11566 11575 11584 11594 11611 11612 11619 11621 11640 11643 11664 11674 11689 11709 11710 11716 11721 11726 11729 11743 11760 11771 11837 11839 11856 11876 11878 11884 11889 11896 11917 11923 11930 11944 11952 11980 11984 12214 12229 12239 12241 12242 12247 12283 12349 12369 12373 12422 12560 12566 12575 12688 12755 12768 12778 12780 12812 12832 12835 12836 12843 12847 12849 12850 12856 12858 12873 12938 12971 13017 13038 13046 13059 13085 13086 13088 13094 13134 13182 13230 13406 13444 13614 13690 13698 13709 13749 13804 13982 14051 14059 14219 14246 14256 14264 14294 14324 14367 14389 14394 14438 14442 14965 15732 16744 18037 18205 18535 18792 19102 20019 20462 21026 21045 21163 21171 21181 21196 21200 21369 21817 ================================================ FILE: data/samplers.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # -------------------------------------------------------- import torch class SubsetRandomSampler(torch.utils.data.Sampler): r"""Samples elements randomly from a given list of indices, without replacement. Arguments: indices (sequence): a sequence of indices """ def __init__(self, indices): self.epoch = 0 self.indices = indices def __iter__(self): return (self.indices[i] for i in torch.randperm(len(self.indices))) def __len__(self): return len(self.indices) def set_epoch(self, epoch): self.epoch = epoch ================================================ FILE: data/zipreader.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # -------------------------------------------------------- import os import zipfile import io import numpy as np from PIL import Image from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True def is_zip_path(img_or_path): """judge if this is a zip path""" return '.zip@' in img_or_path class ZipReader(object): """A class to read zipped files""" zip_bank = dict() def __init__(self): super(ZipReader, self).__init__() @staticmethod def get_zipfile(path): zip_bank = ZipReader.zip_bank if path not in zip_bank: zfile = zipfile.ZipFile(path, 'r') zip_bank[path] = zfile return zip_bank[path] @staticmethod def split_zip_style_path(path): pos_at = path.index('@') assert pos_at != -1, "character '@' is not found from the given path '%s'" % path zip_path = path[0: pos_at] folder_path = path[pos_at + 1:] folder_path = str.strip(folder_path, '/') return zip_path, folder_path @staticmethod def list_folder(path): zip_path, folder_path = ZipReader.split_zip_style_path(path) zfile = ZipReader.get_zipfile(zip_path) folder_list = [] for file_foler_name in zfile.namelist(): file_foler_name = str.strip(file_foler_name, '/') if file_foler_name.startswith(folder_path) and \ len(os.path.splitext(file_foler_name)[-1]) == 0 and \ file_foler_name != folder_path: if len(folder_path) == 0: folder_list.append(file_foler_name) else: folder_list.append(file_foler_name[len(folder_path) + 1:]) return folder_list @staticmethod def list_files(path, extension=None): if extension is None: extension = ['.*'] zip_path, folder_path = ZipReader.split_zip_style_path(path) zfile = ZipReader.get_zipfile(zip_path) file_lists = [] for file_foler_name in zfile.namelist(): file_foler_name = str.strip(file_foler_name, '/') if file_foler_name.startswith(folder_path) and \ str.lower(os.path.splitext(file_foler_name)[-1]) in extension: if len(folder_path) == 0: file_lists.append(file_foler_name) else: file_lists.append(file_foler_name[len(folder_path) + 1:]) return file_lists @staticmethod def read(path): zip_path, path_img = ZipReader.split_zip_style_path(path) zfile = ZipReader.get_zipfile(zip_path) data = zfile.read(path_img) return data @staticmethod def imread(path): zip_path, path_img = ZipReader.split_zip_style_path(path) zfile = ZipReader.get_zipfile(zip_path) data = zfile.read(path_img) try: im = Image.open(io.BytesIO(data)) except: print("ERROR IMG LOADED: ", path_img) random_img = np.random.rand(224, 224, 3) * 255 im = Image.fromarray(np.uint8(random_img)) return im ================================================ FILE: get_started.md ================================================ # Swin Transformer for Image Classification This folder contains the implementation of the Swin Transformer for image classification. ## Model Zoo Please refer to [MODEL HUB](MODELHUB.md#imagenet-22k-pretrained-swin-moe-models) for more pre-trained models. ## Usage ### Install We recommend using the pytorch docker `nvcr>=21.05` by nvidia: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch. - Clone this repo: ```bash git clone https://github.com/microsoft/Swin-Transformer.git cd Swin-Transformer ``` - Create a conda virtual environment and activate it: ```bash conda create -n swin python=3.7 -y conda activate swin ``` - Install `CUDA>=10.2` with `cudnn>=7` following the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) - Install `PyTorch>=1.8.0` and `torchvision>=0.9.0` with `CUDA>=10.2`: ```bash conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=10.2 -c pytorch ``` - Install `timm==0.4.12`: ```bash pip install timm==0.4.12 ``` - Install other requirements: ```bash pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8 pyyaml scipy ``` - Install fused window process for acceleration, activated by passing `--fused_window_process` in the running script ```bash cd kernels/window_process python setup.py install #--user ``` ### Data preparation We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to load data: - For standard folder dataset, move validation images to labeled sub-folders. The file structure should look like: ```bash $ tree data imagenet ├── train │ ├── class1 │ │ ├── img1.jpeg │ │ ├── img2.jpeg │ │ └── ... │ ├── class2 │ │ ├── img3.jpeg │ │ └── ... │ └── ... └── val ├── class1 │ ├── img4.jpeg │ ├── img5.jpeg │ └── ... ├── class2 │ ├── img6.jpeg │ └── ... └── ... ``` - To boost the slow speed when reading images from massive small files, we also support zipped ImageNet, which includes four files: - `train.zip`, `val.zip`: which store the zipped folder for train and validate splits. - `train_map.txt`, `val_map.txt`: which store the relative path in the corresponding zip file and ground truth label. Make sure the data folder looks like this: ```bash $ tree data data └── ImageNet-Zip ├── train_map.txt ├── train.zip ├── val_map.txt └── val.zip $ head -n 5 data/ImageNet-Zip/val_map.txt ILSVRC2012_val_00000001.JPEG 65 ILSVRC2012_val_00000002.JPEG 970 ILSVRC2012_val_00000003.JPEG 230 ILSVRC2012_val_00000004.JPEG 809 ILSVRC2012_val_00000005.JPEG 516 $ head -n 5 data/ImageNet-Zip/train_map.txt n01440764/n01440764_10026.JPEG 0 n01440764/n01440764_10027.JPEG 0 n01440764/n01440764_10029.JPEG 0 n01440764/n01440764_10040.JPEG 0 n01440764/n01440764_10042.JPEG 0 ``` - For ImageNet-22K dataset, make a folder named `fall11_whole` and move all images to labeled sub-folders in this folder. Then download the train-val split file ([ILSVRC2011fall_whole_map_train.txt](https://github.com/SwinTransformer/storage/releases/download/v2.0.1/ILSVRC2011fall_whole_map_train.txt) & [ILSVRC2011fall_whole_map_val.txt](https://github.com/SwinTransformer/storage/releases/download/v2.0.1/ILSVRC2011fall_whole_map_val.txt)) , and put them in the parent directory of `fall11_whole`. The file structure should look like: ```bash $ tree imagenet22k/ imagenet22k/ ├── ILSVRC2011fall_whole_map_train.txt ├── ILSVRC2011fall_whole_map_val.txt └── fall11_whole ├── n00004475 ├── n00005787 ├── n00006024 ├── n00006484 └── ... ``` ### Evaluation To evaluate a pre-trained `Swin Transformer` on ImageNet val, run: ```bash python -m torch.distributed.launch --nproc_per_node --master_port 12345 main.py --eval \ --cfg --resume --data-path ``` For example, to evaluate the `Swin-B` with a single GPU: ```bash python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --eval \ --cfg configs/swin/swin_base_patch4_window7_224.yaml --resume swin_base_patch4_window7_224.pth --data-path ``` ### Training from scratch on ImageNet-1K To train a `Swin Transformer` on ImageNet from scratch, run: ```bash python -m torch.distributed.launch --nproc_per_node --master_port 12345 main.py \ --cfg --data-path [--batch-size --output --tag ] ``` **Notes**: - To use zipped ImageNet instead of folder dataset, add `--zip` to the parameters. - To cache the dataset in the memory instead of reading from files every time, add `--cache-mode part`, which will shard the dataset into non-overlapping pieces for different GPUs and only load the corresponding one for each GPU. - When GPU memory is not enough, you can try the following suggestions: - Use gradient accumulation by adding `--accumulation-steps `, set appropriate `` according to your need. - Use gradient checkpointing by adding `--use-checkpoint`, e.g., it saves about 60% memory when training `Swin-B`. Please refer to [this page](https://pytorch.org/docs/stable/checkpoint.html) for more details. - We recommend using multi-node with more GPUs for training very large models, a tutorial can be found in [this page](https://pytorch.org/tutorials/intermediate/dist_tuto.html). - To change config options in general, you can use `--opts KEY1 VALUE1 KEY2 VALUE2`, e.g., `--opts TRAIN.EPOCHS 100 TRAIN.WARMUP_EPOCHS 5` will change total epochs to 100 and warm-up epochs to 5. - For additional options, see [config](config.py) and run `python main.py --help` to get detailed message. For example, to train `Swin Transformer` with 8 GPU on a single node for 300 epochs, run: `Swin-T`: ```bash python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ --cfg configs/swin/swin_tiny_patch4_window7_224.yaml --data-path --batch-size 128 ``` `Swin-S`: ```bash python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ --cfg configs/swin/swin_small_patch4_window7_224.yaml --data-path --batch-size 128 ``` `Swin-B`: ```bash python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ --cfg configs/swin/swin_base_patch4_window7_224.yaml --data-path --batch-size 64 \ --accumulation-steps 2 [--use-checkpoint] ``` ### Pre-training on ImageNet-22K For example, to pre-train a `Swin-B` model on ImageNet-22K: ```bash python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ --cfg configs/swin/swin_base_patch4_window7_224_22k.yaml --data-path --batch-size 64 \ --accumulation-steps 8 [--use-checkpoint] ``` ### Fine-tuning on higher resolution For example, to fine-tune a `Swin-B` model pre-trained on 224x224 resolution to 384x384 resolution: ```bashs python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ --cfg configs/swin/swin_base_patch4_window12_384_finetune.yaml --pretrained swin_base_patch4_window7_224.pth \ --data-path --batch-size 64 --accumulation-steps 2 [--use-checkpoint] ``` ### Fine-tuning from a ImageNet-22K(21K) pre-trained model For example, to fine-tune a `Swin-B` model pre-trained on ImageNet-22K(21K): ```bashs python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ --cfg configs/swin/swin_base_patch4_window7_224_22kto1k_finetune.yaml --pretrained swin_base_patch4_window7_224_22k.pth \ --data-path --batch-size 64 --accumulation-steps 2 [--use-checkpoint] ``` ### Throughput To measure the throughput, run: ```bash python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py \ --cfg --data-path --batch-size 64 --throughput --disable_amp ``` ## Mixture-of-Experts Support ### Install [Tutel](https://github.com/microsoft/tutel) ```bash python3 -m pip uninstall tutel -y python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@main ``` ### Training Swin-MoE For example, to train a `Swin-MoE-S` model with 32 experts on ImageNet-22K with 32 GPUs (4 nodes): ```bash python -m torch.distributed.launch --nproc_per_node 8 --nnode=4 \ --node_rank= --master_addr= --master_port 12345 main_moe.py \ --cfg configs/swinmoe/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml --data-path --batch-size 128 ``` ### Evaluating Swin-MoE To evaluate a `Swin-MoE-S` with 32 experts on ImageNet-22K with 32 GPUs (4 nodes): 1. Download the zip file [swin_moe_small_patch4_window12_192_32expert_32gpu_22k.zip](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.zip) which contains the pre-trained models for each rank, and unzip them to the folder "swin_moe_small_patch4_window12_192_32expert_32gpu_22k". 2. Run the following evaluation command, note the checkpoint path should not contain the ".rank\" suffix. ```bash python -m torch.distributed.launch --nproc_per_node 8 --nnode=4 \ --node_rank= --master_addr= --master_port 12345 main_moe.py \ --cfg configs/swinmoe/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml --data-path --batch-size 128 \ --resume swin_moe_small_patch4_window12_192_32expert_32gpu_22k/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.pth ``` More Swin-MoE models can be found in [MODEL HUB](MODELHUB.md#imagenet-22k-pretrained-swin-moe-models) ## SimMIM Support ### Evaluating provided models To evaluate a provided model on ImageNet validation set, run: ```bash python -m torch.distributed.launch --nproc_per_node main_simmim_ft.py \ --eval --cfg --resume --data-path ``` For example, to evaluate the `Swin Base` model on a single GPU, run: ```bash python -m torch.distributed.launch --nproc_per_node 1 main_simmim_ft.py \ --eval --cfg configs/simmim/simmim_finetune__swin_base__img224_window7__800ep.yaml --resume simmim_finetune__swin_base__img224_window7__800ep.pth --data-path ``` ### Pre-training with SimMIM To pre-train models with `SimMIM`, run: ```bash python -m torch.distributed.launch --nproc_per_node main_simmim_pt.py \ --cfg --data-path /train [--batch-size --output --tag ] ``` For example, to pre-train `Swin Base` for 800 epochs on one DGX-2 server, run: ```bash python -m torch.distributed.launch --nproc_per_node 16 main_simmim_pt.py \ --cfg configs/simmim/simmim_pretrain__swin_base__img192_window6__800ep.yaml --batch-size 128 --data-path /train [--output --tag ] ``` ### Fine-tuning pre-trained models To fine-tune models pre-trained by `SimMIM`, run: ```bash python -m torch.distributed.launch --nproc_per_node main_simmim_ft.py \ --cfg --data-path --pretrained [--batch-size --output --tag ] ``` For example, to fine-tune `Swin Base` pre-trained by `SimMIM` on one DGX-2 server, run: ```bash python -m torch.distributed.launch --nproc_per_node 16 main_simmim_ft.py \ --cfg configs/simmim/simmim_finetune__swin_base__img224_window7__800ep.yaml --batch-size 128 --data-path --pretrained [--output --tag ] ``` ================================================ FILE: kernels/window_process/setup.py ================================================ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension setup(name='swin_window_process', ext_modules=[ CUDAExtension('swin_window_process', [ 'swin_window_process.cpp', 'swin_window_process_kernel.cu', ]) ], cmdclass={'build_ext': BuildExtension}) ================================================ FILE: kernels/window_process/swin_window_process.cpp ================================================ /* * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include at::Tensor roll_and_window_partition_forward_cuda( at::Tensor & input, //at::Tensor & output, const int B, const int H, const int W, const int C, const int shift_size, const int window_size); at::Tensor roll_and_window_partition_backward_cuda( at::Tensor & grad_in, //at::Tensor & grad_out, const int B, const int H, const int W, const int C, const int shift_size, const int window_size); at::Tensor window_merge_and_roll_forward_cuda( at::Tensor & input, //at::Tensor & output, const int B, const int H, const int W, const int C, const int shift_size, const int window_size); at::Tensor window_merge_and_roll_backward_cuda( at::Tensor & grad_in, //at::Tensor & grad_out, const int B, const int H, const int W, const int C, const int shift_size, const int window_size); #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) at::Tensor roll_and_window_partition_forward( at::Tensor & input, //at::Tensor & output, const int B, const int H, const int W, const int C, const int shift_size, const int window_size){ CHECK_INPUT(input); return roll_and_window_partition_forward_cuda(input, B, H, W, C, shift_size, window_size); } at::Tensor roll_and_window_partition_backward( at::Tensor & grad_in, //at::Tensor & grad_out, const int B, const int H, const int W, const int C, const int shift_size, const int window_size){ CHECK_INPUT(grad_in); return roll_and_window_partition_backward_cuda(grad_in, B, H, W, C, shift_size, window_size); } at::Tensor window_merge_and_roll_forward( at::Tensor & input, //at::Tensor & output, const int B, const int H, const int W, const int C, const int shift_size, const int window_size){ CHECK_INPUT(input); return window_merge_and_roll_forward_cuda(input, B, H, W, C, shift_size, window_size); } at::Tensor window_merge_and_roll_backward( at::Tensor & grad_in, //at::Tensor & grad_out, const int B, const int H, const int W, const int C, const int shift_size, const int window_size){ CHECK_INPUT(grad_in); return window_merge_and_roll_backward_cuda(grad_in, B, H, W, C, shift_size, window_size); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("roll_and_window_partition_forward", &roll_and_window_partition_forward, "torch.roll and window_partition."); m.def("roll_and_window_partition_backward", &roll_and_window_partition_backward, "torch.roll and window_partition."); m.def("window_merge_and_roll_forward", &window_merge_and_roll_forward, "window merge and torch.roll."); m.def("window_merge_and_roll_backward", &window_merge_and_roll_backward, "window merge and torch.roll."); } ================================================ FILE: kernels/window_process/swin_window_process_kernel.cu ================================================ /* * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include int best_block_dim(int feat_dim){ int best_dim; if (feat_dim < 384){ best_dim = 64; } else{ if (feat_dim < 1024){ best_dim = 128; } else{ best_dim = 256; } } return best_dim; } template __global__ void roll_and_window_partition_forward_cuda_kernel( T* input, T* output, const int B, const int H, const int W, const int C, const int shift_size, const int window_size, const int nH, const int nW){ // start //bool qual = threadIdx.x < C; int index = threadIdx.x; int offset; for (int i = index; i < C; i += blockDim.x) { offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize int input_offset = blockIdx.z / (nH * nW) * H * W * C + (blockIdx.z % (nH * nW) / nW * window_size + blockIdx.y - shift_size + H) % H * W * C + (blockIdx.z % nW * window_size + blockIdx.x - shift_size + W) % W * C + i; output[offset] = (T)(__ldg(input + input_offset)); } } template __global__ void roll_and_window_partition_backward_cuda_kernel( T* grad_in, T* grad_out, const int B, const int H, const int W, const int C, const int shift_size, const int window_size, const int nH, const int nW){ // start int index = threadIdx.x; int offset; for (int i = index; i < C; i += blockDim.x) { offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize int input_offset = (blockIdx.z * nH * nW + (blockIdx.y + shift_size + H) % H / window_size * nW + (blockIdx.x + shift_size + W) % W / window_size) * window_size * window_size * C + (blockIdx.y + shift_size + H ) % H % window_size * window_size * C + (blockIdx.x + shift_size + W ) % W % window_size * C + i; grad_out[offset] = (T)(__ldg(grad_in + input_offset)); } } template __global__ void window_merge_and_roll_forward_cuda_kernel( T* input, T* output, const int B, const int H, const int W, const int C, const int shift_size, const int window_size, const int nH, const int nW){ // start int index = threadIdx.x; int offset; for (int i = index; i < C; i += blockDim.x) { offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize int input_offset = (blockIdx.z * nH * nW + (blockIdx.y - shift_size + H) % H / window_size * nH + (blockIdx.x - shift_size + W) % W / window_size) * window_size * window_size * C + (blockIdx.y - shift_size + H) % window_size * window_size * C + (blockIdx.x - shift_size + W) % window_size * C + i; output[offset] = (T)(__ldg(input + input_offset)); } } template __global__ void window_merge_and_roll_backward_cuda_kernel( T* grad_in, T* grad_out, const int B, const int H, const int W, const int C, const int shift_size, const int window_size, const int nH, const int nW){ // start int index = threadIdx.x; int offset; for (int i = index; i < C; i += blockDim.x) { offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize int input_offset = (blockIdx.z / (nH * nW)) * H * W * C + (blockIdx.z % (nH * nW) / nW * window_size + blockIdx.y + shift_size + H) % H * W * C + (blockIdx.z % nW * window_size + blockIdx.x + shift_size + W) % W * C + i; grad_out[offset] = (T)(__ldg(grad_in + input_offset)); } } // input: [B, H, W, C] // output: [B*nH*nW, window_size, window_size, C] at::Tensor roll_and_window_partition_forward_cuda( at::Tensor & input, //at::Tensor & output, const int B, const int H, const int W, const int C, const int shift_size, const int window_size){ int nH = H / window_size; int nW = W / window_size; dim3 grid(window_size, window_size, B * nH * nW); //dim3 block((C + 31) / 32 * 32); int blocknum = best_block_dim(C); dim3 block(blocknum); at::Tensor output; if (input.scalar_type() == torch::kFloat16){ output = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(true)); } else{ output = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(true)); } AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "roll_and_window_partition_forward_cuda_kernel", ([&] { roll_and_window_partition_forward_cuda_kernel<<>>( input.data(), output.data(), B, H, W, C, shift_size, window_size, nH, nW); })); return output; } // grad_in: [B*nH*nW, window_size, window_size, C] // grad_out: [B, H, W, C] at::Tensor roll_and_window_partition_backward_cuda( at::Tensor & grad_in, const int B, const int H, const int W, const int C, const int shift_size, const int window_size){ int nH = H / window_size; int nW = W / window_size; dim3 grid(W, H, B); //dim3 block((C + 31) / 32 * 32); int blocknum = best_block_dim(C); dim3 block(blocknum); at::Tensor grad_out; if (grad_in.scalar_type() == torch::kFloat16){ grad_out = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false)); } else{ grad_out = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); } AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_in.type(), "roll_and_window_partition_backward_cuda_kernel", ([&] { roll_and_window_partition_backward_cuda_kernel<<>>( grad_in.data(), grad_out.data(), B, H, W, C, shift_size, window_size, nH, nW); })); return grad_out; } // input: [B*nH*nW, window_size, window_size, C] // output: [B, H, W, C] at::Tensor window_merge_and_roll_forward_cuda( at::Tensor & input, //at::Tensor & output, const int B, const int H, const int W, const int C, const int shift_size, const int window_size){ int nH = H / window_size; int nW = W / window_size; dim3 grid(W, H, B); //dim3 block((C + 31) / 32 * 32); int blocknum = best_block_dim(C); dim3 block(blocknum); //generate output tensor inside at::Tensor output; if (input.scalar_type() == torch::kFloat16){ output = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(true)); } else{ output = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(true)); } AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "window_merge_and_roll_forward_cuda_kernel", ([&] { window_merge_and_roll_forward_cuda_kernel<<>>( input.data(), output.data(), B, H, W, C, shift_size, window_size, nH, nW); })); return output; } at::Tensor window_merge_and_roll_backward_cuda( at::Tensor & grad_in, const int B, const int H, const int W, const int C, const int shift_size, const int window_size){ int nH = H / window_size; int nW = W / window_size; dim3 grid(window_size, window_size, B * nH * nW); //dim3 block((C + 31) / 32 * 32); int blocknum = best_block_dim(C); dim3 block(blocknum); at::Tensor grad_out; if (grad_in.scalar_type() == torch::kFloat16){ grad_out = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false)); } else{ grad_out = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); } AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_in.type(), "window_merge_and_roll_backward_cuda_kernel", ([&] { window_merge_and_roll_backward_cuda_kernel<<>>( grad_in.data(), grad_out.data(), B, H, W, C, shift_size, window_size, nH, nW); })); return grad_out; } ================================================ FILE: kernels/window_process/unit_test.py ================================================ # -------------------------------------------------------- # Fused kernel for window process for SwinTransformer # Copyright (c) 2022 Nvidia # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import torch import swin_window_process import random import time import unittest class WindowProcess(torch.autograd.Function): @staticmethod def forward(ctx, input, B, H, W, C, shift_size, window_size): output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size) ctx.B = B ctx.H = H ctx.W = W ctx.C = C ctx.shift_size = shift_size ctx.window_size = window_size return output @staticmethod def backward(ctx, grad_in): B = ctx.B H = ctx.H W = ctx.W C = ctx.C shift_size = ctx.shift_size window_size = ctx.window_size grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size) return grad_out, None, None, None, None, None, None, None class WindowProcessReverse(torch.autograd.Function): @staticmethod def forward(ctx, input, B, H, W, C, shift_size, window_size): output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size) ctx.B = B ctx.H = H ctx.W = W ctx.C = C ctx.shift_size = shift_size ctx.window_size = window_size return output @staticmethod def backward(ctx, grad_in): B = ctx.B H = ctx.H W = ctx.W C = ctx.C shift_size = ctx.shift_size window_size = ctx.window_size grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size) return grad_out, None, None, None, None, None, None, None def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x def pyt_forward(x, shift_size, window_size): # x in shape(B, H, W, C) # cyclic shift if shift_size > 0: shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) else: shifted_x = x # partition windows x_windows = window_partition(shifted_x, window_size) return x_windows def reverse_pyt_forward(attn_windows, shift_size, window_size, H, W): # x in shape(B*nH*nW, window_size, window_size, C) shifted_x = window_reverse(attn_windows, window_size, H, W) if shift_size > 0: x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2)) else: x = shifted_x return x def copy_one_tensor(input, requires_grad=True): input1 = input.clone().detach().requires_grad_(requires_grad).cuda() return input1 class Test_WindowProcess(unittest.TestCase): def setUp(self): self.B = 192 self.H = 56 self.W = 56 self.C = 96 self.shift_size = 2 self.window_size = 7 self.nH = self.H // self.window_size self.nW = self.W // self.window_size def test_roll_and_window_partition_forward(self, dtype=torch.float32): input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() input1 = copy_one_tensor(input, True) input2 = copy_one_tensor(input, True) with torch.no_grad(): # ori expected = pyt_forward(input1, self.shift_size, self.window_size) # fused kernel fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size) self.assertTrue(torch.equal(expected, fused_output)) #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) def test_roll_and_window_partition_backward(self, dtype=torch.float32): input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() d_loss_tensor = torch.randn((self.B*self.nW*self.nH, self.window_size, self.window_size, self.C), dtype=dtype).cuda() input1 = copy_one_tensor(input, True) input2 = copy_one_tensor(input, True) # ori expected = pyt_forward(input1, self.shift_size, self.window_size) expected.backward(d_loss_tensor) # fused kernel fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size) fused_output.backward(d_loss_tensor) self.assertTrue(torch.equal(expected, fused_output)) #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) def test_window_merge_and_roll_forward(self, dtype=torch.float32): input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda() input1 = copy_one_tensor(input, True) input2 = copy_one_tensor(input, True) with torch.no_grad(): # ori expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W) # fused kernel fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size) self.assertTrue(torch.equal(expected, fused_output)) #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) def test_window_merge_and_roll_backward(self, dtype=torch.float32): input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda() d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() input1 = copy_one_tensor(input, True) input2 = copy_one_tensor(input, True) # ori expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W) expected.backward(d_loss_tensor) # fused kernel fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size) fused_output.backward(d_loss_tensor) self.assertTrue(torch.equal(expected, fused_output)) #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) def test_forward_backward_speed(self, dtype=torch.float32, times=1000): input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda() d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() input1 = copy_one_tensor(input, True) input2 = copy_one_tensor(input, True) # SwinTransformer official def run_pyt(t=1000): for _ in range(t): expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W) expected.backward(d_loss_tensor) # my op def run_fusedop(t=1000): for _ in range(t): fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size) fused_output.backward(d_loss_tensor) torch.cuda.synchronize() t1 = time.time() run_pyt(t=times) torch.cuda.synchronize() t2 = time.time() run_fusedop(t=times) torch.cuda.synchronize() t3 = time.time() self.assertTrue((t3 - t2) < (t2 - t1)) print('Run {} times'.format(times)) print('Original time cost: {}'.format(t2 - t1)) print('Fused op time cost: {}'.format(t3 - t2)) def test_roll_and_window_partition_forward_fp16(self, dtype=torch.float16): self.test_roll_and_window_partition_forward(dtype=dtype) def test_roll_and_window_partition_backward_fp16(self, dtype=torch.float16): self.test_roll_and_window_partition_backward(dtype=dtype) def test_window_merge_and_roll_forward_fp16(self, dtype=torch.float16): self.test_window_merge_and_roll_forward(dtype=dtype) def test_window_merge_and_roll_backward_fp16(self, dtype=torch.float16): self.test_window_merge_and_roll_backward(dtype=dtype) def test_forward_backward_speed_fp16(self, dtype=torch.float16, times=1000): self.test_forward_backward_speed(dtype=dtype, times=times) if __name__ == '__main__': print('Pass only two tensors are exactly the same (using torch.equal).\n') torch.manual_seed(0) unittest.main(verbosity=2) ================================================ FILE: kernels/window_process/window_process.py ================================================ # -------------------------------------------------------- # Fused kernel for window process for SwinTransformer # Copyright (c) 2022 Nvidia # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import torch import swin_window_process class WindowProcess(torch.autograd.Function): @staticmethod def forward(ctx, input, B, H, W, C, shift_size, window_size): output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size) ctx.B = B ctx.H = H ctx.W = W ctx.C = C ctx.shift_size = shift_size ctx.window_size = window_size return output @staticmethod def backward(ctx, grad_in): B = ctx.B H = ctx.H W = ctx.W C = ctx.C shift_size = ctx.shift_size window_size = ctx.window_size grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size) return grad_out, None, None, None, None, None, None, None class WindowProcessReverse(torch.autograd.Function): @staticmethod def forward(ctx, input, B, H, W, C, shift_size, window_size): output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size) ctx.B = B ctx.H = H ctx.W = W ctx.C = C ctx.shift_size = shift_size ctx.window_size = window_size return output @staticmethod def backward(ctx, grad_in): B = ctx.B H = ctx.H W = ctx.W C = ctx.C shift_size = ctx.shift_size window_size = ctx.window_size #grad_out = ctx.saved_tensors[0] #grad_out = torch.zeros((B, H, W, C), dtype=dtype).cuda() grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size) return grad_out, None, None, None, None, None, None, None ================================================ FILE: logger.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # -------------------------------------------------------- import os import sys import logging import functools from termcolor import colored @functools.lru_cache() def create_logger(output_dir, dist_rank=0, name=''): # create logger logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) logger.propagate = False # create formatter fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' # create console handlers for master process if dist_rank == 0: console_handler = logging.StreamHandler(sys.stdout) console_handler.setLevel(logging.DEBUG) console_handler.setFormatter( logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) logger.addHandler(console_handler) # create file handlers file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) logger.addHandler(file_handler) return logger ================================================ FILE: lr_scheduler.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # -------------------------------------------------------- import bisect import torch from timm.scheduler.cosine_lr import CosineLRScheduler from timm.scheduler.step_lr import StepLRScheduler from timm.scheduler.scheduler import Scheduler def build_scheduler(config, optimizer, n_iter_per_epoch): num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) multi_steps = [i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS] lr_scheduler = None if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': lr_scheduler = CosineLRScheduler( optimizer, t_initial=(num_steps - warmup_steps) if config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX else num_steps, t_mul=1., lr_min=config.TRAIN.MIN_LR, warmup_lr_init=config.TRAIN.WARMUP_LR, warmup_t=warmup_steps, cycle_limit=1, t_in_epochs=False, warmup_prefix=config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX, ) elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': lr_scheduler = LinearLRScheduler( optimizer, t_initial=num_steps, lr_min_rate=0.01, warmup_lr_init=config.TRAIN.WARMUP_LR, warmup_t=warmup_steps, t_in_epochs=False, ) elif config.TRAIN.LR_SCHEDULER.NAME == 'step': lr_scheduler = StepLRScheduler( optimizer, decay_t=decay_steps, decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, warmup_lr_init=config.TRAIN.WARMUP_LR, warmup_t=warmup_steps, t_in_epochs=False, ) elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep': lr_scheduler = MultiStepLRScheduler( optimizer, milestones=multi_steps, gamma=config.TRAIN.LR_SCHEDULER.GAMMA, warmup_lr_init=config.TRAIN.WARMUP_LR, warmup_t=warmup_steps, t_in_epochs=False, ) return lr_scheduler class LinearLRScheduler(Scheduler): def __init__(self, optimizer: torch.optim.Optimizer, t_initial: int, lr_min_rate: float, warmup_t=0, warmup_lr_init=0., t_in_epochs=True, noise_range_t=None, noise_pct=0.67, noise_std=1.0, noise_seed=42, initialize=True, ) -> None: super().__init__( optimizer, param_group_field="lr", noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, initialize=initialize) self.t_initial = t_initial self.lr_min_rate = lr_min_rate self.warmup_t = warmup_t self.warmup_lr_init = warmup_lr_init self.t_in_epochs = t_in_epochs if self.warmup_t: self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] super().update_groups(self.warmup_lr_init) else: self.warmup_steps = [1 for _ in self.base_values] def _get_lr(self, t): if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: t = t - self.warmup_t total_t = self.t_initial - self.warmup_t lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] return lrs def get_epoch_values(self, epoch: int): if self.t_in_epochs: return self._get_lr(epoch) else: return None def get_update_values(self, num_updates: int): if not self.t_in_epochs: return self._get_lr(num_updates) else: return None class MultiStepLRScheduler(Scheduler): def __init__(self, optimizer: torch.optim.Optimizer, milestones, gamma=0.1, warmup_t=0, warmup_lr_init=0, t_in_epochs=True) -> None: super().__init__(optimizer, param_group_field="lr") self.milestones = milestones self.gamma = gamma self.warmup_t = warmup_t self.warmup_lr_init = warmup_lr_init self.t_in_epochs = t_in_epochs if self.warmup_t: self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] super().update_groups(self.warmup_lr_init) else: self.warmup_steps = [1 for _ in self.base_values] assert self.warmup_t <= min(self.milestones) def _get_lr(self, t): if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: lrs = [v * (self.gamma ** bisect.bisect_right(self.milestones, t)) for v in self.base_values] return lrs def get_epoch_values(self, epoch: int): if self.t_in_epochs: return self._get_lr(epoch) else: return None def get_update_values(self, num_updates: int): if not self.t_in_epochs: return self._get_lr(num_updates) else: return None ================================================ FILE: main.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # -------------------------------------------------------- import os import time import json import random import argparse import datetime import numpy as np import torch import torch.backends.cudnn as cudnn import torch.distributed as dist from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from timm.utils import accuracy, AverageMeter from config import get_config from models import build_model from data import build_loader from lr_scheduler import build_scheduler from optimizer import build_optimizer from logger import create_logger from utils import load_checkpoint, load_pretrained, save_checkpoint, NativeScalerWithGradNormCount, auto_resume_helper, \ reduce_tensor # pytorch major version (1.x or 2.x) PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0]) def parse_option(): parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False) parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) parser.add_argument( "--opts", help="Modify config options by adding 'KEY VALUE' pairs. ", default=None, nargs='+', ) # easy config modification parser.add_argument('--batch-size', type=int, help="batch size for single GPU") parser.add_argument('--data-path', type=str, help='path to dataset') parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], help='no: no cache, ' 'full: cache all data, ' 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') parser.add_argument('--pretrained', help='pretrained weight from checkpoint, could be imagenet22k pretrained weight') parser.add_argument('--resume', help='resume from checkpoint') parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") parser.add_argument('--use-checkpoint', action='store_true', help="whether to use gradient checkpointing to save memory") parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp') parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'], help='mixed precision opt level, if O0, no amp is used (deprecated!)') parser.add_argument('--output', default='output', type=str, metavar='PATH', help='root of output folder, the full path is // (default: output)') parser.add_argument('--tag', help='tag of experiment') parser.add_argument('--eval', action='store_true', help='Perform evaluation only') parser.add_argument('--throughput', action='store_true', help='Test throughput only') # distributed training # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead # (see https://pytorch.org/docs/stable/distributed.html#launch-utility) if PYTORCH_MAJOR_VERSION == 1: parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') # for acceleration parser.add_argument('--fused_window_process', action='store_true', help='Fused window shift & window partition, similar for reversed part.') parser.add_argument('--fused_layernorm', action='store_true', help='Use fused layernorm.') ## overwrite optimizer in config (*.yaml) if specified, e.g., fused_adam/fused_lamb parser.add_argument('--optim', type=str, help='overwrite optimizer if provided, can be adamw/sgd/fused_adam/fused_lamb.') args, unparsed = parser.parse_known_args() config = get_config(args) return args, config def main(config): dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config) logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") model = build_model(config) logger.info(str(model)) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"number of params: {n_parameters}") if hasattr(model, 'flops'): flops = model.flops() logger.info(f"number of GFLOPs: {flops / 1e9}") model.cuda() model_without_ddp = model optimizer = build_optimizer(config, model) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) loss_scaler = NativeScalerWithGradNormCount() if config.TRAIN.ACCUMULATION_STEPS > 1: lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS) else: lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) if config.AUG.MIXUP > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif config.MODEL.LABEL_SMOOTHING > 0.: criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) else: criterion = torch.nn.CrossEntropyLoss() max_accuracy = 0.0 if config.TRAIN.AUTO_RESUME: resume_file = auto_resume_helper(config.OUTPUT) if resume_file: if config.MODEL.RESUME: logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") config.defrost() config.MODEL.RESUME = resume_file config.freeze() logger.info(f'auto resuming from {resume_file}') else: logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') if config.MODEL.RESUME: max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger) acc1, acc5, loss = validate(config, data_loader_val, model) logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") if config.EVAL_MODE: return if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): load_pretrained(config, model_without_ddp, logger) acc1, acc5, loss = validate(config, data_loader_val, model) logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") if config.THROUGHPUT_MODE: throughput(data_loader_val, model, logger) return logger.info("Start training") start_time = time.time() for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): data_loader_train.sampler.set_epoch(epoch) train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler) if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger) acc1, acc5, loss = validate(config, data_loader_val, model) logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") max_accuracy = max(max_accuracy, acc1) logger.info(f'Max accuracy: {max_accuracy:.2f}%') total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str)) def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler): model.train() optimizer.zero_grad() num_steps = len(data_loader) batch_time = AverageMeter() loss_meter = AverageMeter() norm_meter = AverageMeter() scaler_meter = AverageMeter() start = time.time() end = time.time() for idx, (samples, targets) in enumerate(data_loader): samples = samples.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True) if mixup_fn is not None: samples, targets = mixup_fn(samples, targets) with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): outputs = model(samples) loss = criterion(outputs, targets) loss = loss / config.TRAIN.ACCUMULATION_STEPS # this attribute is added by timm on one optimizer (adahessian) is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD, parameters=model.parameters(), create_graph=is_second_order, update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0) if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: optimizer.zero_grad() lr_scheduler.step_update((epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS) loss_scale_value = loss_scaler.state_dict()["scale"] torch.cuda.synchronize() loss_meter.update(loss.item(), targets.size(0)) if grad_norm is not None: # loss_scaler return None if not update norm_meter.update(grad_norm) scaler_meter.update(loss_scale_value) batch_time.update(time.time() - end) end = time.time() if idx % config.PRINT_FREQ == 0: lr = optimizer.param_groups[0]['lr'] wd = optimizer.param_groups[0]['weight_decay'] memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) etas = batch_time.avg * (num_steps - idx) logger.info( f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t wd {wd:.4f}\t' f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t' f'mem {memory_used:.0f}MB') epoch_time = time.time() - start logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") @torch.no_grad() def validate(config, data_loader, model): criterion = torch.nn.CrossEntropyLoss() model.eval() batch_time = AverageMeter() loss_meter = AverageMeter() acc1_meter = AverageMeter() acc5_meter = AverageMeter() end = time.time() for idx, (images, target) in enumerate(data_loader): images = images.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # compute output with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): output = model(images) # measure accuracy and record loss loss = criterion(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) acc1 = reduce_tensor(acc1) acc5 = reduce_tensor(acc5) loss = reduce_tensor(loss) loss_meter.update(loss.item(), target.size(0)) acc1_meter.update(acc1.item(), target.size(0)) acc5_meter.update(acc5.item(), target.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if idx % config.PRINT_FREQ == 0: memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) logger.info( f'Test: [{idx}/{len(data_loader)}]\t' f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' f'Mem {memory_used:.0f}MB') logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') return acc1_meter.avg, acc5_meter.avg, loss_meter.avg @torch.no_grad() def throughput(data_loader, model, logger): model.eval() for idx, (images, _) in enumerate(data_loader): images = images.cuda(non_blocking=True) batch_size = images.shape[0] for i in range(50): model(images) torch.cuda.synchronize() logger.info(f"throughput averaged with 30 times") tic1 = time.time() for i in range(30): model(images) torch.cuda.synchronize() tic2 = time.time() logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") return if __name__ == '__main__': args, config = parse_option() if config.AMP_OPT_LEVEL: print("[warning] Apex amp has been deprecated, please use pytorch amp instead!") if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: rank = int(os.environ["RANK"]) world_size = int(os.environ['WORLD_SIZE']) print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") else: rank = -1 world_size = -1 torch.cuda.set_device(config.LOCAL_RANK) torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) torch.distributed.barrier() seed = config.SEED + dist.get_rank() torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) random.seed(seed) cudnn.benchmark = True # linear scale the learning rate according to total batch size, may not be optimal linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 # gradient accumulation also need to scale the learning rate if config.TRAIN.ACCUMULATION_STEPS > 1: linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS config.defrost() config.TRAIN.BASE_LR = linear_scaled_lr config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr config.TRAIN.MIN_LR = linear_scaled_min_lr config.freeze() os.makedirs(config.OUTPUT, exist_ok=True) logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") if dist.get_rank() == 0: path = os.path.join(config.OUTPUT, "config.json") with open(path, "w") as f: f.write(config.dump()) logger.info(f"Full config saved to {path}") # print config logger.info(config.dump()) logger.info(json.dumps(vars(args))) main(config) ================================================ FILE: main_moe.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # -------------------------------------------------------- from tutel import system import os import time import json import random import argparse import datetime import numpy as np from functools import partial import torch import torch.backends.cudnn as cudnn import torch.distributed as dist from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from timm.utils import accuracy, AverageMeter from config import get_config from models import build_model from data import build_loader from lr_scheduler import build_scheduler from optimizer import build_optimizer from logger import create_logger from utils import NativeScalerWithGradNormCount, reduce_tensor from utils_moe import load_checkpoint, load_pretrained, save_checkpoint, auto_resume_helper, hook_scale_grad assert torch.__version__ >= '1.8.0', "DDP-based MoE requires Pytorch >= 1.8.0" # pytorch major version (1.x or 2.x) PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0]) def parse_option(): parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False) parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) parser.add_argument( "--opts", help="Modify config options by adding 'KEY VALUE' pairs. ", default=None, nargs='+', ) # easy config modification parser.add_argument('--batch-size', type=int, help="batch size for single GPU") parser.add_argument('--data-path', type=str, help='path to dataset') parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], help='no: no cache, ' 'full: cache all data, ' 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') parser.add_argument('--pretrained', help='pretrained weight from checkpoint, could be imagenet22k pretrained weight') parser.add_argument('--resume', help='resume from checkpoint') parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") parser.add_argument('--use-checkpoint', action='store_true', help="whether to use gradient checkpointing to save memory") parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp') parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'], help='mixed precision opt level, if O0, no amp is used (deprecated!)') parser.add_argument('--output', default='output', type=str, metavar='PATH', help='root of output folder, the full path is // (default: output)') parser.add_argument('--tag', help='tag of experiment') parser.add_argument('--eval', action='store_true', help='Perform evaluation only') parser.add_argument('--throughput', action='store_true', help='Test throughput only') # distributed training # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead # (see https://pytorch.org/docs/stable/distributed.html#launch-utility) if PYTORCH_MAJOR_VERSION == 1: parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') args, unparsed = parser.parse_known_args() config = get_config(args) return args, config def main(config): dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config) logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") model = build_model(config) logger.info(str(model)) # For Tutel MoE for name, param in model.named_parameters(): if param.requires_grad == True and hasattr(param, 'skip_allreduce') and param.skip_allreduce is True: model.add_param_to_skip_allreduce(name) param.register_hook(partial(hook_scale_grad, dist.get_world_size())) logger.info(f"[rank{dist.get_rank()}] [{name}] skip all_reduce and div {dist.get_world_size()} for grad") n_parameters_single = sum(p.numel() * model.sharded_count if hasattr(p, 'skip_allreduce') else p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"number of params single: {n_parameters_single}") n_parameters_whole = sum(p.numel() * model.sharded_count * model.global_experts if hasattr(p, 'skip_allreduce') else p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"number of params whole: {n_parameters_whole}") if hasattr(model, 'flops'): flops = model.flops() logger.info(f"number of GFLOPs: {flops / 1e9}") model.cuda(config.LOCAL_RANK) model_without_ddp = model optimizer = build_optimizer(config, model) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) loss_scaler = NativeScalerWithGradNormCount() if config.TRAIN.ACCUMULATION_STEPS > 1: lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS) else: lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) if config.AUG.MIXUP > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif config.MODEL.LABEL_SMOOTHING > 0.: criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) else: criterion = torch.nn.CrossEntropyLoss() max_accuracy = 0.0 if config.TRAIN.AUTO_RESUME: resume_file = auto_resume_helper(config.OUTPUT, config.TRAIN.MOE.SAVE_MASTER) if resume_file: if config.MODEL.RESUME: logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") config.defrost() config.MODEL.RESUME = resume_file config.freeze() logger.info(f'auto resuming from {resume_file}') else: logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') if config.MODEL.RESUME: max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger) acc1, acc5, loss = validate(config, data_loader_val, model) logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") if config.EVAL_MODE: return if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): load_pretrained(config, model_without_ddp, logger) acc1, acc5, loss = validate(config, data_loader_val, model) logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") if config.EVAL_MODE: return if config.THROUGHPUT_MODE: throughput(data_loader_val, model, logger) return logger.info("Start training") start_time = time.time() for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): data_loader_train.sampler.set_epoch(epoch) train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler) if (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger) acc1, acc5, loss = validate(config, data_loader_val, model) logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") max_accuracy = max(max_accuracy, acc1) logger.info(f'Max accuracy: {max_accuracy:.2f}%') save_checkpoint(config, 'final', model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger, zero_redundancy=True) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str)) def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler): model.train() optimizer.zero_grad() num_steps = len(data_loader) batch_time = AverageMeter() loss_meter = AverageMeter() loss_aux_meter = AverageMeter() loss_cls_meter = AverageMeter() norm_meter = AverageMeter() scaler_meter = AverageMeter() start = time.time() end = time.time() for idx, (samples, targets) in enumerate(data_loader): samples = samples.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True) if mixup_fn is not None: samples, targets = mixup_fn(samples, targets) with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): outputs, l_aux = model(samples) l_cls = criterion(outputs, targets) loss = l_cls + l_aux loss = loss / config.TRAIN.ACCUMULATION_STEPS # this attribute is added by timm on one optimizer (adahessian) is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD, parameters=model.parameters(), create_graph=is_second_order, update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0) if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: optimizer.zero_grad() lr_scheduler.step_update((epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS) loss_scale_value = loss_scaler.state_dict()["scale"] torch.cuda.synchronize() loss_meter.update(loss.item(), targets.size(0)) loss_cls_meter.update(l_cls.item(), targets.size(0)) loss_aux_meter.update(l_aux if isinstance(l_aux, float) else l_aux.item(), targets.size(0)) if grad_norm is not None: # loss_scaler return None if not update norm_meter.update(grad_norm) scaler_meter.update(loss_scale_value) batch_time.update(time.time() - end) end = time.time() if idx % config.PRINT_FREQ == 0: lr = optimizer.param_groups[0]['lr'] wd = optimizer.param_groups[0]['weight_decay'] memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) etas = batch_time.avg * (num_steps - idx) logger.info( f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t wd {wd:.4f}\t' f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'loss-cls {loss_cls_meter.val:.4f} ({loss_cls_meter.avg:.4f})\t' f'loss-aux {loss_aux_meter.val:.4f} ({loss_aux_meter.avg:.4f})\t' f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t' f'mem {memory_used:.0f}MB') epoch_time = time.time() - start logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") @torch.no_grad() def validate(config, data_loader, model): criterion = torch.nn.CrossEntropyLoss() model.eval() batch_time = AverageMeter() loss_cls_meter = AverageMeter() loss_aux_meter = AverageMeter() acc1_meter = AverageMeter() acc5_meter = AverageMeter() end = time.time() for idx, (images, target) in enumerate(data_loader): images = images.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # compute output with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): output, l_aux = model(images) # measure accuracy and record loss l_cls = criterion(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) acc1 = reduce_tensor(acc1) acc5 = reduce_tensor(acc5) loss_cls_meter.update(l_cls.item(), target.size(0)) loss_aux_meter.update(l_aux if isinstance(l_aux, float) else l_aux.item(), target.size(0)) acc1_meter.update(acc1.item(), target.size(0)) acc5_meter.update(acc5.item(), target.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if idx % config.PRINT_FREQ == 0: memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) logger.info( f'Test: [{idx}/{len(data_loader)}]\t' f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' f'Loss-Cls {loss_cls_meter.val:.4f} ({loss_cls_meter.avg:.4f})\t' f'Loss-Aux {loss_aux_meter.val:.4f} ({loss_aux_meter.avg:.4f})\t' f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' f'Mem {memory_used:.0f}MB') logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') return acc1_meter.avg, acc5_meter.avg, loss_cls_meter.avg @torch.no_grad() def throughput(data_loader, model, logger): model.eval() for idx, (images, _) in enumerate(data_loader): images = images.cuda(non_blocking=True) batch_size = images.shape[0] for i in range(50): model(images) torch.cuda.synchronize() logger.info(f"throughput averaged with 30 times") tic1 = time.time() for i in range(30): model(images) torch.cuda.synchronize() tic2 = time.time() logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") return if __name__ == '__main__': args, config = parse_option() if config.AMP_OPT_LEVEL: print("[warning] Apex amp has been deprecated, please use pytorch amp instead!") if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: rank = int(os.environ["RANK"]) world_size = int(os.environ['WORLD_SIZE']) print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") else: rank = -1 world_size = -1 torch.cuda.set_device(config.LOCAL_RANK) torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) torch.distributed.barrier() seed = config.SEED + dist.get_rank() torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) random.seed(seed) cudnn.benchmark = True # linear scale the learning rate according to total batch size, may not be optimal linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 # gradient accumulation also need to scale the learning rate if config.TRAIN.ACCUMULATION_STEPS > 1: linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS config.defrost() config.TRAIN.BASE_LR = linear_scaled_lr config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr config.TRAIN.MIN_LR = linear_scaled_min_lr config.freeze() os.makedirs(config.OUTPUT, exist_ok=True) logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") if dist.get_rank() == 0: path = os.path.join(config.OUTPUT, "config.json") with open(path, "w") as f: f.write(config.dump()) logger.info(f"Full config saved to {path}") # print config logger.info(config.dump()) logger.info(json.dumps(vars(args))) main(config) ================================================ FILE: main_simmim_ft.py ================================================ # -------------------------------------------------------- # SimMIM # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # Modified by Zhenda Xie # -------------------------------------------------------- import os import time import argparse import datetime import numpy as np import torch import torch.backends.cudnn as cudnn import torch.distributed as dist import torch.cuda.amp as amp from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from timm.utils import accuracy, AverageMeter from config import get_config from models import build_model from data import build_loader from lr_scheduler import build_scheduler from optimizer import build_optimizer from logger import create_logger from utils_simmim import load_checkpoint, load_pretrained, save_checkpoint, get_grad_norm, auto_resume_helper, \ reduce_tensor # pytorch major version (1.x or 2.x) PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0]) def parse_option(): parser = argparse.ArgumentParser('SimMIM fine-tuning script', add_help=False) parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) parser.add_argument( "--opts", help="Modify config options by adding 'KEY VALUE' pairs. ", default=None, nargs='+', ) # easy config modification parser.add_argument('--batch-size', type=int, help="batch size for single GPU") parser.add_argument('--data-path', type=str, help='path to dataset') parser.add_argument('--pretrained', type=str, help='path to pre-trained model') parser.add_argument('--resume', help='resume from checkpoint') parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") parser.add_argument('--use-checkpoint', action='store_true', help="whether to use gradient checkpointing to save memory") parser.add_argument('--enable-amp', action='store_true') parser.add_argument('--disable-amp', action='store_false', dest='enable_amp') parser.set_defaults(enable_amp=True) parser.add_argument('--output', default='output', type=str, metavar='PATH', help='root of output folder, the full path is // (default: output)') parser.add_argument('--tag', help='tag of experiment') parser.add_argument('--eval', action='store_true', help='Perform evaluation only') parser.add_argument('--throughput', action='store_true', help='Test throughput only') # distributed training # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead # (see https://pytorch.org/docs/stable/distributed.html#launch-utility) if PYTORCH_MAJOR_VERSION == 1: parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') args = parser.parse_args() config = get_config(args) return args, config def main(config): dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, simmim=True, is_pretrain=False) logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") model = build_model(config, is_pretrain=False) model.cuda() logger.info(str(model)) optimizer = build_optimizer(config, model, simmim=True, is_pretrain=False) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"number of params: {n_parameters}") if hasattr(model_without_ddp, 'flops'): flops = model_without_ddp.flops() logger.info(f"number of GFLOPs: {flops / 1e9}") lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) scaler = amp.GradScaler() if config.AUG.MIXUP > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif config.MODEL.LABEL_SMOOTHING > 0.: criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) else: criterion = torch.nn.CrossEntropyLoss() max_accuracy = 0.0 if config.TRAIN.AUTO_RESUME: resume_file = auto_resume_helper(config.OUTPUT, logger) if resume_file: if config.MODEL.RESUME: logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") config.defrost() config.MODEL.RESUME = resume_file config.freeze() logger.info(f'auto resuming from {resume_file}') else: logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') if config.MODEL.RESUME: max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger) acc1, acc5, loss = validate(config, data_loader_val, model) logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") if config.EVAL_MODE: return if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): load_pretrained(config, model_without_ddp, logger) acc1, acc5, loss = validate(config, data_loader_val, model) logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") if config.THROUGHPUT_MODE: throughput(data_loader_val, model, logger) return logger.info("Start training") start_time = time.time() for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): data_loader_train.sampler.set_epoch(epoch) train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, scaler) if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, scaler, logger) acc1, acc5, loss = validate(config, data_loader_val, model) logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") max_accuracy = max(max_accuracy, acc1) logger.info(f'Max accuracy: {max_accuracy:.2f}%') total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str)) def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, scaler): model.train() optimizer.zero_grad() logger.info(f'Current learning rate for different parameter groups: {[it["lr"] for it in optimizer.param_groups]}') num_steps = len(data_loader) batch_time = AverageMeter() loss_meter = AverageMeter() norm_meter = AverageMeter() loss_scale_meter = AverageMeter() start = time.time() end = time.time() for idx, (samples, targets) in enumerate(data_loader): samples = samples.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True) if mixup_fn is not None: samples, targets = mixup_fn(samples, targets) outputs = model(samples) if config.TRAIN.ACCUMULATION_STEPS > 1: loss = criterion(outputs, targets) loss = loss / config.TRAIN.ACCUMULATION_STEPS scaler.scale(loss).backward() if config.TRAIN.CLIP_GRAD: scaler.unscale_(optimizer) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(model.parameters()) if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: scaler.step(optimizer) optimizer.zero_grad() scaler.update() lr_scheduler.step_update(epoch * num_steps + idx) else: loss = criterion(outputs, targets) optimizer.zero_grad() scaler.scale(loss).backward() if config.TRAIN.CLIP_GRAD: scaler.unscale_(optimizer) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(model.parameters()) scaler.step(optimizer) scaler.update() lr_scheduler.step_update(epoch * num_steps + idx) torch.cuda.synchronize() loss_meter.update(loss.item(), targets.size(0)) norm_meter.update(grad_norm) loss_scale_meter.update(scaler.get_scale()) batch_time.update(time.time() - end) end = time.time() if idx % config.PRINT_FREQ == 0: lr = optimizer.param_groups[-1]['lr'] memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) etas = batch_time.avg * (num_steps - idx) logger.info( f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' f'loss_scale {loss_scale_meter.val:.4f} ({loss_scale_meter.avg:.4f})\t' f'mem {memory_used:.0f}MB') epoch_time = time.time() - start logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") @torch.no_grad() def validate(config, data_loader, model): criterion = torch.nn.CrossEntropyLoss() model.eval() batch_time = AverageMeter() loss_meter = AverageMeter() acc1_meter = AverageMeter() acc5_meter = AverageMeter() end = time.time() for idx, (images, target) in enumerate(data_loader): images = images.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # compute output output = model(images) # measure accuracy and record loss loss = criterion(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) acc1 = reduce_tensor(acc1) acc5 = reduce_tensor(acc5) loss = reduce_tensor(loss) loss_meter.update(loss.item(), target.size(0)) acc1_meter.update(acc1.item(), target.size(0)) acc5_meter.update(acc5.item(), target.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if idx % config.PRINT_FREQ == 0: memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) logger.info( f'Test: [{idx}/{len(data_loader)}]\t' f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' f'Mem {memory_used:.0f}MB') logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') return acc1_meter.avg, acc5_meter.avg, loss_meter.avg @torch.no_grad() def throughput(data_loader, model, logger): model.eval() for idx, (images, _) in enumerate(data_loader): images = images.cuda(non_blocking=True) batch_size = images.shape[0] for i in range(50): model(images) torch.cuda.synchronize() logger.info(f"throughput averaged with 30 times") tic1 = time.time() for i in range(30): model(images) torch.cuda.synchronize() tic2 = time.time() logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") return if __name__ == '__main__': _, config = parse_option() if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: rank = int(os.environ["RANK"]) world_size = int(os.environ['WORLD_SIZE']) print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") else: rank = -1 world_size = -1 torch.cuda.set_device(config.LOCAL_RANK) torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) torch.distributed.barrier() seed = config.SEED + dist.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True # linear scale the learning rate according to total batch size, may not be optimal linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 # gradient accumulation also need to scale the learning rate if config.TRAIN.ACCUMULATION_STEPS > 1: linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS config.defrost() config.TRAIN.BASE_LR = linear_scaled_lr config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr config.TRAIN.MIN_LR = linear_scaled_min_lr config.freeze() os.makedirs(config.OUTPUT, exist_ok=True) logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") if dist.get_rank() == 0: path = os.path.join(config.OUTPUT, "config.json") with open(path, "w") as f: f.write(config.dump()) logger.info(f"Full config saved to {path}") # print config logger.info(config.dump()) main(config) ================================================ FILE: main_simmim_pt.py ================================================ # -------------------------------------------------------- # SimMIM # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # Modified by Zhenda Xie # -------------------------------------------------------- import os import time import argparse import datetime import numpy as np import torch import torch.backends.cudnn as cudnn import torch.distributed as dist import torch.cuda.amp as amp from timm.utils import AverageMeter from config import get_config from models import build_model from data import build_loader from lr_scheduler import build_scheduler from optimizer import build_optimizer from logger import create_logger from utils_simmim import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper # pytorch major version (1.x or 2.x) PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0]) def parse_option(): parser = argparse.ArgumentParser('SimMIM pre-training script', add_help=False) parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) parser.add_argument( "--opts", help="Modify config options by adding 'KEY VALUE' pairs. ", default=None, nargs='+', ) # easy config modification parser.add_argument('--batch-size', type=int, help="batch size for single GPU") parser.add_argument('--data-path', type=str, help='path to dataset') parser.add_argument('--resume', help='resume from checkpoint') parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") parser.add_argument('--use-checkpoint', action='store_true', help="whether to use gradient checkpointing to save memory") parser.add_argument('--enable-amp', action='store_true') parser.add_argument('--disable-amp', action='store_false', dest='enable_amp') parser.set_defaults(enable_amp=True) parser.add_argument('--output', default='output', type=str, metavar='PATH', help='root of output folder, the full path is // (default: output)') parser.add_argument('--tag', help='tag of experiment') # distributed training # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead # (see https://pytorch.org/docs/stable/distributed.html#launch-utility) if PYTORCH_MAJOR_VERSION == 1: parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') args = parser.parse_args() config = get_config(args) return args, config def main(config): data_loader_train = build_loader(config, simmim=True, is_pretrain=True) logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") model = build_model(config, is_pretrain=True) model.cuda() logger.info(str(model)) optimizer = build_optimizer(config, model, simmim=True, is_pretrain=True) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"number of params: {n_parameters}") if hasattr(model_without_ddp, 'flops'): flops = model_without_ddp.flops() logger.info(f"number of GFLOPs: {flops / 1e9}") lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) scaler = amp.GradScaler() if config.TRAIN.AUTO_RESUME: resume_file = auto_resume_helper(config.OUTPUT, logger) if resume_file: if config.MODEL.RESUME: logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") config.defrost() config.MODEL.RESUME = resume_file config.freeze() logger.info(f'auto resuming from {resume_file}') else: logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') if config.MODEL.RESUME: load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger) logger.info("Start training") start_time = time.time() for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): data_loader_train.sampler.set_epoch(epoch) train_one_epoch(config, model, data_loader_train, optimizer, epoch, lr_scheduler, scaler) if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): save_checkpoint(config, epoch, model_without_ddp, 0., optimizer, lr_scheduler, scaler, logger) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str)) def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler, scaler): model.train() optimizer.zero_grad() num_steps = len(data_loader) batch_time = AverageMeter() loss_meter = AverageMeter() norm_meter = AverageMeter() loss_scale_meter = AverageMeter() start = time.time() end = time.time() for idx, (img, mask, _) in enumerate(data_loader): img = img.cuda(non_blocking=True) mask = mask.cuda(non_blocking=True) with amp.autocast(enabled=config.ENABLE_AMP): loss = model(img, mask) if config.TRAIN.ACCUMULATION_STEPS > 1: loss = loss / config.TRAIN.ACCUMULATION_STEPS scaler.scale(loss).backward() if config.TRAIN.CLIP_GRAD: scaler.unscale_(optimizer) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(model.parameters()) if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: scaler.step(optimizer) optimizer.zero_grad() scaler.update() lr_scheduler.step_update(epoch * num_steps + idx) else: optimizer.zero_grad() scaler.scale(loss).backward() if config.TRAIN.CLIP_GRAD: scaler.unscale_(optimizer) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(model.parameters()) scaler.step(optimizer) scaler.update() lr_scheduler.step_update(epoch * num_steps + idx) torch.cuda.synchronize() loss_meter.update(loss.item(), img.size(0)) norm_meter.update(grad_norm) loss_scale_meter.update(scaler.get_scale()) batch_time.update(time.time() - end) end = time.time() if idx % config.PRINT_FREQ == 0: lr = optimizer.param_groups[0]['lr'] memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) etas = batch_time.avg * (num_steps - idx) logger.info( f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' f'loss_scale {loss_scale_meter.val:.4f} ({loss_scale_meter.avg:.4f})\t' f'mem {memory_used:.0f}MB') epoch_time = time.time() - start logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") if __name__ == '__main__': _, config = parse_option() if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: rank = int(os.environ["RANK"]) world_size = int(os.environ['WORLD_SIZE']) print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") else: rank = -1 world_size = -1 torch.cuda.set_device(config.LOCAL_RANK) torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) torch.distributed.barrier() seed = config.SEED + dist.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True # linear scale the learning rate according to total batch size, may not be optimal linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 # gradient accumulation also need to scale the learning rate if config.TRAIN.ACCUMULATION_STEPS > 1: linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS config.defrost() config.TRAIN.BASE_LR = linear_scaled_lr config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr config.TRAIN.MIN_LR = linear_scaled_min_lr config.freeze() os.makedirs(config.OUTPUT, exist_ok=True) logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") if dist.get_rank() == 0: path = os.path.join(config.OUTPUT, "config.json") with open(path, "w") as f: f.write(config.dump()) logger.info(f"Full config saved to {path}") # print config logger.info(config.dump()) main(config) ================================================ FILE: models/__init__.py ================================================ from .build import build_model ================================================ FILE: models/build.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # -------------------------------------------------------- from .swin_transformer import SwinTransformer from .swin_transformer_v2 import SwinTransformerV2 from .swin_transformer_moe import SwinTransformerMoE from .swin_mlp import SwinMLP from .simmim import build_simmim def build_model(config, is_pretrain=False): model_type = config.MODEL.TYPE # accelerate layernorm if config.FUSED_LAYERNORM: try: import apex as amp layernorm = amp.normalization.FusedLayerNorm except: layernorm = None print("To use FusedLayerNorm, please install apex.") else: import torch.nn as nn layernorm = nn.LayerNorm if is_pretrain: model = build_simmim(config) return model if model_type == 'swin': model = SwinTransformer(img_size=config.DATA.IMG_SIZE, patch_size=config.MODEL.SWIN.PATCH_SIZE, in_chans=config.MODEL.SWIN.IN_CHANS, num_classes=config.MODEL.NUM_CLASSES, embed_dim=config.MODEL.SWIN.EMBED_DIM, depths=config.MODEL.SWIN.DEPTHS, num_heads=config.MODEL.SWIN.NUM_HEADS, window_size=config.MODEL.SWIN.WINDOW_SIZE, mlp_ratio=config.MODEL.SWIN.MLP_RATIO, qkv_bias=config.MODEL.SWIN.QKV_BIAS, qk_scale=config.MODEL.SWIN.QK_SCALE, drop_rate=config.MODEL.DROP_RATE, drop_path_rate=config.MODEL.DROP_PATH_RATE, ape=config.MODEL.SWIN.APE, norm_layer=layernorm, patch_norm=config.MODEL.SWIN.PATCH_NORM, use_checkpoint=config.TRAIN.USE_CHECKPOINT, fused_window_process=config.FUSED_WINDOW_PROCESS) elif model_type == 'swinv2': model = SwinTransformerV2(img_size=config.DATA.IMG_SIZE, patch_size=config.MODEL.SWINV2.PATCH_SIZE, in_chans=config.MODEL.SWINV2.IN_CHANS, num_classes=config.MODEL.NUM_CLASSES, embed_dim=config.MODEL.SWINV2.EMBED_DIM, depths=config.MODEL.SWINV2.DEPTHS, num_heads=config.MODEL.SWINV2.NUM_HEADS, window_size=config.MODEL.SWINV2.WINDOW_SIZE, mlp_ratio=config.MODEL.SWINV2.MLP_RATIO, qkv_bias=config.MODEL.SWINV2.QKV_BIAS, drop_rate=config.MODEL.DROP_RATE, drop_path_rate=config.MODEL.DROP_PATH_RATE, ape=config.MODEL.SWINV2.APE, patch_norm=config.MODEL.SWINV2.PATCH_NORM, use_checkpoint=config.TRAIN.USE_CHECKPOINT, pretrained_window_sizes=config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES) elif model_type == 'swin_moe': model = SwinTransformerMoE(img_size=config.DATA.IMG_SIZE, patch_size=config.MODEL.SWIN_MOE.PATCH_SIZE, in_chans=config.MODEL.SWIN_MOE.IN_CHANS, num_classes=config.MODEL.NUM_CLASSES, embed_dim=config.MODEL.SWIN_MOE.EMBED_DIM, depths=config.MODEL.SWIN_MOE.DEPTHS, num_heads=config.MODEL.SWIN_MOE.NUM_HEADS, window_size=config.MODEL.SWIN_MOE.WINDOW_SIZE, mlp_ratio=config.MODEL.SWIN_MOE.MLP_RATIO, qkv_bias=config.MODEL.SWIN_MOE.QKV_BIAS, qk_scale=config.MODEL.SWIN_MOE.QK_SCALE, drop_rate=config.MODEL.DROP_RATE, drop_path_rate=config.MODEL.DROP_PATH_RATE, ape=config.MODEL.SWIN_MOE.APE, patch_norm=config.MODEL.SWIN_MOE.PATCH_NORM, mlp_fc2_bias=config.MODEL.SWIN_MOE.MLP_FC2_BIAS, init_std=config.MODEL.SWIN_MOE.INIT_STD, use_checkpoint=config.TRAIN.USE_CHECKPOINT, pretrained_window_sizes=config.MODEL.SWIN_MOE.PRETRAINED_WINDOW_SIZES, moe_blocks=config.MODEL.SWIN_MOE.MOE_BLOCKS, num_local_experts=config.MODEL.SWIN_MOE.NUM_LOCAL_EXPERTS, top_value=config.MODEL.SWIN_MOE.TOP_VALUE, capacity_factor=config.MODEL.SWIN_MOE.CAPACITY_FACTOR, cosine_router=config.MODEL.SWIN_MOE.COSINE_ROUTER, normalize_gate=config.MODEL.SWIN_MOE.NORMALIZE_GATE, use_bpr=config.MODEL.SWIN_MOE.USE_BPR, is_gshard_loss=config.MODEL.SWIN_MOE.IS_GSHARD_LOSS, gate_noise=config.MODEL.SWIN_MOE.GATE_NOISE, cosine_router_dim=config.MODEL.SWIN_MOE.COSINE_ROUTER_DIM, cosine_router_init_t=config.MODEL.SWIN_MOE.COSINE_ROUTER_INIT_T, moe_drop=config.MODEL.SWIN_MOE.MOE_DROP, aux_loss_weight=config.MODEL.SWIN_MOE.AUX_LOSS_WEIGHT) elif model_type == 'swin_mlp': model = SwinMLP(img_size=config.DATA.IMG_SIZE, patch_size=config.MODEL.SWIN_MLP.PATCH_SIZE, in_chans=config.MODEL.SWIN_MLP.IN_CHANS, num_classes=config.MODEL.NUM_CLASSES, embed_dim=config.MODEL.SWIN_MLP.EMBED_DIM, depths=config.MODEL.SWIN_MLP.DEPTHS, num_heads=config.MODEL.SWIN_MLP.NUM_HEADS, window_size=config.MODEL.SWIN_MLP.WINDOW_SIZE, mlp_ratio=config.MODEL.SWIN_MLP.MLP_RATIO, drop_rate=config.MODEL.DROP_RATE, drop_path_rate=config.MODEL.DROP_PATH_RATE, ape=config.MODEL.SWIN_MLP.APE, patch_norm=config.MODEL.SWIN_MLP.PATCH_NORM, use_checkpoint=config.TRAIN.USE_CHECKPOINT) else: raise NotImplementedError(f"Unkown model: {model_type}") return model ================================================ FILE: models/simmim.py ================================================ # -------------------------------------------------------- # SimMIM # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Zhenda Xie # -------------------------------------------------------- from functools import partial import torch import torch.nn as nn import torch.nn.functional as F from timm.models.layers import trunc_normal_ from .swin_transformer import SwinTransformer from .swin_transformer_v2 import SwinTransformerV2 def norm_targets(targets, patch_size): assert patch_size % 2 == 1 targets_ = targets targets_count = torch.ones_like(targets) targets_square = targets ** 2. targets_mean = F.avg_pool2d(targets, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=False) targets_square_mean = F.avg_pool2d(targets_square, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=False) targets_count = F.avg_pool2d(targets_count, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=True) * (patch_size ** 2) targets_var = (targets_square_mean - targets_mean ** 2.) * (targets_count / (targets_count - 1)) targets_var = torch.clamp(targets_var, min=0.) targets_ = (targets_ - targets_mean) / (targets_var + 1.e-6) ** 0.5 return targets_ class SwinTransformerForSimMIM(SwinTransformer): def __init__(self, **kwargs): super().__init__(**kwargs) assert self.num_classes == 0 self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) trunc_normal_(self.mask_token, mean=0., std=.02) def forward(self, x, mask): x = self.patch_embed(x) assert mask is not None B, L, _ = x.shape mask_tokens = self.mask_token.expand(B, L, -1) w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens) x = x * (1. - w) + mask_tokens * w if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for layer in self.layers: x = layer(x) x = self.norm(x) x = x.transpose(1, 2) B, C, L = x.shape H = W = int(L ** 0.5) x = x.reshape(B, C, H, W) return x @torch.jit.ignore def no_weight_decay(self): return super().no_weight_decay() | {'mask_token'} class SwinTransformerV2ForSimMIM(SwinTransformerV2): def __init__(self, **kwargs): super().__init__(**kwargs) assert self.num_classes == 0 self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) trunc_normal_(self.mask_token, mean=0., std=.02) def forward(self, x, mask): x = self.patch_embed(x) assert mask is not None B, L, _ = x.shape mask_tokens = self.mask_token.expand(B, L, -1) w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens) x = x * (1. - w) + mask_tokens * w if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for layer in self.layers: x = layer(x) x = self.norm(x) x = x.transpose(1, 2) B, C, L = x.shape H = W = int(L ** 0.5) x = x.reshape(B, C, H, W) return x @torch.jit.ignore def no_weight_decay(self): return super().no_weight_decay() | {'mask_token'} class SimMIM(nn.Module): def __init__(self, config, encoder, encoder_stride, in_chans, patch_size): super().__init__() self.config = config self.encoder = encoder self.encoder_stride = encoder_stride self.decoder = nn.Sequential( nn.Conv2d( in_channels=self.encoder.num_features, out_channels=self.encoder_stride ** 2 * 3, kernel_size=1), nn.PixelShuffle(self.encoder_stride), ) self.in_chans = in_chans self.patch_size = patch_size def forward(self, x, mask): z = self.encoder(x, mask) x_rec = self.decoder(z) mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous() # norm target as prompted if self.config.NORM_TARGET.ENABLE: x = norm_targets(x, self.config.NORM_TARGET.PATCH_SIZE) loss_recon = F.l1_loss(x, x_rec, reduction='none') loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans return loss @torch.jit.ignore def no_weight_decay(self): if hasattr(self.encoder, 'no_weight_decay'): return {'encoder.' + i for i in self.encoder.no_weight_decay()} return {} @torch.jit.ignore def no_weight_decay_keywords(self): if hasattr(self.encoder, 'no_weight_decay_keywords'): return {'encoder.' + i for i in self.encoder.no_weight_decay_keywords()} return {} def build_simmim(config): model_type = config.MODEL.TYPE if model_type == 'swin': encoder = SwinTransformerForSimMIM( img_size=config.DATA.IMG_SIZE, patch_size=config.MODEL.SWIN.PATCH_SIZE, in_chans=config.MODEL.SWIN.IN_CHANS, num_classes=0, embed_dim=config.MODEL.SWIN.EMBED_DIM, depths=config.MODEL.SWIN.DEPTHS, num_heads=config.MODEL.SWIN.NUM_HEADS, window_size=config.MODEL.SWIN.WINDOW_SIZE, mlp_ratio=config.MODEL.SWIN.MLP_RATIO, qkv_bias=config.MODEL.SWIN.QKV_BIAS, qk_scale=config.MODEL.SWIN.QK_SCALE, drop_rate=config.MODEL.DROP_RATE, drop_path_rate=config.MODEL.DROP_PATH_RATE, ape=config.MODEL.SWIN.APE, patch_norm=config.MODEL.SWIN.PATCH_NORM, use_checkpoint=config.TRAIN.USE_CHECKPOINT) encoder_stride = 32 in_chans = config.MODEL.SWIN.IN_CHANS patch_size = config.MODEL.SWIN.PATCH_SIZE elif model_type == 'swinv2': encoder = SwinTransformerV2ForSimMIM( img_size=config.DATA.IMG_SIZE, patch_size=config.MODEL.SWINV2.PATCH_SIZE, in_chans=config.MODEL.SWINV2.IN_CHANS, num_classes=0, embed_dim=config.MODEL.SWINV2.EMBED_DIM, depths=config.MODEL.SWINV2.DEPTHS, num_heads=config.MODEL.SWINV2.NUM_HEADS, window_size=config.MODEL.SWINV2.WINDOW_SIZE, mlp_ratio=config.MODEL.SWINV2.MLP_RATIO, qkv_bias=config.MODEL.SWINV2.QKV_BIAS, drop_rate=config.MODEL.DROP_RATE, drop_path_rate=config.MODEL.DROP_PATH_RATE, ape=config.MODEL.SWINV2.APE, patch_norm=config.MODEL.SWINV2.PATCH_NORM, use_checkpoint=config.TRAIN.USE_CHECKPOINT) encoder_stride = 32 in_chans = config.MODEL.SWINV2.IN_CHANS patch_size = config.MODEL.SWINV2.PATCH_SIZE else: raise NotImplementedError(f"Unknown pre-train model: {model_type}") model = SimMIM(config=config.MODEL.SIMMIM, encoder=encoder, encoder_stride=encoder_stride, in_chans=in_chans, patch_size=patch_size) return model ================================================ FILE: models/swin_mlp.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # -------------------------------------------------------- import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class SwinMLPBlock(nn.Module): r""" Swin MLP Block. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio if min(self.input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.padding = [self.window_size - self.shift_size, self.shift_size, self.window_size - self.shift_size, self.shift_size] # P_l,P_r,P_t,P_b self.norm1 = norm_layer(dim) # use group convolution to implement multi-head MLP self.spatial_mlp = nn.Conv1d(self.num_heads * self.window_size ** 2, self.num_heads * self.window_size ** 2, kernel_size=1, groups=self.num_heads) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # shift if self.shift_size > 0: P_l, P_r, P_t, P_b = self.padding shifted_x = F.pad(x, [0, 0, P_l, P_r, P_t, P_b], "constant", 0) else: shifted_x = x _, _H, _W, _ = shifted_x.shape # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # Window/Shifted-Window Spatial MLP x_windows_heads = x_windows.view(-1, self.window_size * self.window_size, self.num_heads, C // self.num_heads) x_windows_heads = x_windows_heads.transpose(1, 2) # nW*B, nH, window_size*window_size, C//nH x_windows_heads = x_windows_heads.reshape(-1, self.num_heads * self.window_size * self.window_size, C // self.num_heads) spatial_mlp_windows = self.spatial_mlp(x_windows_heads) # nW*B, nH*window_size*window_size, C//nH spatial_mlp_windows = spatial_mlp_windows.view(-1, self.num_heads, self.window_size * self.window_size, C // self.num_heads).transpose(1, 2) spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size * self.window_size, C) # merge windows spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(spatial_mlp_windows, self.window_size, _H, _W) # B H' W' C # reverse shift if self.shift_size > 0: P_l, P_r, P_t, P_b = self.padding x = shifted_x[:, P_t:-P_b, P_l:-P_r, :].contiguous() else: x = shifted_x x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" def flops(self): flops = 0 H, W = self.input_resolution # norm1 flops += self.dim * H * W # Window/Shifted-Window Spatial MLP if self.shift_size > 0: nW = (H / self.window_size + 1) * (W / self.window_size + 1) else: nW = H * W / self.window_size / self.window_size flops += nW * self.dim * (self.window_size * self.window_size) * (self.window_size * self.window_size) # mlp flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 flops += self.dim * H * W return flops class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." x = x.view(B, H, W, C) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x def extra_repr(self) -> str: return f"input_resolution={self.input_resolution}, dim={self.dim}" def flops(self): H, W = self.input_resolution flops = H * W * self.dim flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim return flops class BasicLayer(nn.Module): """ A basic Swin MLP layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ SwinMLPBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x): for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" def flops(self): flops = 0 for blk in self.blocks: flops += blk.flops() if self.downsample is not None: flops += self.downsample.flops() return flops class PatchEmbed(nn.Module): r""" Image to Patch Embedding Args: img_size (int): Image size. Default: 224. patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C if self.norm is not None: x = self.norm(x) return x def flops(self): Ho, Wo = self.patches_resolution flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops class SwinMLP(nn.Module): r""" Swin MLP Args: img_size (int | tuple(int)): Input image size. Default 224 patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each Swin MLP layer. num_heads (tuple(int)): Number of attention heads in different layers. window_size (int): Window size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 drop_rate (float): Dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False patch_norm (bool): If True, add normalization after patch embedding. Default: True use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, **kwargs): super().__init__() self.num_classes = num_classes self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.mlp_ratio = mlp_ratio # split image into non-overlapping patches self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # absolute position embedding if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, drop=drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint) self.layers.append(layer) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Linear, nn.Conv1d)): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return {'absolute_pos_embed'} @torch.jit.ignore def no_weight_decay_keywords(self): return {'relative_position_bias_table'} def forward_features(self, x): x = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for layer in self.layers: x = layer(x) x = self.norm(x) # B L C x = self.avgpool(x.transpose(1, 2)) # B C 1 x = torch.flatten(x, 1) return x def forward(self, x): x = self.forward_features(x) x = self.head(x) return x def flops(self): flops = 0 flops += self.patch_embed.flops() for i, layer in enumerate(self.layers): flops += layer.flops() flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) flops += self.num_features * self.num_classes return flops ================================================ FILE: models/swin_transformer.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # -------------------------------------------------------- import torch import torch.nn as nn import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ try: import os, sys kernel_path = os.path.abspath(os.path.join('..')) sys.path.append(kernel_path) from kernels.window_process.window_process import WindowProcess, WindowProcessReverse except: WindowProcess = None WindowProcessReverse = None print("[Warning] Fused window process have not been installed. Please refer to get_started.md for installation.") class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x def extra_repr(self) -> str: return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' def flops(self, N): # calculate flops for 1 window with token length of N flops = 0 # qkv = self.qkv(x) flops += N * self.dim * 3 * self.dim # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * N * (self.dim // self.num_heads) * N # x = (attn @ v) flops += self.num_heads * N * N * (self.dim // self.num_heads) # x = self.proj(x) flops += N * self.dim * self.dim return flops class SwinTransformerBlock(nn.Module): r""" Swin Transformer Block. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False """ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, fused_window_process=False): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio if min(self.input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None self.register_buffer("attn_mask", attn_mask) self.fused_window_process = fused_window_process def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # cyclic shift if self.shift_size > 0: if not self.fused_window_process: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C else: x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size) else: shifted_x = x # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # reverse cyclic shift if self.shift_size > 0: if not self.fused_window_process: shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size) else: shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C x = shifted_x x = x.view(B, H * W, C) x = shortcut + self.drop_path(x) # FFN x = x + self.drop_path(self.mlp(self.norm2(x))) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" def flops(self): flops = 0 H, W = self.input_resolution # norm1 flops += self.dim * H * W # W-MSA/SW-MSA nW = H * W / self.window_size / self.window_size flops += nW * self.attn.flops(self.window_size * self.window_size) # mlp flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 flops += self.dim * H * W return flops class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." x = x.view(B, H, W, C) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x def extra_repr(self) -> str: return f"input_resolution={self.input_resolution}, dim={self.dim}" def flops(self): H, W = self.input_resolution flops = H * W * self.dim flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim return flops class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, fused_window_process=False): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ SwinTransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, fused_window_process=fused_window_process) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x): for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" def flops(self): flops = 0 for blk in self.blocks: flops += blk.flops() if self.downsample is not None: flops += self.downsample.flops() return flops class PatchEmbed(nn.Module): r""" Image to Patch Embedding Args: img_size (int): Image size. Default: 224. patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C if self.norm is not None: x = self.norm(x) return x def flops(self): Ho, Wo = self.patches_resolution flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops class SwinTransformer(nn.Module): r""" Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: img_size (int | tuple(int)): Input image size. Default 224 patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each Swin Transformer layer. num_heads (tuple(int)): Number of attention heads in different layers. window_size (int): Window size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False patch_norm (bool): If True, add normalization after patch embedding. Default: True use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False """ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, fused_window_process=False, **kwargs): super().__init__() self.num_classes = num_classes self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.mlp_ratio = mlp_ratio # split image into non-overlapping patches self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # absolute position embedding if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, fused_window_process=fused_window_process) self.layers.append(layer) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return {'absolute_pos_embed'} @torch.jit.ignore def no_weight_decay_keywords(self): return {'relative_position_bias_table'} def forward_features(self, x): x = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for layer in self.layers: x = layer(x) x = self.norm(x) # B L C x = self.avgpool(x.transpose(1, 2)) # B C 1 x = torch.flatten(x, 1) return x def forward(self, x): x = self.forward_features(x) x = self.head(x) return x def flops(self): flops = 0 flops += self.patch_embed.flops() for i, layer in enumerate(self.layers): flops += layer.flops() flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) flops += self.num_features * self.num_classes return flops ================================================ FILE: models/swin_transformer_moe.py ================================================ # -------------------------------------------------------- # Swin Transformer MoE # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # -------------------------------------------------------- import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ import numpy as np try: from tutel import moe as tutel_moe except: tutel_moe = None print("Tutel has not been installed. To use Swin-MoE, please install Tutel; otherwise, just ignore this.") class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., mlp_fc2_bias=True): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features, bias=mlp_fc2_bias) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class MoEMlp(nn.Module): def __init__(self, in_features, hidden_features, num_local_experts, top_value, capacity_factor=1.25, cosine_router=False, normalize_gate=False, use_bpr=True, is_gshard_loss=True, gate_noise=1.0, cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0, init_std=0.02, mlp_fc2_bias=True): super().__init__() self.in_features = in_features self.hidden_features = hidden_features self.num_local_experts = num_local_experts self.top_value = top_value self.capacity_factor = capacity_factor self.cosine_router = cosine_router self.normalize_gate = normalize_gate self.use_bpr = use_bpr self.init_std = init_std self.mlp_fc2_bias = mlp_fc2_bias self.dist_rank = dist.get_rank() self._dropout = nn.Dropout(p=moe_drop) _gate_type = {'type': 'cosine_top' if cosine_router else 'top', 'k': top_value, 'capacity_factor': capacity_factor, 'gate_noise': gate_noise, 'fp32_gate': True} if cosine_router: _gate_type['proj_dim'] = cosine_router_dim _gate_type['init_t'] = cosine_router_init_t self._moe_layer = tutel_moe.moe_layer( gate_type=_gate_type, model_dim=in_features, experts={'type': 'ffn', 'count_per_node': num_local_experts, 'hidden_size_per_expert': hidden_features, 'activation_fn': lambda x: self._dropout(F.gelu(x))}, scan_expert_func=lambda name, param: setattr(param, 'skip_allreduce', True), seeds=(1, self.dist_rank + 1, self.dist_rank + 1), batch_prioritized_routing=use_bpr, normalize_gate=normalize_gate, is_gshard_loss=is_gshard_loss, ) if not self.mlp_fc2_bias: self._moe_layer.experts.batched_fc2_bias.requires_grad = False def forward(self, x): x = self._moe_layer(x) return x, x.l_aux def extra_repr(self) -> str: return f'[Statistics-{self.dist_rank}] param count for MoE, ' \ f'in_features = {self.in_features}, hidden_features = {self.hidden_features}, ' \ f'num_local_experts = {self.num_local_experts}, top_value = {self.top_value}, ' \ f'cosine_router={self.cosine_router} normalize_gate={self.normalize_gate}, use_bpr = {self.use_bpr}' def _init_weights(self): if hasattr(self._moe_layer, "experts"): trunc_normal_(self._moe_layer.experts.batched_fc1_w, std=self.init_std) trunc_normal_(self._moe_layer.experts.batched_fc2_w, std=self.init_std) nn.init.constant_(self._moe_layer.experts.batched_fc1_bias, 0) nn.init.constant_(self._moe_layer.experts.batched_fc2_bias, 0) def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 pretrained_window_size (tuple[int]): The height and width of the window in pre-training. """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., pretrained_window_size=[0, 0]): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.pretrained_window_size = pretrained_window_size self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # mlp to generate continuous relative position bias self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)) # get relative_coords_table relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) relative_coords_table = torch.stack( torch.meshgrid([relative_coords_h, relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 if pretrained_window_size[0] > 0: relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) else: relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) relative_coords_table *= 8 # normalize to -8, 8 relative_coords_table = torch.sign(relative_coords_table) * torch.log2( torch.abs(relative_coords_table) + 1.0) / np.log2(8) self.register_buffer("relative_coords_table", relative_coords_table) # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x def extra_repr(self) -> str: return f'dim={self.dim}, window_size={self.window_size}, ' \ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' def flops(self, N): # calculate flops for 1 window with token length of N flops = 0 # qkv = self.qkv(x) flops += N * self.dim * 3 * self.dim # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * N * (self.dim // self.num_heads) * N # x = (attn @ v) flops += self.num_heads * N * N * (self.dim // self.num_heads) # x = self.proj(x) flops += N * self.dim * self.dim return flops class SwinTransformerBlock(nn.Module): r""" Swin Transformer Block. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm mlp_fc2_bias (bool): Whether to add bias in fc2 of Mlp. Default: True init_std: Initialization std. Default: 0.02 pretrained_window_size (int): Window size in pre-training. is_moe (bool): If True, this block is a MoE block. num_local_experts (int): number of local experts in each device (GPU). Default: 1 top_value (int): the value of k in top-k gating. Default: 1 capacity_factor (float): the capacity factor in MoE. Default: 1.25 cosine_router (bool): Whether to use cosine router. Default: False normalize_gate (bool): Whether to normalize the gating score in top-k gating. Default: False use_bpr (bool): Whether to use batch-prioritized-routing. Default: True is_gshard_loss (bool): If True, use Gshard balance loss. If False, use the load loss and importance loss in "arXiv:1701.06538". Default: False gate_noise (float): the noise ratio in top-k gating. Default: 1.0 cosine_router_dim (int): Projection dimension in cosine router. cosine_router_init_t (float): Initialization temperature in cosine router. moe_drop (float): Dropout rate in MoE. Default: 0.0 """ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, mlp_fc2_bias=True, init_std=0.02, pretrained_window_size=0, is_moe=False, num_local_experts=1, top_value=1, capacity_factor=1.25, cosine_router=False, normalize_gate=False, use_bpr=True, is_gshard_loss=True, gate_noise=1.0, cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio self.is_moe = is_moe self.capacity_factor = capacity_factor self.top_value = top_value if min(self.input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, pretrained_window_size=to_2tuple(pretrained_window_size)) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) if self.is_moe: self.mlp = MoEMlp(in_features=dim, hidden_features=mlp_hidden_dim, num_local_experts=num_local_experts, top_value=top_value, capacity_factor=capacity_factor, cosine_router=cosine_router, normalize_gate=normalize_gate, use_bpr=use_bpr, is_gshard_loss=is_gshard_loss, gate_noise=gate_noise, cosine_router_dim=cosine_router_dim, cosine_router_init_t=cosine_router_init_t, moe_drop=moe_drop, mlp_fc2_bias=mlp_fc2_bias, init_std=init_std) else: self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, mlp_fc2_bias=mlp_fc2_bias) if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None self.register_buffer("attn_mask", attn_mask) def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) x = shortcut + self.drop_path(x) # FFN shortcut = x x = self.norm2(x) if self.is_moe: x, l_aux = self.mlp(x) x = shortcut + self.drop_path(x) return x, l_aux else: x = shortcut + self.drop_path(self.mlp(x)) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" def flops(self): flops = 0 H, W = self.input_resolution # norm1 flops += self.dim * H * W # W-MSA/SW-MSA nW = H * W / self.window_size / self.window_size flops += nW * self.attn.flops(self.window_size * self.window_size) # mlp if self.is_moe: flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio * self.capacity_factor * self.top_value else: flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 flops += self.dim * H * W return flops class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." x = x.view(B, H, W, C) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x def extra_repr(self) -> str: return f"input_resolution={self.input_resolution}, dim={self.dim}" def flops(self): H, W = self.input_resolution flops = H * W * self.dim flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim return flops class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None mlp_fc2_bias (bool): Whether to add bias in fc2 of Mlp. Default: True init_std: Initialization std. Default: 0.02 use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. pretrained_window_size (int): Local window size in pre-training. moe_blocks (tuple(int)): The index of each MoE block. num_local_experts (int): number of local experts in each device (GPU). Default: 1 top_value (int): the value of k in top-k gating. Default: 1 capacity_factor (float): the capacity factor in MoE. Default: 1.25 cosine_router (bool): Whether to use cosine router Default: False normalize_gate (bool): Whether to normalize the gating score in top-k gating. Default: False use_bpr (bool): Whether to use batch-prioritized-routing. Default: True is_gshard_loss (bool): If True, use Gshard balance loss. If False, use the load loss and importance loss in "arXiv:1701.06538". Default: False gate_noise (float): the noise ratio in top-k gating. Default: 1.0 cosine_router_dim (int): Projection dimension in cosine router. cosine_router_init_t (float): Initialization temperature in cosine router. moe_drop (float): Dropout rate in MoE. Default: 0.0 """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, mlp_fc2_bias=True, init_std=0.02, use_checkpoint=False, pretrained_window_size=0, moe_block=[-1], num_local_experts=1, top_value=1, capacity_factor=1.25, cosine_router=False, normalize_gate=False, use_bpr=True, is_gshard_loss=True, cosine_router_dim=256, cosine_router_init_t=0.5, gate_noise=1.0, moe_drop=0.0): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ SwinTransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, mlp_fc2_bias=mlp_fc2_bias, init_std=init_std, pretrained_window_size=pretrained_window_size, is_moe=True if i in moe_block else False, num_local_experts=num_local_experts, top_value=top_value, capacity_factor=capacity_factor, cosine_router=cosine_router, normalize_gate=normalize_gate, use_bpr=use_bpr, is_gshard_loss=is_gshard_loss, gate_noise=gate_noise, cosine_router_dim=cosine_router_dim, cosine_router_init_t=cosine_router_init_t, moe_drop=moe_drop) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x): l_aux = 0.0 for blk in self.blocks: if self.use_checkpoint: out = checkpoint.checkpoint(blk, x) else: out = blk(x) if isinstance(out, tuple): x = out[0] cur_l_aux = out[1] l_aux = cur_l_aux + l_aux else: x = out if self.downsample is not None: x = self.downsample(x) return x, l_aux def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" def flops(self): flops = 0 for blk in self.blocks: flops += blk.flops() if self.downsample is not None: flops += self.downsample.flops() return flops class PatchEmbed(nn.Module): r""" Image to Patch Embedding Args: img_size (int): Image size. Default: 224. patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C if self.norm is not None: x = self.norm(x) return x def flops(self): Ho, Wo = self.patches_resolution flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops class SwinTransformerMoE(nn.Module): r""" Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: img_size (int | tuple(int)): Input image size. Default 224 patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each Swin Transformer layer. num_heads (tuple(int)): Number of attention heads in different layers. window_size (int): Window size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False patch_norm (bool): If True, add normalization after patch embedding. Default: True mlp_fc2_bias (bool): Whether to add bias in fc2 of Mlp. Default: True init_std: Initialization std. Default: 0.02 use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer. moe_blocks (tuple(tuple(int))): The index of each MoE block in each layer. num_local_experts (int): number of local experts in each device (GPU). Default: 1 top_value (int): the value of k in top-k gating. Default: 1 capacity_factor (float): the capacity factor in MoE. Default: 1.25 cosine_router (bool): Whether to use cosine router Default: False normalize_gate (bool): Whether to normalize the gating score in top-k gating. Default: False use_bpr (bool): Whether to use batch-prioritized-routing. Default: True is_gshard_loss (bool): If True, use Gshard balance loss. If False, use the load loss and importance loss in "arXiv:1701.06538". Default: False gate_noise (float): the noise ratio in top-k gating. Default: 1.0 cosine_router_dim (int): Projection dimension in cosine router. cosine_router_init_t (float): Initialization temperature in cosine router. moe_drop (float): Dropout rate in MoE. Default: 0.0 aux_loss_weight (float): auxiliary loss weight. Default: 0.1 """ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, mlp_fc2_bias=True, init_std=0.02, use_checkpoint=False, pretrained_window_sizes=[0, 0, 0, 0], moe_blocks=[[-1], [-1], [-1], [-1]], num_local_experts=1, top_value=1, capacity_factor=1.25, cosine_router=False, normalize_gate=False, use_bpr=True, is_gshard_loss=True, gate_noise=1.0, cosine_router_dim=256, cosine_router_init_t=0.5, moe_drop=0.0, aux_loss_weight=0.01, **kwargs): super().__init__() self._ddp_params_and_buffers_to_ignore = list() self.num_classes = num_classes self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.mlp_ratio = mlp_ratio self.init_std = init_std self.aux_loss_weight = aux_loss_weight self.num_local_experts = num_local_experts self.global_experts = num_local_experts * dist.get_world_size() if num_local_experts > 0 \ else dist.get_world_size() // (-num_local_experts) self.sharded_count = (1.0 / num_local_experts) if num_local_experts > 0 else (-num_local_experts) # split image into non-overlapping patches self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # absolute position embedding if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) trunc_normal_(self.absolute_pos_embed, std=self.init_std) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, mlp_fc2_bias=mlp_fc2_bias, init_std=init_std, use_checkpoint=use_checkpoint, pretrained_window_size=pretrained_window_sizes[i_layer], moe_block=moe_blocks[i_layer], num_local_experts=num_local_experts, top_value=top_value, capacity_factor=capacity_factor, cosine_router=cosine_router, normalize_gate=normalize_gate, use_bpr=use_bpr, is_gshard_loss=is_gshard_loss, gate_noise=gate_noise, cosine_router_dim=cosine_router_dim, cosine_router_init_t=cosine_router_init_t, moe_drop=moe_drop) self.layers.append(layer) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=self.init_std) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, MoEMlp): m._init_weights() @torch.jit.ignore def no_weight_decay(self): return {'absolute_pos_embed'} @torch.jit.ignore def no_weight_decay_keywords(self): return {"cpb_mlp", 'relative_position_bias_table', 'fc1_bias', 'fc2_bias', 'temperature', 'cosine_projector', 'sim_matrix'} def forward_features(self, x): x = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) l_aux = 0.0 for layer in self.layers: x, cur_l_aux = layer(x) l_aux = cur_l_aux + l_aux x = self.norm(x) # B L C x = self.avgpool(x.transpose(1, 2)) # B C 1 x = torch.flatten(x, 1) return x, l_aux def forward(self, x): x, l_aux = self.forward_features(x) x = self.head(x) return x, l_aux * self.aux_loss_weight def add_param_to_skip_allreduce(self, param_name): self._ddp_params_and_buffers_to_ignore.append(param_name) def flops(self): flops = 0 flops += self.patch_embed.flops() for i, layer in enumerate(self.layers): flops += layer.flops() flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) flops += self.num_features * self.num_classes return flops ================================================ FILE: models/swin_transformer_v2.py ================================================ # -------------------------------------------------------- # Swin Transformer V2 # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # -------------------------------------------------------- import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ import numpy as np class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 pretrained_window_size (tuple[int]): The height and width of the window in pre-training. """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., pretrained_window_size=[0, 0]): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.pretrained_window_size = pretrained_window_size self.num_heads = num_heads self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) # mlp to generate continuous relative position bias self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)) # get relative_coords_table relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) relative_coords_table = torch.stack( torch.meshgrid([relative_coords_h, relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 if pretrained_window_size[0] > 0: relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) else: relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) relative_coords_table *= 8 # normalize to -8, 8 relative_coords_table = torch.sign(relative_coords_table) * torch.log2( torch.abs(relative_coords_table) + 1.0) / np.log2(8) self.register_buffer("relative_coords_table", relative_coords_table) # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(dim)) self.v_bias = nn.Parameter(torch.zeros(dim)) else: self.q_bias = None self.v_bias = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv_bias = None if self.q_bias is not None: qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) # cosine attention attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp() attn = attn * logit_scale relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww relative_position_bias = 16 * torch.sigmoid(relative_position_bias) attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x def extra_repr(self) -> str: return f'dim={self.dim}, window_size={self.window_size}, ' \ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' def flops(self, N): # calculate flops for 1 window with token length of N flops = 0 # qkv = self.qkv(x) flops += N * self.dim * 3 * self.dim # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * N * (self.dim // self.num_heads) * N # x = (attn @ v) flops += self.num_heads * N * N * (self.dim // self.num_heads) # x = self.proj(x) flops += N * self.dim * self.dim return flops class SwinTransformerBlock(nn.Module): r""" Swin Transformer Block. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm pretrained_window_size (int): Window size in pre-training. """ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio if min(self.input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, pretrained_window_size=to_2tuple(pretrained_window_size)) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None self.register_buffer("attn_mask", attn_mask) def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x x = x.view(B, H, W, C) # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) x = shortcut + self.drop_path(self.norm1(x)) # FFN x = x + self.drop_path(self.norm2(self.mlp(x))) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" def flops(self): flops = 0 H, W = self.input_resolution # norm1 flops += self.dim * H * W # W-MSA/SW-MSA nW = H * W / self.window_size / self.window_size flops += nW * self.attn.flops(self.window_size * self.window_size) # mlp flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 flops += self.dim * H * W return flops class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(2 * dim) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." x = x.view(B, H, W, C) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.reduction(x) x = self.norm(x) return x def extra_repr(self) -> str: return f"input_resolution={self.input_resolution}, dim={self.dim}" def flops(self): H, W = self.input_resolution flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim flops += H * W * self.dim // 2 return flops class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. pretrained_window_size (int): Local window size in pre-training. """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, pretrained_window_size=0): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ SwinTransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, pretrained_window_size=pretrained_window_size) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x): for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" def flops(self): flops = 0 for blk in self.blocks: flops += blk.flops() if self.downsample is not None: flops += self.downsample.flops() return flops def _init_respostnorm(self): for blk in self.blocks: nn.init.constant_(blk.norm1.bias, 0) nn.init.constant_(blk.norm1.weight, 0) nn.init.constant_(blk.norm2.bias, 0) nn.init.constant_(blk.norm2.weight, 0) class PatchEmbed(nn.Module): r""" Image to Patch Embedding Args: img_size (int): Image size. Default: 224. patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C if self.norm is not None: x = self.norm(x) return x def flops(self): Ho, Wo = self.patches_resolution flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops class SwinTransformerV2(nn.Module): r""" Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: img_size (int | tuple(int)): Input image size. Default 224 patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each Swin Transformer layer. num_heads (tuple(int)): Number of attention heads in different layers. window_size (int): Window size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False patch_norm (bool): If True, add normalization after patch embedding. Default: True use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer. """ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, pretrained_window_sizes=[0, 0, 0, 0], **kwargs): super().__init__() self.num_classes = num_classes self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.mlp_ratio = mlp_ratio # split image into non-overlapping patches self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # absolute position embedding if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, pretrained_window_size=pretrained_window_sizes[i_layer]) self.layers.append(layer) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) for bly in self.layers: bly._init_respostnorm() def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return {'absolute_pos_embed'} @torch.jit.ignore def no_weight_decay_keywords(self): return {"cpb_mlp", "logit_scale", 'relative_position_bias_table'} def forward_features(self, x): x = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for layer in self.layers: x = layer(x) x = self.norm(x) # B L C x = self.avgpool(x.transpose(1, 2)) # B C 1 x = torch.flatten(x, 1) return x def forward(self, x): x = self.forward_features(x) x = self.head(x) return x def flops(self): flops = 0 flops += self.patch_embed.flops() for i, layer in enumerate(self.layers): flops += layer.flops() flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) flops += self.num_features * self.num_classes return flops ================================================ FILE: optimizer.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # -------------------------------------------------------- from functools import partial from torch import optim as optim try: from apex.optimizers import FusedAdam, FusedLAMB except: FusedAdam = None FusedLAMB = None print("To use FusedLAMB or FusedAdam, please install apex.") def build_optimizer(config, model, simmim=False, is_pretrain=False): """ Build optimizer, set weight decay of normalization to 0 by default. """ skip = {} skip_keywords = {} if hasattr(model, 'no_weight_decay'): skip = model.no_weight_decay() if hasattr(model, 'no_weight_decay_keywords'): skip_keywords = model.no_weight_decay_keywords() if simmim: if is_pretrain: parameters = get_pretrain_param_groups(model, skip, skip_keywords) else: depths = config.MODEL.SWIN.DEPTHS if config.MODEL.TYPE == 'swin' else config.MODEL.SWINV2.DEPTHS num_layers = sum(depths) get_layer_func = partial(get_swin_layer, num_layers=num_layers + 2, depths=depths) scales = list(config.TRAIN.LAYER_DECAY ** i for i in reversed(range(num_layers + 2))) parameters = get_finetune_param_groups(model, config.TRAIN.BASE_LR, config.TRAIN.WEIGHT_DECAY, get_layer_func, scales, skip, skip_keywords) else: parameters = set_weight_decay(model, skip, skip_keywords) opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() optimizer = None if opt_lower == 'sgd': optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) elif opt_lower == 'adamw': optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) elif opt_lower == 'fused_adam': optimizer = FusedAdam(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) elif opt_lower == 'fused_lamb': optimizer = FusedLAMB(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) return optimizer def set_weight_decay(model, skip_list=(), skip_keywords=()): has_decay = [] no_decay = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # frozen weights if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ check_keywords_in_name(name, skip_keywords): no_decay.append(param) # print(f"{name} has no weight decay") else: has_decay.append(param) return [{'params': has_decay}, {'params': no_decay, 'weight_decay': 0.}] def check_keywords_in_name(name, keywords=()): isin = False for keyword in keywords: if keyword in name: isin = True return isin def get_pretrain_param_groups(model, skip_list=(), skip_keywords=()): has_decay = [] no_decay = [] has_decay_name = [] no_decay_name = [] for name, param in model.named_parameters(): if not param.requires_grad: continue if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ check_keywords_in_name(name, skip_keywords): no_decay.append(param) no_decay_name.append(name) else: has_decay.append(param) has_decay_name.append(name) return [{'params': has_decay}, {'params': no_decay, 'weight_decay': 0.}] def get_swin_layer(name, num_layers, depths): if name in ("mask_token"): return 0 elif name.startswith("patch_embed"): return 0 elif name.startswith("layers"): layer_id = int(name.split('.')[1]) block_id = name.split('.')[3] if block_id == 'reduction' or block_id == 'norm': return sum(depths[:layer_id + 1]) layer_id = sum(depths[:layer_id]) + int(block_id) return layer_id + 1 else: return num_layers - 1 def get_finetune_param_groups(model, lr, weight_decay, get_layer_func, scales, skip_list=(), skip_keywords=()): parameter_group_names = {} parameter_group_vars = {} for name, param in model.named_parameters(): if not param.requires_grad: continue if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ check_keywords_in_name(name, skip_keywords): group_name = "no_decay" this_weight_decay = 0. else: group_name = "decay" this_weight_decay = weight_decay if get_layer_func is not None: layer_id = get_layer_func(name) group_name = "layer_%d_%s" % (layer_id, group_name) else: layer_id = None if group_name not in parameter_group_names: if scales is not None: scale = scales[layer_id] else: scale = 1. parameter_group_names[group_name] = { "group_name": group_name, "weight_decay": this_weight_decay, "params": [], "lr": lr * scale, "lr_scale": scale, } parameter_group_vars[group_name] = { "group_name": group_name, "weight_decay": this_weight_decay, "params": [], "lr": lr * scale, "lr_scale": scale } parameter_group_vars[group_name]["params"].append(param) parameter_group_names[group_name]["params"].append(name) return list(parameter_group_vars.values()) ================================================ FILE: utils.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # -------------------------------------------------------- import os import torch import torch.distributed as dist try: from torch._six import inf except: from torch import inf def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger): logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") if config.MODEL.RESUME.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( config.MODEL.RESUME, map_location='cpu', check_hash=True) else: checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') msg = model.load_state_dict(checkpoint['model'], strict=False) logger.info(msg) max_accuracy = 0.0 if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) config.defrost() config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 config.freeze() if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") if 'max_accuracy' in checkpoint: max_accuracy = checkpoint['max_accuracy'] del checkpoint torch.cuda.empty_cache() return max_accuracy def load_pretrained(config, model, logger): logger.info(f"==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......") checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') state_dict = checkpoint['model'] # delete relative_position_index since we always re-init it relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k] for k in relative_position_index_keys: del state_dict[k] # delete relative_coords_table since we always re-init it relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k] for k in relative_position_index_keys: del state_dict[k] # delete attn_mask since we always re-init it attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] for k in attn_mask_keys: del state_dict[k] # bicubic interpolate relative_position_bias_table if not match relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] for k in relative_position_bias_table_keys: relative_position_bias_table_pretrained = state_dict[k] relative_position_bias_table_current = model.state_dict()[k] L1, nH1 = relative_position_bias_table_pretrained.size() L2, nH2 = relative_position_bias_table_current.size() if nH1 != nH2: logger.warning(f"Error in loading {k}, passing......") else: if L1 != L2: # bicubic interpolate relative_position_bias_table if not match S1 = int(L1 ** 0.5) S2 = int(L2 ** 0.5) relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), mode='bicubic') state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) # bicubic interpolate absolute_pos_embed if not match absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k] for k in absolute_pos_embed_keys: # dpe absolute_pos_embed_pretrained = state_dict[k] absolute_pos_embed_current = model.state_dict()[k] _, L1, C1 = absolute_pos_embed_pretrained.size() _, L2, C2 = absolute_pos_embed_current.size() if C1 != C1: logger.warning(f"Error in loading {k}, passing......") else: if L1 != L2: S1 = int(L1 ** 0.5) S2 = int(L2 ** 0.5) absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1) absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2) state_dict[k] = absolute_pos_embed_pretrained_resized # check classifier, if not match, then re-init classifier to zero head_bias_pretrained = state_dict['head.bias'] Nc1 = head_bias_pretrained.shape[0] Nc2 = model.head.bias.shape[0] if (Nc1 != Nc2): if Nc1 == 21841 and Nc2 == 1000: logger.info("loading ImageNet-22K weight to ImageNet-1K ......") map22kto1k_path = f'data/map22kto1k.txt' with open(map22kto1k_path) as f: map22kto1k = f.readlines() map22kto1k = [int(id22k.strip()) for id22k in map22kto1k] state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :] state_dict['head.bias'] = state_dict['head.bias'][map22kto1k] else: torch.nn.init.constant_(model.head.bias, 0.) torch.nn.init.constant_(model.head.weight, 0.) del state_dict['head.weight'] del state_dict['head.bias'] logger.warning(f"Error in loading classifier head, re-init classifier head to 0") msg = model.load_state_dict(state_dict, strict=False) logger.warning(msg) logger.info(f"=> loaded successfully '{config.MODEL.PRETRAINED}'") del checkpoint torch.cuda.empty_cache() def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger): save_state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'max_accuracy': max_accuracy, 'scaler': loss_scaler.state_dict(), 'epoch': epoch, 'config': config} save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') logger.info(f"{save_path} saving......") torch.save(save_state, save_path) logger.info(f"{save_path} saved !!!") def get_grad_norm(parameters, norm_type=2): if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = list(filter(lambda p: p.grad is not None, parameters)) norm_type = float(norm_type) total_norm = 0 for p in parameters: param_norm = p.grad.data.norm(norm_type) total_norm += param_norm.item() ** norm_type total_norm = total_norm ** (1. / norm_type) return total_norm def auto_resume_helper(output_dir): checkpoints = os.listdir(output_dir) checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] print(f"All checkpoints founded in {output_dir}: {checkpoints}") if len(checkpoints) > 0: latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) print(f"The latest checkpoint founded: {latest_checkpoint}") resume_file = latest_checkpoint else: resume_file = None return resume_file def reduce_tensor(tensor): rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) rt /= dist.get_world_size() return rt def ampscaler_get_grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor: if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = [p for p in parameters if p.grad is not None] norm_type = float(norm_type) if len(parameters) == 0: return torch.tensor(0.) device = parameters[0].grad.device if norm_type == inf: total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) else: total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) return total_norm class NativeScalerWithGradNormCount: state_dict_key = "amp_scaler" def __init__(self): self._scaler = torch.cuda.amp.GradScaler() def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): self._scaler.scale(loss).backward(create_graph=create_graph) if update_grad: if clip_grad is not None: assert parameters is not None self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) else: self._scaler.unscale_(optimizer) norm = ampscaler_get_grad_norm(parameters) self._scaler.step(optimizer) self._scaler.update() else: norm = None return norm def state_dict(self): return self._scaler.state_dict() def load_state_dict(self, state_dict): self._scaler.load_state_dict(state_dict) ================================================ FILE: utils_moe.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # -------------------------------------------------------- import os import torch import torch.distributed as dist def split_moe_model_state_dict(moe_keys, model_state_dict): moe_model_state_dict = {} non_moe_model_state_dict = {} for (k, v) in model_state_dict.items(): if k in moe_keys: moe_model_state_dict[k] = v else: non_moe_model_state_dict[k] = v return moe_model_state_dict, non_moe_model_state_dict def merge_moe_model_state_dict(moe_model_state_dict, non_moe_model_state_dict): model_state_dict = {} model_state_dict.update(moe_model_state_dict) model_state_dict.update(non_moe_model_state_dict) return model_state_dict def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger): global_rank = dist.get_rank() logger.info(f"==============> Rank[{global_rank}] Resuming form {config.MODEL.RESUME}....................") if config.MODEL.RESUME.endswith(f'.pth'): if config.TRAIN.MOE.SAVE_MASTER: resume_path = config.MODEL.RESUME + f'.global' else: resume_path = config.MODEL.RESUME + f'.rank{global_rank}' logger.info(f"===> Rank[{global_rank}] Re-formatting checkpoint name to {resume_path}......") else: resume_path = config.MODEL.RESUME checkpoint = torch.load(resume_path, map_location='cpu') msg = model.load_state_dict(checkpoint['model'], strict=False) logger.info(msg) max_accuracy = 0.0 if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) config.defrost() config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 config.freeze() if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) logger.info(f"=>Rank[{global_rank}] loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") if 'max_accuracy' in checkpoint: max_accuracy = checkpoint['max_accuracy'] del checkpoint torch.cuda.empty_cache() return max_accuracy def load_pretrained(config, model, logger): global_rank = dist.get_rank() logger.info(f"==============> Rank[{global_rank}] Loading weight {config.MODEL.PRETRAINED} for fine-tuning......") if config.MODEL.PRETRAINED.endswith(f'.pth'): if config.TRAIN.MOE.SAVE_MASTER: pretrained_path = config.MODEL.PRETRAINED + f'.global' else: pretrained_path = config.MODEL.PRETRAINED + f'.rank{global_rank}' logger.info(f"===> Rank[{global_rank}] Re-formatting checkpoint name to {pretrained_path}......") else: pretrained_path = config.MODEL.PRETRAINED if pretrained_path.endswith(f'.rank{global_rank}'): checkpoint = torch.load(pretrained_path, map_location='cpu') if os.path.exists(pretrained_path.replace(f'.rank{global_rank}', f'.master')): checkpoint_master = torch.load(pretrained_path.replace(f'.rank{global_rank}', f'.master'), map_location='cpu') state_dict = merge_moe_model_state_dict(checkpoint['model'], checkpoint_master['model']) else: state_dict = checkpoint['model'] elif pretrained_path.endswith(f'.pth.global'): checkpoint = torch.load(pretrained_path, map_location='cpu') state_dict = checkpoint['model'] else: raise NotImplementedError(f"{config.MODEL.PRETRAINED} file error...") # delete relative_position_index since we always re-init it relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k] for k in relative_position_index_keys: del state_dict[k] # delete relative_coords_table since we always re-init it relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k] for k in relative_position_index_keys: del state_dict[k] # delete attn_mask since we always re-init it attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] for k in attn_mask_keys: del state_dict[k] # bicubic interpolate relative_position_bias_table if not match relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] for k in relative_position_bias_table_keys: relative_position_bias_table_pretrained = state_dict[k] relative_position_bias_table_current = model.state_dict()[k] L1, nH1 = relative_position_bias_table_pretrained.size() L2, nH2 = relative_position_bias_table_current.size() if nH1 != nH2: logger.warning(f"Error in loading {k}, passing......") else: if L1 != L2: # bicubic interpolate relative_position_bias_table if not match S1 = int(L1 ** 0.5) S2 = int(L2 ** 0.5) relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), mode='bicubic') state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) # bicubic interpolate absolute_pos_embed if not match absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k] for k in absolute_pos_embed_keys: # dpe absolute_pos_embed_pretrained = state_dict[k] absolute_pos_embed_current = model.state_dict()[k] _, L1, C1 = absolute_pos_embed_pretrained.size() _, L2, C2 = absolute_pos_embed_current.size() if C1 != C1: logger.warning(f"Error in loading {k}, passing......") else: if L1 != L2: S1 = int(L1 ** 0.5) S2 = int(L2 ** 0.5) absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1) absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2) state_dict[k] = absolute_pos_embed_pretrained_resized # check classifier, if not match, then re-init classifier to zero head_bias_pretrained = state_dict['head.bias'] Nc1 = head_bias_pretrained.shape[0] Nc2 = model.head.bias.shape[0] if (Nc1 != Nc2): if Nc1 == 21841 and Nc2 == 1000: logger.info("loading ImageNet-22K weight to ImageNet-1K ......") map22kto1k_path = f'data/map22kto1k.txt' with open(map22kto1k_path) as f: map22kto1k = f.readlines() map22kto1k = [int(id22k.strip()) for id22k in map22kto1k] state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :] state_dict['head.bias'] = state_dict['head.bias'][map22kto1k] else: torch.nn.init.constant_(model.head.bias, 0.) torch.nn.init.constant_(model.head.weight, 0.) del state_dict['head.weight'] del state_dict['head.bias'] logger.warning(f"Error in loading classifier head, re-init classifier head to 0") msg = model.load_state_dict(state_dict, strict=False) logger.warning(msg) logger.info(f"=> loaded successfully '{config.MODEL.PRETRAINED}'") del checkpoint torch.cuda.empty_cache() def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger, zero_redundancy=False): global_rank = dist.get_rank() if zero_redundancy: if config.TRAIN.MOE.SAVE_MASTER: save_state = {'model': model.state_dict()} if global_rank == 0: save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.global') logger.info(f"{save_path} saving......") torch.save(save_state, save_path) logger.info(f"{save_path} saved !!!") else: moe_model_state_dict, non_moe_model_state_dict = \ split_moe_model_state_dict(model._ddp_params_and_buffers_to_ignore, model.state_dict()) save_state = {'model': moe_model_state_dict} save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.rank{global_rank}') logger.info(f"{save_path} saving......") torch.save(save_state, save_path) logger.info(f"{save_path} saved !!!") if global_rank == 0: save_state_master = {'model': non_moe_model_state_dict} save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.master') logger.info(f"{save_path} saving......") torch.save(save_state_master, save_path) logger.info(f"{save_path} saved !!!") else: save_state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'max_accuracy': max_accuracy, 'scaler': loss_scaler.state_dict(), 'epoch': epoch, 'config': config} if config.TRAIN.MOE.SAVE_MASTER: if global_rank == 0: save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.global') logger.info(f"{save_path} saving......") torch.save(save_state, save_path) logger.info(f"{save_path} saved !!!") else: save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.rank{global_rank}') logger.info(f"{save_path} saving......") torch.save(save_state, save_path) logger.info(f"{save_path} saved !!!") def auto_resume_helper(output_dir, save_master=False): global_rank = dist.get_rank() checkpoints = os.listdir(output_dir) if not save_master: master_checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith(f'pth.rank0')] else: master_checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith(f'pth.global')] print(f"All master checkpoints founded in {output_dir}: {master_checkpoints}") if len(master_checkpoints) > 0: latest_master_checkpoint = max([os.path.join(output_dir, d) for d in master_checkpoints], key=os.path.getmtime) latest_checkpoint = latest_master_checkpoint.replace('pth.rank0', f'pth.rank{global_rank}') print(f"The latest checkpoint founded: {latest_checkpoint}") resume_file = latest_checkpoint else: resume_file = None return resume_file def hook_scale_grad(scale, tensor): return tensor / scale ================================================ FILE: utils_simmim.py ================================================ # -------------------------------------------------------- # SimMIM # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # Modified by Zhenda Xie # -------------------------------------------------------- import os import torch import torch.distributed as dist import numpy as np from scipy import interpolate def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger): logger.info(f">>>>>>>>>> Resuming from {config.MODEL.RESUME} ..........") if config.MODEL.RESUME.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( config.MODEL.RESUME, map_location='cpu', check_hash=True) else: checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') # re-map keys due to name change (only for loading provided models) rpe_mlp_keys = [k for k in checkpoint['model'].keys() if "rpe_mlp" in k] for k in rpe_mlp_keys: checkpoint['model'][k.replace('rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k) msg = model.load_state_dict(checkpoint['model'], strict=False) logger.info(msg) max_accuracy = 0.0 if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'scaler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) scaler.load_state_dict(checkpoint['scaler']) config.defrost() config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 config.freeze() logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") if 'max_accuracy' in checkpoint: max_accuracy = checkpoint['max_accuracy'] else: max_accuracy = 0.0 del checkpoint torch.cuda.empty_cache() return max_accuracy def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, scaler, logger): save_state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'scaler': scaler.state_dict(), 'max_accuracy': max_accuracy, 'epoch': epoch, 'config': config} save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') logger.info(f"{save_path} saving......") torch.save(save_state, save_path) logger.info(f"{save_path} saved !!!") def get_grad_norm(parameters, norm_type=2): if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = list(filter(lambda p: p.grad is not None, parameters)) norm_type = float(norm_type) total_norm = 0 for p in parameters: param_norm = p.grad.data.norm(norm_type) total_norm += param_norm.item() ** norm_type total_norm = total_norm ** (1. / norm_type) return total_norm def auto_resume_helper(output_dir, logger): checkpoints = os.listdir(output_dir) checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}") if len(checkpoints) > 0: latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) logger.info(f"The latest checkpoint founded: {latest_checkpoint}") resume_file = latest_checkpoint else: resume_file = None return resume_file def reduce_tensor(tensor): rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) rt /= dist.get_world_size() return rt def load_pretrained(config, model, logger): logger.info(f">>>>>>>>>> Fine-tuned from {config.MODEL.PRETRAINED} ..........") checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') checkpoint_model = checkpoint['model'] if any([True if 'encoder.' in k else False for k in checkpoint_model.keys()]): checkpoint_model = {k.replace('encoder.', ''): v for k, v in checkpoint_model.items() if k.startswith('encoder.')} logger.info('Detect pre-trained model, remove [encoder.] prefix.') else: logger.info('Detect non-pre-trained model, pass without doing anything.') if config.MODEL.TYPE in ['swin', 'swinv2']: logger.info(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........") checkpoint = remap_pretrained_keys_swin(model, checkpoint_model, logger) else: raise NotImplementedError msg = model.load_state_dict(checkpoint_model, strict=False) logger.info(msg) del checkpoint torch.cuda.empty_cache() logger.info(f">>>>>>>>>> loaded successfully '{config.MODEL.PRETRAINED}'") def remap_pretrained_keys_swin(model, checkpoint_model, logger): state_dict = model.state_dict() # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size all_keys = list(checkpoint_model.keys()) for key in all_keys: if "relative_position_bias_table" in key: relative_position_bias_table_pretrained = checkpoint_model[key] relative_position_bias_table_current = state_dict[key] L1, nH1 = relative_position_bias_table_pretrained.size() L2, nH2 = relative_position_bias_table_current.size() if nH1 != nH2: logger.info(f"Error in loading {key}, passing......") else: if L1 != L2: logger.info(f"{key}: Interpolate relative_position_bias_table using geo.") src_size = int(L1 ** 0.5) dst_size = int(L2 ** 0.5) def geometric_progression(a, r, n): return a * (1.0 - r ** n) / (1.0 - r) left, right = 1.01, 1.5 while right - left > 1e-6: q = (left + right) / 2.0 gp = geometric_progression(1, q, src_size // 2) if gp > dst_size // 2: right = q else: left = q # if q > 1.090307: # q = 1.090307 dis = [] cur = 1 for i in range(src_size // 2): dis.append(cur) cur += q ** (i + 1) r_ids = [-_ for _ in reversed(dis)] x = r_ids + [0] + dis y = r_ids + [0] + dis t = dst_size // 2.0 dx = np.arange(-t, t + 0.1, 1.0) dy = np.arange(-t, t + 0.1, 1.0) logger.info("Original positions = %s" % str(x)) logger.info("Target positions = %s" % str(dx)) all_rel_pos_bias = [] for i in range(nH1): z = relative_position_bias_table_pretrained[:, i].view(src_size, src_size).float().numpy() f_cubic = interpolate.interp2d(x, y, z, kind='cubic') all_rel_pos_bias.append(torch.Tensor(f_cubic(dx, dy)).contiguous().view(-1, 1).to( relative_position_bias_table_pretrained.device)) new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) checkpoint_model[key] = new_rel_pos_bias # delete relative_position_index since we always re-init it relative_position_index_keys = [k for k in checkpoint_model.keys() if "relative_position_index" in k] for k in relative_position_index_keys: del checkpoint_model[k] # delete relative_coords_table since we always re-init it relative_coords_table_keys = [k for k in checkpoint_model.keys() if "relative_coords_table" in k] for k in relative_coords_table_keys: del checkpoint_model[k] # re-map keys due to name change rpe_mlp_keys = [k for k in checkpoint_model.keys() if "rpe_mlp" in k] for k in rpe_mlp_keys: checkpoint_model[k.replace('rpe_mlp', 'cpb_mlp')] = checkpoint_model.pop(k) # delete attn_mask since we always re-init it attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k] for k in attn_mask_keys: del checkpoint_model[k] return checkpoint_model