Repository: jamycheung/DELIVER Branch: main Commit: dd6a5aacff3f Files: 58 Total size: 262.6 KB Directory structure: gitextract_aaso4r_f/ ├── .gitignore ├── LICENSE ├── README.md ├── configs/ │ ├── deliver_rgbdel.yaml │ ├── kitti360_rgbdel.yaml │ ├── mcubes_rgbadn.yaml │ ├── mfnet_rgbt.yaml │ ├── nyu_rgbd.yaml │ └── urbanlf.yaml ├── environment.yaml ├── requirements.txt ├── semseg/ │ ├── __init__.py │ ├── augmentations.py │ ├── augmentations_mm.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── deliver.py │ │ ├── kitti360.py │ │ ├── mcubes.py │ │ ├── mfnet.py │ │ ├── nyu.py │ │ ├── unzip.py │ │ └── urbanlf.py │ ├── losses.py │ ├── metrics.py │ ├── models/ │ │ ├── __init__.py │ │ ├── backbones/ │ │ │ ├── __init__.py │ │ │ ├── cmnext.py │ │ │ └── cmx.py │ │ ├── base.py │ │ ├── cmnext.py │ │ ├── cmx.py │ │ ├── heads/ │ │ │ ├── __init__.py │ │ │ ├── condnet.py │ │ │ ├── fapn.py │ │ │ ├── fcn.py │ │ │ ├── fpn.py │ │ │ ├── hem.py │ │ │ ├── lawin.py │ │ │ ├── segformer.py │ │ │ ├── sfnet.py │ │ │ └── upernet.py │ │ ├── layers/ │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ └── initialize.py │ │ └── modules/ │ │ ├── __init__.py │ │ ├── crossatt.py │ │ ├── ffm.py │ │ ├── mspa.py │ │ ├── ppm.py │ │ └── psa.py │ ├── optimizers.py │ ├── schedulers.py │ └── utils/ │ ├── __init__.py │ ├── utils.py │ └── visualize.py └── tools/ ├── infer_mm.py ├── train_mm.py └── val_mm.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Repo-specific GitIgnore ---------------------------------------------------------------------------------------------- *.jpg *.jpeg *.png *.bmp *.tif *.tiff *.heic *.JPG *.JPEG *.PNG *.BMP *.TIF *.TIFF *.HEIC *.mp4 *.mov *.MOV *.avi *.data *.json *.pth *.cfg !cfg/yolov3*.cfg storage.googleapis.com runs/* data/* !data/images/zidane.jpg !data/images/bus.jpg !data/coco.names !data/coco_paper.names !data/coco.data !data/coco_*.data !data/coco_*.txt !data/trainvalno5k.shapes !data/*.sh test.py test_imgs/ pycocotools/* results*.txt gcp_test*.sh checkpoints/ # output/ # output*/ *events* assests/*/ # Datasets ------------------------------------------------------------------------------------------------------------- coco/ coco128/ VOC/ # MATLAB GitIgnore ----------------------------------------------------------------------------------------------------- *.m~ *.mat !targets*.mat # Neural Network weights ----------------------------------------------------------------------------------------------- *.weights *.pt *.onnx *.mlmodel *.torchscript darknet53.conv.74 yolov3-tiny.conv.15 # GitHub Python GitIgnore ---------------------------------------------------------------------------------------------- # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python env/ build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ wandb/ .installed.cfg *.egg # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover .hypothesis/ # Translations *.mo *.pot # Django stuff: # *.log local_settings.py # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # pyenv .python-version # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # dotenv .env # virtualenv .venv* venv*/ ENV*/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ # https://github.com/github/gitignore/blob/master/Global/macOS.gitignore ----------------------------------------------- # General .DS_Store .AppleDouble .LSOverride # Icon must end with two \r Icon Icon? # Thumbnails ._* # Files that might appear in the root of a volume .DocumentRevisions-V100 .fseventsd .Spotlight-V100 .TemporaryItems .Trashes .VolumeIcon.icns .com.apple.timemachine.donotpresent # Directories potentially created on remote AFP share .AppleDB .AppleDesktop Network Trash Folder Temporary Items .apdisk # https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 # User-specific stuff: .idea/* .idea/**/workspace.xml .idea/**/tasks.xml .idea/dictionaries .html # Bokeh Plots .pg # TensorFlow Frozen Graphs .avi # videos # Sensitive or high-churn files: .idea/**/dataSources/ .idea/**/dataSources.ids .idea/**/dataSources.local.xml .idea/**/sqlDataSources.xml .idea/**/dynamic.xml .idea/**/uiDesigner.xml # Gradle: .idea/**/gradle.xml .idea/**/libraries # CMake cmake-build-debug/ cmake-build-release/ # Mongo Explorer plugin: .idea/**/mongoSettings.xml ## File-based project format: *.iws ## Plugin-specific files: # IntelliJ out/ # mpeltonen/sbt-idea plugin .idea_modules/ # JIRA plugin atlassian-ide-plugin.xml # Cursive Clojure plugin .idea/replstate.xml # Crashlytics plugin (for Android Studio and IntelliJ) com_crashlytics_export_strings.xml crashlytics.properties crashlytics-build.properties fabric.properties output/ data/ ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [2023] [Jiaming Zhang] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================
The visualization results on DELIVER dataset. From left to right are the respective *cloudy*, *foggy*, *night* and *rainy* scene.
## Acknowledgements
Thanks for the public repositories:
- [RGBX-semantic-segmentation](https://github.com/huaaaliu/RGBX_Semantic_Segmentation)
- [Semantic-segmentation](https://github.com/sithu31296/semantic-segmentation)
## License
This repository is under the Apache-2.0 license. For commercial use, please contact with the authors.
## Citations
If you use DeLiVer dataset and CMNeXt model, please cite the following works:
- **DeLiVER & CMNeXt** [[**PDF**](https://arxiv.org/pdf/2303.01480.pdf)]
```
@inproceedings{zhang2023delivering,
title={Delivering Arbitrary-Modal Semantic Segmentation},
author={Zhang, Jiaming and Liu, Ruiping and Shi, Hao and Yang, Kailun and Rei{\ss}, Simon and Peng, Kunyu and Fu, Haodong and Wang, Kaiwei and Stiefelhagen, Rainer},
booktitle={CVPR},
year={2023}
}
```
- **CMX** [[**PDF**](https://arxiv.org/pdf/2203.04838.pdf)]
```
@article{zhang2023cmx,
title={CMX: Cross-modal fusion for RGB-X semantic segmentation with transformers},
author={Zhang, Jiaming and Liu, Huayao and Yang, Kailun and Hu, Xinxin and Liu, Ruiping and Stiefelhagen, Rainer},
journal={IEEE Transactions on Intelligent Transportation Systems},
year={2023}
}
```
================================================
FILE: configs/deliver_rgbdel.yaml
================================================
DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...)
SAVE_DIR : 'output' # output folder name used for saving the model, logs and inference results
MODEL:
NAME : CMNeXt # name of the model you are using
BACKBONE : CMNeXt-B2 # model variant
PRETRAINED : 'checkpoints/pretrained/segformer/mit_b2.pth' # backbone model's weight
RESUME : '' # checkpoint file
DATASET:
NAME : DELIVER # dataset name to be trained with (camvid, cityscapes, ade20k)
ROOT : 'data/DELIVER' # dataset root path
IGNORE_LABEL : 255
# MODALS : ['img']
# MODALS : ['img', 'depth']
# MODALS : ['img', 'event']
# MODALS : ['img', 'lidar']
# MODALS : ['img', 'depth', 'event']
# MODALS : ['img', 'depth', 'lidar']
MODALS : ['img', 'depth', 'event', 'lidar']
TRAIN:
IMAGE_SIZE : [1024, 1024] # training image size in (h, w)
BATCH_SIZE : 2 # batch size used to train
EPOCHS : 200 # number of epochs to train
EVAL_START : 100 # evaluation interval start
EVAL_INTERVAL : 1 # evaluation interval during training
AMP : false # use AMP in training
DDP : true # use DDP training
LOSS:
NAME : OhemCrossEntropy # loss function name
CLS_WEIGHTS : false # use class weights in loss calculation
OPTIMIZER:
NAME : adamw # optimizer name
LR : 0.00006 # initial learning rate used in optimizer
WEIGHT_DECAY : 0.01 # decay rate used in optimizer
SCHEDULER:
NAME : warmuppolylr # scheduler name
POWER : 0.9 # scheduler power
WARMUP : 10 # warmup epochs used in scheduler
WARMUP_RATIO : 0.1 # warmup ratio
EVAL:
# MODEL_PATH : 'output/DELIVER/cmnext_b2_deliver_rgb.pth'
# MODEL_PATH : 'output/DELIVER/cmnext_b2_deliver_rgbd.pth'
# MODEL_PATH : 'output/DELIVER/cmnext_b2_deliver_rgbe.pth'
# MODEL_PATH : 'output/DELIVER/cmnext_b2_deliver_rgbl.pth'
# MODEL_PATH : 'output/DELIVER/cmnext_b2_deliver_rgbde.pth'
# MODEL_PATH : 'output/DELIVER/cmnext_b2_deliver_rgbdl.pth'
MODEL_PATH : 'output/DELIVER/cmnext_b2_deliver_rgbdel.pth'
IMAGE_SIZE : [1024, 1024] # evaluation image size in (h, w)
BATCH_SIZE : 4 # batch size used to train
MSF:
ENABLE : false # multi-scale and flip evaluation
FLIP : true # use flip in evaluation
SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation
TEST:
MODEL_PATH : 'output/DELIVER/cmnext_b2_deliver_rgbdel.pth' # trained model file path
FILE : 'data/DELIVER' # filename or foldername
IMAGE_SIZE : [1024, 1024] # inference image size in (h, w)
OVERLAY : false # save the overlay result (image_alpha+label_alpha)
================================================
FILE: configs/kitti360_rgbdel.yaml
================================================
DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...)
SAVE_DIR : 'output' # output folder name used for saving the model, logs and inference results
MODEL:
NAME : CMNeXt # name of the model you are using
BACKBONE : CMNeXt-B2 # model variant
PRETRAINED : 'checkpoints/pretrained/segformer/mit_b2.pth' # backbone model's weight
RESUME : '' # checkpoint file
DATASET:
NAME : KITTI360 # dataset name to be trained with (camvid, cityscapes, ade20k)
ROOT : 'data/KITTI360' # dataset root path
IGNORE_LABEL : 255
# MODALS : ['img']
# MODALS : ['img', 'depth']
# MODALS : ['img', 'event']
# MODALS : ['img', 'lidar']
# MODALS : ['img', 'depth', 'event']
# MODALS : ['img', 'depth', 'lidar']
MODALS : ['img', 'depth', 'event', 'lidar']
TRAIN:
IMAGE_SIZE : [376, 1408] # training image size in (h, w)
BATCH_SIZE : 4 # batch size used to train --- KD
EPOCHS : 40 # number of epochs to train
EVAL_START : 10 # evaluation interval during training
EVAL_INTERVAL : 1 # evaluation interval during training
AMP : false # use AMP in training
DDP : true # use DDP training
LOSS:
NAME : OhemCrossEntropy # loss function name
CLS_WEIGHTS : false # use class weights in loss calculation
OPTIMIZER:
NAME : adamw # optimizer name
LR : 0.00006 # initial learning rate used in optimizer
WEIGHT_DECAY : 0.01 # decay rate used in optimizer
SCHEDULER:
NAME : warmuppolylr # scheduler name
POWER : 0.9 # scheduler power
WARMUP : 10 # warmup epochs used in scheduler
WARMUP_RATIO : 0.1 # warmup ratio
EVAL:
# MODEL_PATH : 'output/KITTI360/cmnext_b2_kitti360_rgb.pth'
# MODEL_PATH : 'output/KITTI360/cmnext_b2_kitti360_rgbd.pth'
# MODEL_PATH : 'output/KITTI360/cmnext_b2_kitti360_rgbe.pth'
# MODEL_PATH : 'output/KITTI360/cmnext_b2_kitti360_rgbl.pth'
# MODEL_PATH : 'output/KITTI360/cmnext_b2_kitti360_rgbde.pth'
# MODEL_PATH : 'output/KITTI360/cmnext_b2_kitti360_rgbdl.pth'
MODEL_PATH : 'output/KITTI360/cmnext_b2_kitti360_rgbdel.pth'
IMAGE_SIZE : [376, 1408] # evaluation image size in (h, w)
BATCH_SIZE : 4 # batch size used to train
MSF:
ENABLE : false # multi-scale and flip evaluation
FLIP : true # use flip in evaluation
SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation
TEST:
MODEL_PATH : 'output/KITTI360/cmnext_b2_kitti360_rgbdel.pth' # trained model file path
FILE : 'data/KITTI360' # filename or foldername
IMAGE_SIZE : [376, 1408] # inference image size in (h, w)
OVERLAY : false # save the overlay result (image_alpha+label_alpha)
================================================
FILE: configs/mcubes_rgbadn.yaml
================================================
DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...)
SAVE_DIR : 'output' # output folder name used for saving the model, logs and inference results
MODEL:
NAME : CMNeXt # name of the model you are using
BACKBONE : CMNeXt-B2 # model variant
PRETRAINED : 'checkpoints/pretrained/segformer/mit_b4.pth' # backbone model's weight
RESUME : '' # checkpoint file
DATASET:
NAME : MCubeS # dataset name to be trained with (camvid, cityscapes, ade20k)
ROOT : 'data/MCubeS/multimodal_dataset' # dataset root path
IGNORE_LABEL : 255
# MODALS : ['image'] #
# MODALS : ['image', 'aolp']
# MODALS : ['image', 'aolp', 'dolp']
MODALS : ['image', 'aolp', 'dolp', 'nir']
TRAIN:
IMAGE_SIZE : [512, 512] # training image size in (h, w) === Fixed in dataloader, following MCubeSNet
BATCH_SIZE : 4 # batch size used to train
EPOCHS : 500 # number of epochs to train
EVAL_START : 400 # evaluation interval during training
EVAL_INTERVAL : 1 # evaluation interval during training
AMP : false # use AMP in training
DDP : true # use DDP training
LOSS:
NAME : OhemCrossEntropy # loss function name
CLS_WEIGHTS : false # use class weights in loss calculation
OPTIMIZER:
NAME : adamw # optimizer name
LR : 0.00006 # initial learning rate used in optimizer
WEIGHT_DECAY : 0.01 # decay rate used in optimizer
SCHEDULER:
NAME : warmuppolylr # scheduler name
POWER : 0.9 # scheduler power
WARMUP : 10 # warmup epochs used in scheduler
WARMUP_RATIO : 0.1 # warmup ratio
EVAL:
# MODEL_PATH : 'output/MCubeS/cmnext_b2_mcubes_rgb.pth'
# MODEL_PATH : 'output/MCubeS/cmnext_b2_mcubes_rgba.pth'
# MODEL_PATH : 'output/MCubeS/cmnext_b2_mcubes_rgbad.pth'
MODEL_PATH : 'output/MCubeS/cmnext_b2_mcubes_rgbadn.pth'
IMAGE_SIZE : [1024, 1024] # evaluation image size in (h, w)
BATCH_SIZE : 2 # batch size used to train
MSF:
ENABLE : false # multi-scale and flip evaluation
FLIP : true # use flip in evaluation
SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation
================================================
FILE: configs/mfnet_rgbt.yaml
================================================
DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...)
SAVE_DIR : 'output' # output folder name used for saving the model, logs and inference results
MODEL:
NAME : CMNeXt # name of the model you are using
BACKBONE : CMNeXt-B4 # model variant
PRETRAINED : 'checkpoints/pretrained/segformer/mit_b4.pth' # backbone model's weight
RESUME : '' # checkpoint file
DATASET:
NAME : MFNet # dataset name to be trained with (camvid, cityscapes, ade20k)
ROOT : 'data/MFNet' # dataset root path
IGNORE_LABEL : 255
# MODALS : ['img']
MODALS : ['img', 'thermal']
TRAIN:
IMAGE_SIZE : [480, 640] # training image size in (h, w)
BATCH_SIZE : 4 # batch size used to train
EPOCHS : 500 # number of epochs to train
EVAL_START : 300 # evaluation interval during training
EVAL_INTERVAL : 1 # evaluation interval during training
AMP : false # use AMP in training
DDP : true # use DDP training
LOSS:
NAME : CrossEntropy # loss function name (ohemce, ce, dice)
CLS_WEIGHTS : false # use class weights in loss calculation
OPTIMIZER:
NAME : adamw # optimizer name
LR : 0.00006 # initial learning rate used in optimizer
WEIGHT_DECAY : 0.01 # decay rate used in optimizer
SCHEDULER:
NAME : warmuppolylr # scheduler name
POWER : 0.9 # scheduler power
WARMUP : 10 # warmup epochs used in scheduler
WARMUP_RATIO : 0.1 # warmup ratio
EVAL:
MODEL_PATH : 'output/MFNet/cmnext_b4_mfnet_rgbt.pth'
IMAGE_SIZE : [480, 640] # evaluation image size in (h, w)
BATCH_SIZE : 2 # batch size used to train
MSF:
ENABLE : false # multi-scale and flip evaluation
FLIP : true # use flip in evaluation
SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation
================================================
FILE: configs/nyu_rgbd.yaml
================================================
DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...)
SAVE_DIR : 'output' # output folder name used for saving the model, logs and inference results
MODEL:
NAME : CMNeXt # name of the model you are using
BACKBONE : CMNeXt-B4 # model variant
PRETRAINED : 'checkpoints/pretrained/segformer/mit_b4.pth' # backbone model's weight
RESUME : '' # checkpoint file
DATASET:
NAME : NYU # dataset name to be trained with (camvid, cityscapes, ade20k)
ROOT : 'data/NYUDepthv2' # dataset root path
IGNORE_LABEL : 255
# MODALS : ['img']
MODALS : ['img', 'depth']
TRAIN:
IMAGE_SIZE : [480, 640] # training image size in (h, w)
BATCH_SIZE : 4 # batch size used to train
EPOCHS : 500 # number of epochs to train
EVAL_START : 300 # evaluation interval during training
EVAL_INTERVAL : 1 # evaluation interval during training
AMP : false # use AMP in training
DDP : true # use DDP training
LOSS:
NAME : CrossEntropy # loss function name
CLS_WEIGHTS : false # use class weights in loss calculation
OPTIMIZER:
NAME : adamw # optimizer name
LR : 0.00006 # initial learning rate used in optimizer
WEIGHT_DECAY : 0.01 # decay rate used in optimizer
SCHEDULER:
NAME : warmuppolylr # scheduler name
POWER : 0.9 # scheduler power
WARMUP : 10 # warmup epochs used in scheduler
WARMUP_RATIO : 0.1 # warmup ratio
EVAL:
MODEL_PATH : 'output/NYU_Depth_V2/cmnext_b4_nyu_rgbd.pth'
IMAGE_SIZE : [480, 640] # evaluation image size in (h, w)
BATCH_SIZE : 2 # batch size used to train
MSF:
ENABLE : true # multi-scale and flip evaluation
FLIP : true # use flip in evaluation
SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation
================================================
FILE: configs/urbanlf.yaml
================================================
DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...)
SAVE_DIR : 'output' # output folder name used for saving the model, logs and inference results
MODEL:
NAME : CMNeXt # name of the model you are using
BACKBONE : CMNeXt-B4 # model variant
PRETRAINED : 'checkpoints/pretrained/segformer/mit_b4.pth' # backbone model's weight
RESUME : '' # checkpoint file
DATASET:
NAME : UrbanLF # dataset name to be trained with (camvid, cityscapes, ade20k)
# ROOT : 'data/UrBanLF/real' # dataset root path, for real dataset
ROOT : 'data/UrBanLF/Syn' # dataset root path, for synthetic dataset
IGNORE_LABEL : 255
# MODALS : ['img']
# MODALS : ['img', '5_1', '5_2', '5_3', '5_4', '5_6', '5_7', '5_8', '5_9']
# MODALS : ['img', '1_1', '1_5', '1_9', '2_2', '2_5', '2_8', '3_3', '3_5', '3_7', '4_4', '4_5', '4_6', '5_1', '5_2', '5_3', '5_4', '5_6', '5_7', '5_8', '5_9', '6_4', '6_5', '6_6', '7_3', '7_5', '7_7', '8_2', '8_5', '8_8', '9_1', '9_5', '9_9']
MODALS : ['img', '1_1', '1_2', '1_3', '1_4', '1_5', '1_6', '1_7', '1_8', '1_9', '2_1', '2_2', '2_3', '2_4', '2_5', '2_6', '2_7', '2_8', '2_9', '3_1', '3_2', '3_3', '3_4', '3_5', '3_6', '3_7', '3_8', '3_9', '4_1', '4_2', '4_3', '4_4', '4_5', '4_6', '4_7', '4_8', '4_9', '5_1', '5_2', '5_3', '5_4', '5_6', '5_7', '5_8', '5_9', '6_1', '6_2', '6_3', '6_4', '6_5', '6_6', '6_7', '6_8', '6_9', '7_1', '7_2', '7_3', '7_4', '7_5', '7_6', '7_7', '7_8', '7_9', '8_1', '8_2', '8_3', '8_4', '8_5', '8_6', '8_7', '8_8', '8_9', '9_1', '9_2', '9_3', '9_4', '9_5', '9_6', '9_7', '9_8', '9_9']
TRAIN:
IMAGE_SIZE : [480, 640] # training image size in (h, w)
BATCH_SIZE : 2 # batch size used to train
EPOCHS : 500 # number of epochs to train
EVAL_START : 300 # evaluation interval start
EVAL_INTERVAL : 1 # evaluation interval during training
AMP : false # use AMP in training
DDP : true # use DDP training
LOSS:
NAME : OhemCrossEntropy # loss function name
CLS_WEIGHTS : false # use class weights in loss calculation
OPTIMIZER:
NAME : adamw # optimizer name
LR : 0.00006 # initial learning rate used in optimizer
WEIGHT_DECAY : 0.01 # decay rate used in optimizer
SCHEDULER:
NAME : warmuppolylr # scheduler name
POWER : 0.9 # scheduler power
WARMUP : 10 # warmup epochs used in scheduler
WARMUP_RATIO : 0.1 # warmup ratio
EVAL:
# MODEL_PATH : 'output/UrbanLF/cmnext_b4_urbanlf_real_rgblf1.pth'
# MODEL_PATH : 'output/UrbanLF/cmnext_b4_urbanlf_real_rgblf8.pth'
# MODEL_PATH : 'output/UrbanLF/cmnext_b4_urbanlf_real_rgblf33.pth'
# MODEL_PATH : 'output/UrbanLF/cmnext_b4_urbanlf_real_rgblf80.pth'
# MODEL_PATH : 'output/UrbanLF/cmnext_b4_urbanlf_syn_rgblf1.pth'
# MODEL_PATH : 'output/UrbanLF/cmnext_b4_urbanlf_syn_rgblf8.pth'
# MODEL_PATH : 'output/UrbanLF/cmnext_b4_urbanlf_syn_rgblf33.pth'
MODEL_PATH : 'output/UrbanLF/cmnext_b4_urbanlf_syn_rgblf80.pth'
IMAGE_SIZE : [480, 640] # eval image size in (h, w)
BATCH_SIZE : 2 # batch size used to train
MSF:
ENABLE : false # multi-scale and flip evaluation
FLIP : true # use flip in evaluation
SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation
================================================
FILE: environment.yaml
================================================
name: cmnext
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=4.5=1_gnu
- _pytorch_select=0.1=cpu_0
- blas=1.0=mkl
- bzip2=1.0.8=h7f98852_4
- ca-certificates=2021.10.26=h06a4308_2
- certifi=2021.10.8=py38h06a4308_2
- cffi=1.14.6=py38ha65f79e_0
- cudatoolkit=11.3.1=h2bc3f7f_2
- cudnn=8.2.1.32=h86fa8c9_0
- ffmpeg=4.3=hf484d3e_0
- freetype=2.10.4=h5ab3b9f_0
- future=0.18.2=py38h578d9bd_4
- gmp=6.2.1=h58526e2_0
- gnutls=3.6.13=h85f3911_1
- intel-openmp=2021.3.0=h06a4308_3350
- jpeg=9d=h7f8727e_0
- lame=3.100=h7f98852_1001
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.35.1=h7274673_9
- libblas=3.9.0=11_linux64_mkl
- libffi=3.3=he6710b0_2
- libgcc-ng=9.3.0=h5101ec6_17
- libgomp=9.3.0=h5101ec6_17
- libiconv=1.16=h516909a_0
- liblapack=3.9.0=11_linux64_mkl
- libpng=1.6.37=hbc83047_0
- libprotobuf=3.16.0=h780b84a_0
- libstdcxx-ng=9.3.0=hd4cf53a_17
- libtiff=4.2.0=h85742a9_0
- libuv=1.40.0=h7b6447c_0
- libwebp-base=1.2.0=h27cfd23_0
- lz4-c=1.9.3=h295c915_1
- magma=2.5.4=h6103c52_2
- mkl=2021.3.0=h06a4308_520
- mkl-service=2.4.0=py38h7f8727e_0
- mkl_fft=1.3.0=py38h42c9631_2
- mkl_random=1.2.2=py38h51133e4_0
- nccl=2.11.4.1=hdc17891_0
- ncurses=6.2=he6710b0_1
- nettle=3.6=he412f7d_0
- ninja=1.10.2=hff7bd54_1
- numpy=1.21.2=py38h20f2e39_0
- numpy-base=1.21.2=py38h79a1101_0
- olefile=0.46=pyhd3eb1b0_0
- openh264=2.1.1=h780b84a_0
- openjpeg=2.4.0=h3ad879b_0
- openssl=1.1.1m=h7f8727e_0
- pillow=8.3.1=py38h2c7a002_0
- pycparser=2.21=pyhd8ed1ab_0
- python=3.8.12=h12debd9_0
- python_abi=3.8=2_cp38
- pytorch=1.9.0=cuda112py38h3d13190_1
- pytorch-gpu=1.9.0=cuda112py38h0bbbad9_1
- pytorch-mutex=1.0=cuda
- readline=8.1=h27cfd23_0
- six=1.16.0=pyhd3eb1b0_0
- sleef=3.5.1=h7f98852_1
- sqlite=3.36.0=hc218d9a_0
- tk=8.6.11=h1ccaba5_0
- torchaudio=0.9.0=py38
- torchvision=0.10.0=py38cuda112h04b465a_0_cuda
- typing_extensions=3.10.0.2=pyh06a4308_0
- xz=5.2.5=h7b6447c_0
- yaml=0.2.5=h7b6447c_0
- zlib=1.2.11=h7b6447c_3
- zstd=1.4.9=haebb681_0
- pip:
- absl-py==1.2.0
- addict==2.4.0
- argon2-cffi==21.3.0
- argon2-cffi-bindings==21.2.0
- asttokens==2.0.5
- attrs==21.4.0
- backcall==0.2.0
- bleach==4.1.0
- cachetools==5.0.0
- charset-normalizer==2.1.1
- cycler==0.11.0
- dataclasses==0.6
- debugpy==1.5.1
- decorator==5.1.1
- defusedxml==0.7.1
- descartes==1.1.0
- easydict==1.9
- einops==0.4.1
- entrypoints==0.4
- executing==0.8.3
- fire==0.4.0
- fvcore==0.1.5.post20220512
- google-auth==2.11.0
- google-auth-oauthlib==0.4.6
- grpcio==1.48.1
- idna==3.3
- importlib-metadata==4.12.0
- importlib-resources==5.4.0
- iopath==0.1.10
- ipykernel==6.9.1
- ipython==8.1.0
- ipython-genutils==0.2.0
- ipywidgets==7.6.5
- jedi==0.18.1
- jinja2==3.0.3
- joblib==1.1.0
- jsonschema==4.4.0
- jupyter==1.0.0
- jupyter-client==7.1.2
- jupyter-console==6.4.0
- jupyter-core==4.9.2
- jupyterlab-pygments==0.1.2
- jupyterlab-widgets==1.0.2
- kiwisolver==1.3.2
- markdown==3.4.1
- markupsafe==2.1.1
- matplotlib==3.4.3
- matplotlib-inline==0.1.3
- mistune==0.8.4
- mmcv-full==1.6.1
- nbclient==0.5.11
- nbconvert==6.4.2
- nbformat==5.1.3
- nest-asyncio==1.5.4
- notebook==6.4.8
- nuscenes-devkit==1.1.9
- oauthlib==3.2.1
- opencv-python==4.5.3.56
- packaging==21.3
- pandocfilters==1.5.0
- parso==0.8.3
- pexpect==4.8.0
- pickleshare==0.7.5
- pip==22.0.3
- plyfile==0.7.4
- portalocker==2.5.1
- prometheus-client==0.13.1
- prompt-toolkit==3.0.28
- protobuf==3.18.1
- ptyprocess==0.7.0
- pure-eval==0.2.2
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pycocotools==2.0.4
- pygments==2.11.2
- pyparsing==3.0.6
- pyquaternion==0.9.9
- pyrsistent==0.18.1
- python-dateutil==2.8.2
- pyyaml==6.0
- pyzmq==22.3.0
- qtconsole==5.2.2
- qtpy==2.0.1
- requests==2.28.1
- requests-oauthlib==1.3.1
- rsa==4.9
- scikit-learn==1.0.2
- scipy==1.7.1
- send2trash==1.8.0
- setuptools==59.5.0
- shapely==1.8.1.post1
- stack-data==0.2.0
- tabulate==0.8.10
- tensorboard==2.10.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.1
- tensorboardx==2.4
- termcolor==1.1.0
- terminado==0.13.1
- testpath==0.6.0
- threadpoolctl==3.1.0
- timm==0.4.12
- tornado==6.1
- tqdm==4.62.3
- traitlets==5.1.1
- urllib3==1.26.12
- wcwidth==0.2.5
- webencodings==0.5.1
- werkzeug==2.2.2
- wheel==0.37.1
- widgetsnbextension==3.5.2
- yacs==0.1.8
- yapf==0.32.0
- zipp==3.7.0
================================================
FILE: requirements.txt
================================================
absl-py==1.2.0
addict==2.4.0
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
asttokens==2.0.5
attrs==21.4.0
backcall==0.2.0
bleach==4.1.0
cachetools==5.0.0
charset-normalizer==2.1.1
cycler==0.11.0
dataclasses==0.6
debugpy==1.5.1
decorator==5.1.1
defusedxml==0.7.1
descartes==1.1.0
easydict==1.9
einops==0.4.1
entrypoints==0.4
executing==0.8.3
fire==0.4.0
fvcore==0.1.5.post20220512
google-auth==2.11.0
google-auth-oauthlib==0.4.6
grpcio==1.48.1
idna==3.3
importlib-metadata==4.12.0
importlib-resources==5.4.0
iopath==0.1.10
ipykernel==6.9.1
ipython==8.1.0
ipython-genutils==0.2.0
ipywidgets==7.6.5
jedi==0.18.1
jinja2==3.0.3
joblib==1.1.0
jsonschema==4.4.0
jupyter==1.0.0
jupyter-client==7.1.2
jupyter-console==6.4.0
jupyter-core==4.9.2
jupyterlab-pygments==0.1.2
jupyterlab-widgets==1.0.2
kiwisolver==1.3.2
markdown==3.4.1
markupsafe==2.1.1
matplotlib==3.4.3
matplotlib-inline==0.1.3
mistune==0.8.4
mmcv-full==1.6.1
nbclient==0.5.11
nbconvert==6.4.2
nbformat==5.1.3
nest-asyncio==1.5.4
notebook==6.4.8
nuscenes-devkit==1.1.9
oauthlib==3.2.1
opencv-python==4.5.3.56
packaging==21.3
pandocfilters==1.5.0
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
pip==22.0.3
plyfile==0.7.4
portalocker==2.5.1
prometheus-client==0.13.1
prompt-toolkit==3.0.28
protobuf==3.18.1
ptyprocess==0.7.0
pure-eval==0.2.2
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycocotools==2.0.4
pygments==2.11.2
pyparsing==3.0.6
pyquaternion==0.9.9
pyrsistent==0.18.1
python-dateutil==2.8.2
pyyaml==6.0
pyzmq==22.3.0
qtconsole==5.2.2
qtpy==2.0.1
requests==2.28.1
requests-oauthlib==1.3.1
rsa==4.9
scikit-learn==1.0.2
scipy==1.7.1
send2trash==1.8.0
setuptools==59.5.0
shapely==1.8.1.post1
stack-data==0.2.0
tabulate==0.8.10
tensorboard==2.10.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorboardx==2.4
termcolor==1.1.0
terminado==0.13.1
testpath==0.6.0
threadpoolctl==3.1.0
timm==0.4.12
tornado==6.1
tqdm==4.62.3
traitlets==5.1.1
urllib3==1.26.12
wcwidth==0.2.5
webencodings==0.5.1
werkzeug==2.2.2
wheel==0.37.1
widgetsnbextension==3.5.2
yacs==0.1.8
yapf==0.32.0
zipp==3.7.0
================================================
FILE: semseg/__init__.py
================================================
from tabulate import tabulate
from semseg import models
from semseg import datasets
from semseg.models import backbones, heads
def show_models():
model_names = models.__all__
numbers = list(range(1, len(model_names)+1))
print(tabulate({'No.': numbers, 'Model Names': model_names}, headers='keys'))
def show_backbones():
backbone_names = backbones.__all__
variants = []
for name in backbone_names:
try:
variants.append(list(eval(f"backbones.{name.lower()}_settings").keys()))
except:
variants.append('-')
print(tabulate({'Backbone Names': backbone_names, 'Variants': variants}, headers='keys'))
def show_heads():
head_names = heads.__all__
numbers = list(range(1, len(head_names)+1))
print(tabulate({'No.': numbers, 'Heads': head_names}, headers='keys'))
def show_datasets():
dataset_names = datasets.__all__
numbers = list(range(1, len(dataset_names)+1))
print(tabulate({'No.': numbers, 'Datasets': dataset_names}, headers='keys'))
================================================
FILE: semseg/augmentations.py
================================================
import torchvision.transforms.functional as TF
import random
import math
import torch
from torch import Tensor
from typing import Tuple, List, Union, Tuple, Optional
class Compose:
def __init__(self, transforms: list) -> None:
self.transforms = transforms
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
if mask.ndim == 2:
assert img.shape[1:] == mask.shape
else:
assert img.shape[1:] == mask.shape[1:]
for transform in self.transforms:
img, mask = transform(img, mask)
return img, mask
class Normalize:
def __init__(self, mean: list = (0.485, 0.456, 0.406), std: list = (0.229, 0.224, 0.225)):
self.mean = mean
self.std = std
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
img = img.float()
img /= 255
img = TF.normalize(img, self.mean, self.std)
return img, mask
class ColorJitter:
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0) -> None:
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
if self.brightness > 0:
img = TF.adjust_brightness(img, self.brightness)
if self.contrast > 0:
img = TF.adjust_contrast(img, self.contrast)
if self.saturation > 0:
img = TF.adjust_saturation(img, self.saturation)
if self.hue > 0:
img = TF.adjust_hue(img, self.hue)
return img, mask
class AdjustGamma:
def __init__(self, gamma: float, gain: float = 1) -> None:
"""
Args:
gamma: Non-negative real number. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter.
gain: constant multiplier
"""
self.gamma = gamma
self.gain = gain
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
return TF.adjust_gamma(img, self.gamma, self.gain), mask
class RandomAdjustSharpness:
def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
self.sharpness = sharpness_factor
self.p = p
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
if random.random() < self.p:
img = TF.adjust_sharpness(img, self.sharpness)
return img, mask
class RandomAutoContrast:
def __init__(self, p: float = 0.5) -> None:
self.p = p
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
if random.random() < self.p:
img = TF.autocontrast(img)
return img, mask
class RandomGaussianBlur:
def __init__(self, kernel_size: int = 3, p: float = 0.5) -> None:
self.kernel_size = kernel_size
self.p = p
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
if random.random() < self.p:
img = TF.gaussian_blur(img, self.kernel_size)
return img, mask
class RandomHorizontalFlip:
def __init__(self, p: float = 0.5) -> None:
self.p = p
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
if random.random() < self.p:
return TF.hflip(img), TF.hflip(mask)
return img, mask
class RandomVerticalFlip:
def __init__(self, p: float = 0.5) -> None:
self.p = p
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
if random.random() < self.p:
return TF.vflip(img), TF.vflip(mask)
return img, mask
class RandomGrayscale:
def __init__(self, p: float = 0.5) -> None:
self.p = p
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
if random.random() < self.p:
img = TF.rgb_to_grayscale(img, 3)
return img, mask
class Equalize:
def __call__(self, image, label):
return TF.equalize(image), label
class Posterize:
def __init__(self, bits=2):
self.bits = bits # 0-8
def __call__(self, image, label):
return TF.posterize(image, self.bits), label
class Affine:
def __init__(self, angle=0, translate=[0, 0], scale=1.0, shear=[0, 0], seg_fill=0):
self.angle = angle
self.translate = translate
self.scale = scale
self.shear = shear
self.seg_fill = seg_fill
def __call__(self, img, label):
return TF.affine(img, self.angle, self.translate, self.scale, self.shear, TF.InterpolationMode.BILINEAR, 0), TF.affine(label, self.angle, self.translate, self.scale, self.shear, TF.InterpolationMode.NEAREST, self.seg_fill)
class RandomRotation:
def __init__(self, degrees: float = 10.0, p: float = 0.2, seg_fill: int = 0, expand: bool = False) -> None:
"""Rotate the image by a random angle between -angle and angle with probability p
Args:
p: probability
angle: rotation angle value in degrees, counter-clockwise.
expand: Optional expansion flag.
If true, expands the output image to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
Note that the expand flag assumes rotation around the center and no translation.
"""
self.p = p
self.angle = degrees
self.expand = expand
self.seg_fill = seg_fill
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
random_angle = random.random() * 2 * self.angle - self.angle
if random.random() < self.p:
img = TF.rotate(img, random_angle, TF.InterpolationMode.BILINEAR, self.expand, fill=0)
mask = TF.rotate(mask, random_angle, TF.InterpolationMode.NEAREST, self.expand, fill=self.seg_fill)
return img, mask
class CenterCrop:
def __init__(self, size: Union[int, List[int], Tuple[int]]) -> None:
"""Crops the image at the center
Args:
output_size: height and width of the crop box. If int, this size is used for both directions.
"""
self.size = (size, size) if isinstance(size, int) else size
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
return TF.center_crop(img, self.size), TF.center_crop(mask, self.size)
class RandomCrop:
def __init__(self, size: Union[int, List[int], Tuple[int]], p: float = 0.5) -> None:
"""Randomly Crops the image.
Args:
output_size: height and width of the crop box. If int, this size is used for both directions.
"""
self.size = (size, size) if isinstance(size, int) else size
self.p = p
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
H, W = img.shape[1:]
tH, tW = self.size
if random.random() < self.p:
margin_h = max(H - tH, 0)
margin_w = max(W - tW, 0)
y1 = random.randint(0, margin_h+1)
x1 = random.randint(0, margin_w+1)
y2 = y1 + tH
x2 = x1 + tW
img = img[:, y1:y2, x1:x2]
mask = mask[:, y1:y2, x1:x2]
return img, mask
class Pad:
def __init__(self, size: Union[List[int], Tuple[int], int], seg_fill: int = 0) -> None:
"""Pad the given image on all sides with the given "pad" value.
Args:
size: expected output image size (h, w)
fill: Pixel fill value for constant fill. Default is 0. This value is only used when the padding mode is constant.
"""
self.size = size
self.seg_fill = seg_fill
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
padding = (0, 0, self.size[1]-img.shape[2], self.size[0]-img.shape[1])
return TF.pad(img, padding), TF.pad(mask, padding, self.seg_fill)
class ResizePad:
def __init__(self, size: Union[int, Tuple[int], List[int]], seg_fill: int = 0) -> None:
"""Resize the input image to the given size.
Args:
size: Desired output size.
If size is a sequence, the output size will be matched to this.
If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio.
"""
self.size = size
self.seg_fill = seg_fill
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
H, W = img.shape[1:]
tH, tW = self.size
# scale the image
scale_factor = min(tH/H, tW/W) if W > H else max(tH/H, tW/W)
# nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5)
nH, nW = round(H*scale_factor), round(W*scale_factor)
img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR)
mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST)
# pad the image
padding = [0, 0, tW - nW, tH - nH]
img = TF.pad(img, padding, fill=0)
mask = TF.pad(mask, padding, fill=self.seg_fill)
return img, mask
class Resize:
def __init__(self, size: Union[int, Tuple[int], List[int]]) -> None:
"""Resize the input image to the given size.
Args:
size: Desired output size.
If size is a sequence, the output size will be matched to this.
If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio.
"""
self.size = size
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
H, W = img.shape[1:]
# scale the image
scale_factor = self.size[0] / min(H, W)
nH, nW = round(H*scale_factor), round(W*scale_factor)
img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR)
mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST)
# make the image divisible by stride
alignH, alignW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32
img = TF.resize(img, (alignH, alignW), TF.InterpolationMode.BILINEAR)
mask = TF.resize(mask, (alignH, alignW), TF.InterpolationMode.NEAREST)
return img, mask
class RandomResizedCrop:
def __init__(self, size: Union[int, Tuple[int], List[int]], scale: Tuple[float, float] = (0.5, 2.0), seg_fill: int = 0) -> None:
"""Resize the input image to the given size.
"""
self.size = size
self.scale = scale
self.seg_fill = seg_fill
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
H, W = img.shape[1:]
tH, tW = self.size
# get the scale
ratio = random.random() * (self.scale[1] - self.scale[0]) + self.scale[0]
# ratio = random.uniform(min(self.scale), max(self.scale))
scale = int(tH*ratio), int(tW*4*ratio)
# scale the image
scale_factor = min(max(scale)/max(H, W), min(scale)/min(H, W))
nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5)
# nH, nW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32
img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR)
mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST)
# random crop
margin_h = max(img.shape[1] - tH, 0)
margin_w = max(img.shape[2] - tW, 0)
y1 = random.randint(0, margin_h+1)
x1 = random.randint(0, margin_w+1)
y2 = y1 + tH
x2 = x1 + tW
img = img[:, y1:y2, x1:x2]
mask = mask[:, y1:y2, x1:x2]
# pad the image
if img.shape[1:] != self.size:
padding = [0, 0, tW - img.shape[2], tH - img.shape[1]]
img = TF.pad(img, padding, fill=0)
mask = TF.pad(mask, padding, fill=self.seg_fill)
return img, mask
def get_train_augmentation(size: Union[int, Tuple[int], List[int]], seg_fill: int = 0):
return Compose([
# ColorJitter(brightness=0.0, contrast=0.5, saturation=0.5, hue=0.5),
# RandomAdjustSharpness(sharpness_factor=0.1, p=0.5),
# RandomAutoContrast(p=0.2),
RandomHorizontalFlip(p=0.5),
# RandomVerticalFlip(p=0.5),
# RandomGaussianBlur((3, 3), p=0.5),
# RandomGrayscale(p=0.5),
# RandomRotation(degrees=10, p=0.3, seg_fill=seg_fill),
RandomResizedCrop(size, scale=(0.5, 2.0), seg_fill=seg_fill),
Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
def get_val_augmentation(size: Union[int, Tuple[int], List[int]]):
return Compose([
Resize(size),
Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
if __name__ == '__main__':
h = 230
w = 420
img = torch.randn(3, h, w)
mask = torch.randn(1, h, w)
aug = Compose([
RandomResizedCrop((512, 512)),
# RandomCrop((512, 512), p=1.0),
# Pad((512, 512))
])
img, mask = aug(img, mask)
print(img.shape, mask.shape)
================================================
FILE: semseg/augmentations_mm.py
================================================
import torchvision.transforms.functional as TF
import random
import math
import torch
from torch import Tensor
from typing import Tuple, List, Union, Tuple, Optional
class Compose:
def __init__(self, transforms: list) -> None:
self.transforms = transforms
def __call__(self, sample: list) -> list:
img, mask = sample['img'], sample['mask']
if mask.ndim == 2:
assert img.shape[1:] == mask.shape
else:
assert img.shape[1:] == mask.shape[1:]
for transform in self.transforms:
sample = transform(sample)
return sample
class Normalize:
def __init__(self, mean: list = (0.485, 0.456, 0.406), std: list = (0.229, 0.224, 0.225)):
self.mean = mean
self.std = std
def __call__(self, sample: list) -> list:
for k, v in sample.items():
if k == 'mask':
continue
elif k == 'img':
sample[k] = sample[k].float()
sample[k] /= 255
sample[k] = TF.normalize(sample[k], self.mean, self.std)
else:
sample[k] = sample[k].float()
sample[k] /= 255
return sample
class RandomColorJitter:
def __init__(self, p=0.5) -> None:
self.p = p
def __call__(self, sample: list) -> list:
if random.random() < self.p:
self.brightness = random.uniform(0.5, 1.5)
sample['img'] = TF.adjust_brightness(sample['img'], self.brightness)
self.contrast = random.uniform(0.5, 1.5)
sample['img'] = TF.adjust_contrast(sample['img'], self.contrast)
self.saturation = random.uniform(0.5, 1.5)
sample['img'] = TF.adjust_saturation(sample['img'], self.saturation)
return sample
class AdjustGamma:
def __init__(self, gamma: float, gain: float = 1) -> None:
"""
Args:
gamma: Non-negative real number. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter.
gain: constant multiplier
"""
self.gamma = gamma
self.gain = gain
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
return TF.adjust_gamma(img, self.gamma, self.gain), mask
class RandomAdjustSharpness:
def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
self.sharpness = sharpness_factor
self.p = p
def __call__(self, sample: list) -> list:
if random.random() < self.p:
sample['img'] = TF.adjust_sharpness(sample['img'], self.sharpness)
return sample
class RandomAutoContrast:
def __init__(self, p: float = 0.5) -> None:
self.p = p
def __call__(self, sample: list) -> list:
if random.random() < self.p:
sample['img'] = TF.autocontrast(sample['img'])
return sample
class RandomGaussianBlur:
def __init__(self, kernel_size: int = 3, p: float = 0.5) -> None:
self.kernel_size = kernel_size
self.p = p
def __call__(self, sample: list) -> list:
if random.random() < self.p:
sample['img'] = TF.gaussian_blur(sample['img'], self.kernel_size)
# img = TF.gaussian_blur(img, self.kernel_size)
return sample
class RandomHorizontalFlip:
def __init__(self, p: float = 0.5) -> None:
self.p = p
def __call__(self, sample: list) -> list:
if random.random() < self.p:
for k, v in sample.items():
sample[k] = TF.hflip(v)
return sample
return sample
class RandomVerticalFlip:
def __init__(self, p: float = 0.5) -> None:
self.p = p
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
if random.random() < self.p:
return TF.vflip(img), TF.vflip(mask)
return img, mask
class RandomGrayscale:
def __init__(self, p: float = 0.5) -> None:
self.p = p
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
if random.random() < self.p:
img = TF.rgb_to_grayscale(img, 3)
return img, mask
class Equalize:
def __call__(self, image, label):
return TF.equalize(image), label
class Posterize:
def __init__(self, bits=2):
self.bits = bits # 0-8
def __call__(self, image, label):
return TF.posterize(image, self.bits), label
class Affine:
def __init__(self, angle=0, translate=[0, 0], scale=1.0, shear=[0, 0], seg_fill=0):
self.angle = angle
self.translate = translate
self.scale = scale
self.shear = shear
self.seg_fill = seg_fill
def __call__(self, img, label):
return TF.affine(img, self.angle, self.translate, self.scale, self.shear, TF.InterpolationMode.BILINEAR, 0), TF.affine(label, self.angle, self.translate, self.scale, self.shear, TF.InterpolationMode.NEAREST, self.seg_fill)
class RandomRotation:
def __init__(self, degrees: float = 10.0, p: float = 0.2, seg_fill: int = 0, expand: bool = False) -> None:
"""Rotate the image by a random angle between -angle and angle with probability p
Args:
p: probability
angle: rotation angle value in degrees, counter-clockwise.
expand: Optional expansion flag.
If true, expands the output image to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
Note that the expand flag assumes rotation around the center and no translation.
"""
self.p = p
self.angle = degrees
self.expand = expand
self.seg_fill = seg_fill
def __call__(self, sample: list) -> list:
random_angle = random.random() * 2 * self.angle - self.angle
if random.random() < self.p:
for k, v in sample.items():
if k == 'mask':
sample[k] = TF.rotate(v, random_angle, TF.InterpolationMode.NEAREST, self.expand, fill=self.seg_fill)
else:
sample[k] = TF.rotate(v, random_angle, TF.InterpolationMode.BILINEAR, self.expand, fill=0)
# img = TF.rotate(img, random_angle, TF.InterpolationMode.BILINEAR, self.expand, fill=0)
# mask = TF.rotate(mask, random_angle, TF.InterpolationMode.NEAREST, self.expand, fill=self.seg_fill)
return sample
class CenterCrop:
def __init__(self, size: Union[int, List[int], Tuple[int]]) -> None:
"""Crops the image at the center
Args:
output_size: height and width of the crop box. If int, this size is used for both directions.
"""
self.size = (size, size) if isinstance(size, int) else size
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
return TF.center_crop(img, self.size), TF.center_crop(mask, self.size)
class RandomCrop:
def __init__(self, size: Union[int, List[int], Tuple[int]], p: float = 0.5) -> None:
"""Randomly Crops the image.
Args:
output_size: height and width of the crop box. If int, this size is used for both directions.
"""
self.size = (size, size) if isinstance(size, int) else size
self.p = p
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
H, W = img.shape[1:]
tH, tW = self.size
if random.random() < self.p:
margin_h = max(H - tH, 0)
margin_w = max(W - tW, 0)
y1 = random.randint(0, margin_h+1)
x1 = random.randint(0, margin_w+1)
y2 = y1 + tH
x2 = x1 + tW
img = img[:, y1:y2, x1:x2]
mask = mask[:, y1:y2, x1:x2]
return img, mask
class Pad:
def __init__(self, size: Union[List[int], Tuple[int], int], seg_fill: int = 0) -> None:
"""Pad the given image on all sides with the given "pad" value.
Args:
size: expected output image size (h, w)
fill: Pixel fill value for constant fill. Default is 0. This value is only used when the padding mode is constant.
"""
self.size = size
self.seg_fill = seg_fill
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
padding = (0, 0, self.size[1]-img.shape[2], self.size[0]-img.shape[1])
return TF.pad(img, padding), TF.pad(mask, padding, self.seg_fill)
class ResizePad:
def __init__(self, size: Union[int, Tuple[int], List[int]], seg_fill: int = 0) -> None:
"""Resize the input image to the given size.
Args:
size: Desired output size.
If size is a sequence, the output size will be matched to this.
If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio.
"""
self.size = size
self.seg_fill = seg_fill
def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
H, W = img.shape[1:]
tH, tW = self.size
# scale the image
scale_factor = min(tH/H, tW/W) if W > H else max(tH/H, tW/W)
# nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5)
nH, nW = round(H*scale_factor), round(W*scale_factor)
img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR)
mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST)
# pad the image
padding = [0, 0, tW - nW, tH - nH]
img = TF.pad(img, padding, fill=0)
mask = TF.pad(mask, padding, fill=self.seg_fill)
return img, mask
class Resize:
def __init__(self, size: Union[int, Tuple[int], List[int]]) -> None:
"""Resize the input image to the given size.
Args:
size: Desired output size.
If size is a sequence, the output size will be matched to this.
If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio.
"""
self.size = size
def __call__(self, sample:list) -> list:
H, W = sample['img'].shape[1:]
# scale the image
scale_factor = self.size[0] / min(H, W)
nH, nW = round(H*scale_factor), round(W*scale_factor)
for k, v in sample.items():
if k == 'mask':
sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.NEAREST)
else:
sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.BILINEAR)
# img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR)
# mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST)
# make the image divisible by stride
alignH, alignW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32
for k, v in sample.items():
if k == 'mask':
sample[k] = TF.resize(v, (alignH, alignW), TF.InterpolationMode.NEAREST)
else:
sample[k] = TF.resize(v, (alignH, alignW), TF.InterpolationMode.BILINEAR)
# img = TF.resize(img, (alignH, alignW), TF.InterpolationMode.BILINEAR)
# mask = TF.resize(mask, (alignH, alignW), TF.InterpolationMode.NEAREST)
return sample
class RandomResizedCrop:
def __init__(self, size: Union[int, Tuple[int], List[int]], scale: Tuple[float, float] = (0.5, 2.0), seg_fill: int = 0) -> None:
"""Resize the input image to the given size.
"""
self.size = size
self.scale = scale
self.seg_fill = seg_fill
def __call__(self, sample: list) -> list:
# img, mask = sample['img'], sample['mask']
H, W = sample['img'].shape[1:]
tH, tW = self.size
# get the scale
ratio = random.random() * (self.scale[1] - self.scale[0]) + self.scale[0]
# ratio = random.uniform(min(self.scale), max(self.scale))
scale = int(tH*ratio), int(tW*4*ratio)
# scale the image
scale_factor = min(max(scale)/max(H, W), min(scale)/min(H, W))
nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5)
# nH, nW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32
for k, v in sample.items():
if k == 'mask':
sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.NEAREST)
else:
sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.BILINEAR)
# random crop
margin_h = max(sample['img'].shape[1] - tH, 0)
margin_w = max(sample['img'].shape[2] - tW, 0)
y1 = random.randint(0, margin_h+1)
x1 = random.randint(0, margin_w+1)
y2 = y1 + tH
x2 = x1 + tW
for k, v in sample.items():
sample[k] = v[:, y1:y2, x1:x2]
# pad the image
if sample['img'].shape[1:] != self.size:
padding = [0, 0, tW - sample['img'].shape[2], tH - sample['img'].shape[1]]
for k, v in sample.items():
if k == 'mask':
sample[k] = TF.pad(v, padding, fill=self.seg_fill)
else:
sample[k] = TF.pad(v, padding, fill=0)
return sample
def get_train_augmentation(size: Union[int, Tuple[int], List[int]], seg_fill: int = 0):
return Compose([
RandomColorJitter(p=0.2), #
RandomHorizontalFlip(p=0.5), #
RandomGaussianBlur((3, 3), p=0.2), #
RandomResizedCrop(size, scale=(0.5, 2.0), seg_fill=seg_fill), #
Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
def get_val_augmentation(size: Union[int, Tuple[int], List[int]]):
return Compose([
Resize(size),
Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
if __name__ == '__main__':
h = 230
w = 420
sample = {}
sample['img'] = torch.randn(3, h, w)
sample['depth'] = torch.randn(3, h, w)
sample['lidar'] = torch.randn(3, h, w)
sample['event'] = torch.randn(3, h, w)
sample['mask'] = torch.randn(1, h, w)
aug = Compose([
RandomHorizontalFlip(p=0.5),
RandomResizedCrop((512, 512)),
Resize((224, 224)),
Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
sample = aug(sample)
for k, v in sample.items():
print(k, v.shape)
================================================
FILE: semseg/datasets/__init__.py
================================================
from .deliver import DELIVER
from .kitti360 import KITTI360
from .nyu import NYU
from .mfnet import MFNet
from .urbanlf import UrbanLF
from .mcubes import MCubeS
__all__ = [
'DELIVER',
'KITTI360',
'NYU',
'MFNet',
'UrbanLF',
'MCubeS'
]
================================================
FILE: semseg/datasets/deliver.py
================================================
import os
import torch
import numpy as np
from torch import Tensor
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF
from torchvision import io
from pathlib import Path
from typing import Tuple
import glob
import einops
from torch.utils.data import DataLoader
from torch.utils.data import DistributedSampler, RandomSampler
from semseg.augmentations_mm import get_train_augmentation
class DELIVER(Dataset):
"""
num_classes: 25
"""
CLASSES = ["Building", "Fence", "Other", "Pedestrian", "Pole", "RoadLine", "Road", "SideWalk", "Vegetation",
"Cars", "Wall", "TrafficSign", "Sky", "Ground", "Bridge", "RailTrack", "GroundRail",
"TrafficLight", "Static", "Dynamic", "Water", "Terrain", "TwoWheeler", "Bus", "Truck"]
PALETTE = torch.tensor([[70, 70, 70],
[100, 40, 40],
[55, 90, 80],
[220, 20, 60],
[153, 153, 153],
[157, 234, 50],
[128, 64, 128],
[244, 35, 232],
[107, 142, 35],
[0, 0, 142],
[102, 102, 156],
[220, 220, 0],
[70, 130, 180],
[81, 0, 81],
[150, 100, 100],
[230, 150, 140],
[180, 165, 180],
[250, 170, 30],
[110, 190, 160],
[170, 120, 50],
[45, 60, 150],
[145, 170, 100],
[ 0, 0, 230],
[ 0, 60, 100],
[ 0, 0, 70],
])
def __init__(self, root: str = 'data/DELIVER', split: str = 'train', transform = None, modals = ['img'], case = None) -> None:
super().__init__()
assert split in ['train', 'val', 'test']
self.transform = transform
self.n_classes = len(self.CLASSES)
self.ignore_label = 255
self.modals = modals
self.files = sorted(glob.glob(os.path.join(*[root, 'img', '*', split, '*', '*.png'])))
# --- debug
# self.files = sorted(glob.glob(os.path.join(*[root, 'img', '*', split, '*', '*.png'])))[:100]
# --- split as case
if case is not None:
assert case in ['cloud', 'fog', 'night', 'rain', 'sun', 'motionblur', 'overexposure', 'underexposure', 'lidarjitter', 'eventlowres'], "Case name not available."
_temp_files = [f for f in self.files if case in f]
self.files = _temp_files
if not self.files:
raise Exception(f"No images found in {img_path}")
print(f"Found {len(self.files)} {split} {case} images.")
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
rgb = str(self.files[index])
x1 = rgb.replace('/img', '/hha').replace('_rgb', '_depth')
x2 = rgb.replace('/img', '/lidar').replace('_rgb', '_lidar')
x3 = rgb.replace('/img', '/event').replace('_rgb', '_event')
lbl_path = rgb.replace('/img', '/semantic').replace('_rgb', '_semantic')
sample = {}
sample['img'] = io.read_image(rgb)[:3, ...]
H, W = sample['img'].shape[1:]
if 'depth' in self.modals:
sample['depth'] = self._open_img(x1)
if 'lidar' in self.modals:
sample['lidar'] = self._open_img(x2)
if 'event' in self.modals:
eimg = self._open_img(x3)
sample['event'] = TF.resize(eimg, (H, W), TF.InterpolationMode.NEAREST)
label = io.read_image(lbl_path)[0,...].unsqueeze(0)
label[label==255] = 0
label -= 1
sample['mask'] = label
if self.transform:
sample = self.transform(sample)
label = sample['mask']
del sample['mask']
label = self.encode(label.squeeze().numpy()).long()
sample = [sample[k] for k in self.modals]
return sample, label
def _open_img(self, file):
img = io.read_image(file)
C, H, W = img.shape
if C == 4:
img = img[:3, ...]
if C == 1:
img = img.repeat(3, 1, 1)
return img
def encode(self, label: Tensor) -> Tensor:
return torch.from_numpy(label)
if __name__ == '__main__':
cases = ['cloud', 'fog', 'night', 'rain', 'sun', 'motionblur', 'overexposure', 'underexposure', 'lidarjitter', 'eventlowres']
traintransform = get_train_augmentation((1024, 1024), seg_fill=255)
for case in cases:
trainset = DELIVER(transform=traintransform, split='val', case=case)
trainloader = DataLoader(trainset, batch_size=2, num_workers=2, drop_last=False, pin_memory=False)
for i, (sample, lbl) in enumerate(trainloader):
print(torch.unique(lbl))
================================================
FILE: semseg/datasets/kitti360.py
================================================
import os
import torch
import numpy as np
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import io
from pathlib import Path
from typing import Tuple
import glob
import einops
from torch.utils.data import DataLoader
from torch.utils.data import DistributedSampler, RandomSampler
from semseg.augmentations_mm import get_train_augmentation
class KITTI360(Dataset):
"""
num_classes: 19
"""
CLASSES = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation',
'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle']
PALETTE = torch.tensor([[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35],
[152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]])
ID2TRAINID = {0:255, 1:255, 2:255, 3:255, 4:255, 5:255, 6:255, 7:0, 8:1, 9:255, 10:255, 11:2, 12:3, 13:4, 14:255, 15:255, 16:255, 17:5, 18:255, 19:6,
20:7, 21:8, 22:9, 23:10, 24:11, 25:12, 26:13, 27:14, 28:15, 29:255, 30:255, 31:16, 32:17, 33:18, 34:2, 35:4, 36:255, 37:5, 38:255, 39:255, 40:255, 41:255, 42:255, 43:255, 44:255, -1:255}
def __init__(self, root: str = 'data/KITTI360', split: str = 'train', transform = None, modals = ['img', 'depth', 'event', 'lidar'], case = None) -> None:
super().__init__()
assert split in ['train', 'val']
self.root = root
self.transform = transform
self.n_classes = len(self.CLASSES)
self.ignore_label = 255
self.modals = modals
self.label_map = np.arange(256)
for id, trainid in self.ID2TRAINID.items():
self.label_map[id] = trainid
self.files = self._get_file_names(split)
# --- debug
# self.files = self._get_file_names(split)[:100]
if not self.files:
raise Exception(f"No images found in {img_path}")
print(f"Found {len(self.files)} {split} images.")
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
item_name = str(self.files[index])
rgb = os.path.join(self.root, item_name)
x1 = os.path.join(self.root, item_name.replace('data_2d_raw', 'data_2d_hha'))
x2 = os.path.join(self.root, item_name.replace('data_2d_raw', 'data_2d_lidar'))
x2 = x2.replace('.png', '_color.png')
x3 = os.path.join(self.root, item_name.replace('data_2d_raw', 'data_2d_event'))
x3 = x3.replace('/image_00/data_rect/', '/').replace('.png', '_event_image.png')
lbl_path = os.path.join(*[self.root, item_name.replace('data_2d_raw', 'data_2d_semantics/train').replace('data_rect', 'semantic')])
sample = {}
sample['img'] = io.read_image(rgb)[:3, ...]
if 'depth' in self.modals:
sample['depth'] = self._open_img(x1)
if 'lidar' in self.modals:
sample['lidar'] = self._open_img(x2)
if 'event' in self.modals:
sample['event'] = self._open_img(x3)
label = io.read_image(lbl_path)[0,...].unsqueeze(0)
sample['mask'] = label
if self.transform:
sample = self.transform(sample)
label = sample['mask']
del sample['mask']
label = self.encode(label.squeeze().numpy()).long()
sample = [sample[k] for k in self.modals]
return sample, label
def _open_img(self, file):
img = io.read_image(file)
C, H, W = img.shape
if C == 4:
img = img[:3, ...]
if C == 1:
img = img.repeat(3, 1, 1)
return img
def encode(self, label: Tensor) -> Tensor:
label = self.label_map[label]
return torch.from_numpy(label)
def _get_file_names(self, split_name):
assert split_name in ['train', 'val']
source = os.path.join(self.root, '{}.txt'.format(split_name))
file_names = []
with open(source) as f:
files = f.readlines()
for item in files:
file_name = item.strip()
if ' ' in file_name:
# --- KITTI-360
file_name = file_name.split(' ')[0]
file_names.append(file_name)
return file_names
if __name__ == '__main__':
traintransform = get_train_augmentation((376, 1408), seg_fill=255)
trainset = KITTI360(transform=traintransform)
trainloader = DataLoader(trainset, batch_size=2, num_workers=2, drop_last=True, pin_memory=False)
for i, (sample, lbl) in enumerate(trainloader):
print(torch.unique(lbl))
================================================
FILE: semseg/datasets/mcubes.py
================================================
import os
import torch
import numpy as np
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import io
from torchvision import transforms
from pathlib import Path
from typing import Tuple
import glob
import einops
from torch.utils.data import DataLoader
from torch.utils.data import DistributedSampler, RandomSampler
from semseg.augmentations_mm import get_train_augmentation
import cv2
import random
from PIL import Image, ImageOps, ImageFilter
class MCubeS(Dataset):
"""
num_classes: 20
"""
CLASSES = ['asphalt','concrete','metal','road_marking','fabric','glass','plaster','plastic','rubber','sand',
'gravel','ceramic','cobblestone','brick','grass','wood','leaf','water','human','sky',]
PALETTE = torch.tensor([[ 44, 160, 44],
[ 31, 119, 180],
[255, 127, 14],
[214, 39, 40],
[140, 86, 75],
[127, 127, 127],
[188, 189, 34],
[255, 152, 150],
[ 23, 190, 207],
[174, 199, 232],
[196, 156, 148],
[197, 176, 213],
[247, 182, 210],
[199, 199, 199],
[219, 219, 141],
[158, 218, 229],
[ 57, 59, 121],
[107, 110, 207],
[156, 158, 222],
[ 99, 121, 57]])
def __init__(self, root: str = 'data/MCubeS/multimodal_dataset', split: str = 'train', transform = None, modals = ['image', 'aolp', 'dolp', 'nir'], case = None) -> None:
super().__init__()
assert split in ['train', 'val']
self.split = split
self.root = root
self.transform = transform
self.n_classes = len(self.CLASSES)
self.ignore_label = 255
self.modals = modals
self._left_offset = 192
self.img_h = 1024
self.img_w = 1224
max_dim = max(self.img_h, self.img_w)
u_vec = (np.arange(self.img_w)-self.img_w/2)/max_dim*2
v_vec = (np.arange(self.img_h)-self.img_h/2)/max_dim*2
self.u_map, self.v_map = np.meshgrid(u_vec, v_vec)
self.u_map = self.u_map[:,:self._left_offset]
self.base_size = 512
self.crop_size = 512
self.files = self._get_file_names(split)
if not self.files:
raise Exception(f"No images found in {img_path}")
print(f"Found {len(self.files)} {split} images.")
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
item_name = str(self.files[index])
rgb = os.path.join(*[self.root, 'polL_color', item_name+'.png'])
x1 = os.path.join(*[self.root, 'polL_aolp_sin', item_name+'.npy'])
x1_1 = os.path.join(*[self.root, 'polL_aolp_cos', item_name+'.npy'])
x2 = os.path.join(*[self.root, 'polL_dolp', item_name+'.npy'])
x3 = os.path.join(*[self.root, 'NIR_warped', item_name+'.png'])
lbl_path = os.path.join(*[self.root, 'GT', item_name+'.png'])
nir_mask = os.path.join(*[self.root, 'NIR_warped_mask', item_name+'.png'])
_mask = os.path.join(*[self.root, 'SS', item_name+'.png'])
_img = cv2.imread(rgb,-1)[:,:,::-1]
_img = _img.astype(np.float32)/65535 if _img.dtype==np.uint16 else _img.astype(np.float32)/255
_target = cv2.imread(lbl_path,-1)
_mask = cv2.imread(_mask,-1)
_aolp_sin = np.load(x1)
_aolp_cos = np.load(x1_1)
_aolp = np.stack([_aolp_sin, _aolp_cos, _aolp_sin], axis=2) # H x W x 3
dolp = np.load(x2)
_dolp = np.stack([dolp, dolp, dolp], axis=2) # H x W x 3
nir = cv2.imread(x3,-1)
nir = nir.astype(np.float32)/65535 if nir.dtype==np.uint16 else nir.astype(np.float32)/255
_nir = np.stack([nir, nir, nir], axis=2) # H x W x 3
_nir_mask = cv2.imread(nir_mask,0)
_img, _target, _aolp, _dolp, _nir, _nir_mask, _mask = _img[:,self._left_offset:], _target[:,self._left_offset:], \
_aolp[:,self._left_offset:], _dolp[:,self._left_offset:], \
_nir[:,self._left_offset:], _nir_mask[:,self._left_offset:], _mask[:,self._left_offset:]
sample = {'image': _img, 'label': _target, 'aolp': _aolp, 'dolp': _dolp, 'nir': _nir, 'nir_mask': _nir_mask, 'u_map': self.u_map, 'v_map': self.v_map, 'mask':_mask}
if self.split == "train":
sample = self.transform_tr(sample)
elif self.split == 'val':
sample = self.transform_val(sample)
elif self.split == 'test':
sample = self.transform_val(sample)
else:
raise NotImplementedError()
label = sample['label'].long()
sample = [sample[k] for k in self.modals]
return sample, label
def transform_tr(self, sample):
composed_transforms = transforms.Compose([
RandomHorizontalFlip(),
RandomScaleCrop(base_size=self.base_size, crop_size=self.crop_size, fill=255),
RandomGaussianBlur(),
Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensor()])
return composed_transforms(sample)
def transform_val(self, sample):
composed_transforms = transforms.Compose([
FixScaleCrop(crop_size=1024),
Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensor()])
return composed_transforms(sample)
def _get_file_names(self, split_name):
assert split_name in ['train', 'val']
source = os.path.join(self.root, 'list_folder/test.txt') if split_name == 'val' else os.path.join(self.root, 'list_folder/train.txt')
file_names = []
with open(source) as f:
files = f.readlines()
for item in files:
file_name = item.strip()
if ' ' in file_name:
# --- KITTI-360
file_name = file_name.split(' ')[0]
file_names.append(file_name)
return file_names
class Normalize(object):
"""Normalize a tensor image with mean and standard deviation.
Args:
mean (tuple): means for each channel.
std (tuple): standard deviations for each channel.
"""
def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
self.mean = mean
self.std = std
def __call__(self, sample):
img = sample['image']
mask = sample['label']
img = np.array(img).astype(np.float32)
mask = np.array(mask).astype(np.float32)
img -= self.mean
img /= self.std
nir = sample['nir']
nir = np.array(nir).astype(np.float32)
# nir /= 255
return {'image': img,
'label': mask,
'aolp' : sample['aolp'],
'dolp' : sample['dolp'],
'nir' : nir,
'nir_mask': sample['nir_mask'],
'u_map': sample['u_map'],
'v_map': sample['v_map'],
'mask':sample['mask']}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
# swap color axis because
# numpy image: H x W x C
# torch image: C X H X W
img = sample['image']
mask = sample['label']
aolp = sample['aolp']
dolp = sample['dolp']
nir = sample['nir']
nir_mask = sample['nir_mask']
SS=sample['mask']
img = np.array(img).astype(np.float32).transpose((2, 0, 1))
mask = np.array(mask).astype(np.float32)
aolp = np.array(aolp).astype(np.float32).transpose((2, 0, 1))
dolp = np.array(dolp).astype(np.float32).transpose((2, 0, 1))
SS = np.array(SS).astype(np.float32)
nir = np.array(nir).astype(np.float32).transpose((2, 0, 1))
nir_mask = np.array(nir_mask).astype(np.float32)
img = torch.from_numpy(img).float()
mask = torch.from_numpy(mask).float()
aolp = torch.from_numpy(aolp).float()
dolp = torch.from_numpy(dolp).float()
SS = torch.from_numpy(SS).float()
nir = torch.from_numpy(nir).float()
nir_mask = torch.from_numpy(nir_mask).float()
u_map = sample['u_map']
v_map = sample['v_map']
u_map = torch.from_numpy(u_map.astype(np.float32)).float()
v_map = torch.from_numpy(v_map.astype(np.float32)).float()
return {'image': img,
'label': mask,
'aolp' : aolp,
'dolp' : dolp,
'nir' : nir,
'nir_mask' : nir_mask,
'u_map': u_map,
'v_map': v_map,
'mask':SS}
class RandomHorizontalFlip(object):
def __call__(self, sample):
img = sample['image']
mask = sample['label']
aolp = sample['aolp']
dolp = sample['dolp']
nir = sample['nir']
nir_mask = sample['nir_mask']
u_map = sample['u_map']
v_map = sample['v_map']
SS=sample['mask']
if random.random() < 0.5:
# img = img.transpose(Image.FLIP_LEFT_RIGHT)
# mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
# nir = nir.transpose(Image.FLIP_LEFT_RIGHT)
img = img[:,::-1]
mask = mask[:,::-1]
nir = nir[:,::-1]
nir_mask = nir_mask[:,::-1]
aolp = aolp[:,::-1]
dolp = dolp[:,::-1]
SS = SS[:,::-1]
u_map = u_map[:,::-1]
return {'image': img,
'label': mask,
'aolp' : aolp,
'dolp' : dolp,
'nir' : nir,
'nir_mask' : nir_mask,
'u_map': u_map,
'v_map': v_map,
'mask':SS}
class RandomGaussianBlur(object):
def __call__(self, sample):
img = sample['image']
mask = sample['label']
nir = sample['nir']
if random.random() < 0.5:
radius = random.random()
# img = img.filter(ImageFilter.GaussianBlur(radius=radius))
# nir = nir.filter(ImageFilter.GaussianBlur(radius=radius))
img = cv2.GaussianBlur(img, (0,0), radius)
nir = cv2.GaussianBlur(nir, (0,0), radius)
return {'image': img,
'label': mask,
'aolp' : sample['aolp'],
'dolp' : sample['dolp'],
'nir' : nir,
'nir_mask': sample['nir_mask'],
'u_map': sample['u_map'],
'v_map': sample['v_map'],
'mask':sample['mask']}
class RandomScaleCrop(object):
def __init__(self, base_size, crop_size, fill=255):
self.base_size = base_size
self.crop_size = crop_size
self.fill = fill
def __call__(self, sample):
img = sample['image']
mask = sample['label']
aolp = sample['aolp']
dolp = sample['dolp']
nir = sample['nir']
nir_mask = sample['nir_mask']
SS=sample['mask']
# random scale (short edge)
short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
# w, h = img.size
h, w = img.shape[:2]
if h > w:
ow = short_size
oh = int(1.0 * h * ow / w)
else:
oh = short_size
ow = int(1.0 * w * oh / h)
# pad crop
if short_size < self.crop_size:
padh = self.crop_size - oh if oh < self.crop_size else 0
padw = self.crop_size - ow if ow < self.crop_size else 0
# random crop crop_size
# w, h = img.size
h, w = img.shape[:2]
# x1 = random.randint(0, w - self.crop_size)
# y1 = random.randint(0, h - self.crop_size)
x1 = random.randint(0, max(0, ow - self.crop_size))
y1 = random.randint(0, max(0, oh - self.crop_size))
u_map = sample['u_map']
v_map = sample['v_map']
u_map = cv2.resize(u_map,(ow,oh))
v_map = cv2.resize(v_map,(ow,oh))
aolp = cv2.resize(aolp ,(ow,oh))
dolp = cv2.resize(dolp ,(ow,oh))
SS = cv2.resize(SS ,(ow,oh))
img = cv2.resize(img ,(ow,oh), interpolation=cv2.INTER_LINEAR)
mask = cv2.resize(mask ,(ow,oh), interpolation=cv2.INTER_NEAREST)
nir = cv2.resize(nir ,(ow,oh), interpolation=cv2.INTER_LINEAR)
nir_mask = cv2.resize(nir_mask ,(ow,oh), interpolation=cv2.INTER_NEAREST)
if short_size < self.crop_size:
u_map_ = np.zeros((oh+padh,ow+padw))
u_map_[:oh,:ow] = u_map
u_map = u_map_
v_map_ = np.zeros((oh+padh,ow+padw))
v_map_[:oh,:ow] = v_map
v_map = v_map_
aolp_ = np.zeros((oh+padh,ow+padw,3))
aolp_[:oh,:ow] = aolp
aolp = aolp_
dolp_ = np.zeros((oh+padh,ow+padw,3))
dolp_[:oh,:ow] = dolp
dolp = dolp_
img_ = np.zeros((oh+padh,ow+padw,3))
img_[:oh,:ow] = img
img = img_
SS_ = np.zeros((oh+padh,ow+padw))
SS_[:oh,:ow] = SS
SS = SS_
mask_ = np.full((oh+padh,ow+padw),self.fill)
mask_[:oh,:ow] = mask
mask = mask_
nir_ = np.zeros((oh+padh,ow+padw,3))
nir_[:oh,:ow] = nir
nir = nir_
nir_mask_ = np.zeros((oh+padh,ow+padw))
nir_mask_[:oh,:ow] = nir_mask
nir_mask = nir_mask_
u_map = u_map[y1:y1+self.crop_size, x1:x1+self.crop_size]
v_map = v_map[y1:y1+self.crop_size, x1:x1+self.crop_size]
aolp = aolp[y1:y1+self.crop_size, x1:x1+self.crop_size]
dolp = dolp[y1:y1+self.crop_size, x1:x1+self.crop_size]
img = img[y1:y1+self.crop_size, x1:x1+self.crop_size]
mask = mask[y1:y1+self.crop_size, x1:x1+self.crop_size]
nir = nir[y1:y1+self.crop_size, x1:x1+self.crop_size]
SS = SS[y1:y1+self.crop_size, x1:x1+self.crop_size]
nir_mask = nir_mask[y1:y1+self.crop_size, x1:x1+self.crop_size]
return {'image': img,
'label': mask,
'aolp' : aolp,
'dolp' : dolp,
'nir' : nir,
'nir_mask' : nir_mask,
'u_map': u_map,
'v_map': v_map,
'mask':SS}
class FixScaleCrop(object):
def __init__(self, crop_size):
self.crop_size = crop_size
def __call__(self, sample):
img = sample['image']
mask = sample['label']
aolp = sample['aolp']
dolp = sample['dolp']
nir = sample['nir']
nir_mask = sample['nir_mask']
SS = sample['mask']
# w, h = img.size
h, w = img.shape[:2]
if w > h:
oh = self.crop_size
ow = int(1.0 * w * oh / h)
else:
ow = self.crop_size
oh = int(1.0 * h * ow / w)
# img = img.resize((ow, oh), Image.BILINEAR)
# mask = mask.resize((ow, oh), Image.NEAREST)
# nir = nir.resize((ow, oh), Image.BILINEAR)
# center crop
# w, h = img.size
# h, w = img.shape[:2]
x1 = int(round((ow - self.crop_size) / 2.))
y1 = int(round((oh - self.crop_size) / 2.))
# img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
# mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
# nir = nir.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
u_map = sample['u_map']
v_map = sample['v_map']
u_map = cv2.resize(u_map,(ow,oh))
v_map = cv2.resize(v_map,(ow,oh))
aolp = cv2.resize(aolp ,(ow,oh))
dolp = cv2.resize(dolp ,(ow,oh))
SS = cv2.resize(SS ,(ow,oh))
img = cv2.resize(img ,(ow,oh), interpolation=cv2.INTER_LINEAR)
mask = cv2.resize(mask ,(ow,oh), interpolation=cv2.INTER_NEAREST)
nir = cv2.resize(nir ,(ow,oh), interpolation=cv2.INTER_LINEAR)
nir_mask = cv2.resize(nir_mask,(ow,oh), interpolation=cv2.INTER_NEAREST)
u_map = u_map[y1:y1+self.crop_size, x1:x1+self.crop_size]
v_map = v_map[y1:y1+self.crop_size, x1:x1+self.crop_size]
aolp = aolp[y1:y1+self.crop_size, x1:x1+self.crop_size]
dolp = dolp[y1:y1+self.crop_size, x1:x1+self.crop_size]
img = img[y1:y1+self.crop_size, x1:x1+self.crop_size]
mask = mask[y1:y1+self.crop_size, x1:x1+self.crop_size]
SS = SS[y1:y1+self.crop_size, x1:x1+self.crop_size]
nir = nir[y1:y1+self.crop_size, x1:x1+self.crop_size]
nir_mask = nir_mask[y1:y1+self.crop_size, x1:x1+self.crop_size]
return {'image': img,
'label': mask,
'aolp' : aolp,
'dolp' : dolp,
'nir' : nir,
'nir_mask' : nir_mask,
'u_map': u_map,
'v_map': v_map,
'mask':SS}
if __name__ == '__main__':
traintransform = get_train_augmentation((1024, 1224), seg_fill=255)
trainset = MCubeS(transform=traintransform, split='val')
trainloader = DataLoader(trainset, batch_size=1, num_workers=0, drop_last=False, pin_memory=False)
for i, (sample, lbl) in enumerate(trainloader):
print(torch.unique(lbl))
================================================
FILE: semseg/datasets/mfnet.py
================================================
import os
import torch
import numpy as np
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import io
from pathlib import Path
from typing import Tuple
import glob
import einops
from torch.utils.data import DataLoader
from torch.utils.data import DistributedSampler, RandomSampler
from semseg.augmentations_mm import get_train_augmentation
class MFNet(Dataset):
"""
num_classes: 9
"""
CLASSES = ['unlabeled', 'car', 'person', 'bike', 'curve', 'car_stop', 'guardrail', 'color_cone', 'bump']
PALETTE = torch.tensor([[64,0,128],[64,64,0],[0,128,192],[0,0,192],[128,128,0],[64,64,128],[192,128,128],[192,64,0]])
def __init__(self, root: str = 'data/MFNet', split: str = 'train', transform = None, modals = ['img', 'thermal'], case = None) -> None:
super().__init__()
assert split in ['train', 'val']
self.root = root
self.transform = transform
self.n_classes = len(self.CLASSES)
self.ignore_label = 255
self.modals = modals
self.files = self._get_file_names(split)
if not self.files:
raise Exception(f"No images found in {img_path}")
print(f"Found {len(self.files)} {split} images.")
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
item_name = str(self.files[index])
rgb = os.path.join(*[self.root, 'rgb', item_name+'.jpg'])
x1 = os.path.join(*[self.root, 'ther', item_name+'.jpg'])
lbl_path = os.path.join(*[self.root, 'labels', item_name+'.png'])
sample = {}
sample['img'] = io.read_image(rgb)[:3, ...]
if 'thermal' in self.modals:
sample['thermal'] = self._open_img(x1)
label = io.read_image(lbl_path)[0,...].unsqueeze(0)
sample['mask'] = label
if self.transform:
sample = self.transform(sample)
label = sample['mask']
del sample['mask']
label = self.encode(label.squeeze().numpy()).long()
sample = [sample[k] for k in self.modals]
return sample, label
def _open_img(self, file):
img = io.read_image(file)
C, H, W = img.shape
if C == 4:
img = img[:3, ...]
if C == 1:
img = img.repeat(3, 1, 1)
return img
def encode(self, label: Tensor) -> Tensor:
return torch.from_numpy(label)
def _get_file_names(self, split_name):
assert split_name in ['train', 'val']
source = os.path.join(self.root, 'test.txt') if split_name == 'val' else os.path.join(self.root, 'train.txt')
file_names = []
with open(source) as f:
files = f.readlines()
for item in files:
file_name = item.strip()
if ' ' in file_name:
file_name = file_name.split(' ')[0]
file_names.append(file_name)
return file_names
if __name__ == '__main__':
traintransform = get_train_augmentation((480, 640), seg_fill=255)
trainset = MFNet(transform=traintransform)
trainloader = DataLoader(trainset, batch_size=2, num_workers=2, drop_last=True, pin_memory=False)
for i, (sample, lbl) in enumerate(trainloader):
print(torch.unique(lbl))
================================================
FILE: semseg/datasets/nyu.py
================================================
import os
import torch
import numpy as np
from torch import Tensor
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF
from torchvision import io
from pathlib import Path
from typing import Tuple
import glob
import einops
from torch.utils.data import DataLoader
from torch.utils.data import DistributedSampler, RandomSampler
from semseg.augmentations_mm import get_train_augmentation
class NYU(Dataset):
"""
num_classes: 40
"""
CLASSES = ['wall','floor','cabinet','bed','chair','sofa','table','door','window','bookshelf','picture','counter','blinds',
'desk','shelves','curtain','dresser','pillow','mirror','floor mat','clothes','ceiling','books','refridgerator',
'television','paper','towel','shower curtain','box','whiteboard','person','night stand','toilet',
'sink','lamp','bathtub','bag','otherstructure','otherfurniture','otherprop']
PALETTE = None
def __init__(self, root: str = 'data/NYUDepthv2', split: str = 'train', transform = None, modals = ['img', 'depth'], case = None) -> None:
super().__init__()
assert split in ['train', 'val']
self.root = root
self.transform = transform
self.n_classes = len(self.CLASSES)
self.ignore_label = 255
self.modals = modals
self.files = self._get_file_names(split)
if not self.files:
raise Exception(f"No images found in {img_path}")
print(f"Found {len(self.files)} {split} images.")
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
item_name = str(self.files[index])
rgb = os.path.join(*[self.root, 'RGB', item_name+'.jpg'])
x1 = os.path.join(*[self.root, 'HHA', item_name+'.jpg'])
lbl_path = os.path.join(*[self.root, 'Label', item_name+'.png'])
sample = {}
sample['img'] = io.read_image(rgb)[:3, ...]
if 'depth' in self.modals:
sample['depth'] = self._open_img(x1)
if 'lidar' in self.modals:
raise NotImplementedError()
if 'event' in self.modals:
raise NotImplementedError()
label = io.read_image(lbl_path)[0,...].unsqueeze(0)
label[label==255] = 0
label -= 1
sample['mask'] = label
if self.transform:
sample = self.transform(sample)
label = sample['mask']
del sample['mask']
label = self.encode(label.squeeze().numpy()).long()
sample = [sample[k] for k in self.modals]
return sample, label
def _open_img(self, file):
img = io.read_image(file)
C, H, W = img.shape
if C == 4:
img = img[:3, ...]
if C == 1:
img = img.repeat(3, 1, 1)
return img
def encode(self, label: Tensor) -> Tensor:
return torch.from_numpy(label)
def _get_file_names(self, split_name):
assert split_name in ['train', 'val']
source = os.path.join(self.root, 'test.txt') if split_name == 'val' else os.path.join(self.root, 'train.txt')
file_names = []
with open(source) as f:
files = f.readlines()
for item in files:
file_name = item.strip()
if ' ' in file_name:
file_name = file_name.split(' ')[0]
file_names.append(file_name)
return file_names
if __name__ == '__main__':
traintransform = get_train_augmentation((480, 640), seg_fill=255)
trainset = NYU(transform=traintransform, split='val')
trainloader = DataLoader(trainset, batch_size=2, num_workers=2, drop_last=True, pin_memory=False)
for i, (sample, lbl) in enumerate(trainloader):
print(torch.unique(lbl))
================================================
FILE: semseg/datasets/unzip.py
================================================
import zipfile
with zipfile.ZipFile("data/MCubeS/multimodal_dataset.zip", "r") as zip_ref:
for name in zip_ref.namelist():
try:
zip_ref.extract(name, "multimodal_dataset_extracted/")
except zipfile.BadZipFile as e:
print(e)
================================================
FILE: semseg/datasets/urbanlf.py
================================================
import os
import torch
import numpy as np
from torch import Tensor
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF
from torchvision import io
from pathlib import Path
from typing import Tuple
import glob
import einops
from torch.utils.data import DataLoader
from torch.utils.data import DistributedSampler, RandomSampler
from semseg.augmentations_mm import get_train_augmentation
class UrbanLF(Dataset):
"""
num_classes: 14
"""
CLASSES = ['bike','building','fence','others','person','pole','road','sidewalk','traffic sign','vegetation','vehicle','bridge','rider','sky']
PALETTE = [[168,198,168],[198,0,0],[202,154,198],[0,0,0],[100,198,198],[198,100,0],[52,42,198],[154,52,192],[198,0,168],[0,198,0],[198,186,90],[108,107,161],[156,200,26],[158,179,202]]
def __init__(self, root: str = 'data/UrBanLF/Syn', split: str = 'train', transform = None, modals = ['img', '5_1', '5_2', '5_3', '5_4', '5_6', '5_7', '5_8', '5_9'], case = None) -> None:
super().__init__()
assert split in ['train', 'val']
self.root = root
self.transform = transform
self.n_classes = len(self.CLASSES)
self.ignore_label = 255
self.modals = modals
self.files = sorted(glob.glob(os.path.join(*[root, split, '*', '5_5.png'])))
if not self.files:
raise Exception(f"No images found in {img_path}")
print(f"Found {len(self.files)} {split} images.")
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
item_name = str(self.files[index])
rgb = item_name
rgb_dir_name = os.path.dirname(rgb)
lf_names = []
lf_paths = []
for i in range(1, 10):
for j in range(1, 10):
lf_name = '{}_{}'.format(i, j)
if lf_name != '5_5':
if lf_name in self.modals:
lf_names.append(lf_name)
lf_paths.append(os.path.join(rgb_dir_name, lf_name+'.png'))
if 'real' in self.root:
lbl_path = item_name.replace('5_5', 'label')
elif 'Syn' in self.root:
lbl_path = item_name.replace('5_5.png', '5_5_label.npy')
else:
raise NotImplemented
sample = {}
sample['img'] = io.read_image(rgb)[:3, ...]
if len(self.modals) > 1:
for i, lf_name in enumerate(lf_names):
assert lf_name in lf_paths[i], "Not matched."
sample[lf_name] = self._open_img(lf_paths[i])
if 'real' in self.root:
label = io.read_image(lbl_path)
label = self.encode(label.numpy())
elif 'Syn' in self.root:
label = np.load(lbl_path)
label[label==255] = 0
label -= 1
label = torch.tensor(label[None,...])
else:
raise NotImplemented
sample['mask'] = label
if self.transform:
sample = self.transform(sample)
label = sample['mask']
del sample['mask']
label = label.long().squeeze(0)
sample_list = [sample['img']]
sample_list += [sample[k] for k in lf_names]
return sample_list, label
def _open_img(self, file):
img = io.read_image(file)
C, H, W = img.shape
if C == 4:
img = img[:3, ...]
if C == 1:
img = img.repeat(3, 1, 1)
return img
def encode(self, label: Tensor) -> Tensor:
label = label.transpose(1,2,0) # C, H, W -> H, W, C
label_mask = np.zeros((label.shape[0], label.shape[1]), dtype=np.int16)
for ii, lb in enumerate(self.PALETTE):
label_mask[np.where(np.all(label == lb, axis=-1))[:2]] = ii
label_mask = label_mask[None,...].astype(int)
return torch.from_numpy(label_mask)
if __name__ == '__main__':
traintransform = get_train_augmentation((432, 623), seg_fill=255)
trainset = UrbanLF(transform=traintransform, modals=['img', '1_2'])
trainloader = DataLoader(trainset, batch_size=2, num_workers=2, drop_last=True, pin_memory=False)
for i, (sample, lbl) in enumerate(trainloader):
print(torch.unique(lbl))
================================================
FILE: semseg/losses.py
================================================
import torch
from torch import nn, Tensor
from torch.nn import functional as F
class CrossEntropy(nn.Module):
def __init__(self, ignore_label: int = 255, weight: Tensor = None, aux_weights: list = [1, 0.4, 0.4]) -> None:
super().__init__()
self.aux_weights = aux_weights
self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label)
def _forward(self, preds: Tensor, labels: Tensor) -> Tensor:
# preds in shape [B, C, H, W] and labels in shape [B, H, W]
return self.criterion(preds, labels)
def forward(self, preds, labels: Tensor) -> Tensor:
if isinstance(preds, tuple):
return sum([w * self._forward(pred, labels) for (pred, w) in zip(preds, self.aux_weights)])
return self._forward(preds, labels)
class OhemCrossEntropy(nn.Module):
def __init__(self, ignore_label: int = 255, weight: Tensor = None, thresh: float = 0.7, aux_weights: list = [1, 1]) -> None:
super().__init__()
self.ignore_label = ignore_label
self.aux_weights = aux_weights
self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float))
self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label, reduction='none')
def _forward(self, preds: Tensor, labels: Tensor) -> Tensor:
# preds in shape [B, C, H, W] and labels in shape [B, H, W]
n_min = labels[labels != self.ignore_label].numel() // 16
loss = self.criterion(preds, labels).view(-1)
loss_hard = loss[loss > self.thresh]
if loss_hard.numel() < n_min:
loss_hard, _ = loss.topk(n_min)
return torch.mean(loss_hard)
def forward(self, preds, labels: Tensor) -> Tensor:
if isinstance(preds, tuple):
return sum([w * self._forward(pred, labels) for (pred, w) in zip(preds, self.aux_weights)])
return self._forward(preds, labels)
class Dice(nn.Module):
def __init__(self, delta: float = 0.5, aux_weights: list = [1, 0.4, 0.4]):
"""
delta: Controls weight given to FP and FN. This equals to dice score when delta=0.5
"""
super().__init__()
self.delta = delta
self.aux_weights = aux_weights
def _forward(self, preds: Tensor, labels: Tensor) -> Tensor:
# preds in shape [B, C, H, W] and labels in shape [B, H, W]
num_classes = preds.shape[1]
labels = F.one_hot(labels, num_classes).permute(0, 3, 1, 2)
tp = torch.sum(labels*preds, dim=(2, 3))
fn = torch.sum(labels*(1-preds), dim=(2, 3))
fp = torch.sum((1-labels)*preds, dim=(2, 3))
dice_score = (tp + 1e-6) / (tp + self.delta * fn + (1 - self.delta) * fp + 1e-6)
dice_score = torch.sum(1 - dice_score, dim=-1)
dice_score = dice_score / num_classes
return dice_score.mean()
def forward(self, preds, targets: Tensor) -> Tensor:
if isinstance(preds, tuple):
return sum([w * self._forward(pred, targets) for (pred, w) in zip(preds, self.aux_weights)])
return self._forward(preds, targets)
__all__ = ['CrossEntropy', 'OhemCrossEntropy', 'Dice']
def get_loss(loss_fn_name: str = 'CrossEntropy', ignore_label: int = 255, cls_weights: Tensor = None):
assert loss_fn_name in __all__, f"Unavailable loss function name >> {loss_fn_name}.\nAvailable loss functions: {__all__}"
if loss_fn_name == 'Dice':
return Dice()
return eval(loss_fn_name)(ignore_label, cls_weights)
if __name__ == '__main__':
pred = torch.randint(0, 19, (2, 19, 480, 640), dtype=torch.float)
label = torch.randint(0, 19, (2, 480, 640), dtype=torch.long)
loss_fn = Dice()
y = loss_fn(pred, label)
print(y)
================================================
FILE: semseg/metrics.py
================================================
import torch
from torch import Tensor
from typing import Tuple
class Metrics:
def __init__(self, num_classes: int, ignore_label: int, device) -> None:
self.ignore_label = ignore_label
self.num_classes = num_classes
self.hist = torch.zeros(num_classes, num_classes).to(device)
def update(self, pred: Tensor, target: Tensor) -> None:
pred = pred.argmax(dim=1)
keep = target != self.ignore_label
self.hist += torch.bincount(target[keep] * self.num_classes + pred[keep], minlength=self.num_classes**2).view(self.num_classes, self.num_classes)
def compute_iou(self) -> Tuple[Tensor, Tensor]:
ious = self.hist.diag() / (self.hist.sum(0) + self.hist.sum(1) - self.hist.diag())
ious[ious.isnan()]=0.
miou = ious.mean().item()
# miou = ious[~ious.isnan()].mean().item()
ious *= 100
miou *= 100
return ious.cpu().numpy().round(2).tolist(), round(miou, 2)
def compute_f1(self) -> Tuple[Tensor, Tensor]:
f1 = 2 * self.hist.diag() / (self.hist.sum(0) + self.hist.sum(1))
f1[f1.isnan()]=0.
mf1 = f1.mean().item()
# mf1 = f1[~f1.isnan()].mean().item()
f1 *= 100
mf1 *= 100
return f1.cpu().numpy().round(2).tolist(), round(mf1, 2)
def compute_pixel_acc(self) -> Tuple[Tensor, Tensor]:
acc = self.hist.diag() / self.hist.sum(1)
acc[acc.isnan()]=0.
macc = acc.mean().item()
# macc = acc[~acc.isnan()].mean().item()
acc *= 100
macc *= 100
return acc.cpu().numpy().round(2).tolist(), round(macc, 2)
================================================
FILE: semseg/models/__init__.py
================================================
from .cmx import CMX
from .cmnext import CMNeXt
__all__ = [
'CMX',
'CMNeXt',
]
================================================
FILE: semseg/models/backbones/__init__.py
================================================
from .cmx import CMX
from .cmnext import CMNeXt
__all__ = [
'CMX',
'CMNeXt',
]
================================================
FILE: semseg/models/backbones/cmnext.py
================================================
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from semseg.models.layers import DropPath
import functools
from functools import partial
from fvcore.nn import flop_count_table, FlopCountAnalysis
from semseg.models.modules.ffm import FeatureFusionModule as FFM
from semseg.models.modules.ffm import FeatureRectifyModule as FRM
from semseg.models.modules.ffm import ChannelEmbed
from semseg.models.modules.mspa import MSPABlock
from semseg.utils.utils import nchw_to_nlc, nlc_to_nchw
class Attention(nn.Module):
def __init__(self, dim, head, sr_ratio):
super().__init__()
self.head = head
self.sr_ratio = sr_ratio
self.scale = (dim // head) ** -0.5
self.q = nn.Linear(dim, dim)
self.kv = nn.Linear(dim, dim*2)
self.proj = nn.Linear(dim, dim)
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, sr_ratio, sr_ratio)
self.norm = nn.LayerNorm(dim)
def forward(self, x: Tensor, H, W) -> Tensor:
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x = x.permute(0, 2, 1).reshape(B, C, H, W)
x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
x = self.norm(x)
k, v = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class DWConv(nn.Module):
def __init__(self, dim):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
def forward(self, x: Tensor, H, W) -> Tensor:
B, _, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
return x.flatten(2).transpose(1, 2)
class MLP(nn.Module):
def __init__(self, c1, c2):
super().__init__()
self.fc1 = nn.Linear(c1, c2)
self.dwconv = DWConv(c2)
self.fc2 = nn.Linear(c2, c1)
def forward(self, x: Tensor, H, W) -> Tensor:
return self.fc2(F.gelu(self.dwconv(self.fc1(x), H, W)))
class PatchEmbed(nn.Module):
def __init__(self, c1=3, c2=32, patch_size=7, stride=4, padding=0):
super().__init__()
self.proj = nn.Conv2d(c1, c2, patch_size, stride, padding) # padding=(ps[0]//2, ps[1]//2)
self.norm = nn.LayerNorm(c2)
def forward(self, x: Tensor) -> Tensor:
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class PatchEmbedParallel(nn.Module):
def __init__(self, c1=3, c2=32, patch_size=7, stride=4, padding=0, num_modals=4):
super().__init__()
self.proj = ModuleParallel(nn.Conv2d(c1, c2, patch_size, stride, padding)) # padding=(ps[0]//2, ps[1]//2)
self.norm = LayerNormParallel(c2, num_modals)
def forward(self, x: list) -> list:
x = self.proj(x)
_, _, H, W = x[0].shape
x = self.norm(x)
return x, H, W
class Block(nn.Module):
def __init__(self, dim, head, sr_ratio=1, dpr=0., is_fan=False):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, head, sr_ratio)
self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity()
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, int(dim*4)) if not is_fan else ChannelProcessing(dim, mlp_hidden_dim=int(dim*4))
def forward(self, x: Tensor, H, W) -> Tensor:
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class ChannelProcessing(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., drop_path=0., mlp_hidden_dim=None, norm_layer=nn.LayerNorm):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp_v = MLP(dim, mlp_hidden_dim)
self.norm_v = norm_layer(dim)
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.pool = nn.AdaptiveAvgPool2d((None, 1))
self.sigmoid = nn.Sigmoid()
def forward(self, x, H, W, atten=None):
B, N, C = x.shape
v = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
k = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
q = q.softmax(-2).transpose(-1,-2)
_, _, Nk, Ck = k.shape
k = k.softmax(-2)
k = torch.nn.functional.avg_pool2d(k, (1, Ck))
attn = self.sigmoid(q @ k)
Bv, Hd, Nv, Cv = v.shape
v = self.norm_v(self.mlp_v(v.transpose(1, 2).reshape(Bv, Nv, Hd*Cv), H, W)).reshape(Bv, Nv, Hd, Cv).transpose(1, 2)
x = (attn * v.transpose(-1, -2)).permute(0, 3, 1, 2).reshape(B, N, C)
return x
class PredictorConv(nn.Module):
def __init__(self, embed_dim=384, num_modals=4):
super().__init__()
self.num_modals = num_modals
self.score_nets = nn.ModuleList([nn.Sequential(
nn.Conv2d(embed_dim, embed_dim, 3, 1, 1, groups=(embed_dim)),
nn.Conv2d(embed_dim, 1, 1),
nn.Sigmoid()
)for _ in range(num_modals)])
def forward(self, x):
B, C, H, W = x[0].shape
x_ = [torch.zeros((B, 1, H, W)) for _ in range(self.num_modals)]
for i in range(self.num_modals):
x_[i] = self.score_nets[i](x[i])
return x_
class ModuleParallel(nn.Module):
def __init__(self, module):
super(ModuleParallel, self).__init__()
self.module = module
def forward(self, x_parallel):
return [self.module(x) for x in x_parallel]
class ConvLayerNorm(nn.Module):
"""Channel first layer norm
"""
def __init__(self, normalized_shape, eps=1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
def forward(self, x: Tensor) -> Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class LayerNormParallel(nn.Module):
def __init__(self, num_features, num_modals=4):
super(LayerNormParallel, self).__init__()
# self.num_modals = num_modals
for i in range(num_modals):
setattr(self, 'ln_' + str(i), ConvLayerNorm(num_features, eps=1e-6))
def forward(self, x_parallel):
return [getattr(self, 'ln_' + str(i))(x) for i, x in enumerate(x_parallel)]
cmnext_settings = {
# 'B0': [[32, 64, 160, 256], [2, 2, 2, 2]],
# 'B1': [[64, 128, 320, 512], [2, 2, 2, 2]],
'B2': [[64, 128, 320, 512], [3, 4, 6, 3]],
# 'B3': [[64, 128, 320, 512], [3, 4, 18, 3]],
'B4': [[64, 128, 320, 512], [3, 8, 27, 3]],
'B5': [[64, 128, 320, 512], [3, 6, 40, 3]]
}
class CMNeXt(nn.Module):
def __init__(self, model_name: str = 'B0', modals: list = ['rgb', 'depth', 'event', 'lidar']):
super().__init__()
assert model_name in cmnext_settings.keys(), f"Model name should be in {list(cmnext_settings.keys())}"
embed_dims, depths = cmnext_settings[model_name]
extra_depths = depths
self.modals = modals[1:] if len(modals)>1 else []
self.num_modals = len(self.modals)
drop_path_rate = 0.1
self.channels = embed_dims
norm_cfg = dict(type='BN', requires_grad=True)
# patch_embed
self.patch_embed1 = PatchEmbed(3, embed_dims[0], 7, 4, 7//2)
self.patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2, 3//2)
self.patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2, 3//2)
self.patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2, 3//2)
if self.num_modals > 0:
self.extra_downsample_layers = nn.ModuleList([
PatchEmbedParallel(3, embed_dims[0], 7, 4, 7//2, self.num_modals),
*[PatchEmbedParallel(embed_dims[i], embed_dims[i+1], 3, 2, 3//2, self.num_modals) for i in range(3)]
])
if self.num_modals > 1:
self.extra_score_predictor = nn.ModuleList([PredictorConv(embed_dims[i], self.num_modals) for i in range(len(depths))])
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
self.block1 = nn.ModuleList([Block(embed_dims[0], 1, 8, dpr[cur+i]) for i in range(depths[0])])
self.norm1 = nn.LayerNorm(embed_dims[0])
if self.num_modals > 0:
self.extra_block1 = nn.ModuleList([MSPABlock(embed_dims[0], mlp_ratio=8, drop_path=dpr[cur+i], norm_cfg=norm_cfg) for i in range(extra_depths[0])]) # --- MSPABlock
self.extra_norm1 = ConvLayerNorm(embed_dims[0])
cur += depths[0]
self.block2 = nn.ModuleList([Block(embed_dims[1], 2, 4, dpr[cur+i]) for i in range(depths[1])])
self.norm2 = nn.LayerNorm(embed_dims[1])
if self.num_modals > 0:
self.extra_block2 = nn.ModuleList([MSPABlock(embed_dims[1], mlp_ratio=8, drop_path=dpr[cur+i], norm_cfg=norm_cfg) for i in range(extra_depths[1])])
self.extra_norm2 = ConvLayerNorm(embed_dims[1])
cur += depths[1]
self.block3 = nn.ModuleList([Block(embed_dims[2], 5, 2, dpr[cur+i]) for i in range(depths[2])])
self.norm3 = nn.LayerNorm(embed_dims[2])
if self.num_modals > 0:
self.extra_block3 = nn.ModuleList([MSPABlock(embed_dims[2], mlp_ratio=4, drop_path=dpr[cur+i], norm_cfg=norm_cfg) for i in range(extra_depths[2])])
self.extra_norm3 = ConvLayerNorm(embed_dims[2])
cur += depths[2]
self.block4 = nn.ModuleList([Block(embed_dims[3], 8, 1, dpr[cur+i]) for i in range(depths[3])])
self.norm4 = nn.LayerNorm(embed_dims[3])
if self.num_modals > 0:
self.extra_block4 = nn.ModuleList([MSPABlock(embed_dims[3], mlp_ratio=4, drop_path=dpr[cur+i], norm_cfg=norm_cfg) for i in range(extra_depths[3])])
self.extra_norm4 = ConvLayerNorm(embed_dims[3])
if self.num_modals > 0:
num_heads = [1,2,5,8]
self.FRMs = nn.ModuleList([
FRM(dim=embed_dims[0], reduction=1),
FRM(dim=embed_dims[1], reduction=1),
FRM(dim=embed_dims[2], reduction=1),
FRM(dim=embed_dims[3], reduction=1)])
self.FFMs = nn.ModuleList([
FFM(dim=embed_dims[0], reduction=1, num_heads=num_heads[0], norm_layer=nn.BatchNorm2d),
FFM(dim=embed_dims[1], reduction=1, num_heads=num_heads[1], norm_layer=nn.BatchNorm2d),
FFM(dim=embed_dims[2], reduction=1, num_heads=num_heads[2], norm_layer=nn.BatchNorm2d),
FFM(dim=embed_dims[3], reduction=1, num_heads=num_heads[3], norm_layer=nn.BatchNorm2d)])
def tokenselect(self, x_ext, module):
x_scores = module(x_ext)
for i in range(len(x_ext)):
x_ext[i] = x_scores[i] * x_ext[i] + x_ext[i]
x_f = functools.reduce(torch.max, x_ext)
return x_f
def forward(self, x: list) -> list:
x_cam = x[0]
if self.num_modals > 0:
x_ext = x[1:]
B = x_cam.shape[0]
outs = []
# stage 1
x_cam, H, W = self.patch_embed1(x_cam)
for blk in self.block1:
x_cam = blk(x_cam, H, W)
x1_cam = self.norm1(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)
if self.num_modals > 0:
x_ext, _, _ = self.extra_downsample_layers[0](x_ext)
x_f = self.tokenselect(x_ext, self.extra_score_predictor[0]) if self.num_modals > 1 else x_ext[0]
for blk in self.extra_block1:
x_f = blk(x_f)
x1_f = self.extra_norm1(x_f)
x1_cam, x1_f = self.FRMs[0](x1_cam, x1_f)
x_fused = self.FFMs[0](x1_cam, x1_f)
outs.append(x_fused)
x_ext = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2) + x1_f for x_ in x_ext] if self.num_modals > 1 else [x1_f]
else:
outs.append(x1_cam)
# stage 2
x_cam, H, W = self.patch_embed2(x1_cam)
for blk in self.block2:
x_cam = blk(x_cam, H, W)
x2_cam = self.norm2(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)
if self.num_modals > 0:
x_ext, _, _ = self.extra_downsample_layers[1](x_ext)
x_f = self.tokenselect(x_ext, self.extra_score_predictor[1]) if self.num_modals > 1 else x_ext[0]
for blk in self.extra_block2:
x_f = blk(x_f)
x2_f = self.extra_norm2(x_f)
x2_cam, x2_f = self.FRMs[1](x2_cam, x2_f)
x_fused = self.FFMs[1](x2_cam, x2_f)
outs.append(x_fused)
x_ext = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2) + x2_f for x_ in x_ext] if self.num_modals > 1 else [x2_f]
else:
outs.append(x2_cam)
# stage 3
x_cam, H, W = self.patch_embed3(x2_cam)
for blk in self.block3:
x_cam = blk(x_cam, H, W)
x3_cam = self.norm3(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)
if self.num_modals > 0:
x_ext, _, _ = self.extra_downsample_layers[2](x_ext)
x_f = self.tokenselect(x_ext, self.extra_score_predictor[2]) if self.num_modals > 1 else x_ext[0]
for blk in self.extra_block3:
x_f = blk(x_f)
x3_f = self.extra_norm3(x_f)
x3_cam, x3_f = self.FRMs[2](x3_cam, x3_f)
x_fused = self.FFMs[2](x3_cam, x3_f)
outs.append(x_fused)
x_ext = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2) + x3_f for x_ in x_ext] if self.num_modals > 1 else [x3_f]
else:
outs.append(x3_cam)
# stage 4
x_cam, H, W = self.patch_embed4(x3_cam)
for blk in self.block4:
x_cam = blk(x_cam, H, W)
x4_cam = self.norm4(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)
if self.num_modals > 0:
x_ext, _, _ = self.extra_downsample_layers[3](x_ext)
x_f = self.tokenselect(x_ext, self.extra_score_predictor[3]) if self.num_modals > 1 else x_ext[0]
for blk in self.extra_block4:
x_f = blk(x_f)
x4_f = self.extra_norm4(x_f)
x4_cam, x4_f = self.FRMs[3](x4_cam, x4_f)
x_fused = self.FFMs[3](x4_cam, x4_f)
outs.append(x_fused)
else:
outs.append(x4_cam)
return outs
if __name__ == '__main__':
modals = ['img', 'depth', 'event', 'lidar']
x = [torch.zeros(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024)*2, torch.ones(1, 3, 1024, 1024) *3]
model = CMNeXt('B2', modals)
outs = model(x)
for y in outs:
print(y.shape)
================================================
FILE: semseg/models/backbones/cmx.py
================================================
import torch
from torch import nn, Tensor
from torch.nn import functional as F
import einops
from semseg.models.layers import DropPath
from semseg.models.modules.ffm import FeatureFusionModule as FFM
from semseg.models.modules.ffm import FeatureRectifyModule as FRM
class Attention(nn.Module):
def __init__(self, dim, head, sr_ratio):
super().__init__()
self.head = head
self.sr_ratio = sr_ratio
self.scale = (dim // head) ** -0.5
self.q = nn.Linear(dim, dim)
self.kv = nn.Linear(dim, dim*2)
self.proj = nn.Linear(dim, dim)
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, sr_ratio, sr_ratio)
self.norm = nn.LayerNorm(dim)
def forward(self, x: Tensor, H, W) -> Tensor:
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x = x.permute(0, 2, 1).reshape(B, C, H, W)
x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
x = self.norm(x)
k, v = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class DWConv(nn.Module):
def __init__(self, dim):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
def forward(self, x: Tensor, H, W) -> Tensor:
B, _, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
return x.flatten(2).transpose(1, 2)
class MLP(nn.Module):
def __init__(self, c1, c2):
super().__init__()
self.fc1 = nn.Linear(c1, c2)
self.dwconv = DWConv(c2)
self.fc2 = nn.Linear(c2, c1)
def forward(self, x: Tensor, H, W) -> Tensor:
return self.fc2(F.gelu(self.dwconv(self.fc1(x), H, W)))
class PatchEmbed(nn.Module):
def __init__(self, c1=3, c2=32, patch_size=7, stride=4):
super().__init__()
self.proj = nn.Conv2d(c1, c2, patch_size, stride, patch_size//2) # padding=(ps[0]//2, ps[1]//2)
self.norm = nn.LayerNorm(c2)
def forward(self, x: Tensor) -> Tensor:
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class Block(nn.Module):
def __init__(self, dim, head, sr_ratio=1, dpr=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, head, sr_ratio)
self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity()
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, int(dim*4))
def forward(self, x: Tensor, H, W) -> Tensor:
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
# class PredictorConv(nn.Module):
# """ Image to modality score, in spatial selection
# b, h, w, c -> b, h, w, 1
# """
# def __init__(self, embed_dim=384, num_modals=4):
# super().__init__()
# self.num_modals = num_modals
# self.dwconv = ModuleParallel(nn.Conv2d(embed_dim, embed_dim, 7, 1, 3, groups=embed_dim))
# self.score_nets = nn.ModuleList([nn.Sequential(
# nn.LayerNorm(embed_dim, eps=1e-6),
# # nn.Linear(embed_dim, embed_dim),
# # nn.GELU(),
# # nn.Linear(embed_dim, embed_dim // 2),
# # nn.GELU(),
# # nn.Linear(embed_dim // 2, embed_dim // 4),
# # nn.GELU(),
# # nn.Linear(embed_dim // 4, 1),
# # nn.Linear(embed_dim, embed_dim // 4),
# # nn.GELU(),
# nn.Linear(embed_dim, 1),
# # nn.Sigmoid()
# # nn.LogSoftmax(dim=-1)
# nn.Softmax(dim=-1)
# ) for _ in range(num_modals)])
# def forward(self, x):
# x = self.dwconv(x)
# x = [x_.permute(0, 2, 3, 1) for x_ in x] # NCHW to NHWC
# x = [self.score_nets[i](x[i]) for i in range(self.num_modals)]
# x = [x_.permute(0, 3, 1, 2) for x_ in x] # NHWC to NCHW
# # for i, (xi, rat) in enumerate(zip(x, [1, 1, 0.5, 0.2])):
# # x[i] = xi * rat
# return x
class PredictorLG(nn.Module):
""" Image to Patch Embedding from DydamicVit
"""
def __init__(self, embed_dim=384, num_modals=4):
super().__init__()
self.num_modals = num_modals
self.score_nets = nn.ModuleList([nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, embed_dim // 4),
nn.GELU(),
nn.Linear(embed_dim // 4, 1),
nn.Softmax(dim=-1)
) for _ in range(num_modals)])
def forward(self, x):
x = [self.score_nets[i](x[i]) for i in range(self.num_modals)]
return x
mit_settings = {
'B0': [[32, 64, 160, 256], [2, 2, 2, 2]], # [embed_dims, depths]
'B1': [[64, 128, 320, 512], [2, 2, 2, 2]],
'B2': [[64, 128, 320, 512], [3, 4, 6, 3]],
'B3': [[64, 128, 320, 512], [3, 4, 18, 3]],
'B4': [[64, 128, 320, 512], [3, 8, 27, 3]],
'B5': [[64, 128, 320, 512], [3, 6, 40, 3]]
}
class CMX(nn.Module):
def __init__(self, model_name: str = 'B0', modals: list = ['rgb', 'depth', 'event', 'lidar']):
super().__init__()
assert model_name in mit_settings.keys(), f"Model name should be in {list(mit_settings.keys())}"
embed_dims, depths = mit_settings[model_name]
extra_depths = depths # for fusion branch
self.modals = modals[1:] if len(modals)>1 else [] # remove rgb
self.num_modals = len(self.modals)
drop_path_rate = 0.1
self.channels = embed_dims
# patch_embed
self.patch_embed1 = PatchEmbed(3, embed_dims[0], 7, 4)
self.patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2)
self.patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2)
self.patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2)
if self.num_modals > 0:
self.extra_patch_embed1 = PatchEmbed(3, embed_dims[0], 7, 4)
self.extra_patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2)
self.extra_patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2)
self.extra_patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2)
if self.num_modals > 1:
# self.extra_score_predictor = nn.ModuleList([PredictorConv(embed_dims[i], self.num_modals) for i in range(len(depths))])
self.extra_score_predictor = nn.ModuleList([PredictorLG(embed_dims[i], self.num_modals) for i in range(len(depths))])
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
self.block1 = nn.ModuleList([Block(embed_dims[0], 1, 8, dpr[cur+i]) for i in range(depths[0])])
self.norm1 = nn.LayerNorm(embed_dims[0])
if self.num_modals > 0:
self.extra_block1 = nn.ModuleList([Block(embed_dims[0], 1, 8, dpr[cur+i]) for i in range(depths[0])])
self.extra_norm1 = nn.LayerNorm(embed_dims[0])
cur += depths[0]
self.block2 = nn.ModuleList([Block(embed_dims[1], 2, 4, dpr[cur+i]) for i in range(depths[1])])
self.norm2 = nn.LayerNorm(embed_dims[1])
if self.num_modals > 0:
self.extra_block2 = nn.ModuleList([Block(embed_dims[1], 2, 4, dpr[cur+i]) for i in range(depths[1])])
self.extra_norm2 = nn.LayerNorm(embed_dims[1])
cur += depths[1]
self.block3 = nn.ModuleList([Block(embed_dims[2], 5, 2, dpr[cur+i]) for i in range(depths[2])])
self.norm3 = nn.LayerNorm(embed_dims[2])
if self.num_modals > 0:
self.extra_block3 = nn.ModuleList([Block(embed_dims[2], 5, 2, dpr[cur+i]) for i in range(depths[2])])
self.extra_norm3 = nn.LayerNorm(embed_dims[2])
cur += depths[2]
self.block4 = nn.ModuleList([Block(embed_dims[3], 8, 1, dpr[cur+i]) for i in range(depths[3])])
self.norm4 = nn.LayerNorm(embed_dims[3])
if self.num_modals > 0:
self.extra_block4 = nn.ModuleList([Block(embed_dims[3], 8, 1, dpr[cur+i]) for i in range(depths[3])])
self.extra_norm4 = nn.LayerNorm(embed_dims[3])
if self.num_modals > 0:
num_heads = [1,2,5,8]
self.FRMs = nn.ModuleList([
FRM(dim=embed_dims[0], reduction=1),
FRM(dim=embed_dims[1], reduction=1),
FRM(dim=embed_dims[2], reduction=1),
FRM(dim=embed_dims[3], reduction=1)])
self.FFMs = nn.ModuleList([
FFM(dim=embed_dims[0], reduction=1, num_heads=num_heads[0], norm_layer=nn.BatchNorm2d),
FFM(dim=embed_dims[1], reduction=1, num_heads=num_heads[1], norm_layer=nn.BatchNorm2d),
FFM(dim=embed_dims[2], reduction=1, num_heads=num_heads[2], norm_layer=nn.BatchNorm2d),
FFM(dim=embed_dims[3], reduction=1, num_heads=num_heads[3], norm_layer=nn.BatchNorm2d)])
# ---Hard selection
def tokenselect(self, x_ext, module):
x_scores = module(x_ext)
# select tokens according to the max score of multiple modals, regarding H, W
x_stack = torch.stack(x_ext, dim=-1) # B, N, C, N_modals
B, N, C, N_modals = x_stack.shape
x_scores = torch.stack(x_scores, dim=-1) # B, N, 1, N_modals
x_index = torch.argmax(x_scores, dim=-1, keepdim=True) # B, N, 1, N_modals
x_index = einops.repeat(x_index, 'b n 1 m -> b n c m', c=C) # B, C, H, W, N_modals
# --- token selection
x_select = x_stack.gather(-1, x_index)
return x_select.squeeze(-1) # B, C, H, W
def forward(self, x: list) -> list:
x_cam = x[0]
if self.num_modals > 0:
x_ext = x[1:]
B = x_cam.shape[0]
outs = []
# stage 1
x_cam, H, W = self.patch_embed1(x_cam)
for blk in self.block1:
x_cam = blk(x_cam, H, W)
x_cam = self.norm1(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)
if self.num_modals > 0:
x_ext = [self.extra_patch_embed1(x_ext_)[0] for x_ext_ in x_ext]
x_f = self.tokenselect(x_ext, self.extra_score_predictor[0]) if self.num_modals > 1 else x_ext[0]
for blk in self.extra_block1:
x_f = blk(x_f, H, W)
x_f = self.extra_norm1(x_f).reshape(B, H, W, -1).permute(0, 3, 1, 2)
# --- FFM
x_cam, x_f = self.FRMs[0](x_cam, x_f)
x_fused = self.FFMs[0](x_cam, x_f)
outs.append(x_fused)
x_ext = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2) + x_f for x_ in x_ext] if self.num_modals > 1 else [x_f]
else:
outs.append(x_cam)
# stage 2
x_cam, H, W = self.patch_embed2(x_cam)
for blk in self.block2:
x_cam = blk(x_cam, H, W)
x_cam = self.norm2(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)
if self.num_modals > 0:
x_ext = [self.extra_patch_embed2(x_ext_)[0] for x_ext_ in x_ext]
x_f = self.tokenselect(x_ext, self.extra_score_predictor[1]) if self.num_modals > 1 else x_ext[0]
for blk in self.extra_block2:
x_f = blk(x_f, H, W)
x_f = self.extra_norm2(x_f).reshape(B, H, W, -1).permute(0, 3, 1, 2)
# --- FFM
x_cam, x_f = self.FRMs[1](x_cam, x_f)
x_fused = self.FFMs[1](x_cam, x_f)
outs.append(x_fused)
x_ext = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2) + x_f for x_ in x_ext] if self.num_modals > 1 else [x_f]
else:
outs.append(x_cam)
# stage 3
x_cam, H, W = self.patch_embed3(x_cam)
for blk in self.block3:
x_cam = blk(x_cam, H, W)
x_cam = self.norm3(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)
if self.num_modals > 0:
x_ext = [self.extra_patch_embed3(x_ext_)[0] for x_ext_ in x_ext]
x_f = self.tokenselect(x_ext, self.extra_score_predictor[2]) if self.num_modals > 1 else x_ext[0]
for blk in self.extra_block3:
x_f = blk(x_f, H, W)
x_f = self.extra_norm3(x_f).reshape(B, H, W, -1).permute(0, 3, 1, 2)
# --- FFM
x_cam, x_f = self.FRMs[2](x_cam, x_f)
x_fused = self.FFMs[2](x_cam, x_f)
outs.append(x_fused)
x_ext = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2) + x_f for x_ in x_ext] if self.num_modals > 1 else [x_f]
else:
outs.append(x_cam)
# stage 4
x_cam, H, W = self.patch_embed4(x_cam)
for blk in self.block4:
x_cam = blk(x_cam, H, W)
x_cam = self.norm4(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)
if self.num_modals > 0:
x_ext = [self.extra_patch_embed4(x_ext_)[0] for x_ext_ in x_ext]
x_f = self.tokenselect(x_ext, self.extra_score_predictor[3]) if self.num_modals > 1 else x_ext[0]
for blk in self.extra_block4:
x_f = blk(x_f, H, W)
x_f = self.extra_norm4(x_f).reshape(B, H, W, -1).permute(0, 3, 1, 2)
# --- FFM
x_cam, x_f = self.FRMs[3](x_cam, x_f)
x_fused = self.FFMs[3](x_cam, x_f)
outs.append(x_fused)
# x_ext = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2) + x_f for x_ in x_ext] if self.num_modals > 1 else [x_f]
else:
outs.append(x_cam)
return outs
if __name__ == '__main__':
modals = ['img']
# modals = ['img', 'depth', 'event', 'lidar']
# x = [torch.zeros(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024)*2, torch.ones(1, 3, 1024, 1024) *3]
# modals = ['img', 'depth']
# x = [torch.zeros(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024)]
x = [torch.zeros(1, 3, 1024, 1024)]
model = CMX('B2', modals)
outs = model(x)
print(model)
for y in outs:
print(y.shape)
# print(flop_count_table(FlopCountAnalysis(model, x)))
================================================
FILE: semseg/models/base.py
================================================
import torch
import math
from torch import nn
from semseg.models.backbones import *
from semseg.models.layers import trunc_normal_
from collections import OrderedDict
def load_dualpath_model(model, model_file):
# load raw state_dict
if isinstance(model_file, str):
raw_state_dict = torch.load(model_file, map_location=torch.device('cpu'))
#raw_state_dict = torch.load(model_file)
if 'model' in raw_state_dict.keys():
raw_state_dict = raw_state_dict['model']
else:
raw_state_dict = model_file
state_dict = {}
for k, v in raw_state_dict.items():
if k.find('patch_embed') >= 0:
state_dict[k] = v
# patch_embedx, proj, weight = k.split('.')
# state_dict[k.replace('patch_embed', 'extra_patch_embed')] = v
# state_dict[new_k] = v
elif k.find('block') >= 0:
state_dict[k] = v
# state_dict[k.replace('block', 'extra_block')] = v
elif k.find('norm') >= 0:
state_dict[k] = v
# state_dict[k.replace('norm', 'extra_norm')] = v
msg = model.load_state_dict(state_dict, strict=False)
print(msg)
del state_dict
class BaseModel(nn.Module):
def __init__(self, backbone: str = 'MiT-B0', num_classes: int = 19, modals: list = ['rgb', 'depth', 'event', 'lidar']) -> None:
super().__init__()
backbone, variant = backbone.split('-')
self.backbone = eval(backbone)(variant, modals)
# self.backbone = eval(backbone)(variant)
self.modals = modals
def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out // m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def init_pretrained(self, pretrained: str = None) -> None:
if pretrained:
if len(self.modals)>1:
load_dualpath_model(self.backbone, pretrained)
else:
checkpoint = torch.load(pretrained, map_location='cpu')
if 'state_dict' in checkpoint.keys():
checkpoint = checkpoint['state_dict']
# if 'PoolFormer' in self.__class__.__name__:
# new_dict = OrderedDict()
# for k, v in checkpoint.items():
# if not 'backbone.' in k:
# new_dict['backbone.'+k] = v
# else:
# new_dict[k] = v
# checkpoint = new_dict
if 'model' in checkpoint.keys(): # --- for HorNet
checkpoint = checkpoint['model']
msg = self.backbone.load_state_dict(checkpoint, strict=False)
print(msg)
================================================
FILE: semseg/models/cmnext.py
================================================
import torch
from torch import Tensor
from torch.nn import functional as F
from semseg.models.base import BaseModel
from semseg.models.heads import SegFormerHead
from semseg.models.heads import LightHamHead
from semseg.models.heads import UPerHead
from fvcore.nn import flop_count_table, FlopCountAnalysis
class CMNeXt(BaseModel):
def __init__(self, backbone: str = 'CMNeXt-B0', num_classes: int = 25, modals: list = ['img', 'depth', 'event', 'lidar']) -> None:
super().__init__(backbone, num_classes, modals)
self.decode_head = SegFormerHead(self.backbone.channels, 256 if 'B0' in backbone or 'B1' in backbone else 512, num_classes)
self.apply(self._init_weights)
def forward(self, x: list) -> list:
y = self.backbone(x)
y = self.decode_head(y)
y = F.interpolate(y, size=x[0].shape[2:], mode='bilinear', align_corners=False)
return y
def init_pretrained(self, pretrained: str = None) -> None:
if pretrained:
if self.backbone.num_modals > 0:
load_dualpath_model(self.backbone, pretrained)
else:
checkpoint = torch.load(pretrained, map_location='cpu')
if 'state_dict' in checkpoint.keys():
checkpoint = checkpoint['state_dict']
if 'model' in checkpoint.keys():
checkpoint = checkpoint['model']
msg = self.backbone.load_state_dict(checkpoint, strict=False)
print(msg)
def load_dualpath_model(model, model_file):
extra_pretrained = None
if isinstance(extra_pretrained, str):
raw_state_dict_ext = torch.load(extra_pretrained, map_location=torch.device('cpu'))
if 'state_dict' in raw_state_dict_ext.keys():
raw_state_dict_ext = raw_state_dict_ext['state_dict']
if isinstance(model_file, str):
raw_state_dict = torch.load(model_file, map_location=torch.device('cpu'))
if 'model' in raw_state_dict.keys():
raw_state_dict = raw_state_dict['model']
else:
raw_state_dict = model_file
state_dict = {}
for k, v in raw_state_dict.items():
if k.find('patch_embed') >= 0:
state_dict[k] = v
elif k.find('block') >= 0:
state_dict[k] = v
elif k.find('norm') >= 0:
state_dict[k] = v
if isinstance(extra_pretrained, str):
for k, v in raw_state_dict_ext.items():
if k.find('patch_embed1.proj') >= 0:
state_dict[k.replace('patch_embed1.proj', 'extra_downsample_layers.0.proj.module')] = v
if k.find('patch_embed2.proj') >= 0:
state_dict[k.replace('patch_embed2.proj', 'extra_downsample_layers.1.proj.module')] = v
if k.find('patch_embed3.proj') >= 0:
state_dict[k.replace('patch_embed3.proj', 'extra_downsample_layers.2.proj.module')] = v
if k.find('patch_embed4.proj') >= 0:
state_dict[k.replace('patch_embed4.proj', 'extra_downsample_layers.3.proj.module')] = v
if k.find('patch_embed1.norm') >= 0:
for i in range(model.num_modals):
state_dict[k.replace('patch_embed1.norm', 'extra_downsample_layers.0.norm.ln_{}'.format(i))] = v
if k.find('patch_embed2.norm') >= 0:
for i in range(model.num_modals):
state_dict[k.replace('patch_embed2.norm', 'extra_downsample_layers.1.norm.ln_{}'.format(i))] = v
if k.find('patch_embed3.norm') >= 0:
for i in range(model.num_modals):
state_dict[k.replace('patch_embed3.norm', 'extra_downsample_layers.2.norm.ln_{}'.format(i))] = v
if k.find('patch_embed4.norm') >= 0:
for i in range(model.num_modals):
state_dict[k.replace('patch_embed4.norm', 'extra_downsample_layers.3.norm.ln_{}'.format(i))] = v
elif k.find('block') >= 0:
state_dict[k.replace('block', 'extra_block')] = v
elif k.find('norm') >= 0:
state_dict[k.replace('norm', 'extra_norm')] = v
msg = model.load_state_dict(state_dict, strict=False)
del state_dict
if __name__ == '__main__':
modals = ['img', 'depth', 'event', 'lidar']
model = CMNeXt('CMNeXt-B2', 25, modals)
model.init_pretrained('checkpoints/pretrained/segformer/mit_b2.pth')
x = [torch.zeros(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024)*2, torch.ones(1, 3, 1024, 1024) *3]
y = model(x)
print(y.shape)
================================================
FILE: semseg/models/cmx.py
================================================
import torch
from torch import Tensor
from torch.nn import functional as F
from semseg.models.base import BaseModel
from semseg.models.heads import SegFormerHead
from fvcore.nn import flop_count_table, FlopCountAnalysis
class CMX(BaseModel):
def __init__(self, backbone: str = 'CMX-B0', num_classes: int = 25, modals: list = ['img', 'depth', 'event', 'lidar']) -> None:
super().__init__(backbone, num_classes, modals)
self.decode_head = SegFormerHead(self.backbone.channels, 256 if 'B0' in backbone or 'B1' in backbone else 512, num_classes)
self.apply(self._init_weights)
def forward(self, x: list) -> list:
y = self.backbone(x)
y = self.decode_head(y)
y = F.interpolate(y, size=x[0].shape[2:], mode='bilinear', align_corners=False) # to original image shape
return y
def init_pretrained(self, pretrained: str = None) -> None:
if pretrained:
if self.backbone.num_modals > 0:
load_dualpath_model(self.backbone, pretrained)
else:
checkpoint = torch.load(pretrained, map_location='cpu')
if 'state_dict' in checkpoint.keys():
checkpoint = checkpoint['state_dict']
if 'model' in checkpoint.keys():
checkpoint = checkpoint['model']
msg = self.backbone.load_state_dict(checkpoint, strict=False)
print(msg)
def load_dualpath_model(model, model_file):
if isinstance(model_file, str):
raw_state_dict = torch.load(model_file, map_location=torch.device('cpu'))
if 'model' in raw_state_dict.keys():
raw_state_dict = raw_state_dict['model']
else:
raw_state_dict = model_file
state_dict = {}
for k, v in raw_state_dict.items():
if k.find('patch_embed') >= 0:
state_dict[k] = v
elif k.find('block') >= 0:
state_dict[k] = v
elif k.find('norm') >= 0:
state_dict[k] = v
msg = model.load_state_dict(state_dict, strict=False)
del state_dict
if __name__ == '__main__':
modals = ['img']
# modals = ['img', 'depth', 'event', 'lidar']
model = CMX('CMX-B2', 25, modals)
model.init_pretrained('checkpoints/pretrained/segformer/mit_b2.pth')
x = [torch.zeros(1, 3, 512, 512)]
y = model(x)
print(y.shape)
================================================
FILE: semseg/models/heads/__init__.py
================================================
from .upernet import UPerHead
from .segformer import SegFormerHead
from .sfnet import SFHead
from .fpn import FPNHead
from .fapn import FaPNHead
from .fcn import FCNHead
from .condnet import CondHead
from .lawin import LawinHead
from .hem import LightHamHead
__all__ = ['UPerHead', 'SegFormerHead']
================================================
FILE: semseg/models/heads/condnet.py
================================================
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from semseg.models.layers import ConvModule
class CondHead(nn.Module):
def __init__(self, in_channel: int = 2048, channel: int = 512, num_classes: int = 19):
super().__init__()
self.num_classes = num_classes
self.weight_num = channel * num_classes
self.bias_num = num_classes
self.conv = ConvModule(in_channel, channel, 1)
self.dropout = nn.Dropout2d(0.1)
self.guidance_project = nn.Conv2d(channel, num_classes, 1)
self.filter_project = nn.Conv2d(channel*num_classes, self.weight_num + self.bias_num, 1, groups=num_classes)
def forward(self, features) -> Tensor:
x = self.dropout(self.conv(features[-1]))
B, C, H, W = x.shape
guidance_mask = self.guidance_project(x)
cond_logit = guidance_mask
key = x
value = x
guidance_mask = guidance_mask.softmax(dim=1).view(*guidance_mask.shape[:2], -1)
key = key.view(B, C, -1).permute(0, 2, 1)
cond_filters = torch.matmul(guidance_mask, key)
cond_filters /= H * W
cond_filters = cond_filters.view(B, -1, 1, 1)
cond_filters = self.filter_project(cond_filters)
cond_filters = cond_filters.view(B, -1)
weight, bias = torch.split(cond_filters, [self.weight_num, self.bias_num], dim=1)
weight = weight.reshape(B * self.num_classes, -1, 1, 1)
bias = bias.reshape(B * self.num_classes)
value = value.view(-1, H, W).unsqueeze(0)
seg_logit = F.conv2d(value, weight, bias, 1, 0, groups=B).view(B, self.num_classes, H, W)
if self.training:
return cond_logit, seg_logit
return seg_logit
if __name__ == '__main__':
from semseg.models.backbones import ResNetD
backbone = ResNetD('50')
head = CondHead()
x = torch.randn(2, 3, 224, 224)
features = backbone(x)
outs = head(features)
for out in outs:
out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False)
print(out.shape)
================================================
FILE: semseg/models/heads/fapn.py
================================================
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torchvision.ops import DeformConv2d
from semseg.models.layers import ConvModule
class DCNv2(nn.Module):
def __init__(self, c1, c2, k, s, p, g=1):
super().__init__()
self.dcn = DeformConv2d(c1, c2, k, s, p, groups=g)
self.offset_mask = nn.Conv2d(c2, g* 3 * k * k, k, s, p)
self._init_offset()
def _init_offset(self):
self.offset_mask.weight.data.zero_()
self.offset_mask.bias.data.zero_()
def forward(self, x, offset):
out = self.offset_mask(offset)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat([o1, o2], dim=1)
mask = mask.sigmoid()
return self.dcn(x, offset, mask)
class FSM(nn.Module):
def __init__(self, c1, c2):
super().__init__()
self.conv_atten = nn.Conv2d(c1, c1, 1, bias=False)
self.conv = nn.Conv2d(c1, c2, 1, bias=False)
def forward(self, x: Tensor) -> Tensor:
atten = self.conv_atten(F.avg_pool2d(x, x.shape[2:])).sigmoid()
feat = torch.mul(x, atten)
x = x + feat
return self.conv(x)
class FAM(nn.Module):
def __init__(self, c1, c2):
super().__init__()
self.lateral_conv = FSM(c1, c2)
self.offset = nn.Conv2d(c2*2, c2, 1, bias=False)
self.dcpack_l2 = DCNv2(c2, c2, 3, 1, 1, 8)
def forward(self, feat_l, feat_s):
feat_up = feat_s
if feat_l.shape[2:] != feat_s.shape[2:]:
feat_up = F.interpolate(feat_s, size=feat_l.shape[2:], mode='bilinear', align_corners=False)
feat_arm = self.lateral_conv(feat_l)
offset = self.offset(torch.cat([feat_arm, feat_up*2], dim=1))
feat_align = F.relu(self.dcpack_l2(feat_up, offset))
return feat_align + feat_arm
class FaPNHead(nn.Module):
def __init__(self, in_channels, channel=128, num_classes=19):
super().__init__()
in_channels = in_channels[::-1]
self.align_modules = nn.ModuleList([ConvModule(in_channels[0], channel, 1)])
self.output_convs = nn.ModuleList([])
for ch in in_channels[1:]:
self.align_modules.append(FAM(ch, channel))
self.output_convs.append(ConvModule(channel, channel, 3, 1, 1))
self.conv_seg = nn.Conv2d(channel, num_classes, 1)
self.dropout = nn.Dropout2d(0.1)
def forward(self, features) -> Tensor:
features = features[::-1]
out = self.align_modules[0](features[0])
for feat, align_module, output_conv in zip(features[1:], self.align_modules[1:], self.output_convs):
out = align_module(feat, out)
out = output_conv(out)
out = self.conv_seg(self.dropout(out))
return out
if __name__ == '__main__':
from semseg.models.backbones import ResNet
backbone = ResNet('50')
head = FaPNHead([256, 512, 1024, 2048], 128, 19)
x = torch.randn(2, 3, 224, 224)
features = backbone(x)
out = head(features)
out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False)
print(out.shape)
================================================
FILE: semseg/models/heads/fcn.py
================================================
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from semseg.models.layers import ConvModule
class FCNHead(nn.Module):
def __init__(self, c1, c2, num_classes: int = 19):
super().__init__()
self.conv = ConvModule(c1, c2, 1)
self.cls = nn.Conv2d(c2, num_classes, 1)
def forward(self, features) -> Tensor:
x = self.conv(features[-1])
x = self.cls(x)
return x
if __name__ == '__main__':
from semseg.models.backbones import ResNet
backbone = ResNet('50')
head = FCNHead(2048, 256, 19)
x = torch.randn(2, 3, 224, 224)
features = backbone(x)
out = head(features)
out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False)
print(out.shape)
================================================
FILE: semseg/models/heads/fpn.py
================================================
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from semseg.models.layers import ConvModule
class FPNHead(nn.Module):
"""Panoptic Feature Pyramid Networks
https://arxiv.org/abs/1901.02446
"""
def __init__(self, in_channels, channel=128, num_classes=19):
super().__init__()
self.lateral_convs = nn.ModuleList([])
self.output_convs = nn.ModuleList([])
for ch in in_channels[::-1]:
self.lateral_convs.append(ConvModule(ch, channel, 1))
self.output_convs.append(ConvModule(channel, channel, 3, 1, 1))
self.conv_seg = nn.Conv2d(channel, num_classes, 1)
self.dropout = nn.Dropout2d(0.1)
def forward(self, features) -> Tensor:
features = features[::-1]
out = self.lateral_convs[0](features[0])
for i in range(1, len(features)):
out = F.interpolate(out, scale_factor=2.0, mode='nearest')
out = out + self.lateral_convs[i](features[i])
out = self.output_convs[i](out)
out = self.conv_seg(self.dropout(out))
return out
if __name__ == '__main__':
from semseg.models.backbones import ResNet
backbone = ResNet('50')
head = FPNHead([256, 512, 1024, 2048], 128, 19)
x = torch.randn(2, 3, 224, 224)
features = backbone(x)
out = head(features)
out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False)
print(out.shape)
================================================
FILE: semseg/models/heads/hem.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from fvcore.nn import flop_count_table, FlopCountAnalysis
class _MatrixDecomposition2DBase(nn.Module):
def __init__(self, args=dict()):
super().__init__()
self.spatial = args.setdefault('SPATIAL', True)
self.S = args.setdefault('MD_S', 1)
self.D = args.setdefault('MD_D', 512)
self.R = args.setdefault('MD_R', 64)
self.train_steps = args.setdefault('TRAIN_STEPS', 6)
self.eval_steps = args.setdefault('EVAL_STEPS', 7)
self.inv_t = args.setdefault('INV_T', 100)
self.eta = args.setdefault('ETA', 0.9)
self.rand_init = args.setdefault('RAND_INIT', True)
print('spatial', self.spatial)
print('S', self.S)
print('D', self.D)
print('R', self.R)
print('train_steps', self.train_steps)
print('eval_steps', self.eval_steps)
print('inv_t', self.inv_t)
print('eta', self.eta)
print('rand_init', self.rand_init)
def _build_bases(self, B, S, D, R, cuda=False):
raise NotImplementedError
def local_step(self, x, bases, coef):
raise NotImplementedError
# @torch.no_grad()
def local_inference(self, x, bases):
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
coef = torch.bmm(x.transpose(1, 2), bases)
coef = F.softmax(self.inv_t * coef, dim=-1)
steps = self.train_steps if self.training else self.eval_steps
for _ in range(steps):
bases, coef = self.local_step(x, bases, coef)
return bases, coef
def compute_coef(self, x, bases, coef):
raise NotImplementedError
def forward(self, x, return_bases=False):
B, C, H, W = x.shape
# (B, C, H, W) -> (B * S, D, N)
if self.spatial:
D = C // self.S
N = H * W
x = x.view(B * self.S, D, N)
else:
D = H * W
N = C // self.S
x = x.view(B * self.S, N, D).transpose(1, 2)
if not self.rand_init and not hasattr(self, 'bases'):
bases = self._build_bases(1, self.S, D, self.R, cuda=True)
self.register_buffer('bases', bases)
# (S, D, R) -> (B * S, D, R)
if self.rand_init:
bases = self._build_bases(B, self.S, D, self.R, cuda=True)
else:
bases = self.bases.repeat(B, 1, 1)
bases, coef = self.local_inference(x, bases)
# (B * S, N, R)
coef = self.compute_coef(x, bases, coef)
# (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
x = torch.bmm(bases, coef.transpose(1, 2))
# (B * S, D, N) -> (B, C, H, W)
if self.spatial:
x = x.view(B, C, H, W)
else:
x = x.transpose(1, 2).view(B, C, H, W)
# (B * H, D, R) -> (B, H, N, D)
bases = bases.view(B, self.S, D, self.R)
return x
class NMF2D(_MatrixDecomposition2DBase):
def __init__(self, args=dict()):
super().__init__(args)
self.inv_t = 1
def _build_bases(self, B, S, D, R, cuda=False):
if cuda:
bases = torch.rand((B * S, D, R)).cuda()
else:
bases = torch.rand((B * S, D, R))
bases = F.normalize(bases, dim=1)
return bases
# @torch.no_grad()
def local_step(self, x, bases, coef):
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
numerator = torch.bmm(x.transpose(1, 2), bases)
# (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
# Multiplicative Update
coef = coef * numerator / (denominator + 1e-6)
# (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
numerator = torch.bmm(x, coef)
# (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
# Multiplicative Update
bases = bases * numerator / (denominator + 1e-6)
return bases, coef
def compute_coef(self, x, bases, coef):
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
numerator = torch.bmm(x.transpose(1, 2), bases)
# (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
# multiplication update
coef = coef * numerator / (denominator + 1e-6)
return coef
class Hamburger(nn.Module):
def __init__(self, ham_channels=512, ham_kwargs=dict(), norm_cfg=None):
super().__init__()
self.ham_in = ConvModule(ham_channels, ham_channels, 1, norm_cfg=None, act_cfg=None)
self.ham = NMF2D(ham_kwargs)
self.ham_out = ConvModule(ham_channels, ham_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
def forward(self, x):
enjoy = self.ham_in(x)
enjoy = F.relu(enjoy, inplace=True)
enjoy = self.ham(enjoy)
enjoy = self.ham_out(enjoy)
ham = F.relu(x + enjoy, inplace=True)
return ham
class LightHamHead(nn.Module):
def __init__(self, in_channels=[64, 128, 320, 512], ham_channels=512, ham_kwargs=dict(), num_classes=25):
super().__init__()
self.in_channels = in_channels[1:]
self.in_index = [1,2,3]
self.ham_channels = self.channels = ham_channels
self.conv_cfg = None
self.norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
self.act_cfg = dict(type='ReLU')
self.ham_channels = ham_channels
self.squeeze = ConvModule(sum(self.in_channels), self.ham_channels, 1, conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)
self.hamburger = Hamburger(ham_channels, ham_kwargs, self.norm_cfg)
self.align = ConvModule(self.ham_channels, self.channels, 1, conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)
self.conv_seg = nn.Conv2d(self.channels, num_classes, kernel_size=1)
def forward(self, inputs):
"""Forward function."""
inputs = [inputs[i] for i in self.in_index]
inputs = [F.interpolate(level, size=inputs[0].shape[2:], mode='bilinear', align_corners=False) for level in inputs]
inputs = torch.cat(inputs, dim=1)
x = self.squeeze(inputs)
x = self.hamburger(x)
output = self.align(x)
output = self.conv_seg(output)
return output
if __name__ == '__main__':
model = LightHamHead(num_classes=25)
model = model.cuda()
x = [torch.zeros(1, 64, 256, 256), torch.ones(1, 128, 128, 128), torch.ones(1, 320, 64, 64)*2, torch.ones(1, 512, 32, 32) *3]
x = [xi.cuda() for xi in x]
outs = model(x)
print(model)
for y in outs:
print(y.shape)
print(flop_count_table(FlopCountAnalysis(model, x)))
================================================
FILE: semseg/models/heads/lawin.py
================================================
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from einops import rearrange
class MLP(nn.Module):
def __init__(self, dim=2048, embed_dim=768):
super().__init__()
self.proj = nn.Linear(dim, embed_dim)
def forward(self, x: Tensor) -> Tensor:
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class PatchEmbed(nn.Module):
def __init__(self, patch_size=4, in_ch=3, dim=96, type='pool') -> None:
super().__init__()
self.patch_size = patch_size
self.type = type
self.dim = dim
if type == 'conv':
self.proj = nn.Conv2d(in_ch, dim, patch_size, patch_size, groups=patch_size*patch_size)
else:
self.proj = nn.ModuleList([
nn.MaxPool2d(patch_size, patch_size),
nn.AvgPool2d(patch_size, patch_size)
])
self.norm = nn.LayerNorm(dim)
def forward(self, x: Tensor) -> Tensor:
_, _, H, W = x.shape
if W % self.patch_size != 0:
x = F.pad(x, (0, self.patch_size - W % self.patch_size))
if H % self.patch_size != 0:
x = F.pad(x, (0, 0, 0, self.patch_size - H % self.patch_size))
if self.type == 'conv':
x = self.proj(x)
else:
x = 0.5 * (self.proj[0](x) + self.proj[1](x))
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.dim, Wh, Ww)
return x
class LawinAttn(nn.Module):
def __init__(self, in_ch=512, head=4, patch_size=8, reduction=2) -> None:
super().__init__()
self.head = head
self.position_mixing = nn.ModuleList([
nn.Linear(patch_size * patch_size, patch_size * patch_size)
for _ in range(self.head)])
self.inter_channels = max(in_ch // reduction, 1)
self.g = nn.Conv2d(in_ch, self.inter_channels, 1)
self.theta = nn.Conv2d(in_ch, self.inter_channels, 1)
self.phi = nn.Conv2d(in_ch, self.inter_channels, 1)
self.conv_out = nn.Sequential(
nn.Conv2d(self.inter_channels, in_ch, 1, bias=False),
nn.BatchNorm2d(in_ch)
)
def forward(self, query: Tensor, context: Tensor) -> Tensor:
B, C, H, W = context.shape
context = context.reshape(B, C, -1)
context_mlp = []
for i, pm in enumerate(self.position_mixing):
context_crt = context[:, (C//self.head)*i:(C//self.head)*(i+1), :]
context_mlp.append(pm(context_crt))
context_mlp = torch.cat(context_mlp, dim=1)
context = context + context_mlp
context = context.reshape(B, C, H, W)
g_x = self.g(context).view(B, self.inter_channels, -1)
g_x = rearrange(g_x, "b (h dim) n -> (b h) dim n", h=self.head)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(query).view(B, self.inter_channels, -1)
theta_x = rearrange(theta_x, "b (h dim) n -> (b h) dim n", h=self.head)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(context).view(B, self.inter_channels, -1)
phi_x = rearrange(phi_x, "b (h dim) n -> (b h) dim n", h=self.head)
pairwise_weight = torch.matmul(theta_x, phi_x)
pairwise_weight /= theta_x.shape[-1]**0.5
pairwise_weight = pairwise_weight.softmax(dim=-1)
y = torch.matmul(pairwise_weight, g_x)
y = rearrange(y, "(b h) n dim -> b n (h dim)", h=self.head)
y = y.permute(0, 2, 1).contiguous().reshape(B, self.inter_channels, *query.shape[-2:])
output = query + self.conv_out(y)
return output
class ConvModule(nn.Module):
def __init__(self, c1, c2):
super().__init__()
self.conv = nn.Conv2d(c1, c2, 1, bias=False)
self.bn = nn.BatchNorm2d(c2) # use SyncBN in original
self.activate = nn.ReLU(True)
def forward(self, x: Tensor) -> Tensor:
return self.activate(self.bn(self.conv(x)))
class LawinHead(nn.Module):
def __init__(self, in_channels: list, embed_dim=512, num_classes=19) -> None:
super().__init__()
for i, dim in enumerate(in_channels):
self.add_module(f"linear_c{i+1}", MLP(dim, 48 if i == 0 else embed_dim))
self.lawin_8 = LawinAttn(embed_dim, 64)
self.lawin_4 = LawinAttn(embed_dim, 16)
self.lawin_2 = LawinAttn(embed_dim, 4)
self.ds_8 = PatchEmbed(8, embed_dim, embed_dim)
self.ds_4 = PatchEmbed(4, embed_dim, embed_dim)
self.ds_2 = PatchEmbed(2, embed_dim, embed_dim)
self.image_pool = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
ConvModule(embed_dim, embed_dim)
)
self.linear_fuse = ConvModule(embed_dim*3, embed_dim)
self.short_path = ConvModule(embed_dim, embed_dim)
self.cat = ConvModule(embed_dim*5, embed_dim)
self.low_level_fuse = ConvModule(embed_dim+48, embed_dim)
self.linear_pred = nn.Conv2d(embed_dim, num_classes, 1)
self.dropout = nn.Dropout2d(0.1)
def get_lawin_att_feats(self, x: Tensor, patch_size: int):
_, _, H, W = x.shape
query = F.unfold(x, patch_size, stride=patch_size)
query = rearrange(query, 'b (c ph pw) (nh nw) -> (b nh nw) c ph pw', ph=patch_size, pw=patch_size, nh=H//patch_size, nw=W//patch_size)
outs = []
for r in [8, 4, 2]:
context = F.unfold(x, patch_size*r, stride=patch_size, padding=int((r-1)/2*patch_size))
context = rearrange(context, "b (c ph pw) (nh nw) -> (b nh nw) c ph pw", ph=patch_size*r, pw=patch_size*r, nh=H//patch_size, nw=W//patch_size)
context = getattr(self, f"ds_{r}")(context)
output = getattr(self, f"lawin_{r}")(query, context)
output = rearrange(output, "(b nh nw) c ph pw -> b c (nh ph) (nw pw)", ph=patch_size, pw=patch_size, nh=H//patch_size, nw=W//patch_size)
outs.append(output)
return outs
def forward(self, features):
B, _, H, W = features[1].shape
outs = [self.linear_c2(features[1]).permute(0, 2, 1).reshape(B, -1, *features[1].shape[-2:])]
for i, feature in enumerate(features[2:]):
cf = eval(f"self.linear_c{i+3}")(feature).permute(0, 2, 1).reshape(B, -1, *feature.shape[-2:])
outs.append(F.interpolate(cf, size=(H, W), mode='bilinear', align_corners=False))
feat = self.linear_fuse(torch.cat(outs[::-1], dim=1))
B, _, H, W = feat.shape
## Lawin attention spatial pyramid pooling
feat_short = self.short_path(feat)
feat_pool = F.interpolate(self.image_pool(feat), size=(H, W), mode='bilinear', align_corners=False)
feat_lawin = self.get_lawin_att_feats(feat, 8)
output = self.cat(torch.cat([feat_short, feat_pool, *feat_lawin], dim=1))
## Low-level feature enhancement
c1 = self.linear_c1(features[0]).permute(0, 2, 1).reshape(B, -1, *features[0].shape[-2:])
output = F.interpolate(output, size=features[0].shape[-2:], mode='bilinear', align_corners=False)
fused = self.low_level_fuse(torch.cat([output, c1], dim=1))
seg = self.linear_pred(self.dropout(fused))
return seg
================================================
FILE: semseg/models/heads/segformer.py
================================================
import torch
from torch import nn, Tensor
from typing import Tuple
from torch.nn import functional as F
class MLP(nn.Module):
def __init__(self, dim, embed_dim):
super().__init__()
self.proj = nn.Linear(dim, embed_dim)
def forward(self, x: Tensor) -> Tensor:
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class ConvModule(nn.Module):
def __init__(self, c1, c2):
super().__init__()
self.conv = nn.Conv2d(c1, c2, 1, bias=False)
self.bn = nn.BatchNorm2d(c2) # use SyncBN in original
self.activate = nn.ReLU(True)
def forward(self, x: Tensor) -> Tensor:
return self.activate(self.bn(self.conv(x)))
class SegFormerHead(nn.Module):
def __init__(self, dims: list, embed_dim: int = 256, num_classes: int = 19):
super().__init__()
for i, dim in enumerate(dims):
self.add_module(f"linear_c{i+1}", MLP(dim, embed_dim))
self.linear_fuse = ConvModule(embed_dim*4, embed_dim)
self.linear_pred = nn.Conv2d(embed_dim, num_classes, 1)
self.dropout = nn.Dropout2d(0.1)
def forward(self, features: Tuple[Tensor, Tensor, Tensor, Tensor]) -> Tensor:
B, _, H, W = features[0].shape
outs = [self.linear_c1(features[0]).permute(0, 2, 1).reshape(B, -1, *features[0].shape[-2:])]
for i, feature in enumerate(features[1:]):
cf = eval(f"self.linear_c{i+2}")(feature).permute(0, 2, 1).reshape(B, -1, *feature.shape[-2:])
outs.append(F.interpolate(cf, size=(H, W), mode='bilinear', align_corners=False))
seg = self.linear_fuse(torch.cat(outs[::-1], dim=1))
seg = self.linear_pred(self.dropout(seg))
return seg
================================================
FILE: semseg/models/heads/sfnet.py
================================================
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from semseg.models.layers import ConvModule
from semseg.models.modules import PPM
class AlignedModule(nn.Module):
def __init__(self, c1, c2, k=3):
super().__init__()
self.down_h = nn.Conv2d(c1, c2, 1, bias=False)
self.down_l = nn.Conv2d(c1, c2, 1, bias=False)
self.flow_make = nn.Conv2d(c2 * 2, 2, k, 1, 1, bias=False)
def forward(self, low_feature: Tensor, high_feature: Tensor) -> Tensor:
high_feature_origin = high_feature
H, W = low_feature.shape[-2:]
low_feature = self.down_l(low_feature)
high_feature = self.down_h(high_feature)
high_feature = F.interpolate(high_feature, size=(H, W), mode='bilinear', align_corners=True)
flow = self.flow_make(torch.cat([high_feature, low_feature], dim=1))
high_feature = self.flow_warp(high_feature_origin, flow, (H, W))
return high_feature
def flow_warp(self, x: Tensor, flow: Tensor, size: tuple) -> Tensor:
norm = torch.tensor([[[[*size]]]]).type_as(x).to(x.device)
H = torch.linspace(-1.0, 1.0, size[0]).view(-1, 1).repeat(1, size[1])
W = torch.linspace(-1.0, 1.0, size[1]).repeat(size[0], 1)
grid = torch.cat((W.unsqueeze(2), H.unsqueeze(2)), dim=2)
grid = grid.repeat(x.shape[0], 1, 1, 1).type_as(x).to(x.device)
grid = grid + flow.permute(0, 2, 3, 1) / norm
output = F.grid_sample(x, grid, align_corners=False)
return output
class SFHead(nn.Module):
def __init__(self, in_channels, channel=256, num_classes=19, scales=(1, 2, 3, 6)):
super().__init__()
self.ppm = PPM(in_channels[-1], channel, scales)
self.fpn_in = nn.ModuleList([])
self.fpn_out = nn.ModuleList([])
self.fpn_out_align = nn.ModuleList([])
for in_ch in in_channels[:-1]:
self.fpn_in.append(ConvModule(in_ch, channel, 1))
self.fpn_out.append(ConvModule(channel, channel, 3, 1, 1))
self.fpn_out_align.append(AlignedModule(channel, channel//2))
self.bottleneck = ConvModule(len(in_channels) * channel, channel, 3, 1, 1)
self.dropout = nn.Dropout2d(0.1)
self.conv_seg = nn.Conv2d(channel, num_classes, 1)
def forward(self, features: list) -> Tensor:
f = self.ppm(features[-1])
fpn_features = [f]
for i in reversed(range(len(features) - 1)):
feature = self.fpn_in[i](features[i])
f = feature + self.fpn_out_align[i](feature, f)
fpn_features.append(self.fpn_out[i](f))
fpn_features.reverse()
for i in range(1, len(fpn_features)):
fpn_features[i] = F.interpolate(fpn_features[i], size=fpn_features[0].shape[-2:], mode='bilinear', align_corners=True)
output = self.bottleneck(torch.cat(fpn_features, dim=1))
output = self.conv_seg(self.dropout(output))
return output
================================================
FILE: semseg/models/heads/upernet.py
================================================
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from typing import Tuple
from semseg.models.layers import ConvModule
from semseg.models.modules import PPM
class UPerHead(nn.Module):
"""Unified Perceptual Parsing for Scene Understanding
https://arxiv.org/abs/1807.10221
scales: Pooling scales used in PPM module applied on the last feature
"""
def __init__(self, in_channels, channel=128, num_classes: int = 19, scales=(1, 2, 3, 6)):
super().__init__()
# PPM Module
self.ppm = PPM(in_channels[-1], channel, scales)
# FPN Module
self.fpn_in = nn.ModuleList()
self.fpn_out = nn.ModuleList()
for in_ch in in_channels[:-1]: # skip the top layer
self.fpn_in.append(ConvModule(in_ch, channel, 1))
self.fpn_out.append(ConvModule(channel, channel, 3, 1, 1))
self.bottleneck = ConvModule(len(in_channels)*channel, channel, 3, 1, 1)
self.dropout = nn.Dropout2d(0.1)
self.conv_seg = nn.Conv2d(channel, num_classes, 1)
def forward(self, features: Tuple[Tensor, Tensor, Tensor, Tensor]) -> Tensor:
f = self.ppm(features[-1])
fpn_features = [f]
for i in reversed(range(len(features)-1)):
feature = self.fpn_in[i](features[i])
f = feature + F.interpolate(f, size=feature.shape[-2:], mode='bilinear', align_corners=False)
fpn_features.append(self.fpn_out[i](f))
fpn_features.reverse()
for i in range(1, len(features)):
fpn_features[i] = F.interpolate(fpn_features[i], size=fpn_features[0].shape[-2:], mode='bilinear', align_corners=False)
output = self.bottleneck(torch.cat(fpn_features, dim=1))
output = self.conv_seg(self.dropout(output))
return output
if __name__ == '__main__':
model = UPerHead([64, 128, 256, 512], 128)
x1 = torch.randn(2, 64, 56, 56)
x2 = torch.randn(2, 128, 28, 28)
x3 = torch.randn(2, 256, 14, 14)
x4 = torch.randn(2, 512, 7, 7)
y = model([x1, x2, x3, x4])
print(y.shape)
================================================
FILE: semseg/models/layers/__init__.py
================================================
from .common import *
from .initialize import *
================================================
FILE: semseg/models/layers/common.py
================================================
import torch
from torch import nn, Tensor
class ConvModule(nn.Sequential):
def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1):
super().__init__(
nn.Conv2d(c1, c2, k, s, p, d, g, bias=False),
nn.BatchNorm2d(c2),
nn.ReLU(True)
)
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Copied from timm
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
def __init__(self, p: float = None):
super().__init__()
self.p = p
def forward(self, x: Tensor) -> Tensor:
if self.p == 0. or not self.training:
return x
kp = 1 - self.p
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = kp + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
return x.div(kp) * random_tensor
================================================
FILE: semseg/models/layers/initialize.py
================================================
import torch
import math
import warnings
from torch import nn, Tensor
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
================================================
FILE: semseg/models/modules/__init__.py
================================================
from .ppm import PPM
from .psa import PSAP, PSAS
__all__ = ['PPM', 'PSAP', 'PSAS']
================================================
FILE: semseg/models/modules/crossatt.py
================================================
import torch
from torch import nn
from einops import rearrange
from torch import einsum
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def stable_softmax(t, dim = -1):
t = t - t.amax(dim = dim, keepdim = True)
return t.softmax(dim = dim)
# bidirectional cross attention - have two sequences attend to each other with 1 attention step
class BidirectionalCrossAttention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, context_dim=None, dropout=0., talking_heads=False, prenorm=True,):
super().__init__()
context_dim = default(context_dim, dim)
self.norm = nn.LayerNorm(dim) if prenorm else nn.Identity()
self.context_norm = nn.LayerNorm(context_dim) if prenorm else nn.Identity()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.dropout = nn.Dropout(dropout)
self.context_dropout = nn.Dropout(dropout)
self.to_qk = nn.Linear(dim, inner_dim, bias = False)
self.context_to_qk = nn.Linear(context_dim, inner_dim, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.context_to_v = nn.Linear(context_dim, inner_dim, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.context_to_out = nn.Linear(inner_dim, context_dim)
self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
self.context_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
def forward(self, x, context, return_attn=False, rel_pos_bias=None):
b, i, j, h, device = x.shape[0], x.shape[-2], context.shape[-2], self.heads, x.device
x = self.norm(x)
context = self.context_norm(context)
# get shared query/keys and values for sequence and context
qk, v = self.to_qk(x), self.to_v(x)
context_qk, context_v = self.context_to_qk(context), self.context_to_v(context)
# split out head
qk, context_qk, v, context_v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (qk, context_qk, v, context_v))
# get similarities
sim = einsum('b h i d, b h j d -> b h i j', qk, context_qk) * self.scale
# relative positional bias, if supplied
if exists(rel_pos_bias):
sim = sim + rel_pos_bias
# get attention along both sequence length and context length dimensions
# shared similarity matrix
attn = stable_softmax(sim, dim = -1)
context_attn = stable_softmax(sim, dim = -2)
# dropouts
attn = self.dropout(attn)
context_attn = self.context_dropout(context_attn)
# talking heads
attn = self.talking_heads(attn)
context_attn = self.context_talking_heads(context_attn)
# src sequence aggregates values from context, context aggregates values from src sequence
out = einsum('b h i j, b h j d -> b h i d', attn, context_v)
context_out = einsum('b h j i, b h j d -> b h i d', context_attn, v)
# merge heads and combine out
out, context_out = map(lambda t: rearrange(t, 'b h n d -> b n (h d)'), (out, context_out))
out = self.to_out(out)
context_out = self.context_to_out(context_out)
if return_attn:
return out, context_out, attn, context_attn
return out, context_out
if __name__ == '__main__':
video = torch.randn(1, 4096, 512)
audio = torch.randn(1, 8192, 386)
joint_cross_attn = BidirectionalCrossAttention(dim = 512, heads = 8, dim_head = 64, context_dim = 386)
video_out, audio_out = joint_cross_attn(video, audio)
# attended output should have the same shape as input
assert video_out.shape == video.shape
assert audio_out.shape == audio.shape
================================================
FILE: semseg/models/modules/ffm.py
================================================
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
import math
# Feature Rectify Module
class ChannelWeights(nn.Module):
def __init__(self, dim, reduction=1):
super(ChannelWeights, self).__init__()
self.dim = dim
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.mlp = nn.Sequential(
nn.Linear(self.dim * 4, self.dim * 4 // reduction),
nn.ReLU(inplace=True),
nn.Linear(self.dim * 4 // reduction, self.dim * 2),
nn.Sigmoid())
def forward(self, x1, x2):
B, _, H, W = x1.shape
x = torch.cat((x1, x2), dim=1)
avg = self.avg_pool(x).view(B, self.dim * 2)
max = self.max_pool(x).view(B, self.dim * 2)
y = torch.cat((avg, max), dim=1) # B 4C
y = self.mlp(y).view(B, self.dim * 2, 1)
channel_weights = y.reshape(B, 2, self.dim, 1, 1).permute(1, 0, 2, 3, 4) # 2 B C 1 1
return channel_weights
class SpatialWeights(nn.Module):
def __init__(self, dim, reduction=1):
super(SpatialWeights, self).__init__()
self.dim = dim
self.mlp = nn.Sequential(
nn.Conv2d(self.dim * 2, self.dim // reduction, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(self.dim // reduction, 2, kernel_size=1),
nn.Sigmoid())
def forward(self, x1, x2):
B, _, H, W = x1.shape
x = torch.cat((x1, x2), dim=1) # B 2C H W
spatial_weights = self.mlp(x).reshape(B, 2, 1, H, W).permute(1, 0, 2, 3, 4) # 2 B 1 H W
return spatial_weights
class FeatureRectifyModule(nn.Module):
def __init__(self, dim, reduction=1, lambda_c=.5, lambda_s=.5):
super(FeatureRectifyModule, self).__init__()
self.lambda_c = lambda_c
self.lambda_s = lambda_s
self.channel_weights = ChannelWeights(dim=dim, reduction=reduction)
self.spatial_weights = SpatialWeights(dim=dim, reduction=reduction)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x1, x2):
channel_weights = self.channel_weights(x1, x2)
spatial_weights = self.spatial_weights(x1, x2)
out_x1 = x1 + self.lambda_c * channel_weights[1] * x2 + self.lambda_s * spatial_weights[1] * x2
out_x2 = x2 + self.lambda_c * channel_weights[0] * x1 + self.lambda_s * spatial_weights[0] * x1
return out_x1, out_x2
# Stage 1
class CrossAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None):
super(CrossAttention, self).__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.kv1 = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.kv2 = nn.Linear(dim, dim * 2, bias=qkv_bias)
def forward(self, x1, x2):
B, N, C = x1.shape
q1 = x1.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()
q2 = x2.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()
k1, v1 = self.kv1(x1).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
k2, v2 = self.kv2(x2).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
ctx1 = (k1.transpose(-2, -1) @ v1) * self.scale
ctx1 = ctx1.softmax(dim=-2)
ctx2 = (k2.transpose(-2, -1) @ v2) * self.scale
ctx2 = ctx2.softmax(dim=-2)
x1 = (q1 @ ctx2).permute(0, 2, 1, 3).reshape(B, N, C).contiguous()
x2 = (q2 @ ctx1).permute(0, 2, 1, 3).reshape(B, N, C).contiguous()
return x1, x2
class CrossPath(nn.Module):
def __init__(self, dim, reduction=1, num_heads=None, norm_layer=nn.LayerNorm):
super().__init__()
self.channel_proj1 = nn.Linear(dim, dim // reduction * 2)
self.channel_proj2 = nn.Linear(dim, dim // reduction * 2)
self.act1 = nn.ReLU(inplace=True)
self.act2 = nn.ReLU(inplace=True)
self.cross_attn = CrossAttention(dim // reduction, num_heads=num_heads)
self.end_proj1 = nn.Linear(dim // reduction * 2, dim)
self.end_proj2 = nn.Linear(dim // reduction * 2, dim)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
def forward(self, x1, x2):
y1, u1 = self.act1(self.channel_proj1(x1)).chunk(2, dim=-1)
y2, u2 = self.act2(self.channel_proj2(x2)).chunk(2, dim=-1)
v1, v2 = self.cross_attn(u1, u2)
y1 = torch.cat((y1, v1), dim=-1)
y2 = torch.cat((y2, v2), dim=-1)
out_x1 = self.norm1(x1 + self.end_proj1(y1))
out_x2 = self.norm2(x2 + self.end_proj2(y2))
return out_x1, out_x2
# Stage 2
class ChannelEmbed(nn.Module):
def __init__(self, in_channels, out_channels, reduction=1, norm_layer=nn.BatchNorm2d):
super(ChannelEmbed, self).__init__()
self.out_channels = out_channels
self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self.channel_embed = nn.Sequential(
nn.Conv2d(in_channels, out_channels//reduction, kernel_size=1, bias=True),
nn.Conv2d(out_channels//reduction, out_channels//reduction, kernel_size=3, stride=1, padding=1, bias=True, groups=out_channels//reduction),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels//reduction, out_channels, kernel_size=1, bias=True),
norm_layer(out_channels)
)
self.norm = norm_layer(out_channels)
def forward(self, x, H, W):
B, N, _C = x.shape
x = x.permute(0, 2, 1).reshape(B, _C, H, W).contiguous()
residual = self.residual(x)
x = self.channel_embed(x)
out = self.norm(residual + x)
return out
class FeatureFusionModule(nn.Module):
def __init__(self, dim, reduction=1, num_heads=None, norm_layer=nn.BatchNorm2d):
super().__init__()
self.cross = CrossPath(dim=dim, reduction=reduction, num_heads=num_heads)
self.channel_emb = ChannelEmbed(in_channels=dim*2, out_channels=dim, reduction=reduction, norm_layer=norm_layer)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x1, x2):
B, C, H, W = x1.shape
x1 = x1.flatten(2).transpose(1, 2)
x2 = x2.flatten(2).transpose(1, 2)
x1, x2 = self.cross(x1, x2)
merge = torch.cat((x1, x2), dim=-1)
merge = self.channel_emb(merge, H, W)
return merge
================================================
FILE: semseg/models/modules/mspa.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd
from timm.models.layers import DropPath
from fvcore.nn import flop_count_table, FlopCountAnalysis
from mmcv.cnn import build_norm_layer
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x):
x = self.dwconv(x)
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class MSPoolAttention(nn.Module):
def __init__(self, dim):
super().__init__()
pools = [3,7,11]
self.conv0 = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
self.pool1 = nn.AvgPool2d(pools[0], stride=1, padding=pools[0]//2, count_include_pad=False)
self.pool2 = nn.AvgPool2d(pools[1], stride=1, padding=pools[1]//2, count_include_pad=False)
self.pool3 = nn.AvgPool2d(pools[2], stride=1, padding=pools[2]//2, count_include_pad=False)
self.conv4 = nn.Conv2d(dim, dim, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
u = x.clone()
x_in = self.conv0(x)
x_1 = self.pool1(x_in)
x_2 = self.pool2(x_in)
x_3 = self.pool3(x_in)
x_out = self.sigmoid(self.conv4(x_in + x_1 + x_2 + x_3)) * u
return x_out + u
class MSPABlock(nn.Module):
def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_cfg=dict(type='BN', requires_grad=True)):
super().__init__()
self.norm1 = build_norm_layer(norm_cfg, dim)[1]
self.attn = MSPoolAttention(dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = build_norm_layer(norm_cfg, dim)[1]
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
layer_scale_init_value = 1e-2
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.is_channel_mix = True
if self.is_channel_mix:
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.c_nets = nn.Sequential(
nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False),
nn.Sigmoid())
def forward(self, x):
x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))
if self.is_channel_mix:
x_c = self.avg_pool(x)
x_c = self.c_nets(x_c.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
x_c = x_c.expand_as(x)
x_c_mix = x_c * x
x_mlp = self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
x = x_c_mix + x_mlp
else:
x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
return x
if __name__ == '__main__':
x = torch.zeros(2, 224*224, 64)
c1 = MSDyBlock(64, 64)
outs = c1(x)
print(outs.shape)
print(c1)
print(flop_count_table(FlopCountAnalysis(c1, x)))
================================================
FILE: semseg/models/modules/ppm.py
================================================
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from semseg.models.layers import ConvModule
class PPM(nn.Module):
"""Pyramid Pooling Module in PSPNet
"""
def __init__(self, c1, c2=128, scales=(1, 2, 3, 6)):
super().__init__()
self.stages = nn.ModuleList([
nn.Sequential(
nn.AdaptiveAvgPool2d(scale),
ConvModule(c1, c2, 1)
# ConvModule(c1, c2, 1, p=1)
)
for scale in scales])
self.bottleneck = ConvModule(c1 + c2 * len(scales), c2, 3, 1, 1)
def forward(self, x: Tensor) -> Tensor:
outs = []
for stage in self.stages:
outs.append(F.interpolate(stage(x), size=x.shape[-2:], mode='bilinear', align_corners=True))
outs = [x] + outs[::-1]
out = self.bottleneck(torch.cat(outs, dim=1))
return out
if __name__ == '__main__':
model = PPM(512, 128)
x = torch.randn(2, 512, 7, 7)
y = model(x)
print(y.shape) # [2, 128, 7, 7]
================================================
FILE: semseg/models/modules/psa.py
================================================
import torch
from torch import nn, Tensor
from torch.nn import functional as F
class PSAP(nn.Module):
def __init__(self, c1, c2):
super().__init__()
ch = c2 // 2
self.conv_q_right = nn.Conv2d(c1, 1, 1, bias=False)
self.conv_v_right = nn.Conv2d(c1, ch, 1, bias=False)
self.conv_up = nn.Conv2d(ch, c2, 1, bias=False)
self.conv_q_left = nn.Conv2d(c1, ch, 1, bias=False)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_v_left = nn.Conv2d(c1, ch, 1, bias=False)
def spatial_pool(self, x: Tensor) -> Tensor:
input_x = self.conv_v_right(x) # [B, C, H, W]
context_mask = self.conv_q_right(x) # [B, 1, H, W]
B, C, _, _ = input_x.shape
input_x = input_x.view(B, C, -1)
context_mask = context_mask.view(B, 1, -1).softmax(dim=2)
context = input_x @ context_mask.transpose(1, 2)
context = self.conv_up(context.unsqueeze(-1)).sigmoid()
x *= context
return x
def channel_pool(self, x: Tensor) -> Tensor:
g_x = self.conv_q_left(x)
B, C, H, W = g_x.shape
avg_x = self.avg_pool(g_x).view(B, C, -1).permute(0, 2, 1)
theta_x = self.conv_v_left(x).view(B, C, -1)
context = avg_x @ theta_x
context = context.softmax(dim=2).view(B, 1, H, W).sigmoid()
x *= context
return x
def forward(self, x: Tensor) -> Tensor:
return self.spatial_pool(x) + self.channel_pool(x)
class PSAS(nn.Module):
def __init__(self, c1, c2):
super().__init__()
ch = c2 // 2
self.conv_q_right = nn.Conv2d(c1, 1, 1, bias=False)
self.conv_v_right = nn.Conv2d(c1, ch, 1, bias=False)
self.conv_up = nn.Sequential(
nn.Conv2d(ch, ch // 4, 1),
nn.LayerNorm([ch // 4, 1, 1]),
nn.ReLU(),
nn.Conv2d(ch // 4, c2, 1)
)
self.conv_q_left = nn.Conv2d(c1, ch, 1, bias=False)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_v_left = nn.Conv2d(c1, ch, 1, bias=False)
def spatial_pool(self, x: Tensor) -> Tensor:
input_x = self.conv_v_right(x) # [B, C, H, W]
context_mask = self.conv_q_right(x) # [B, 1, H, W]
B, C, _, _ = input_x.shape
input_x = input_x.view(B, C, -1)
context_mask = context_mask.view(B, 1, -1).softmax(dim=2)
context = input_x @ context_mask.transpose(1, 2)
context = self.conv_up(context.unsqueeze(-1)).sigmoid()
x *= context
return x
def channel_pool(self, x: Tensor) -> Tensor:
g_x = self.conv_q_left(x)
B, C, H, W = g_x.shape
avg_x = self.avg_pool(g_x).view(B, C, -1).permute(0, 2, 1)
theta_x = self.conv_v_left(x).view(B, C, -1).softmax(dim=2)
context = avg_x @ theta_x
context = context.view(B, 1, H, W).sigmoid()
x *= context
return x
def forward(self, x: Tensor) -> Tensor:
return self.channel_pool(self.spatial_pool(x))
"""
PSA Module Usage
class BasicBlock(nn.Module):
# 2 Layer No Expansion Block
expansion: int = 1
def __init__(self, c1, c2, s=1, downsample= None) -> None:
super().__init__()
self.conv1 = nn.Conv2d(c1, c2, 3, s, 1, bias=False)
self.bn1 = nn.BatchNorm2d(c2)
self.deattn = PSAS(c2, c2)
self.conv2 = nn.Conv2d(c2, c2, 3, 1, 1, bias=False)
self.bn2 = nn.BatchNorm2d(c2)
self.downsample = downsample
def forward(self, x: Tensor) -> Tensor:
identity = x
out = F.relu(self.bn1(self.conv1(x)))
out = self.deattn(out)
out = self.bn2(self.conv2(out))
if self.downsample is not None: identity = self.downsample(x)
out += identity
return F.relu(out)
class Bottleneck(nn.Module):
# 3 Layer 4x Expansion Block
expansion: int = 4
def __init__(self, c1, c2, s=1, downsample=None) -> None:
super().__init__()
self.conv1 = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
self.bn1 = nn.BatchNorm2d(c2)
self.conv2 = nn.Conv2d(c2, c2, 3, s, 1, bias=False)
self.bn2 = nn.BatchNorm2d(c2)
self.deattn = PSAP(c2, c2)
self.conv3 = nn.Conv2d(c2, c2 * self.expansion, 1, 1, 0, bias=False)
self.bn3 = nn.BatchNorm2d(c2 * self.expansion)
self.downsample = downsample
def forward(self, x: Tensor) -> Tensor:
identity = x
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.deattn(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None: identity = self.downsample(x)
out += identity
return F.relu(out)
resnet_settings = {
'18': [BasicBlock, [2, 2, 2, 2]],
'34': [BasicBlock, [3, 4, 6, 3]],
'50': [Bottleneck, [3, 4, 6, 3]],
'101': [Bottleneck, [3, 4, 23, 3]],
'152': [Bottleneck, [3, 8, 36, 3]]
}
class ResNet(nn.Module):
def __init__(self, model_name: str = '50') -> None:
super().__init__()
assert model_name in resnet_settings.keys(), f"ResNet model name should be in {list(resnet_settings.keys())}"
block, depths = resnet_settings[model_name]
self.inplanes = 64
self.conv1 = nn.Conv2d(3, self.inplanes, 7, 2, 3, bias=False)
self.bn1 = nn.BatchNorm2d(self.inplanes)
self.maxpool = nn.MaxPool2d(3, 2, 1)
self.layer1 = self._make_layer(block, 64, depths[0], s=1)
self.layer2 = self._make_layer(block, 128, depths[1], s=2)
self.layer3 = self._make_layer(block, 256, depths[2], s=2)
self.layer4 = self._make_layer(block, 512, depths[3], s=2)
def _make_layer(self, block, planes, depth, s=1) -> nn.Sequential:
downsample = None
if s != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, 1, s, bias=False),
nn.BatchNorm2d(planes * block.expansion)
)
layers = nn.Sequential(
block(self.inplanes, planes, s, downsample),
*[block(planes * block.expansion, planes) for _ in range(1, depth)]
)
self.inplanes = planes * block.expansion
return layers
def forward(self, x: Tensor) -> Tensor:
x = self.maxpool(F.relu(self.bn1(self.conv1(x)))) # [1, 64, H/4, W/4]
x1 = self.layer1(x) # [1, 64/256, H/4, W/4]
x2 = self.layer2(x1) # [1, 128/512, H/8, W/8]
x3 = self.layer3(x2) # [1, 256/1024, H/16, W/16]
x4 = self.layer4(x3) # [1, 512/2048, H/32, W/32]
return x1, x2, x3, x4
if __name__ == '__main__':
model = ResNet('18')
x = torch.zeros(2, 3, 224, 224)
outs = model(x)
for y in outs:
print(y.shape)
"""
================================================
FILE: semseg/optimizers.py
================================================
from torch import nn
from torch.optim import AdamW, SGD
def get_optimizer(model: nn.Module, optimizer: str, lr: float, weight_decay: float = 0.01):
wd_params, nwd_params = [], []
for p in model.parameters():
if p.requires_grad:
if p.dim() == 1:
nwd_params.append(p)
else:
wd_params.append(p)
params = [
{"params": wd_params},
{"params": nwd_params, "weight_decay": 0}
]
if optimizer == 'adamw':
return AdamW(params, lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=weight_decay)
else:
return SGD(params, lr, momentum=0.9, weight_decay=weight_decay)
================================================
FILE: semseg/schedulers.py
================================================
import torch
import math
from torch.optim.lr_scheduler import _LRScheduler
class PolyLR(_LRScheduler):
def __init__(self, optimizer, max_iter, decay_iter=1, power=0.9, last_epoch=-1) -> None:
self.decay_iter = decay_iter
self.max_iter = max_iter
self.power = power
super().__init__(optimizer, last_epoch=last_epoch)
def get_lr(self):
if self.last_epoch % self.decay_iter or self.last_epoch % self.max_iter:
return self.base_lrs
else:
factor = (1 - self.last_epoch / float(self.max_iter)) ** self.power
return [factor*lr for lr in self.base_lrs]
class WarmupLR(_LRScheduler):
def __init__(self, optimizer, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None:
self.warmup_iter = warmup_iter
self.warmup_ratio = warmup_ratio
self.warmup = warmup
super().__init__(optimizer, last_epoch)
def get_lr(self):
ratio = self.get_lr_ratio()
return [ratio * lr for lr in self.base_lrs]
def get_lr_ratio(self):
return self.get_warmup_ratio() if self.last_epoch < self.warmup_iter else self.get_main_ratio()
def get_main_ratio(self):
raise NotImplementedError
def get_warmup_ratio(self):
assert self.warmup in ['linear', 'exp']
alpha = self.last_epoch / self.warmup_iter
return self.warmup_ratio + (1. - self.warmup_ratio) * alpha if self.warmup == 'linear' else self.warmup_ratio ** (1. - alpha)
class WarmupPolyLR(WarmupLR):
def __init__(self, optimizer, power, max_iter, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None:
self.power = power
self.max_iter = max_iter
super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch)
def get_main_ratio(self):
real_iter = self.last_epoch - self.warmup_iter
real_max_iter = self.max_iter - self.warmup_iter
alpha = real_iter / real_max_iter
return (1 - alpha) ** self.power
class WarmupExpLR(WarmupLR):
def __init__(self, optimizer, gamma, interval=1, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None:
self.gamma = gamma
self.interval = interval
super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch)
def get_main_ratio(self):
real_iter = self.last_epoch - self.warmup_iter
return self.gamma ** (real_iter // self.interval)
class WarmupCosineLR(WarmupLR):
def __init__(self, optimizer, max_iter, eta_ratio=0, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None:
self.eta_ratio = eta_ratio
self.max_iter = max_iter
super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch)
def get_main_ratio(self):
real_iter = self.last_epoch - self.warmup_iter
real_max_iter = self.max_iter - self.warmup_iter
return self.eta_ratio + (1 - self.eta_ratio) * (1 + math.cos(math.pi * self.last_epoch / real_max_iter)) / 2
__all__ = ['polylr', 'warmuppolylr', 'warmupcosinelr', 'warmupsteplr']
def get_scheduler(scheduler_name: str, optimizer, max_iter: int, power: int, warmup_iter: int, warmup_ratio: float):
assert scheduler_name in __all__, f"Unavailable scheduler name >> {scheduler_name}.\nAvailable schedulers: {__all__}"
if scheduler_name == 'warmuppolylr':
return WarmupPolyLR(optimizer, power, max_iter, warmup_iter, warmup_ratio, warmup='linear')
elif scheduler_name == 'warmupcosinelr':
return WarmupCosineLR(optimizer, max_iter, warmup_iter=warmup_iter, warmup_ratio=warmup_ratio)
return PolyLR(optimizer, max_iter)
if __name__ == '__main__':
model = torch.nn.Conv2d(3, 16, 3, 1, 1)
optim = torch.optim.SGD(model.parameters(), lr=1e-3)
max_iter = 20000
sched = WarmupPolyLR(optim, power=0.9, max_iter=max_iter, warmup_iter=200, warmup_ratio=0.1, warmup='exp', last_epoch=-1)
lrs = []
for _ in range(max_iter):
lr = sched.get_lr()[0]
lrs.append(lr)
optim.step()
sched.step()
import matplotlib.pyplot as plt
import numpy as np
plt.plot(np.arange(len(lrs)), np.array(lrs))
plt.grid()
plt.show()
================================================
FILE: semseg/utils/__init__.py
================================================
================================================
FILE: semseg/utils/utils.py
================================================
import torch
import numpy as np
import random
import time
import os
import sys
import functools
from pathlib import Path
from torch.backends import cudnn
from torch import nn, Tensor
from torch.autograd import profiler
from typing import Union
from torch import distributed as dist
from tabulate import tabulate
from semseg import models
import logging
from fvcore.nn import flop_count_table, FlopCountAnalysis
import datetime
def fix_seeds(seed: int = 3407) -> None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def setup_cudnn() -> None:
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
cudnn.benchmark = True
cudnn.deterministic = False
def time_sync() -> float:
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.time()
def get_model_size(model: Union[nn.Module, torch.jit.ScriptModule]):
tmp_model_path = Path('temp.p')
if isinstance(model, torch.jit.ScriptModule):
torch.jit.save(model, tmp_model_path)
else:
torch.save(model.state_dict(), tmp_model_path)
size = tmp_model_path.stat().st_size
os.remove(tmp_model_path)
return size / 1e6 # in MB
@torch.no_grad()
def test_model_latency(model: nn.Module, inputs: torch.Tensor, use_cuda: bool = False) -> float:
with profiler.profile(use_cuda=use_cuda) as prof:
_ = model(inputs)
return prof.self_cpu_time_total / 1000 # ms
def count_parameters(model: nn.Module) -> float:
return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 # in M
def setup_ddp():
# print(os.environ.keys())
if 'SLURM_PROCID' in os.environ and not 'RANK' in os.environ:
# --- multi nodes
world_size = int(os.environ['WORLD_SIZE'])
rank = int(os.environ["SLURM_PROCID"])
gpus_per_node = int(os.environ["SLURM_GPUS_ON_NODE"])
gpu = rank - gpus_per_node * (rank // gpus_per_node)
torch.cuda.set_device(gpu)
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank, timeout=datetime.timedelta(seconds=7200))
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
# gpu = int(os.environ(['LOCAL_RANK']))
# ---
gpu = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(gpu)
dist.init_process_group('nccl', init_method="env://",world_size=world_size, rank=rank, timeout=datetime.timedelta(seconds=7200))
dist.barrier()
else:
gpu = 0
return gpu
def cleanup_ddp():
if dist.is_initialized():
dist.destroy_process_group()
def reduce_tensor(tensor: Tensor) -> Tensor:
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= dist.get_world_size()
return rt
@torch.no_grad()
def throughput(dataloader, model: nn.Module, times: int = 30):
model.eval()
images, _ = next(iter(dataloader))
images = images.cuda(non_blocking=True)
B = images.shape[0]
print(f"Throughput averaged with {times} times")
start = time_sync()
for _ in range(times):
model(images)
end = time_sync()
print(f"Batch Size {B} throughput {times * B / (end - start)} images/s")
def show_models():
model_names = models.__all__
model_variants = [list(eval(f'models.{name.lower()}_settings').keys()) for name in model_names]
print(tabulate({'Model Names': model_names, 'Model Variants': model_variants}, headers='keys'))
def timer(func):
@functools.wraps(func)
def wrapper_timer(*args, **kwargs):
tic = time.perf_counter()
value = func(*args, **kwargs)
toc = time.perf_counter()
elapsed_time = toc - tic
print(f"Elapsed time: {elapsed_time * 1000:.2f}ms")
return value
return wrapper_timer
# _default_level_name = os.getenv('ENGINE_LOGGING_LEVEL', 'INFO')
# _default_level = logging.getLevelName(_default_level_name.upper())
def get_logger(log_file=None):
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s: - %(message)s',datefmt='%Y%m%d %H:%M:%S')
logger = logging.getLogger()
# logger.setLevel(logging.DEBUG)
logger.setLevel(logging.INFO)
del logger.handlers[:]
if log_file:
file_handler = logging.FileHandler(log_file, mode='w')
# file_handler.setLevel(logging.DEBUG)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
# stream_handler.setLevel(logging.DEBUG)
stream_handler.setLevel(logging.INFO)
logger.addHandler(stream_handler)
return logger
def cal_flops(model, modals, logger):
x = [torch.zeros(1, 3, 512, 512) for _ in range(len(modals))]
# x = [torch.zeros(2, 3, 512, 512) for _ in range(len(modals))] #--- PGSNet
# x = [torch.zeros(1, 3, 512, 512) for _ in range(len(modals))] # --- for HRFuser
if torch.distributed.is_initialized():
if 'HR' in model.module.__class__.__name__:
x = [torch.zeros(1, 3, 512, 512) for _ in range(len(modals))] # --- for HorNet
else:
if 'HR' in model.__class__.__name__:
x = [torch.zeros(1, 3, 512, 512) for _ in range(len(modals))] # --- for HorNet
if torch.cuda.is_available:
x = [xi.cuda() for xi in x]
model = model.cuda()
logger.info(flop_count_table(FlopCountAnalysis(model, x)))
def print_iou(epoch, iou, miou, acc, macc, class_names):
assert len(iou) == len(class_names)
assert len(acc) == len(class_names)
lines = ['\n%-8s\t%-8s\t%-8s' % ('Class', 'IoU', 'Acc')]
for i in range(len(iou)):
if class_names is None:
cls = 'Class %d:' % (i+1)
else:
cls = '%d %s' % (i+1, class_names[i])
lines.append('%-8s\t%.2f\t%.2f' % (cls, iou[i], acc[i]))
lines.append('== %-8s\t%d\t%-8s\t%.2f\t%-8s\t%.2f' % ('Epoch:', epoch, 'mean_IoU', miou, 'mean_Acc',macc))
line = "\n".join(lines)
return line
def nchw_to_nlc(x):
"""Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
Args:
x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
Returns:
Tensor: The output tensor of shape [N, L, C] after conversion.
"""
assert len(x.shape) == 4
return x.flatten(2).transpose(1, 2).contiguous()
def nlc_to_nchw(x, hw_shape):
"""Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
Args:
x (Tensor): The input tensor of shape [N, L, C] before conversion.
hw_shape (Sequence[int]): The height and width of output feature map.
Returns:
Tensor: The output tensor of shape [N, C, H, W] after conversion.
"""
H, W = hw_shape
assert len(x.shape) == 3
B, L, C = x.shape
assert L == H * W, 'The seq_len does not match H, W'
return x.transpose(1, 2).reshape(B, C, H, W).contiguous()
def nlc2nchw2nlc(module, x, hw_shape, contiguous=False, **kwargs):
"""Convert [N, L, C] shape tensor `x` to [N, C, H, W] shape tensor. Use the
reshaped tensor as the input of `module`, and convert the output of
`module`, whose shape is.
[N, C, H, W], to [N, L, C].
Args:
module (Callable): A callable object the takes a tensor
with shape [N, C, H, W] as input.
x (Tensor): The input tensor of shape [N, L, C].
hw_shape: (Sequence[int]): The height and width of the
feature map with shape [N, C, H, W].
contiguous (Bool): Whether to make the tensor contiguous
after each shape transform.
Returns:
Tensor: The output tensor of shape [N, L, C].
Example:
>>> import torch
>>> import torch.nn as nn
>>> conv = nn.Conv2d(16, 16, 3, 1, 1)
>>> feature_map = torch.rand(4, 25, 16)
>>> output = nlc2nchw2nlc(conv, feature_map, (5, 5))
"""
H, W = hw_shape
assert len(x.shape) == 3
B, L, C = x.shape
assert L == H * W, 'The seq_len doesn\'t match H, W'
if not contiguous:
x = x.transpose(1, 2).reshape(B, C, H, W)
x = module(x, **kwargs)
x = x.flatten(2).transpose(1, 2)
else:
x = x.transpose(1, 2).reshape(B, C, H, W).contiguous()
x = module(x, **kwargs)
x = x.flatten(2).transpose(1, 2).contiguous()
return x
================================================
FILE: semseg/utils/visualize.py
================================================
import torch
import random
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.utils import make_grid
from semseg.augmentations import Compose, Normalize, RandomResizedCrop
from PIL import Image, ImageDraw, ImageFont
def visualize_dataset_sample(dataset, root, split='val', batch_size=4):
transform = Compose([
RandomResizedCrop((512, 512), scale=(1.0, 1.0)),
Normalize()
])
dataset = dataset(root, split=split, transform=transform)
dataloader = DataLoader(dataset, shuffle=True, batch_size=batch_size)
image, label = next(iter(dataloader))
print(f"Image Shape\t: {image.shape}")
print(f"Label Shape\t: {label.shape}")
print(f"Classes\t\t: {label.unique().tolist()}")
label[label == -1] = 0
label[label == 255] = 0
labels = [dataset.PALETTE[lbl.to(int)].permute(2, 0, 1) for lbl in label]
labels = torch.stack(labels)
inv_normalize = T.Normalize(
mean=(-0.485/0.229, -0.456/0.224, -0.406/0.225),
std=(1/0.229, 1/0.224, 1/0.225)
)
image = inv_normalize(image)
image *= 255
images = torch.vstack([image, labels])
plt.imshow(make_grid(images, nrow=4).to(torch.uint8).numpy().transpose((1, 2, 0)))
plt.show()
colors = [
[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], [102, 255, 0], [92, 0, 255]
]
def generate_palette(num_classes, background: bool = False):
random.shuffle(colors)
if background:
palette = [[0, 0, 0]]
palette += colors[:num_classes-1]
else:
palette = colors[:num_classes]
return np.array(palette)
def draw_text(image: torch.Tensor, seg_map: torch.Tensor, labels: list, fontsize: int = 15):
image = image.to(torch.uint8)
font = ImageFont.truetype("Helvetica.ttf", fontsize)
pil_image = Image.fromarray(image.numpy())
draw = ImageDraw.Draw(pil_image)
indices = seg_map.unique().tolist()
classes = [labels[index] for index in indices]
for idx, cls in zip(indices, classes):
mask = seg_map == idx
mask = mask.squeeze().numpy()
center = np.median((mask == 1).nonzero(), axis=1)[::-1]
bbox = draw.textbbox(center, cls, font=font)
bbox = (bbox[0]-3, bbox[1]-3, bbox[2]+3, bbox[3]+3)
draw.rectangle(bbox, fill=(255, 255, 255), width=1)
draw.text(center, cls, fill=(0, 0, 0), font=font)
return pil_image
================================================
FILE: tools/infer_mm.py
================================================
import torch
import argparse
import yaml
import math
from torch import Tensor
from torch.nn import functional as F
from pathlib import Path
from torchvision import io
from torchvision import transforms as T
import torchvision.transforms.functional as TF
from semseg.models import *
from semseg.datasets import *
from semseg.utils.utils import timer
from semseg.utils.visualize import draw_text
import glob
import os
from PIL import Image, ImageDraw, ImageFont
class SemSeg:
def __init__(self, cfg) -> None:
# inference device cuda or cpu
self.device = torch.device(cfg['DEVICE'])
# get dataset classes' colors and labels
self.palette = eval(cfg['DATASET']['NAME']).PALETTE
self.labels = eval(cfg['DATASET']['NAME']).CLASSES
# initialize the model and load weights and send to device
self.model = eval(cfg['MODEL']['NAME'])(cfg['MODEL']['BACKBONE'], len(self.palette), cfg['DATASET']['MODALS'])
msg = self.model.load_state_dict(torch.load(cfg['EVAL']['MODEL_PATH'], map_location='cpu'))
print(msg)
self.model = self.model.to(self.device)
self.model.eval()
# preprocess parameters and transformation pipeline
self.size = cfg['TEST']['IMAGE_SIZE']
self.tf_pipeline_img = T.Compose([
T.Resize(self.size),
T.Lambda(lambda x: x / 255),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
T.Lambda(lambda x: x.unsqueeze(0))
])
self.tf_pipeline_modal = T.Compose([
T.Resize(self.size),
T.Lambda(lambda x: x / 255),
T.Lambda(lambda x: x.unsqueeze(0))
])
def postprocess(self, orig_img: Tensor, seg_map: Tensor, overlay: bool) -> Tensor:
seg_map = seg_map.softmax(dim=1).argmax(dim=1).cpu().to(int)
seg_image = self.palette[seg_map].squeeze()
if overlay:
seg_image = (orig_img.permute(1, 2, 0) * 0.4) + (seg_image * 0.6)
image = seg_image.to(torch.uint8)
pil_image = Image.fromarray(image.numpy())
return pil_image
@torch.inference_mode()
@timer
def model_forward(self, img: Tensor) -> Tensor:
return self.model(img)
def _open_img(self, file):
img = io.read_image(file)
C, H, W = img.shape
if C == 4:
img = img[:3, ...]
if C == 1:
img = img.repeat(3, 1, 1)
return img
def predict(self, img_fname: str, overlay: bool) -> Tensor:
if cfg['DATASET']['NAME'] == 'DELIVER':
x1 = img_fname.replace('/img', '/hha').replace('_rgb', '_depth')
x2 = img_fname.replace('/img', '/lidar').replace('_rgb', '_lidar')
x3 = img_fname.replace('/img', '/event').replace('_rgb', '_event')
lbl_path = img_fname.replace('/img', '/semantic').replace('_rgb', '_semantic')
elif cfg['DATASET']['NAME'] == 'KITTI360':
x1 = os.path.join(img_fname.replace('data_2d_raw', 'data_2d_hha'))
x2 = os.path.join(img_fname.replace('data_2d_raw', 'data_2d_lidar'))
x2 = x2.replace('.png', '_color.png')
x3 = os.path.join(img_fname.replace('data_2d_raw', 'data_2d_event'))
x3 = x3.replace('/image_00/data_rect/', '/').replace('.png', '_event_image.png')
lbl_path = os.path.join(*[img_fname.replace('data_2d_raw', 'data_2d_semantics/train').replace('data_rect', 'semantic')])
image = io.read_image(img_fname)[:3, ...]
img = self.tf_pipeline_img(image).to(self.device)
# --- modals
x1 = self._open_img(x1)
x1 = self.tf_pipeline_modal(x1).to(self.device)
x2 = self._open_img(x2)
x2 = self.tf_pipeline_modal(x2).to(self.device)
x3 = self._open_img(x3)
x3 = self.tf_pipeline_modal(x3).to(self.device)
label = io.read_image(lbl_path)[0,...].unsqueeze(0)
label[label==255] = 0
label -= 1
sample = [img, x1, x2, x3][:len(modals)]
seg_map = self.model_forward(sample)
seg_map = self.postprocess(image, seg_map, overlay)
return seg_map
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='configs/DELIVER.yaml')
args = parser.parse_args()
with open(args.cfg) as f:
cfg = yaml.load(f, Loader=yaml.SafeLoader)
# cases = ['cloud', 'fog', 'night', 'rain', 'sun', 'motionblur', 'overexposure', 'underexposure', 'lidarjitter', 'eventlowres', None]
cases = ['lidarjitter']
modals = cfg['DATASET']['MODALS']
test_file = Path(cfg['TEST']['FILE'])
if not test_file.exists():
raise FileNotFoundError(test_file)
# print(f"Model {cfg['MODEL']['NAME']} {cfg['MODEL']['BACKBONE']}")
# print(f"Model {cfg['DATASET']['NAME']}")
modals_name = ''.join([m[0] for m in cfg['DATASET']['MODALS']])
save_dir = Path(cfg['SAVE_DIR']) / 'test_results' / (cfg['DATASET']['NAME']+'_'+cfg['MODEL']['BACKBONE']+'_'+modals_name)
semseg = SemSeg(cfg)
if test_file.is_file():
segmap = semseg.predict(str(test_file), cfg['TEST']['OVERLAY'])
segmap.save(save_dir / f"{str(test_file.stem)}.png")
else:
if cfg['DATASET']['NAME'] == 'DELIVER':
files = sorted(glob.glob(os.path.join(*[str(test_file), 'img', '*', 'val', '*', '*.png']))) # --- Deliver
elif cfg['DATASET']['NAME'] == 'KITTI360':
source = os.path.join(test_file, 'val.txt')
files = []
with open(source) as f:
files_ = f.readlines()
for item in files_:
file_name = item.strip()
if ' ' in file_name:
# --- KITTI-360
file_name = os.path.join(*[str(test_file), file_name.split(' ')[0]])
files.append(file_name)
else:
raise NotImplementedError()
for file in files:
print(file)
if not '2013_05_28_drive_0000_sync' in file:
continue
segmap = semseg.predict(file, cfg['TEST']['OVERLAY'])
save_path = os.path.join(str(save_dir),file)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
segmap.save(save_path)
================================================
FILE: tools/train_mm.py
================================================
import os
import torch
import argparse
import yaml
import time
import multiprocessing as mp
from tabulate import tabulate
from tqdm import tqdm
from torch.utils.data import DataLoader
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DistributedSampler, RandomSampler
from torch import distributed as dist
from semseg.models import *
from semseg.datasets import *
from semseg.augmentations_mm import get_train_augmentation, get_val_augmentation
from semseg.losses import get_loss
from semseg.schedulers import get_scheduler
from semseg.optimizers import get_optimizer
from semseg.utils.utils import fix_seeds, setup_cudnn, cleanup_ddp, setup_ddp, get_logger, cal_flops, print_iou
from val_mm import evaluate
def main(cfg, gpu, save_dir):
start = time.time()
best_mIoU = 0.0
best_epoch = 0
num_workers = 8
device = torch.device(cfg['DEVICE'])
train_cfg, eval_cfg = cfg['TRAIN'], cfg['EVAL']
dataset_cfg, model_cfg = cfg['DATASET'], cfg['MODEL']
loss_cfg, optim_cfg, sched_cfg = cfg['LOSS'], cfg['OPTIMIZER'], cfg['SCHEDULER']
epochs, lr = train_cfg['EPOCHS'], optim_cfg['LR']
resume_path = cfg['MODEL']['RESUME']
gpus = int(os.environ['WORLD_SIZE'])
traintransform = get_train_augmentation(train_cfg['IMAGE_SIZE'], seg_fill=dataset_cfg['IGNORE_LABEL'])
valtransform = get_val_augmentation(eval_cfg['IMAGE_SIZE'])
trainset = eval(dataset_cfg['NAME'])(dataset_cfg['ROOT'], 'train', traintransform, dataset_cfg['MODALS'])
valset = eval(dataset_cfg['NAME'])(dataset_cfg['ROOT'], 'val', valtransform, dataset_cfg['MODALS'])
class_names = trainset.CLASSES
model = eval(model_cfg['NAME'])(model_cfg['BACKBONE'], trainset.n_classes, dataset_cfg['MODALS'])
resume_checkpoint = None
if os.path.isfile(resume_path):
resume_checkpoint = torch.load(resume_path, map_location=torch.device('cpu'))
msg = model.load_state_dict(resume_checkpoint['model_state_dict'])
# print(msg)
logger.info(msg)
else:
model.init_pretrained(model_cfg['PRETRAINED'])
model = model.to(device)
iters_per_epoch = len(trainset) // train_cfg['BATCH_SIZE'] // gpus
loss_fn = get_loss(loss_cfg['NAME'], trainset.ignore_label, None)
start_epoch = 0
optimizer = get_optimizer(model, optim_cfg['NAME'], lr, optim_cfg['WEIGHT_DECAY'])
scheduler = get_scheduler(sched_cfg['NAME'], optimizer, int((epochs+1)*iters_per_epoch), sched_cfg['POWER'], iters_per_epoch * sched_cfg['WARMUP'], sched_cfg['WARMUP_RATIO'])
if train_cfg['DDP']:
sampler = DistributedSampler(trainset, dist.get_world_size(), dist.get_rank(), shuffle=True)
sampler_val = None
model = DDP(model, device_ids=[gpu], output_device=0, find_unused_parameters=True)
else:
sampler = RandomSampler(trainset)
sampler_val = None
if resume_checkpoint:
start_epoch = resume_checkpoint['epoch'] - 1
optimizer.load_state_dict(resume_checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(resume_checkpoint['scheduler_state_dict'])
loss = resume_checkpoint['loss']
best_mIoU = resume_checkpoint['best_miou']
trainloader = DataLoader(trainset, batch_size=train_cfg['BATCH_SIZE'], num_workers=num_workers, drop_last=True, pin_memory=False, sampler=sampler)
valloader = DataLoader(valset, batch_size=eval_cfg['BATCH_SIZE'], num_workers=num_workers, pin_memory=False, sampler=sampler_val)
scaler = GradScaler(enabled=train_cfg['AMP'])
if (train_cfg['DDP'] and torch.distributed.get_rank() == 0) or (not train_cfg['DDP']):
writer = SummaryWriter(str(save_dir))
logger.info('================== model complexity =====================')
cal_flops(model, dataset_cfg['MODALS'], logger)
logger.info('================== model structure =====================')
logger.info(model)
logger.info('================== training config =====================')
logger.info(cfg)
for epoch in range(start_epoch, epochs):
model.train()
if train_cfg['DDP']: sampler.set_epoch(epoch)
train_loss = 0.0
lr = scheduler.get_lr()
lr = sum(lr) / len(lr)
pbar = tqdm(enumerate(trainloader), total=iters_per_epoch, desc=f"Epoch: [{epoch+1}/{epochs}] Iter: [{0}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss:.8f}")
for iter, (sample, lbl) in pbar:
optimizer.zero_grad(set_to_none=True)
sample = [x.to(device) for x in sample]
lbl = lbl.to(device)
with autocast(enabled=train_cfg['AMP']):
logits = model(sample)
loss = loss_fn(logits, lbl)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()
torch.cuda.synchronize()
lr = scheduler.get_lr()
lr = sum(lr) / len(lr)
if lr <= 1e-8:
lr = 1e-8 # minimum of lr
train_loss += loss.item()
pbar.set_description(f"Epoch: [{epoch+1}/{epochs}] Iter: [{iter+1}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss / (iter+1):.8f}")
train_loss /= iter+1
if (train_cfg['DDP'] and torch.distributed.get_rank() == 0) or (not train_cfg['DDP']):
writer.add_scalar('train/loss', train_loss, epoch)
torch.cuda.empty_cache()
if ((epoch+1) % train_cfg['EVAL_INTERVAL'] == 0 and (epoch+1)>train_cfg['EVAL_START']) or (epoch+1) == epochs:
if (train_cfg['DDP'] and torch.distributed.get_rank() == 0) or (not train_cfg['DDP']):
acc, macc, _, _, ious, miou = evaluate(model, valloader, device)
writer.add_scalar('val/mIoU', miou, epoch)
if miou > best_mIoU:
prev_best_ckp = save_dir / f"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}_epoch{best_epoch}_{best_mIoU}_checkpoint.pth"
prev_best = save_dir / f"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}_epoch{best_epoch}_{best_mIoU}.pth"
if os.path.isfile(prev_best): os.remove(prev_best)
if os.path.isfile(prev_best_ckp): os.remove(prev_best_ckp)
best_mIoU = miou
best_epoch = epoch+1
cur_best_ckp = save_dir / f"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}_epoch{best_epoch}_{best_mIoU}_checkpoint.pth"
cur_best = save_dir / f"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}_epoch{best_epoch}_{best_mIoU}.pth"
torch.save(model.module.state_dict() if train_cfg['DDP'] else model.state_dict(), cur_best)
# ---
torch.save({'epoch': best_epoch,
'model_state_dict': model.module.state_dict() if train_cfg['DDP'] else model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': train_loss,
'scheduler_state_dict': scheduler.state_dict(),
'best_miou': best_mIoU,
}, cur_best_ckp)
logger.info(print_iou(epoch, ious, miou, acc, macc, class_names))
logger.info(f"Current epoch:{epoch} mIoU: {miou} Best mIoU: {best_mIoU}")
if (train_cfg['DDP'] and torch.distributed.get_rank() == 0) or (not train_cfg['DDP']):
writer.close()
pbar.close()
end = time.gmtime(time.time() - start)
table = [
['Best mIoU', f"{best_mIoU:.2f}"],
['Total Training Time', time.strftime("%H:%M:%S", end)]
]
logger.info(tabulate(table, numalign='right'))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='configs/deliver_rgbdel.yaml', help='Configuration file to use')
args = parser.parse_args()
with open(args.cfg) as f:
cfg = yaml.load(f, Loader=yaml.SafeLoader)
fix_seeds(3407)
setup_cudnn()
gpu = setup_ddp()
modals = ''.join([m[0] for m in cfg['DATASET']['MODALS']])
model = cfg['MODEL']['BACKBONE']
exp_name = '_'.join([cfg['DATASET']['NAME'], model, modals])
save_dir = Path(cfg['SAVE_DIR'], exp_name)
if os.path.isfile(cfg['MODEL']['RESUME']):
save_dir = Path(os.path.dirname(cfg['MODEL']['RESUME']))
os.makedirs(save_dir, exist_ok=True)
logger = get_logger(save_dir / 'train.log')
main(cfg, gpu, save_dir)
cleanup_ddp()
================================================
FILE: tools/val_mm.py
================================================
import torch
import argparse
import yaml
import math
import os
import time
from pathlib import Path
from tqdm import tqdm
from tabulate import tabulate
from torch.utils.data import DataLoader
from torch.nn import functional as F
from semseg.models import *
from semseg.datasets import *
from semseg.augmentations_mm import get_val_augmentation
from semseg.metrics import Metrics
from semseg.utils.utils import setup_cudnn
from math import ceil
import numpy as np
from torch.utils.data import DistributedSampler, RandomSampler
from torch import distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from semseg.utils.utils import fix_seeds, setup_cudnn, cleanup_ddp, setup_ddp, get_logger, cal_flops, print_iou
def pad_image(img, target_size):
rows_to_pad = max(target_size[0] - img.shape[2], 0)
cols_to_pad = max(target_size[1] - img.shape[3], 0)
padded_img = F.pad(img, (0, cols_to_pad, 0, rows_to_pad), "constant", 0)
return padded_img
@torch.no_grad()
def sliding_predict(model, image, num_classes, flip=True):
image_size = image[0].shape
tile_size = (int(ceil(image_size[2]*1)), int(ceil(image_size[3]*1)))
overlap = 1/3
stride = ceil(tile_size[0] * (1 - overlap))
num_rows = int(ceil((image_size[2] - tile_size[0]) / stride) + 1)
num_cols = int(ceil((image_size[3] - tile_size[1]) / stride) + 1)
total_predictions = torch.zeros((num_classes, image_size[2], image_size[3]), device=torch.device('cuda'))
count_predictions = torch.zeros((image_size[2], image_size[3]), device=torch.device('cuda'))
tile_counter = 0
for row in range(num_rows):
for col in range(num_cols):
x_min, y_min = int(col * stride), int(row * stride)
x_max = min(x_min + tile_size[1], image_size[3])
y_max = min(y_min + tile_size[0], image_size[2])
img = [modal[:, :, y_min:y_max, x_min:x_max] for modal in image]
padded_img = [pad_image(modal, tile_size) for modal in img]
tile_counter += 1
padded_prediction = model(padded_img)
if flip:
fliped_img = [padded_modal.flip(-1) for padded_modal in padded_img]
fliped_predictions = model(fliped_img)
padded_prediction += fliped_predictions.flip(-1)
predictions = padded_prediction[:, :, :img[0].shape[2], :img[0].shape[3]]
count_predictions[y_min:y_max, x_min:x_max] += 1
total_predictions[:, y_min:y_max, x_min:x_max] += predictions.squeeze(0)
return total_predictions.unsqueeze(0)
@torch.no_grad()
def evaluate(model, dataloader, device):
print('Evaluating...')
model.eval()
n_classes = dataloader.dataset.n_classes
metrics = Metrics(n_classes, dataloader.dataset.ignore_label, device)
sliding = False
for images, labels in tqdm(dataloader):
images = [x.to(device) for x in images]
labels = labels.to(device)
if sliding:
preds = sliding_predict(model, images, num_classes=n_classes).softmax(dim=1)
else:
preds = model(images).softmax(dim=1)
metrics.update(preds, labels)
ious, miou = metrics.compute_iou()
acc, macc = metrics.compute_pixel_acc()
f1, mf1 = metrics.compute_f1()
return acc, macc, f1, mf1, ious, miou
@torch.no_grad()
def evaluate_msf(model, dataloader, device, scales, flip):
model.eval()
n_classes = dataloader.dataset.n_classes
metrics = Metrics(n_classes, dataloader.dataset.ignore_label, device)
for images, labels in tqdm(dataloader):
labels = labels.to(device)
B, H, W = labels.shape
scaled_logits = torch.zeros(B, n_classes, H, W).to(device)
for scale in scales:
new_H, new_W = int(scale * H), int(scale * W)
new_H, new_W = int(math.ceil(new_H / 32)) * 32, int(math.ceil(new_W / 32)) * 32
scaled_images = [F.interpolate(img, size=(new_H, new_W), mode='bilinear', align_corners=True) for img in images]
scaled_images = [scaled_img.to(device) for scaled_img in scaled_images]
logits = model(scaled_images)
logits = F.interpolate(logits, size=(H, W), mode='bilinear', align_corners=True)
scaled_logits += logits.softmax(dim=1)
if flip:
scaled_images = [torch.flip(scaled_img, dims=(3,)) for scaled_img in scaled_images]
logits = model(scaled_images)
logits = torch.flip(logits, dims=(3,))
logits = F.interpolate(logits, size=(H, W), mode='bilinear', align_corners=True)
scaled_logits += logits.softmax(dim=1)
metrics.update(scaled_logits, labels)
acc, macc = metrics.compute_pixel_acc()
f1, mf1 = metrics.compute_f1()
ious, miou = metrics.compute_iou()
return acc, macc, f1, mf1, ious, miou
def main(cfg):
device = torch.device(cfg['DEVICE'])
eval_cfg = cfg['EVAL']
transform = get_val_augmentation(eval_cfg['IMAGE_SIZE'])
# cases = ['cloud', 'fog', 'night', 'rain', 'sun']
# cases = ['motionblur', 'overexposure', 'underexposure', 'lidarjitter', 'eventlowres']
cases = [None] # all
model_path = Path(eval_cfg['MODEL_PATH'])
if not model_path.exists():
raise FileNotFoundError
print(f"Evaluating {model_path}...")
exp_time = time.strftime('%Y%m%d_%H%M%S', time.localtime())
eval_path = os.path.join(os.path.dirname(eval_cfg['MODEL_PATH']), 'eval_{}.txt'.format(exp_time))
for case in cases:
dataset = eval(cfg['DATASET']['NAME'])(cfg['DATASET']['ROOT'], 'val', transform, cfg['DATASET']['MODALS'], case)
# --- test set
# dataset = eval(cfg['DATASET']['NAME'])(cfg['DATASET']['ROOT'], 'test', transform, cfg['DATASET']['MODALS'], case)
model = eval(cfg['MODEL']['NAME'])(cfg['MODEL']['BACKBONE'], dataset.n_classes, cfg['DATASET']['MODALS'])
msg = model.load_state_dict(torch.load(str(model_path), map_location='cpu'))
print(msg)
model = model.to(device)
sampler_val = None
dataloader = DataLoader(dataset, batch_size=eval_cfg['BATCH_SIZE'], num_workers=eval_cfg['BATCH_SIZE'], pin_memory=False, sampler=sampler_val)
if True:
if eval_cfg['MSF']['ENABLE']:
acc, macc, f1, mf1, ious, miou = evaluate_msf(model, dataloader, device, eval_cfg['MSF']['SCALES'], eval_cfg['MSF']['FLIP'])
else:
acc, macc, f1, mf1, ious, miou = evaluate(model, dataloader, device)
table = {
'Class': list(dataset.CLASSES) + ['Mean'],
'IoU': ious + [miou],
'F1': f1 + [mf1],
'Acc': acc + [macc]
}
print("mIoU : {}".format(miou))
print("Results saved in {}".format(eval_cfg['MODEL_PATH']))
with open(eval_path, 'a+') as f:
f.writelines(eval_cfg['MODEL_PATH'])
f.write("\n============== Eval on {} {} images =================\n".format(case, len(dataset)))
f.write("\n")
print(tabulate(table, headers='keys'), file=f)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='configs/DELIVER.yaml')
args = parser.parse_args()
with open(args.cfg) as f:
cfg = yaml.load(f, Loader=yaml.SafeLoader)
setup_cudnn()
# gpu = setup_ddp()
# main(cfg, gpu)
main(cfg)